## Initializing Libraries

In [14]:
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy import asarray
from numpy.random import randint
from scipy.linalg import sqrtm
from PIL import Image
from torchvision.utils import save_image


## Declaring Path Variables for Train & Validation

In [15]:
path = "data\cars_train"
SIZE = 256
paths = glob.glob(path + "/*.jpg") 
paths_subset = np.random.choice(paths, 8_000, replace=False)
rand_idxs = np.random.permutation(8_000)
train_index = rand_idxs[:6000] 
val_index = rand_idxs[6000:] 
results = paths_subset[rand_idxs]
train = paths_subset[train_index]
val = paths_subset[val_index]
print(len(train), len(val), len(results))


os.makedirs("land_data",exist_ok=True)
os.makedirs("coco_data",exist_ok=True)
os.makedirs("orig_coco_data",exist_ok=True)
os.makedirs("orig_land_data", exist_ok = True)
os.makedirs("cars_data",exist_ok=True)
os.makedirs("orig_cars_data", exist_ok=True)

6000 2000 8000


## Loading Data With Augmentation

In [16]:

class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE)),
                transforms.RandomVerticalFlip(), 
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((SIZE, SIZE))
        
        self.split = split
        self.size = SIZE
        self.paths = paths
    
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32") 
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1. 
        ab = img_lab[[1, 2], ...] / 110. 
        
        return {'L': L, 'ab': ab}
    
    def __len__(self):
        return len(self.paths)

def make_dataloaders(batch_size=64, n_workers=0, pin_memory=True, **kwargs): 
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers, pin_memory=pin_memory)
    return dataloader

In [17]:
train_data = make_dataloaders(paths=train, split='train')
val_data = make_dataloaders(paths=val, split='val')
result_dl = make_dataloaders(paths=results, split='val')
data = next(iter(train_data))
print(len(train_data), len(val_data), len(result_dl))

94 32 125


## Generative Adversarial Network Architecture

In [18]:
class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False, innermost=False, outermost=False):
        super().__init__()
        self.outermost = outermost
        if input_c is None:
            input_c = nf
        dc = nn.Conv2d(input_c, ni, kernel_size=4, stride=2, padding=1, bias=False)
        dr = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        ur = nn.ReLU(True)
        un = nn.BatchNorm2d(nf)
        
        if not innermost and not outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1, bias=False)
            down = [dr, dc, downnorm]
            up = [ur, upconv, un]
            if dropout: up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4, stride=2, padding=1, bias=False)
            down = [dr, dc]
            up = [ur, upconv, un]
            model = down + up
        elif outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1)
            up = [ur, upconv, nn.Tanh()]
            down = [dc]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, total_filters=64):
        super().__init__()
        unet_block = UnetBlock(total_filters * 8, total_filters * 8, innermost=True)
        for _ in range(n_down - 5):
            unet_block = UnetBlock(total_filters * 8, total_filters * 8, submodule=unet_block, dropout=True)
        endfilter = total_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(endfilter // 2, endfilter, submodule=unet_block)
            endfilter //= 2
        self.model = UnetBlock(output_c, endfilter, input_c=input_c, submodule=unet_block, outermost=True)
    
    def forward(self, x):
        return self.model(x)
class Discriminator(nn.Module):
    def __init__(self, input_c, total_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, total_filters, norm=False)]
        model += [self.get_layers(total_filters * 2 ** i, total_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) 
                          for i in range(n_down)]
                                                  
        model += [self.get_layers(total_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] 
                                                                                             
        self.model = nn.Sequential(*model)                                                   
        
    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): 
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]   
        if act: layers += [nn.LeakyReLU(0.2, True)]       
        if norm: layers += [nn.BatchNorm2d(nf)]
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [20]:
class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        self.loss = nn.BCEWithLogitsLoss()
    
    def get_labels(self, preds, target_is_real):
        if not target_is_real:
            labels = self.fake_label
        else:
            labels = self.real_label
        return labels.expand_as(preds)
    
    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

In [21]:
def weight_initialize(net, init='norm', gain=0.02):
    
    def init_func(m):
        classname = m.__class__.__name__
        if 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)
        elif hasattr(m, 'weight') and 'Conv' in classname:
            nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
            
    net.apply(init_func)
    return net

def init_model(model, device):
    model = model.to(device)
    model = weight_initialize(model)
    return model

