<a href="https://colab.research.google.com/github/merail/Improving-Shape-Deformation-in-Unsupervised-Image-to-Image-Translation/blob/master/GANimorph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# required modules
import numpy as np 

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.init import kaiming_normal_, calculate_gain
import torch.optim as optim
from torch.autograd import Variable, grad

from functools import *
from scipy.io import loadmat

import os, time, imageio

from PIL import Image
from matplotlib import pyplot as plt
from os import listdir
from os.path import isfile, join

import itertools

In [2]:
# mount drive
from google.colab import drive
drive.mount('/content/gdrive', force_remount = True)

Mounted at /content/gdrive


In [0]:
class Initializer(nn.Module):
  def __init__(self):
    super(Initializer, self).__init__()
    
  def forward(self, layer):
    if layer.__class__.__name__.find('Conv2D') != -1:
      kaiming_normal_(layer.weight, a = calculate_gain("leaky_relu", negative_slope))
    if layer.__class__.__name__.find('ConvTransposed2D') != -1:
      kaiming_normal_(layer.weight, a = calculate_gain("relu"))
    if layer.__class__.__name__.find('InstanceNorm') != -1:
      nn.init.normal_(layer.weight, mean=1.0, std=0.002)
      layer.bias.data.zero_()

    return layer

In [0]:
class Store():
  def __init__(self):
    pass
  
  def make_store(self):
    current_time = time.strftime('%Y-%m-%d %H%M%S')
    epoch = 0
    
    samples_directory = os.path.join(store_directory, current_time, 'samples')
    weights_directory = os.path.join(store_directory, current_time, 'weights')
    losses_directory = os.path.join(store_directory, current_time, 'losses')
    os.makedirs(samples_directory)
    os.makedirs(weights_directory)
    os.makedirs(losses_directory)
    
    self.store_settings = {'epoch': epoch, 'samples_directory': samples_directory, 
            'weights_directory': weights_directory, 'losses_directory': losses_directory}
    
    return self.store_settings
    
  def restore_model(self):
    pattern = restore_file.split('-')
    epoch = int(pattern[2])

    samples_directory = os.path.join(restore_directory, 'samples')
    weights_directory = os.path.join(restore_directory, 'weights')
    losses_directory = os.path.join(restore_directory, 'losses')
    G_model = os.path.join(weights_directory, restore_file + '-G.pth')
    D_model = os.path.join(weights_directory, restore_file + '-D.pth')
    
    G.load_state_dict(torch.load(G_model))
    D.load_state_dict(torch.load(D_model))

    print('Restored from directory: %s, pattern: %s' % (restore_directory, restore_file))
    
    self.store_settings = {'epoch': epoch, 'samples_directory': samples_directory, 
            'weights_directory': weights_directory, 'losses_directory': losses_directory}
    
    return self.store_settings
  
  def save_model(self, it):
    g_file = os.path.join(self.store_settings['weights_directory'], '-%s' % (str(it).zfill(6))) + '-G.pth'
    torch.save(G.state_dict(), g_file)

    d_file = os.path.join(self.store_settings['weights_directory'], '-%s' % (str(it).zfill(6))) + '-D.pth'
    torch.save(D.state_dict(), d_file)
    
  def save_losses(self, it, losses_type, losses):
    path = os.path.join(self.store_settings['losses_directory'], '-%s' % (str(it).zfill(6))) + '-' + losses_type + '.txt'
    file = open(path, 'w')
    for item in losses:
      file.write("%s\n" % item)
    file.close()
    losses = []

    return losses

