# Configs

In [0]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

import matplotlib.pyplot as plt
import numpy as np 
import random
import math

import os

import pickle


In [0]:
%cd /content
!mkdir datasets

# Code from https://github.com/phillipi/pix2pix/blob/master/datasets/download_dataset.sh
!wget -N https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2shoes.tar.gz 
!tar -zxvf edges2shoes.tar.gz -C ./datasets/
!rm edges2shoes.tar.gz

!wget -N https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2handbags.tar.gz
!tar -zxvf edges2handbags.tar.gz -C ./datasets/
!rm edges2handbags.tar.gz

In [0]:
A_DATA_PATH = os.path.dirname('/content/datasets/edges2handbags/')
B_DATA_PATH = os.path.dirname('/content/datasets/edges2shoes/')

modelName = "DiscoGAN-shoes2bags"

log_PATH = os.path.join("/gdrive/My Drive/notebooks", "logs","DiscoGAN")

batch_size = 200
instance_norm = True if batch_size==1 else False
workers = 2

epochs = 100

gf_dim = 64
df_dim = 64

lambda_a =10.0
lambda_b =10.0
in_w = in_h = 64
c_dim = 3

learning_rate = 2e-4
betas = (0.5,0.999)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

plt.rc('font',size =15)

manualSeed = 3734
print("Random Seed: ",manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# Data loading

In [0]:
transform = transforms.Compose(
    [
     transforms.Resize((in_h,in_w*2)),
     transforms.ToTensor(),
     transforms.Normalize((0.5,),(0.5,)),
     ]
)

def transform_inverse (y):
  t= None
  if y.size()[0]==1:
    t=torchvision.transforms.Normalize((-1,),(2,))
  else :
    t=torchvision.transforms.Normalize((-1,-1,-1),(2,2,2))
  return t(y)

def batch_transform_inverse(y):
  x = y.new(*y.size())
  if y.size()[1]==1:
    x[:, 0, :, :] = y[:, 0, :, :] * 2 - 1
  else:
    x[:, 0, :, :] = y[:, 0, :, :] * 2 - 1
    x[:, 1, :, :] = y[:, 1, :, :] * 2 - 1 
    x[:, 2, :, :] = y[:, 2, :, :] * 2 - 1
  return x

In [0]:
import os
from torch.utils.data import Dataset
from PIL import Image

class pix2pixDataset(Dataset):
  def __init__(self,root,test=False, transform=None):
    self.root_dir = root
    self.test= test
    self.transform=transform
    self.image_list = []

    if self.test == True:
        self.root_dir = os.path.join( self.root_dir, 'val' )
    else:
        self.root_dir = os.path.join( self.root_dir, 'train' )
    file_names = os.listdir(self.root_dir)
    for f in file_names:
      path = os.path.join(self.root_dir,f)
      self.image_list.append(path)

  def __getitem__(self, idx):
    path = self.image_list[idx]
    image=Image.open(path)
    if self.transform:
      image = self.transform(image)
    return image[:,:,:in_h],image[:,:,in_h:]

  def __len__(self):
    return len(self.image_list)


In [0]:
train_a_set = pix2pixDataset(root=A_DATA_PATH,transform=transform)
train_a_loader = torch.utils.data.DataLoader(train_a_set,batch_size=batch_size,
                                          shuffle =True, num_workers=workers)
train_b_set =  pix2pixDataset(root=B_DATA_PATH,transform=transform)
train_b_loader = torch.utils.data.DataLoader(train_b_set,batch_size=batch_size,
                                          shuffle =True, num_workers=workers)

test_a_set =  pix2pixDataset(root=A_DATA_PATH,test=True,transform=transform)
test_a_loader = torch.utils.data.DataLoader(test_a_set, batch_size=batch_size,shuffle = False, num_workers=workers)
test_b_set = pix2pixDataset(root=B_DATA_PATH,test=True,transform=transform)
test_b_loader = torch.utils.data.DataLoader(test_b_set, batch_size=batch_size,shuffle = False, num_workers=workers)

train_a_iter = iter(train_a_loader)
train_b_iter = iter(train_b_loader)
test_a_iter = iter(test_a_loader)
test_b_iter = iter(test_b_loader)

In [0]:
def compare_imshow(first_batch, second_batch, first_title="first_batch", second_title ="second_batch", nrow=10, third_batch =None, third_title="third_batch"):
  # Plot the first batch
  plt.figure(figsize=(40,70))
  plt.subplot(1,3,1)
  plt.axis("off")
  plt.title(first_title)
  plt.imshow(np.transpose(vutils.make_grid(first_batch, nrow=nrow,padding=2, normalize=True).cpu(),(1,2,0)))
  # Plot the fake images from the last epoch
  plt.subplot(1,3,2)
  plt.axis("off")
  plt.title(second_title)
  plt.imshow(np.transpose(vutils.make_grid(second_batch, nrow=nrow,padding=2, normalize=True).cpu(),(1,2,0)))

  if third_batch is not None:
    plt.subplot(1,3,3)
    plt.axis("off")
    plt.title(third_title)
    plt.imshow(np.transpose(vutils.make_grid(third_batch, nrow=nrow,padding=2, normalize=True).cpu(),(1,2,0)))


_,domain_a_batch = next(train_a_iter)
_,domain_b_batch= next(train_b_iter)
compare_imshow(domain_a_batch,domain_b_batch,"domain A(bag)","domain B(shoes)")
print("domain A Batch size : %s "%str(domain_a_batch.size()))
print("domain B Batch size : %s "%str(domain_b_batch.size()))

#ops

In [0]:
import torch.nn as nn

def conv_bn_layer(in_channels,out_channels,kernel_size,stride=1,padding=0):
    return nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,bias=False),
        nn.BatchNorm2d(out_channels,momentum=0.1,eps=1e-5),
    )
