## Image Colorization in the RGB Colorspace



In [None]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!unzip "/content/drive/My Drive/Object Detection/Image Colorization/landscape_images.zip" -d "/content/drive/My Drive/Object Detection/Image Colorization/landscape_images/"

Archive:  /content/drive/My Drive/Object Detection/Image Colorization/landscape_images.zip


KeyboardInterrupt: ignored

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%cd '/content/drive/MyDrive/Object Detection/Image Colorization'
%ls

/content/drive/.shortcut-targets-by-id/1l9lgKWCgTf4rXpuIuwGAHyEDtRWf-MtR/Object Detection/Image Colorization
 BasicModel-epoch-44.pt
 basic_model.py
 [0m[01;34mbasicNet_files[0m/
 basicNet.html
 ColorizationModel-epoch-50.pt
 colorize_data.py
'evegaCopy of Image_Colorization_RGB.ipynb'
 [01;34mimages[0m/
 [01;34mlandscape_images[0m/
 landscape_images.zip
 [01;34moutputs[0m/
 [01;34m__pycache__[0m/
 rename.py
 ResNetUNetModel-epoch-after-26-24.pt
'swatibCopy of Image_Colorization_RGB.ipynb'
 unet_diagram.webp
 UNetModel-epoch-34.pt
 UnetResnet_diagram.png
'(Use_me_1)_Image_Colorization_RGB.ipynb'
'(Use_this_2)Image_Colorization_RGB.ipynb'


In [None]:
#!unzip '//content//drive//My Drive//442 Project//Image Colorisation//landscape_images.zip' -d '//content//drive//My Drive//442 Project//Image Colorisation//images'


Import the necessary modules 

In [None]:
import torch
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
from colorize_data import *
from torch.nn import *
import torch
import torch.nn.functional as F
import torchvision.models as models
from torchvision import datasets, transforms
import os, shutil, time
from collections import OrderedDict
import torch.nn as nn

# Preparing the data for training 
 
1. Load data from the image folder and create train and val dataloaders

In [None]:
# Train Dataloader
training_data = ColorizeData('/content/drive/MyDrive/Object Detection/Image Colorization/landscape_images/landscape_images/train')
train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)


# Validation Dataloader
validation_data = ColorizeData('/content/drive/MyDrive/Object Detection/Image Colorization/landscape_images/landscape_images/val')
val_dataloader = DataLoader(validation_data, batch_size=32, shuffle=True)


## Define the model 

#### 1. Basic Model 

In [None]:
class BasicNet(nn.Module):
  
    def __init__(self, d=128):
        super(BasicNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1) 
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1) 
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) 
        self.batchnorm1 = nn.BatchNorm2d(128)

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 
        self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 

        self.convTrans1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) 
        self.batchnorm2 = nn.BatchNorm2d(64)

        self.convTrans2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1) 
        self.batchnorm3 = nn.BatchNorm2d(32)

        self.convTrans3 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1) 

    def forward(self, input):

        x = F.relu(self.conv1(input))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.batchnorm1(self.conv4(x)))
        x = F.relu(self.conv5(x))
        x = F.relu(self.batchnorm2(self.convTrans1(x)))
        x = F.relu(self.batchnorm3(self.convTrans2(x)))
        x = self.convTrans3(x)

        return x


#### 2. Colorization Model with Resnet Encoder

In [None]:
import torch.nn as nn
import torchvision.models as models


class ColorizationModel(nn.Module):
  def __init__(self, input_size=256):
    super(ColorizationModel, self).__init__()
    
    # ResNet - First layer accepts grayscale images, 
    resnet = models.resnet18(num_classes=100)
    resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1)) 
    self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])
    RESNET_FEATURE_SIZE = 128

    ## Upsampling Network
    self.upsample = nn.Sequential(     
      nn.Conv2d(RESNET_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(128),
      nn.ReLU(),
      nn.Upsample(scale_factor=2),
      nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      nn.Upsample(scale_factor=2),
      nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(32),
      nn.ReLU(),
      nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
      nn.Upsample(scale_factor=2)
    )

  def forward(self, input):
    midlevel_features = self.midlevel_resnet(input)
    output = self.upsample(midlevel_features)
    return output

#### 3. ResNetUNet Model 
Uses the first few layers of ResNet in the encoder part and UNet type upsampling in the decoder 

In [None]:
import torch.nn as nn
import torchvision.models

# Define the ConvRelu block which does the Sequential operations of Convolution and ReLU
def ConvRelu(in_channels, out_channels, kernel, padding):
  return nn.Sequential(
    nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
    nn.ReLU(inplace=True),
  )


class ResNetUNet(nn.Module):
  def __init__(self, n_class):
    super().__init__()

    self.base_model = torchvision.models.resnet18(pretrained=True)
    self.base_layers = list(self.base_model.children())

    self.conv0 = nn.Sequential(*self.base_layers[:3]) 
    self.convRelu0 = ConvRelu(64, 64, 1, 0)
    self.conv1 = nn.Sequential(*self.base_layers[3:5]) 
    self.convRelu1 = ConvRelu(64, 64, 1, 0)
    self.conv2 = self.base_layers[5]  
    self.convRelu2 = ConvRelu(128, 128, 1, 0)
    self.conv3 = self.base_layers[6]  
    self.convRelu3 = ConvRelu(256, 256, 1, 0)
    self.conv4 = self.base_layers[7]  
    self.convRelu4 = ConvRelu(512, 512, 1, 0)

    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    self.convRelu_up3 = ConvRelu(256 + 512, 512, 3, 1)
    self.convRelu_up2 = ConvRelu(128 + 512, 256, 3, 1)
    self.convRelu_up1 = ConvRelu(64 + 256, 256, 3, 1)
    self.convRelu_up0 = ConvRelu(64 + 256, 128, 3, 1)

    self.convOriginal0 = ConvRelu(3, 64, 3, 1)
    self.convOriginal1 = ConvRelu(64, 64, 3, 1)
    self.convOriginal2 = ConvRelu(64 + 128, 64, 3, 1)

    self.conv_last = nn.Conv2d(64, n_class, 1)

  def forward(self, input):
    x_original = self.convOriginal0(input)
    x_original = self.convOriginal1(x_original)

    conv0 = self.conv0(input)
    conv1 = self.conv1(conv0)
    conv2 = self.conv2(conv1)
    conv3 = self.conv3(conv2)
    conv4 = self.conv4(conv3)

    conv4 = self.convRelu4(conv4)
    x = self.upsample(conv4)
    conv3 = self.convRelu3(conv3)
    x = torch.cat([x, conv3], dim=1)
    x = self.convRelu_up3(x)

    x = self.upsample(x)
    conv2 = self.convRelu2(conv2)
    x = torch.cat([x, conv2], dim=1)
    x = self.convRelu_up2(x)

    x = self.upsample(x)
    conv1 = self.convRelu1(conv1)
    x = torch.cat([x, conv1], dim=1)
    x = self.convRelu_up1(x)

    x = self.upsample(x)
    conv0 = self.convRelu0(conv0)
    x = torch.cat([x, conv0], dim=1)
    x = self.convRelu_up0(x)

    x = self.upsample(x)
    x = torch.cat([x, x_original], dim=1)
    x = self.convOriginal2(x)

    out = self.conv_last(x)

    return out

#### 4. UNet Colorization Model

(Adapted from the Pytorch tutorial)

In [None]:

class UNet(nn.Module):

    def __init__(self, in_channels=1, out_channels=3, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=1)

    def forward(self, x):
      
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

# Set the hyperparameters

In [None]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()

# Initialise the model 
model = BasicNet()

# Define the hyperparameters
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0.0)
criterion = nn.MSELoss()  # Change this to try different loss functions 

