# Get Dataset from Google Drive
please upload your dataset on google drive first.

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [None]:
import os
import zipfile
import tqdm

file_name = "Multimedia_dataset.zip"
zip_path = os.path.join('/content/drive/MyDrive/Multimedia_dataset.zip')

!cp "{zip_path}" .
!unzip -q "{file_name}"
!rm "{file_name}"

# Color Hint Transform
If you want to change how many hints you are giving, change the threshold values in call function.

In [None]:
import torch
from torch.autograd import Variable
from torchvision import transforms

import cv2
import random
import numpy as np

class ColorHintTransform(object):
  def __init__(self, size=256, mode="training"):
    super(ColorHintTransform, self).__init__()
    self.size = size
    self.mode = mode
    self.transform = transforms.Compose([transforms.ToTensor()])

  def bgr_to_lab(self, img):
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l, ab = lab[:, :, 0], lab[:, :, 1:]
    return l, ab

  def hint_mask(self, bgr, threshold=[0.95, 0.97, 0.99]):
    h, w, c = bgr.shape
    mask_threshold = random.choice(threshold)
    mask = np.random.random([h, w, 1]) > mask_threshold
    return mask

  def img_to_mask(self, mask_img):
    mask = mask_img[:, :, 0, np.newaxis] >= 255
    return mask

  def __call__(self, img, mask_img=None):
    threshold = [0.95, 0.97, 0.99]
    if (self.mode == "training") | (self.mode == "validation"):
      image = cv2.resize(img, (self.size, self.size))
      mask = self.hint_mask(image, threshold)

      hint_image = image * mask

      l, ab = self.bgr_to_lab(image)
      l_hint, ab_hint = self.bgr_to_lab(hint_image)

      return self.transform(l), self.transform(ab), self.transform(ab_hint)

    elif self.mode == "testing":
      image = cv2.resize(img, (self.size, self.size))
      hint_image = image * self.img_to_mask(mask_img)

      l, _ = self.bgr_to_lab(image)
      _, ab_hint = self.bgr_to_lab(hint_image)

      return self.transform(l), self.transform(ab_hint)

    else:
      return NotImplementedError

# Dataloader for Colorization Dataset

In [None]:
import torch
import torch.utils.data  as data
import os
import cv2
from google.colab.patches import cv2_imshow

class ColorHintDataset(data.Dataset):
  def __init__(self, root_path, size):
    super(ColorHintDataset, self).__init__()

    self.root_path = root_path
    self.size = size
    self.transforms = None
    self.examples = None
    self.hint = None
    self.mask = None

  def set_mode(self, mode):
    self.mode = mode
    self.transforms = ColorHintTransform(self.size, mode)
    if mode == "training":
      train_dir = os.path.join(self.root_path, "train")
      self.examples = [os.path.join(self.root_path, "train", dirs) for dirs in os.listdir(train_dir)]
    elif mode == "validation":
      val_dir = os.path.join(self.root_path, "validation")
      self.examples = [os.path.join(self.root_path, "validation", dirs) for dirs in os.listdir(val_dir)]
    elif mode == "testing":
      hint_dir = os.path.join(self.root_path, "hint")
      mask_dir = os.path.join(self.root_path, "mask")
      self.hint = [os.path.join(self.root_path, "hint", dirs) for dirs in os.listdir(hint_dir)]
      self.mask = [os.path.join(self.root_path, "mask", dirs) for dirs in os.listdir(mask_dir)]
    else:
      raise NotImplementedError

  def __len__(self):
    if self.mode != "testing":
      return len(self.examples)
    else:
      return len(self.hint)

  def __getitem__(self, idx):
    if self.mode == "testing":
      hint_file_name = self.hint[idx]
      mask_file_name = self.mask[idx]
      hint_img = cv2.imread(hint_file_name)
      mask_img = cv2.imread(mask_file_name)

      input_l, input_hint = self.transforms(hint_img, mask_img)
      sample = {"l": input_l, "hint": input_hint,
                "file_name": "image_%06d.png" % int(os.path.basename(hint_file_name).split('.')[0])}
    else:
      file_name = self.examples[idx]
      img = cv2.imread(file_name)
      l, ab, hint = self.transforms(img)
      sample = {"l": l, "ab": ab, "hint": hint}

    return sample

# Network
Unet network for example.

In [None]:
""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    # (convolution => [BN] => ReLU) * 2

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    # Downscaling with maxpool then double conv

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    # Upscaling then double conv

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
""" Full assembly of the parts to form the complete network """
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

# Tensorboard
For training progress visualization. Run before training phase.

