In [1]:
from __future__ import print_function
from __future__ import division
import torch
torch.manual_seed(0)
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
%matplotlib inline
import time
import os
import copy
import pandas as pd
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

  return f(*args, **kwds)


PyTorch Version:  0.4.1
Torchvision Version:  0.2.1


In [2]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Dataset

In [3]:
b_sz = 4

In [4]:
class ColorizeDataset(Dataset):

    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images; make sure the directory 
                only contains images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self._files = [x for x in os.listdir(self.root_dir) if x.find('.ipynb') == -1]

    def __len__(self):
        return len(self._files) 

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self._files[idx])
        sample = Image.open(img_name)
        if self.transform:
            sample = self.transform(sample)

        return sample[0,:,:], sample[1:,:,:]/128

In [5]:
from PIL import Image

In [6]:
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab
from skimage.io import imsave

In [7]:
def transform_rgb2lab(image):
    image = image.convert("RGB")
    image = np.asarray(image)
    image = rgb2lab(image/255)
    return image

In [8]:
from torchvision.transforms import Lambda
from torchvision.transforms import RandomAffine
from torchvision.transforms import RandomHorizontalFlip
from torchvision.transforms import Resize

In [9]:
# shear, zoom, rotation and horizontal flip
transform = transforms.Compose([RandomAffine(degrees=0.2, shear=0.2),
                                RandomHorizontalFlip(p=0.5),
                                Resize((256, 256)),
                                Lambda(lambda image: transform_rgb2lab(image)),
                                transforms.ToTensor(),
                               ])

In [10]:
dataset = ColorizeDataset('../Full-version/Train/', transform)

TODO: Resolve error `num_workers` > 1

In [11]:
dataloader = DataLoader(dataset, batch_size=b_sz, shuffle=True, num_workers=1)

### Network

In [12]:
import torch.nn.functional as F

In [13]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)

