In [42]:
# Imports
import torch
import torch.nn as nn 
from torchvision import transforms
from  matplotlib import pyplot as plt
from utils import CycleDataset, show_batch
import itertools
from tqdm.notebook import tqdm
from torchvision.utils import make_grid

In [43]:
# Class for storing things such as learning rate, image size...
class Args:
    def __init__(self):
        self.lr = 2e-4
        self.epochs = 200
        self.b1 = 0.5
        self.b2 = 0.999
        self.img_size = 64
        self.pixels = int(self.img_size ** 2)
        self.channels = 3
        self.img_tuple = (self.channels, self.img_size, self.img_size)
        self.batch_size = 1
        self.d_loss_threshold = 0.5
        self.n_res_blocks = 9
        self.decay_epoch = 5
        self.start_epoch = 1
args = Args()

In [47]:
# Loading the images.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
    transforms.Resize(int(args.img_size*1.12), Image.BICUBIC),
    transforms.RandomCrop((args.img_size, args.img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
loader = torch.utils.data.DataLoader(
    CycleDataset("data/", transform=transform, unaligned=True),
    batch_size=args.batch_size,
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    CycleDataset("data/", transform=transform, unaligned=True),
    batch_size=4,
    shuffle=True
)

In [49]:
# Defining the Generator. 
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.pad1 = nn.ReflectionPad2d(1)
        self.pad2 = nn.ReflectionPad2d(1)
        self.conv1 = nn.Conv2d(in_features, in_features, 3)
        self.conv2 = nn.Conv2d(in_features, in_features, 3)
        self.norm1 = nn.InstanceNorm2d(in_features)
        self.norm2 = nn.InstanceNorm2d(in_features)
    
    def block(self, x):
        x = self.pad1(x)
        x = self.conv1(x)
        x = self.norm1(x)
        x = x.relu()
        x = self.pad2(x)
        x = self.conv2(x)
        return self.norm2(x)
    
    def forward(self, x):
        return x + self.block(x)
    
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.out_features = 64
        self.in_features = self.out_features
        self.__compile_model()
        
    # First convolutional block. 
    def __conv1(self):
        self.model += [
            nn.ReflectionPad2d(args.channels),
            nn.Conv2d(args.channels, self.out_features, 7),
            nn.InstanceNorm2d(self.out_features),
            nn.ReLU(inplace=True)
        ]
                       
    # Downsampling
    def __downsample(self):
        for _ in range(2):
            self.out_features *= 2  
            self.model += [
                nn.Conv2d(self.in_features, self.out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(self.out_features),
                nn.ReLU(inplace=True)
            ]
            self.in_features = self.out_features

    def __residual_blocks(self):
        for _ in range(args.n_res_blocks):
            self.model += [ResidualBlock(self.out_features)]

    # Upsampling 
    def __upsample(self):
        for _ in range(2):
            self.out_features //= 2
            self.model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(self.in_features, self.out_features, 3, stride=1, padding=1),
                nn.ReLU(inplace=True)
            ]
            self.in_features = self.out_features

    # Output layer
    def __output(self):
        self.model += [
            nn.ReflectionPad2d(args.channels),
            nn.Conv2d(self.out_features, args.channels, 7),
            nn.Tanh()
        ]
    
    # Compiles the model
    def __compile_model(self):
        self.model = []
        self.__conv1()
        self.__downsample()
        self.__residual_blocks()
        self.__upsample()
        self.__output()
        self.model = nn.Sequential(*self.model)
    
    def forward(self, x):
        return self.model(x)


In [50]:
# Defining the discriminator.
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.output_shape = (1, args.img_size // 2 ** 4, args.img_size // 2 ** 4)
        self.__compile_model()
    
    def __discriminator_block(self, in_filters, out_filters, normalize=True):
        layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_filters))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers 

    def __compile_model(self):
        self.model = nn.Sequential(
            *self.__discriminator_block(args.channels, 64,  normalize=False),
            *self.__discriminator_block(64, 128),
            *self.__discriminator_block(128, 256),
            *self.__discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )
    
    def forward(self, x):
        return self.model(x)

In [51]:
# Initializes weights to a Gaussian distribution. 
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm2d') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

In [52]:
# Initalizing the networks. 
gen_ab = Generator().to(device)
gen_ba = Generator().to(device)
disc_a = Discriminator().to(device)
disc_b = Discriminator().to(device)
loss_fn_gan = nn.MSELoss()
loss_fn_cycle = nn.L1Loss()
loss_fn_identity = nn.L1Loss()

gen_ab.apply(init_weights)
gen_ba.apply(init_weights)
disc_a.apply(init_weights)
disc_b.apply(init_weights)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): ZeroPad2d((1, 0, 1, 0))
    (12): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

In [53]:
# Setting up optimizers. 
optim_g = torch.optim.Adam(
    itertools.chain(gen_ab.parameters(), gen_ba.parameters()), lr=args.lr, betas=(args.b1, args.b2)
    )
optim_d_a = torch.optim.Adam(
    disc_a.parameters(), lr=args.lr, betas=(args.b1, args.b2)
)
optim_d_b = torch.optim.Adam(
    disc_b.parameters(), lr=args.lr, betas=(args.b1, args.b2)
)

