# Importing libraries

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from utils import loss_holder, print_image, print_train_image, print_test_image
from modeling import ResNetUNetGenerator, Discriminator
from dataset import Gray_colored_dataset


# Initializing dataset

In [None]:
transform_to_input_image = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor()
])

transform_to_target_image = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.ToTensor()
])



dataset_path = '../input/flickr30k/images'
dataset = Gray_colored_dataset(dataset_path, transform_to_target_image, transform_to_input_image)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)


inputs, labels = next(iter(dataloader))
print('Input Image')
print_image(inputs[0])

# Initializing Generator and Discriminator
***
Generator is basically a Unet with some tweaks

Discriminator is a typical conv classifier, adjustable for any input image size

We will train the model with:
- BCELoss from scores of Discriminator
- Mean of MAE and RMSE from comparison of generated image and ground truth

In [None]:
ndf = 16
beta1 = 0.5

netG = ResNetUNet().to(device)
netD = Discriminator(256).to(device)

criterion = nn.BCELoss().to(device)
MSE_loss = nn.MSELoss().to(device)

real_label = 1.
fake_label = 0.
lr = 0.0001
num_epochs = 15

optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))



# Training process
***
Training of this GAN is pretty simple
1. Firstly, we update Discriminator's gradients with ground truth image and its error
2. Secondly, we generate a colored image and accumulate Discriminator's gradients with processed colored image and its error and update Discriminator's weights
3. Then we calculate all the losses' values for Generator and update it's weights

Also we freeze resnet layers of Generator for 1/3 of first epoch in order not to wreck well-pretrained weights

In [None]:
print("Starting Training Loop...")
netG.train()
netD.train()

netG.freeze_parameters()

losses_holder = loss_holder()

for epoch in range(num_epochs):
    
    for i, (inputs, labels) in enumerate(dataloader):
            
        if netG.frozen and i > len(dataloader)//3:
            netG.unfreeze_parameters()
                
        ## Train with all-real batch
        netD.zero_grad()
        inputs = inputs.to(device)
        labels = labels.to(device)
        b_size = inputs.shape[0]
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(labels).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()

        ## Train with all-fake batch
        fake = netG(inputs)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        # Update D
        optimizerD.step()
        
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        Adv_loss = criterion(output, label)
        # Calculate gradients for G
        RMSE_err = torch.sqrt(MSE_loss(fake, labels))
        perc_loss = Adv_loss * 1e-2 + RMSE_err
        perc_loss.backward()
        losses_holder.add_batch_(errD_real.item(), errD_fake.item(), perc_loss.item(), Adv_loss.item())
        
        # Update G
        optimizerG.step()
        
        # Output training stats     
        if i % (len(dataloader)//10) == 0:
            current_error_means = losses_holder.get_means()
            print('[%d/%d][%d%%]\tLoss_D_real: %.4f\tLoss_D_fake: %.4f\tLoss_G: %.4f\tAdversarial loss: %.4f\t'
                  % (epoch, num_epochs, (i * 100 / len(dataloader)),
                     current_error_means['D_real_loss'], current_error_means['D_fake_loss'], 
                     current_error_means['G_loss'], current_error_means['Adv_loss']))
            losses_holder.clear_values_()
            
    
    

# Check on train data perfomance

In [None]:
inp, lab = next(iter(dataloader))

print_train_image(netG, inp[0], lab[0])

# Check on old photos perfomance

In [None]:
gray_test_dataset = ImageFolder('../input/test-images', transform=transform_to_input_image)
print_images_from_dataset(netG, gray_test_dataset)

# Save models and optimizers

In [None]:
torch.save({
            'model_state_dict': netG.state_dict(),
            'optimizer_state_dict': optimizerG.state_dict(),
            }, './generator.pth')
torch.save({
            'model_state_dict': netD.state_dict(),
            'optimizer_state_dict': optimizerD.state_dict(),
            }, './discriminator.pth')