def tconv_bn_layer(in_channels,out_channels,kernel_size,stride=1,padding=0):
  return nn.Sequential(
      nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,bias=False),
      nn.BatchNorm2d(out_channels,momentum=0.1,eps=1e-5),
  )
def tconv_layer(in_channels,out_channels,kernel_size,stride=1,padding=0):
  return nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding)

def conv_layer(in_channels,out_channels,kernel_size,stride=1,padding=0):
    return nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding)

def fc_layer(in_features,out_features):
  return nn.Linear(in_features,out_features)

def fc_bn_layer(in_features,out_features):
  return nn.Sequential(
      nn.Linear(in_features,out_features,bias=False),
      nn.BatchNorm1d(out_features)
  )

# Models

In [0]:
# DCGAN based Encoder-Decoder Generator refer to https://github.com/SKTBrain/DiscoGAN/blob/master/discogan/model.py

import torch.nn as nn
import torch.nn.functional as F

class Generator(nn.Module):
  def __init__(self, extra_layers = False):
    super(Generator, self).__init__()
    encoder_list = []
    basic_layers = 4
    # Encoder 
    in_dim = c_dim
    out_dim = gf_dim
    encoder_list += nn.Sequential(
        conv_layer(in_dim,out_dim,4,2,1),
        nn.LeakyReLU(0.2,inplace=True)
        )
    for i in range(basic_layers-1):
      in_dim = out_dim
      out_dim = out_dim*2
      encoder_list += nn.Sequential(
          conv_bn_layer(in_dim,out_dim,4,2,1),
          nn.LeakyReLU(0.2,inplace=True)
      )
    if extra_layers == True:
      in_dim = out_dim
      out_dim = 100
      encoder_list +=  nn.Sequential(
          conv_bn_layer(in_dim, out_dim,4,1,0),
          nn.LeakyReLU(0.2, inplace=True),
      )
    self.encoder = nn.Sequential(*encoder_list)

    #Decoder
    decoder_list = []
    if extra_layers == True:
      in_dim = out_dim
      out_dim = gf_dim * 8 
      decoder_list += nn.Sequential(
          tconv_bn_layer(in_dim, out_dim, 4,1,0),
          nn.ReLU(True),
      )
    for i in range(basic_layers-1):
      in_dim = out_dim
      out_dim = int(out_dim/2)
      decoder_list += nn.Sequential(
          tconv_bn_layer(in_dim, out_dim, 4,2,1),
          nn.ReLU(True),
      )
    in_dim = out_dim
    out_dim = c_dim
    decoder_list += nn.Sequential(
        tconv_layer(in_dim,out_dim,4,2,1),
        nn.Sigmoid()
      )
    self.decoder = nn.Sequential(*decoder_list)
  def forward(self, x):
    x= self.encoder(x)
    x= self.decoder(x)
    return x


In [0]:
# DCGAN based Discriminator refer to https://github.com/SKTBrain/DiscoGAN/blob/master/discogan/model.py

