In [1]:
import os

# import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# import cv2
from skimage import color, io

torch.set_default_tensor_type('torch.DoubleTensor')

# TODO

- [x] convert dataset to tensor? (now: numpy array)
 - seems like dataloader does it automatically - but doesnt swap axes
- [ ] normialize data?
- [ ] transform images?
- [ ] on deploy, increase num_workers in dataloader (>1)
- [ ] deal with `torch.set_default_tensor_type('torch.DoubleTensor')`

### Load images and convert them to [*CIE Lab*](https://en.m.wikipedia.org/wiki/CIELAB_color_space) color space

Warrning: `OpenCV` [uses](https://stackoverflow.com/questions/39316447/opencv-giving-wrong-color-to-colored-images-on-loading) BGR scheme, whereas `matplotlib` uses RGB. `scikit-image` uses RGB as well.


[Scikit color ranges: L: 0 to 100, a: -127 to 128, b: -128 to 127.](https://stackoverflow.com/questions/25294141/cielab-color-range-for-scikit-image)

### Custom dateset

Notes

One needs to [swap axes](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#transforms)
- numpy image: H x W x C
- torch image: C X H X W

In [47]:
class ImagesDateset(Dataset):
    
    def __init__(self, img_dir):
        """
        All images from `img_dir` will be read.
        """
        self.img_dir = img_dir
        self.img_names = [file for file in os.listdir(self.img_dir)]
        
        assert all([img.endswith('.jpg') for img in self.img_names]), "Must be *.jpg"
        
    
    def __len__(self):
        return len(self.img_names)
    
    
    def __getitem__(self, idx):
        """
        Get an image in Lab color space.
        Returns a tuple (L, ab)
            - `L` stands for lightness - it's the net input
            - `ab` is chrominance - something that the net learns
        
        """
        
        img_name = os.path.join(self.img_dir, self.img_names[idx])
        image = io.imread(img_name)
        
        assert image.shape == (224, 224, 3)
                
        img_lab = color.rgb2lab(image)
        
        tsfm2tensor = transforms.ToTensor()
        
        img_lab = tsfm2tensor(img_lab)
        
        img_lab = img_lab.double()
        

        assert img_lab.shape == (3, 224, 224)
               
        L  = img_lab[:1,:,:]
        ab = img_lab[1:,:,:]
              
        assert L.shape == (1, 224, 224)
        assert ab.shape == (2, 224, 224)
        
        return (L, ab)

### Load data

In [48]:
img_dir='./data/rawcropped224/'

trainset = ImagesDateset(img_dir)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=0)

### The Network

In [52]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 'Low-level features'
        # conv1 has only one in channel - because it's only L channel of a photo
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
        
        # 'Mid-level fetures'
        self.conv7 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1)
        
        # 'Colorization network'
        self.conv9 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1)
        
        
        # UPSAMPLE 1
        # Till now the picture should be H/8 x W/8
        # since we fix H=W=224, then H/8=28. We want upsample to 56x56
        # DEPRECATED - see forward()
#         self.upsample1 = nn.Upsample(size=(256, 56, 56), mode='nearest')
        
        
        self.conv10 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv11 = nn.Conv2d(in_channels=64,  out_channels=64, kernel_size=3, stride=1, padding=1)
        
        # UPSAMPLE 2   -   We got 56x56, want: 112x112
        # DEPRECATED - see forward()
#         self.upsample2 = nn.Upsample(size=(64, 112, 112), mode='nearest')
        
        
        self.conv12 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)
        
        # ??????????????
        
    def forward(self, x):
        
        # Low level
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        
        
        # Mid level
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        
        assert x.shape == (4, 256, 28, 28), "おわり： mid level"
        

        # Colorization Net
        x = F.relu(self.conv9(x))
        
        assert x.shape == (4, 128, 28, 28),　"おわり： conv9"
        
        
#         x = self.upsample1(x)
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')
    
        assert x.shape == (4, 128, 56, 56), "おわり： upsample1"
    
        print(x.shape)
        
    
        x = F.relu(self.conv10(x))
        x = F.relu(self.conv11(x))
        
#         x = self.upsample2(x)
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')


        x = F.relu(self.conv12(x))

        return x
        

In [53]:
net = Net()
dataiter = iter(trainloader)

L, ab = dataiter.next()

print("L.shape")
print(L.shape)

print("\n---\n")

out = net(L)

L.shape
torch.Size([4, 1, 224, 224])

---

Mid おわり。x.shape: torch.Size([4, 256, 28, 28])
torch.Size([4, 128, 56, 56])


In [54]:
out.shape

torch.Size([4, 32, 112, 112])