In [54]:
import cv2
import numpy as np
import os
import glob
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from PIL import Image
from matplotlib import cm
import torch
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torchvision.transforms as transforms  # Transformations we can perform on our dataset
import torchvision
import os
import pandas as pd

from skimage import io, color
from torch.utils.data import (
    Dataset,
    DataLoader,
)  # Gives easier dataset managment and creates mini batches



In [55]:
!pip install opencv-python
!pip install scikit-image



In [56]:
img_train_dir = os.path.join('./Dataset/face_images/face_images/train/face_images')
img_test_dir = os.path.join('./Dataset/face_images/face_images/test/face_images')
aug_train_dir = os.path.join('./Dataset/augment_images')
lab_train_dir = os.path.join('./Dataset/lab_images')
lab_test_dir = os.path.join('./Dataset/lab_images/test')
l_train_dir = os.path.join('./Dataset/l_images')
l_test_dir = os.path.join('./Dataset/l_images/test')
a_train_dir = os.path.join('./Dataset/a_images/')
a_test_dir = os.path.join('./Dataset/a_images/test')
b_train_dir = os.path.join('./Dataset/b_images/')
b_test_dir = os.path.join('./Dataset/b_images/test')

## Data Loading (Training)

In [57]:
image_list_a_train = glob.glob(os.path.join(a_train_dir,"*.jpg"))
image_list_a_train.sort()
image_list_b_train = glob.glob(os.path.join(b_train_dir,"*.jpg"))
image_list_b_train.sort()
input_image_list_train = glob.glob(os.path.join(l_train_dir,"*.jpg"))
input_image_list_train.sort()

In [58]:
class TrainDataset(Dataset):
    def __init__(self, csv_file, root_dir):
        self.root_dir = root_dir
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.annotations = input_image_list_train
        
    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_path =  input_image_list_train[index]
        image = io.imread(img_path)
        a_path = image_list_a_train[index]
        a = io.imread(a_path)
        b_path = image_list_b_train[index]
        b = io.imread(b_path)
        a = self.transform(a)
        b = self.transform(b)
        image = self.transform(image)

        return (image, a, b)

In [59]:
if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu"  
device = torch.device(dev) 

In [60]:
# Hyperparameters
learning_rate = 1e-7
batch_size = 10
num_epochs = 20

In [61]:
train_loader = DataLoader(TrainDataset(
    csv_file="exp_input.csv",
    root_dir="CAP5404/l_images",
), batch_size=batch_size, shuffle=True)

## Data Loading (Testing)

In [62]:
image_list_a_test = glob.glob(os.path.join(a_test_dir,"*.jpg"))
image_list_a_test.sort()
image_list_b_test = glob.glob(os.path.join(b_test_dir,"*.jpg"))
image_list_b_test.sort()
input_image_list_test = glob.glob(os.path.join(l_test_dir,"*.jpg"))
input_image_list_test.sort()

In [63]:
class TestDataset(Dataset):
    def __init__(self, csv_file, root_dir):
        self.root_dir = root_dir
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.annotations = input_image_list_test
        
    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_path =  input_image_list_test[index]
        image = io.imread(img_path)
        a_path = image_list_a_test[index]
        a = io.imread(a_path)
        b_path = image_list_b_test[index]
        b = io.imread(b_path)
        a = self.transform(a)
        b = self.transform(b)
        image = self.transform(image)

        return (image, a, b)

In [64]:
if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu"  
device = torch.device(dev) 

In [65]:
# Hyperparameters
learning_rate = 1e-7
batch_size = 10
num_epochs = 20

In [66]:
test_loader = DataLoader( TestDataset(
    csv_file="exp_input.csv",
    root_dir="CAP5404/l_images",
), batch_size=batch_size, shuffle=True)

## Model