num_epochs = 50
best_loss = 1e10


In [None]:
from torchsummary import summary
summary(model.cpu(), input_size=(1, 256, 256))

## Putting it all together 

1. Define the metric used to track model performance during training 
(Adapted from the Pytorch tutorial)

2. Define the training function

3. Define the validation function

In [None]:
# A handy class from the PyTorch ImageNet tutorial
class AverageMeter(object):
  
  def __init__(self):
    self.reset()
  
  def reset(self):
    self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
  
  def update(self, val, n=1):
      
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count
    

# Define the validation function 
def validation(val_loader, model, criterion):
    
    # Change the model to the eval mode during validation
    model.eval()

    # Prepare value counters and timers
    batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
  
    end = time.time()
    
    for i, (input_image, target_image) in enumerate(val_loader):
        
      data_time.update(time.time() - end)
  
      # If GPU available then use it
      if use_gpu: 
          
          input_image, target_image = input_image.cuda(), target_image.cuda()
          model = model.cuda()
          criterion = criterion.cuda()
          
      # Run validation pass on the model  
      predicted = model(input_image)
      
      # Calculate the losses 
      loss = criterion(predicted, target_image)
      losses.update(loss.item(), input_image.size(0))
      
      # Record time to do forward pass
      batch_time.update(time.time() - end)
      end = time.time()
  
      # Print model accuracy 
      if i % 25 == 0:
          
        print('Validate: [{0}/{1}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(i, len(val_loader), 
                                                              batch_time=batch_time, loss=losses))
  
    print('Completed Validation Step')
    return losses.avg