In [22]:
class GANModel(nn.Module):
    def __init__(self, nG=None, lr_G=2e-4, lr_D=2e-4, 
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        
        if nG:
            self.nG = nG.to(self.device)
        else:
            self.nG = init_model(Unet(input_c=1, output_c=2, n_down=8, total_filters=64), self.device)
        self.nD = init_model(Discriminator(input_c=3, n_down=3, total_filters=64), self.device)
        self.GAN_loss = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1_loss = nn.L1Loss()
        self.Gen_optimize = optim.Adam(self.nG.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.Disc_optimize = optim.Adam(self.nD.parameters(), lr=lr_D, betas=(beta1, beta2))
    
    def bD(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.nD(fake_image.detach())
        self.Disc_fake = self.GAN_loss(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.nD(real_image)
        self.Disc_real = self.GAN_loss(real_preds, True)
        self.Disc_Loss = (self.Disc_fake + self.Disc_real) * 0.5
        self.Disc_Loss.backward()
    
    def bG(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.nD(fake_image)
        self.loss_G_GAN = self.GAN_loss(fake_preds, True)
        self.Gen_L1 = self.L1_loss(self.fake_color, self.ab) * self.lambda_L1
        self.Gen_Loss = self.loss_G_GAN + self.Gen_L1
        self.Gen_Loss.backward()
    
    def optimize(self):
        self.forward()
        self.nD.train()
        self.grad_required(self.nD, True)
        self.Disc_optimize.zero_grad()
        self.bD()
        self.Disc_optimize.step()
        
        self.nG.train()
        self.grad_required(self.nD, False)
        self.Gen_optimize.zero_grad()
        self.bG()
        self.Gen_optimize.step()
    def grad_required(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad
        
    def init_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)
        
    def forward(self):
        self.fake_color = self.nG(self.L)
    

## Loss Metrics

In [23]:
class AverageMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

In [None]:
def loss_metrics():
    Disc_Loss = AverageMeter()
    Gen_Loss = AverageMeter()
    
    return {'Disc_Loss': Disc_Loss,
            'Gen_Loss': Gen_Loss}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def L2R(L, ab):
    L, ab = (L + 1.) * 50. ,  ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)
    
def visualize(model, data, save=True):
    model.nG.eval()
    with torch.no_grad():
        model.init_input(data)
        model.forward()
    model.nG.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake, real = L2R(L, fake_color), L2R(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real[i])
        ax.axis("off")
    plt.show()
    if save:
        fig.savefig(f"colorization_{time.time()}.png")
        
def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")

## Training the Model

In [None]:
epoch = 1
def train_model(model, train_data, epochs, display_every=94):
    data = next(iter(val_data)) 
    for e in range(epochs):
        loss_meter_dict = loss_metrics()
        i = 0                                  
        for data in tqdm(train_data):
            model.init_input(data) 
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['L'].size(0))
            i += 1
            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs}")
                print(f"Iteration {i}/{len(train_data)}")
                log_results(loss_meter_dict) 
                visualize(model, data, save=False) 
        torch.save(model.state_dict(), 'checkpoint22.pt')

model = GANModel()
cl = GANModel()
clf = GANModel()

load = True
if load:
  print('loading model')
  cl.load_state_dict(torch.load('checkpoint21.pt'))
  clf.load_state_dict(torch.load('checkpoint21.pt'))
  model.load_state_dict(torch.load('checkpoint21.pt'))
train_model(model, train_data, epoch)

## Validation and Visualizing Results of Training

In [None]:

def visualiz(model, data, save=True):
    model.nG.eval()
    with torch.no_grad():
        model.init_input(data)
        model.forward()
    model.nG.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = L2R(L, fake_color)
    real_imgs = L2R(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()
data = next(iter(val_data))       
display_every = 8

loss_meter_dict = loss_metrics() 
i = 0

for data in tqdm(val_data):
    clf.init_input(data) 
    clf.optimize()
    update_losses(clf, loss_meter_dict, count=data['L'].size(0)) 
    i += 1
    if i % display_every == 0:
        print(f"Iteration {i}/{len(val_data)}")
        log_results(loss_meter_dict) 
        visualiz(clf, data) 

## Saving Results of model into Folders

In [None]:
count=0
def visual(model, data, save=True):
    global count
    model.nG.eval()
    with torch.no_grad():
        model.init_input(data)
        model.forward()
    model.nG.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = L2R(L, fake_color)
    real_imgs = L2R(L, real_color)

    for i in range(fake_imgs.shape[0]):
        img1 = Image.fromarray((fake_imgs[i] * 255).astype(np.uint8))
        img2 = Image.fromarray((real_imgs[i] * 255).astype(np.uint8))
        img1.save(f'cars_data/{count}.png')
        img2.save(f'orig_cars_data/{count}.png')
        count+=1
data = next(iter(result_dl))       
display_every = 8
for data in tqdm(result_dl):
    cl.init_input(data) 
    cl.optimize()
    visual(cl, data) 