In [0]:
# generator
class Generator(nn.Module):
    def __init__(self):
      super(Generator, self).__init__()

      self.initializer = Initializer()
      subDepth = 3
      i = 1

      self.model = nn.ModuleList()

      self.model += [nn.Sequential(nn.Conv2d(in_channels = input_channel_size, out_channels = NF, kernel_size = 4, stride = 2, padding = 1, bias = False),
                                   nn.InstanceNorm2d(NF, affine=True),
                                   nn.LeakyReLU(negative_slope = negative_slope))]

      self.model += [nn.Sequential(nn.Conv2d(in_channels = NF, out_channels = NF * 2, kernel_size = 4, stride = 2, padding = 1, bias = False),
                                   nn.InstanceNorm2d(NF * 2, affine=True),
                                   nn.LeakyReLU(negative_slope = negative_slope))]

      self.model += [self.build_res_group(subDepth, NF * 2, NF * 2, i)]

      self.model += [nn.Sequential(nn.Conv2d(in_channels = NF*2, out_channels = NF * 4, kernel_size = 4, stride = 2, padding = 1, bias=False),
                     nn.InstanceNorm2d(NF * 4, affine=True),
                     nn.LeakyReLU(negative_slope = negative_slope))]

      self.model += [self.build_res_group(subDepth, NF * 4, NF * 4, i)]

      self.model += [nn.Sequential(nn.Conv2d(in_channels = NF * 4, out_channels = NF * 8, kernel_size = 4, stride = 2, padding = 1, bias=False),
                     nn.InstanceNorm2d(NF * 8, affine=True),
                     nn.LeakyReLU(negative_slope = negative_slope))]

      self.model += [self.build_res_group(subDepth, NF * 8, NF * 8, i)]

      self.model += [nn.Sequential(nn.ConvTranspose2d(in_channels = NF * 8, out_channels = NF * 4, kernel_size = 2, stride=2, bias=False),
                     nn.InstanceNorm2d(NF * 4, affine=True),
                     nn.ReLU())]

      self.model += [self.build_res_group(subDepth, NF * 8, NF * 4, i)]

      self.model += [nn.Sequential(nn.ConvTranspose2d(in_channels = NF * 4, out_channels = NF * 2, kernel_size = 2, stride=2, bias=False),
                     nn.InstanceNorm2d(NF * 2, affine=True),
                     nn.ReLU())]

      self.model += [self.build_res_group(subDepth, NF * 4, NF * 2, i)]

      self.model += [nn.Sequential(nn.ConvTranspose2d(in_channels = NF * 2, out_channels = NF, kernel_size = 2, stride=2, bias=False),
                     nn.InstanceNorm2d(NF, affine=True),
                     nn.ReLU())]

      self.model += [nn.Sequential(nn.ConvTranspose2d(in_channels = NF, out_channels = 3, kernel_size = 2, stride=2, bias=False),
                     nn.InstanceNorm2d(3, affine=True),
                     nn.Sigmoid())]

      for i in range(0, len(self.model)):
        for j in range(0, len(self.model[i])):
            self.model[i][j] = self.initializer(self.model[i][j])
      
    def build_res_group(self, depth, in_channels, out_channels, i):
      residual_group = nn.ModuleList()
      buf_channels = in_channels

      for k in range(depth):
        for m in range(depth):
          if ((m + 1) % 3 == 0):
            residual_group += [nn.Conv2d(in_channels = (buf_channels + out_channels), out_channels = in_channels, kernel_size = 3, stride = 1, padding = 1, bias=False)]
            residual_group += [nn.InstanceNorm2d(in_channels, affine=True)]
            residual_group += [nn.LeakyReLU(negative_slope = negative_slope)]
            buf_channels = out_channels
          else:
            residual_group += [nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1, bias=False)]
            residual_group += [nn.InstanceNorm2d(out_channels, affine=True)]
            residual_group += [nn.LeakyReLU(negative_slope = negative_slope)]
            
            in_channels = out_channels
      
      return nn.Sequential(*residual_group)

    def forward(self, x):
      for i in range(0, len(self.model)):
        if i == 3:
          skip_in_1 = x
        if i == 5:
          skip_in_2 = x
        if i == 8:
          x = torch.cat((x, skip_in_2), 1)
        if i == 10:  
          x = torch.cat((x, skip_in_1), 1)
        
        for j in range (0, len(self.model[i])):
          if ((j % 9) == 0):
            skip_for_residual_block = x
          if ((j + 3) % 9 == 0):
            x = torch.cat((x, skip_for_residual_block), 1)
          x = self.model[i][j](x)
          # print(i, " ", j)
          #print(x)
          #print("gen", x.shape)

        #print("============")
        
      return x

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
      super(Discriminator, self).__init__()

      subDepth = 3
      self.width = 16

      self.initializer = Initializer()

      self.model = nn.ModuleList()

      self.model += [nn.Sequential(nn.Conv2d(in_channels = input_channel_size, out_channels = NF * 2, kernel_size = 4, stride = 2, padding = 1, bias = False),
                     nn.ReLU())]
      self.model += [nn.Sequential(nn.Conv2d(in_channels = NF * 2, out_channels = NF * 4, kernel_size = 4, stride = 2, padding = 1, bias = False),
                     nn.InstanceNorm2d(NF * 4, affine=True),
                     nn.LeakyReLU(negative_slope = negative_slope))]
      self.model += [nn.Sequential(nn.Conv2d(in_channels = NF * 4, out_channels = NF * 8, kernel_size = 4, stride = 2, padding = 1, bias = False),
                     nn.InstanceNorm2d(NF * 8, affine=True),
                     nn.LeakyReLU(negative_slope = negative_slope))]
      self.model += [nn.Sequential(nn.Conv2d(in_channels = NF * 8, out_channels = NF * 8, kernel_size = 3, stride = 1, padding = 1, bias = False),
                     nn.InstanceNorm2d(NF * 8, affine=True),
                     nn.LeakyReLU(negative_slope = negative_slope))]
      self.model += [nn.Sequential(nn.Conv2d(in_channels = NF * 8, out_channels = NF * 8, kernel_size = 3, stride = 1, padding = 2, dilation = 2, bias=False),
                     nn.InstanceNorm2d(NF * 8, affine=True),
                     nn.LeakyReLU(negative_slope = negative_slope))]
      self.model += [nn.Sequential(nn.Conv2d(in_channels = NF * 8, out_channels = NF * 8, kernel_size = 3, stride = 1, padding = 4, dilation = 4, bias=False),
                     nn.InstanceNorm2d(NF * 8, affine=True),
                     nn.LeakyReLU(negative_slope = negative_slope))]
      self.model += [nn.Sequential(nn.Conv2d(in_channels = NF * 8, out_channels = NF * 8, kernel_size = 3, stride = 1, padding = 8, dilation = 8, bias=False),
                     nn.InstanceNorm2d(NF * 8, affine=True),
                     nn.LeakyReLU(negative_slope = negative_slope))]
      self.model += [nn.Sequential(nn.Conv2d(in_channels = NF * 16, out_channels = NF * 8, kernel_size = 3, stride = 1, padding = 1, bias = False),
                     nn.InstanceNorm2d(NF * 8, affine=True),
                     nn.LeakyReLU(negative_slope = negative_slope))]
      self.model += [nn.Sequential(nn.Conv2d(in_channels = NF * 8, out_channels = 1, kernel_size = 3, stride = 1, padding = 1, bias = False))]

      for i in range(0, len(self.model)):
        for j in range(0, len(self.model[i])):
            self.model[i][j] = self.initializer(self.model[i][j])
      
    def forward(self, x):
      feats = []
      for i in range(0, len(self.model)):
        if i != 0 and i != 1:
          feats.append(x)
        if i == 4:
          a = x
        if i == 7:
          x = torch.cat((x, a), 1)
        x = self.model[i](x)
        #print(i)
        #print(x)
        #print("dis", x.shape)

      #print("==========")

      return x, feats

