In [2]:
!nvidia-smi

Thu Apr 29 13:35:35 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    24W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Model Architecture

In [None]:
# concatenate version of U-net with ResNet-50 encoder
import torchvision
import torch
import torch.nn as nn
from torch.nn import init
from torch.cuda.amp import autocast, GradScaler
import numpy 
from PIL import Image
from numpy import asarray
from torchvision.transforms import ToTensor
from __future__ import print_function
import argparse
import os
import matplotlib.pyplot as plt
import numpy as np 
from torchvision import transforms
import imageio
from imageio import imread
import cv2
import matplotlib.pyplot as plt
resnet = torchvision.models.resnet.resnet50(pretrained=True)

class ConvBlock(nn.Module):
    """
    Helper module that consists of a Conv -> BN -> ReLU
    """
    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU() #nn.LeakyReLU()  #
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x

class out(nn.Module):

    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.conv2 = nn.Conv2d(32, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.out = nn.Conv2d(32, 3, padding=padding, kernel_size=3, stride=stride)
        # self.dropout = nn.Dropout(0.4)
        self.relu = nn.ReLU()
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.conv2(x)
        # x = self.dropout(x)
        x = self.out(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x


class Bridge(nn.Module):
    """
    This is the middle layer of the UNet which just consists of some conv blocks
    """

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)


class UpBlockForUNetWithResNet50(nn.Module):
    """
    Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock
    """

    def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
                 upsampling_method="conv_transpose"):
        super().__init__()

        if up_conv_in_channels == None:
            up_conv_in_channels = in_channels
        if up_conv_out_channels == None:
            up_conv_out_channels = out_channels

        if upsampling_method == "conv_transpose":
            self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.upsample = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            )
        self.conv_block_1 = ConvBlock(in_channels, out_channels)
        self.conv_block_2 = ConvBlock(out_channels, out_channels)

    def forward(self, up_x, down_x):
        """
        :param up_x: this is the output from the previous up block
        :param down_x: this is the output from the down block
        :return: upsampled feature map
        """
        x = self.upsample(up_x)
        x = torch.cat([x, down_x], 1)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return x


class UNetWithResnet50Encoder(nn.Module):
    DEPTH = 6

    def __init__(self, n_classes=256):
        super().__init__()
        resnet = torchvision.models.resnet.resnet50(pretrained=True) # Loading the pre trained resnet
        down_blocks = []  # encodes blocks or downsampling blocks
        up_blocks = []  # decoder part of upsampling blocks 
        self.input_block = nn.Sequential(*list(resnet.children()))[:3]  # input layers or resnet or encoder layers
        self.input_pool = list(resnet.children())[3]  # encoder max pooling layer
        for bottleneck in list(resnet.children()):
            if isinstance(bottleneck, nn.Sequential):
                down_blocks.append(bottleneck)
     
        self.down_blocks = nn.ModuleList(down_blocks)  # just getting the down sampling layers/encoder layers from resnet

        self.bridge = Bridge(4096, 2048)

        up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
        up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
        up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
                                                    up_conv_in_channels=256, up_conv_out_channels=128))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
                                                    up_conv_in_channels=128, up_conv_out_channels=64))

        self.up_blocks = nn.ModuleList(up_blocks) # got the upsampling or the decoder layers

        self.out = out(64, 32)

    def forward(self, x,y, with_output_feature_map=False):  # here we are making the network
        pre_pools = dict()
        pre_pools[f"layer_0"] = x
        x = self.input_block(x)  # taking the input block and storing in x
        pre_pools[f"layer_1"] = x # the input block is now stored as layer 1
        x = self.input_pool(x)  #  taking the max_pool now 

        # constrcuting the encoder part
        for i, block in enumerate(self.down_blocks, 2):  # for all the down blocks 
            x = block(x)
            if i == (UNetWithResnet50Encoder.DEPTH - 1):
                continue
            pre_pools[f"layer_{i}"] = x  ## creating all the down sampling layers


        pre_pools_inp2 = dict()
        pre_pools_inp2[f"layer_0"] = y
        y = self.input_block(y)  # taking the input block and storing in x
        pre_pools_inp2[f"layer_1"] = y # the input block is now stored as layer 1
        y = self.input_pool(y)  #  taking the max_pool now 


        for i, block in enumerate(self.down_blocks, 2):  # for all the down blocks 
            y = block(y)
            if i == (UNetWithResnet50Encoder.DEPTH - 1):
                continue
            pre_pools_inp2[f"layer_{i}"] = y  ## creating all the down sampling layers

        # concatenating the encoder outputs
        x = torch.cat([x,y],1)

        # constructing the bridge between encoder and decoder
        x = self.bridge(x)  # this is the bridge between down sampling and up sampling 

        # decoder and output block
        for i, block in enumerate(self.up_blocks, 1):
            key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}"  # now using that bridge for upsampling f
            x = block(x, pre_pools[key])
        output_feature_map = x
        x = self.out(x)
        del pre_pools
        del pre_pools_inp2
        if with_output_feature_map:
            return x, output_feature_map
        else:
            return x