In [14]:
class ColorNetBeta(nn.Module):
    def __init__(self):
        super(ColorNetBeta, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, (3,3), stride=1)
        self.conv2 = nn.Conv2d(64, 64, (3,3), stride=2)
        self.conv3 = nn.Conv2d(64, 128, (3,3), stride=1)
        self.conv4 = nn.Conv2d(128, 128, (3,3), stride=2)
        self.conv5 = nn.Conv2d(128, 256, (3,3), stride=1)
        self.conv6 = nn.Conv2d(256, 256, (3,3), stride=2)
        self.conv7 = nn.Conv2d(256, 512, (3,3), stride=1)
        self.conv8 = nn.Conv2d(512, 256, (3,3), stride=1)
        self.conv9 = nn.Conv2d(256, 128, (3,3), stride=1)
        self.upsample10 = nn.Upsample(scale_factor=(2,2))
        self.conv11 = nn.Conv2d(128, 64, (3,3), stride=1)
        self.upsample12 = nn.Upsample(scale_factor=(2,2))
        self.conv13 = nn.Conv2d(64, 32, (3,3), stride=1)
        self.conv14 = nn.Conv2d(32, 2, (3,3), stride=1)
        self.upsample15 = nn.Upsample(scale_factor=(2,2))
   
    def same_pad(self, input, k=(3,3), d=(1,1), s=(1,1)):
        
        i = (input.size(-2), input.size(-1))
        # i = (i_H, i_W)
        # k = (k_H, k_W)
        # d = (d, d); dilation
        # s = (s, s); stride
        
        # tensorflow style - same padding output calculation
        calc_eff_k = lambda k, d: (k - 1) * d + 1
        k = tuple([calc_eff_k(x, y) for x, y in zip(k, d)])
        calc_o = lambda i, s: np.ceil(i/s)
        o = tuple([calc_o(x, y) for x, y in zip(i, s)])
        calc_p = lambda o, s, k, i: max(0, (o-1)*s + k-i)
        p = tuple([calc_p(w, x, y, z) for w, x, y, z in zip(o, s, k, i)])
                               
        # left, right , up , bottom
        padding = [p[1]//2, p[1]//2 + p[1]%2, p[0]//2, p[0]//2 + p[0]%2]
        padding = [int(p) for p in padding]
        return F.pad(input, padding)

    def forward(self, x):
        x = F.relu(self.conv1(self.same_pad(x)))
        x = F.relu(self.conv2(self.same_pad(x, s=(2, 2))))
        x = F.relu(self.conv3(self.same_pad(x)))
        x = F.relu(self.conv4(self.same_pad(x, s=(2, 2))))
        x = F.relu(self.conv5(self.same_pad(x)))
        x = F.relu(self.conv6(self.same_pad(x, s=(2, 2))))
        x = F.relu(self.conv7(self.same_pad(x)))
        x = F.relu(self.conv8(self.same_pad(x)))
        x = F.relu(self.conv9(self.same_pad(x)))
        x = self.upsample10(x)
        x = F.relu(self.conv11(self.same_pad(x)))
        x = self.upsample12(x)
        x = F.relu(self.conv13(self.same_pad(x)))
        x = F.tanh(self.conv14(self.same_pad(x)))
        x = self.upsample15(x)
        
        return x

In [15]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

In [16]:
net = ColorNetBeta()

In [17]:
net=net.apply(weights_init)

In [18]:
net=net.to(device)

In [19]:
criterion = nn.MSELoss()

In [20]:
optimizer = optim.RMSprop(net.parameters(), lr=0.0001, momentum=0.9)

### Training Loop

In [21]:
from tqdm import tqdm

In [22]:
for epoch in tqdm(range(200)):  
    running_loss = 0.0
    # get the inputs
    for X, Y in dataloader:
        bz = X.shape[0]
        X, Y = X.view(bz, -1, 256, 256), Y.view(bz, -1, 256, 256)
        inputs, labels = X.float().to(device), Y.float().to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # print statistics
    running_loss += loss.item()
    if epoch % 10 == 9:    # print every 10 epochs
        print(f'epoch: {epoch + 1}, loss: {running_loss / 4}')
        running_loss = 0.0

print('Finished Training')

  5%|▌         | 10/200 [00:07<02:30,  1.26it/s]

epoch: 10, loss: 0.0025241198018193245


 10%|█         | 20/200 [00:15<02:20,  1.28it/s]

epoch: 20, loss: 0.001218198100104928


 15%|█▌        | 30/200 [00:23<02:12,  1.29it/s]

epoch: 30, loss: 0.001563854282721877


 20%|██        | 40/200 [00:30<02:03,  1.29it/s]

epoch: 40, loss: 0.0017088993918150663


 25%|██▌       | 50/200 [00:38<01:56,  1.29it/s]

epoch: 50, loss: 0.0016615402419120073


 30%|███       | 60/200 [00:46<01:48,  1.29it/s]

epoch: 60, loss: 0.0017109338659793139


 35%|███▌      | 70/200 [00:54<01:40,  1.29it/s]

epoch: 70, loss: 0.000922243227250874


 40%|████      | 80/200 [01:01<01:32,  1.29it/s]

epoch: 80, loss: 0.0017484532436355948


 45%|████▌     | 90/200 [01:09<01:24,  1.29it/s]

epoch: 90, loss: 0.0010772462701424956


 50%|█████     | 100/200 [01:17<01:17,  1.30it/s]

epoch: 100, loss: 0.0006255232729017735


 55%|█████▌    | 110/200 [01:24<01:09,  1.30it/s]

epoch: 110, loss: 0.0006199725903570652


 60%|██████    | 120/200 [01:32<01:01,  1.30it/s]

epoch: 120, loss: 0.0019473774591460824


 65%|██████▌   | 130/200 [01:40<00:53,  1.30it/s]

epoch: 130, loss: 0.001128384843468666


 70%|███████   | 140/200 [01:47<00:46,  1.30it/s]

epoch: 140, loss: 0.0008004392730072141


 75%|███████▌  | 150/200 [01:55<00:38,  1.30it/s]

epoch: 150, loss: 0.0011302819475531578


 80%|████████  | 160/200 [02:03<00:30,  1.30it/s]

epoch: 160, loss: 0.0005982593866065145


 85%|████████▌ | 170/200 [02:11<00:23,  1.30it/s]

epoch: 170, loss: 0.000842070730868727


 90%|█████████ | 180/200 [02:18<00:15,  1.30it/s]

epoch: 180, loss: 0.0003926165518350899


 95%|█████████▌| 190/200 [02:26<00:07,  1.30it/s]

epoch: 190, loss: 0.00070851860800758


100%|██████████| 200/200 [02:34<00:00,  1.30it/s]

epoch: 200, loss: 0.0004900220083072782
Finished Training





### Test

In [27]:
class ColorizeDataset(Dataset):

    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images; make sure the directory 
                only contains images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self._files = [x for x in os.listdir(self.root_dir) if x.find('.ipynb') == -1]

    def __len__(self):
        return len(self._files) 

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self._files[idx])
        sample = Image.open(img_name)
        if self.transform:
            sample = self.transform(sample)

        return img_name, sample[0,:,:], sample[1:,:,:]/128 # added file_name

In [28]:
test_dataset = ColorizeDataset('../Full-version/Test/', transform)

In [29]:
test_dataloader = DataLoader(test_dataset, shuffle=False)

In [30]:
save_dir = './result_pytorch'
if not os.path.exists(save_dir): os.makedirs(save_dir)

In [37]:
with torch.no_grad():
    for img_name, X, Y in test_dataloader:
        bz = X.shape[0] # bz should be 1
        X, Y = X.view(bz, -1, 256, 256), Y.view(bz, -1, 256, 256)
        inputs, labels = X.float().to(device), Y.float().to(device)
        # inference
        preds = net(inputs)
        loss = criterion(preds, labels)
        print(f'loss: {loss}')
        ab = np.transpose(preds.cpu().numpy(), (0, 2, 3, 1)) * 128
        L = np.transpose(inputs.cpu().numpy(), (0, 2, 3, 1))

        output = np.empty(shape=(256, 256, 3))
        output[:,:,0] = np.squeeze(L)
        output[:,:,1:] = np.squeeze(ab)
        imsave(f"{os.path.join(save_dir, os.path.basename(img_name[0]))[:-4]}.png", lab2rgb(output))

  .format(dtypeobj_in, dtypeobj_out))


loss: 0.0018670042045414448
loss: 0.0025375480763614178
loss: 0.006047520786523819
loss: 0.0003591009881347418
loss: 0.016005825251340866
loss: 0.002735728397965431
loss: 0.005061606410890818
loss: 0.008008704520761967