class Discriminator(nn.Module):
  def __init__(self,):
    super(Discriminator,self).__init__()
    self.conv1 = conv_layer(c_dim,df_dim,4,2,1)
    self.conv2 = conv_bn_layer(df_dim,df_dim*2,4,2,1)
    self.conv3 = conv_bn_layer(df_dim*2,df_dim*4,4,2,1)
    self.conv4 = conv_bn_layer(df_dim*4,df_dim*8,4,2,1)
    self.conv5 = conv_layer(df_dim*8,1,4,1,0)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = F.leaky_relu(self.conv1(x),0.2,inplace=True)
    x = F.leaky_relu(self.conv2(x),0.2,inplace=True)
    x = F.leaky_relu(self.conv3(x),0.2,inplace=True)
    x = F.leaky_relu(self.conv4(x),0.2,inplace=True)
    x = self.conv5(x)

    return self.sigmoid(x)

In [0]:
def weights_init(m):
  classname =m.__class__.__name__
  if classname.find('Conv')!=-1 : # for Conv
    nn.init.normal_(m.weight.data,0.0,0.02)
    # nn.init.constant_(m.bias.data,0)
  elif classname.find('BatchNorm')!=-1:# Reference : https://discuss.pytorch.org/t/weight-initialization-for-batchnorm-in-dcgan-tutorial/32351
    nn.init.normal_(m.weight.data,1.0,0.02)
    # nn.init.constant_(m.bias.data,0)

# Train configs

In [0]:
import torch.optim as optim
import torch.nn as nn

G_AB = Generator(extra_layers=True).to(device)
G_BA = Generator(extra_layers=True).to(device)
D_A = Discriminator().to(device)
D_B = Discriminator().to(device)

d = nn.MSELoss()
bce = nn.BCELoss()

G_AB_optimizer = optim.Adam(G_AB.parameters(), lr=learning_rate,betas=betas)
G_BA_optimizer = optim.Adam(G_BA.parameters(), lr=learning_rate,betas=betas)
D_A_optimizer = optim.Adam(D_A.parameters(), lr=learning_rate,betas=betas)
D_B_optimizer = optim.Adam(D_B.parameters(), lr=learning_rate,betas=betas)

# fixed_a_condition, _ = next(iter(test_a_loader))
fixed_a_condition = None
fixed_b_condition = None

for i, (_, t) in enumerate(test_a_loader):
  if i ==0:
    fixed_a_condition = t
  elif i ==4:
    break
  else:
    fixed_a_condition=torch.cat((fixed_a_condition,t),0)

for i, (_, t) in enumerate(test_b_loader):
  if i ==0:
    fixed_b_condition = t
  elif i ==4:
    break
  else:
    fixed_b_condition=torch.cat((fixed_b_condition,t),0)


fixed_a_condition = fixed_a_condition.to(device)
fixed_b_condition = fixed_b_condition.to(device)

print(D_A.apply(weights_init))
print(D_B.apply(weights_init))
print(G_AB.apply(weights_init))
print(G_BA.apply(weights_init))

In [0]:
with torch.no_grad():
  fake_batch=G_AB(fixed_a_condition)
  fake_recon_batch=G_BA(fake_batch)
compare_imshow(fixed_a_condition,fake_batch,first_title="bag",second_title="bag->shoes",third_batch =fake_recon_batch,third_title="bag->shoes->bag")

In [0]:
with torch.no_grad():
  fake_batch=G_BA(fixed_b_condition)
  fake_recon_batch = G_AB(fake_batch)
compare_imshow(fixed_b_condition,fake_batch,first_title="shoes",second_title="shoes->bag",third_batch = fake_recon_batch, third_title="shoes->bag->shoes")

In [0]:
img_a_list = []
img_b_list = []
D_A_GAN_losses = []
D_B_GAN_losses = []
G_AB_GAN_losses = []
G_BA_GAN_losses = []

recon_A_losses = []
recon_B_losses = []

iter_per_plot = 500
plot_per_eps=(int(len(train_a_loader)/iter_per_plot))
transform_PIL=transforms.ToPILImage()

train_data_length = len(train_a_loader) if len(train_a_loader) < len(train_b_loader) else len(train_b_loader)

def log_list_save(l,file_name):
  with open(os.path.join(log_PATH ,file_name+".logs"), "wb") as fp:
    pickle.dump(l, fp)

def log_list_load(file_name):
  with open(os.path.join(log_PATH ,file_name+".logs"), "rb") as fp:
    return pickle.load(fp)

#Train