Applying tranformations

In [None]:
from skimage.io import imread
from torch.utils import data
from tqdm import tqdm

# converting the data to the tensors 
class SegmentationDataSet(data.Dataset):
    def __init__(self,
                 inputs: list,
                 targets: list,
                 originals: list,
                 transform=None,
                 use_cache=False,
                 pre_transform=None,
                 ):
        self.inputs = inputs
        self.targets = targets
        self.originals = originals
        self.transform = transform
        self.inputs_dtype = torch.float32
        self.targets_dtype = torch.float32
        self.use_cache = use_cache
        self.pre_transform = pre_transform

        # caching the data to reduce training time
        if self.use_cache:
            self.cached_data = []

            progressbar = tqdm(range(len(self.inputs)), desc='Caching')
            for i, img_name, tar_name, org_name in zip(progressbar, self.inputs, self.targets,self.originals):
                img= Image.open(img_name)
                tar = Image.open(tar_name) 
                org = Image.open(org_name)
                if self.pre_transform is not None:
                    img = self.pre_transform(img)
                    tar = self.pre_transform(tar)
                    org = self.pre_transform(org)
                self.cached_data.append((img, tar, org))
                
    def __len__(self):
        return len(self.inputs)

    def __getitem__(self,
                    index: int):
        if self.use_cache:
            x, y, z = self.cached_data[index]
        else:
            # Select the sample
            input_ID = self.inputs[index]
            target_ID = self.targets[index]
            org_ID = self.originals[index]
            # Load input and target
            x, y, z = imread(input_ID), imread(target_ID), imread(org_ID)

        # Preprocessing
        if self.transform is not None:
            x, y = self.transform(image=x)['image'], y
        return x, y, z

Training Function

In [None]:
# training functions 
from math import exp

class LogCoshLoss(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self,yt, yt_prime):
    e_y = yt - yt_prime
    return torch.mean(torch.log(torch.cosh(e_y + 1e-12)))

