In [51]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

**Mounted Drive**

In [52]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


**HelperFunction**

In [53]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    # image_shifted = (image_tensor + 1) / 2
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=4)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def crop(image,new_shape):
    #image = (batch,channel,weight,height)

    middle_width = image.shape[2] // 2
    middle_height = image.shape[3] // 2
    starting_width = middle_width - new_shape[2] // 2
    end_width = starting_width + new_shape[2]
    starting_height = middle_height - new_shape[3] // 2
    end_height = starting_height + new_shape[3]

    return image[:,:,starting_width:end_width,starting_height:end_height]


**ContractingBlock**

In [54]:
class ContractingBlock(nn.Module):
    def __init__(self,input_channels):
        super(ContractingBlock,self).__init__()
        self.conv1 = nn.Conv2d(input_channels,input_channels * 2,kernel_size=3)
        self.conv2 = nn.Conv2d(input_channels*2,input_channels * 2,kernel_size=3)
        self.activation = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size=2,stride=2)

    def forward(self,x):
        x = self.conv1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.max_pool(x)
        
        return x

**ExpandingBlock**

In [55]:
class ExpandingBlock(nn.Module):
    def __init__(self,input_channels):
        super(ExpandingBlock,self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv1 = nn.Conv2d(input_channels,input_channels//2,kernel_size=2)
        self.conv2 = nn.Conv2d(input_channels,input_channels//2,kernel_size=3)
        self.conv3 = nn.Conv2d(input_channels//2,input_channels//2,kernel_size=3)
        self.activation = nn.ReLU()

    def forward(self,x,skip_con_x):
        x = self.upsample(x)
        x = self.conv1(x)
        skip_con_x = crop(skip_con_x,x.shape)
        x = torch.cat([x,skip_con_x],axis=1)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.conv3(x)
        x = self.activation(x)

        return x

**FeatureMap**

In [56]:
class FeatureMapBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)

    def forward(self, x):
        x = self.conv(x)
        return x

**Unet Architecture**

In [57]:
class Unet(nn.Module):
    def __init__(self,input_channels,output_channels,hidden_chans=64):
        super(Unet,self).__init__()
        self.up_feature = FeatureMapBlock(input_channels,hidden_chans)
        self.contract1 = ContractingBlock(hidden_chans)
        self.contract2 = ContractingBlock(hidden_chans * 2)
        self.contract3 = ContractingBlock(hidden_chans * 4)
        self.contract4 = ContractingBlock(hidden_chans * 8)
        self.expanding1 = ExpandingBlock(hidden_chans * 16)
        self.expanding2 = ExpandingBlock(hidden_chans * 8)
        self.expanding3 = ExpandingBlock(hidden_chans * 4)
        self.expanding4 = ExpandingBlock(hidden_chans * 2)
        self.down_feature = FeatureMapBlock(hidden_chans,output_channels)
    
    def forward(self,x):
        x0 = self.up_feature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        x5 = self.expanding1(x4,x3)
        x6 = self.expanding2(x5,x2)
        x7 = self.expanding3(x6,x1)
        x8 = self.expanding4(x7,x0)
        output = self.down_feature(x8)

        return output

**Setup Component**

In [58]:
import torch.nn.functional as F
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
input_dim = 1
label_dim = 1
display_step = 20
batch_size = 4
lr = 0.0002
initial_shape = 512
target_shape = 373
device = 'cuda'

In [62]:
from skimage import io
import numpy as np
volumes = torch.Tensor(io.imread('/content/drive/MyDrive/Generative In Action/ImplicitModel/GAN/src/Unet-data/train-volume.tif'))[:, None, :, :] / 255
labels = torch.Tensor(io.imread('/content/drive/MyDrive/Generative In Action/ImplicitModel/GAN/src/Unet-data/train-volume.tif', plugin="tifffile"))[:, None, :, :] / 255
labels = crop(labels, torch.Size([labels.shape[0], 1, target_shape, target_shape]))
dataset = torch.utils.data.TensorDataset(volumes, labels)

In [63]:
dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True)

unet = Unet(input_dim, label_dim).to(device)
unet_opt = torch.optim.Adam(unet.parameters(), lr=lr)
cur_step = 0

**Training**

In [65]:
for epoch in range(n_epochs):
    for real, labels in tqdm(dataloader):
        real = real.to(device)
        labels = labels.to(device)
        #===================
        unet_opt.zero_grad()

        pred = unet(real)
        loss = criterion(pred,labels)
        loss.backward()

        unet_opt.step()
        #===================
        if cur_step % display_step == 0:
              print(f"Epoch {epoch}: Step {cur_step}: U-Net loss: {loss.item()}")
              show_tensor_images(
                  crop(real, torch.Size([len(real), 1, target_shape, target_shape])), 
                  size=(input_dim, target_shape, target_shape)
              )
              show_tensor_images(labels, size=(label_dim, target_shape, target_shape))
              show_tensor_images(torch.sigmoid(pred), size=(label_dim, target_shape, target_shape))
        cur_step += 1

  0%|          | 0/8 [00:00<?, ?it/s]

RuntimeError: ignored