In [1]:
import numpy as np 
import pandas as pd 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.utils.data as data
import random
import torch.nn.functional as F

from torchvision import transforms
import numpy
from PIL import Image
from torch.nn.modules.utils import _ntuple
import datetime
import pandas
import math

import os , itertools

from numpy import load

import matplotlib.pyplot as plt

In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [None]:
data_dir = './dataset'
data_dir29 ='./dataset'

In [None]:
params = {
    'batch_size':1,
    'input_size':128,
    'stack_num':29,
    'resize':128,
    'crop_size':32,
    'fliplr':True,
    'num_epochs':100,
    'decay_epoch':100,
    'save_epoch': 80,
    'print_layer':1,
    'ngf':16,               #number of generator filters
    'ndf':32,               #number of discriminator filters
    'num_resnet':6,         #number of resnet blocks
    'lrG':0.0002,           #learning rate for generator
    'lrD':0.0002,           #learning rate for discriminator
    'beta1':0.5 ,           #beta1 for Adam optimizer
    'beta2':0.999 ,         #beta2 for Adam optimizer
    'lambdaA':10 ,          #lambdaA for cycle loss
    'lambdaB':10  ,         #lambdaB for cycle loss
}

In [None]:
class NormalizeInverse(transforms.Normalize):
    
    def __init__(self,mean,std):
        
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        
        std_inv = 1/(std+1e-7)
        mean_inv = -mean*std_inv
        super().__init__(mean=mean_inv,std=std_inv)
        
    def __call__(self,tensor):
        return super().__call__(tensor.clone())

In [None]:
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