# Trainer function
class Trainer:
    def __init__(self,
                 model: torch.nn.Module,
                 device: torch.device,
                 criterion: torch.nn.Module,
                 optimizer: torch.optim.Optimizer,
                 training_DataLoader: torch.utils.data.Dataset,
                 validation_DataLoader: torch.utils.data.Dataset = None,
                 lr_scheduler: torch.optim.lr_scheduler = None,
                 epochs: int = 100,
                 accum_iter: int = 1,
                 verbose_step: int = 1,
                 EarlyStopping: int = 5,
                 notebook: bool = False,
                 schd_batch_update: bool =False
                 ):

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.training_DataLoader = training_DataLoader
        self.validation_DataLoader = validation_DataLoader
        self.device = device
        self.epochs = epochs
        self.notebook = notebook
        self.accum_iter = accum_iter
        self.schd_batch_update = schd_batch_update
        self.verbose_step = verbose_step
        self.EarlyStopping = EarlyStopping

        self.training_loss = []
        self.validation_loss = []
        self.learning_rate = []
        self.loss_min = 999999
        self.epoch = 0

    def run_trainer(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        progressbar = trange(self.epochs, desc='Progress')
        not_improving = 0
        for i in progressbar:
            """Epoch counter"""
            self.epoch += 1  # epoch counter

            """Training block"""
            self._train()

            """Validation block"""
            if self.validation_DataLoader is not None:
                self._validate()
                not_improving += 1

            if self.validation_loss[-1]<self.loss_min:
                print(f'val_loss_min ({self.loss_min:.4f} --> {self.validation_loss[-1]:.4f}). Saving model ...')
                self.loss_min = self.validation_loss[-1]
                not_improving = 0
                torch.save({
                    'epoch': self.epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'loss': self.criterion,
                     }, '/content/drive/MyDrive/Checkpoints/StyleTransferByUnet.pth')
                print('-----------------------------------------------------------')

            # applying early stopping to the model
            if not_improving == self.EarlyStopping:
                print('Early Stopping...')
                return self.training_loss, self.validation_loss, self.learning_rate

            # """Learning rate scheduler block"""
            # if self.lr_scheduler is not None:
            #     if self.validation_DataLoader is not None and self.lr_scheduler.__class__.__name__ == 'ReduceLROnPlateau':
            #         self.lr_scheduler.batch(self.validation_loss[i])  # learning rate scheduler step with validation loss
            #     else:
            #         self.lr_scheduler.batch()  # learning rate scheduler step
        return self.training_loss, self.validation_loss, self.learning_rate

    # training part
    def _train(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        running_loss = None
        self.model.train()  # train mode
        train_losses = []  # accumulate the losses here
        batch_iter = tqdm(enumerate(self.training_DataLoader), 'Training', total=len(self.training_DataLoader))
        for i, (x, y, z) in batch_iter:
            input, target = x.to(self.device), y.to(self.device)  # send to device (GPU or CPU)
            origin = z.to(self.device)

            with autocast():
                out = self.model(input,target)  # one forward pass
              
                # out loss function is based on the input value since that's the final result we want
                loss = self.criterion(out, origin)  # calculate loss

            scaler.scale(loss).backward()
            if running_loss is None:
                running_loss = loss.item()
            else:
                running_loss = running_loss * .95 + loss.item() * .05
            train_losses.append(loss.item())
            # loss.backward()  # one backward pass

            if ((i + 1) %  self.accum_iter == 0) or ((i + 1) == len(self.training_DataLoader)):
                # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
                scaler.step(self.optimizer)  # update the parameters
                scaler.update()
                self.optimizer.zero_grad() 

                if self.lr_scheduler is not None and self.schd_batch_update:
                   self.lr_scheduler.step()

            if ((i + 1) % self.verbose_step == 0) or ((i + 1) == len(self.training_DataLoader)):
                description = f'epoch {self.epoch} train loss: {running_loss:.4f}'

                batch_iter.set_description(description)  # update progressbar

            del input
            del target
            del origin
            torch.cuda.empty_cache()

        if self.lr_scheduler is not None and not self.schd_batch_update:
            self.lr_scheduler.step()

        self.training_loss.append(np.mean(train_losses))
        self.learning_rate.append(self.optimizer.param_groups[0]['lr'])

        # batch_iter.close()

    # validation part
    def _validate(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        self.model.eval()  # evaluation mode
        valid_losses = []  # accumulate the losses here
        batch_iter = tqdm(enumerate(self.validation_DataLoader), 'Validation', total=len(self.validation_DataLoader))

        for i, (x, y, z) in batch_iter:
            input, target = x.to(self.device), y.to(self.device)  # send to device (GPU or CPU)
            origin = z.to(self.device)
            with torch.no_grad():        
                out = self.model(input,target)
                loss = self.criterion(out, origin)
                loss_value = loss.item()
                valid_losses.append(loss_value)

                batch_iter.set_description(f'epoch {self.epoch} validation {loss_value:.4f}')
                del input
                del target
                del origin
                torch.cuda.empty_cache()

        self.validation_loss.append(np.mean(valid_losses))

        # batch_iter.close()
        torch.cuda.empty_cache()

Data Pre-processing

In [None]:
# pre processing
from torch.utils.data import DataLoader
import albumentations
from sklearn.model_selection import train_test_split
import pathlib
from torchvision import datasets, models, transforms
import random
# root directory
root = pathlib.Path.cwd() / '/content/drive/MyDrive/Dataset_for_StyleTransferByUnet/training_set'
root1 = pathlib.Path.cwd() / '/content/drive/MyDrive/Dataset_for_StyleTransferByUnet/validation_set'
def get_filenames_of_path(path: pathlib.Path, ext: str = '*'):
    """Returns a list of files in a directory/path. Uses pathlib."""
    print(ext)
    filenames = [file for file in path.glob(ext) if file.is_file()]
    filenames.sort()   #= sorted(filenames)
    print(filenames)
    return filenames
# input and target files
inputs = get_filenames_of_path(root / 'train_new')
targets = get_filenames_of_path(root / 'train_new2')
originals = get_filenames_of_path(root / 'train_new1')

validation_inputs = get_filenames_of_path(root1 / 'input')
validation_targets = get_filenames_of_path(root1 / 'output')
validation_originals = get_filenames_of_path(root1 / 'style')

# resizing the data to 256x256
pre_transforms = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor(),
    ])

def get_train_transforms():
    return A.Compose([
            # A.Cutout(),
            A.OneOf([
                A.IAAAdditiveGaussianNoise(scale=(0.01 * 255, 0.02 * 255)),
                # A.GaussNoise(var_limit=(10.0, 25.0)),
            ], p=0.66),
            # A.OneOf([
            #     A.MotionBlur(blur_limit=3),
            #     A.MedianBlur(blur_limit=3),
            #     A.Blur(blur_limit=3), 
            # ], p=0.66),
            # A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.33),
            # A.RandomGamma(gamma_limit=(90, 110)),
            # A.CLAHE(clip_limit=1),
            A.OneOf([
                A.IAASharpen(alpha=(0.1, 0.3)),
                A.IAAEmboss(alpha=(0.1, 0.4)),
            ], p=0.66),
        ], p=1.)

# random seed
random_seed = 42
# split dataset into training set and validation set
train_size = 0.8  # 80:20 split

inputs_train =  inputs 
targets_train = targets 
originals_train = originals 

inputs_valid = validation_inputs 
targets_valid = validation_targets 
originals_valid = validation_originals 


dataset_train = SegmentationDataSet(inputs=inputs_train,
                                    targets=targets_train,
                                    originals= originals_train,
                                    transform=None,
                                    use_cache=True,
                                    pre_transform=pre_transforms)

# dataset validation
dataset_valid = SegmentationDataSet(inputs=inputs_valid,
                                    targets=targets_valid,
                                    originals= originals_valid,
                                    transform=None,
                                    use_cache=True,
                                    pre_transform=pre_transforms)

# dataloader training
dataloader_training = DataLoader(dataset=dataset_train,
                                 batch_size=32,
                                 shuffle=True)
# dataloader validation
dataloader_validation = DataLoader(dataset=dataset_valid,
                                   batch_size=32,
                                   shuffle=True)

del originals_train
del originals_valid
del pre_transforms
del targets_train
del targets_valid
del inputs_train
del inputs_valid
del dataset_train
del dataset_valid

del inputs
del targets
del originals
import gc
gc.collect()
torch.cuda.empty_cache()

Training the model

In [None]:
# training the model 
device = torch.device('cuda')
 
model = UNetWithResnet50Encoder().to(device)
 
# criterion
criterion = torch.nn.MSELoss() #LogCoshLoss()



# optimizer

decoder_parameters = [item for module in model.up_blocks for item in module.parameters()]
optimizer = torch.optim.Adam(decoder_parameters, lr=1e-4, weight_decay=1e-6)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1)
# lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=25, 
#                                                    max_lr=0.0001, epochs=150, steps_per_epoch=len(dataloader_training))