In [54]:
# Learning rate scheduler.
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "n_epochs must be larger than decay_start_epoch"
        self.n_epochs = n_epochs 
        self.offset = offset 
        self.decay_start_epoch = decay_start_epoch
    
    def step(self, epoch):
        return 1.0 - max(0, epoch+self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

In [55]:
lr_scheduler_g = torch.optim.lr_scheduler.LambdaLR(
    optim_g,
    lr_lambda=LambdaLR(args.epochs, args.start_epoch, args.decay_epoch).step
)
lr_scheduler_d_a = torch.optim.lr_scheduler.LambdaLR(
    optim_d_a,
    lr_lambda=LambdaLR(args.epochs, args.start_epoch, args.decay_epoch).step
)
lr_scheduler_d_b = torch.optim.lr_scheduler.LambdaLR(
    optim_d_b,
    lr_lambda=LambdaLR(args.epochs, args.start_epoch, args.decay_epoch).step
)

In [57]:
# Function to show some real and fake images. 
def sample_imgs():
    imgs = next(iter(test_loader))
    gen_ab.eval()
    gen_ba.eval()
    real_B = imgs['B'].to(device) # Photos
    with torch.no_grad():
        fake_A = gen_ba(real_B).detach()
    
    show_batch(real_B.cpu(), title="Photos")
    show_batch(fake_A.cpu(), title="Photos with style transfer applied")


In [None]:
# Training loop 
g_loss_list, d_loss_list = [], []

for epoch in range(args.start_epoch, args.epochs + args.start_epoch):
    g_epoch_loss = 0
    d_epoch_loss = 0
    for _, batch in tqdm(enumerate(loader)):
        real_a = batch['A'].to(device)
        real_b = batch['B'].to(device)
        real_label = torch.ones((real_a.size(0), *disc_a.output_shape), dtype=torch.float).to(device)
        fake_label = torch.zeros((real_a.size(0), *disc_a.output_shape), dtype=torch.float).to(device)

        ############
        # GENERATORS
        ############
        gen_ab.train()
        gen_ba.train()
        optim_g.zero_grad()
        
        # Identity loss 
        # Feeding A to gen_ba should produce A, same goes for B and gen_ab. 
        loss_id_a = loss_fn_identity(gen_ba(real_a), real_a)
        loss_id_b = loss_fn_identity(gen_ab(real_b), real_b) 
        loss_id = (loss_id_a + loss_id_b) / 2

        # GAN loss
        fake_b = gen_ab(real_a) # Fake photo.
        loss_gan_ab = loss_fn_gan(disc_b(fake_b), real_label) 
        fake_a = gen_ba(real_b) # Fake painting.
        loss_gan_ba = loss_fn_gan(disc_a(fake_a), real_label)
        loss_gan = (loss_gan_ab + loss_gan_ba) / 2

        # Cycle loss 
        recov_a = gen_ba(fake_b) # Fake painting created from fake photo. 
        loss_cycle_a = loss_fn_cycle(recov_a, real_a)
        recov_b = gen_ab(fake_a) # Fake photo created from fake painting. 
        loss_cycle_b = loss_fn_cycle(recov_b, real_b)
        loss_cycle = (loss_cycle_a + loss_cycle_b) / 2

        # Total loss 
        loss_g = loss_gan + (10 * loss_cycle) + (5 * loss_id)
        loss_g.backward()
        optim_g.step()

        #################
        # DISCRIMINATOR A 
        #################
        optim_d_a.zero_grad()
        loss_real = loss_fn_gan(disc_a(real_a), real_label) # Real paintings should be classified as real.
        loss_fake = loss_fn_gan(disc_a(fake_a.detach()), fake_label) # Fake paintings should be classified as fake. 
        loss_d_a = (loss_real + loss_fake) / 2
        loss_d_a.backward()
        optim_d_a.step()

        #################
        # DISCRIMINATOR B
        #################
        optim_d_b.zero_grad()
        loss_real = loss_fn_gan(disc_b(real_b), real_label)  # Real photos should be classified as real.
        loss_fake = loss_fn_gan(disc_b(fake_b.detach()), fake_label) # Fake photos should be classified as fake. 
        loss_d_b = (loss_real + loss_fake) / 2
        loss_d_b.backward()
        optim_d_b.step()

        # Total loss 
        loss_d = (loss_d_a + loss_d_b) / 2 

        d_epoch_loss += loss_d.item()
        g_epoch_loss += loss_g.item()

    g_epoch_loss /= len(loader)
    d_epoch_loss /= len(loader)
    print(f"EPOCH: {epoch}\nGenerator loss: {g_epoch_loss:.3f}\nDiscriminator loss: {d_epoch_loss:.3f}\n")
    g_loss_list.append(g_epoch_loss)
    d_loss_list.append(d_epoch_loss)

    # Generating some images to show later. 
    if epoch % 25 == 0 or epoch == 1:
        sample_imgs()

In [None]:
#  Plotting the loss and some generated images.
plt.plot(g_loss_list, label="Generator loss")
plt.plot(d_loss_list, label="Discriminator loss")
plt.legend()
plt.show()
sample_imgs()