In [0]:
class GANimorph(nn.Module):
  def __init__(self):
    super(GANimorph, self).__init__()
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    self.store = Store()
    if restore:
      self.store_settings = self.store.restore_model()
    else:
      self.store_settings = self.store.make_store()

    self.d_period = 1
    self.g_period = 2

    self.GAN_last_loss = 1
    self.FM_last_loss = 1
    self.CYC_last_loss = 1
    self.L1_last_loss = 1

    self.sigmoid = nn.Sigmoid().cuda()
    self.criterion = nn.BCELoss()

    self.real_labels = Variable(torch.full((16, 1, 16, 16), real_label)).cuda()
    self.fake_labels = Variable(torch.full((16, 1, 16, 16), fake_label)).cuda()

  def make_batch(self):
    indexes_A = np.random.randint(len(files_A), size = minibatch_size)
    indexes_B = np.random.randint(len(files_B), size = minibatch_size)

    names_A = []
    names_B = []
    for i in range (0, minibatch_size):
      names_A.append(files_A[indexes_A[i]])
      names_B.append(files_B[indexes_B[i]])

    batch_A = Variable(torch.FloatTensor(minibatch_size, 3, 128, 128)).cuda()
    batch_B = Variable(torch.FloatTensor(minibatch_size, 3, 128, 128)).cuda()
    for i in range(0, minibatch_size):
      batch_A[i] = torch.transpose(torch.transpose(torch.from_numpy(np.array(Image.open(data_path_A + names_A[i]))), 1, 2), 0, 1)
      batch_B[i] = torch.transpose(torch.transpose(torch.from_numpy(np.array(Image.open(data_path_B + names_B[i]))), 1, 2), 0, 1)

    return batch_A, batch_B

  def calculate_scheduled_GAN_loss(self, loss, iteration):
    moving_average_loss = beta*self.GAN_last_loss + (1-beta)*loss
    #print(moving_average_loss)
    self.GAN_last_loss = moving_average_loss
    if iteration % s == 0:
      loss = loss / (moving_average_loss.item() + epsilon)
    
    return loss

  def calculate_scheduled_FM_loss(self, loss, iteration):
    moving_average_loss = beta*self.FM_last_loss + (1-beta)*loss
    self.FM_last_loss = moving_average_loss
    if iteration % s == 0:
      loss = loss / (moving_average_loss.item() + epsilon)

    return loss

  def calculate_scheduled_CYC_loss(self, loss, iteration):
    moving_average_loss = beta*self.CYC_last_loss + (1-beta)*loss
    self.CYC_last_loss = moving_average_loss
    if iteration % s == 0:
      loss = loss / (moving_average_loss.item() + epsilon)

    return loss

  def calculate_scheduled_L1_loss(self, loss, iteration):
    moving_average_loss = beta*self.L1_last_loss + (1-beta)*loss
    self.L1_last_loss = moving_average_loss
    if iteration % s == 0:
      loss = loss / (moving_average_loss.item() + epsilon)

    return loss

  def calculate_L1_loss(self, X, XYX, Y, YXY):
    return (torch.mean(abs(XYX - X)) + torch.mean(abs(YXY - Y)))

  def calculate_GAN_discriminator_loss(self, dis_real, dis_fake):
    dis_real = self.sigmoid(dis_real)
    dis_fake = self.sigmoid(dis_fake)

    D_loss_real = self.criterion(dis_real, self.real_labels)
    D_loss_fake = self.criterion(dis_fake, self.fake_labels)
    D_loss = 0.5*(D_loss_real + D_loss_fake)

    return D_loss

  def calculate_GAN_generator_loss(self, dis_fake_for_gen):
    dis_fake_for_gen = self.sigmoid(dis_fake_for_gen)

    G_loss = self.criterion(dis_fake_for_gen, self.real_labels)

    return G_loss

  def calculate_feature_match_loss(self, feats_real, feats_fake):
    losses = Variable(torch.FloatTensor(7, 16, 16)).cuda()
    for i in range(0, len(feats_real)):
      loss = torch.mean((torch.mean(feats_real[i], 0) - torch.mean(feats_fake[i], 0))**2)
      losses[i] = loss
    
    result = torch.mean(losses)

    return result

  def pytorch_fspecial_gauss(self, size, sigma):
    """Function to mimic the 'fspecial' gaussian MATLAB function
    """
    x_data, y_data = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]

    x_data = np.expand_dims(x_data, axis=-1)
    x_data = torch.from_numpy(np.expand_dims(x_data, axis=-1))

    y_data = np.expand_dims(y_data, axis=-1)
    y_data = torch.from_numpy(np.expand_dims(y_data, axis=-1))

    g = torch.exp(-((x_data**2 + y_data**2)/(2.0*sigma**2)))

    return g / torch.sum(g)

  def calculate_SSIM_loss(self, img1, img2, cs_map=False, mean_metric=True, size=8, sigma=1.5, l = 0):
    window = self.pytorch_fspecial_gauss(size, sigma) # window shape [size, size]
    window = torch.transpose(torch.transpose(window, 1, 3), 0, 2)
    K1 = 0.03
    K2 = 0.05
    L = 1  # depth of image (255 in case the image has a differnt scale)
    C1 = (K1*L)**2
    C2 = (K2*L)**2
    #mu1 = torch.FloatTensor(, requires_grad = True)
    mu1 = F.conv2d(img1.cpu(), window)
    mu2 = F.conv2d(img2.cpu(), window)
    #window = window.cuda()
    # mu1 = F.conv2d(img1, window)
    # mu2 = F.conv2d(img2, window)

    mu1_sq = mu1*mu1
    mu2_sq = mu2*mu2
    mu1_mu2 = mu1*mu2
    sigma1_sq = F.conv2d((img1*img1).cpu(), window) - mu1_sq
    sigma2_sq = F.conv2d((img2*img2).cpu(), window)  - mu2_sq
    #print(F.conv2d((img1*img2), window).shape)
    #print(mu1_mu2.shape)
    sigma12 = F.conv2d((img1*img2).cpu(), window)  - mu1_mu2
    # sigma1_sq = F.conv2d((img1*img1), window) - mu1_sq
    # sigma2_sq = F.conv2d((img2*img2), window)  - mu2_sq
    # sigma12 = F.conv2d((img1*img2), window)  - mu1_mu2
    
    sigma1_sq = abs(sigma1_sq)
    sigma2_sq = abs(sigma2_sq)
    sigma12 = abs(sigma12)
    if cs_map:
        value = (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
                    (sigma1_sq + sigma2_sq + C2)),
                (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2))
    else:
        value = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
                    (sigma1_sq + sigma2_sq + C2))

    if mean_metric:
        value = torch.mean(value)
        
    return value

  def calculate_MS_SSIM_loss(self, img1, img2, mean_metric=True, level=5):

    weight = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
    mssim = []
    mcs = []
    for l in range(level):
        ssim_map, cs_map = self.calculate_SSIM_loss(img1, img2, cs_map=True, mean_metric=False, l = 0)
        mssim.append(torch.mean(torch.FloatTensor(ssim_map)))

        mcs.append(torch.mean(torch.FloatTensor(cs_map)))
        filtered_im1 = torch.nn.functional.avg_pool2d(img1, kernel_size = 2, stride = 2)
        filtered_im2 = torch.nn.functional.avg_pool2d(img2, kernel_size = 2, stride = 2)
        img1 = filtered_im1
        img2 = filtered_im2

    # list to tensor of dim D+1
    mssim = torch.stack(mssim, axis=0)
    mcs = torch.stack(mcs, axis=0)

    value = (torch.prod(mcs[0:level-1]**weight[0:level-1])*
                       (mssim[level-1]**weight[level-1])).cuda()

    if mean_metric:
        value = torch.mean(value)

    return value

  def calculate_DSSIM_loss(self, img1, img2):
    img1 = img1.view(img1.shape[0], img1.shape[1], -1, img1.shape[2], img1.shape[3])
    img2 = img2.view(img2.shape[0], img2.shape[1], -1, img2.shape[2], img2.shape[3])
    img1 = torch.unbind(img1, axis=1)
    img2 = torch.unbind(img2, axis=1)

    value = torch.stack([self.calculate_MS_SSIM_loss(i1, i2) for i1, i2 in zip(img1, img2)], axis=0)

    return (1.0 - torch.sum(value)/3)

  def viz3(self, a, b, c, d, e, f, iteration):
    im1 = torch.cat([a, b, c], axis=3)
    im2 = torch.cat([d, e, f], axis=3)
    im = torch.cat([im1, im2], axis=2)
    im = torch.transpose(torch.transpose(im, 1, 2), 2, 3)
    im = (im) * 255
    im = torch.clamp(im, 0, 255)
    im = im.type(torch.IntTensor)
    #print(im)

    imageio.imwrite(os.path.join(self.store_settings['samples_directory'], '%d.png' % (iteration)), im[0].numpy())

  def train(self):
    G1.cuda()
    G2.cuda()
    D1.cuda()
    D2.cuda()

    self.generator_optimizer = optim.Adam(itertools.chain(G1.parameters(), G2.parameters()), lr = learning_rate, betas = (beta1, beta2))
    self.discriminator_optimizer = optim.Adam(itertools.chain(D1.parameters(), D2.parameters()), lr = learning_rate, betas = (beta1, beta2))

    #for epoch in range(0, number_of_epoches):
    for iteration in range(0, number_of_iterations):
        A, B = self.make_batch()
        A = A / 255.0
        #print(A)
        B = B / 255.0

        AB = G1(A)
        BA = G2(B)
        ABA = G2(AB)
        BAB = G1(BA)
        
        self.generator_optimizer.zero_grad()

        _, A_feats_real = D1(A)
        A_dis_fake_for_gen, A_feats_fake = D1(BA)

        G_loss_A = self.calculate_GAN_generator_loss(A_dis_fake_for_gen)

        recon_loss_A = self.calculate_DSSIM_loss(A, ABA)

        fm_loss_A = self.calculate_feature_match_loss(A_feats_real, A_feats_fake)

        _, B_feats_real = D2(B)
        B_dis_fake_for_gen, B_feats_fake = D2(AB)

        G_loss_B = self.calculate_GAN_generator_loss(B_dis_fake_for_gen)

        recon_loss_B = self.calculate_DSSIM_loss(B, BAB)

        fm_loss_B = self.calculate_feature_match_loss(B_feats_real, B_feats_fake)

        recon_loss_l = self.calculate_L1_loss(A, ABA, B, BAB)

        # g_loss = (self.calculate_scheduled_GAN_loss(G_loss_A + G_loss_B, iteration) * 0.7 + \
        #         self.calculate_scheduled_FM_loss(fm_loss_A + fm_loss_B, iteration) * 0.3) * (1 - rate)  + \
        #         (self.calculate_scheduled_CYC_loss((recon_loss_A + recon_loss_B), iteration) * 0.7 + \
        #         self.calculate_scheduled_L1_loss((recon_loss_l),iteration) * 0.3) * rate
        g_loss = (self.calculate_scheduled_GAN_loss(G_loss_A + G_loss_B, iteration) * 0.7 + \
                self.calculate_scheduled_FM_loss(fm_loss_A + fm_loss_B, iteration) * 0.3) * (1 - rate)  + \
                (self.calculate_scheduled_CYC_loss((recon_loss_A + recon_loss_B), iteration) * 0.7 + \
                self.calculate_scheduled_L1_loss((recon_loss_l),iteration) * 0.3) * rate
        print(iteration, " g_loss", g_loss.item())

        g_loss.backward()

        self.generator_optimizer.step()
        
        if iteration % 2 == 0:
          self.discriminator_optimizer.zero_grad()

        A_dis_real, _ = D1(A)
        A_dis_fake, _ = D1(BA.detach())
        D_loss_A  = self.calculate_GAN_discriminator_loss(A_dis_real, A_dis_fake)
        
        B_dis_real, B_feats_real = D2(B)
        B_dis_fake, _ = D2(AB.detach())
        D_loss_B  = self.calculate_GAN_discriminator_loss(B_dis_real, B_dis_fake)

        d_loss = D_loss_A + D_loss_B
        print(iteration, " d_loss", d_loss.item())
        
        if iteration % 2 == 0:
          d_loss.backward()

        self.discriminator_optimizer.step()

        if iteration % freq == 0:
          self.viz3(A, AB, ABA, B, BA, BAB, iteration)
      
        if iteration % 500 == 0 and iteration > 0:
          g_file1 = os.path.join(self.store_settings['weights_directory'], '%s' % (str(iteration).zfill(6))) + '-G1.pth'
          torch.save(G1.state_dict(), g_file1)
          g_file2 = os.path.join(self.store_settings['weights_directory'], '%s' % (str(iteration).zfill(6))) + '-G2.pth'
          torch.save(G1.state_dict(), g_file2)

          d1_file = os.path.join(self.store_settings['weights_directory'], '%s' % (str(iteration).zfill(6))) + '-D1.pth'
          torch.save(D1.state_dict(), d1_file)

          d2_file = os.path.join(self.store_settings['weights_directory'], '%s' % (str(iteration).zfill(6))) + '-D2.pth'
          torch.save(D2.state_dict(), d2_file)

