# Automatic Image Colorization

Based on [Let there be Color!](http://hi.cs.waseda.ac.jp/~iizuka/projects/colorization/en/)

In [None]:
import os
import time

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from skimage import color, io

torch.set_default_tensor_type('torch.FloatTensor')

### Hyperparameters

In [None]:
BATCH_SIZE = 32
EPOCHS = 3
img_dir_train ='./data/caprese_salad/'
img_dir_test = './data/caprese_salad/'

# TODO

- [x] convert dataset to tensor? (now: numpy array)
 - seems like dataloader does it automatically - but doesnt swap axes
 - transforms.toTensor does the job
- [x] normialize data?
 - only $ab$ channel + L channel
- [ ] transform images? - random crops itp
- [ ] deal with `torch.set_default_tensor_type('torch.DoubleTensor')` or `torch.FloatTensor`
- [x] refector asserts so they work with batch sizes different than 4
- [ ] shuffle data (trainloader)
- [ ] refactor?? - separate files
- [x] load / save model
- [x] extract hyperparameters
- [x] load whole dataset to memory
- [ ] read/write to zip
- [ ] 

### Custom dateset

- One needs to [swap axes](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#transforms)
    - numpy image: H x W x C
    - torch image: C X H X W


- Quick fact: $L \in [0, 100]$, $a \in [-127, 128]$, $b \in [-128, 127]$ [(source)](https://stackoverflow.com/questions/25294141/cielab-color-range-for-scikit-image)


- How an image is processed:
    1. Load i-th image to memory
    2. Convert it to LAB Space
    3. Convert to torch.tensor
    4. Split the image to $L$ and $ab$ channels
    5. Normalize $ab$ to $[0, 1]$. $L$ ~~remains unnormalized.~~ is normalized as well.
    6. $L$ will feed net, $a'b'$ will be its output
    7. Calculate $Loss(ab, a'b')$

In [None]:
class ImagesDateset(Dataset):
    
    def __init__(self, img_dir, all2mem=True):
        """
        All images from `img_dir` will be read.
        """
        self.img_dir = img_dir
        self.all2mem = all2mem
        self.img_names = [file for file in os.listdir(self.img_dir)]
        
        assert all([img.endswith('.jpg') for img in self.img_names]), "Must be *.jpg"
        
        if self.all2mem:
            self.images = [io.imread(os.path.join(self.img_dir, img)) 
                           for img in self.img_names]
        
    
    def __len__(self):
        return len(self.img_names)
    
   
    
    def __getitem__(self, idx):
        """
        Get an image in Lab color space.
        Returns a tuple (L, ab, name)
            - `L` stands for lightness - it's the net input
            - `ab` is chrominance - something that the net learns
            - `name` - image filename
        """
        
        if self.all2mem:
            image = self.images[idx]
        else:
            img_name = os.path.join(self.img_dir, self.img_names[idx])
            image = io.imread(img_name)
        
        
        assert image.shape == (224, 224, 3)
                
        img_lab = color.rgb2lab(image)
        img_lab = np.transpose(img_lab, (2, 0, 1))
        
        assert img_lab.shape == (3, 224, 224)
        
        img_lab = torch.tensor( img_lab.astype(np.float32) )
        
        assert img_lab.shape == (3, 224, 224)
               
        L  = img_lab[:1,:,:]
        ab = img_lab[1:,:,:]
        
        # Normalization
        L =   L / 100.0         # 0..1
        ab = (ab + 128.0) / 255.0 # 0..1
              
        assert L.shape == (1, 224, 224)
        assert ab.shape == (2, 224, 224)
        
        return L, ab, self.img_names[idx]   

### Load data

In [None]:
trainset = ImagesDateset(img_dir_train, all2mem=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=False, num_workers=4)

testset = ImagesDateset(img_dir_test, all2mem=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                          shuffle=False, num_workers=0)

### The Network

In [None]:
class ColNet(nn.Module):
    def __init__(self):

        super(ColNet, self).__init__()
        
        ksize = np.array( [1, 64, 128, 128, 256, 256, 512, 512, 256, 128, 64, 64, 32] ) // 8
        ksize[0] = 1
        
        # 'Low-level features'
        self.conv1 = nn.Conv2d(in_channels=1,        out_channels=ksize[1], kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=ksize[1], out_channels=ksize[2], kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=ksize[2], out_channels=ksize[3], kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(in_channels=ksize[3], out_channels=ksize[4], kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=ksize[4], out_channels=ksize[5], kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(in_channels=ksize[5], out_channels=ksize[6], kernel_size=3, stride=1, padding=1)
        
        # 'Mid-level fetures'
        self.conv7 = nn.Conv2d(in_channels=ksize[6], out_channels=ksize[7], kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(in_channels=ksize[7], out_channels=ksize[8], kernel_size=3, stride=1, padding=1)
        
        # 'Colorization network'
        self.conv9 = nn.Conv2d(in_channels=ksize[8], out_channels=ksize[9], kernel_size=3, stride=1, padding=1)
        
        # Here comes upsample #1
        
        self.conv10 = nn.Conv2d(in_channels=ksize[9], out_channels=ksize[10], kernel_size=3, stride=1, padding=1)
        self.conv11 = nn.Conv2d(in_channels=ksize[10],out_channels=ksize[11], kernel_size=3, stride=1, padding=1)
        
        # Here comes upsample #2        
        
        self.conv12 = nn.Conv2d(in_channels=ksize[11], out_channels=ksize[12], kernel_size=3, stride=1, padding=1)
        self.conv13out = nn.Conv2d(in_channels=ksize[12], out_channels=2, kernel_size=3, stride=1, padding=1)
        
        
    def forward(self, x):
        
        # Low level
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        
        
        # Mid level
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        
        # assert x.shape[1:] == (256, 28, 28), "おわり： mid level"
        

        # Colorization Net
        x = F.relu(self.conv9(x))
        
        # assert x.shape[1:] == (128, 28, 28), "おわり： conv9"
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')

        # assert x.shape[1:] == (128, 56, 56), "おわり： upsample1"
    
        x = F.relu(self.conv10(x))
        x = F.relu(self.conv11(x))
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')


        x = F.relu(self.conv12(x))
        x = torch.sigmoid(self.conv13out(x))
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')
        
        # assert x.shape[1:] == (2, 224, 224)
        
        return x
        

### Utilities

Unnormalize and return RGB image

In [None]:
def net_out2rgb(L, ab_out):
    """
    L - original `L` channel
    ab_out - learned `ab` channels which were the net's output
    
    Retruns: 3 channel RGB image
    """
    # Convert to numpy and unnnormalize
    L = L.numpy() * 100.0
    
    ab_out = np.floor(ab_out.numpy() * 255.0) - 128.0 
    
    # Transpose axis to HxWxC again
    L = L.transpose((1, 2, 0))
    ab_out = ab_out.transpose((1, 2, 0))

    # Stack layers
    lab_stack = np.dstack((L, ab_out))
    
    return color.lab2rgb(lab_stack)

### Training

In [None]:
net = ColNet()
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

model_filename = ""

for epoch in range(EPOCHS): 

    epoch_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        L, ab, _ = data
        
        optimizer.zero_grad()

        ab_outputs = net(L)
        
        loss = criterion(ab, ab_outputs)
        loss.backward()
        optimizer.step()

        running_loss = loss.item()
        
        print('[epoch: {} | batch: {}] loss: {:.3f}'
              .format(epoch + 1, i + 1, running_loss / BATCH_SIZE))
    
        epoch_loss += running_loss
    
    print('mean epoch loss: {:.2f}'.format(epoch_loss / (BATCH_SIZE * len(trainloader))))
        
    # Save
    # model_filename = './model/colnet{}.pt'.format(time.strftime("%y%m%d-%H-%M-%S"))
    model_filename = './model/colnet.pt'
    torch.save(net.state_dict(), model_filename)
    print('saved model to {}'.format(model_filename))
    
    print('End of epoch {}\n'.format(epoch + 1))

print('Finished Training')

### Load and test model

In [None]:
model_filename = "./model/colnet.pt"

print("Make sure you're using up to date model!!!")    

net = ColNet()
net.load_state_dict(torch.load(model_filename))
net.eval()

print("Colorizing {} using {}\n".format(img_dir_test, model_filename))


with torch.no_grad():
    for batch_no, data in enumerate(testloader):
        
        print("Processing batch {} / {}".format(batch_no + 1, len(testloader)))
        L, ab, name = data
        ab_outputs = net(L)
        
        for i in range(L.shape[0]):
            img = net_out2rgb(L[i], ab_outputs[i])
            io.imsave(os.path.join("./out/", name[i]), img)

print("Saved all photos to ./out/")