# Importing libraries

In [None]:
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
import pytorch_lightning as pl

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader
import albumentations as albu
from albumentations.pytorch import ToTensorV2

from utils import print_image, print_train_image, print_test_image
from modeling import ResNetUNetGenerator, Discriminator
from dataset import Gray_colored_dataset

device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print("device name", torch.cuda.get_device_name())

# Initializing dataset

In [None]:
transforms = albu.Compose([
            albu.SmallestMaxSize(256),
            albu.RandomCrop(256, 256), 
            albu.HorizontalFlip(p=0.2),
            albu.VerticalFlip(p=0.2),
            albu.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            ToTensorV2()
            ], additional_targets= {'grayscale_image': 'image'})


dataset_path = '../input/flickr30k/images'
dataset = Gray_colored_dataset(dataset_path, transforms)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)


inputs, labels = next(iter(dataloader))
print('Input Image')
print_image(inputs[0])

# Initializing Generator and Discriminator
***
Generator is basically a Unet with some tweaks

Discriminator is a typical conv classifier, adjustable for any input image size

We will train the model with:
- BCELoss from scores of Discriminator
- Mean of MAE and RMSE from comparison of generated image and ground truth

In [None]:
model = Colorization()

# Training process
***
Training of this GAN is pretty simple
1. Firstly, we update Discriminator's gradients with ground truth image and its error
2. Secondly, we generate a colored image and accumulate Discriminator's gradients with processed colored image and its error and update Discriminator's weights
3. Then we calculate all the losses' values for Generator and update it's weights

Also we freeze resnet layers of Generator for 1/3 of first epoch in order not to wreck well-pretrained weights

In [None]:
trainer = Trainer(
    logger=wandb_logger,    # W&B integration
    gpus=-1,                # use all GPU's
    max_epochs=15            # number of epochs
    )

trainer.fit(model)

# Check on old photos perfomance

In [None]:
from torchvision.datasets import ImageFolder
import torchvision
transform_to_input_image = torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(256),
            torchvision.transforms.Grayscale(num_output_channels=3),
            torchvision.transforms.ToTensor()
])

gray_test_dataset = ImageFolder('../input/test-images', transform=transform_to_input_image)
print_images_from_dataset(model, gray_test_dataset)

# Save models and optimizers

In [None]:
torch.save({
            'model_state_dict': model.generator.state_dict(),
            }, './generator.pth')
torch.save({
            'model_state_dict': model.discriminator.state_dict(),
            }, './discriminator.pth')