In [0]:
for ep in range(epochs):
  for ((i, (_, a_data)), (_, b_data)) in zip(enumerate(train_a_loader), train_b_loader):
    b_size= a_data.shape[0]
    x_A=a_data.to(device)
    x_B=b_data.to(device)

    real_label = torch.ones(b_size).to(device)
    fake_label = torch.zeros(b_size).to(device)

    break

    #Train D
    ## D_A
    D_A.zero_grad()

    x_BA = G_BA(x_B)
    d_a_fake=D_A(x_BA.detach())
    d_a_real=D_A(x_A)
    D_A_GAN_loss = bce(d_a_fake,fake_label) + bce(d_a_real,real_label)
    D_A_GAN_loss.backward()
    D_A_optimizer.step()

    ## D_B
    D_B.zero_grad()

    x_AB = G_AB(x_B)
    d_b_fake=D_A(x_BA.detach())
    d_b_real=D_A(x_B)
    D_B_GAN_loss = bce(d_b_fake,fake_label) + bce(d_b_real,real_label)
    D_B_GAN_loss.backward()
    D_B_optimizer.step()


    ## G with GAN loss
    G_AB.zero_grad()
    d_b_fake = D_B(fake_A)
    G_AB_GAN_loss=bce(d_b_fake,real_label)

    G_BA.zero_grad()
    d_a_fake = D_A(fake_A)
    G_BA_GAN_loss=bce(d_a_fake,real_label)
    
    # G with Reconstruction loss
    ## A->B->A
    x_ABA = G_BA(x_AB)
    A_recon =d(x_ABA,x_A)
    
    ## B->A->B
    x_BAB= G_AB(x_BA)
    B_recon = d(x_BAB,x_B)

    G_loss= G_AB_GAN_loss +  G_BA_GAN_loss + A_recon + B_recon
    G_loss.backward()
    G_AB_optimizer.step()
    G_BA_optimizer.step()

    if (i+1)%iter_per_plot == 0 :
      print('Epoch [{}/{}], Step [{}/{}], D_A_loss: {:.4f}, D_B_loss: {:.4f},G_AB_loss: {:.4f}, G_BA_loss:{:.4f},A_recon_loss:{:.4f},B_recon_loss:{:.4f}' 
            .format(ep, epochs, i+1, len(train_a_loader), D_A_GAN_loss.item(), D_B_GAN_loss.item(),G_AB_GAN_loss.item(), G_BA_GAN_loss.item(),A_recon.item(),B_recon.item()))
      D_A_GAN_losses.append(D_A_GAN_loss.item())
      D_B_GAN_losses.append(D_B_GAN_loss.item())
      G_AB_GAN_losses.append(G_AB_GAN_loss.item())
      G_BA_GAN_losses.append(G_BA_GAN_loss.item())

      recon_A_losses.append(A_recon.item())
      recon_B_losses.append(B_recon.item())

      with torch.no_grad():
        G_A.eval()
        G_B.eval()
        fake_B = G_A(fixed_a_condition).detach()
        fake_B_A = G_B(fake_B).detach()
        fake_A = G_B(fixed_b_condition).detach()
        fake_A_B = G_A(fake_A).detach()
        G_A.train()
        G_B.train()
      figs=plt.figure(figsize=(40,70))
      plt.subplot(1,3,1)
      plt.axis("off")
      plt.title("bag")
      plt.imshow(np.transpose(vutils.make_grid(fixed_a_condition, nrow=1,padding=5, normalize=True).cpu(),(1,2,0)))
      plt.subplot(1,3,2)
      plt.axis("off")
      plt.title("bag -> shoes")
      plt.imshow(np.transpose(vutils.make_grid(fake_B, nrow=1,padding=5, normalize=True).cpu(),(1,2,0)))
      plt.subplot(1,3,3)
      plt.axis("off")
      plt.title("bag -> shoes -> bag")
      plt.imshow(np.transpose(vutils.make_grid(fake_B_A, nrow=1,padding=5, normalize=True).cpu(),(1,2,0)))
      plt.savefig(os.path.join(log_PATH,modelName+"A-"+str(ep) +".png"))
      plt.close()
      img_a_list.append(figs)

      figs=plt.figure(figsize=(40,70))
      plt.subplot(1,3,1)
      plt.axis("off")
      plt.title("shoes")
      plt.imshow(np.transpose(vutils.make_grid(fixed_b_condition, nrow=1,padding=5, normalize=True).cpu(),(1,2,0)))
      plt.subplot(1,3,2)
      plt.axis("off")
      plt.title("shoes -> bag")
      plt.imshow(np.transpose(vutils.make_grid(fake_A, nrow=1,padding=5, normalize=True).cpu(),(1,2,0)))
      plt.subplot(1,3,3)
      plt.axis("off")
      plt.title("shoes -> bag -> shoes")
      plt.imshow(np.transpose(vutils.make_grid(fake_A_B, nrow=1,padding=5, normalize=True).cpu(),(1,2,0)))
      plt.savefig(os.path.join(log_PATH,modelName+"B-"+str(ep) +".png"))
      plt.close()
      img_b_list.append(figs)

      log_list_save(D_A_GAN_losses,"D_A_GAN_losses")
      log_list_save(D_B_GAN_losses,"D_B_GAN_losses")
      log_list_save(G_AB_GAN_losses,"G_AB_GAN_losses")
      log_list_save(G_BA_GAN_losses,"G_BA_GAN_losses")
      log_list_save(recon_A_losses,"recon_A_losses")
      log_list_save(recon_B_losses,"recon_B_losses") 

      torch.save(D_A.state_dict(),os.path.join(log_PATH,"D_A_"+modelName+".pth"))
      torch.save(D_B.state_dict(),os.path.join(log_PATH,"D_B_"+modelName+".pth"))
      torch.save(G_A.state_dict(),os.path.join(log_PATH,"G_AB_"+modelName+".pth"))
      torch.save(G_B.state_dict(),os.path.join(log_PATH,"G_BA_"+modelName+".pth"))

