In [1]:
import cv2
import numpy as np
import os
import glob
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from PIL import Image
from matplotlib import cm
import torch
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torchvision.transforms as transforms  # Transformations we can perform on our dataset
import torchvision
import os
import pandas as pd

from skimage import io, color
from torch.utils.data import (
    Dataset,
    DataLoader,
)  # Gives easier dataset managment and creates mini batches



In [2]:
!pip install opencv-python
!pip install scikit-image

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [3]:
img_train_dir = os.path.join('./Dataset/face_images/face_images/train/face_images')
img_test_dir = os.path.join('./Dataset/face_images/face_images/test/face_images')
aug_train_dir = os.path.join('./Dataset/augment_images')
lab_train_dir = os.path.join('./Dataset/lab_images')
lab_test_dir = os.path.join('./Dataset/lab_images/test')
l_train_dir = os.path.join('./Dataset/l_images')
l_test_dir = os.path.join('./Dataset/l_images/test')
a_train_dir = os.path.join('./Dataset/a_images/')
a_test_dir = os.path.join('./Dataset/a_images/test')
b_train_dir = os.path.join('./Dataset/b_images/')
b_test_dir = os.path.join('./Dataset/b_images/test')

## Data Loading (Training)

In [4]:
image_list_a_train = glob.glob(os.path.join(a_train_dir,"*.jpg"))
image_list_a_train.sort()
image_list_b_train = glob.glob(os.path.join(b_train_dir,"*.jpg"))
image_list_b_train.sort()
input_image_list_train = glob.glob(os.path.join(l_train_dir,"*.jpg"))
input_image_list_train.sort()

In [5]:
class TrainDataset(Dataset):
    def __init__(self, csv_file, root_dir):
        self.root_dir = root_dir
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.annotations = input_image_list_train
        
    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_path =  input_image_list_train[index]
        image = io.imread(img_path)
        a_path = image_list_a_train[index]
        a = io.imread(a_path)
        b_path = image_list_b_train[index]
        b = io.imread(b_path)
        a = self.transform(a)
        b = self.transform(b)
        image = self.transform(image)

        return (image, a, b)

In [7]:
# Hyperparameters
learning_rate = 1e-7
batch_size = 10
num_epochs = 10

In [8]:
train_loader = DataLoader(TrainDataset(
    csv_file="exp_input.csv",
    root_dir="CAP5404/l_images",
), batch_size=batch_size, shuffle=True)

## Data Loading (Testing)

In [9]:
image_list_a_test = glob.glob(os.path.join(a_test_dir,"*.jpg"))
image_list_a_test.sort()
image_list_b_test = glob.glob(os.path.join(b_test_dir,"*.jpg"))
image_list_b_test.sort()
input_image_list_test = glob.glob(os.path.join(l_test_dir,"*.jpg"))
input_image_list_test.sort()

In [10]:
class TestDataset(Dataset):
    def __init__(self, csv_file, root_dir):
        self.root_dir = root_dir
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.annotations = input_image_list_test
        
    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_path =  input_image_list_test[index]
        image = io.imread(img_path)
        a_path = image_list_a_test[index]
        a = io.imread(a_path)
        b_path = image_list_b_test[index]
        b = io.imread(b_path)
        a = self.transform(a)
        b = self.transform(b)
        image = self.transform(image)

        return (image, a, b)

In [11]:
# Hyperparameters
learning_rate = 1e-7
batch_size = 10
num_epochs = 20

In [12]:
test_loader = DataLoader( TestDataset(
    csv_file="exp_input.csv",
    root_dir="CAP5404/l_images",
), batch_size=batch_size, shuffle=True)

## Model

In [13]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models 
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.cnn = nn.Sequential(nn.Conv2d(1,3,kernel_size=2, padding=0,stride=2),
                                 nn.ReLU(),
                                nn.Conv2d(3,3,kernel_size=2, padding=0,stride=2),
                                nn.ReLU(),
                                nn.Conv2d(3,3,kernel_size=2, padding=0,stride=2),
                                nn.ReLU(),
                                nn.Conv2d(3,3,kernel_size=2, padding=0,stride=2),
                                nn.ReLU(),
                                nn.Conv2d(3,3,kernel_size=2, padding=0,stride=2),
                                nn.ReLU(),
                                nn.Conv2d(3,3,kernel_size=2, padding=0,stride=2),
                                nn.ReLU(),
                                nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(128),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2),
                                nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(64),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2),
                                nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(32),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2),
                                nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(16),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2),
                                nn.Conv2d(16, 9, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(9),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2),
                                nn.Conv2d(9, 2, kernel_size=3, stride=1, padding=1),
                                nn.ReLU(),
                                nn.Upsample(scale_factor=2))
    
                                
                                 
    def forward(self, input):
        output = self.cnn(input)
        #         mean_chrominance = self.regressor(output)

        return  output

In [14]:
# Model
model = CNN()

In [15]:
# Loss and optimizer
loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

In [17]:
from datetime import datetime as dt

start = dt.now()

In [18]:
# Train Network
import torchvision.transforms as transforms
trans = transforms.Compose([transforms.ToTensor()])
Train_losses = []
Test_losses = []
for epoch in range(num_epochs):
    model.train()
    for batch_idx,(data, a , b) in enumerate(train_loader):
        #img_path =  input_image_list[index] 
        data = data
        target_a = a
        target_b = b
        target = torch.cat((a, b), 1)
        # forward
        pred = model(data)
        loss_val = loss(pred, target)

        Train_losses.append(loss_val.item())

        # backward
        optimizer.zero_grad()
        
        loss_val.backward()

        # gradient descent or adam step
        optimizer.step()
    
    model.eval()
    for batch_idx,(data, a , b) in enumerate(test_loader):
        #img_path =  input_image_list[index] 
        data = data
        target_a = a
        target_b = b
        target = torch.cat((a, b), 1)
        # forward
        pred = model(data)
        
        loss_val = loss(pred, target)

        Test_losses.append(loss_val.item())

        # backward
        optimizer.zero_grad()
        
        loss_val.backward()

        # gradient descent or adam step
        optimizer.step()

print(f"Average Train MSE is {sum(Train_losses)/len(Train_losses)}")
print(f"Average Test MSE is {sum(Test_losses)/len(Test_losses)}")

Average Train MSE is 0.005564737640198372
Average Test MSE is 0.0008029561029616161


In [19]:
running_secs = (dt.now() - start).seconds

In [20]:
running_secs

487

In [25]:
torch.save(model, 'colorizer_cpu_model.pt')

In [29]:
for batch_idx,(data, a , b) in enumerate(test_loader):
        #img_path =  input_image_list[index] 
        data = data
        target_a = a
        target_b = b
        # forward
        pred = model(data)
        save_predict_image(pred, data, batch_idx)

  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


<Figure size 432x288 with 0 Axes>

In [27]:
def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
    Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).detach().numpy() # combine channels
    color_image = color_image + 1
    color_image = color_image - color_image.min()
    color_image = color_image / (color_image.max() - color_image.min())
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
    color_image = color.lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path + "grayscale", save_name), cmap='gray')
    plt.imsave(arr=color_image, fname='{}{}'.format(save_path, save_name))

In [26]:
def save_predict_image(pred, data, batch_idx):
    for i in range(len(data)):
        name = str(batch_idx) + " " + str(i)
        to_rgb(data[i], pred[i], save_path = "./predicted_images_cpu/", save_name = f"{name}.jpg")