In [1]:
%reload_ext autoreload
%autoreload 2
from tqdm.notebook import tqdm
import os
import numpy as np
import glob
import PIL.Image as Image
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

if torch.cuda.is_available():
    print("The code will run on GPU.")
else:
    print("The code will run on CPU. Go to Edit->Notebook Settings and choose GPU as the hardware accelerator")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

The code will run on GPU.


# Load dataset

In [12]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, transform, data_path='horse2zebra'):
        self.transform = transform
        self.A_paths = sum([glob.glob('%s/%s%s/*.jpg' % (data_path, split, 'A')) for split in ['test', 'train']], [])
        self.B_paths = sum([glob.glob('%s/%s%s/*.jpg' % (data_path, split, 'B')) for split in ['test', 'train']], [])
        assert len(self.A_paths) == len(self.B_paths)
        
    def __len__(self):
        return len(self.A_paths)

    def __getitem__(self, idx):
        A = Image.open(self.A_paths[idx]).convert('RGB')
        B = Image.open(self.B_paths[idx]).convert('RGB')
        return self.transform(A), self.transform(B)

transform = transforms.ToTensor()

dataset = MyDataset(transform)

batch_size = 1
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [13]:
#A, B = next(iter(dataloader))
#
#plt.figure()
#
#fix_axis = lambda im: np.swapaxes(np.swapaxes(im, 0, 2), 0, 1)
#
#plt.subplot(1,2,1)
#plt.imshow(fix_axis(A))
#plt.axis('off')
#plt.subplot(1,2,2)
#plt.imshow(fix_axis(B))
#plt.axis('off')
#plt.show()



# Model definition

In [14]:
class ResNetBlock(nn.Module):
    def __init__(self, n_features=32):
        super(ResNetBlock, self).__init__()
        self.convolutional = nn.Sequential(
            nn.Conv2d(n_features, n_features, 3, padding=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(n_features, n_features, 3, padding=1, stride=1)
        )

    def forward(self, x):
        out = self.convolutional(x)
        out += x
        out = F.relu(out)
        return out

class Generator(nn.Module):
    def __init__(self, n_resblocks=5):
        super().__init__()

        self.encoding = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=1, padding=3),   # 256 -> 256
            nn.InstanceNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # 256 -> 128
            nn.InstanceNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),  # 128 -> 64
            nn.InstanceNorm2d(256),
            nn.ReLU()
        )

        self.resnet = nn.Sequential(
            *[ResNetBlock(n_features=256) for _ in range(n_resblocks)]
        )

        self.decoding = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 7, stride=1, padding=3),
            nn.ReLU()
        )



    def forward(self, x):
        x = self.encoding(x)
        x = self.resnet(x)
        x = self.decoding(x)
        return x


class Discriminator(nn.Module):
    def __init__(self, leak=0.2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),  # 256 -> 128
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(leak),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 128 -> 64
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(leak),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),  # 64 -> 32
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(leak),
            nn.Conv2d(256, 512, 4, stride=1, padding=1),  # 32 -> 31
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(leak),
            nn.Conv2d(512, 1, 4, stride=1, padding=1), # 31 -> 30
            nn.LogSoftmax(dim=0),
        )

    def forward(self, x):
        x = self.conv(x)
        return x

        
        
        

In [15]:
# initialize models in memory
A2B = Generator()
B2A = Generator()
D_A = Discriminator()
D_B = Discriminator()

A2B.cuda()
B2A.cuda()
D_A.cuda()
D_B.cuda()

from utils import weights_init_normal
A2B.apply(weights_init_normal)
B2A.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)
print('loaded all models to GPU memory!!')

loaded all models to GPU memory!!


# Training loop

In [None]:
import itertools
from torch.autograd import Variable
from utils import Logger, ReplayBuffer, LambdaLR
Tensor = torch.cuda.FloatTensor

# EPOCHS !!
epochs = 100
decay_after_epochs = 50


# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

LR = 0.0002

optimizer_G = torch.optim.Adam(itertools.chain(A2B.parameters(), B2A.parameters()), lr=LR, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=LR, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=LR, betas=(0.5, 0.999))

# LR schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(epochs, 0, decay_after_epochs).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(epochs, 0, decay_after_epochs).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(epochs, 0, decay_after_epochs).step)

# buffer variables (to be overriden and used during training)
#input_A = Tensor(batch_size, 3, 256, 256)
#input_B = Tensor(batch_size, 3, 256, 256)
target_real = Variable(Tensor(batch_size).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(batch_size).fill_(0.0), requires_grad=False)
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# progress logger
logger = Logger(epochs, len(dataloader))

for epoch in range(epochs):
    for i, batch in enumerate(dataloader):
        real_A = Variable(torch.tensor(batch[0]).to(device))
        real_B = Variable(torch.tensor(batch[1]).to(device))

        # for generators A & B
        optimizer_G.zero_grad()

        # Identity loss
        # G_A2B(B) should equal B if real B is fed
        same_B = A2B(real_B)
        loss_identity_B = criterion_identity(same_B, real_B)*5.0
        # G_B2A(A) should equal A if real A is fed
        same_A = B2A(real_A)
        loss_identity_A = criterion_identity(same_A, real_A)*5.0

        # GAN loss
        fake_B = A2B(real_A)
        pred_fake = D_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

        fake_A = B2A(real_B)
        pred_fake = D_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

        # Cycle loss
        recovered_A = B2A(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0

        recovered_B = A2B(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0

        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        
        optimizer_G.step()
        ###################################

        ###### Discriminator A ######
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real = D_A(real_A)
        loss_D_real = criterion_GAN(pred_real, target_real)

        # Fake loss
        fake_A = fake_A_buffer.push_and_pop(fake_A)
        pred_fake = D_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_A = (loss_D_real + loss_D_fake)*0.5
        loss_D_A.backward()

        optimizer_D_A.step()
        ###################################

        ###### Discriminator B ######
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real = D_B(real_B)
        loss_D_real = criterion_GAN(pred_real, target_real)
        
        # Fake loss
        fake_B = fake_B_buffer.push_and_pop(fake_B)
        pred_fake = D_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_B = (loss_D_real + loss_D_fake)*0.5
        loss_D_B.backward()

        optimizer_D_B.step()
        ###################################

        # Progress report (http://localhost:8097)
        logger.log({'loss_G': loss_G, 'loss_G_identity': (loss_identity_A + loss_identity_B), 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A),
                    'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 'loss_D': (loss_D_A + loss_D_B)}, 
                    images={'real_A': real_A, 'real_B': real_B, 'fake_A': fake_A, 'fake_B': fake_B})
    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    # Save models checkpoints
    torch.save(A2B.state_dict(), 'output/A2B.pth')
    torch.save(B2A.state_dict(), 'output/B2A.pth')
    torch.save(A.state_dict(), 'output/A.pth')
    torch.save(B.state_dict(), 'output/B.pth')

    