# Cycle GAN

In [15]:
import os
import sys
import glob
import random
import itertools
from PIL import Image

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from utils import ReplayBuffer

In [4]:
DATASET_PATH = os.path.join('..', 'Datasets', 'summer2winter_yosemite')

In [14]:
sys.path.append("../functions/")

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        conv_block = [
            # Better padding that zero padding to conserve image's distribution
            nn.ReflectionPad2d(1), 
            nn.Conv2d(in_features, in_features, 3),
            # Not as good in normalization as BN, but it is better conserving the contrast
            nn.InstanceNorm2d(in_features),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        ]
        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return self.conv_block(x) + x

In [11]:
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, F),
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
        ]

        in_features = 64
        out_features = in_features * 2

        # Encoding
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(True),
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual transformations
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Decoding
        out_features = in_features / 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(
                    in_features, out_features, 3, stride=2, padding=1, output_padding=1
                ),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(True),
            ]
            in_features = out_features
            out_features = in_features // 2

        model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]

        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)


In [17]:
class Discrimator(nn.Module):
    """This will replicate the PatchGAN. Discriminates style or texture
    """
    
    def __init__(self, input_nc):
        super(Discrimator, self).__init__()
        
        model = [
            nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, padding=1)
        ]
        
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        x = self.model(x)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

In [18]:
class ImageDataset(Dataset):
    
    def __init__(self, base_dir, transforms=None, split='train'):
        self.transform = transforms.Compose(transforms)
        self.files_A = sorted(glob.glob(os.path.join(base_dir, f"{split}/A/*.*")))
        self.files_B = sorted(glob.glob(os.path.join(base_dir, f"{split}/B/*.*")))
        
    def __len__(self):
        return max(len(self.files_A), len(self.files_B))
    
    def __getitem__(self, idx):
        image_A = self.transform(Image.open(self.files_A[idx]))
        image_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B)-1)]))

        return {"A": image_A, "B":image_B}

In [None]:
class Class():
    
    def __init__(self) -> None:
        