<a href="https://colab.research.google.com/github/dntjr41/CV_TermP/blob/main/Image_Colorization_UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Computer Vision - Term Project [ConvNet Challenge]

In [None]:
# Connect Google Drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
import tqdm
import zipfile
import os

train_file_name = 'colorization_dataset.zip'
test_file_name = 'test_dataset.zip'

train_zip_path = '/content/drive/MyDrive/colorization_dataset.zip'
test_zip_path = '/content/drive/MyDrive/test_dataset.zip'

!cp = "{train_zip_path}" .
!unzip -q '{train_file_name}'
!rm '{train_file_name}'

!cp = "{test_zip_path}" .
!unzip -q '{test_file_name}'
!rm '{test_file_name}'

# Check Dataset & Color Hint (기존에 올라온 코드 [dataloader])

In [None]:
import os

print(len(os.listdir('./cv_project/train')))
print(len(os.listdir('./cv_project/val')))

print(len(os.listdir('./test_dataset/hint')))
print(len(os.listdir('./test_dataset/mask')))

In [None]:
import torch
from torch.autograd import Variable
from torchvision import transforms
import torch.utils.data as data

import os
import cv2
import random
import numpy as np

import matplotlib.pyplot as plt
from torchvision import transforms
import tqdm
from PIL import Image
import numpy as np

class ColorHintTransform(object):
  def __init__(self, size=256, mode="train"):
    super(ColorHintTransform, self).__init__()
    self.size = size
    self.mode = mode
    self.transform = transforms.Compose([transforms.ToTensor()])

  def bgr_to_lab(self, img):
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l, ab = lab[:, :, 0], lab[:, :, 1:]
    return l, ab

  def hint_mask(self, bgr, threshold=[0.95, 0.97, 0.99]):
    h, w, c = bgr.shape
    mask_threshold = random.choice(threshold)
    mask = np.random.random([h, w, 1]) > mask_threshold
    return mask

  def img_to_mask(self, mask_img):
    mask = mask_img[:, :, 0, np.newaxis] >= 255
    return mask

  def __call__(self, img, mask_img=None):
    threshold = [0.95, 0.97, 0.99]
    if (self.mode == "train") | (self.mode == "val"):
      image = cv2.resize(img, (self.size, self.size))
      mask = self.hint_mask(image, threshold)

      hint_image = image * mask

      l, ab = self.bgr_to_lab(image)
      l_hint, ab_hint = self.bgr_to_lab(hint_image)

      return self.transform(l), self.transform(ab), self.transform(ab_hint)

    elif self.mode == "test":
      image = cv2.resize(img, (self.size, self.size))
      hint_image = image * self.img_to_mask(mask_img)

      l, _ = self.bgr_to_lab(image)
      _, ab_hint = self.bgr_to_lab(hint_image)

      return self.transform(l), self.transform(ab_hint)

    else:
      return NotImplementedError


class ColorHintDataset(data.Dataset):
  def __init__(self, root_path, size):
    super(ColorHintDataset, self).__init__()

    self.root_path = root_path
    self.size = size
    self.transforms = None
    self.examples = None
    self.hint = None
    self.mask = None

  def set_mode(self, mode):
    self.mode = mode
    self.transforms = ColorHintTransform(self.size, mode)
    if mode == "train":
      train_dir = os.path.join(self.root_path, "train")
      self.examples = [os.path.join(self.root_path, "train", dirs) for dirs in os.listdir(train_dir)]
    elif mode == "val":
      val_dir = os.path.join(self.root_path, "val")
      self.examples = [os.path.join(self.root_path, "val", dirs) for dirs in os.listdir(val_dir)]
    elif mode == "test":
      hint_dir = os.path.join(self.root_path, "hint")
      mask_dir = os.path.join(self.root_path, "mask")
      self.hint = [os.path.join(self.root_path, "hint", dirs) for dirs in os.listdir(hint_dir)]
      self.mask = [os.path.join(self.root_path, "mask", dirs) for dirs in os.listdir(mask_dir)]
    else:
      raise NotImplementedError

  def __len__(self):
    if self.mode != "test":
      return len(self.examples)
    else:
      return len(self.hint)

  def __getitem__(self, idx):
    if self.mode == "test":
      hint_file_name = self.hint[idx]
      mask_file_name = self.mask[idx]
      hint_img = cv2.imread(hint_file_name)
      mask_img = cv2.imread(mask_file_name)

      input_l, input_hint = self.transforms(hint_img, mask_img)
      sample = {"l": input_l, "hint": input_hint,
                "file_name": "image_%06d.png" % int(os.path.basename(hint_file_name).split('.')[0])}
    else:
      file_name = self.examples[idx]
      img = cv2.imread(file_name)
      l, ab, hint = self.transforms(img)
      sample = {"l": l, "ab": ab, "hint": hint}

    return sample


