# CycleGAN: Unpaired Image-to-Image Translation

## 0. Some setup before we begin...

In this folder, please run the following commands to create empty folders for logging:

```
mkdir checkpoints_cyclegan
mkdir samples_cyclegan
```

Also, make sure the unzipped dataset folder ```summer2winter_yosemite``` is in the same folder as this file.

## 1. Load and Visualize the Data 

In [None]:
import os
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import warnings

%matplotlib inline

# Some matplotlib settings
plt.rcParams["figure.figsize"] = (16, 10)
plt.rcParams["axes.grid"] = False
plt.rcParams["xtick.major.bottom"] = False
plt.rcParams["ytick.major.left"] = False

# Device Settings
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Some training settings:
batch_size = 16
num_workers = 0

The dataset can be downloaded from the following [Link](https://s3.amazonaws.com/video.udacity-data.com/topher/2018/November/5be66e78_summer2winter-yosemite/summer2winter-yosemite.zip).

Move the ```summer2winter_yosemite``` folder into the git repo.

### Define function to generate train and test dataloaders

The following function returns train and test dataloaders, which we will use to visualize the data and train the model

In [None]:
def get_data_loaders(image_type, image_dir='summer2winter_yosemite', transform=None, batch_size=16, num_workers=0):
    
    # get training and test directories
    image_type_test = "test_{}".format(image_type)
    train_path = os.path.join(image_dir, image_type)
    test_path = os.path.join(image_dir, image_type_test)

    # define datasets using ImageFolder
    train_dataset = datasets.ImageFolder(train_path, transform)
    test_dataset = datasets.ImageFolder(test_path, transform)

    # create and return DataLoaders
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader

Now we can call the functions to get two sets of dataloaders. 

But first, we must define a image transformation function to transform the images from PIL format to Tensor format.

We also normalize the image using standard mean and variances

transforms.Compose() takes a list of transformations and calls them in sequence. 

In [None]:
image_size = 256
transformations = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor(), 
                                      transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])])

summer_dataloader, test_summer_dataloader = get_data_loaders(image_type='summer', transform=transformations, 
                                                            batch_size=batch_size, num_workers=num_workers)
winter_dataloader, test_winter_dataloader = get_data_loaders(image_type='winter', transform=transformations,
                                                            batch_size=batch_size, num_workers=num_workers)

Dataset created by ```datasets.ImageFolder``` return two values: 

1. Image
2. corresponding label.

When we wrap a dataloader around the dataset, this becomes a batch of images, and the corresponding labels. 

In our case, we are not interested in the labels (they are all 0). Thus, to visualize the data, we only save the first term of the returned values. To iterate just one batch of the dataset, we wrap the dataloader as an iterable and call one iteration.


### Visualize the data 

We define an inverse transform that changes a tensor back to a PIL Image. First, we "unnormalize" the tensor by taking the inverse of the normalization function from above. Then, we can call ```transforms.ToPILImage()``` which will convert the tensor to a PIL Image

In [None]:
def get_pil_img_from_tensor(img_tensor):
    inverse_transform = transforms.Compose([transforms.Normalize(mean=[-1, -1, -1], std=[1/0.5, 1/0.5, 1/0.5]), 
                                            transforms.ToPILImage()])
    return inverse_transform(img_tensor)

def show_img(pil_img):
    plt.imshow(pil_img)

In [None]:
batch_summer_imgs, _ = next(iter(summer_dataloader))
batch_winter_imgs, _ = next(iter(winter_dataloader))

In [None]:
show_img(get_pil_img_from_tensor(torchvision.utils.make_grid(batch_summer_imgs)))

In [None]:
show_img(get_pil_img_from_tensor(torchvision.utils.make_grid(batch_winter_imgs)))

## 2. Define the Model 

A CycleGAN is composed of two discriminators and two generators.

### Discriminators
The discriminators, $D_X$ and $D_Y$, in this CycleGAN are convolutional neural networks that see an image and attempt to classify it as real or fake. In this case, real is indicated by an output close to 1 and fake as close to 0. The discriminators have the following architecture:

