In [1]:
import os
import sys
import os.path as path

import numpy as np
import matplotlib.pyplot as plt
import pandas

from PIL import Image
from skimage import color

import torch
import torch.nn.functional as F

In [2]:
def load_img(img_path):
    
    out_np = np.asarray(Image.open(img_path))
    
    if(out_np.ndim == 2):
        
        out_np = np.tile(out_np[:, :, None], 3)
    
    return out_np


def resize_img(img, HW=(256, 256), resample=3):
    
    return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))


def preprocess_img(img_rgb_orig, HW=(256, 256), resample=3):

    img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
    img_lab_rs = color.rgb2lab(img_rgb_rs)

    img_l_rs = img_lab_rs[:, :, 0]
    img_ab_rs = img_lab_rs[:, :, 1:]

    tens_rs_l = torch.Tensor(img_l_rs)[None, None, :, :]
    tens_rs_ab = torch.Tensor(np.moveaxis(img_ab_rs, 2, 0))[None, :, :, :]

    return tens_rs_l, tens_rs_ab


def get_files_in_dir(directory):

    return [directory + "/" + filename for filename in os.listdir(directory) if os.path.isfile(directory + "/" + filename)]


def load_images(filenames, use_gpu=False):
    
    data = []
    
    for file in filenames:
        
        img = load_img(file)
        (tens_l_rs, tens_ab_rs) = preprocess_img(img)

        if(use_gpu):
            
            tens_l_rs = tens_l_rs.cuda()
        
        data.append((tens_l_rs.reshape(1, 256, 256), tens_ab_rs.reshape(2, 256, 256)))
        
    return data


def separate_io(test_data):
    
    inputs = torch.cat([x for x, _ in test_data]).reshape(-1, 1, 256, 256)
    outputs = torch.cat([y for _, y in test_data]).reshape(-1, 2, 256, 256)
    
    return inputs, outputs

In [3]:
import torch.nn as nn
import torch
import matplotlib.pyplot as plt


# custom loss function based on MSELoss. Slightly biased towards more extreme (colorful) values
def MSEColorfulLoss(output, target):
    
    return torch.mean((output - target)**2 + 0.15 * torch.tanh(6*(torch.abs(target) - torch.abs(output) - 0.333)))

class Colorizer(nn.Module):
    
    def __init__(self):
        
        super(Colorizer, self).__init__()
        k = 3
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=7, stride=2, padding=7//2),
            nn.ReLU(True),
            nn.BatchNorm2d(16)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 64, kernel_size=5, stride=2, padding=5//2),
            nn.ReLU(True),
            nn.BatchNorm2d(64)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=k, stride=2, padding=k//2),
            nn.ReLU(True),
            nn.BatchNorm2d(128)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=k, stride=1, padding=k//2),
            nn.ReLU(True),
            nn.BatchNorm2d(128)
        )
        self.drop_out = nn.Dropout()
        self.layer5 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=k, stride=2, padding=k//2, output_padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64)
        )
        self.layer6 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=k, stride=2, padding=k//2, output_padding=1),
            nn.ReLU(True)
        )
        self.layer7 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=k, stride=2, padding=k//2, output_padding=1),
            nn.ReLU(True)
        )
        self.layer8 = nn.Sequential(
            nn.Conv2d(16, 8, kernel_size=k, stride=1, padding=k//2),
            nn.ReLU(True)
        )
        self.layer9 = nn.Conv2d(8, 2, kernel_size=5, stride=1, padding=5//2)
        
        
    # low level functionality
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.drop_out(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.layer7(out)
        out = self.layer8(out)
        out = self.layer9(out)
        return out
    
    def use(self, x):
        return self.forward((x-50)/100.0) * 110
    
    def test(self, x, y, criterion=MSEColorfulLoss):
        return criterion(self.forward((x-50)/100.0), y / 110.0).item()
        
    def train(self, train_loader, learning_rate=0.1, num_epochs=10, criterion=None, optimizer=None):
        
        if criterion == None:
            criterion = MSEColorfulLoss
            
        if optimizer == None:
            optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        
        # Train the model
        total_step = len(train_loader)
        loss_list = []
        for epoch in range(num_epochs):
            for i, (images, expected) in enumerate(train_loader):
                # Run the forward pass
                outputs = self.forward((images-50)/100.0)
                loss = criterion(outputs, expected / 110.0)
                loss_list.append(loss.item())

                # Backprop and perform Adam optimisation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if ((epoch + 1) % 10 == 0 or num_epochs < 5) and i == 0:
                    print('Epoch [{}/{}], Loss: {:.4f}'
                          .format(epoch + 1, num_epochs, loss.item()))
                    
    def save(self, filename="model_weights.pt"):
        torch.save(self.state_dict(), filename)
        
    def load(self, filename="model_weights.pt"):
        self.load_state_dict(torch.load(filename))
        #self.eval()
        
        
    # high level functionality
    
    # trains from images within a directory (not subdirectories), and optionally returns loss from a test set
    def train_from_dir(self, train_dir, test_dir=None, learning_rate=0.001, num_epochs=200, batch_size=10, criterion=None, model_filename=None):
        train_data = load_images(get_files_in_dir(train_dir))

        train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=8)
        self.train(train_loader, learning_rate=learning_rate, num_epochs=num_epochs, criterion=criterion)
        
        if (model_filename != None):
            self.save(model_filename)
        
        
        if (test_dir != None):
            test_x, test_y = separate_io(load_images(get_files_in_dir(test_dir)))
            return self.test(test_x, test_y)
        
        return None

In [None]:
weights_file="flower_colorizer_rmse.pt"

model = Colorizer()

if path.exists(weights_file):
    
    #load model from file
    model.load(weights_file)
    
else:
    model.train_from_dir("../../Data/Colorization/train", learning_rate=0.001, num_epochs=200, batch_size=100, model_filename=weights_file)

In [None]:
model.colorize("imgs/test_flower_01.jpg", "out_imgs/test_flower_01_colorized.png")

In [None]:
model.colorize("imgs/test_flower_01.jpg", "out_imgs/test_flower_01_colorized.png")

In [None]:
model.colorize("imgs/test_flower_02.jpg", "out_imgs/test_flower_02_colorized.png")

In [None]:
model.colorize("imgs/test_flower_03.jpg", "out_imgs/test_flower_03_colorized.png")

In [None]:
model.colorize("imgs/test_flower_04.jpg", "out_imgs/test_flower_04_colorized.png")