def tensor2im(input_image, imtype=np.uint8):
    if isinstance(input_image, torch.Tensor):
        image_tensor = input_image.data
    else:
        return input_image
    image_numpy = image_tensor[0].cpu().float().numpy()
    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3, 1, 1))
    image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0))), 0, 1) * 255.0
    return image_numpy.astype(imtype)



In [None]:
# import torch
# import torch.utils.data as data
# import cv2
# import tqdm
# import numpy as np
# import matplotlib.pyplot as plt
# import matplotlib.image as img

# # Change to your data root directory
# root_path = "./test_dataset"
# # Depend on runtime setting
# use_cuda = True

# train_dataset = ColorHintDataset(root_path, 256, "test")
# train_dataloader = data.DataLoader(train_dataset, batch_size=4, shuffle=True)

# for i, data in enumerate(tqdm.tqdm(train_dataloader)):
#     if use_cuda:
#         l = data["l"].to('cuda')
#         ab = data["ab"].to('cuda')
#         hint = data["hint"].to('cuda')
#     else:
#         l = data["l"]
#         ab = data["ab"]
#         hint = data["hint"]

#     gt_image = torch.cat((l, ab), dim=1)
#     hint_image = torch.cat((l, hint), dim=1)

#     gt_np = tensor2im(gt_image)
#     hint_np = tensor2im(hint_image)

#     gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_LAB2RGB)
#     hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2RGB)

#     plt.figure(1)
#     plt.imshow(hint_bgr)
#     plt.show()

#     input()

# Network Construction (여기부터 구현)

# Unet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.double_conv = nn.Sequential(nn.Conv2d(nin, nout, 3, padding=1, stride=1),
                                         nn.BatchNorm2d(nout),
                                         nn.ReLU(inplace=True),
                                         nn.Conv2d(nout, nout, 3, padding=1, stride=1),
                                         nn.BatchNorm2d(nout),
                                         nn.ReLU(inplace=True)
                                         )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.down_conv = nn.Sequential(nn.MaxPool2d(2),
                                       DoubleConv(nin, nout))

    def forward(self, x):
        return self.down_conv(x)


class Up(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.double_conv = DoubleConv(nin, nout)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # padding
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2))

        x = torch.cat([x2, x1], dim=1)
        x = self.double_conv(x)
        return x