In [None]:
!pip install tensorboardX

Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/07/84/46421bd3e0e89a92682b1a38b40efc22dafb6d8e3d947e4ceefd4a5fabc7/tensorboardX-2.2-py2.py3-none-any.whl (120kB)
[K     |██▊                             | 10kB 21.2MB/s eta 0:00:01[K     |█████▍                          | 20kB 27.3MB/s eta 0:00:01[K     |████████▏                       | 30kB 21.0MB/s eta 0:00:01[K     |██████████▉                     | 40kB 16.5MB/s eta 0:00:01[K     |█████████████▋                  | 51kB 9.9MB/s eta 0:00:01[K     |████████████████▎               | 61kB 10.1MB/s eta 0:00:01[K     |███████████████████             | 71kB 9.9MB/s eta 0:00:01[K     |█████████████████████▊          | 81kB 11.0MB/s eta 0:00:01[K     |████████████████████████▌       | 92kB 11.3MB/s eta 0:00:01[K     |███████████████████████████▏    | 102kB 9.5MB/s eta 0:00:01[K     |██████████████████████████████  | 112kB 9.5MB/s eta 0:00:01[K     |████████████████████████████████| 12

In [None]:
# %load_ext tensorboard
%reload_ext tensorboard

In [None]:
%tensorboard --logdir logs

# Training Phase
Unet training code.

In [None]:
import torch
import torch.nn as nn
import torch.utils.data  as data

from torchvision import transforms
from tensorboardX import SummaryWriter
import torchvision.utils as tvutils

import os
import matplotlib.pyplot as plt
import numpy as np
import tqdm.notebook as tq
from PIL import Image
from skimage.measure.simple_metrics import compare_psnr

def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += compare_psnr(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range)
    return (PSNR/Img.shape[0])

# Change to your data root directory
root_path = "/content/"
save_path = "/content/drive/MyDrive/Colorization_models"

# Depend on runtime setting
use_cuda = True

# Dataloader setting
train_dataset = ColorHintDataset(root_path, 128)
train_dataset.set_mode("training")

val_dataset = ColorHintDataset(root_path, 128)
val_dataset.set_mode("validation")

train_dataloader = data.DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=4, shuffle=True)

# Model declaration
net = UNet(2, 2)
model = nn.DataParallel(net)

# loss
criterion = nn.MSELoss(size_average=False)

if use_cuda:
  model.to('cuda')
  criterion.to('cuda')

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, eps=1e-08)

step = 0
# Tensorboard writer
writer = SummaryWriter("logs")

for epoch in range(30):
  print('Epoch {}/{}'.format(epoch + 1, 30))
  print('-' * 10)

  for i, data in enumerate(tq.tqdm(train_dataloader)):
    if use_cuda:
      l = data["l"].to('cuda')
      ab = data["ab"].to('cuda')
      hint = data["hint"].to('cuda')
    
    model.train()
    model.zero_grad()
    optimizer.zero_grad()

    preds = model(hint)

    loss = criterion(preds, ab)

    loss.backward()
    optimizer.step()
    
    gt_image = torch.cat((l, ab), dim=1)
    pred_image = torch.cat((l, preds), dim=1)
    psnr_train = batch_PSNR(gt_image, pred_image, 1.)

    if step % 100 == 0:
      # Log the scalar values
      writer.add_scalar('loss', loss.item(), step)
      writer.add_scalar('PSNR on training data', psnr_train, step)

      # log the images => Tensorboard
      Img = tvutils.make_grid(gt_image.data, nrow=4, normalize=True, scale_each=True)
      Irecon = tvutils.make_grid(pred_image.data, nrow=4, normalize=True, scale_each=True)
      writer.add_image('GT image', Img, epoch)
      writer.add_image('reconstructed image', Irecon, epoch)
      print("[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f" %
          (epoch + 1, i + 1, len(train_dataloader), loss.item(), psnr_train))
    step += 1
    
  torch.save(model.module.state_dict(), os.path.join(save_path, "{}.tar".format(epoch+1)))
  print("saved at {}".format(os.path.join(save_path, "{}.tar".format(epoch+1))))

  psnr_val = []

  # Validation on training phase
  model.eval()
  with torch.no_grad():
    for val_data in tq.tqdm(val_dataloader):
      l = val_data["l"].to('cuda')
      ab = val_data["ab"].to('cuda')
      hint = val_data["hint"].to('cuda')

      preds = model(hint)

      gt_image = torch.cat((l, ab), dim=1)
      pred_image = torch.cat((l, preds), dim=1)
      psnr = batch_PSNR(gt_image, pred_image, 1.)
      psnr_val.append(psnr)

      val_Img = tvutils.make_grid(gt_image.data, nrow=4, normalize=True, scale_each=True)
      val_Irecon = tvutils.make_grid(pred_image.data, nrow=4, normalize=True, scale_each=True)
      writer.add_image('validation gt image', val_Img, epoch)
      writer.add_image('validation reconstructed image', val_Irecon, epoch+1)
    
    mean_val = np.mean(psnr_val)
    print("\n[epoch %d] PSNR_val: %.4f" % (epoch + 1, mean_val))
    writer.add_scalar('PSNR on validation data', mean_val, epoch)

In [None]:
import torch
import torch.nn as nn
import torch.utils.data  as data

import os
import cv2
import numpy as np
import tqdm.notebook as tq
from PIL import Image
from skimage.measure.simple_metrics import compare_psnr

def image_save(img, path):
  if isinstance(img, torch.Tensor):
    img = np.asarray(transforms.ToPILImage()(img))
  img = cv2.cvtColor(img, cv2.COLOR_LAB2BGR)
  cv2.imwrite(path, img)

def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += compare_psnr(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range)
    return (PSNR/Img.shape[0])

# Change to your data root directory
image_path = "/content/drive/MyDrive/Multimedia_test_dataset/colorization2/"
checkpoint_path = "/content/drive/MyDrive/Colorization_models/25.tar"
result_save_path = "/content/drive/MyDrive/Multimedia_test_dataset/colorization_test_result"

# Depend on runtime setting
use_cuda = True

test_dataset = ColorHintDataset(image_path, 128)
test_dataset.set_mode("testing")

test_dataloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False)

net = UNet(2, 2)

if use_cuda:
  net.to('cuda')

net.load_state_dict(torch.load(checkpoint_path))
model = nn.DataParallel(net)

model.eval()

for i, data in enumerate(tq.tqdm(test_dataloader)):
  if use_cuda:
    l = data["l"].to('cuda')
    hint = data["hint"].to('cuda')
  file_name = data["file_name"]

  with torch.no_grad():
    out_test = model(hint)
    pred_image = torch.cat((l, out_test), dim=1)
    for idx in range(len(file_name)):
      image_save(pred_image[idx], os.path.join(result_save_path, file_name[idx]))