<img src='imgs/discriminator_layers.png' width=80% />

### A function that creates a general convolution layer 

In [None]:
import torch.nn as nn
import torch.nn.functional as F

# helper conv function. Note that DEFAULT stride = 2
def conv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):
    """Creates a convolutional layer, with optional batch normalization.
    """
    layers = []
    conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                           kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
    
    layers.append(conv_layer)

    if batch_norm:
        layers.append(nn.BatchNorm2d(out_channels))
    return nn.Sequential(*layers)

### The Discriminator

Let's create the discriminator!

> **Exercise**: Create a discriminator model using the `conv` function above. Refer to the image above for what the network should look like!

In [None]:
class Discriminator(nn.Module):
    
    def __init__(self, conv_dim=64):
        super(Discriminator, self).__init__()

        # Define all convolutional layers
        # Should accept an RGB image as input and output a single value

    def forward(self, x):
        # Define feed forward
        
        return out


In [None]:
# Test Discriminator

with torch.no_grad():
    D_X = Discriminator()
    rand_input = torch.randn(3, 3, 128, 128)
    out = D_X(rand_input)
    assert(out.size() == (3, 1, 8, 8))

### The Generator

<img src='imgs/cyclegan_generator_ex.png' width=90% />


### Residual Block

<img src='imgs/resnet_block.png' width=40%/>

Convolution output size formula:
```
output_size = 1 + (input_size - kernel_size + 2*padding) / stride
```

In [None]:
# residual block class
class ResidualBlock(nn.Module):

    def __init__(self, conv_dim):
        super(ResidualBlock, self).__init__()
        # conv_dim = number of inputs
        
        # define two convolutional layers + batch normalization that will act as our residual function, F(x)
        # layers should have the same shape input as output; I suggest a kernel_size of 3
        
        self.conv_layer1 = ?
        
        self.conv_layer2 = ?
        
    def forward(self, x):
        # Fill out feed forward. Use F.relu() for relu
        out = ?
        out = ?
        return F.relu(out)

### Deconv Layer 

In [None]:
def deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):
    """Creates a transpose convolutional layer, with optional batch normalization.
    """
    layers = []
    # append transpose conv layer
    layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False))
    # optional batch norm layer
    if batch_norm:
        layers.append(nn.BatchNorm2d(out_channels))
    return nn.Sequential(*layers)

### The Generator

Use:
* 3 conv layers,
* Resnet layers (given)
* 3 transpose conv layers

Output shape of transpose conv can be calculated by the following formula:

```
output_size = strides * (input_size-1) + kernel_size - 2*padding
```

select the correct input dims, output dims and kernel size! (Default stride = 2, default padding = 1)

In [None]:
class CycleGenerator(nn.Module):
    
    def __init__(self, conv_dim=64, n_res_blocks=6):
        super(CycleGenerator, self).__init__()

        # Encoder 

        # Resnet part. Hint: add Residual blocks to a list, then use nn.Sequential(*res_layers)
        # This is useful for taking a variable number of res_blocks

        # 3. Decoder
        
    def forward(self, x):
        """Given an image x, returns a transformed image."""
        # define feedforward behavior, applying activations as necessary. Make sure to finish off with tanh activation

        return out

In [None]:
# Test Generator
with torch.no_grad():
    G_XtoY = CycleGenerator()
    rand_input = torch.randn(3, 3, 128, 128)
    out = G_XtoY(rand_input)
    assert(out.size() == rand_input.size())

### Model Creator 

In [None]:
def create_model(g_conv_dim=64, d_conv_dim=64, n_res_blocks=6, device='cuda'):
    
    # Instantiate generators
    G_XtoY = ?
    G_YtoX = ?
    
    # Instantiate discriminators
    D_X = ?
    D_Y = ?

    # Cast to appropriate device. 
    G_XtoY.to(device)
    G_YtoX.to(device)
    D_X.to(device)
    D_Y.to(device)
    print('Models loaded on {}'.format(device))

    return G_XtoY, G_YtoX, D_X, D_Y