scaler = GradScaler()  
trainer_n = Trainer(model = model,
                  device = device,
                  criterion = criterion,
                  optimizer = optimizer,
                  training_DataLoader = dataloader_training,
                  validation_DataLoader = dataloader_validation,
                  lr_scheduler = lr_scheduler,
                  epochs = 150,
                  accum_iter = 4,
                  verbose_step = 1,
                  EarlyStopping = 30,
                  notebook = True,
                  schd_batch_update = True,
                  )
 

# # start training
training_losses, validation_losses, lr_rates = trainer_n.run_trainer()

# Plot training & validation loss values
plt.plot(training_losses)
plt.plot(validation_losses)
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper right')
plt.show() 

del model, optimizer, dataloader_training, dataloader_validation, scaler, lr_scheduler
torch.cuda.empty_cache()

In [None]:
# making predictions from the trained model (generating output images)
device = torch.device('cuda')

model = UNetWithResnet50Encoder().to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/Checkpoints/StyleTransferByUnet.pth')['model_state_dict'])

# predict function 
def predict(img,
            tar,
            model,
            preprocess,
            postprocess,
            device,
            ):
    model.eval()
    img = preprocess(img)  # preprocess image
    target = preprocess(tar)
    x = img.unsqueeze(0).cuda() 
    y = target.unsqueeze(0).cuda()
    with torch.no_grad():
        out = model(x,y)  # send through model/network
    result = postprocess(out)  # works better without relu

    return result
    
#  making predictions

root = pathlib.Path.cwd() / '/content/drive/MyDrive/Dataset_for_StyleTransferByUnet/validation_set' 
def get_filenames_of_path(path: pathlib.Path, ext: str = '*'):
    """Returns a list of files in a directory/path. Uses pathlib."""
    filenames = [file for file in path.glob(ext) if file.is_file()]
    filenames.sort()   #= sorted(filenames)
    # print(filenames)
    return filenames
    del filenames
    
