In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.utils.data as data
import torch.nn.functional as F
from torchvision import transforms
from torch.nn.modules.utils import _ntuple

import pandas as pd
import pandas
import os , itertools
from PIL import Image
import matplotlib.pyplot as plt

import datetime

import numpy
import numpy as np
from numpy import load
import random

In [None]:
params = {
    'batch_size':1,
    'input_size':256,
    'resize_scale':286,
    'resize':128,
    'crop_size':256,
    'fliplr':True,
    #model params
    'num_epochs':100,
    'decay_epoch':100,
    'save_epoch': 70,
    'ngf':32,   #number of generator filters
    'ndf':64,   #number of discriminator filters
    'num_resnet':6, #number of resnet blocks
    'lrG':0.0002,    #learning rate for generator
    'lrD':0.0002,    #learning rate for discriminator
    'beta1':0.5 ,    #beta1 for Adam optimizer
    'beta2':0.999 ,  #beta2 for Adam optimizer
    'lambdaA':10 ,   #lambdaA for cycle loss
    'lambdaB':10  ,  #lambdaB for cycle loss
}

In [None]:
print(os.listdir("./vangogh2photo"))
data_dir = './vangogh2photo'

In [None]:
torch.cuda.is_available()

In [None]:
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

In [None]:
def to_np(x):
    return x.data.cpu().numpy()

In [None]:
class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, subfolder='train', transform=None, resize_scale=None, crop_size=None, fliplr=False):
        super(DatasetFromFolder, self).__init__()
        self.input_path = os.path.join(image_dir, subfolder)
        self.image_filenames = [x for x in sorted(os.listdir(self.input_path))]
        self.transform = transform
        
        self.resize_scale = resize_scale
        self.crop_size = crop_size
        self.fliplr = fliplr

    def __getitem__(self, index):
        # Load Image
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        img = Image.open(img_fn).convert('RGB')

        # preprocessing
        if self.resize_scale:
            img = img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR)

        if self.crop_size:
            x = random.randint(0, self.resize_scale - self.crop_size + 1)
            y = random.randint(0, self.resize_scale - self.crop_size + 1)
            img = img.crop((x, y, x + self.crop_size, y + self.crop_size))
        if self.fliplr:
            if random.random() < 0.5:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)

        if self.transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        return len(self.image_filenames)

In [None]:
class GetMyData(data.Dataset):
    
    def __init__(self, image_dir, subfolder='train', transform=None, crop_size=None, fliplr=False):
        super(GetMyData, self).__init__()
        self.input_path = os.path.join(image_dir, subfolder)
        self.image_filenames = [x for x in sorted(os.listdir(self.input_path))]
        self.transform = transform
        self.crop_size = crop_size

    def __getitem__(self, index):
        #Load Image
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        imgnpz = np.load(img_fn, mmap_mode='r', allow_pickle=True)
        image = imgnpz['arr_0']
        
        #Preprocessing
        #img = numpy.array(img)
        imarray = numpy.array(image)
        iim=image
        
        #Preprocessing
        
        if self.crop_size:
            
            x = random.randint(0, params['input_size'] - self.crop_size + 1)
            y = random.randint(0, params['input_size'] - self.crop_size + 1)
            
            #iim = im.crop((x, y, x + self.crop_size, y + self.crop_size))
            
            iim=torch.from_numpy(imarray)
            iim=iim[x:x + self.crop_size,y:y+ self.crop_size,:]

            
    
        if self.transform is not None:
            iim=torch.from_numpy(imarray)
            
            aaa=imarray[:,:,0]
            aaa0 = self.transform(aaa)
               
            iim = aaa0[...,np.newaxis]

            for i in range (1,params['stack_num']):   
                 
                aaa=imarray[:,:,i]
                aaa1 = self.transform(aaa)
               
                aaa1 = aaa1[...,np.newaxis]
                iim=np.concatenate((iim,aaa1),3)

        return iim

    def __len__(self):
        return len(self.image_filenames)