In [None]:
G_XtoY, G_YtoX, D_X, D_Y = create_model(device=device)

#### At this point, you can open a terminal on the server and run ```nvidia-smi```. On the first GPU, you will see that some memory has been allocated for your model (~1GB). 

If you want to skip training and just visualize samples, skip to Section 6

## 3. Loss functions 

* Use `torch.mean()` for mean across inputs.
* Use `torch.abs()` to get absolute value of a tensor

In [None]:
def real_mse_loss(D_out):
    # How close is the produced output from being REAL?
    return ?

def fake_mse_loss(D_out):
    # How close is the produced output from being FAKE?
    return ?

def cycle_consistency_loss(real_im, reconstructed_im, lambda_weight):
    # Reconstruction loss
    reconstr_loss = ?
    return lambda_weight*reconstr_loss

## 4. Optimizers 

- We will use Adam with a initial learning rate of ```lr=0.0002```. 

- In PyTorch, we pass the model's parameters into the optimizer. The model's parameters can be obtained by calling ```model.parameters()```. 

- When we want to optimize the parameters of two models with one optimizer, we can simply concatenate the list of the two models' parameters and pass that into the optimizer 

In [None]:
import torch.optim as optim

# Hyperparameters for Adam optimizer
lr = 0.0002
beta1 = 0.5
beta2 = 0.999

d_x_optimizer = optim.Adam(params=D_X.parameters(), lr=lr, betas=[beta1, beta2])
d_y_optimizer = optim.Adam(params=D_Y.parameters(), lr=lr, betas=[beta1, beta2])

# Create a single optimizer for both generators
generator_params = ?
g_optimizer = optim.Adam(params=generator_params, lr=lr, betas=[beta1, beta2])

## 5. Training 

### Training Loop 

Skip this part if you want to skip directly to visualizing the results with a pretrained network.

Training takes around **80 minutes** on a single **NVIDIA 2080Ti** GPU (CUDA 10.2) for 6000 iterations. It will use up around **7GB** of VRAM

In [None]:
import time
from helpers import save_samples, checkpoint