In [67]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models 
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.cnn = nn.Sequential(nn.Conv2d(1,3,kernel_size=2, padding=0,stride=2),
                                 nn.ReLU(),
                                nn.Conv2d(3,3,kernel_size=2, padding=0,stride=2),
                                nn.ReLU(),
                                nn.Conv2d(3,3,kernel_size=2, padding=0,stride=2),
                                nn.ReLU(),
                                nn.Conv2d(3,3,kernel_size=2, padding=0,stride=2),
                                nn.ReLU(),
                                nn.Conv2d(3,3,kernel_size=2, padding=0,stride=2),
                                nn.ReLU(),
                                nn.Conv2d(3,3,kernel_size=2, padding=0,stride=2),
                                nn.ReLU(),
                                nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(128),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2),
                                nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(64),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2),
                                nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(32),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2),
                                nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(16),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2),
                                nn.Conv2d(16, 9, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(9),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2),
                                nn.Conv2d(9, 2, kernel_size=3, stride=1, padding=1),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2))
    
                                
                                 
    def forward(self, input):
        output = self.cnn(input)
        #         mean_chrominance = self.regressor(output)

        return  output

In [68]:
# Model
model = CNN()
model.to(device)

CNN(
  (cnn): Sequential(
    (0): Conv2d(1, 3, kernel_size=(2, 2), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(3, 3, kernel_size=(2, 2), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(3, 3, kernel_size=(2, 2), stride=(2, 2))
    (5): ReLU()
    (6): Conv2d(3, 3, kernel_size=(2, 2), stride=(2, 2))
    (7): ReLU()
    (8): Conv2d(3, 3, kernel_size=(2, 2), stride=(2, 2))
    (9): ReLU()
    (10): Conv2d(3, 3, kernel_size=(2, 2), stride=(2, 2))
    (11): ReLU()
    (12): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): ReLU()
    (15): Upsample(scale_factor=2.0, mode=nearest)
    (16): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (18): ReLU()
    (19): Upsample(scale_factor=2.0, mode=nearest)
    (20): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padd

In [69]:
# Loss and optimizer
loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

In [70]:
# Train Network
import torchvision.transforms as transforms
trans = transforms.Compose([transforms.ToTensor()])
Train_losses = []
Test_losses = []
for epoch in range(num_epochs):
    model.train()
    for batch_idx,(data, a , b) in enumerate(train_loader):
        #img_path =  input_image_list[index] 
        data = data.to(device=device)
        target_a = a.to(device=device)
        target_b = b.to(device=device)
        target = torch.cat((a, b), 1)
        # forward
        pred = model(data)
        loss_val = loss(pred, target)

        Train_losses.append(loss_val.item())

        # backward
        optimizer.zero_grad()
        
        loss_val.backward()

        # gradient descent or adam step
        optimizer.step()
    
    model.eval()
    for batch_idx,(data, a , b) in enumerate(test_loader):
        #img_path =  input_image_list[index] 
        data = data.to(device=device)
        target_a = a.to(device=device)
        target_b = b.to(device=device)
        target = torch.cat((a, b), 1)
        # forward
        pred = model(data)
        
        loss_val = loss(pred, target)

        Test_losses.append(loss_val.item())

        # backward
        optimizer.zero_grad()
        
        loss_val.backward()

        # gradient descent or adam step
        optimizer.step()

print(f"Average Train MSE is {sum(Train_losses)/len(Train_losses)}")
print(f"Average Test MSE is {sum(Test_losses)/len(Test_losses)}")

KeyboardInterrupt: 

In [313]:
for batch_idx,(data, a , b) in enumerate(test_loader):
        #img_path =  input_image_list[index] 
        data = data.to(device=device)
        target_a = a.to(device=device)
        target_b = b.to(device=device)
        # forward
        pred = model(data)
        save_predict_image(pred, data, batch_idx)

<Figure size 432x288 with 0 Axes>

In [314]:
def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
    Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).detach().numpy() # combine channels
    color_image = color_image + 1
    color_image = color_image - color_image.min()
    color_image = color_image / (color_image.max() - color_image.min())
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
    color_image = color.lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path + "grayscale", save_name), cmap='gray')
    plt.imsave(arr=color_image, fname='{}{}'.format(save_path, save_name))

