In [0]:
!wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip
!unzip horse2zebra.zip

In [0]:
!pip3 install torch
!pip3 install torchvision
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)


In [0]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import lr_scheduler
import numpy as np
import random
import itertools

device = torch.device("cuda:0")





In [0]:
 def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
        
class ReplayBuffer():
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.tensor(torch.cat(to_return), dtype=torch.float32, device=device)
      
      
class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)

In [0]:
class ResnetGenerator(nn.Module):
  def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
    'Construct a Resnet-based generator'
    super(ResnetGenerator, self).__init__()
    use_bias = norm_layer == nn.InstanceNorm2d
    
    model = [nn.ReflectionPad2d(3),
             nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
             norm_layer(ngf),
             nn.ReLU(True)]
    
    n_downsampling = 2
    
    for i in range(0, n_downsampling):
      mult = 2 ** i
      model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]
      
    mult = 2 ** n_downsampling
    
    for i in range(n_blocks):
      model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
    
    for i in range(0, n_downsampling):  
      mult = 2 ** int(n_downsampling - i)
      model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),  kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
                norm_layer(int(ngf * mult / 2)),
                nn.ReLU(True)]
      
    model += [nn.ReflectionPad2d(3)]
    model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
    model += [nn.Tanh()]

    self.model = nn.Sequential(*model)
      
  def forward(self, input):
    return self.model(input)

class ResnetBlock(nn.Module):
  def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
    super(ResnetBlock, self).__init__()
    conv_block = []
    p = 0

    if padding_type == 'reflect':
      conv_block += [nn.ReflectionPad2d(1)]
    elif padding_type == 'replicate':
      conv_block += [nn.ReplicationPad2d(1)]
    elif padding_type == 'zero':
      p = 1
    
    conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
    
    if use_dropout:
      conv_block += [nn.Dropout(0.5)]
    
    p = 0
    if padding_type == 'reflect':
      conv_block += [nn.ReflectionPad2d(1)]
    elif padding_type == 'replicate':
      conv_block += [nn.ReplicationPad2d(1)]
    elif padding_type == 'zero':
      p = 1
    
    conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
    
    self.conv_block = nn.Sequential(*conv_block)
    
  def forward(self, x):
    out = x + self.conv_block(x)
    return out

In [0]:
class NLayerDiscriminator(nn.Module):
  def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
    super(NLayerDiscriminator, self).__init__()
    use_bias = norm_layer != nn.BatchNorm2d
    kw = 4
    padw = 1
    
    model = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
    nf_mult = 1
    nf_mult_prev = 1
    for n in range(1, n_layers):
      nf_mult_prev = nf_mult
      nf_mult = min(2 ** n, 8)
      model += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)]
    nf_mult_prev = nf_mult
    nf_mult = min(2 ** n_layers, 8)
    model += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
              norm_layer(ndf * nf_mult),
              nn.LeakyReLU(0.2, True)]
    model += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
    self.model = nn.Sequential(*model)
    
  def forward(self, input):
    return self.model(input)
    

In [0]:
class GANLOSS(nn.Module):
  def __init__(self, target_real_label = 1.0, target_fake_label = 0.0):
    super(GANLOSS, self).__init__()
    self.register_buffer('real_label', torch.tensor(target_real_label))
    self.register_buffer('fake_label', torch.tensor(target_fake_label))
    self.gan_mode = 'lsgan'
    self.loss = nn.MSELoss()
  
  def get_target_tensor(self, prediction, target_is_real):
    if (target_is_real):
        target_tensor = self.real_label
    else:
        target_tensor = self.fake_label
    return target_tensor.expand_as(prediction)
  def __call__(self, prediction, target_is_real):
    target_tensor = self.get_target_tensor(prediction, target_is_real)
    loss = self.loss(prediction, target_tensor)
    return loss


In [0]:
import torch.utils.data as data
from PIL import Image
import os 
import os.path

def make_dataset(dir):
  images = []
  for root, _, fnames in sorted(os.walk(dir)):
    for fname in fnames:
      path = os.path.join(root,fname)
      images.append(path)
  return images

def loader(path):
  return Image.open(path).convert("RGB")

class ImageLoader(data.Dataset):
  def __init__(self, root, transform=None):
    imgs = make_dataset(root)
    
    self.root = root
    self.imgs = imgs
    self.transform = transform
    self.loader = loader
  
  def __getitem__(self, index):
    path = self.imgs[index]
    img = self.loader(path)
    if(self.transform != None):
      img = self.tranform(img)
    else:
      return img
  def __len__(self):
    return len(self.imgs)
     
def set_requires_grad(nets, requires_grad = False):
  if(not isinstance(nets, list)):
    nets = [nets]
  for net in nets:
    if(net is not None):
      for param in net.parameters():
        param.requires_grad = requires_grad
  

In [0]:
criterionGAN  = GANLOSS().to(device)
criterionCycle = torch.nn.L1Loss()
criterionIdt = torch.nn.L1Loss()
netG_A2B = ResnetGenerator(3,3).to(device)
netG_B2A = ResnetGenerator(3,3).to(device)
netD_A = NLayerDiscriminator(3).to(device)
netD_B = NLayerDiscriminator(3).to(device)