In [None]:
def plot_train_result(real_image, gen_image, recon_image, epoch, save=False,  show=True, fig_size=(15, 15)):
    fig, axes = plt.subplots(2, 3, figsize=fig_size)
    imgs = [to_np(real_image[0]), to_np(gen_image[0]), to_np(recon_image[0]),
            to_np(real_image[1]), to_np(gen_image[1]), to_np(recon_image[1])]
    for ax, img in zip(axes.flatten(), imgs):
        ax.axis('off')
        #ax.set_adjustable('box-forced')
        # Scale to 0-255
        img = img.squeeze()
        img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8)
        ax.imshow(img, cmap=None, aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)

    title = 'Epoch {0}'.format(epoch + 1)
    fig.text(0.5, 0.04, title, ha='center')
    
    # save figure
    if save:
        save_fn = 'Result_epoch_{:d}'.format(epoch+1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()

In [None]:
def print_images(im_list, save_dir, epoch_num, save_mode_on=True):
    """
        Pytorch conv2d uses input & output dimensions as: (N, C, H, W).
        To be able to plot the generated images, torch tensors must be converted back to (W,H,C)
        and transferred back into the local memory by .cpu() function
    """

    #A list that holds the necessary plot titles
    titles = ['Real-B', 'Fake-A (B->A)', 'Recon-B (B->A->B)', 'Identity-A (A->A)']

    im_idx = 0
    fig, axarr = plt.subplots(1,4, figsize=(12, 12))

    for j in range(4):

        im = im_list[im_idx].squeeze().T
        im = (im + 1) / 2.0
        imm=im
        axarr[j].axis('off')
        axarr[j].imshow(imm.detach().cpu())
        axarr[j].set_title(titles[im_idx], fontweight="bold")

        im_idx = im_idx + 1

    fig.tight_layout()

    if save_mode_on:
        plt.savefig(os.path.join(save_dir, 'epoch-{}.png'.format(epoch_num)))
        plt.close()
    else:
        plt.show()

In [None]:
class ResnetBlock(torch.nn.Module):
    def __init__(self,num_filter,kernel_size=3,stride=1,padding=0):
        super(ResnetBlock,self).__init__()
        conv1 = torch.nn.Conv2d(num_filter,num_filter,kernel_size,stride,padding)
        conv2 = torch.nn.Conv2d(num_filter,num_filter,kernel_size,stride,padding)
        bn = torch.nn.InstanceNorm2d(num_filter)
        relu = torch.nn.ReLU(True)
        pad = torch.nn.ReflectionPad2d(1)
        
        self.resnet_block = torch.nn.Sequential(
            pad,
            conv1,
            bn,
            relu,
            pad,
            conv2,
            bn
            )
    def forward(self,x):
        out = self.resnet_block(x)
        return out

In [None]:
class ConvBlock(torch.nn.Module):
    def __init__(self,input_size,output_size,kernel_size=3,stride=2,padding=1,activation='relu',batch_norm=True):
        super(ConvBlock,self).__init__()
        self.conv = torch.nn.Conv2d(input_size,output_size,kernel_size,stride,padding)
        self.batch_norm = batch_norm
        self.bn = torch.nn.InstanceNorm2d(output_size)
        self.activation = activation
        self.relu = torch.nn.ReLU(True)
        self.lrelu = torch.nn.LeakyReLU(0.2,True)
        self.tanh = torch.nn.Tanh()
    def forward(self,x):
        if self.batch_norm:
            out = self.bn(self.conv(x))
        else:
            out = self.conv(x)
        
        if self.activation == 'relu':
            return self.relu(out)
        elif self.activation == 'lrelu':
            return self.lrelu(out)
        elif self.activation == 'tanh':
            return self.tanh(out)
        elif self.activation == 'no_act':
            return out

In [None]:
class DeconvBlock(torch.nn.Module):
    def __init__(self,input_size,output_size,kernel_size=3,stride=2,padding=1,output_padding=1,activation='relu',batch_norm=True):
        super(DeconvBlock,self).__init__()
        self.deconv = torch.nn.ConvTranspose2d(input_size,output_size,kernel_size,stride,padding,output_padding)
        self.batch_norm = batch_norm
        self.bn = torch.nn.InstanceNorm2d(output_size)
        self.activation = activation
        self.relu = torch.nn.ReLU(True)
    def forward(self,x):
        if self.batch_norm:
            out = self.bn(self.deconv(x))
        else:
            out = self.deconv(x)
        if self.activation == 'relu':
            return self.relu(out)
        elif self.activation == 'lrelu':
            return self.lrelu(out)
        elif self.activation == 'tanh':
            return self.tanh(out)
        elif self.activation == 'no_act':
            return out

In [None]:
class Generator(torch.nn.Module):
    def __init__(self,input_dim,num_filter,output_dim,num_resnet):
        super(Generator,self).__init__()
        
        #Reflection padding
        self.pad = torch.nn.ReflectionPad2d(3)
        #Encoder
        self.conv1 = ConvBlock(input_dim,num_filter,kernel_size=7,stride=1,padding=0)
        self.conv2 = ConvBlock(num_filter,num_filter*2)
        self.conv3 = ConvBlock(num_filter*2,num_filter*4)
        #Resnet blocks
        self.resnet_blocks = []
        for i in range(num_resnet):
            self.resnet_blocks.append(ResnetBlock(num_filter*4))
        self.resnet_blocks = torch.nn.Sequential(*self.resnet_blocks)
        #Decoder
        self.deconv1 = DeconvBlock(num_filter*4,num_filter*2)
        self.deconv2 = DeconvBlock(num_filter*2,num_filter)
        self.deconv3 = ConvBlock(num_filter,output_dim,kernel_size=7,stride=1,padding=0,activation='tanh',batch_norm=False)
    
    def forward(self,x):
        #Encoder
        enc1 = self.conv1(self.pad(x))
        enc2 = self.conv2(enc1)
        enc3 = self.conv3(enc2)
        #Resnet blocks
        res = self.resnet_blocks(enc3)
        #Decoder
        dec1 = self.deconv1(res)
        dec2 = self.deconv2(dec1)
        out = self.deconv3(self.pad(dec2))
        return out
    def normal_weight_init(self,mean=0.0,std=0.02):
        for m in self.children():
            if isinstance(m,ConvBlock):
                torch.nn.init.normal_(m.conv.weight,mean,std)
            if isinstance(m,DeconvBlock):
                torch.nn.init.normal_(m.deconv.weight,mean,std)
            if isinstance(m,ResnetBlock):
                torch.nn.init.normal_(m.conv.weight,mean,std)
                torch.nn.init.constant_(m.conv.bias,0)

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self,input_dim,num_filter,output_dim):
        super(Discriminator,self).__init__()
        conv1 = ConvBlock(input_dim,num_filter,kernel_size=4,stride=2,padding=1,activation='lrelu',batch_norm=False)
        conv2 = ConvBlock(num_filter,num_filter*2,kernel_size=4,stride=2,padding=1,activation='lrelu')
        conv3 = ConvBlock(num_filter*2,num_filter*4,kernel_size=4,stride=2,padding=1,activation='lrelu')
        conv4 = ConvBlock(num_filter*4,num_filter*8,kernel_size=4,stride=1,padding=1,activation='lrelu')
        conv5 = ConvBlock(num_filter*8,output_dim,kernel_size=4,stride=1,padding=1,activation='no_act',batch_norm=False)
        self.conv_blocks = torch.nn.Sequential(
            conv1,
            conv2,
            conv3,
            conv4,
            conv5
            )
    def forward(self,x):
        out = self.conv_blocks(x)
        return out
        
    def normal_weight_init(self,mean=0.0,std=0.02):
        for m in self.children():
            if isinstance(m,ConvBlock):
                torch.nn.init.normal_(m.conv.weight.data,mean,std)