class OutConv(nn.Module):
    def __init__(self, nin, nout):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.in_conv = DoubleConv(nin, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024 // 2)
        self.up1 = Up(1024, 512 // 2)
        self.up2 = Up(512, 256 // 2)
        self.up3 = Up(256, 128 // 2)
        self.up4 = Up(128, 64)
        self.out_conv = OutConv(64, nout)

    def forward(self, x):
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.out_conv(x)
        return x
# from torch import nn

# class conv_block(nn.Module):
#     def __init__(self,ch_in,ch_out):
#         super(conv_block,self).__init__()
#         self.conv = nn.Sequential(
#             nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
#             nn.BatchNorm2d(ch_out),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
#             nn.BatchNorm2d(ch_out),
#             nn.ReLU(inplace=True)
#         )


#     def forward(self,x):
#         x = self.conv(x)
#         return x

# class up_conv(nn.Module):
#     def __init__(self,ch_in,ch_out):
#         super(up_conv,self).__init__()
#         self.up = nn.Sequential(
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
# 		    nn.BatchNorm2d(ch_out),
# 			nn.ReLU(inplace=True)
#         )

#     def forward(self,x):
#         x = self.up(x)
#         return x

# class U_Net(nn.Module):
#     def __init__(self,img_ch=3,output_ch=3):
#         super(U_Net,self).__init__()
        
#         self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

#         self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
#         self.Conv2 = conv_block(ch_in=64,ch_out=128)
#         self.Conv3 = conv_block(ch_in=128,ch_out=256)
#         self.Conv4 = conv_block(ch_in=256,ch_out=512)
#         self.Conv5 = conv_block(ch_in=512,ch_out=1024)

#         self.Up5 = up_conv(ch_in=1024,ch_out=512)
#         self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

#         self.Up4 = up_conv(ch_in=512,ch_out=256)
#         self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
#         self.Up3 = up_conv(ch_in=256,ch_out=128)
#         self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
#         self.Up2 = up_conv(ch_in=128,ch_out=64)
#         self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

#         self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


#     def forward(self,x):
#         # encoding path
#         x1 = self.Conv1(x)

#         x2 = self.Maxpool(x1)
#         x2 = self.Conv2(x2)
        
#         x3 = self.Maxpool(x2)
#         x3 = self.Conv3(x3)

#         x4 = self.Maxpool(x3)
#         x4 = self.Conv4(x4)

#         x5 = self.Maxpool(x4)
#         x5 = self.Conv5(x5)

#         # decoding + concat path
#         d5 = self.Up5(x5)
#         d5 = torch.cat((x4,d5),dim=1)
        
#         d5 = self.Up_conv5(d5)
        
#         d4 = self.Up4(d5)
#         d4 = torch.cat((x3,d4),dim=1)
#         d4 = self.Up_conv4(d4)

#         d3 = self.Up3(d4)
#         d3 = torch.cat((x2,d3),dim=1)
#         d3 = self.Up_conv3(d3)

#         d2 = self.Up2(d3)
#         d2 = torch.cat((x1,d2),dim=1)
#         d2 = self.Up_conv2(d2)

#         d1 = self.Conv_1x1(d2)

#         return d1


# Utils

In [None]:
!pip install pytorch_ssim
!pip install pytorch_msssim

In [None]:
import cv2
from skimage.metrics import structural_similarity
from pytorch_msssim import ssim
import pytorch_ssim
from torchvision.transforms import ToTensor
from torch.autograd import Variable


## calculate loss per image##
class AverageMeter(object):
  '''A handy class from the PyTorch ImageNet tutorial'''
  def __init__(self):
    self.reset()
  def reset(self):
    self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

#SSIM##
def SSIM(img_A, img_B):
    img1 = ToTensor()(img_A).unsqueeze(0)
    img2 = ToTensor()(img_B).unsqueeze(0)
    if torch.cuda.is_available():
        img1 = img1.cuda()
        img2 = img2.cuda()
    ssim_val = round(ssim(img1, img2).item(),2)
    return ssim_val

def psnr(img_A, img_B):
    score = cv2.PSNR(img_A, img_B)
    return score

##SAVE IMG##
def save_img(gt, hint, output, num):
    SSIM_VAL = SSIM(gt, output)
    PSNR = psnr(gt, output)
    cv2.imwrite("outputs/GroundTruth/"+str(num)+"gt.png", gt)
    cv2.imwrite("outputs/Hint/"+str(num)+"hint.png", hint)
    cv2.imwrite("outputs/Output/"+str(num)+"_ssim:"+str(SSIM_VAL)+"_psnr:"+str(PSNR)+".png", output)





# Train, Validation, Test Fuction

In [None]:
import torch
import tqdm
import cv2
import os
import shutil

##TRAIN##
def train(train_loader, model, criterion, optimizer, epoch):
    print('Starting training epoch {}'.format(epoch))
    model.train()
    use_cuda = True
    # Prepare value counters and timers
    losses = AverageMeter()

    for i, data in enumerate(tqdm.tqdm(train_loader)):
        if use_cuda:
            l = data["l"].to('cuda')
            ab = data["ab"].to('cuda')
            hint = data["hint"].to('cuda')
        else:
            l = data["l"]
            ab = data["ab"]
            hint = data["hint"]
        gt_image = torch.cat((l, ab), dim=1)
        hint_image = torch.cat((l, hint), dim=1)

        # Run forward pass
        output_hint = model(hint_image)
        loss = criterion(output_hint, gt_image)
        losses.update(loss.item(), hint_image.size(0))
        # Compute gradient and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print model accuracy -- in the code below, val refers to value, not validation
        if i % 100 == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i, len(train_loader), loss=losses))

    print('Finished training epoch {}'.format(epoch))


def validate(val_loader, model, criterion, save_images, epoch):
    model.eval()
    use_cuda = True
    # Prepare value counters and timers
    losses = AverageMeter()

    shutil.rmtree("outputs/Output")
    os.makedirs('outputs/Output', exist_ok=True)

    for i, data in enumerate(tqdm.tqdm(val_loader)):
        if use_cuda:
            l = data["l"].to('cuda')
            ab = data["ab"].to('cuda')
            hint = data["hint"].to('cuda')
        else:
            l = data["l"]
            ab = data["ab"]
            hint = data["hint"]
        gt_image = torch.cat((l, ab), dim=1)
        hint_image = torch.cat((l, hint), dim=1)
        output_hint = model(hint_image)

        loss = criterion(output_hint, gt_image)
        losses.update(loss.item(), hint_image.size(0))

        # Print model accuracy -- in the code below, val refers to both value and validation
        if i % 100 == 0:
            print('Validate: [{0}/{1}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(i, len(val_loader), loss=losses))
        out_hint_np = tensor2im(output_hint)
        out_hint_bgr = cv2.cvtColor(out_hint_np, cv2.COLOR_LAB2BGR)

        hint_np = tensor2im(hint_image)
        hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)

        gt_np = tensor2im(gt_image)
        gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_LAB2BGR)

        save_img(gt_bgr, hint_bgr, out_hint_bgr, i)

    print('Finished validation.')
    return losses.avg


def test(test_loader, model):
    model.eval()
    use_cuda = True

    for i, data in enumerate(tqdm.tqdm(test_loader)):
        if use_cuda:
            l = data["l"].to('cuda')
            hint = data["hint"].to('cuda')
            file_name = data["file_name"][0]
        else:
            l = data["l"]
            hint = data["hint"]
            file_name = data["file_name"][0]

        hint_image = torch.cat((l, hint), dim=1)
        output_hint = model(hint_image)

        out_hint_np = tensor2im(output_hint)
        out_hint_bgr = cv2.cvtColor(out_hint_np, cv2.COLOR_LAB2BGR)
        cv2.imwrite("outputs/test/"+file_name, out_hint_bgr)

    print('Finished test.')


# Training

In [None]:
import os
import torch
import torch.utils.data
from torch import nn

## DATALOADER ##
# Change to your data root directory
root_path = "./cv_project"
# Depend on runtime setting
use_cuda = True

train_dataset = ColorHintDataset(root_path, 256)
train_dataset.set_mode("train")

val_dataset = ColorHintDataset(root_path, 256)
val_dataset.set_mode("val")

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False)

