In [119]:
import os
import time

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from skimage import color, io

torch.set_default_tensor_type('torch.FloatTensor')

# TODO

- [x] convert dataset to tensor? (now: numpy array)
 - seems like dataloader does it automatically - but doesnt swap axes
 - transforms.toTensor does the job
- [x] normialize data?
 - only $ab$ channel 
- [ ] transform images? - random crops itp
- [ ] on deploy, increase num_workers in dataloader (>1)
- [ ] deal with `torch.set_default_tensor_type('torch.DoubleTensor')`
- [ ] should I use `torch.nn.functional.{relu | sigmoid}` or just `torch.nn.*`?
- [x] refector asserts so they work with batch sizes different than 4
- [ ] shuffle data (trainloader)
- [ ] refactor?? - separate files
- [ ] load / save model

### Load images and convert them to [*CIE Lab*](https://en.m.wikipedia.org/wiki/CIELAB_color_space) color space

Warrning: `OpenCV` [uses](https://stackoverflow.com/questions/39316447/opencv-giving-wrong-color-to-colored-images-on-loading) BGR scheme, whereas `matplotlib` uses RGB. `scikit-image` uses RGB as well.


[Scikit color ranges: L: 0 to 100, a: -127 to 128, b: -128 to 127.](https://stackoverflow.com/questions/25294141/cielab-color-range-for-scikit-image)

### Custom dateset

Notes

One needs to [swap axes](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#transforms)
- numpy image: H x W x C
- torch image: C X H X W


Quick facts: $L \in [0, 100]$, $a \in [-127, 128]$, $b \in [-128, 127]$

How an image is processed:
1. Load i-th image to memory
2. Convert it to LAB Space
3. Convert to torch.tensor
4. Split the image to $L$ and $ab$ channels
5. Normalize $ab$ to $[0, 1]$. $L$ remains unnormalized.
6. $L$ will feed net, $a'b'$ will be its output
7. Calculate $Loss(ab, a'b')$

In [120]:
class ImagesDateset(Dataset):
    
    def __init__(self, img_dir):
        """
        All images from `img_dir` will be read.
        """
        self.img_dir = img_dir
        self.img_names = [file for file in os.listdir(self.img_dir)]
        
        assert all([img.endswith('.jpg') for img in self.img_names]), "Must be *.jpg"
        
    
    def __len__(self):
        return len(self.img_names)
    
    
    def __getitem__(self, idx):
        """
        Get an image in Lab color space.
        Returns a tuple (L, ab)
            - `L` stands for lightness - it's the net input
            - `ab` is chrominance - something that the net learns
        
        """
        
        img_name = os.path.join(self.img_dir, self.img_names[idx])
        image = io.imread(img_name)
        
        
        assert image.shape == (224, 224, 3)
                
        img_lab = color.rgb2lab(image)
        img_lab = np.transpose(img_lab, (2, 0, 1))
        
        assert img_lab.shape == (3, 224, 224)
        
        img_lab = torch.tensor( img_lab.astype(np.float32) )
        
        assert img_lab.shape == (3, 224, 224)
               
        L  = img_lab[:1,:,:]
        ab = img_lab[1:,:,:]
        
        # Normalization
        L = L/128.0   # 0..1
        ab = (ab+128) / 255.0 # -1 .. 1
              
        assert L.shape == (1, 224, 224)
        assert ab.shape == (2, 224, 224)
        
        return (L, ab)

### Load data

In [126]:
img_dir_train ='./data/apple_pie//'
img_dir_test = './data/apple_pie/'

trainset = ImagesDateset(img_dir_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=False, num_workers=0)

testset = ImagesDateset(img_dir_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
                                          shuffle=False, num_workers=0)

### The Network

In [127]:
np.array([2,4])/2

array([1., 2.])

In [136]:
class Net(nn.Module):
    def __init__(self):
        ksize = np.array( [1, 64, 128, 128, 256, 256, 512, 512, 256, 128, 64, 64, 32] ) // 8
        ksize[0] = 1
#         ksize = [1, 32, 128, 128, 256, 256, 512, 512, 256, 128, 64, 64, 32]
        super(Net, self).__init__()
        # 'Low-level features'
        # conv1 has only one in channel - because it's only L channel of a photo
        self.conv1 = nn.Conv2d(in_channels=1,        out_channels=ksize[1], kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=ksize[1], out_channels=ksize[2], kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=ksize[2], out_channels=ksize[3], kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(in_channels=ksize[3], out_channels=ksize[4], kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=ksize[4], out_channels=ksize[5], kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(in_channels=ksize[5], out_channels=ksize[6], kernel_size=3, stride=1, padding=1)
        
        # 'Mid-level fetures'
        self.conv7 = nn.Conv2d(in_channels=ksize[6], out_channels=ksize[7], kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(in_channels=ksize[7], out_channels=ksize[8], kernel_size=3, stride=1, padding=1)
        
        # 'Colorization network'
        self.conv9 = nn.Conv2d(in_channels=ksize[8], out_channels=ksize[9], kernel_size=3, stride=1, padding=1)
        
        # Here comes upsample #1
        
        self.conv10 = nn.Conv2d(in_channels=ksize[9], out_channels=ksize[10], kernel_size=3, stride=1, padding=1)
        self.conv11 = nn.Conv2d(in_channels=ksize[10],out_channels=ksize[11], kernel_size=3, stride=1, padding=1)
        
        # Here comes upsample #2        
        
        self.conv12 = nn.Conv2d(in_channels=ksize[11], out_channels=ksize[12], kernel_size=3, stride=1, padding=1)
        
        self.conv13out = nn.Conv2d(in_channels=ksize[12], out_channels=2, kernel_size=3, stride=1, padding=1)
        
        
    def forward(self, x):
        
        # Low level
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        
        
        # Mid level
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        
#         assert x.shape[1:] == (256, 28, 28), "おわり： mid level"
        

        # Colorization Net
        x = F.relu(self.conv9(x))
        
#         assert x.shape[1:] == (128, 28, 28), "おわり： conv9"
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')

#         assert x.shape[1:] == (128, 56, 56), "おわり： upsample1"
    
        x = F.relu(self.conv10(x))
        x = F.relu(self.conv11(x))
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')


        x = F.relu(self.conv12(x))
        x = torch.sigmoid(self.conv13out(x))
        
        x = nn.functional.interpolate(input=x, scale_factor=2, mode='nearest')
        
#         assert x.shape[1:] == (2, 224, 224)
        
        return x
        

### Utils

In [137]:
def net_out2rgb(L, ab_out):
    """
    L - original `L` channel
    ab_out - learned `ab` channels which were the net's output
    
    Retruns: 3 channel RGB image
    """
    # Convert to numpy and unnnormalize
    L = L.numpy() * 128.0
    
    ab_out = ab_out.numpy() * 256.0 - 128.0
    
    # Transpose axis to HxWxC again
    L = L.transpose((1, 2, 0))
    ab_out = ab_out.transpose((1, 2, 0))

    # Stack layers
    lab_stack = np.dstack((L, ab_out))
    
    return color.lab2rgb(lab_stack)

### Training

In [143]:
net = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)


for epoch in range(5): 

    epoch_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        L, ab = data
        
        optimizer.zero_grad()

        ab_outputs = net(L)
        
        loss = criterion(ab, ab_outputs)
        loss.backward()
        optimizer.step()

        running_loss = loss.item()
        
#         if i % 2000 == 1999:    
        print('[%d, %5d] loss: %.3f' %
              (epoch + 1, i + 1, running_loss / 32))
    
        epoch_loss += running_loss
    
    print('epoch loss [%d] loss: %.3f' %
              (epoch + 1, epoch / 32))
    print('End of epoch {}'.format(epoch))
    torch.save(net.state_dict(), './model/colnet' + time.strftime("%y%m%d-%H-%M-%S") + ".pt")

print('Finished Training')

[1,     1] loss: 1375.594
[1,     2] loss: 1326.204
[1,     3] loss: 1378.108
[1,     4] loss: 1413.406
[1,     5] loss: 1054.143
[1,     6] loss: 1019.315
[1,     7] loss: 1030.867
[1,     8] loss: 918.058
[1,     9] loss: 740.116
[1,    10] loss: 400.310
[1,    11] loss: 383.798
[1,    12] loss: 629.999
[1,    13] loss: 503.124
[1,    14] loss: 361.580
[1,    15] loss: 402.080
[1,    16] loss: 490.566
[1,    17] loss: 575.011
[1,    18] loss: 597.113
[1,    19] loss: 648.988
[1,    20] loss: 435.035
[1,    21] loss: 477.658
[1,    22] loss: 318.393
[1,    23] loss: 549.432
[1,    24] loss: 424.274
[1,    25] loss: 451.949
[1,    26] loss: 461.863
[1,    27] loss: 424.949
[1,    28] loss: 349.140
[1,    29] loss: 464.093
[1,    30] loss: 210.156
epoch loss [1] loss: 0.000
End of epoch 0
[2,     1] loss: 443.143
[2,     2] loss: 488.819
[2,     3] loss: 505.988
[2,     4] loss: 551.310
[2,     5] loss: 365.880
[2,     6] loss: 365.141
[2,     7] loss: 405.466
[2,     8] loss: 486.333
[

### Save / Load model

In [117]:
torch.save(net.state_dict(), './data/colnet.pt')

In [145]:
net = Net()
net.load_state_dict(torch.load('./model/colnet181030-13-32-54.pt'))
net.eval()

with torch.no_grad():
    for data in testloader:

        L, ab = data
        ab_outputs = net(L)
        
        for i in range(L.shape[0]):
            print(i)
            img = net_out2rgb(L[i], ab_outputs[i])
            io.imsave("./out/{}.jpg".format(i), img)

0
1
2
3
4
5
6
7
8


  .format(dtypeobj_in, dtypeobj_out))


9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
2

  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)


18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
0
1
2
3
4
5
