In [None]:
import os
import math
import glob
import random
import itertools
import numpy as np
from PIL import Image
from skimage import io

import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F

import torch.utils.data
from torchvision.utils import save_image
import torchvision
from torch.autograd import Variable
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class MyDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, dir_A, dir_B, transform=None):
        """
        Args:

        """
        self.dir_A = glob.glob(os.path.join(dir_A, '*.jpg'))
        self.dir_B = glob.glob(os.path.join(dir_B, '*.jpg'))
        self.transform = transform

    def __len__(self):
        return min(len(self.dir_A), len(self.dir_B))

    def __getitem__(self, idx):

        image_A = Image.open(self.dir_A[idx])

        image_B = Image.open(self.dir_B[idx])
        
        if self.transform:
            image_A = self.transform(image_A)
            image_B = self.transform(image_B)

        return image_A, image_B

### Generator

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels=128, out_channels=128):
        super(ResBlock, self).__init__()
        
        self.padd_1 = nn.ReflectionPad2d(1)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True)
        self.norm_1 = nn.InstanceNorm2d(out_channels, affine=True, track_running_stats=True)
        
        self.padd_2 = nn.ReflectionPad2d(1)
        self.conv_2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True)
        self.norm_2 = nn.InstanceNorm2d(out_channels, affine=True, track_running_stats=True)
        
        
    def forward(self, x):
        
        x_res = self.padd_1(x)
        x_res = self.conv_1(x_res)
        x_res = self.norm_1(x_res)
        x_res = F.relu(x_res)
        
        x_res = self.padd_2(x_res)
        x_res = self.conv_2(x_res)
        x_res = self.norm_2(x_res)

        return x + x_res 
class Generator(nn.Module):
    def __init__(self, num_bottlenecks=6):
        super(Generator, self).__init__()
        
        #Downsampling layers
        
        #c7s1-32
        self.padd_1 = nn.ReflectionPad2d(3)
        self.conv_1 = nn.Conv2d(3, 32, kernel_size=7, stride=1, padding=0, bias=False)
        self.norm_1 = nn.InstanceNorm2d(32, affine=True, track_running_stats=True)
        
        
        #d64
        self.conv_2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=True)
        self.norm_2 = nn.InstanceNorm2d(64, affine=True, track_running_stats=True)
        
        #d128
        self.conv_3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=True)
        self.norm_3 = nn.InstanceNorm2d(128, affine=True, track_running_stats=True)
           
        #ResBlocks
        self.Bottleneck = nn.Sequential(*[
                            ResBlock(128, 128) for _ in range(num_bottlenecks)
        ])
        
        #UpSampling
        #u64
        self.conv_4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True)
        self.norm_4 = nn.InstanceNorm2d(64, affine=True, track_running_stats=True)
          
        #u32
        self.conv_5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True)
        self.norm_5 = nn.InstanceNorm2d(32, affine=True, track_running_stats=True)
        
        #c7s1-3
        self.padd_6 = nn.ReflectionPad2d(3)
        self.conv_6 = nn.Conv2d(32, 3, kernel_size=7, stride=1, padding=0, bias=True)
        self.norm_6 = nn.InstanceNorm2d(3, affine=True, track_running_stats=True)
   
        
    def forward(self, x):
        '''
        Inputs:
            x: (batch x 3 x 128 x 128)
        Outputs:
            image: (batch x 3 x 128 x 128)
        '''
        
        x = self.padd_1(x)
        x = self.conv_1(x)
        x = self.norm_1(x)
        x = F.relu(x)
        
        x = self.conv_2(x)
        x = self.norm_2(x)
        x = F.relu(x)
        
        x = self.conv_3(x)
        x = self.norm_3(x)
        x = F.relu(x)
                
        x = self.Bottleneck(x)
        
        x = self.conv_4(x)
        x = self.norm_4(x)
        x = F.relu(x)
        
        x = self.conv_5(x)
        x = self.norm_5(x)
        x = F.relu(x)
        
        x = self.padd_6(x)
        x = self.conv_6(x)
        x = self.norm_6(x)
        x = torch.tanh(x)
        

        return x

### Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.conv_1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=True)
        
        self.conv_2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=True)
        self.norm_2 = nn.InstanceNorm2d(128, affine=True, track_running_stats=True)
        
        self.conv_3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=True)
        self.norm_3 = nn.InstanceNorm2d(256, affine=True, track_running_stats=True)
        
        self.conv_4 = nn.Conv2d(256, 512, kernel_size=4, padding=1, bias=True)
        self.norm_4 = nn.InstanceNorm2d(512, affine=True, track_running_stats=True)
        
        self.conv_output = nn.Conv2d(512, 1, kernel_size=4, padding=1, bias = False)
       
    def forward(self, x):
        '''
        Inputs:
            x: (batch, 3, 128, 128)
        Output:
            out: (batch, 1)
        '''
        
        x = self.conv_1(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_2(x)
        x = self.norm_2(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_3(x)
        x = self.norm_3(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_4(x)
        x = self.norm_4(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_output(x)
        x = torch.sigmoid(x)
        
        return x

### Loss

In [None]:
generator_A2B = Generator(6).to(device)
generator_B2A = Generator(6).to(device)


In [None]:
discriminator_A = Discriminator().to(device)
discriminator_B = Discriminator().to(device)

In [None]:
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()


In [None]:
criterion_identity = torch.nn.L1Loss()

In [None]:
optimizer_G = torch.optim.Adam(itertools.chain(generator_A2B.parameters(), generator_B2A.parameters()),
                                lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(discriminator_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(discriminator_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
batch_size = 4

dataset = MyDataset('data/trainA', 'data/trainB', transform=transforms.Compose([
                                               transforms.Resize(128),
                                               transforms.ToTensor()]))
dataloader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=True, num_workers=4)

Tensor = torch.cuda.FloatTensor
input_A = Tensor(batch_size, 3, 128, 128)
input_B = Tensor(batch_size, 3, 128, 128)

In [None]:
test_dataset = MyDataset('data/testA', 'data/testB', transform=transforms.Compose([
                                               transforms.Resize(128),
                                               transforms.ToTensor()]))
test_dataloader = DataLoader(test_dataset, batch_size=1,
                        shuffle=True, num_workers=4)

In [None]:
epoch = 0

### Input A is image of summer
### Input B is image of winter

In [None]:
while epoch<200:
    for en, x in enumerate(dataloader):
#         x = x.to(device)
        
        input_A = Tensor(x[0].size(0), 3, 128, 128)
        input_B = Tensor(x[0].size(0), 3, 128, 128)
        real_A = Variable(input_A.copy_(x[0])) 
        real_B = Variable(input_B.copy_(x[1]))
        
        
        optimizer_G.zero_grad()
        
        same_B = generator_A2B(real_B)
        loss_identity_B = criterion_identity(same_B, real_B)*5.0
        
        # G_B2A(A) should equal A if real A is fed
        same_A = generator_B2A(real_A)
        loss_identity_A = criterion_identity(same_A, real_A)*5.0        
        
#         real_A = x[0].to(device)
#         real_B = x[1].to(device)
        
        target_real = Variable(Tensor(real_A.size(0), 1, 14, 14).fill_(1.0), requires_grad=False)
        target_fake = Variable(Tensor(real_B.size(0), 1, 14, 14).fill_(0.0), requires_grad=False)
        
        ### GENERATOR LOSS ###     
        fake_B = generator_A2B(real_A)
        discriminator_out_A = discriminator_B(fake_B)
        gen_loss_A = criterion_GAN(discriminator_out_A, target_real)
        
        
        fake_A = generator_B2A(real_B)    
        discriminator_out_B = discriminator_A(fake_A)
        gen_loss_B = criterion_GAN(discriminator_out_B, target_real)
        
        
        ### Cycle loss ###
        recovered_A = generator_B2A(fake_B)
        cycle_loss_ABA = criterion_cycle(recovered_A, real_A) * 10.0
        
        recovered_B = generator_A2B(fake_A)
        cycle_loss_BAB = criterion_cycle(recovered_B, real_B) * 10.0
        
        ###Total Generator Loss ###
        generator_loss = loss_identity_B + loss_identity_A + gen_loss_A + gen_loss_B + cycle_loss_ABA + cycle_loss_BAB
        generator_loss.backward()       
        optimizer_G.step()
        
        
        ### Discriminator A ###
        optimizer_D_A.zero_grad()
        
        pred_real_A = discriminator_A(real_A)       
        disc_loss_real_A = criterion_GAN(pred_real_A, target_real)
        
        pred_fake_A = discriminator_A(fake_A.detach())
        disc_loss_fake_A = criterion_GAN(pred_fake_A, target_fake)
        
        disc_loss_A = (disc_loss_real_A + disc_loss_fake_A) * 0.005
        disc_loss_A.backward()

        optimizer_D_A.step()
        
        
        ### Discriminator B ###
        optimizer_D_B.zero_grad()
        
        pred_real_B = discriminator_B(real_B)       
        disc_loss_real_B = criterion_GAN(pred_real_B, target_real)
        
        pred_fake_B = discriminator_B(fake_B.detach())
        disc_loss_fake_B = criterion_GAN(pred_fake_B, target_fake)
        
        disc_loss_B = (disc_loss_real_B + disc_loss_fake_B) * 0.005
        disc_loss_B.backward()

        optimizer_D_B.step()
        
                 
    img = torchvision.utils.make_grid(torch.cat([real_B, fake_A], dim=0), nrow=4)
    save_image(img, filename=('test_outputs/model/images/'+str(epoch)+'.jpg'))
    
    for en, test in enumerate(test_dataloader):
            generator_A2B.eval()
            real_B = Variable(test[1]).to(device)
            fake_A = generator_B2A(real_B)
            img = torchvision.utils.make_grid([real_B.squeeze(0), fake_A.squeeze(0)], nrow=2)
            save_image(img, filename=('test_outputs/model/images/test'+'_'+str(epoch)+'.jpg'))
            break
    
    if epoch%5==0 and epoch!=0:
        torch.save(generator_B2A.state_dict(), 'test_outputs/model/B2A_gen_'+str(epoch)+'.pth')
        torch.save(generator_A2B.state_dict(), 'test_outputs/model/A2B_gen_'+str(epoch)+'.pth')

    print("The generator loss {0:2f}, discriminator A loss {1:2f}, discriminator B loss {1:2f}".format(generator_loss.item(), disc_loss_A.item(), disc_loss_B.item()))
    epoch += 1
    

In [None]:
transform_list = [transforms.Resize(128),
                                               transforms.ToTensor()
                  ]


transform = transforms.Compose(transform_list)

landmarkk_test = Image.open('face_dataset/celeba/devushka_test.jpg')#.convert('L')
#landmarkk_test = landmarkk_test.resize((256, 128), Image.BICUBIC)
landmarkk_test = transform(landmarkk_test)
landmarkk_test = landmarkk_test.to(device)


fake_test = model.eval(landmarkk_test.unsqueeze(0))
print(fake_test.shape)
print(landmarkk_test.shape)

test_img = torchvision.utils.make_grid([landmarkk_test.repeat(3, 1, 1).cpu(), fake_test.squeeze(0).cpu()], nrow=3)
save_image(test_img.squeeze(0), filename=('devushka_fake_test'+str(epoch)+'.jpg'))

In [None]:
test_dataset = MyDataset('data/testA', 'data/testB', transform=transforms.Compose([
                                               transforms.Resize(128),
                                               transforms.ToTensor()]))
test_dataloader = DataLoader(test_dataset, batch_size=1,
                        shuffle=True, num_workers=4)


for en, test in enumerate(test_dataloader):
    generator_A2B.eval()
    real_B = Variable(test[1]).to(device)
    fake_A = generator_B2A(real_B)
    img = torchvision.utils.make_grid([real_B.squeeze(0), fake_A.squeeze(0)], nrow=2)
    save_image(img, filename=('test_outputs/model/images/test'+'_'+str(epoch)+'.jpg'))
    break