# Model Architecture

In [None]:
import torch
import torch.nn as nn
from functools import reduce
from torch.autograd import Variable

In [None]:
class shave_block(nn.Module):
    def __init__(self, s):
        super(shave_block, self).__init__()
        self.s=s
    def forward(self,x):
        return x[:,:,self.s:-self.s,self.s:-self.s]

class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input

class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))

class LambdaMap(LambdaBase):
    def forward(self, input):
        return list(map(self.lambda_func,self.forward_prepare(input)))

class LambdaReduce(LambdaBase):
    def forward(self, input):
        return reduce(self.lambda_func,self.forward_prepare(input))

def generator():
    # Conv2d(in_channels, out_channels, filter_size, stride=(1,1), padding=(0,0))
    G = nn.Sequential( # Sequential,
        nn.ReflectionPad2d((40, 40, 40, 40)),
        nn.Conv2d(1,32,(9, 9),(1, 1),(4, 4)),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.Conv2d(32,64,(3, 3),(2, 2),(1, 1)),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.Conv2d(64,128,(3, 3),(2, 2),(1, 1)),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.Sequential( # Sequential,
            LambdaMap(lambda x: x, # ConcatTable,
                nn.Sequential( # Sequential,
                    nn.Conv2d(128,128,(3, 3)),
                    nn.BatchNorm2d(128),
                    nn.ReLU(),
                    nn.Conv2d(128,128,(3, 3)),
                    nn.BatchNorm2d(128),
                    ),
                shave_block(2),
                ),
            LambdaReduce(lambda x,y: x+y), # CAddTable,
            ),
        nn.Sequential( # Sequential,
            LambdaMap(lambda x: x, # ConcatTable,
                nn.Sequential( # Sequential,
                    nn.Conv2d(128,128,(3, 3)),
                    nn.BatchNorm2d(128),
                    nn.ReLU(),
                    nn.Conv2d(128,128,(3, 3)),
                    nn.BatchNorm2d(128),
                    ),
                shave_block(2),
                ),
            LambdaReduce(lambda x,y: x+y), # CAddTable,
            ),
        nn.Sequential( # Sequential,
            LambdaMap(lambda x: x, # ConcatTable,
                nn.Sequential( # Sequential,
                    nn.Conv2d(128,128,(3, 3)),
                    nn.BatchNorm2d(128),
                    nn.ReLU(),
                    nn.Conv2d(128,128,(3, 3)),
                    nn.BatchNorm2d(128),
                    ),
                shave_block(2),
                ),
            LambdaReduce(lambda x,y: x+y), # CAddTable,
            ),
        nn.Sequential( # Sequential,
            LambdaMap(lambda x: x, # ConcatTable,
                nn.Sequential( # Sequential,
                    nn.Conv2d(128,128,(3, 3)),
                    nn.BatchNorm2d(128),
                    nn.ReLU(),
                    nn.Conv2d(128,128,(3, 3)),
                    nn.BatchNorm2d(128),
                    ),
                shave_block(2),
                ),
            LambdaReduce(lambda x,y: x+y), # CAddTable,
            ),
        nn.Sequential( # Sequential,
            LambdaMap(lambda x: x, # ConcatTable,
                nn.Sequential( # Sequential,
                    nn.Conv2d(128,128,(3, 3)),
                    nn.BatchNorm2d(128),
                    nn.ReLU(),
                    nn.Conv2d(128,128,(3, 3)),
                    nn.BatchNorm2d(128),
                    ),
                shave_block(2),
                ),
            LambdaReduce(lambda x,y: x+y), # CAddTable,
            ),
        nn.ConvTranspose2d(128,64,(3, 3),(2, 2),(1, 1),(1, 1)),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.ConvTranspose2d(64,32,(3, 3),(2, 2),(1, 1),(1, 1)),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.Conv2d(32,2,(9, 9),(1, 1),(4, 4)),
        nn.Tanh(),
    )
    return G

# Training

In [None]:
import torchvision.models as models
import os
from torch.utils import data
import numpy as np
from PIL import Image
from skimage.color import rgb2yuv,yuv2rgb
import cv2