In [None]:
def plot_train_result(real_image, gen_image, epoch, save=False,  show=True, fig_size=(12, 12)):
    fig, axes = plt.subplots(2, 2, figsize=fig_size)
    
    imgs = [to_np(real_image[0]), to_np(gen_image[0]),
            to_np(real_image[1]), to_np(gen_image[1])]
       
    for ax, img in zip(axes.flatten(), imgs):
        ax.axis('off')
        
        img1=img[0,0,:,:,:]
        img2=torch.from_numpy(img1)

        tra=NormalizeInverse(mean=[0.5],std=[0.5])

        img3 = tra(img2)*255
        img3 = img3.to(dtype=torch.uint8)
        
        ax.imshow(img1[:,:,params['print_layer']], cmap='gray', aspect='equal')
        plt.subplots_adjust(wspace=0, hspace=0)

    title = 'Epoch {0}'.format(epoch + 1)
    fig.text(0.5, 0.04, title, ha='center')

    # save figure
    if save:
        save_fn = './ResultsFiles/Result_epoch_{:d}'.format(epoch+1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()


In [None]:
def plot_train_result29(recon_image29, epoch, save=False,  show=True, fig_size=(5, 15)):
    
    img = recon_image29
    
    plt.axis('off')
        
    img1=img[0,0,:,:]

    img1 = to_np(img1)
        
    plt.imshow(img1, cmap='gray', aspect='equal')

    title = 'Epoch {0}'.format(epoch + 1)

    if save:
        save_fn = './ResultsFiles/Result_epoch_{:d}'.format(epoch+1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()


In [None]:
def print_images(im_list, save_dir, epoch_num, save_mode_on=True):
    """
        Pytorch conv2d uses input & output dimensions as: (N, C, H, W).
        To be able to plot the generated images, torch tensors must be converted back to (W,H,C)
        and transferred back into the local memory by .cpu() function
    """

    #A list that holds the necessary plot titles
    titles = ['Real-Blur', 'Fake-Sharp (B->S)', 'Recon-Blur (B->S->B)', 'Identity-Sharp (S->S)']

    im_idx = 0
    fig, axarr = plt.subplots(1,4, figsize=(12, 12))

    for j in range(4):

        im = im_list[im_idx].squeeze().T
        im = (im + 1) / 2.0
        imm=im[params['print_layer'],:,:]
        axarr[j].axis('off')
        axarr[j].imshow(imm.detach().cpu())
        axarr[j].set_title(titles[im_idx], fontweight="bold")

        im_idx = im_idx + 1

    fig.tight_layout()

    if save_mode_on:
        plt.savefig(os.path.join(save_dir, 'epoch-{}.png'.format(epoch_num)))
        plt.close()
    else:
        plt.show()

In [None]:
class GetMyData(data.Dataset):
    
    def __init__(self, image_dir, subfolder='train', transform=None, crop_size=None, fliplr=False):
        super(GetMyData, self).__init__()
        self.input_path = os.path.join(image_dir, subfolder)
        self.image_filenames = [x for x in sorted(os.listdir(self.input_path))]
        self.transform = transform
        self.crop_size = crop_size

    def __getitem__(self, index):
        #Load Image
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        imgnpz = np.load(img_fn, mmap_mode='r', allow_pickle=True)
        image = imgnpz['arr_0']
        
        #Preprocessing
        #img = numpy.array(img)
        imarray = numpy.array(image)
        iim=image
        
        #Preprocessing
        
        if self.crop_size:
            
            x = random.randint(0, params['input_size'] - self.crop_size + 1)
            y = random.randint(0, params['input_size'] - self.crop_size + 1)
            
            #iim = im.crop((x, y, x + self.crop_size, y + self.crop_size))
            
            iim=torch.from_numpy(imarray)
            iim=iim[x:x + self.crop_size,y:y+ self.crop_size,:]

            
    
        if self.transform is not None:
            iim=torch.from_numpy(imarray)
            
            aaa=imarray[:,:,0]
            aaa0 = self.transform(aaa)
               
            iim = aaa0[...,np.newaxis]

            for i in range (1,params['stack_num']):   
                 
                aaa=imarray[:,:,i]
                aaa1 = self.transform(aaa)
               
                aaa1 = aaa1[...,np.newaxis]
                iim=np.concatenate((iim,aaa1),3)

        return iim

    def __len__(self):
        return len(self.image_filenames)

In [None]:
class GetVerticalData(data.Dataset):
    
    def __init__(self, image_dir, subfolder='train', transform=None,crop_size=None, fliplr=False):
        super(GetVerticalData, self).__init__()
        self.input_path = os.path.join(image_dir, subfolder)
        self.image_filenames = [x for x in sorted(os.listdir(self.input_path))]
        self.transform = transform
        self.crop_size = crop_size
     

    def __getitem__(self, index):
        #Load Image
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        img = Image.open(img_fn)
        
        imarray = numpy.array(img)
        imm=imarray
        
        if self.crop_size:
            
            imm = torch.from_numpy(imarray)
            
            x = random.randint(0, params['input_size'] - self.crop_size + 1)
            y = random.randint(0, params['input_size'] - self.crop_size + 1)
            
            imm = imm[x:x + self.crop_size,y:y + self.crop_size]

        
        if self.transform is not None:
            
            imm = self.transform(imarray)

        return imm

    def __len__(self):
        return len(self.image_filenames)

In [None]:
class GetSharpData(data.Dataset):
    
    def __init__(self, image_dir, subfolder='train', transform=None,crop_size=None, fliplr=False):
        super(GetSharpData, self).__init__()
        self.input_path = os.path.join(image_dir, subfolder)
        self.image_filenames = [x for x in sorted(os.listdir(self.input_path))]
        self.transform = transform
        self.crop_size = crop_size
     

    def __getitem__(self, index):
        # Load Image
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        img = Image.open(img_fn)
        
        imarray = numpy.array(img)
        imm=imarray
        
        if self.crop_size:
            
            imm = torch.from_numpy(imarray)
            x = random.randint(0, params['input_size'] - self.crop_size + 1)
            imm = imm[x:x+ self.crop_size,:]

        
        if self.transform is not None:
            
            imm = self.transform(imarray)

        return imm

    def __len__(self):
        return len(self.image_filenames)

In [None]:
class GetVerticalData_scale(data.Dataset):
    
    def __init__(self, image_dir, subfolder='train', transform=None):
        super(GetVerticalData_new, self).__init__()
        self.input_path = os.path.join(image_dir, subfolder)
        self.image_filenames = [x for x in sorted(os.listdir(self.input_path))]
        self.transform = transform
     

    def __getitem__(self, index):
        #Load Image
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        img = Image.open(img_fn)
        
        imarray = numpy.array(img)  
        imarray = imarray[:,11:18]
        
        if self.transform is not None:
            
            imarray = torch.from_numpy(imarray)
            imarray = self.transform(imarray)

        return imarray

    def __len__(self):
        return len(self.image_filenames)

In [None]:
class ReplicationPad3d(torch.nn.modules.padding._ReplicationPadNd):
    def __init__(self, padding):
        super(ReplicationPad3d, self).__init__()
        self.padding = _ntuple(6)(padding)

In [None]:
def to_np(x):
    return x.data.cpu().numpy()

In [None]:
class ResnetBlock(torch.nn.Module):
    
    
    def __init__(self,features_Nr,kernel_size=3,stride=1,padding=1,downsample=None):
        super(ResnetBlock,self).__init__()
        
        self.conv1=torch.nn.Conv3d(features_Nr, features_Nr, kernel_size,stride,padding)
        self.Bn1=torch.nn.BatchNorm3d(features_Nr)
        self.LRe1=torch.nn.LeakyReLU(0.2)
            
        self.conv2=torch.nn.Conv3d(features_Nr,features_Nr,  kernel_size,stride,padding)
        self.Bn2=torch.nn.BatchNorm3d(features_Nr)
        self.LRe2=torch.nn.LeakyReLU(0.2)
        self.downsample = downsample
        
    def forward(self,x):
        residual=x
        output=self.conv1(x)
        output=self.Bn1(output)
        output=self.LRe1(output)
        output=self.conv2(output)
        output=self.Bn2(output)
        
        if self.downsample:
            residual = self.downsample(x)
            
        output=output+residual
        
        output=self.LRe2(output)
        
        return output



In [None]:
class ConvBlock(torch.nn.Module):
    def __init__(self,input_size,output_size,kernel_size=3,stride=2,bias=False,padding=1,activation='relu',batch_norm=True):
        super(ConvBlock,self).__init__()
        self.conv = torch.nn.Conv3d(input_size,output_size,kernel_size,stride,padding)
        self.batch_norm = batch_norm
        self.bn = torch.nn.InstanceNorm3d(output_size)
        self.activation = activation
        self.relu = torch.nn.ReLU(True)
        self.lrelu = torch.nn.LeakyReLU(0.2,True)
        self.tanh = torch.nn.Tanh()
    def forward(self,x):
        if self.batch_norm:
            out = self.bn(self.conv(x))
        else:
            out = self.conv(x)
        
        if self.activation == 'relu':
            return self.relu(out)
        elif self.activation == 'lrelu':
            return self.lrelu(out)
        elif self.activation == 'tanh':
            return self.tanh(out)
        elif self.activation == 'no_act':
            return out

In [None]:
class ConvBlock2d(torch.nn.Module):
    def __init__(self,input_size,output_size,kernel_size=3,stride=2,padding=1,activation='relu',batch_norm=True):
        super(ConvBlock2d,self).__init__()
        self.conv = torch.nn.Conv2d(input_size,output_size,kernel_size,stride,padding)
        self.batch_norm = batch_norm
        self.bn = torch.nn.InstanceNorm2d(output_size)
        self.activation = activation
        self.relu = torch.nn.ReLU(True)
        self.lrelu = torch.nn.LeakyReLU(0.2,True)
        self.tanh = torch.nn.Tanh()
    def forward(self,x):
        if self.batch_norm:
            out = self.bn(self.conv(x))
        else:
            out = self.conv(x)
        
        if self.activation == 'relu':
            return self.relu(out)
        elif self.activation == 'lrelu':
            return self.lrelu(out)
        elif self.activation == 'tanh':
            return self.tanh(out)
        elif self.activation == 'no_act':
            return out

In [None]:
class DeconvBlock(torch.nn.Module):
    def __init__(self,input_size,output_size,kernel_size=3,stride=2,bias=False,padding=1,output_padding=1,activation='relu',batch_norm=True):
        super(DeconvBlock,self).__init__()
        self.deconv = torch.nn.ConvTranspose3d(input_size,output_size,kernel_size,stride,padding,output_padding)
        self.batch_norm = batch_norm
        self.bn = torch.nn.InstanceNorm3d(output_size)
        self.activation = activation
        self.relu = torch.nn.ReLU(True)
    def forward(self,x):
        if self.batch_norm:
            out = self.bn(self.deconv(x))
        else:
            out = self.deconv(x)
        if self.activation == 'relu':
            return self.relu(out)
        elif self.activation == 'lrelu':
            return self.lrelu(out)
        elif self.activation == 'tanh':
            return self.tanh(out)
        elif self.activation == 'no_act':
            return out

In [None]:
class Generator(torch.nn.Module):
    
    def __init__(self,input_dim,num_filter,output_dim,num_resnet):
        super(Generator,self).__init__()
        
        #Reflection padding
        self.pad1 = ReplicationPad3d(3)
        self.pad2 = ReplicationPad3d(2)
        
        #Encoder
        #self.conv1 = ConvBlock(input_dim,num_filter,kernel_size=(7,7,7),stride=1,padding=(2,2,2),activation='relu',batch_norm=True)
        self.conv1 = ConvBlock(input_dim,num_filter,kernel_size=5,stride=1,padding=2,activation='relu',batch_norm=True)
        self.conv2 = ConvBlock(num_filter,num_filter*2,activation='relu',batch_norm=True)
        self.conv3 = ConvBlock(num_filter*2,num_filter*4,activation='relu',batch_norm=True)
        
        #Resnet blocks
        self.resnet_blocks = []
        for i in range(num_resnet):
            self.resnet_blocks.append(ResnetBlock(num_filter*4))
        self.resnet_blocks = torch.nn.Sequential(*self.resnet_blocks)
        
        #Decoder
        self.deconv1 = DeconvBlock(num_filter*4,num_filter*2,activation='relu',batch_norm=True)
        self.deconv2 = DeconvBlock(num_filter*2,num_filter,activation='relu',batch_norm=True)
        ##self.deconv3 = ConvBlock(num_filter,output_dim,kernel_size=(7,7,6),stride=1,bias=True,padding=(2,2,2),activation='tanh',batch_norm=False)
        self.deconv3 = ConvBlock(num_filter,output_dim,kernel_size=(5,5,6),stride=1,padding=(2,2,1),activation='tanh',batch_norm=False)
        
    def forward(self,x):
        #Encoder
        #enc0=self.pad1(x)
        enc1 = self.conv1(x)
        enc2 = self.conv2(enc1)
        enc3 = self.conv3(enc2)
        
        
        #Resnet blocks
        #res0 = self.pad2(enc3)
        
        res=self.resnet_blocks(enc3)
      
        #Decoder
        #dec0=self.pad2(res)
        dec1 = self.deconv1(res)
        dec2 = self.deconv2(dec1)
        ##dec3 = self.pad2(dec2)
        out = self.deconv3(dec2)

        
        return out
    
    def normal_weight_init(self,mean=0.0,std=0.02):
        for m in self.children():
            if isinstance(m,ConvBlock):
                torch.nn.init.normal_(m.conv.weight,mean,std)
            if isinstance(m,DeconvBlock):
                torch.nn.init.normal_(m.deconv.weight,mean,std)
            if isinstance(m,ResnetBlock):
                torch.nn.init.normal_(m.conv.weight,mean,std)
                torch.nn.init.constant_(m.conv.bias,0)

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self,input_dim,num_filter,output_dim):
        super(Discriminator,self).__init__()
        conv1 = ConvBlock2d(input_dim,num_filter,kernel_size=4,stride=2,padding=1,activation='lrelu',batch_norm=False)
        conv2 = ConvBlock2d(num_filter,num_filter*2,kernel_size=4,stride=2,padding=1,activation='lrelu')
        conv3 = ConvBlock2d(num_filter*2,num_filter*4,kernel_size=4,stride=2,padding=1,activation='lrelu')
        conv4 = ConvBlock2d(num_filter*4,num_filter*8,kernel_size=4,stride=2,padding=1,activation='lrelu')
        conv5 = ConvBlock2d(num_filter*8,output_dim,kernel_size=2,stride=1,padding=1,activation='no_act',batch_norm=False)
        self.conv_blocks = torch.nn.Sequential(
            conv1,
            conv2,
            conv3,
            conv4,
            conv5
            )
    def forward(self,x):
        out = self.conv_blocks(x)
        return out
        
    def normal_weight_init(self,mean=0.0,std=0.02):
        for m in self.children():
            if isinstance(m,ConvBlock):
                torch.nn.init.normal_(m.conv.weight.data,mean,std)

In [None]:
transform = transforms.Compose([transforms.ToPILImage(),
    transforms.Resize(size=(params['input_size'],params['input_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

In [None]:
transform29 = transforms.Compose([transforms.ToPILImage(),
    transforms.Resize(size=(params['input_size'],params['stack_num'])),
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

In [None]:
#for linux or Mac os
def manage_folders():

    currentDT = datetime.datetime.now().strftime("%Y_%m_%d-%H:%M")

    #cur_dir='D:/'
    #for windows
    cur_dir = os.getcwd()

    if not os.path.isdir(os.path.join(cur_dir, 'Output')):
        os.mkdir(os.path.join(cur_dir, 'Output'))

    output_folder = os.path.join(cur_dir, 'Output')
    output_folder = os.path.join(output_folder, currentDT)
    os.mkdir(output_folder)

    graph_save_dir = os.path.join(output_folder, 'loss-graphs')
    if not os.path.isdir(graph_save_dir):
        os.mkdir(graph_save_dir)

    im_save_dir = os.path.join(output_folder, 'generated-images')
    if not os.path.isdir(im_save_dir):
        os.mkdir(im_save_dir)

    tr_im_save_dir = os.path.join(im_save_dir, 'train')
    if not os.path.isdir(tr_im_save_dir):
        os.mkdir(tr_im_save_dir)

    te_im_save_dir = os.path.join(im_save_dir, 'te')
    if not os.path.isdir(te_im_save_dir):
        os.mkdir(te_im_save_dir)

    model_save_dir = os.path.join(output_folder, 'saved-models')
    if not os.path.isdir(model_save_dir):
        os.mkdir(model_save_dir)

    # Check if the directories exist
    assert(os.path.isdir(im_save_dir)), 'Check your im_save_dir path.'
    assert(os.path.isdir(graph_save_dir)), 'Check your graph_save_dir path.'

    print('-----Directories to save the output-----\nTrain Fake Images: {}\nVal Fake Images: {}\nLosses: {}\nModel: {}'.format(tr_im_save_dir, te_im_save_dir, graph_save_dir, model_save_dir))

    return tr_im_save_dir, te_im_save_dir, graph_save_dir, model_save_dir

In [None]:
train_data_S = GetSharpData(data_dir, subfolder='trainA2D', transform=transform,crop_size=params['crop_size'])
train_data_loader_S = torch.utils.data.DataLoader(dataset=train_data_S, batch_size=1, shuffle=True)

train_data_B = GetMyData(data_dir, subfolder='trainB', transform=transform,crop_size=params['crop_size'])
train_data_loader_B = torch.utils.data.DataLoader(dataset=train_data_B, batch_size=1, shuffle=True)

test_data_S = GetSharpData(data_dir, subfolder='testA2D', transform=transform,crop_size=params['crop_size'])
test_data_loader_S = torch.utils.data.DataLoader(dataset=test_data_S, batch_size=1, shuffle=False)

test_data_B = GetMyData(data_dir, subfolder='testB', transform=transform,crop_size=params['crop_size'])
test_data_loader_B = torch.utils.data.DataLoader(dataset=test_data_B, batch_size=1, shuffle=False)

train29_data_S = GetVerticalData(data_dir, subfolder='trainA29', transform=transform29,crop_size=params['crop_size'])
train29_data_loader_S = torch.utils.data.DataLoader(dataset=train29_data_S, batch_size=params['batch_size'], shuffle=True)

test29_data_S = GetVerticalData(data_dir, subfolder='testA29', transform=transform29,crop_size=params['crop_size'])
test29_data_loader_S = torch.utils.data.DataLoader(dataset=test29_data_S, batch_size=params['batch_size'], shuffle=False)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
L1_Loss = torch.nn.L1Loss().cuda()

D_B2S_avg_losses = []
D29_B2S_avg_losses = []
D_S2B_avg_losses = []

G_B2S_avg_losses = []
G_S2B_avg_losses = []
cycle_B2S_avg_losses = []
cycle_S2B_avg_losses = []

In [None]:
num_pool = 50
fake_B2S_pool = ImagePool(num_pool)
fake29_B2S_pool=ImagePool(num_pool)
fake_S2B_pool = ImagePool(num_pool)

In [None]:
def gan_loss(inputs, is_real):

    if is_real:

        return F.mse_loss(inputs, torch.ones(inputs.shape).to(device))
    else:
        return F.mse_loss(inputs, torch.zeros(inputs.shape).to(device))
    
    
def gan_loss_max(inputs, is_real):

    if is_real:
        
        G_loss = F.mse_loss(inputs[:,:,:,:,0], torch.ones(inputs[:,:,:,:,0].shape).to(device))
    
        for i in range(1,params['stack_num']):
            
            G_lossi = F.mse_loss(inputs[:,:,:,:,i], torch.ones(inputs[:,:,:,:,i].shape).to(device))   
            G_loss = max(G_loss,G_lossi)       

        return G_loss
    
    else:
        G_loss = F.mse_loss(inputs[:,:,:,:,0], torch.zeros(inputs[:,:,:,:,0].shape).to(device))
    
        for i in range(1,params['stack_num']):
            
            G_lossi = F.mse_loss(inputs[:,:,:,:,i], torch.zeros(inputs[:,:,:,:,i].shape).to(device))   
            G_loss = max(G_loss,G_lossi)
        return G_loss

    
def cycle_loss_original(reconstructed_images, real_images):

    return F.l1_loss(reconstructed_images, real_images)


def cycle_loss_weighted(reconstructed_images, real_images,p_similarity):

    C_loss = F.l1_loss(reconstructed_images[:,:,:,:,0], real_images[:,:,:,:,0])*p_similarity[0]
    
    for i in range(1,params['stack_num']):
        
         C_loss = C_loss+F.l1_loss(reconstructed_images[:,:,:,:,i], real_images[:,:,:,:,i])*p_similarity[i]
        
    return C_loss/params['stack_num']


def cycle_loss(reconstructed_images, real_images,p_similarity):

    C_loss = F.l1_loss(reconstructed_images[:,:,:,:,0], real_images[:,:,:,:,0])*p_similarity[0]
    
    for i in range(1,params['stack_num']):
        
        C_lossi = F.l1_loss(reconstructed_images[:,:,:,:,i], real_images[:,:,:,:,i])*p_similarity[i]    
        C_loss = max(C_loss,C_lossi)        
        
    return C_loss
    

def identity_loss(inputs, real_images):

    return F.l1_loss(inputs, real_images)

In [None]:
def gaussian(sigma,mu, x):

    k = 1 / (sigma * math.sqrt(2*math.pi))
    s = -1.0 / (2 * sigma * sigma)
    G = k * math.exp(s * (x - mu)*(x - mu))
    
    return  round(G*3.4+0.1,2)

anisotropic=numpy.ones(params['stack_num'])
n=(params['stack_num']+1)/2
n=int(n)


for i in range(0,n):
    anisotropic[i]=round(gaussian(1.6,0,0.2*i),2)+0.72

for i in range(n,params['stack_num']):
    anisotropic[i]=round(gaussian(1.6,0,0.2*(params['stack_num']-1-i)),2)+0.72
    

similarity=numpy.ones(params['stack_num'])

for i in range(0,n):
    similarity[i]=gaussian(0.46,0,0.05*i)
    
for i in range(n,params['stack_num']):
    N=(params['stack_num']-1)/2
    similarity[i]=round(gaussian(0.45,0,0.05*(params['stack_num']-1-i))-0.01*(i-N),2)

In [None]:
class cycleGAN(nn.Module):

    def __init__(self, learning_rate=2e-4):
        
        super().__init__()

        nn.Module.__init__(self)

        self.learning_rate = learning_rate
        
        #params['input_size']
        
        #Loss function coeffs
        self.LAMBDA_CYCLE = 10.5
        self.LAMBDA_ID = 0.5
        self.LAMBDA_CROSS = 0.35   #0.3works
        
        
        self.counter = 0;
        self.counter1 = 0;
        self.progress = []
        self.progress1 = []
        
        #Image pool parameter
        pool_size = 50

        #Discriminate test and train behaviour
        self.is_training = True
        self.save_losses = False

        #Initialize the image pools for both domains.
        self.fake_S_pool = ImagePool(pool_size)
        self.fake_B_pool = ImagePool(pool_size)
        self.fake29_S_pool = ImagePool(pool_size)

        #Create dictionaries to save the entire loss progress
        
        self.tr_gen_loss_dict = {
            'loss_gen_S2B': [],
            'loss_gen_B2S': [],
            
            'loss_iden_S2B': [],
            'loss_iden_B2S': [],
            
            'loss_cycle_B2S2B': [],
            'loss_gen_total': []
        }
        self.tr_dis_loss_dict = {
            'loss_dis_B': [],
            'loss_dis_S': [],
            'loss_dis_total': []
        }
        self.te_gen_loss_dict = {
            'loss_gen_S2B': [],
            'loss_gen_B2S': [],
            
            'loss_iden_S2B': [],
            'loss_iden_B2S': [],
            
            'loss_cycle_B2S2B': [],
            'loss_gen_total': []
        }
        self.te_dis_loss_dict = {
            'loss_dis_B': [],
            'loss_dis_S': [],
            'loss_dis_total': []
        }

        self.im_list = []

        self.generator_S2B = Generator(1,params['ngf'],1,params['num_resnet'])
        self.generator_B2S = Generator(1,params['ngf'],1,params['num_resnet'])
        
        self.discriminator_S = Discriminator(1,params['ndf'],1)
        self.discriminator_B = Discriminator(1,params['ndf'],1)
        
        self.optimizer_G = torch.optim.Adam(itertools.chain(self.generator_S2B.parameters(), self.generator_B2S.parameters()), lr=self.learning_rate)
        self.optimizer_D = torch.optim.Adam(itertools.chain(self.discriminator_S.parameters(), self.discriminator_B.parameters()), lr=self.learning_rate)
                  
        

    def forward(self, real_B):

        fake_B2S = self.generator_B2S(real_B)
        fake_S2B = self.generator_S2B(fake_B2S)
        recon_S2B = fake_S2B
        #recon_S2B = self.generator_S2B(fake_B2S)
        
        jj = random.randint(0,params['input_size']-1)
        fake29_B2S=fake_B2S[:,:,jj,:,:]

        identity_S2B = self.generator_S2B(real_B)
        identity_B2S = self.generator_B2S(fake_B2S)

        self.im_list = [real_B,fake_B2S,recon_S2B,identity_B2S]

        return fake_S2B, recon_S2B,fake_B2S,identity_S2B,identity_B2S,fake29_B2S
    

    def backward_G(self, real_B, fake_S2B, fake_B2S, recon_S2B, identity_S2B, identity_B2S, fake29_B2S):
        
        if self.is_training:
          
            self.set_requires_grad([self.discriminator_B,self.discriminator_S], False)
      
            self.optimizer_G.zero_grad()

        loss_identity_S2B = identity_loss(identity_S2B, real_B)
        loss_identity_B2S = identity_loss(identity_B2S, fake_B2S)
                                            
        ii0= random.randint(0,params['stack_num']-1)                                    
        loss_gan_gen_S2B = anisotropic[ii0]*gan_loss(self.discriminator_B(fake_S2B[:,:,:,:,ii0]), True)
        
        jj= random.randint(0,params['input_size']-1)
        ii= random.randint(0,params['stack_num']-1)
        pp= random.randint(0,1)
        loss_gan_gen_B2S = anisotropic[ii]*(1-self.LAMBDA_CROSS)*gan_loss(self.discriminator_S(fake_B2S[:,:,:,:,ii]), True) + self.LAMBDA_CROSS *0.5*pp*gan_loss(self.discriminator_S(fake_B2S[:,:,jj,:,:]), True)+self.LAMBDA_CROSS *0.5*(1-pp)*gan_loss(self.discriminator_S(fake_B2S[:,:,:,jj,:]), True)
                                                                          
        loss_cycle_B2S2B = cycle_loss(recon_S2B, real_B, similarity)

        # Total generator loss
        loss_gen_total = loss_gan_gen_S2B + loss_gan_gen_B2S \
            + loss_cycle_B2S2B * self.LAMBDA_CYCLE \
            + (loss_identity_S2B+loss_identity_B2S) * self.LAMBDA_ID
        
        loss_plot=loss_gen_total
        loss_plot1=loss_gan_gen_S2B
        loss_plot2=loss_gan_gen_B2S
        loss_plot3=loss_cycle_B2S2B
        
        self.counter += 1;
        
        if (self.counter % 10 == 0):
            self.progress.append(loss_plot.item())
            self.progress.append(loss_plot1.item())
            self.progress.append(loss_plot2.item())
            self.progress.append(loss_plot3.item())
            pass
        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)
            pass

        if self.is_training:
            # Calculate gradients
            loss_gen_total.backward()

            self.optimizer_G.step()


        if self.save_losses:
            if self.is_training:
                self.tr_gen_loss_dict['loss_gen_S2B'].append(loss_gan_gen_S2B.item())
                self.tr_gen_loss_dict['loss_gen_B2S'].append(loss_gan_gen_B2S.item())
                self.tr_gen_loss_dict['loss_iden_S2B'].append(loss_identity_S2B.item())
                
                self.tr_gen_loss_dict['loss_cycle_B2S2B'].append(loss_cycle_B2S2B.item())
                self.tr_gen_loss_dict['loss_gen_total'].append(loss_gen_total.item())
            else:
                self.te_gen_loss_dict['loss_gen_S2B'].append(loss_gan_gen_S2B.item())
                self.te_gen_loss_dict['loss_gen_B2S'].append(loss_gan_gen_B2S.item())
                self.te_gen_loss_dict['loss_iden_S2B'].append(loss_identity_S2B.item())
           
                self.te_gen_loss_dict['loss_cycle_B2S2B'].append(loss_cycle_B2S2B.item())
                self.te_gen_loss_dict['loss_gen_total'].append(loss_gen_total.item())



    def backward_D(self, real_S, real_B, fake_S2B, fake_B2S,real29_S,fake29_B2S):

        fake_S2B = self.fake_B_pool.query(fake_S2B)
        fake_B2S = self.fake_S_pool.query(fake_B2S)
        fake29_B2S = self.fake29_S_pool.query(fake29_B2S)

        if self.is_training:
            self.set_requires_grad([self.discriminator_B,self.discriminator_S], True)
            self.optimizer_D.zero_grad()  

        loss_gan_dis_S_real = (1-self.LAMBDA_CROSS)*gan_loss(self.discriminator_S(real_S), True)+self.LAMBDA_CROSS *gan_loss(self.discriminator_S(real29_S), True)
      
        jj = random.randint(0,params['input_size']-1)
        ii = random.randint(0,params['stack_num']-1)
        pp = random.randint(0,1)
        loss_gan_dis_S_fake = anisotropic[ii]*(1-self.LAMBDA_CROSS)*gan_loss(self.discriminator_S(fake_B2S[:,:,:,:,ii].detach()), False)+ self.LAMBDA_CROSS*0.5*pp*gan_loss(self.discriminator_S(fake_B2S[:,:,jj,:,:].detach()), False)+ self.LAMBDA_CROSS*0.5*(1-pp)*gan_loss(self.discriminator_S(fake_B2S[:,:,:,jj,:].detach()), False) # Detach added
                                  
        # Discriminator B should classify real_b as B
        ii = random.randint(0,params['stack_num']-1)                                                    
        loss_gan_dis_B_real =anisotropic[ii]* gan_loss(self.discriminator_B(real_B[:,:,:,:,ii]), True)
        # Discriminator B should classify generated fake_a2b as not B
        ii = random.randint(0,params['stack_num']-1)
        loss_gan_dis_B_fake = anisotropic[ii]*gan_loss(self.discriminator_B(fake_S2B[:,:,:,:,ii].detach()), False) # Detach added
                                            

        # Total discriminator loss
        loss_dis_S = (loss_gan_dis_S_real + loss_gan_dis_S_fake) * 0.5
        
        loss_dis_B = (loss_gan_dis_B_real + loss_gan_dis_B_fake) * 0.5

        loss_dis_total = loss_dis_S + loss_dis_B
        
        loss_plott=loss_dis_total
        loss_plott1=loss_dis_S
        loss_plott2=loss_dis_B
        
        
        self.counter1 += 1;
        
        if (self.counter1 % 10 == 0):
            self.progress1.append(loss_plott.item())
            self.progress1.append(loss_plott1.item())
            self.progress1.append(loss_plott2.item())

            pass
        if (self.counter1 % 1000 == 0):
            print("counter1 = ", self.counter1)
            pass


        if self.is_training:
            # Calculate gradients
            loss_dis_total.backward()
            # Update D_A and D_B's weights
            self.optimizer_D.step()

        # Save train and test losses separately
        if self.save_losses:
            if self.is_training:
                self.tr_dis_loss_dict['loss_dis_B'].append(loss_dis_B.item())
                self.tr_dis_loss_dict['loss_dis_S'].append(loss_dis_S.item())
                self.tr_dis_loss_dict['loss_dis_total'].append(loss_dis_total.item())
            else:
                self.te_dis_loss_dict['loss_dis_B'].append(loss_dis_B.item())
                self.te_dis_loss_dict['loss_dis_S'].append(loss_dis_S.item())
                self.te_dis_loss_dict['loss_dis_total'].append(loss_dis_total.item())

    def set_requires_grad(self, nets, requires_grad=False):
      
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
                    
                    
    def plot_progress(self):
        
        df = pandas.DataFrame(self.progress, columns=['loss_plot'])
        df1=pandas.concat([df, pandas.DataFrame(columns = [ 'loss_plot1'])])
        df1=pandas.concat([df1, pandas.DataFrame(columns = [ 'loss_plot2'])])
        df1=pandas.concat([df1, pandas.DataFrame(columns = [ 'loss_plot3'])])

        fig,ax= plt.subplots()
        
        ax=df1[['loss_plot','loss_plot1','loss_plot2','loss_plot3']].plot.area(ax=ax)

        ax.autoscale()
        ax.set_ylim(0,None)
        ax.margins(x=0)

        plt.show()
        
        pass
    
    
    def plot_progress1(self):
        
        df = pandas.DataFrame(self.progress1, columns=['loss_plott'])
        df1=pandas.concat([df, pandas.DataFrame(columns = [ 'loss_plott1'])])
        df1=pandas.concat([df1, pandas.DataFrame(columns = [ 'loss_plott2'])])
        
        fig,ax= plt.subplots()
  
        ax=df1[['loss_plott','loss_plott1','loss_plott2']].plot.area(ax=ax)

        ax.autoscale()
        ax.set_ylim(0,None)
        ax.margins(x=0)

        plt.show()
        
        pass

    def optimize_parameters(self, real_S, real_B,real29_S):

        # Forward
        fake_S2B, recon_S2B,fake_B2S,identity_S2B,identity_B2S,fake29_B2S = self.forward(real_B)  # compute fake images and reconstruction images.
        # G_A and G_B
        self.backward_G(real_B, fake_S2B, fake_B2S, recon_S2B, identity_S2B, identity_B2S, fake29_B2S)  # calculate gradients for G_A and G_B
        # D_A and D_B
        self.backward_D(real_S, real_B, fake_S2B, fake_B2S,real29_S,fake29_B2S)  # To-Do: Query fake images from the pool.

In [None]:
def train(train_dataset_S, train_dataset_B,test_dataset_S,test_dataset_B,train_dataset29_S,test_dataset29_S ,epochs, device):
 
    model = cycleGAN().to(device)

    for epoch in range(epochs):

        print('Epoch', epoch+1, '------------------')
        

        #Training
        temp = 1
        model.is_training = True
        
        for i, (data_S, data_B,data29_S) in enumerate(zip(train_dataset_S,train_dataset_B,train_dataset29_S)):
        
            #Input image data
            data_S = data_S.to(device)
            data29_S = data29_S.to(device)
            data_B = data_B.to(device)
      

            #Save loss values at the end of each epoch
            if temp == train_dataset_S.__len__():
                model.save_losses = True

            model.optimize_parameters(data_S, data_B,data29_S)

            temp = temp+1
            
        model.plot_progress()
        model.plot_progress1()

        print('Tr - Total Generator Loss:', np.round(model.tr_gen_loss_dict['loss_gen_total'][-1], decimals=3))
        print('Tr - Total Dicriminator Loss:', np.round(model.tr_dis_loss_dict['loss_dis_total'][-1], decimals=3))

        model.save_losses = False

        if epoch % 10 == 0:
            
            print_images(model.im_list, tr_im_save_dir, str(epoch), save_mode_on=True)

        #Test
        with torch.set_grad_enabled(False):

            temp = 1
            model.is_training = False
            for i, (data_S, data_B,data29_S) in enumerate(zip(test_dataset_S,test_dataset_B,test_dataset29_S)):

                data_S = data_S.to(device)
                data29_S = data29_S.to(device)
                data_B = data_B.to(device)

                if temp == test_dataset_S.__len__():
                    model.save_losses = True

                model.optimize_parameters(data_S, data_B,data29_S)

                temp = temp+1

            print('----')
            print('te - Total Generator Loss:', np.round(model.te_gen_loss_dict['loss_gen_total'][-1], decimals=3))
            print('te - Total Dicriminator Loss:', np.round(model.te_dis_loss_dict['loss_dis_total'][-1], decimals=3))

            model.save_losses = False

            print_images(model.im_list, te_im_save_dir, str(epoch), save_mode_on=True)
            
            test_real_Blur = test_blur_data.cuda()
    
            Results = model.forward(test_real_Blur)
    
        
            test_fake_Blur = Results[0]
    
            test_recon_Sharp = Results[4]
    
            test_fake_Sharp = Results[2]
    
            test_recon_Blur = Results[1]
        
            test_recon_Blur29 = Results[5]

            ##plot_train_result([test_real_Blur, test_recon_Blur], [test_fake_Sharp, test_recon_Sharp],epoch, save=False)
            plot_train_result29(test_recon_Blur29,epoch, save=False)
            
            if epoch  > params['save_epoch']:
                torch.save(model,os.path.join(model_save_dir,'model_%03d.pth'%epoch))
    
 
    #Save gen and disc loss values to respective csv files.
    df = pd.DataFrame.from_dict(model.tr_gen_loss_dict)
    df.to_csv(os.path.join(graph_save_dir, 'tr_gen_losses.csv'), index=False)
    
    df = pd.DataFrame.from_dict(model.tr_dis_loss_dict)
    df.to_csv(os.path.join(graph_save_dir, 'tr_dis_losses.csv'), index=False)
    
    #Save gen and disc loss values to respective csv files.
    df = pd.DataFrame.from_dict(model.te_gen_loss_dict)
    df.to_csv(os.path.join(graph_save_dir, 'te_gen_losses.csv'), index=False)
    
    df = pd.DataFrame.from_dict(model.te_dis_loss_dict)
    df.to_csv(os.path.join(graph_save_dir, 'te_dis_losses.csv'), index=False)

    #Save entire model architecture and params.
    torch.save(model, 'model.pth')

In [36]:
tr_im_save_dir, te_im_save_dir, graph_save_dir, model_save_dir = manage_folders()

-----Directories to save the output-----
Train Fake Images: /Users/Honglei/Output/2023_05_03-18:22/generated-images/train
Val Fake Images: /Users/Honglei/Output/2023_05_03-18:22/generated-images/te
Losses: /Users/Honglei/Output/2023_05_03-18:22/loss-graphs
Model: /Users/Honglei/Output/2023_05_03-18:22/saved-models


In [None]:
%%time
train(train_data_loader_S,train_data_loader_B,test_data_loader_S,test_data_loader_B,train29_data_loader_S,test29_data_loader_S,200, device)

Epoch 1 ------------------
