# Cycle GAN

In [3]:
import os

from torch import nn
from torch.nn import functional as F

In [4]:
DATASET_PATH = os.path.join('..', 'Datasets', 'summer2winter_yosemite')

In [5]:


class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        conv_block = [
            # Better padding that zero padding to conserve image's distribution
            nn.ReflectionPad2d(1), 
            nn.Conv2d(in_features, in_features, 3),
            # Not as good in normalization as BN, but it is better conserving the contrast
            nn.InstanceNorm2d(in_features),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        ]
        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return self.conv_block(x) + x

In [9]:
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, F),
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
        ]

        in_features = 64
        out_features = in_features * 2

        # Encoding
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(True),
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual transformations
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Decoding
        out_features /= 2
        