In [None]:
# define data generator
class img_data(data.Dataset):
    def __init__(self, path):
        files = os.listdir(path)
        self.files = [os.path.join(path,x) for x in files]
    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        img = Image.open(self.files[index])
        yuv = rgb2yuv(img)
        y = yuv[...,0]-0.5
        u_t = yuv[...,1] / 0.43601035
        v_t = yuv[...,2] / 0.61497538
        # Put these values in a standard Tensor
        return torch.Tensor(np.expand_dims(y,axis=0)),torch.Tensor(np.stack([u_t,v_t],axis=0))

In [None]:
checkpoint_location = "drive/MyDrive/Grayscale_Colorize/checkpoints"
training_dir = "drive/MyDrive/Grayscale_Colorize/places365"
test_image = "drive/MyDrive/Grayscale_Colorize/test_img.jpg"
# The various Hyperparameters
pixel_loss_weights = 1000.0
# Number of times to go through over complete dataset
epochs = 65
# On Google Colab, we can use very powerful GPUs like the NVIDIA P-1000
gpu = 0
# We load images in RAM in batches
batch_size = 20
# Number of parallel threads
num_workers = 6
# Gradient is calculated after every batch's processing and backpropagated
g_every = 1
# Learning rate for generator
g_lr = 1e-4
# Learning rate for discriminator
d_lr = 1e-4
# Checkpoints saved to disk every 100 iterations
checkpoint_every = 100
# Initial weights for discriminator
d_init = "drive/MyDrive/Grayscale_Colorize/D_init.pth"
# Initial weights for Generator
g_init = "drive/MyDrive/Grayscale_Colorize/G_init.pth"

In [None]:
if not os.path.exists(os.path.join(checkpoint_location,'weights')):
    os.makedirs(os.path.join(checkpoint_location,'weights'))

In [None]:
# Define G, same as torch version
G = generator().cuda(gpu)

# define D, 2 classes -> Real or Fake
D = models.resnet18(pretrained=False, num_classes=2)
# Add a fully connected layer
# Apply a linear transformation to the incoming data: y = xA^T + b
# Rescaled to 512 * 512
D.fc = nn.Sequential(nn.Linear(512, 1), nn.Sigmoid())
D = D.cuda(gpu)

trainset = img_data(training_dir)
params = {'batch_size': batch_size,
          'shuffle': True,
          'num_workers': num_workers}
training_generator = data.DataLoader(trainset, **params)

In [None]:
if test_image is not None:
    test_img = Image.open(test_image).convert('RGB').resize((256,256))
    test_yuv = rgb2yuv(test_img)
    # Expand dimensions / make it linear
    test_inf = test_yuv[...,0].reshape(1,1,256,256)
    test_var = Variable(torch.Tensor(test_inf-0.5)).cuda(gpu)

In [None]:
# Load the initial weights
if d_init is not None:
    D.load_state_dict(torch.load(d_init, map_location='cuda:0'))
if g_init is not None:
    G.load_state_dict(torch.load(g_init, map_location='cuda:0'))

print("Initial Weights Loaded")

In [None]:
if test_image is not None:
    test_res = G(test_var)
    uv=test_res.cpu().detach().numpy()
    uv[:,0,:,:] *= 0.436
    uv[:,1,:,:] *= 0.615
    test_yuv = np.concatenate([test_inf,uv],axis=1).reshape(3,256,256)
    test_rgb = yuv2rgb(test_yuv.transpose(1,2,0))
    cv2.imwrite(os.path.join(checkpoint_location,'test_init.jpg'),(test_rgb.clip(min=0,max=1)*256)[:,:,[2,1,0]])