In [None]:
#for windows systems

def manage_folders_mac():

    currentDT = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")

    #cur_dir = os.getcwd()
    
    cur_dir='./'

    if not os.path.isdir(os.path.join(cur_dir, 'Output')):
        os.mkdir(os.path.join(cur_dir, 'Output'))

    output_folder = os.path.join(cur_dir, 'Output')
    output_folder = os.path.join(output_folder, currentDT)
    os.mkdir(output_folder)

    graph_save_dir = os.path.join(output_folder, 'loss-graphs')
    if not os.path.isdir(graph_save_dir):
        os.mkdir(graph_save_dir)

    im_save_dir = os.path.join(output_folder, 'generated-images')
    if not os.path.isdir(im_save_dir):
        os.mkdir(im_save_dir)

    tr_im_save_dir = os.path.join(im_save_dir, 'train')
    if not os.path.isdir(tr_im_save_dir):
        os.mkdir(tr_im_save_dir)

    te_im_save_dir = os.path.join(im_save_dir, 'te')
    if not os.path.isdir(te_im_save_dir):
        os.mkdir(te_im_save_dir)

    model_save_dir = os.path.join(output_folder, 'saved-models')
    if not os.path.isdir(model_save_dir):
        os.mkdir(model_save_dir)

    # Check if the directories exist
    assert(os.path.isdir(im_save_dir)), 'Check your im_save_dir path.'
    assert(os.path.isdir(graph_save_dir)), 'Check your graph_save_dir path.'

    print('-----Directories to save the output-----\nTrain Fake Images: {}\nTest Fake Images: {}\nLosses: {}\nModel: {}'.format(tr_im_save_dir, te_im_save_dir, graph_save_dir, model_save_dir))

    return tr_im_save_dir, te_im_save_dir, graph_save_dir, model_save_dir