netG_A2B.apply(weights_init_normal)
netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)


n_epochs = 200
epoch = 0
decay_epoch = 100

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=0.0001)
optimizer_D_A  = torch.optim.Adam(netD_A.parameters(), lr=0.0001)
optimizer_D_B  = torch.optim.Adam(netD_B.parameters(), lr=0.0001)

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)


realA_loader = ImageLoader('./horse2zebra/trainA')
realB_loader = ImageLoader('./horse2zebra/trainB')

In [0]:

loss_D_B = None
loss_D_A = None
loss_G = None

target_real = torch.ones([1]).cuda()
target_fake = torch.zeros([1]).cuda()


for epoch in range(epoch, n_epochs): 

  for i, (real_A, real_B) in enumerate(zip(realA_loader, realB_loader)):
    
    real_A = torch.tensor(np.array(real_A).reshape((-1, 3, 256, 256)), dtype=torch.float32, device=device)
    real_B = torch.tensor(np.array(real_B).reshape((-1, 3, 256, 256)), dtype=torch.float32, device=device)
    optimizer_G.zero_grad()
  
    same_A = netG_B2A(real_A)
    loss_identity_A = criterion_identity(same_A, real_A) * 5.0
    same_B = netG_A2B(real_B)
    loss_identity_B = criterion_identity(same_B, real_B) * 5.0
  
    
  
    fake_A = netG_B2A(real_B)
    pred_fake = netD_B(fake_A)
    loss_GAN_B2A = criterion_GAN(pred_fake, target_real)
  
    fake_B = netG_A2B(real_A)
    pred_fake = netD_A(fake_B)
    loss_GAN_A2B = criterion_GAN(pred_fake, target_real)
    
    recovered_A = netG_B2A(fake_B)
    recovered_B = netG_A2B(fake_A)
    
    loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0
    loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0
    loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
    loss_G.backward()
    optimizer_G.step()
    
    optimizer_D_A.zero_grad()
    pred_real = netD_A(real_A)
    loss_D_real = criterion_GAN(pred_real, target_real)  
    fake_A = fake_A_buffer.push_and_pop(fake_A)
    pred_fake = netD_A(fake_A.detach())
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    loss_D_A = (loss_D_real + loss_D_fake)*0.5
    loss_D_A.backward()
    optimizer_D_A.step()
        
      
    optimizer_D_B.zero_grad()
    pred_real = netD_B(real_B)
    loss_D_real = criterion_GAN(pred_real, target_real)        
    fake_B = fake_B_buffer.push_and_pop(fake_B)
    pred_fake = netD_B(fake_B.detach())
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    loss_D_B = (loss_D_real + loss_D_fake)*0.5
    loss_D_B.backward()
    optimizer_D_B.step()
  
  lr_scheduler_G.step()
  lr_scheduler_D_A.step()
  lr_scheduler_D_B.step()
    
  
  print("{}:  {}, {} :: {}".format(epoch + 1, loss_D_B, loss_D_A, loss_G))  


  torch.save(netG_A2B,"genA.pb")
  torch.save(netG_B2A,"genB.pb")

  torch.save(netD_A,"disA.pb")
  torch.save(netD_B,"disB.pb")


  model_file = drive.CreateFile({'title' : 'genA.pb'})
  model_file.SetContentFile('genA.pb')
  model_file.Upload()
  drive.CreateFile({'id': model_file.get('id')})

  model_file = drive.CreateFile({'title' : 'genB.pb'})
  model_file.SetContentFile('genB.pb')
  model_file.Upload()
  drive.CreateFile({'id': model_file.get('id')})

  model_file = drive.CreateFile({'title' : 'disA.pb'})
  model_file.SetContentFile('disA.pb')
  model_file.Upload()
  drive.CreateFile({'id': model_file.get('id')})

  model_file = drive.CreateFile({'title' : 'disB.pb'})
  model_file.SetContentFile('disB.pb')
  model_file.Upload()
  drive.CreateFile({'id': model_file.get('id')})


In [0]:


import matplotlib.pyplot as plt

fig = plt.figure()

real_A = realA_loader[0]
real_B = realB_loader[0]

a = fig.add_subplot(2,2,1)
plt.imshow(real_A)

a = fig.add_subplot(2,2,3)
plt.imshow(real_B)


real_A = torch.tensor(np.array(real_A).reshape((-1,3,256,256)), dtype=torch.float32, device=device)
real_B = torch.tensor(np.array(real_B).reshape((-1,3,256,256)), dtype=torch.float32, device=device)

fake_B = netG_A(real_A)
fake_A = netG_B(real_B)
fake_B = fake_B.cpu().detach().numpy().reshape((-1,256,256,3)).squeeze()
fake_A = fake_A.cpu().detach().numpy().reshape((-1,256,256,3)).squeeze()

fake_A.astype(np.uint8)
fake_B.astype(np.uint8)

a = fig.add_subplot(2,2,2)
plt.imshow(fake_A)



a = fig.add_subplot(2,2,4)
plt.imshow(fake_B)