In [None]:
i=0
adversarial_loss = torch.nn.BCELoss()
optimizer_G = torch.optim.Adam(G.parameters(), lr=g_lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=d_lr, betas=(0.5, 0.999))
for epoch in range(17,epochs):
    for y, uv in training_generator:
        # Adversarial ground truths
        valid = Variable(torch.Tensor(y.size(0), 1).fill_(1.0), requires_grad=False).cuda(gpu)
        # Black and white version with fill 0
        fake = Variable(torch.Tensor(y.size(0), 1).fill_(0.0), requires_grad=False).cuda(gpu)

        yvar = Variable(y).cuda(gpu)
        uvvar = Variable(uv).cuda(gpu)
        real_imgs = torch.cat([yvar,uvvar],dim=1)

        optimizer_G.zero_grad()
        uvgen = G(yvar)
        # Generate a batch of images
        gen_imgs = torch.cat([yvar.detach(),uvgen],dim=1)

        # Loss measures generator's ability to fool the discriminator
        g_loss_gan = adversarial_loss(D(gen_imgs), valid)
        g_loss = g_loss_gan + pixel_loss_weights * torch.mean((uvvar-uvgen)**2)
        if i%g_every==0:
            g_loss.backward()
            optimizer_G.step()

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(D(real_imgs), valid)
        fake_loss = adversarial_loss(D(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        i+=1
        if i%checkpoint_every==0:
            print ("Epoch: %d: [D loss: %f] [G total loss: %f] [G GAN Loss: %f]" % (epoch, d_loss.item(), g_loss.item(), g_loss_gan.item()))

            torch.save(D.state_dict(), os.path.join(checkpoint_location,'weights','D'+str(epoch)+'.pth'))
            torch.save(G.state_dict(), os.path.join(checkpoint_location,'weights','G'+str(epoch)+'.pth'))
            if test_image is not None:
                test_res = G(test_var)
                uv=test_res.cpu().detach().numpy()
                uv[:,0,:,:] *= 0.436
                uv[:,1,:,:] *= 0.615
                test_yuv = np.concatenate([test_inf,uv],axis=1).reshape(3,256,256)
                test_rgb = yuv2rgb(test_yuv.transpose(1,2,0))
                cv2.imwrite(os.path.join(checkpoint_location,'test_epoch_'+str(epoch)+'.jpg'),(test_rgb.clip(min=0,max=1)*256)[:,:,[2,1,0]])
torch.save(D.state_dict(), os.path.join(checkpoint_location,'D_final.pth'))
torch.save(G.state_dict(), os.path.join(checkpoint_location,'G_final.pth'))


# Colorize new Grayscale Images using the Generator

In [None]:
from scipy.ndimage import zoom

In [None]:
input = "drive/MyDrive/Grayscale_Colorize/test_input"
output = "drive/MyDrive/Grayscale_Colorize/test_output"
model = "drive/MyDrive/Grayscale_Colorize/G_Final.pth"
gpu = 0

In [None]:
G = generator()

if gpu>=0:
    G=G.cuda(gpu)
    G.load_state_dict(torch.load(model,map_location='cuda:0'))
else:
    G.load_state_dict(torch.load(model,map_location={'cuda:1': 'cpu'}))


In [None]:
def inference(G,in_path,out_path):
    p=Image.open(in_path).convert('RGB')
    img_yuv = rgb2yuv(p)
    H,W,_ = img_yuv.shape
    infimg = np.expand_dims(np.expand_dims(img_yuv[...,0], axis=0), axis=0)
    img_variable = Variable(torch.Tensor(infimg-0.5))
    if gpu>=0:
        img_variable=img_variable.cuda(gpu)
    res = G(img_variable)
    uv=res.cpu().detach().numpy()
    uv[:,0,:,:] *= 0.436
    uv[:,1,:,:] *= 0.615
    (_,_,H1,W1) = uv.shape
    uv = zoom(uv,(1,1,H/H1,W/W1))
    yuv = np.concatenate([infimg,uv],axis=1)[0]
    rgb=yuv2rgb(yuv.transpose(1,2,0))
    cv2.imwrite(out_path,(rgb.clip(min=0,max=1)*256)[:,:,[2,1,0]])

In [None]:
if not os.path.isdir(input):
    inference(G, input, output)
else:
    if not os.path.exists(output):
        os.makedirs(output)
    for f in os.listdir(input):
        inference(G,os.path.join(input, f), os.path.join(output, f))
        print("DONE", os.path.join(output, f))