<a href="https://colab.research.google.com/github/dntjr41/CV_TermP/blob/main/TermP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Computer Vision - Term Project [ConvNet Challenge]

In [None]:
# Connect Google Drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:

import tqdm
import zipfile
import os

train_file_name = 'colorization_dataset.zip'
test_file_name = 'test_dataset.zip'

train_zip_path = '/content/drive/MyDrive/colorization_dataset.zip'
test_zip_path = '/content/drive/MyDrive/test_dataset.zip'

!cp = "{train_zip_path}" .
!unzip -q '{train_file_name}'
!rm '{train_file_name}'

!cp = "{test_zip_path}" .
!unzip -q '{test_file_name}'
!rm '{test_file_name}'

# Check Dataset & Color Hint (기존에 올라온 코드 [dataloader])

In [None]:
import os

print(len(os.listdir('./cv_project/train')))
print(len(os.listdir('./cv_project/val')))

print(len(os.listdir('./test_dataset/hint')))
print(len(os.listdir('./test_dataset/mask')))

In [None]:
from torchvision import transforms

import cv2
import random
import numpy as np
import torch
import torch.utils.data as data
import os

class ColorHintTransform(object):
    def __init__(self, size=256, mode="train"):
        super(ColorHintTransform, self).__init__()
        self.size = size
        self.mode = mode
        self.transform = transforms.Compose([transforms.ToTensor()])

    def bgr_to_lab(self, img):
        lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
        l, ab = lab[:, :, 0], lab[:, :, 1:]
        return l, ab

    def hint_mask(self, bgr, threshold=[0.95, 0.97, 0.99]):
        h, w, c = bgr.shape
        mask_threshold = random.choice(threshold)
        mask = np.random.random([h, w, 1]) > mask_threshold
        return mask

    def img_to_mask(self, mask_img):
        mask = mask_img[:, :, 0, np.newaxis] >= 255
        return mask

    def __call__(self, img, mask_img=None):
        threshold = [0.95, 0.97, 0.99]
        if (self.mode == "train") | (self.mode == "val"):
            image = cv2.resize(img, (self.size, self.size))
            mask = self.hint_mask(image, threshold)

            hint_image = image * mask

            l, ab = self.bgr_to_lab(image)
            l_hint, ab_hint = self.bgr_to_lab(hint_image)

            return self.transform(l), self.transform(ab), self.transform(ab_hint)

        elif self.mode == "test":
            image = cv2.resize(img, (self.size, self.size))
            hint_image = image * self.img_to_mask(mask_img)

            l, _ = self.bgr_to_lab(image)
            _, ab_hint = self.bgr_to_lab(hint_image)

            return self.transform(l), self.transform(ab_hint)

        else:
            return NotImplementedError

class ColorHintDataset(data.Dataset):
    def __init__(self, root_path, size, mode="train"):
        super(ColorHintDataset, self).__init__()

        self.root_path = root_path
        self.size = size
        self.mode = mode
        self.transforms = ColorHintTransform(self.size, self.mode)
        self.examples = None
        self.hint = None
        self.mask = None

        if self.mode == "train":
            train_dir = os.path.join(self.root_path, "train")
            self.examples = [os.path.join(self.root_path, "train", dirs) for dirs in os.listdir(train_dir)]
        elif self.mode == "val":
            val_dir = os.path.join(self.root_path, "val")
            self.examples = [os.path.join(self.root_path, "validation", dirs) for dirs in os.listdir(val_dir)]
        elif self.mode == "test":
            hint_dir = os.path.join(self.root_path, "hint")
            mask_dir = os.path.join(self.root_path, "mask")
            self.hint = [os.path.join(self.root_path, "hint", dirs) for dirs in os.listdir(hint_dir)]
            self.mask = [os.path.join(self.root_path, "mask", dirs) for dirs in os.listdir(mask_dir)]
        else:
            raise NotImplementedError

    def __len__(self):
        if self.mode != "test":
            return len(self.examples)
        else:
            return len(self.hint)

    def __getitem__(self, idx):
        if self.mode == "test":
            hint_file_name = self.hint[idx]
            mask_file_name = self.mask[idx]
            hint_img = cv2.imread(hint_file_name)
            mask_img = cv2.imread(mask_file_name)

            input_l, input_hint = self.transforms(hint_img, mask_img)
            sample = {"l": input_l, "hint": input_hint,
                      "file_name": "image_%06d.png" % int(os.path.basename(hint_file_name).split('.')[0])}
        else:
            file_name = self.examples[idx]
            img = cv2.imread(file_name)
            l, ab, hint = self.transforms(img)
            sample = {"l": l, "ab": ab, "hint": hint}

        return sample

In [None]:
import torch
import torch.utils.data as data
import cv2
import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as img

def tensor2im(input_image, imtype=np.uint8):
    if isinstance(input_image, torch.Tensor):
        image_tensor = input_image.data
    else:
        return input_image
    image_numpy = image_tensor[0].cpu().float().numpy()
    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3, 1, 1))
    image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0))), 0, 1) * 255.0
    return image_numpy.astype(imtype)


# Change to your data root directory
root_path = "./cv_project"
# Depend on runtime setting
use_cuda = True

train_dataset = ColorHintDataset(root_path, 256, "train")
train_dataloader = data.DataLoader(train_dataset, batch_size=4, shuffle=True)

for i, data in enumerate(tqdm.tqdm(train_dataloader)):
    if use_cuda:
        l = data["l"].to('cuda')
        ab = data["ab"].to('cuda')
        hint = data["hint"].to('cuda')
    else:
        l = data["l"]
        ab = data["ab"]
        hint = data["hint"]

    gt_image = torch.cat((l, ab), dim=1)
    hint_image = torch.cat((l, hint), dim=1)

    gt_np = tensor2im(gt_image)
    hint_np = tensor2im(hint_image)

    gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_LAB2RGB)
    hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2RGB)

    plt.figure(1)
    plt.imshow(gt_bgr)
    plt.figure(2)
    plt.imshow(hint_bgr)
    plt.show()

    input()


# Network Construction (여기부터 구현)