In [None]:
#for linux
def manage_folders():

    currentDT = datetime.datetime.now().strftime("%Y_%m_%d-%H:%M")

    cur_dir = os.getcwd()

    if not os.path.isdir(os.path.join(cur_dir, 'Output')):
        os.mkdir(os.path.join(cur_dir, 'Output'))

    output_folder = os.path.join(cur_dir, 'Output')
    output_folder = os.path.join(output_folder, currentDT)
    os.mkdir(output_folder)

    graph_save_dir = os.path.join(output_folder, 'loss-graphs')
    if not os.path.isdir(graph_save_dir):
        os.mkdir(graph_save_dir)

    im_save_dir = os.path.join(output_folder, 'generated-images')
    if not os.path.isdir(im_save_dir):
        os.mkdir(im_save_dir)

    tr_im_save_dir = os.path.join(im_save_dir, 'train')
    if not os.path.isdir(tr_im_save_dir):
        os.mkdir(tr_im_save_dir)

    te_im_save_dir = os.path.join(im_save_dir, 'te')
    if not os.path.isdir(te_im_save_dir):
        os.mkdir(te_im_save_dir)

    model_save_dir = os.path.join(output_folder, 'saved-models')
    if not os.path.isdir(model_save_dir):
        os.mkdir(model_save_dir)

    # Check if the directories exist
    assert(os.path.isdir(im_save_dir)), 'Check your im_save_dir path.'
    assert(os.path.isdir(graph_save_dir)), 'Check your graph_save_dir path.'

    print('-----Directories to save the output-----\nTrain Fake Images: {}\nVal Fake Images: {}\nLosses: {}\nModel: {}'.format(tr_im_save_dir, te_im_save_dir, graph_save_dir, model_save_dir))

    return tr_im_save_dir, te_im_save_dir, graph_save_dir, model_save_dir

