# Get Dataset from Google Drive  
Please upload your dataset on google drive first.

In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


# Color-hint Transform

If you want to change how many hints you are giving, change the threshold values in __call__ function.

In [2]:
import torch
from torch.autograd import Variable
from torchvision import transforms

import cv2
import random
import numpy as np

class ColorHintTransform(object):
  def __init__(self, size=256, mode="training"):
    super(ColorHintTransform, self).__init__()
    self.size = size
    self.mode = mode
    self.transform = transforms.Compose([transforms.ToTensor()])

  def bgr_to_lab(self, img):
      lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
      l, ab = lab[:, :, 0], lab[:, :, 1:]
      return l, ab
  
  def hint_mask(self, bgr, threshold=[0.95, 0.97, 0.99]):
    h, w, c = bgr.shape
    mask_threshold = random.choice(threshold)
    mask = np.random.random([h, w, 1]) > threshold
    return mask

  def __call__(self, img):
    threshold = [0.95, 0.97, 0.99]
    if (self.mode == "training") | (self.mode == "validation"):
      image = cv2.resize(img, (self.size, self.size))
      mask = self.hint_mask(image, threshold)

      hint_image = image * mask
      
      l, ab = self.bgr_to_lab(image)
      l_hint, ab_hint = self.bgr_to_lab(hint_image)

      return self.transform(l), self.transform(ab), self.transform(ab_hint)

    elif self.mode == "testing":
      image = cv2.resize(img, (self.size, self.size))

      l, ab = self.bgr_to_lab(image)

      return self.gt_transform(l), self.transform(ab)

    else:
      return NotImplementedError


# Dataloader for Colorization Dataset

In [3]:
import torch
import torch.utils.data  as data
import os
import cv2
from google.colab.patches import cv2_imshow
class ColorHintDataset(data.Dataset):
  def __init__(self, root_path, size):
    super(ColorHintDataset, self).__init__()
 
    self.root_path = root_path
    self.size = size
    self.transforms = None
    self.examples = None
 
  def set_mode(self, mode):
    self.mode = mode
    self.transforms = ColorHintTransform(self.size, mode)
    if mode == "training":
      train_dir = os.path.join(self.root_path, "train")
      self.examples = [os.path.join(self.root_path, "train", dirs) for dirs in os.listdir(train_dir)]
    elif mode == "validation":
      val_dir = os.path.join(self.root_path, "validation")
      self.examples = [os.path.join(self.root_path, "validation", dirs) for dirs in os.listdir(val_dir)]
    elif mode == "testing":
      test_dir = os.path.join(self.root_path, "test")
      self.examples = [os.path.join(self.root_path, "test", dirs) for dirs in os.listdir(test_dir)]
    else:
      raise NotImplementedError
  
  def __len__(self):
    return len(self.examples)
 
  def __getitem__(self, idx):
    file_name = self.examples[idx]
    img = cv2.imread(file_name)

    if self.mode == "testing":
      input_l, input_ab = self.transforms(img)
      sample = {"l": input_l, "ab": input_ab}
    else:
      l, ab, hint = self.transforms(img)
      sample = {"l": l, "ab": ab, "hint": hint}
 
    return sample

# Example for Loading

In [None]:
import torch
import torch.utils.data  as data
import os
import cv2
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
from torchvision import transforms
import tqdm
from PIL import Image

def tensor2im(input_image, imtype=np.uint8):
  if isinstance(input_image, torch.Tensor):
      image_tensor = input_image.data
  else:
      return input_image
  image_numpy = image_tensor[0].cpu().float().numpy()
  if image_numpy.shape[0] == 1:
      image_numpy = np.tile(image_numpy, (3, 1, 1))
  image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0)) ),0, 1) * 255.0
  return image_numpy.astype(imtype)

# Change to your data root directory
root_path = "/content/drive/MyDrive/Multimedia_dataset"
# Depend on runtime setting
use_cuda = True

train_dataset = ColorHintDataset(root_path, 256)
train_dataset.set_mode("training")

train_dataloader = data.DataLoader(train_dataset, batch_size=4, shuffle=True)

for i, data in enumerate(tqdm.tqdm(train_dataloader)):
  if use_cuda:
    l = data["l"]
    ab = data["ab"]
    hint = data["hint"]
  
  gt_image = torch.cat((l, ab), dim=1)
  hint_image = torch.cat((l, hint), dim=1)

  gt_np = tensor2im(gt_image)
  hint_np = tensor2im(hint_image)

  gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_LAB2BGR)
  hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)
  
  cv2_imshow(gt_bgr)
  cv2_imshow(hint_bgr)

  input()
