# **Torch Implmentation of CycleGANs**

### **What is CycleGANs?**
CycleGANs solves the computer vision problem of translating the style of one image (or sets of images) to the another using artificial neural networks, convolutional layers, and deep learning methods such as gradient descent, to optimize the output for a particular style.

Implementation adopted from and based on these sources:
1. [**CycleGANs Paper**](./data/1703.10593.pdf)
2. [**GitHub Repository of CycleGANs Paper**](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
3. [**Code based on this Torch Implementation**](https://github.com/yunjey/mnist-svhn-transfer)

### Package Requirements

- Python version 3.x
- Torch version 1.9
- Torchvision version 0.10
- imageio version 2.9.0
- numpy version 1.19.5

In [44]:
import torch
import torchvision
import os
import pickle
import imageio
import numpy as np
from torch.backends import cudnn
from torch.autograd import Variable
from torch import optim
from torchvision import datasets
from torchvision import transforms

### Datasets

- svhn: [The Street View House Numbers (SVHN) Dataset](http://ufldl.stanford.edu/housenumbers/)
- mnist: [Modified National Institute of Standards and Technology database](https://en.wikipedia.org/wiki/MNIST_database)

In [45]:
svhn_path='./svhn'
mnist_path='./mnist'

transform = transforms.Compose([transforms.Resize(32),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

svhn = datasets.SVHN(root=svhn_path, 
                     download=True, 
                     transform=transform)

mnist = datasets.MNIST(root=mnist_path, 
                       download=True, 
                       transform=transform)
    
svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
                                          batch_size=64,
                                          shuffle=True,
                                          num_workers=2)

mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                           batch_size=64,
                                               
                                           shuffle=True,
                                           num_workers=2)

Using downloaded and verified file: ./svhn/train_32x32.mat


In [46]:
import torch.nn as nn
import torch.nn.functional as F

- Conv2d and ConvTranspose2d Layers with BatchNorm2d

In [33]:
def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    layers=[]
    layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    layers=[]
    layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

### Generators

In [41]:
class G12(nn.Module):
    def __init__(self, conv_dim=64):
        super(G12, self).__init__()
        self.conv1 = conv(1, conv_dim, 4)
        self.conv2 = conv(conv_dim, conv_dim * 2, 4)
        
        self.conv3 = conv(conv_dim * 2, conv_dim * 2, 3, 1, 1)
        self.conv4 = conv(conv_dim * 2, conv_dim * 2, 3, 1, 1)
        
        self.deconv1 = deconv(conv_dim * 2, conv_dim, 4)
        self.deconv2 = deconv(conv_dim, 3, 4, bn=False)
        
    def forward(self, x):
        out = F.leaky_relu(self.conv1(x), 0.05)
        out = F.leaky_relu(self.conv2(out), 0.05)
        
        out = F.leaky_relu(self.conv3(out), 0.05)
        out = F.leaky_relu(self.conv4(out), 0.05)
        
        out = F.leaky_relu(self.deconv1(out), 0.05)
        out = torch.tanh(self.deconv2(out))
        return out

In [42]:
class G21(nn.Module):
    def __init__(self, conv_dim=64):
        super(G21, self).__init__()
        self.conv1 = conv(3, conv_dim, 4)
        self.conv2 = conv(conv_dim, conv_dim * 2, 4)
        
        self.conv3 = conv(conv_dim * 2, conv_dim * 2, 3, 1, 1)
        self.conv4 = conv(conv_dim * 2, conv_dim * 2, 3, 1, 1)
        
        self.deconv1 = deconv(conv_dim * 2, conv_dim, 4)
        self.deconv2 = deconv(conv_dim, 1, 4, bn=False)
        
    def forward(self, x):
        out = F.leaky_relu(self.conv1(x), 0.05)
        out = F.leaky_relu(self.conv2(out), 0.05)
    
        out = F.leaky_relu(self.conv3(out), 0.05)
        out = F.leaky_relu(self.conv4(out), 0.05)
    
        out = F.leaky_relu(self.deconv1(out), 0.05)
        out = torch.tanh(self.deconv2(out))
        return out

### Discriminators

In [36]:
class D1(nn.Module):   
    def __init__(self, conv_dim=64, use_labels=False):
        super(D1, self).__init__()
        self.conv1 = conv(1, conv_dim, 4, bn=False)
        self.conv2 = conv(conv_dim, conv_dim * 2, 4)
        self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4)
        n_out = 11 if use_labels else 1
        self.fc = conv(conv_dim * 4, n_out, 4, 1, 0, False)
    
    def forward(self, x):
        out = F.leaky_relu(self.conv1(x), 0.05)
        out = F.leaky_relu(self.conv2(out), 0.05)
        out = F.leaky_relu(self.conv3(out), 0.05)
        out = self.fc(out).squeeze()
        return out

In [37]:
class D2(nn.Module):
    def __init__(self, conv_dim=64, use_labels=False):
        super(D2, self).__init__()
        self.conv1 = conv(3, conv_dim, 4, bn=False)
        self.conv2 = conv(conv_dim, conv_dim * 2, 4)
        self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4)
        n_out = 11 if use_labels else 1
        self.fc = conv(conv_dim * 4, n_out, 4, 1, 0, False)
        
    def forward(self, x):
        out = F.leaky_relu(self.conv1(x), 0.05)
        out = F.leaky_relu(self.conv2(out), 0.05)
        out = F.leaky_relu(self.conv3(out), 0.05)
        out = self.fc(out).squeeze()
        return out

### CycleGans

In [60]:
class CycleGans(object):
    def __init__(self, s, m):
        self.s = s
        self.m = m
        self.use_reconst_loss = True
        self.use_labels = True
        self.num_classes = 10
        self.beta1 = 0.5
        self.beta2 = 0.999
        self.g_conv_dim = 64
        self.d_conv_dim = 64
        self.train_iters = 500000
        self.batch_size = 64
        self.lr = 0.0002
        self.log_step = 1000
        self.sample_step = 5000
        self.sample_path = './samples' 
        self.model_path = './models'
        self.build_model()
    
    def build_model(self):
        self.g12 = G12(conv_dim=self.g_conv_dim)
        self.g21 = G21(conv_dim=self.g_conv_dim)
        self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        
        g_params = list(self.g12.parameters()) + list(self.g21.parameters())
        d_params = list(self.d1.parameters()) + list(self.d2.parameters())
        
        self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
        
        if torch.cuda.is_available():
            self.g12.cuda()
            self.g21.cuda()
            self.d1.cuda()
            self.d2.cuda()
            
    def merge_images(self, sources, targets, k=10):
        _, _, h, w = sources.shape
        row = int(np.sqrt(self.batch_size))
        merged = np.zeros([3, row * h, row * w * 2])
        for idx, (s, t) in enumerate(zip(sources, targets)):
            i = idx // row
            j = idx % row
            merged[:, i*h:(i + 1) * h, (j * 2) * h: (j * 2 + 1) * h] = s
            merged[:, i*h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h] = t
        return merged.transpose(1, 2, 0)
    
    def to_var(self, x):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)
    
    def to_data(self, x):
        if torch.cuda.is_available():
            x = x.cpu()
        return torch.from_numpy(x.data.numpy())
    
    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        
    def train(self):
        svhn_iter = iter(self.s)
        mnist_iter = iter(self.m)
        iter_per_epoch = min(len(svhn_iter), len(mnist_iter))
        
        fixed_svhn = self.to_var(svhn_iter.next()[0])
        fixed_mnist = self.to_var(mnist_iter.next()[0])
        
        criterion = nn.CrossEntropyLoss()
        
        for step in range(self.train_iters + 1):
            if (step + 1) % iter_per_epoch == 0:
                mnist_iter = iter(self.m)
                svhn_iter = iter(self.s)
                
            svhn, s_labels = svhn_iter.next()
            svhn, s_labels = self.to_var(svhn), self.to_var(s_labels).long().squeeze()
            mnist, m_labels = mnist_iter.next()
            mnist, m_labels = self.to_var(mnist), self.to_var(m_labels)
            
            if self.use_labels:
                mnist_fake_labels = self.to_var(
                    torch.Tensor([self.num_classes] * svhn.size(0)).long())
                svhn_fake_labels = self.to_var(
                    torch.Tensor([self.num_classes] * mnist.size(0)).long())
                
            self.reset_grad()
            out = self.d1(mnist)
            if self.use_labels:
                d1_loss = criterion(out, m_labels)
            else:
                d1_loss = torch.mean((out - 1) ** 2)
                
            out = self.d2(svhn)
            if self.use_labels:
                d2_loss = criterion(out, s_labels)
            else:
                d2_loss = torch.mean((out - 1) ** 2)
                
            d_mnist_loss = d1_loss
            d_svhn_loss = d2_loss
            d_real_loss = d1_loss + d2_loss
            d_real_loss.backward()
            self.d_optimizer.step()
            
            self.reset_grad()
            fake_svhn = self.g12(mnist)
            out = self.d2(fake_svhn)
            if self.use_labels:
                d2_loss = criterion(out, svhn_fake_labels)
            else:
                d2_loss = torch.mean(out**2)
            fake_mnist = self.g21(svhn)
            out = self.d1(fake_mnist)
            if self.use_labels:
                d1_loss = criterion(out, mnist_fake_labels)
            else:
                d1_loss = torch.mean(out**2)
            d_fake_loss = d1_loss + d2_loss
            d_fake_loss.backward()
            self.d_optimizer.step()
            
            self.reset_grad()
            fake_svhn =self.g12(mnist)
            out = self.d2(fake_svhn)
            reconst_mnist = self.g21(fake_svhn)
            if self.use_labels:
                g_loss = criterion(out, m_labels)
            else:
                g_loss = torch.mean((out - 1) ** 2)
                
            if self.use_reconst_loss:
                g_loss += torch.mean((mnist - reconst_mnist) ** 2)
            
            g_loss.backward()
            self.g_optimizer.step()
            
            self.reset_grad()
            fake_mnist = self.g21(svhn)
            out = self.d1(fake_mnist)
            reconst_svhn = self.g12(fake_mnist)
            if self.use_labels:
                g_loss = criterion(out, s_labels)
            else:
                g_loss = torch.mean((out - 1) ** 2)
            
            if self.use_reconst_loss:
                g_loss += torch.mean((svhn - reconst_svhn) ** 2)
            
            g_loss.backward()
            self.g_optimizer.step()
            
            if (step + 1) % self.log_step == 0:
                print('Step [%d/%d], d_real_loss: %.4f, d_fake_loss: %.4f, g_loss; %.4f' % (step+1,
                                                                                            self.train_iters,
                                                                                            d_real_loss.data.item(),
                                                                                            d_fake_loss.data.item(),
                                                                                            g_loss.data.item()))
                
            if (step + 1) % self.sample_step == 0:
                fake_svhn = self.g12(fixed_mnist)
                fake_mnist = self.g21(fixed_svhn)
                
                mnist, fake_mnist = self.to_data(fixed_mnist), self.to_data(fake_mnist)
                svhn, fake_svhn = self.to_data(fixed_svhn), self.to_data(fake_svhn)
                
                merged = self.merge_images(mnist, fake_svhn)
                path = os.path.join(self.sample_path, 'sample-%d-m-s.png' %(step + 1))
                imageio.imwrite(path, (merged[:, :, 0]*255).astype(np.uint8))
                print('saved %s' %path)
                
                merged = self.merge_images(svhn, fake_mnist)
                path = os.path.join(self.sample_path, 'sample-%d-s-m.png' %(step + 1))
                imageio.imwrite(path, (merged[:, :, 0]*255).astype(np.uint8))
                print('saved %s' %path)
                
            if (step + 1) % 5000 == 0:
                
                g12_path = os.join(self.model_path, 'g12-%d.pkl' % (step+1))
                g21_path = os.join(self.model_path, 'g21-%d.pkl' %(step + 1))
                d1_path = os.path.join(self.model_path, 'd1-%d.pkl' %(step + 1))
                d2_path = os.path.join(self.model_path, 'd2-%d.pkl' %(step + 1))
                
                torch.save(self.g12.state_dict(), g12_path)
                torch.save(self.g21.state_dict(), g21_path)
                torch.save(self.d1.state_dict(), d1_path)
                torch.save(self.d2.state_dict(), d2_path)

In [61]:
gans = CycleGans(svhn_loader, mnist_loader)
cudnn.benchmark = True

if not os.path.exists('./models'):
    os.makedirs('./models')
if not os.path.exists('./samples'):
    os.makedirs('./samples')
    
gans.train()

Step [10/50], d_real_loss: 4.8725, d_fake_loss: 2.2086, g_loss; 3.1746
Step [20/50], d_real_loss: 4.8294, d_fake_loss: 1.6392, g_loss; 3.1834
Step [30/50], d_real_loss: 4.9983, d_fake_loss: 1.5757, g_loss; 3.1731
Step [40/50], d_real_loss: 5.1253, d_fake_loss: 1.4935, g_loss; 3.1146
Step [50/50], d_real_loss: 4.8509, d_fake_loss: 1.4044, g_loss; 3.1127
saved ./samples/sample-50-m-s.png
saved ./samples/sample-50-s-m.png
