In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/My \Drive/Acad/ADS/Project/

In [None]:
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Contracting Path
        self.enc1 = self.conv_block(3, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        # Expansive Path
        self.up3 = self.upconv_block(512, 256)
        self.dec3 = self.conv_block(512, 256)
        self.up2 = self.upconv_block(256, 128)
        self.dec2 = self.conv_block(256, 128)
        self.up1 = self.upconv_block(128, 64)
        self.dec1 = self.conv_block(128, 64)

        self.out_conv = nn.Conv2d(64, 1, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        return block

    def upconv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        )
        return block

    def forward(self, x):
        # Contracting Path
        e1 = self.enc1(x)
        p1 = self.pool(e1)
        e2 = self.enc2(p1)
        p2 = self.pool(e2)
        e3 = self.enc3(p2)
        p3 = self.pool(e3)
        e4 = self.enc4(p3)

        # Expansive Path
        up3 = self.up3(e4)
        merge3 = torch.cat([up3, e3], dim=1)
        d3 = self.dec3(merge3)

        up2 = self.up2(d3)
        merge2 = torch.cat([up2, e2], dim=1)
        d2 = self.dec2(merge2)

        up1 = self.up1(d2)
        merge1 = torch.cat([up1, e1], dim=1)
        d1 = self.dec1(merge1)

        out = self.out_conv(d1)
        return torch.sigmoid(out)

In [None]:
'''
Notebook to generate segmentation maps using trained UNet model
'''

import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm

device = "cuda"
model = UNet().to(device)
model = torch.load('./Weights/UNet_ckpt_2.pt')

x_test = np.load('./Data/unet/test_images.npy').astype(np.float32)[:20].reshape(-1,5,3,128,128)

out = []
for i in tqdm(range(x_test.shape[0])):

    data = torch.from_numpy(x_test[i].reshape(5,3,128,128))
    data = data.to(device)
    recon = model(data)
    out.append(recon.cpu().detach().numpy().reshape(5,1,128,128))
dataSR = np.asarray(out)
print(dataSR.shape)

x = np.load('./Data/unet/test_labels.npy').astype(np.float32)[:20].reshape(-1,5,1,128,128)
x_out = dataSR.astype(np.float32)

x_test = x_test.reshape(-1,3,128,128) # LR
x = x.reshape(-1,1,128,128) # HR
x_out = x_out.reshape(-1,1,128,128) # SR

print("Metrics:")
criteria = nn.MSELoss()
losses = []
for i in range(x_test.shape[0]):
    losses.append(criteria(torch.from_numpy(x_out[i]), torch.from_numpy(x[i])))
print("Average MSE super resolution samples: " + str('%.5f'%np.average(losses)))

In [None]:
# Visualize samples
import cv2
dataLR = x_test
dataHR = x

for i in range(20):
  f, axarr = plt.subplots(nrows=1,ncols=3,figsize=(16,3))
  plt.sca(axarr[0]);
  plt.imshow(x_test[i].transpose(1, 2, 0)); plt.title('Low Resolution Image (Input)')
  plt.sca(axarr[1]);
  plt.imshow(x_out[i][0], cmap='gray'); plt.title('Model Output Labels')
  plt.sca(axarr[2]);
  plt.imshow(x[i][0], cmap='gray'); plt.title('Ground Truth Labels')
  plt.savefig('./Results/Samples/Sample' + str(i+1) + '.png', format='png', dpi=300)
  plt.close()