In [315]:
def save_predict_image(pred, data, batch_idx):
    for i in range(len(data)):
        name = str(batch_idx) + " " + str(i)
        to_rgb(data[i], pred[i], save_path = "./predicted_images/", save_name = f"{name}.jpg")

In [305]:
model = torch.load("./colorizer_250.pt", map_location=torch.device('cpu'))

# NCD

In [316]:
ncd_test = glob.glob(os.path.join("./Dataset/Gray/Cherry/","*.jpg"))
ncd_test.sort()
ncd_original = glob.glob(os.path.join("./Dataset/ColorfulOriginal/Cherry/","*.jpg"))
ncd_original.sort()

In [317]:
class NCDTestDataset(Dataset):
    def __init__(self, csv_file, root_dir):
        self.root_dir = root_dir
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.annotations = ncd_test
        
    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_path = ncd_test[index]
        img_path_o = ncd_original[index]
        T = transforms.Resize((128, 128))
        image = Image.open(img_path)
        image_o = Image.open(img_path_o)
        image = T(image)
        image_o = T(image_o)
        image = self.transform(image)
        image_o = self.transform(image_o)
        return (image, image_o)

In [318]:
ncd_test_loader = DataLoader(NCDTestDataset(
    csv_file="exp_input.csv",
    root_dir="CAP5404/l_images",
), batch_size=batch_size, shuffle=True)

In [319]:
for b, (image, image_o) in enumerate(ncd_test_loader):
    pred = model(image)
    
    save_predict_image_ncd(pred, image, b, image_o)

<Figure size 432x288 with 0 Axes>

In [320]:
from skimage.metrics import structural_similarity as ssim

In [321]:
def to_rgb_ncd(grayscale_input, ab_input, original, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
    Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).detach().numpy() # combine channels
    color_image = color_image + 1
    color_image = color_image - color_image.min()
    color_image = color_image / (color_image.max() - color_image.min())
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
    color_image = color.lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    original = original.numpy()
    original = original.transpose()

    plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path + "grayscale", save_name), cmap='gray')
    plt.imsave(arr=original, fname='{}{}'.format(save_path + "original", save_name))
    plt.imsave(arr=color_image, fname='{}{}'.format(save_path, save_name))

In [124]:
def save_predict_image_ncd(pred, data, batch_idx, image_o):
    for i in range(len(data)):
        name = str(batch_idx) + " " + str(i)
        to_rgb_ncd(data[i], pred[i], image_o[i], save_path = "./ncd_predicted_images/", save_name = f"{name}.jpg")

In [289]:
original = glob.glob(os.path.join(r"./ncd_predicted_images/original*.jpg"))
predicted = glob.glob(os.path.join(r"./ncd_predicted_images/[0,1,2]*.jpg"))

In [290]:
predicted.sort()
original.sort()

In [292]:
len(predicted)

29

In [293]:
for i in range(len(original)):
    a = Image.open(original[i])
    a = np.array(a)
    b = Image.open(predicted[i])
    b = np.array(b)
    print(ssim(a, b, channel_axis = -1))

0.3580939012461708
0.4464937619769425
0.36875340394001227
0.34434305279784244
0.41021283981408335
0.4150615420542478
0.37169179272873526
0.2727458441569837
0.3770320595549275
0.3308623104401629
0.41748914096369893
0.45059996720694223
0.36869617798962323
0.44936479363564086
0.42191882162953426
0.5533975959985756
0.598348637153439
0.5852125063175916
0.5530422724326791
0.3044505503521544
0.2423542802433228
0.31980157753880206
0.3392061665853175
0.40082590219422776
0.40835275179971076
0.3649599161375409
0.3135458061661558
0.3035956974943903
0.4316711712778649


IndexError: list index out of range