model = UNet(nin=3, nout=3)
# print(model)
# PATH = "model-epoch-8-losses-0.00763.pth"
# model.load_state_dict(torch.load(PATH))

criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0)

# Move model and loss function to GPU
if use_cuda:
    criterion = criterion.cuda()
    model = model.cuda()
# Make folders and set parameters
os.makedirs('outputs/GroundTruth', exist_ok=True)
os.makedirs('outputs/Hint', exist_ok=True)
os.makedirs('outputs/Output', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
save_images = True
best_losses = 1e10
epochs = 3
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_dataloader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_dataloader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses < best_losses:
        best_losses = losses
        torch.save(model.state_dict(), 'checkpoints/model-epoch-{}-losses-{:.5f}.pth'.format(epoch + 1, losses))



# Testing

In [None]:
import os
import torch
import torch.utils.data
from torch import nn

import matplotlib.pyplot as plt
import numpy as np
import pylab

## DATALOADER ##
# Change to your data root directory
root_path = "./test_dataset"
# Depend on runtime setting
use_cuda = True

test_dataset = ColorHintDataset(root_path, 256)
test_dataset.set_mode("test")

test_dataloader = torch.utils.data.DataLoader(test_dataset)

model = UNet(nin=3, nout=3)
print(model)
PATH = "./checkpoints/model-epoch-3-losses-0.01682.pth"
model.load_state_dict(torch.load(PATH))
model.eval()

os.makedirs('./outputs/test', exist_ok=True)

# Move model and loss function to GPU
if use_cuda:
    model = model.cuda()
# Make folders and set parameters
with torch.no_grad():
    test(test_dataloader, model)


## calculate and save psnr, ssim ##

# change to your Output data directory
output_path = "./outputs/Output"
file_list = os.listdir(output_path)

ssim = np.zeros(len(file_list))
psnr = np.zeros(len(file_list))

for i, img_name in enumerate(file_list):
    print(img_name)
    name = img_name.replace('.png', '')   # remove '.png'
    temp = name.split('_')
    ssim[i] += float(temp[1].replace('ssim:', ''))
    psnr[i] += float(temp[2].replace('psnr:', ''))

ssim_avg = sum(ssim)/len(ssim)
psnr_avg = sum(psnr)/len(psnr)

print('Average of ssim: {}'.format(ssim_avg))
print('Average of psnr: {}'.format(psnr_avg))

np.save(os.path.join('./', 'ssim.npy'), ssim)
np.save(os.path.join('./', 'psnr.npy'), psnr)

# plot and save ssim curve
plt.figure()
plt.title('ssim')
pylab.xlim(0, len(file_list) + 1)
pylab.ylim(0, 1.1)
plt.plot(range(1, len(file_list) + 1), ssim, label='ssim')
plt.legend()
plt.savefig(os.path.join('./', 'ssim.pdf'))
plt.show()
plt.close()

# plot and save psnr curve
plt.figure()
plt.title('pnsr')
pylab.xlim(0, len(file_list) + 1)
pylab.ylim(0, 100)
plt.plot(range(1, len(file_list) + 1), psnr, label='psnr')
plt.legend()
plt.savefig(os.path.join('./', 'psnr.pdf'))
plt.show()
plt.close()

