In [1]:
import numpy as np
from skimage import color

# import torchvision.transforms as transforms
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
num_epochs = 1000
batch_size = 512
learning_rate = 1e-3
use_gpu = True

In [3]:
#https://colab.research.google.com/github/smartgeometry-ucl/dl4g/blob/master/colorization.ipynb#scrollTo=j2MERVvtGYEy
#https://colab.research.google.com/drive/1r45y6bnxT1d8qUe5YDovWYUbfX1hMnAz#scrollTo=rRoQxRmuWqnG not using right now
class ColorNet(nn.Module):
    def __init__(self, d=128):
        super(ColorNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1) # out: 32 x 16 x 16
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1) # out: 64 x 8 x 8
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) # out: 128 x 4 x 4
        self.conv3_bn = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) # out: 128 x 4 x 4
        self.conv4_bn = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) # out: 128 x 4 x 4
        self.conv5_bn = nn.BatchNorm2d(128)
        self.conv6 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) # out: 128 x 4 x 4
        self.conv6_bn = nn.BatchNorm2d(128)
        self.tconv1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) # out: 64 x 8 x 8
        self.tconv1_bn = nn.BatchNorm2d(64)
        self.tconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1) # out: 32 x 16 x 16
        self.tconv2_bn = nn.BatchNorm2d(32)
        self.tconv3 = nn.ConvTranspose2d(32, 2, kernel_size=4, stride=2, padding=1) # out: 2 x 32 x 32

    def forward(self, input):
        x = F.relu(self.conv1_bn(self.conv1(input)))
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x = F.relu(self.conv4_bn(self.conv4(x)))
        x = F.relu(self.conv5_bn(self.conv5(x)))
        x = F.relu(self.conv6_bn(self.conv6(x)))
        x = F.relu(self.tconv1_bn(self.tconv1(x)))
        x = F.relu(self.tconv2_bn(self.tconv2(x)))
        x = self.tconv3(x)

        return x

cnet = ColorNet()

device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
cnet = cnet.to(device)

num_params = sum(p.numel() for p in cnet.parameters() if p.requires_grad)
print('Number of parameters: %d' % (num_params))

Number of parameters: 773698


In [4]:
def import_image(img):
    return torch.FloatTensor(np.transpose(color.rgb2lab(np.array(img)), (2, 0, 1)))

img_transform = transforms.Compose([
    transforms.Lambda(import_image)
])


train_transforms = transforms.Compose([
                                transforms.RandomApply(torch.nn.ModuleList([
                                transforms.RandomRotation(30),
                                transforms.RandomResizedCrop(224),
                                transforms.RandomHorizontalFlip(),
                                #transforms.Grayscale(num_output_channels=1),
                                #transforms.Grayscale(num_output_channels=3),
                                
                                ]), p=0.3),
                                transforms.RandomGrayscale(p=0.1),
                                transforms.RandomRotation(30),
                                transforms.RandomVerticalFlip(p=0.5),
                                transforms.RandomResizedCrop(224),
                                transforms.RandomHorizontalFlip(),
                                #transforms.Lambda(color.rgb2lab),
                                #transforms.ToTensor(),
                                transforms.Normalize([0, 0, 0], [1, 1, 1]),
                                transforms.RandomApply(torch.nn.ModuleList([     
                                    transforms.Normalize([0, 0, 0], [1, 1, 1]),
                                    transforms.Normalize([0.5, 0.5, 0.5], [0.01, 0.01, 0.01]),
                                    transforms.Normalize([0.8, 0.8, 0.8], [0.2, 0.2, 0.2]),
                                ])),
                                transforms.Lambda(import_image)
                           
])



test_transform = transforms.Compose([
        transforms.Lambda(color.rgb2lab),
        transforms.ToTensor()
    ])

TypeError: torchvision.transforms.transforms.RandomRotation is not a Module subclass

In [38]:
from torchvision import datasets, transforms

#dataset_notransform = datasets.ImageFolder("face_image_testset/", transform = test_transform)

dataset_notransform = datasets.ImageFolder("face_image_testset/", transform = img_transform)

train_dataloader = torch.utils.data.DataLoader(dataset_notransform,            
                                          batch_size=batch_size, 
                                          #GPU_data = True,
                                          pin_memory = True,
                                          num_workers=1,
                                         
                                         )


# train_dataloader = torch.utils.data.DataLoader(data_set_train_tensor_lab,            
#                                           batch_size=32, 
#                                           #GPU_data = True,
#                                           pin_memory = True,
#                                           num_workers=1,
                                         
#                                          )

In [39]:
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
%matplotlib inline
data_iter = iter(train_dataloader)
data_iter

<torch.utils.data.dataloader._MultiProcessingDataLoaderIter at 0x7fd755507160>

In [40]:
optimizer = torch.optim.Adam(params=cnet.parameters(), lr=learning_rate)

# set to training mode
cnet.train()

train_loss_avg = []