# Define the train function 
def training(train_loader, model, criterion, optimizer, epoch):
    
  print('Training epoch {}'.format(epoch))
  
  # Set the model to train mode
  model.train()
  
  # Prepare value counters and timers
  batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
  end = time.time()
  
  for i, (input_image, target_image) in enumerate(train_loader):
    
    # If GPU available then use it
    if use_gpu: 
        
        input_image, target_image = input_image.cuda(), target_image.cuda()
        model = model.cuda()
        criterion = criterion.cuda()

    # Record time to load data 
    data_time.update(time.time() - end)
    
    # Run validation pass on the model  
    predicted = model(input_image)
    
    # Calculate the losses 
    loss = criterion(predicted, target_image)
    losses.update(loss.item(), input_image.size(0))

    # Compute gradient and optimize in backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Record time to do forward and backward passes
    batch_time.update(time.time() - end)
    end = time.time()

    # Print model accuracy 
    if i % 25 == 0:
      print('Epoch: [{0}][{1}/{2}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
              epoch, i, len(train_loader), batch_time=batch_time,
             data_time=data_time, loss=losses)) 

  print('Completed training epoch {}'.format(epoch))

  return losses.avg
    

# Start the training process 

1. Begin training the model using the hyperparameters defined above and see model performance on validation data

In [None]:
train_losses = []
val_losses = []

# Start the process of training model
for epoch in range(num_epochs):
    
    # Train model for every epoch and the call validation to track model performance 
    trainLoss = training(train_dataloader, model, criterion, optimizer, epoch)
    
    with torch.no_grad():
      valLoss = validation(val_dataloader, model, criterion)
      
    # Save best model
    if valLoss < best_loss:
      best_loss = valLoss
      torch.save(model.state_dict(), 'models/BasicModel-epoch-{}.pt'.format(epoch+1))

    # Save the train and val loss 
    train_losses.append(trainLoss)
    val_losses.append(valLoss)
      
      

# Test model performance and Visualise results

In [None]:
# Plot the train loss and val loss curves 

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from skimage import color, io
import torchvision.utils
import matplotlib.pyplot as plt
plt.ion()

fig = plt.figure(figsize=(15, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val loss')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

NameError: ignored

<Figure size 1080x360 with 0 Axes>

In [None]:
import matplotlib.pyplot as plt

# Model Path 
model_path = '/content/drive/MyDrive/Object Detection/Image Colorization/ResNetUNetModel-epoch-after-26-24.pt'

# Load Model and set to evaluation mode
model = ResNetUNet(3)

model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))
model.eval()

# Define the test path 
test_path = "/content/drive/MyDrive/Object Detection/Image Colorization/landscape_images/landscape_images/test_eleazar"

# Train Dataloader
test_data = ColorizeData(test_path)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)

for i, (input_image, target_image) in enumerate(test_dataloader):
        
    losses = AverageMeter()

    with torch.no_grad():

        if use_gpu:
          input_image = input_image.cuda()
          target_image = target_image.cuda()
          model = model.cuda()

        predicted = model(input_image)

        # Calculate the loss 
        loss = criterion(predicted, target_image)
        losses.update(loss.item(), input_image.size(0))

    # plot images
    fig, ax = plt.subplots(figsize=(15, 15), nrows=1, ncols=2)
  
    predicted = predicted[0,:,:,:].cpu().numpy()
    target_image = target_image[0,:,:,:].cpu().numpy()

    ax[0].imshow(np.transpose(predicted, (1,2,0)))
    ax[0].title.set_text('Predicted')
    ax[1].imshow(np.transpose(target_image, (1,2,0)))
    ax[1].title.set_text('Ground Truth')
    plt.show()
    
print('Loss: %f' % (losses.avg))

RuntimeError: ignored

In [None]:
#calculate psnr
PSNR = 10*np.log(255**2/losses.avg) / np.log(10)
print(PSNR)



58.137399765304274


In [None]:
!pip install torchviz

Collecting torchviz
  Downloading torchviz-0.0.2.tar.gz (4.9 kB)
Building wheels for collected packages: torchviz
  Building wheel for torchviz (setup.py) ... [?25l[?25hdone
  Created wheel for torchviz: filename=torchviz-0.0.2-py3-none-any.whl size=4150 sha256=1f9e15c253fda8324ab1c763f850b219f5af71bdb268d16729090813f85ce143
  Stored in directory: /root/.cache/pip/wheels/04/38/f5/dc4f85c3909051823df49901e72015d2d750bd26b086480ec2
Successfully built torchviz
Installing collected packages: torchviz
Successfully installed torchviz-0.0.2


In [None]:
import torchviz 
from torchviz import *
from torchviz import make_dot, make_dot_from_trace

x = torch.randn(1, 1, 256, 256)

model = BasicNet()

make_dot(model(x), params=dict(model.named_parameters())).render("basic.png",format="png")


'basic.png.png'

In [None]:
x = torch.randn(1, 1, 256, 256)

model = ColorizationModel()

make_dot(model(x), params=dict(model.named_parameters())).render("colorization.png",format="png")

'colorization.png.png'

In [None]:
x = torch.randn(1, 3, 256, 256)

model = ResNetUNet(3)

make_dot(model(x), params=dict(model.named_parameters())).render("resnet",format="png")

'resnet.png'

In [None]:
x = torch.randn(1, 1, 256, 256)

model = UNet()

make_dot(model(x), params=dict(model.named_parameters())).render("unet.png",format="png")

'unet.png.png'