In [None]:
transform = transforms.Compose([
    transforms.Resize(size=params['input_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

In [None]:
train_data_A = DatasetFromFolder(data_dir, subfolder='trainA', transform=transform,
                                resize_scale=params['resize_scale'], crop_size=params['crop_size'], fliplr=params['fliplr'])
train_data_loader_A = torch.utils.data.DataLoader(dataset=train_data_A, batch_size=params['batch_size'], shuffle=True)


train_data_B = DatasetFromFolder(data_dir, subfolder='trainB', transform=transform,
                                resize_scale=params['resize_scale'], crop_size=params['crop_size'], fliplr=params['fliplr'])
train_data_loader_B = torch.utils.data.DataLoader(dataset=train_data_B, batch_size=params['batch_size'], shuffle=True)


test_data_A = DatasetFromFolder(data_dir, subfolder='testA', transform=transform,
                                resize_scale=params['resize_scale'], crop_size=params['crop_size'], fliplr=params['fliplr'])
test_data_loader_A = torch.utils.data.DataLoader(dataset=train_data_A, batch_size=params['batch_size'], shuffle=True)


test_data_B = DatasetFromFolder(data_dir, subfolder='testB', transform=transform,
                                resize_scale=params['resize_scale'], crop_size=params['crop_size'], fliplr=params['fliplr'])
test_data_loader_B = torch.utils.data.DataLoader(dataset=train_data_B, batch_size=params['batch_size'], shuffle=True)


In [None]:

tryimgA = train_data_A[99]
tryimgB = train_data_B[9]
tryimgAt = test_data_A[19]
tryimgBt = test_data_B[9]
print(tryimgA.shape)
print(tryimgB.shape)
print(tryimgAt.shape)
print(tryimgBt.shape)


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print ("device",device)

In [None]:
test_real_A_data = test_data_A.__getitem__(11).unsqueeze(0) # Convert to 4d tensor (BxNxHxW)
test_real_B_data = test_data_B.__getitem__(96).unsqueeze(0)

In [None]:
print(test_real_A_data.shape)
print(test_real_B_data.shape)
generator_A2B = Generator(3,params['ngf'],3,params['num_resnet'])
generator_B2A = Generator(3,params['ngf'],3,params['num_resnet'])
print(generator_A2B(test_real_A_data).shape)
print(generator_B2A(generator_A2B(test_real_A_data)).shape)

In [None]:
def cycle_loss(reconstructed_images, real_images):

    return F.l1_loss(reconstructed_images, real_images)
    
def identity_loss(inputs, real_images):

    return F.l1_loss(inputs, real_images)
    
def gan_loss(inputs, is_real):

    if is_real:

        return F.mse_loss(inputs, torch.ones(inputs.shape).to(device))
    else:
        return F.mse_loss(inputs, torch.zeros(inputs.shape).to(device))

In [None]:
test=cycle_loss(generator_A2B(test_real_A_data),test_real_A_data)
print(test)
#for i, (data_A, data_B) in enumerate(zip(train_data_loader_A,train_data_loader_B)):
        
        # input image data
#    data_A = data_A.to(device)
#    print(data_A.shape)
#    data_B = data_B.to(device)
#    print(data_B.shape)

In [None]:
class cycleGAN(nn.Module):

    def __init__(self, learning_rate=2e-4):
        
        super().__init__()

        nn.Module.__init__(self)

        self.learning_rate = learning_rate
        
        #params['input_size']

        # Loss function coeffs
        self.LAMBDA_CYCLE = 10.5
        self.LAMBDA_ID = 0.5
        self.LAMBDA_CROSS = 0.35   #0.3works

        self.beta1 = 0.5     #beta1 for Adam optimizer
        self.beta2 = 0.999  #beta2 for Adam optimizer
        
        
        self.counter = 0
        self.counter1 = 0
        self.progress = []
        self.progress1 = []
        
        # Image pool parameter
        pool_size = 50

        # Discriminate test and train behaviour
        self.is_training = True
        self.save_losses = False

        # Initialize the image pools for both domains.
        self.fake_A_pool = ImagePool(pool_size)
        self.fake_B_pool = ImagePool(pool_size)

# Create dictionaries to save the entire loss progress
                
        self.tr_gen_loss_dict = {
            'loss_gen_A2B': [],
            'loss_gen_B2A': [],
            
            'loss_iden_A2B': [],
            'loss_iden_B2A': [],
            
            'loss_cycle_B2A2B': [],
            'loss_gen_total': []
        }
        self.tr_dis_loss_dict = {
            'loss_dis_B': [],
            'loss_dis_A': [],
            'loss_dis_total': []
        }
        self.te_gen_loss_dict = {
            'loss_gen_A2B': [],
            'loss_gen_B2A': [],
            
            'loss_iden_A2B': [],
            'loss_iden_B2A': [],
            
            'loss_cycle_B2A2B': [],
            'loss_gen_total': []
        }
        self.te_dis_loss_dict = {
            'loss_dis_B': [],
            'loss_dis_A': [],
            'loss_dis_total': []
        }

        self.im_list = []
        
        self.generator_A2B = Generator(3,params['ngf'],3,params['num_resnet'])
        self.generator_B2A = Generator(3,params['ngf'],3,params['num_resnet'])
        
        self.discriminator_A = Discriminator(3,params['ndf'],1)
        self.discriminator_B = Discriminator(3,params['ndf'],1)
        
        self.optimizer_total = torch.optim.Adam(itertools.chain(self.generator_A2B.parameters(), self.generator_B2A.parameters()), lr=self.learning_rate, betas=(params['beta1'], params['beta2']))
        self.optimizer_G = torch.optim.Adam(itertools.chain(self.generator_A2B.parameters(), self.generator_B2A.parameters()), lr=self.learning_rate, betas=(params['beta1'], params['beta2']))
        self.optimizer_D = torch.optim.Adam(itertools.chain(self.discriminator_A.parameters(), self.discriminator_B.parameters()), lr=self.learning_rate, betas=(params['beta1'], params['beta2']))

    def forward(self, real_A, real_B):

        fake_B2A = self.generator_B2A(real_B)
        fake_A2B = self.generator_A2B(real_A)
        
        recon_A2B = self.generator_A2B(fake_B2A)
        recon_B2A = self.generator_B2A(fake_A2B)
        
        identity_A2B = self.generator_A2B(real_B)
        identity_B2A = self.generator_B2A(real_A)

        self.im_list = [real_B,fake_B2A,recon_A2B,identity_B2A]

        return fake_A2B, recon_A2B, fake_B2A, recon_B2A, identity_A2B, identity_B2A

    def backward_G(self, real_A, real_B, fake_A2B, fake_B2A, recon_A2B, recon_B2A, identity_A2B, identity_B2A):
        
        if self.is_training:
          
            self.set_requires_grad([self.discriminator_B,self.discriminator_A], False)
      
            self.optimizer_G.zero_grad()

        loss_identity_A2B = identity_loss(identity_A2B, real_B)
        loss_identity_B2A = identity_loss(identity_B2A, real_A)

                                            
        loss_gan_gen_B2A = gan_loss(self.discriminator_A(fake_B2A), True)
        loss_gan_gen_A2B = gan_loss(self.discriminator_B(fake_A2B), True) 
                                                                          
        loss_cycle_B2A2B = cycle_loss(recon_A2B, real_B)

        # Total generator loss
        loss_gen_total = loss_gan_gen_A2B + loss_gan_gen_B2A \
            + loss_cycle_B2A2B * self.LAMBDA_CYCLE \
            + (loss_identity_A2B + loss_identity_B2A) * self.LAMBDA_ID 
        
        loss_plot=loss_gen_total
        loss_plot1=loss_gan_gen_A2B
        loss_plot2=loss_gan_gen_B2A
        loss_plot3=loss_cycle_B2A2B

        self.counter += 1;
        
        if (self.counter % 10 == 0):
            self.progress.append(loss_plot.item())
            self.progress.append(loss_plot1.item())
            self.progress.append(loss_plot2.item())
            self.progress.append(loss_plot3.item())
            pass
        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)
            pass

        if self.is_training:
            # Calculate gradients
            loss_gen_total.backward()

            self.optimizer_G.step()
            
        if self.save_losses:
            if self.is_training:
                self.tr_gen_loss_dict['loss_gen_A2B'].append(loss_gan_gen_A2B.item())
                self.tr_gen_loss_dict['loss_gen_B2A'].append(loss_gan_gen_B2A.item())
                self.tr_gen_loss_dict['loss_iden_A2B'].append(loss_identity_A2B.item())
                
                self.tr_gen_loss_dict['loss_cycle_B2A2B'].append(loss_cycle_B2A2B.item())
                self.tr_gen_loss_dict['loss_gen_total'].append(loss_gen_total.item())
            else:
                self.te_gen_loss_dict['loss_gen_A2B'].append(loss_gan_gen_A2B.item())
                self.te_gen_loss_dict['loss_gen_B2A'].append(loss_gan_gen_B2A.item())
                self.te_gen_loss_dict['loss_iden_A2B'].append(loss_identity_A2B.item())
           
                self.te_gen_loss_dict['loss_cycle_B2A2B'].append(loss_cycle_B2A2B.item())
                self.te_gen_loss_dict['loss_gen_total'].append(loss_gen_total.item())
    
    def backward_D(self, real_A, real_B, fake_A2B, fake_B2A):

        fake_A2B = self.fake_B_pool.query(fake_A2B)
        fake_B2A = self.fake_A_pool.query(fake_B2A)

        if self.is_training:
            self.set_requires_grad([self.discriminator_B,self.discriminator_A], True)
            self.optimizer_D.zero_grad()  

        loss_gan_dis_A_real = gan_loss(self.discriminator_A(real_A), True)
        loss_gan_dis_A_fake = gan_loss(self.discriminator_A(fake_B2A.detach()), False)   
                                                          
        loss_gan_dis_B_real = gan_loss(self.discriminator_B(real_B), True)
        loss_gan_dis_B_fake = gan_loss(self.discriminator_B(fake_A2B.detach()), False) # Detach added

        # Total discriminator loss
        loss_dis_A = (loss_gan_dis_A_real + loss_gan_dis_A_fake) * 0.5
        
        loss_dis_B = (loss_gan_dis_B_real + loss_gan_dis_B_fake) * 0.5

        loss_dis_total = loss_dis_A + loss_dis_B
        
        loss_plott=loss_dis_total
        loss_plott1=loss_dis_A
        loss_plott2=loss_dis_B
        
        
        self.counter1 += 1;

        if (self.counter1 % 10 == 0):
            self.progress1.append(loss_plott.item())
            self.progress1.append(loss_plott1.item())
            self.progress1.append(loss_plott2.item())

            pass
        if (self.counter1 % 1000 == 0):
            print("counter1 = ", self.counter1)
            pass


        if self.is_training:
            # Calculate gradients
            loss_dis_total.backward()
            # Update D_A and D_B's weights
            self.optimizer_D.step()

        # Save train and test losses separately
        if self.save_losses:
            if self.is_training:
                self.tr_dis_loss_dict['loss_dis_B'].append(loss_dis_B.item())
                self.tr_dis_loss_dict['loss_dis_A'].append(loss_dis_A.item())
                self.tr_dis_loss_dict['loss_dis_total'].append(loss_dis_total.item())
            else:
                self.te_dis_loss_dict['loss_dis_B'].append(loss_dis_B.item())
                self.te_dis_loss_dict['loss_dis_A'].append(loss_dis_A.item())
                self.te_dis_loss_dict['loss_dis_total'].append(loss_dis_total.item())

    def set_requires_grad(self, nets, requires_grad=False):
      
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
                    

    def plot_progress(self):
        
        df = pandas.DataFrame(self.progress, columns=['loss_plot'])
        df1=pandas.concat([df, pandas.DataFrame(columns = [ 'loss_plot1'])])
        df1=pandas.concat([df1, pandas.DataFrame(columns = [ 'loss_plot2'])])
        df1=pandas.concat([df1, pandas.DataFrame(columns = [ 'loss_plot3'])])

        fig,ax= plt.subplots()
        
        ax=df1[['loss_plot','loss_plot1','loss_plot2','loss_plot3']].plot.area(ax=ax)

        ax.autoscale()
        ax.set_ylim(0,None)
        ax.margins(x=0)

        plt.show()
        
        pass

    def optimize_parameters(self, real_A, real_B):

        # Forward
        fake_A2B, recon_A2B, fake_B2A, recon_B2A, identity_A2B, identity_B2A = self.forward(real_A, real_B)  # compute fake images and reconstruction images.
        # G_A and G_B
        self.backward_G(real_A, real_B, fake_A2B, fake_B2A, recon_A2B, recon_B2A, identity_A2B, identity_B2A)  # calculate gradients for G_A and G_B
        # D_A and D_B
        self.backward_D(real_A, real_B, fake_A2B, fake_B2A)  # To-Do: Query fake images from the pool.


In [None]:
def train(train_dataset_A, train_dataset_B,test_dataset_A,test_dataset_B ,epochs, device):
 
    model = cycleGAN().to(device)

    for epoch in range(epochs):

        print('Epoch', epoch+1, '------------------')
        

        # Training
        temp = 1
        model.is_training = True
        
        for i, (data_A, data_B) in enumerate(zip(train_dataset_A,train_dataset_B)):
        
        # input image data
        
            data_A = data_A.to(device)
            data_B = data_B.to(device)
      

            # Save loss values at the end of each epoch
            if temp == train_dataset_A.__len__():
                model.save_losses = True

            model.optimize_parameters(data_A, data_B)

            temp = temp+1
            
        #model.plot_progress()

        print('Tr - Total Generator Loss:', np.round(model.tr_gen_loss_dict['loss_gen_total'][-1], decimals=3))
        print('Tr - Total Dicriminator Loss:', np.round(model.tr_dis_loss_dict['loss_dis_total'][-1], decimals=3))

        model.save_losses = False

        if epoch % 10 == 0:
            
            print_images(model.im_list, tr_im_save_dir, str(epoch), save_mode_on=True)

        # test
        with torch.set_grad_enabled(False):

            temp = 1
            model.is_training = False
            for i, (data_A, data_B) in enumerate(zip(test_dataset_A,test_dataset_B)):

                data_A = data_A.to(device)
                data_B = data_B.to(device)

                if temp == test_dataset_A.__len__():
                    model.save_losses = True

                model.optimize_parameters(data_A, data_B)

                temp = temp+1

            print('----')
            print('te - Total Generator Loss:', np.round(model.te_gen_loss_dict['loss_gen_total'][-1], decimals=3))
            print('te - Total Dicriminator Loss:', np.round(model.te_dis_loss_dict['loss_dis_total'][-1], decimals=3))

            model.save_losses = False

            print_images(model.im_list, te_im_save_dir, str(epoch), save_mode_on=True)
           
        
            
            test_real_B = test_real_B_data.cuda()
            test_real_A = test_real_A_data.cuda()
    
            Results = model.forward(test_real_A,test_real_B)
    
        
            test_fake_B = Results[0] #fake_A2B
    
            test_recon_b = Results[1] #recon_A2B
    
            test_fake_A = Results[2] #fake_B2A
    
            test_recon_A = Results[3] #recon_B2A

            #test_identi_B = Results[4] #identity_A2B

            #test_identi_A = Results[5] #identity_B2A
            
            plot_train_result([test_real_B, test_real_A], [test_fake_A, test_fake_B], [test_recon_B, test_recon_A],
                            epoch, save=True)
        
            
            if epoch  > params['save_epoch']:
                torch.save(model,os.path.join(model_save_dir,'model_%03d.pth'%epoch))
    
 
    # Save gen and disc loss values to respective csv files.
    df = pd.DataFrame.from_dict(model.tr_gen_loss_dict)
    df.to_csv(os.path.join(graph_save_dir, 'tr_gen_losses.csv'), index=False)
    
    df = pd.DataFrame.from_dict(model.tr_dis_loss_dict)
    df.to_csv(os.path.join(graph_save_dir, 'tr_dis_losses.csv'), index=False)
    
    # Save gen and disc loss values to respective csv files.
    df = pd.DataFrame.from_dict(model.te_gen_loss_dict)
    df.to_csv(os.path.join(graph_save_dir, 'te_gen_losses.csv'), index=False)
    
    df = pd.DataFrame.from_dict(model.te_dis_loss_dict)
    df.to_csv(os.path.join(graph_save_dir, 'te_dis_losses.csv'), index=False)

    # Save entire model architecture and params.
    torch.save(model, 'model.pth')

In [None]:
tr_im_save_dir, te_im_save_dir, graph_save_dir, model_save_dir = manage_folders()

In [None]:
%%time
train(train_data_loader_A,train_data_loader_B,test_data_loader_A,test_data_loader_B,200, device)