print('Training ...')
for epoch in range(num_epochs):
    train_loss_avg.append(0)
    num_batches = 0
    
    for lab_batch, _ in train_dataloader:
        
        lab_batch = lab_batch.to(device)
        
        # apply the color net to the luminance component of the Lab images
        # to get the color (ab) components
        predicted_ab_batch = cnet(lab_batch[:, 0:1, :, :])
        
        # loss is the L2 error to the actual color (ab) components
        loss = F.mse_loss(predicted_ab_batch, lab_batch[:, 1:3, :, :])
        
        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # one step of the optmizer (using the gradients from backpropagation)
        optimizer.step()
        
        train_loss_avg[-1] += loss.item()
        num_batches += 1
        
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))

Training ...
Epoch [1 / 1000] average reconstruction error: 171.294006
Epoch [2 / 1000] average reconstruction error: 168.707581
Epoch [3 / 1000] average reconstruction error: 164.994019
Epoch [4 / 1000] average reconstruction error: 162.055008
Epoch [5 / 1000] average reconstruction error: 159.845139
Epoch [6 / 1000] average reconstruction error: 157.655960
Epoch [7 / 1000] average reconstruction error: 155.875183
Epoch [8 / 1000] average reconstruction error: 154.195450
Epoch [9 / 1000] average reconstruction error: 152.804169
Epoch [10 / 1000] average reconstruction error: 151.272659
Epoch [11 / 1000] average reconstruction error: 149.911896
Epoch [12 / 1000] average reconstruction error: 148.537888
Epoch [13 / 1000] average reconstruction error: 147.147369
Epoch [14 / 1000] average reconstruction error: 145.795578
Epoch [15 / 1000] average reconstruction error: 144.425812
Epoch [16 / 1000] average reconstruction error: 143.053772
Epoch [17 / 1000] average reconstruction error: 141.

Epoch [141 / 1000] average reconstruction error: 14.934711
Epoch [142 / 1000] average reconstruction error: 14.661499
Epoch [143 / 1000] average reconstruction error: 14.376368
Epoch [144 / 1000] average reconstruction error: 14.096889
Epoch [145 / 1000] average reconstruction error: 13.849844
Epoch [146 / 1000] average reconstruction error: 13.612430
Epoch [147 / 1000] average reconstruction error: 13.356872
Epoch [148 / 1000] average reconstruction error: 13.122458
Epoch [149 / 1000] average reconstruction error: 12.903575
Epoch [150 / 1000] average reconstruction error: 12.684789
Epoch [151 / 1000] average reconstruction error: 12.504235
Epoch [152 / 1000] average reconstruction error: 12.323533
Epoch [153 / 1000] average reconstruction error: 12.256669
Epoch [154 / 1000] average reconstruction error: 12.045737
Epoch [155 / 1000] average reconstruction error: 11.794779
Epoch [156 / 1000] average reconstruction error: 11.567083
Epoch [157 / 1000] average reconstruction error: 11.4424

Epoch [282 / 1000] average reconstruction error: 2.788204
Epoch [283 / 1000] average reconstruction error: 2.750992
Epoch [284 / 1000] average reconstruction error: 2.726539
Epoch [285 / 1000] average reconstruction error: 2.701017
Epoch [286 / 1000] average reconstruction error: 2.658462
Epoch [287 / 1000] average reconstruction error: 2.636829
Epoch [288 / 1000] average reconstruction error: 2.609573
Epoch [289 / 1000] average reconstruction error: 2.577991
Epoch [290 / 1000] average reconstruction error: 2.559901
Epoch [291 / 1000] average reconstruction error: 2.556753
Epoch [292 / 1000] average reconstruction error: 2.582909
Epoch [293 / 1000] average reconstruction error: 2.588148
Epoch [294 / 1000] average reconstruction error: 2.574328
Epoch [295 / 1000] average reconstruction error: 2.534306
Epoch [296 / 1000] average reconstruction error: 2.499971
Epoch [297 / 1000] average reconstruction error: 2.522962
Epoch [298 / 1000] average reconstruction error: 2.524253
Epoch [299 / 1

Epoch [424 / 1000] average reconstruction error: 1.247269
Epoch [425 / 1000] average reconstruction error: 1.245176
Epoch [426 / 1000] average reconstruction error: 1.245852
Epoch [427 / 1000] average reconstruction error: 1.245902
Epoch [428 / 1000] average reconstruction error: 1.248883
Epoch [429 / 1000] average reconstruction error: 1.258593
Epoch [430 / 1000] average reconstruction error: 1.290590
Epoch [431 / 1000] average reconstruction error: 1.322280
Epoch [432 / 1000] average reconstruction error: 1.376687
Epoch [433 / 1000] average reconstruction error: 1.400208
Epoch [434 / 1000] average reconstruction error: 1.394993
Epoch [435 / 1000] average reconstruction error: 1.345701
Epoch [436 / 1000] average reconstruction error: 1.415611
Epoch [437 / 1000] average reconstruction error: 1.484224
Epoch [438 / 1000] average reconstruction error: 1.383358
Epoch [439 / 1000] average reconstruction error: 1.278136
Epoch [440 / 1000] average reconstruction error: 1.395993
Epoch [441 / 1

KeyboardInterrupt: 