In [None]:
import os

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from skimage import color, io

torch.set_default_tensor_type('torch.DoubleTensor')

# TODO

- [x] convert dataset to tensor? (now: numpy array)
 - seems like dataloader does it automatically - but doesnt swap axes
 - transforms.toTensor does the job
- [x] normialize data?
 - only $ab$ channel 
- [ ] transform images? - random crops itp
- [ ] on deploy, increase num_workers in dataloader (>1)
- [ ] deal with `torch.set_default_tensor_type('torch.DoubleTensor')`
- [ ] should I use `torch.nn.functional.{relu | sigmoid}` or just `torch.nn.*`?
- [x] refector asserts so they work with batch sizes different than 4

### Load images and convert them to [*CIE Lab*](https://en.m.wikipedia.org/wiki/CIELAB_color_space) color space

Warrning: `OpenCV` [uses](https://stackoverflow.com/questions/39316447/opencv-giving-wrong-color-to-colored-images-on-loading) BGR scheme, whereas `matplotlib` uses RGB. `scikit-image` uses RGB as well.


[Scikit color ranges: L: 0 to 100, a: -127 to 128, b: -128 to 127.](https://stackoverflow.com/questions/25294141/cielab-color-range-for-scikit-image)

### Custom dateset

Notes

One needs to [swap axes](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#transforms)
- numpy image: H x W x C
- torch image: C X H X W


Quick facts: $L \in [0, 100]$, $a \in [-127, 128]$, $b \in [-128, 127]$

How an image is processed:
1. Load i-th image to memory
2. Convert it to LAB Space
3. Convert to torch.tensor
4. Split the image to $L$ and $ab$ channels
5. Normalize $ab$ to $[0, 1]$. $L$ remains unnormalized.
6. $L$ will feed net, $a'b'$ will be its output
7. Calculate $Loss(ab, a'b')$

In [None]:
class ImagesDateset(Dataset):
    
    def __init__(self, img_dir):
        """
        All images from `img_dir` will be read.
        """
        self.img_dir = img_dir
        self.img_names = [file for file in os.listdir(self.img_dir)]
        
        assert all([img.endswith('.jpg') for img in self.img_names]), "Must be *.jpg"
        
    
    def __len__(self):
        return len(self.img_names)
    
    
    def __getitem__(self, idx):
        """
        Get an image in Lab color space.
        Returns a tuple (L, ab)
            - `L` stands for lightness - it's the net input
            - `ab` is chrominance - something that the net learns
        
        """
        
        img_name = os.path.join(self.img_dir, self.img_names[idx])
        image = io.imread(img_name)
        
        assert image.shape == (224, 224, 3)
                
        img_lab = color.rgb2lab(image)
        
        tsfm2tensor = transforms.ToTensor()
        
        img_lab = tsfm2tensor(img_lab)        
        img_lab = img_lab.double()
        
        assert img_lab.shape == (3, 224, 224)
               
        L  = img_lab[:1,:,:]
        ab = img_lab[1:,:,:]
        
        # Normalize to (0, 1)
        ab = (ab + 128.0) / 256.0
              
        assert L.shape == (1, 224, 224)
        assert ab.shape == (2, 224, 224)
        
        return (L, ab)

### Load data

In [None]:
img_dir='./data/rawcropped224/'

trainset = ImagesDateset(img_dir)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=0)

### The Network

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 'Low-level features'
        # conv1 has only one in channel - because it's only L channel of a photo
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
        
        # 'Mid-level fetures'
        self.conv7 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1)
        
        # 'Colorization network'
        self.conv9 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1)
        
        # Here comes upsample #1
        
        self.conv10 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv11 = nn.Conv2d(in_channels=64,  out_channels=64, kernel_size=3, stride=1, padding=1)
        
        # Here comes upsample #2        
        
        self.conv12 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)
        
        self.conv13out = nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
        
        
    def forward(self, x):
        
        # Low level
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        
        
        # Mid level
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        
        assert x.shape[1:] == (256, 28, 28), "おわり： mid level"
        

        # Colorization Net
        x = F.relu(self.conv9(x))
        
        assert x.shape[1:] == (128, 28, 28), "おわり： conv9"
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')

        assert x.shape[1:] == (128, 56, 56), "おわり： upsample1"
    
        x = F.relu(self.conv10(x))
        x = F.relu(self.conv11(x))
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')


        x = F.relu(self.conv12(x))
        x = torch.sigmoid(self.conv13out(x))
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')
        
        assert x.shape[1:] == (2, 224, 224)
        
        return x
        

### Training

In [None]:
net = Net()
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)


for epoch in range(2): 

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        # inputs, labels = data
        L, ab = data
        
        optimizer.zero_grad()

        ab_outputs = net(L)
        
        loss = criterion(ab, ab_outputs)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        #if i % 2000 == 1999:    
        print('[%d, %5d] loss: %.3f' %
              (epoch + 1, i + 1, running_loss / 2000))
        running_loss = 0.0

print('Finished Training')

In [None]:
# Conert to numpy array and map to AB range
out_np = out[1].detach().numpy() * 256 - 128.0

# Take L channel and convert to numpy
Lnew = L[1].numpy()

# Transpose axis to HxWxC again
Lnew = Lnew.transpose((1, 2, 0))
out_np = out_np.transpose((1, 2, 0))     

print(Lnew.shape)
print(type(Lnew))
print(out_np.shape)
print(type(out_np))


# Stack layers
LABstackeed = np.dstack((Lnew, out_np))

print("lab_ok.shape")
print(LABstackeed.shape)

# Print img
io.imshow(color.lab2rgb(LABstackeed))