# Automatic Image Colorization

Based on [Let there be Color!](http://hi.cs.waseda.ac.jp/~iizuka/projects/colorization/en/)

# !!!!!!!!!!!!!!!!!!!!!!!!!
# DEKLARACJA !!!
# PRACY INŻ !!!!!!!!
# DO PIĄTKU !!!!!!!!
# !!!!!!!!!!!!!!!!!!!!!!!!!

In [1]:
import os
import time

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from skimage import color, io

torch.set_default_tensor_type('torch.FloatTensor')

### Hyperparameters

In [3]:
BATCH_SIZE = 8
EPOCHS = 100
img_dir_train ='./data/food41-120-train/'
img_dir_test = './data/food41-120-test/'

# TODO

- [x] convert dataset to tensor? (now: numpy array)
 - seems like dataloader does it automatically - but doesnt swap axes
 - transforms.toTensor does the job
- [x] normialize data?
 - only $ab$ channel + L channel
- [ ] transform images? - random crops itp
- [ ] deal with `torch.set_default_tensor_type('torch.DoubleTensor')` or `torch.FloatTensor`
- [x] refector asserts so they work with batch sizes different than 4
- [ ] shuffle data (trainloader)
- [ ] refactor?? - separate files
- [x] load / save model
- [x] extract hyperparameters
- [x] load whole dataset to memory
- [ ] read/write to zip
- [ ] 

### Custom dateset

- One needs to [swap axes](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#transforms)
    - numpy image: H x W x C
    - torch image: C X H X W


- Quick fact: $L \in [0, 100]$, $a \in [-127, 128]$, $b \in [-128, 127]$ [(source)](https://stackoverflow.com/questions/25294141/cielab-color-range-for-scikit-image)


- How an image is processed:
    1. Load i-th image to memory
    2. Convert it to LAB Space
    3. Convert to torch.tensor
    4. Split the image to $L$ and $ab$ channels
    5. Normalize $ab$ to $[0, 1]$. $L$ ~~remains unnormalized.~~ is normalized as well.
    6. $L$ will feed net, $a'b'$ will be its output
    7. Calculate $Loss(ab, a'b')$

In [4]:
class ImagesDateset(Dataset):
    
    def __init__(self, img_dir, all2mem=True):
        """
        All images from `img_dir` will be read.
        """
        self.img_dir = img_dir
        self.all2mem = all2mem
        self.img_names = [file for file in os.listdir(self.img_dir)]
        
        assert all([img.endswith('.jpg') for img in self.img_names]), "Must be *.jpg"
        
        if self.all2mem:
            self.images = [io.imread(os.path.join(self.img_dir, img)) 
                           for img in self.img_names]
        
    
    def __len__(self):
        return len(self.img_names)
    
   
    
    def __getitem__(self, idx):
        """
        Get an image in Lab color space.
        Returns a tuple (L, ab, name)
            - `L` stands for lightness - it's the net input
            - `ab` is chrominance - something that the net learns
            - `name` - image filename
        """
        
        if self.all2mem:
            image = self.images[idx]
        else:
            img_name = os.path.join(self.img_dir, self.img_names[idx])
            image = io.imread(img_name)
        
        
        assert image.shape == (224, 224, 3)
                
        img_lab = color.rgb2lab(image)
        img_lab = np.transpose(img_lab, (2, 0, 1))
        
        assert img_lab.shape == (3, 224, 224)
        
        img_lab = torch.tensor( img_lab.astype(np.float32) )
        
        assert img_lab.shape == (3, 224, 224)
               
        L  = img_lab[:1,:,:]
        ab = img_lab[1:,:,:]
        
        # Normalization
        L =   L / 100.0         # 0..1
        ab = (ab + 128.0) / 255.0 # 0..1
              
        assert L.shape == (1, 224, 224)
        assert ab.shape == (2, 224, 224)
        
        return L, ab, self.img_names[idx]   

### Load data

In [5]:
trainset = ImagesDateset(img_dir_train, all2mem=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=False, num_workers=4)

testset = ImagesDateset(img_dir_test, all2mem=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                          shuffle=False, num_workers=0)

### The Network

In [6]:
class ColNet(nn.Module):
    def __init__(self):

        super(ColNet, self).__init__()
        
        ksize = np.array( [1, 64, 128, 128, 256, 256, 512, 512, 256, 128, 64, 64, 32] ) // 8
        ksize[0] = 1
        
        # 'Low-level features'
        self.conv1 = nn.Conv2d(in_channels=1,        out_channels=ksize[1], kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=ksize[1], out_channels=ksize[2], kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=ksize[2], out_channels=ksize[3], kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(in_channels=ksize[3], out_channels=ksize[4], kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=ksize[4], out_channels=ksize[5], kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(in_channels=ksize[5], out_channels=ksize[6], kernel_size=3, stride=1, padding=1)
        
        # 'Mid-level fetures'
        self.conv7 = nn.Conv2d(in_channels=ksize[6], out_channels=ksize[7], kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(in_channels=ksize[7], out_channels=ksize[8], kernel_size=3, stride=1, padding=1)
        
        # 'Colorization network'
        self.conv9 = nn.Conv2d(in_channels=ksize[8], out_channels=ksize[9], kernel_size=3, stride=1, padding=1)
        
        # Here comes upsample #1
        
        self.conv10 = nn.Conv2d(in_channels=ksize[9], out_channels=ksize[10], kernel_size=3, stride=1, padding=1)
        self.conv11 = nn.Conv2d(in_channels=ksize[10],out_channels=ksize[11], kernel_size=3, stride=1, padding=1)
        
        # Here comes upsample #2        
        
        self.conv12 = nn.Conv2d(in_channels=ksize[11], out_channels=ksize[12], kernel_size=3, stride=1, padding=1)
        self.conv13out = nn.Conv2d(in_channels=ksize[12], out_channels=2, kernel_size=3, stride=1, padding=1)
        
        
    def forward(self, x):
        
        # Low level
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        
        
        # Mid level
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        
        # assert x.shape[1:] == (256, 28, 28), "おわり： mid level"
        

        # Colorization Net
        x = F.relu(self.conv9(x))
        
        # assert x.shape[1:] == (128, 28, 28), "おわり： conv9"
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')

        # assert x.shape[1:] == (128, 56, 56), "おわり： upsample1"
    
        x = F.relu(self.conv10(x))
        x = F.relu(self.conv11(x))
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')


        x = F.relu(self.conv12(x))
        x = torch.sigmoid(self.conv13out(x))
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')
        
        # assert x.shape[1:] == (2, 224, 224)
        
        return x
        

### Utilities

Unnormalize and return RGB image

In [7]:
def net_out2rgb(L, ab_out):
    """
    L - original `L` channel
    ab_out - learned `ab` channels which were the net's output
    
    Retruns: 3 channel RGB image
    """
    # Convert to numpy and unnnormalize
    L = L.numpy() * 100.0
    
    ab_out = np.floor(ab_out.numpy() * 255.0) - 128.0 
    
    # Transpose axis to HxWxC again
    L = L.transpose((1, 2, 0))
    ab_out = ab_out.transpose((1, 2, 0))

    # Stack layers
    lab_stack = np.dstack((L, ab_out))
    
    return color.lab2rgb(lab_stack)

### Training

In [8]:
net = ColNet()
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

model_filename = ""

for epoch in range(EPOCHS): 

    epoch_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        L, ab, _ = data
        
        optimizer.zero_grad()

        ab_outputs = net(L)
        
        loss = criterion(ab, ab_outputs)
        loss.backward()
        optimizer.step()

        running_loss = loss.item()
        
        print('[epoch: {} | batch: {}] loss: {:.3f}'
              .format(epoch + 1, i + 1, running_loss / BATCH_SIZE))
    
        epoch_loss += running_loss
    
    print('mean epoch loss: {:.2f}'.format(epoch_loss / (BATCH_SIZE * len(trainloader))))
        
    # Save
    model_filename = './model/colnet{}.pt'.format(time.strftime("%y%m%d-%H-%M-%S"))
    # model_filename = './model/colnet.pt'
    torch.save(net.state_dict(), model_filename)
    print('saved model to {}'.format(model_filename))
    
    print('End of epoch {}\n'.format(epoch + 1))

print('Finished Training')

[epoch: 1 | batch: 1] loss: 1725.632
[epoch: 1 | batch: 2] loss: 1774.514
[epoch: 1 | batch: 3] loss: 1319.859
[epoch: 1 | batch: 4] loss: 1445.086
[epoch: 1 | batch: 5] loss: 1338.139
[epoch: 1 | batch: 6] loss: 1581.826
[epoch: 1 | batch: 7] loss: 1346.652
[epoch: 1 | batch: 8] loss: 1549.567
[epoch: 1 | batch: 9] loss: 1141.875
[epoch: 1 | batch: 10] loss: 1079.381
[epoch: 1 | batch: 11] loss: 1222.143
[epoch: 1 | batch: 12] loss: 742.058
[epoch: 1 | batch: 13] loss: 455.419
mean epoch loss: 1286.32
saved model to ./model/colnet181105-21-29-32.pt
End of epoch 1

[epoch: 2 | batch: 1] loss: 854.145
[epoch: 2 | batch: 2] loss: 816.064
[epoch: 2 | batch: 3] loss: 655.662
[epoch: 2 | batch: 4] loss: 700.462
[epoch: 2 | batch: 5] loss: 748.077
[epoch: 2 | batch: 6] loss: 910.888
[epoch: 2 | batch: 7] loss: 779.830
[epoch: 2 | batch: 8] loss: 891.541
[epoch: 2 | batch: 9] loss: 595.433
[epoch: 2 | batch: 10] loss: 613.256
[epoch: 2 | batch: 11] loss: 574.222
[epoch: 2 | batch: 12] loss: 6

[epoch: 15 | batch: 9] loss: 460.988
[epoch: 15 | batch: 10] loss: 579.382
[epoch: 15 | batch: 11] loss: 516.359
[epoch: 15 | batch: 12] loss: 508.368
[epoch: 15 | batch: 13] loss: 279.195
mean epoch loss: 529.90
saved model to ./model/colnet181105-21-30-31.pt
End of epoch 15

[epoch: 16 | batch: 1] loss: 528.261
[epoch: 16 | batch: 2] loss: 507.688
[epoch: 16 | batch: 3] loss: 460.588
[epoch: 16 | batch: 4] loss: 693.373
[epoch: 16 | batch: 5] loss: 576.638
[epoch: 16 | batch: 6] loss: 652.002
[epoch: 16 | batch: 7] loss: 581.992
[epoch: 16 | batch: 8] loss: 508.949
[epoch: 16 | batch: 9] loss: 456.977
[epoch: 16 | batch: 10] loss: 578.102
[epoch: 16 | batch: 11] loss: 515.350
[epoch: 16 | batch: 12] loss: 507.599
[epoch: 16 | batch: 13] loss: 277.425
mean epoch loss: 526.53
saved model to ./model/colnet181105-21-30-35.pt
End of epoch 16

[epoch: 17 | batch: 1] loss: 523.985
[epoch: 17 | batch: 2] loss: 502.058
[epoch: 17 | batch: 3] loss: 462.033
[epoch: 17 | batch: 4] loss: 697.191


[epoch: 30 | batch: 1] loss: 504.025
[epoch: 30 | batch: 2] loss: 493.868
[epoch: 30 | batch: 3] loss: 456.323
[epoch: 30 | batch: 4] loss: 681.360
[epoch: 30 | batch: 5] loss: 536.167
[epoch: 30 | batch: 6] loss: 634.113
[epoch: 30 | batch: 7] loss: 579.640
[epoch: 30 | batch: 8] loss: 468.937
[epoch: 30 | batch: 9] loss: 443.544
[epoch: 30 | batch: 10] loss: 573.968
[epoch: 30 | batch: 11] loss: 535.609
[epoch: 30 | batch: 12] loss: 497.225
[epoch: 30 | batch: 13] loss: 249.779
mean epoch loss: 511.89
saved model to ./model/colnet181105-21-31-30.pt
End of epoch 30

[epoch: 31 | batch: 1] loss: 497.115
[epoch: 31 | batch: 2] loss: 473.210
[epoch: 31 | batch: 3] loss: 479.095
[epoch: 31 | batch: 4] loss: 700.843
[epoch: 31 | batch: 5] loss: 524.335
[epoch: 31 | batch: 6] loss: 634.755
[epoch: 31 | batch: 7] loss: 568.912
[epoch: 31 | batch: 8] loss: 482.210
[epoch: 31 | batch: 9] loss: 440.949
[epoch: 31 | batch: 10] loss: 572.082
[epoch: 31 | batch: 11] loss: 511.023
[epoch: 31 | batc

[epoch: 44 | batch: 6] loss: 617.993
[epoch: 44 | batch: 7] loss: 578.615
[epoch: 44 | batch: 8] loss: 440.489
[epoch: 44 | batch: 9] loss: 419.026
[epoch: 44 | batch: 10] loss: 559.169
[epoch: 44 | batch: 11] loss: 461.076
[epoch: 44 | batch: 12] loss: 481.874
[epoch: 44 | batch: 13] loss: 271.682
mean epoch loss: 494.68
saved model to ./model/colnet181105-21-32-27.pt
End of epoch 44

[epoch: 45 | batch: 1] loss: 498.998
[epoch: 45 | batch: 2] loss: 480.350
[epoch: 45 | batch: 3] loss: 450.229
[epoch: 45 | batch: 4] loss: 696.608
[epoch: 45 | batch: 5] loss: 495.188
[epoch: 45 | batch: 6] loss: 619.853
[epoch: 45 | batch: 7] loss: 581.773
[epoch: 45 | batch: 8] loss: 460.142
[epoch: 45 | batch: 9] loss: 421.735
[epoch: 45 | batch: 10] loss: 552.916
[epoch: 45 | batch: 11] loss: 492.331
[epoch: 45 | batch: 12] loss: 487.125
[epoch: 45 | batch: 13] loss: 248.836
mean epoch loss: 498.93
saved model to ./model/colnet181105-21-32-31.pt
End of epoch 45

[epoch: 46 | batch: 1] loss: 508.909


[epoch: 58 | batch: 11] loss: 443.754
[epoch: 58 | batch: 12] loss: 466.247
[epoch: 58 | batch: 13] loss: 197.211
mean epoch loss: 461.04
saved model to ./model/colnet181105-21-33-24.pt
End of epoch 58

[epoch: 59 | batch: 1] loss: 434.005
[epoch: 59 | batch: 2] loss: 452.095
[epoch: 59 | batch: 3] loss: 420.234
[epoch: 59 | batch: 4] loss: 671.390
[epoch: 59 | batch: 5] loss: 435.202
[epoch: 59 | batch: 6] loss: 555.899
[epoch: 59 | batch: 7] loss: 572.346
[epoch: 59 | batch: 8] loss: 430.602
[epoch: 59 | batch: 9] loss: 404.581
[epoch: 59 | batch: 10] loss: 524.195
[epoch: 59 | batch: 11] loss: 435.504
[epoch: 59 | batch: 12] loss: 457.227
[epoch: 59 | batch: 13] loss: 199.934
mean epoch loss: 461.02
saved model to ./model/colnet181105-21-33-29.pt
End of epoch 59

[epoch: 60 | batch: 1] loss: 451.165
[epoch: 60 | batch: 2] loss: 478.261
[epoch: 60 | batch: 3] loss: 413.139
[epoch: 60 | batch: 4] loss: 672.707
[epoch: 60 | batch: 5] loss: 402.909
[epoch: 60 | batch: 6] loss: 548.575
[

[epoch: 73 | batch: 1] loss: 391.601
[epoch: 73 | batch: 2] loss: 419.635
[epoch: 73 | batch: 3] loss: 411.147
[epoch: 73 | batch: 4] loss: 648.353
[epoch: 73 | batch: 5] loss: 388.080
[epoch: 73 | batch: 6] loss: 496.685
[epoch: 73 | batch: 7] loss: 513.841
[epoch: 73 | batch: 8] loss: 370.282
[epoch: 73 | batch: 9] loss: 375.507
[epoch: 73 | batch: 10] loss: 479.041
[epoch: 73 | batch: 11] loss: 413.937
[epoch: 73 | batch: 12] loss: 417.611
[epoch: 73 | batch: 13] loss: 179.021
mean epoch loss: 423.44
saved model to ./model/colnet181105-21-34-27.pt
End of epoch 73

[epoch: 74 | batch: 1] loss: 389.983
[epoch: 74 | batch: 2] loss: 416.206
[epoch: 74 | batch: 3] loss: 406.969
[epoch: 74 | batch: 4] loss: 643.046
[epoch: 74 | batch: 5] loss: 376.028
[epoch: 74 | batch: 6] loss: 484.944
[epoch: 74 | batch: 7] loss: 504.054
[epoch: 74 | batch: 8] loss: 365.361
[epoch: 74 | batch: 9] loss: 366.145
[epoch: 74 | batch: 10] loss: 473.022
[epoch: 74 | batch: 11] loss: 419.327
[epoch: 74 | batc

[epoch: 87 | batch: 6] loss: 457.231
[epoch: 87 | batch: 7] loss: 489.074
[epoch: 87 | batch: 8] loss: 351.952
[epoch: 87 | batch: 9] loss: 364.288
[epoch: 87 | batch: 10] loss: 454.716
[epoch: 87 | batch: 11] loss: 378.384
[epoch: 87 | batch: 12] loss: 394.196
[epoch: 87 | batch: 13] loss: 163.835
mean epoch loss: 398.89
saved model to ./model/colnet181105-21-35-24.pt
End of epoch 87

[epoch: 88 | batch: 1] loss: 379.368
[epoch: 88 | batch: 2] loss: 402.475
[epoch: 88 | batch: 3] loss: 384.605
[epoch: 88 | batch: 4] loss: 585.202
[epoch: 88 | batch: 5] loss: 353.079
[epoch: 88 | batch: 6] loss: 444.181
[epoch: 88 | batch: 7] loss: 474.007
[epoch: 88 | batch: 8] loss: 336.699
[epoch: 88 | batch: 9] loss: 365.537
[epoch: 88 | batch: 10] loss: 449.215
[epoch: 88 | batch: 11] loss: 380.964
[epoch: 88 | batch: 12] loss: 386.282
[epoch: 88 | batch: 13] loss: 166.664
mean epoch loss: 392.94
saved model to ./model/colnet181105-21-35-28.pt
End of epoch 88

[epoch: 89 | batch: 1] loss: 373.622


### Load and test model

In [9]:
!rm -rf out/*

In [10]:
model_filename = "./model/colnet181105-21-36-20.pt"

print("Make sure you're using up to date model!!!")    

net = ColNet()
net.load_state_dict(torch.load(model_filename))
net.eval()

print("Colorizing {} using {}\n".format(img_dir_test, model_filename))


with torch.no_grad():
    for batch_no, data in enumerate(testloader):
        
        print("Processing batch {} / {}".format(batch_no + 1, len(testloader)))
        L, ab, name = data
        ab_outputs = net(L)
        
        for i in range(L.shape[0]):
            img = net_out2rgb(L[i], ab_outputs[i])
            io.imsave(os.path.join("./out/", name[i]), img)

print("Saved all photos to ./out/")

Make sure you're using up to date model!!!
Colorizing ./data/food41-120-test/ using ./model/colnet181105-21-36-20.pt

Processing batch 1 / 3


  .format(dtypeobj_in, dtypeobj_out))


Processing batch 2 / 3
Processing batch 3 / 3
Saved all photos to ./out/


  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