## Test and plot logs

In [0]:
with torch.no_grad():
  fake_batch=G_AB(fixed_a_condition)
  fake_recon_batch=G_BA(fake_batch)
compare_imshow(fixed_a_condition,fake_batch,first_title="bag",second_title="bag->shoes",third_batch =fake_recon_batch,third_title="bag->shoes->bag")

In [0]:
with torch.no_grad():
  fake_batch=G_BA(fixed_b_condition)
  fake_recon_batch = G_AB(fake_batch)
compare_imshow(fixed_b_condition,fake_batch,first_title="shoes",second_title="shoes->bag",third_batch = fake_recon_batch, third_title="shoes->bag->shoes")

In [0]:
plt.title("bag -> shoes Gan loss")

epsilon = 1/plot_per_eps
X = np.array(range(plot_per_eps*epochs))/plot_per_eps
plt.rc('font',size =9)
plt.plot(X,D_A_GAN_losses,label="G loss")
plt.plot(X,G_A_GAN_losses,label="D loss")
plt.legend(loc=2)
plt.xticks(np.arange(0,epochs+1,20)) 
plt.ylabel("loss")
plt.xlabel("Epochs")
plt.savefig(os.path.join(log_PATH,modelName+"s2b_loss_figure.png"))

In [0]:
plt.title("shoes -> bag Gan loss")

epsilon = 1/plot_per_eps
X = np.array(range(plot_per_eps*epochs))/plot_per_eps
plt.rc('font',size =9)
plt.plot(X,D_B_GAN_losses,label="G loss")
plt.plot(X,G_BA_GAN_losses,label="D loss")
plt.legend(loc=2)
plt.xticks(np.arange(0,epochs+1,20)) 
plt.ylabel("loss")
plt.xlabel("Epochs")
plt.savefig(os.path.join(log_PATH,modelName+"s2b_loss_figure.png"))


In [0]:
plt.title("cycle consistency loss")

epsilon = 1/plot_per_eps
X = np.array(range(plot_per_eps*epochs))/plot_per_eps
plt.rc('font',size =9)
plt.plot(X,recon_A_losses,label="h->z->h cycle loss")
plt.plot(X,recon_B_losses,label="z->h->z cycle loss")
plt.legend(loc=2)
plt.xticks(np.arange(0,epochs+1,20)) 
plt.ylabel("loss")
plt.xlabel("Epochs")
plt.savefig(os.path.join(log_PATH,modelName+"cycle_loss_figure.png"))

# Model save

In [0]:
torch.save(D_A.state_dict(),os.path.join(log_PATH,"D_A_"+modelName+".pth"))
torch.save(D_B.state_dict(),os.path.join(log_PATH,"D_B_"+modelName+".pth"))
torch.save(G_A.state_dict(),os.path.join(log_PATH,"G_AB_"+modelName+".pth"))
torch.save(G_B.state_dict(),os.path.join(log_PATH,"G_BA_"+modelName+".pth"))

In [0]:
D_A.load_state_dict(torch.load(os.path.join(log_PATH,"D_A_"+modelName+".pth")))
D_B.load_state_dict(torch.load(os.path.join(log_PATH,"D_B_"+modelName+".pth")))
G_A.load_state_dict(torch.load(os.path.join(log_PATH,"G_AB_"+modelName+".pth")))
G_B.load_state_dict(torch.load(os.path.join(log_PATH,"G_BA_"+modelName+".pth")))