In [8]:
learning_rate = 2e-4
minibatch_size = 16

beta1 = 0.95
beta2 = 0.999

beta = 0.9999
epsilon = 1e-10
s = 200

freq = 100

negative_slope = 0.2

number_of_epoches = 150
number_of_iterations = 150000

lambda_gan = 0.49
lambda_fm = 0.21
lambda_cyc = 0.3
lambda_ss = 0.70
lambda_l1 = 0.3

input_channel_size = 3
NF = 64

negative_slope = 0.2

real_label = 1
fake_label = 0

rate = 0.33

restore = False
store_directory = '/content/gdrive/My Drive/Colab Notebooks/'
data_path_A = "/content/gdrive/My Drive/Colab Notebooks/cat_dog_face/cat_dog_face/trainA/"
data_path_B = "/content/gdrive/My Drive/Colab Notebooks/cat_dog_face/cat_dog_face/trainB/"

files_A = [f for f in listdir(data_path_A) if isfile(join(data_path_A, f))]
files_B = [f for f in listdir(data_path_B) if isfile(join(data_path_B, f))]

G1 = Generator()
print(G1)
G2 = Generator()
D1 = Discriminator()
D2 = Discriminator()
print(D1)
GAN = GANimorph()

Generator(
  (initializer): Initializer()
  (model): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (2): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=F

In [0]:
GAN.train()

0  g_loss 0.9993885159492493
0  d_loss 1.4367892742156982




1  g_loss 1.6410890817642212
1  d_loss 2.0186209678649902
2  g_loss 1.285774827003479
2  d_loss 1.6933786869049072
3  g_loss 1.1092331409454346
3  d_loss 1.5622360706329346
4  g_loss 1.0447908639907837
4  d_loss 1.842872142791748
5  g_loss 0.9269403219223022
5  d_loss 1.6180064678192139
6  g_loss 0.9380228519439697
6  d_loss 1.975632905960083
7  g_loss 0.8771974444389343
7  d_loss 1.7196929454803467
8  g_loss 0.8689675331115723
8  d_loss 1.7146685123443604
9  g_loss 0.859859049320221
9  d_loss 1.5997593402862549
10  g_loss 0.9024825096130371
10  d_loss 1.5883815288543701
11  g_loss 0.9179650545120239
11  d_loss 1.5458338260650635
12  g_loss 0.9078595638275146
12  d_loss 1.4635419845581055
13  g_loss 0.8920004367828369
13  d_loss 1.4276866912841797
14  g_loss 0.9214564561843872
14  d_loss 1.4190531969070435
15  g_loss 0.9292880296707153
15  d_loss 1.4028412103652954
16  g_loss 0.953708291053772
16  d_loss 1.3789482116699219
17  g_loss 0.9508684873580933
17  d_loss 1.3783705234527588
18 



101  g_loss 0.9519901275634766
101  d_loss 1.2124052047729492
102  g_loss 0.9351452589035034
102  d_loss 1.201958179473877
103  g_loss 0.9478635191917419
103  d_loss 1.1678481101989746
104  g_loss 0.9402936100959778
104  d_loss 1.161047101020813
105  g_loss 0.9275811910629272
105  d_loss 1.232434630393982
106  g_loss 0.8774231672286987
106  d_loss 1.2607792615890503
107  g_loss 0.9208969473838806
107  d_loss 1.1953434944152832
108  g_loss 0.8723577857017517
108  d_loss 1.2786520719528198
109  g_loss 0.8489018082618713
109  d_loss 1.3289191722869873
110  g_loss 0.8018914461135864
110  d_loss 1.3361005783081055
111  g_loss 0.8418487310409546
111  d_loss 1.358597755432129
112  g_loss 0.8685532212257385
112  d_loss 1.3287553787231445
113  g_loss 0.9093838930130005
113  d_loss 1.3834452629089355
114  g_loss 0.9018370509147644
114  d_loss 1.41756272315979
115  g_loss 0.915449321269989
115  d_loss 1.3977108001708984
116  g_loss 0.8863747715950012
116  d_loss 1.4397461414337158
117  g_loss 0.9



201  g_loss 0.825435996055603
201  d_loss 1.1813108921051025
202  g_loss 0.7576178312301636
202  d_loss 1.2314971685409546
203  g_loss 0.6968289613723755
203  d_loss 1.3181157112121582
204  g_loss 0.6318624019622803
204  d_loss 1.4572274684906006
205  g_loss 0.6249956488609314
205  d_loss 1.4812884330749512
206  g_loss 0.6437456607818604
206  d_loss 1.3859570026397705
207  g_loss 0.7059550285339355
207  d_loss 1.2985341548919678
208  g_loss 0.7515849471092224
208  d_loss 1.29058039188385
209  g_loss 0.8002193570137024
209  d_loss 1.2788782119750977
210  g_loss 0.826686680316925
210  d_loss 1.2428746223449707
211  g_loss 0.851099967956543
211  d_loss 1.3219034671783447
212  g_loss 0.8879623413085938
212  d_loss 1.2616240978240967
213  g_loss 0.9283022284507751
213  d_loss 1.3040094375610352
214  g_loss 0.9287434816360474
214  d_loss 1.3226079940795898
215  g_loss 0.929886519908905
215  d_loss 1.3062407970428467
216  g_loss 0.8923872709274292
216  d_loss 1.3240227699279785
217  g_loss 0.



301  g_loss 0.8451089859008789
301  d_loss 1.2807040214538574
302  g_loss 0.8921597003936768
302  d_loss 1.2527170181274414
303  g_loss 0.9175256490707397
303  d_loss 1.2268917560577393
304  g_loss 0.9038141965866089
304  d_loss 1.2507777214050293
305  g_loss 0.9248878955841064
305  d_loss 1.259894847869873
306  g_loss 0.9020220041275024
306  d_loss 1.3099119663238525
307  g_loss 0.8935126662254333
307  d_loss 1.2853219509124756
308  g_loss 0.8977693319320679
308  d_loss 1.2441339492797852
309  g_loss 0.8599498271942139
309  d_loss 1.3083207607269287
310  g_loss 0.8493108749389648
310  d_loss 1.3030409812927246
311  g_loss 0.8330950140953064
311  d_loss 1.3268790245056152
312  g_loss 0.7775962352752686
312  d_loss 1.3735175132751465
313  g_loss 0.7680667638778687
313  d_loss 1.364845871925354
314  g_loss 0.6920194029808044
314  d_loss 1.5678281784057617
315  g_loss 0.6866662502288818
315  d_loss 1.6514651775360107
316  g_loss 0.6847243309020996
316  d_loss 1.6392414569854736
317  g_los



401  g_loss 0.7237588763237
401  d_loss 1.2967281341552734
402  g_loss 0.7140353322029114
402  d_loss 1.268786072731018
403  g_loss 0.7153395414352417
403  d_loss 1.3078714609146118
404  g_loss 0.7273573875427246
404  d_loss 1.2917757034301758
405  g_loss 0.7170485258102417
405  d_loss 1.3139467239379883
406  g_loss 0.7148227691650391
406  d_loss 1.3336381912231445
407  g_loss 0.7167955636978149
407  d_loss 1.3519604206085205
408  g_loss 0.7113552689552307
408  d_loss 1.3242595195770264
409  g_loss 0.7599316239356995
409  d_loss 1.233015775680542
410  g_loss 0.7672397494316101
410  d_loss 1.1884820461273193
411  g_loss 0.8024869561195374
411  d_loss 1.2706172466278076
412  g_loss 0.8547443747520447
412  d_loss 1.1953274011611938
413  g_loss 0.8912354707717896
413  d_loss 1.1801129579544067
414  g_loss 0.9547518491744995
414  d_loss 1.1080713272094727
415  g_loss 0.9973344206809998
415  d_loss 1.126231074333191
416  g_loss 0.9873263835906982
416  d_loss 1.0958366394042969
417  g_loss 0.



501  g_loss 0.8292044997215271
501  d_loss 1.2800405025482178
502  g_loss 0.8317092657089233
502  d_loss 1.3144090175628662
503  g_loss 0.8980922698974609
503  d_loss 1.3097920417785645
504  g_loss 0.9190289974212646
504  d_loss 1.2810559272766113
505  g_loss 0.9941084980964661
505  d_loss 1.239613652229309
506  g_loss 0.998464047908783
506  d_loss 1.184538722038269
507  g_loss 1.0216301679611206
507  d_loss 1.152845025062561
508  g_loss 1.0391157865524292
508  d_loss 1.1396182775497437
509  g_loss 1.0449113845825195
509  d_loss 1.062534213066101
510  g_loss 1.0230944156646729
510  d_loss 1.0454440116882324
511  g_loss 1.0194833278656006
511  d_loss 1.0508376359939575
512  g_loss 1.000247836112976
512  d_loss 1.0640169382095337
513  g_loss 1.013935923576355
513  d_loss 1.0623767375946045
514  g_loss 0.9549675583839417
514  d_loss 1.074341893196106
515  g_loss 0.9317483305931091
515  d_loss 1.1110560894012451
516  g_loss 0.8706883192062378
516  d_loss 1.1529953479766846
517  g_loss 0.84



601  g_loss 1.0179753303527832
601  d_loss 1.0668354034423828
602  g_loss 1.0554184913635254
602  d_loss 1.0681769847869873
603  g_loss 0.9916607737541199
603  d_loss 1.0574918985366821
604  g_loss 0.8759697079658508
604  d_loss 1.1021041870117188
605  g_loss 0.88789963722229
605  d_loss 1.0950345993041992
606  g_loss 0.8607437610626221
606  d_loss 1.1507043838500977
607  g_loss 0.9226773977279663
607  d_loss 1.064910650253296
608  g_loss 0.9649515748023987
608  d_loss 1.0898559093475342
609  g_loss 1.1167356967926025
609  d_loss 1.0057770013809204
610  g_loss 1.1941654682159424
610  d_loss 0.947697639465332
611  g_loss 1.2686820030212402
611  d_loss 1.0725009441375732
612  g_loss 1.2669482231140137
612  d_loss 0.9014875888824463
613  g_loss 1.2381768226623535
613  d_loss 0.9164122939109802