# read images and store them in memory
images = [Image.open(img_name) for img_name in get_filenames_of_path(root / 'input')]
targets = [Image.open(tar_name) for tar_name in get_filenames_of_path(root / 'output')]
origins = [Image.open(org_name) for org_name in get_filenames_of_path(root / 'style')]

# preprocess function
def preprocess(img):

  
  pre_transforms = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor(),
    ])

  img = pre_transforms(img)
  return img
  del img


def inverse_normalize(tensor, mean, std):
  for t, m, s in zip(tensor, mean, std):
      t.mul_(s).add_(m)
  return tensor


# postprocess function
def postprocess(img: torch.tensor):
    
  
    inv_normalize = transforms.Normalize(
      mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
      std=[1/0.229, 1/0.224, 1/0.225]
    )

    post_transforms = transforms.Normalize([0,0,0], [1,1,1])

    normalize =  transforms.Compose([ transforms.Normalize([0,0,0], [1,1,1]),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    img = img.squeeze(0)  # (C,H,W)
    img = torch.transpose(img,0,1) #(H,C,W)
    img = torch.transpose(img,1,2) #(H,W,C)
    img = img.cpu()
    img = np.maximum(np.minimum(img.cpu().detach().numpy(),1.0),0.0)
    img = (np.array(img) * 255).astype(np.uint8)
    
    return img
    del img


def preprocess_originals(img,
            preprocess,
            postprocess,
            ):
    model.eval()
    img = preprocess(img)  # preprocess image
    x = img.unsqueeze(0).cpu() 
    result = postprocess(x)  
    return result

output = [predict(img, targets[i], model, preprocess, postprocess, device) for i,img in enumerate(images)]
origs = [preprocess_originals(img, preprocess, postprocess) for i,img in enumerate(origins)]

for i in range(len(output)):
    plt.figure(figsize=(16, 12))

    plt.subplot(1, 2, 1)
    plt.xticks([])
    plt.yticks([])
    plt.title("Output")
    plt.imshow(output[i])

    plt.subplot(1, 2, 2)
    plt.xticks([])
    plt.yticks([])
    plt.title("Original")
    plt.imshow(origs[i])

    plt.show()

p_transforms = transforms.Compose([
    transforms.Resize((256,256)),
])

image_new = [p_transforms(img_name) for img_name in images]
del images
gc.collect()
torch.cuda.empty_cache()
target_new = [p_transforms(tar_name) for tar_name in targets]
del targets
gc.collect()
torch.cuda.empty_cache()
origin_new = [p_transforms(tar_name) for tar_name in origins]
del origins
gc.collect()
torch.cuda.empty_cache()

gc.collect()
torch.cuda.empty_cache()

Display the Results

In [None]:
from matplotlib import pyplot as plt
import cv2 
for i, im in enumerate(output):
  
  if i < 100:
    fig, axes = plt.subplots(nrows=1,ncols=4,figsize=(15,15))
    axes[0].imshow(image_new[i])
    axes[1].imshow(target_new[i])
    axes[2].imshow(origin_new[i])
    axes[3].imshow(im)

###Evaluation Metrics

SSIM calculation function


In [None]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)
            
            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)

PSNR calculation function

In [None]:
# psnr for pwct results
from matplotlib import pyplot as plt
import cv2 
from math import log10, sqrt
import math 
import PIL
psnr_values = []
from PIL import Image
image_dir = os.path.join('./train_new1/')   
result_dir = os.path.join('./post_results_unet_input/')
targets = [f for f in os.listdir(image_dir)]
images = [f for f in os.listdir(result_dir)]

print(targets)
print(images)
gc.collect()
torch.cuda.empty_cache()
def psnr(img1, img2):
    mse = numpy.mean( (img1 - img2) ** 2 )
    if mse == 0:
      return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10((PIXEL_MAX)/ sqrt(mse))

PSNR and SSIM calculation

In [None]:
psnr_values = []
avgPSNR = 0
for i in range(len(origin_new)):
    im1 = tf.image.convert_image_dtype(origin_new[i], tf.float32)
    output = output_model[i].squeeze(0)
    output = output.cpu()
    im2 = tf.image.convert_image_dtype(output, tf.float32)
    psnr2 = tf.image.psnr(im1, im2, max_val=1.0)
    avgPSNR += psnr2

print(len(origin_new))
print('Average PSNR = ', avgPSNR/len(origin_new))

print("SSIM Values")
ssim_value = 0
for i in range(len(origin_new)):
    temp1 = torch.unsqueeze(origin_new[i],0)
    temp2 = output_model[i].cpu() 
    ssim_value += ssim(temp1,temp2)
print('Average SSIM = ', ssim_value/len(origin_new))