In [None]:
def training_loop(dataloader_X, dataloader_Y, test_dataloader_X, test_dataloader_Y, 
                  n_epochs=1000):
    
    since = time.time()
    print_every=50
    sample_every=100
    
    losses = []      # keep track of losses over time

    
    test_iter_X = iter(test_dataloader_X)
    test_iter_Y = iter(test_dataloader_Y)

    # Get some fixed data from domains X and Y for sampling. These are images that are held
    # constant throughout training, that allow us to inspect the model's performance.
    fixed_X = test_iter_X.next()[0]
    fixed_Y = test_iter_Y.next()[0]

    # batches per epoch
    iter_X = iter(dataloader_X)
    iter_Y = iter(dataloader_Y)
    batches_per_epoch = min(len(iter_X), len(iter_Y))

    for epoch in range(1, n_epochs+1):

        # Reset iterators for each epoch
        if epoch % batches_per_epoch == 0:
            iter_X = iter(dataloader_X)
            iter_Y = iter(dataloader_Y)

        images_X, _ = iter_X.next()
        images_Y, _ = iter_Y.next()
        
        # move images to device
        images_X = images_X.to(device)
        images_Y = images_Y.to(device)

        # ============================================
        #            TRAIN THE DISCRIMINATORS
        # ============================================
        """D_X """ 
        d_x_optimizer.zero_grad()

        # Train with real X images
        
        out_x = ?
        D_X_real_loss = ?
        
        # Train with fake X images
        fake_X = ?
        out_x = ?
        D_X_fake_loss = ?
        
        d_x_loss = D_X_real_loss + D_X_fake_loss
        d_x_loss.backward()
        d_x_optimizer.step()
        
        """D_Y """ 
        d_y_optimizer.zero_grad()
        
        # Train with real Y images
        out_y = ?
        D_Y_real_loss = ?
        
        # Train with fake Y images
        fake_Y = ?
        out_y = ?
        D_Y_fake_loss = ?

        d_y_loss = D_Y_real_loss + D_Y_fake_loss
        d_y_loss.backward()
        d_y_optimizer.step()


        # =========================================
        #            TRAIN THE GENERATORS
        # =========================================

        """G : G_XtoY + G_YtoX 
        For reconstructed loss, use a lambda weight of 10
        """ 
        

        g_optimizer.zero_grad()
        
        fake_X = ?
        out_x = ?
        g_YtoX_loss = ?

        fake_Y = ?
        out_y = ?
        g_XtoY_loss = ?

        reconstructed_X = ?
        reconstructed_x_loss = ?
        
        reconstructed_Y = ?
        reconstructed_y_loss = ?

        g_total_loss = g_YtoX_loss + g_XtoY_loss + reconstructed_y_loss + reconstructed_x_loss
        g_total_loss.backward()
        g_optimizer.step()

        # Print
        if epoch % print_every == 0:
            # append real and fake discriminator losses and the generator loss
            losses.append((d_x_loss.item(), d_y_loss.item(), g_total_loss.item()))
            time_elapsed = time.time() - since
            print('Epoch [{:5d}/{:5d}] | d_X_loss: {:6.4f} | d_Y_loss: {:6.4f} | g_total_loss: {:6.4f} time : {:.0f}m {:.0f}s'.format(
                    epoch, n_epochs, d_x_loss.item(), d_y_loss.item(), g_total_loss.item(),time_elapsed // 60, time_elapsed % 60))

        
        if epoch % sample_every == 0:
            G_YtoX.eval() # set generators to eval mode for sample generation
            G_XtoY.eval()
            save_samples(epoch, fixed_Y, fixed_X, G_YtoX, G_XtoY, batch_size=16, sample_dir='samples_cyclegan')
            G_YtoX.train()
            G_XtoY.train()

        checkpoint_every=1
        if epoch % checkpoint_every == 0:
            checkpoint(epoch, G_XtoY, G_YtoX, D_X, D_Y)

    return losses

In [None]:
n_epochs = 6000 # keep this small when testing if a model first works

losses = training_loop(summer_dataloader, winter_dataloader, test_summer_dataloader, test_winter_dataloader, n_epochs=n_epochs)

### Visualize Training Losses

In [None]:
fig, ax = plt.subplots(figsize=(12,8))
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator, X', alpha=0.5)
plt.plot(losses.T[1], label='Discriminator, Y', alpha=0.5)
plt.plot(losses.T[2], label='Generators', alpha=0.5)
plt.title("Training Losses")
plt.legend()

## 6. Visualize the Results 

### (Optional) Load Pretrained Models

If we lack the time and resources to fully train the model, we can load a pretrained model and visualize the results. Use the ```save_samples``` function as in the training loop

In [None]:
from helpers import load_checkpoints
load_checkpoints(G_XtoY, G_YtoX, D_X, D_Y, checkpoint_dir='checkpoints_cyclegan_pretrained')

### Visualize Samples 

In [None]:
import matplotlib.image as mpimg

# helper visualization code
def view_samples(iteration, sample_dir='samples_cyclegan'):
    
    # samples are named by iteration
    path_XtoY = os.path.join(sample_dir, 'sample-{:06d}-X-Y.png'.format(iteration))
    path_YtoX = os.path.join(sample_dir, 'sample-{:06d}-Y-X.png'.format(iteration))
    
    # read in those samples
    try: 
        x2y = mpimg.imread(path_XtoY)
        y2x = mpimg.imread(path_YtoX)
    except:
        print('Invalid number of iterations.')
    
    fig, (ax1, ax2) = plt.subplots(figsize=(18,20), nrows=2, ncols=1, sharey=True, sharex=True)
    ax1.imshow(x2y)
    ax1.set_title('X to Y')
    ax2.imshow(y2x)
    ax2.set_title('Y to X')

In [None]:
# view samples at iteration 6000
view_samples(6000, 'samples_cyclegan')

In [None]:
view_samples(1000, 'samples_cyclegan')