<a href="https://colab.research.google.com/github/sswon314/CV_Colorization/blob/master/Computer_Vision_Team_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount("/content/gDrive")

!nvidia-smi

# Utils

In [None]:
# import modules
from torchvision import transforms
import cv2
import random
import numpy as np

import zipfile
import os

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim

import tqdm

import matplotlib.pyplot as plt
import matplotlib.image as img

### Transform (이미지 변환)

In [None]:
class ColorHintTransform(object):
    def __init__(self, size=256, mode="train"):
        super(ColorHintTransform, self).__init__()
        self.size = size
        self.mode = mode
        self.transform = transforms.Compose([transforms.ToTensor()])

    def bgr_to_lab(self, img):
        # rgb 값을 lab로 변경
        lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
        # l은 명도(gray), ab는 채색?인듯
        l, ab = lab[:, :, 0], lab[:, :, 1:]
        return l, ab

    def hint_mask(self, bgr, threshold=[0.95, 0.97, 0.99]):
        h, w, c = bgr.shape
        mask_threshold = random.choice(threshold)
        # h*w*1 크기의 매트릭스에 T or F로 (랜덤값이 threshold보다 크면 T)
        mask = np.random.random([h, w, 1]) > mask_threshold 
        return mask

    def img_to_mask(self, mask_img):
        mask = mask_img[:, :, 0, np.newaxis] >= 255
        return mask

    def __call__(self, img, mask_img=None):
        threshold = [0.95, 0.97, 0.99]
        if (self.mode == "train") | (self.mode == "val"):
            image = cv2.resize(img, (self.size, self.size))
            mask = self.hint_mask(image, threshold)

            # hint mask에서 true 인 부분의 픽셀만 값을 가짐
            hint_image = image * mask

            # 이미지의 l, ab값 추출
            # 힌트 이미지(특정 픽셀만 보이는 이미지)의 l, ab값 추출
            l, ab = self.bgr_to_lab(image)
            l_hint, ab_hint = self.bgr_to_lab(hint_image)

            # 각각의 값을 tensor로 변환해서 반환
            return self.transform(l), self.transform(ab), self.transform(ab_hint), self.transform(mask)

        elif self.mode == "test":
            image = cv2.resize(img, (self.size, self.size))
            mask = self.img_to_mask(mask_img)
            hint_image = image * mask

            l, _ = self.bgr_to_lab(image)
            _, ab_hint = self.bgr_to_lab(hint_image)

            return self.transform(l), self.transform(ab_hint), self.transform(mask)

        else:
            return NotImplementedError

### Data Unzip

In [None]:
# 드라이브에 있는 데이터 zip 파일 압축 해제

# Train, Val 데이터 압축해제
trainFileName="colorization_dataset.zip"
trainZipPath="/content/gDrive/MyDrive/CV/colorization_dataset.zip"

# zip path에 있는 현재 디렉토리에 복사 -> 압축해제 -> 복사한 zip파일 삭제
!cp "{trainZipPath}" .
!unzip -q "{trainFileName}"
!rm "{trainFileName}"


# Test 데이터 압축해제
testFileName="test_dataset.zip"
testZipPath="/content/gDrive/MyDrive/CV/test_dataset.zip"

# zip path에 있는 현재 디렉토리에 복사 -> 압축해제 -> 복사한 zip파일 삭제
!cp "{testZipPath}" .
!unzip -q "{testFileName}" -d "cv_project"
!rm "{testFileName}"


# 압축 해제 후 디렉토리
# cv_project
#   └ train
#   └ val
#   └ test_dataset
#       └ hint
#       └ mask

### Data Augmentation

In [None]:
root_path = "cv_project"
train_dir = os.path.join(root_path, "train")
examples = [os.path.join(root_path, "train", dirs) for dirs in os.listdir(train_dir)]

print("Before Augmentation: {}".format(len(examples)))
for f in examples:
    file = cv2.imread(f)
    file = cv2.cvtColor(file , cv2.COLOR_BGR2RGB)
    conv_file = cv2.flip(file,1)
    rot_file = cv2.rotate(file , cv2.ROTATE_180)
    cv2.imwrite(os.path.join(train_dir, 'flip_image{}.png'.format(examples.index(f))), conv_file)
    cv2.imwrite(os.path.join(train_dir, 'rot_image{}.png'.format(examples.index(f))),rot_file)

examples = [os.path.join(root_path, "train", dirs) for dirs in os.listdir(train_dir)]
print("After Augmentation: {}".format(len(examples)))

### DataLoader

In [None]:
class ColorHintDataset(data.Dataset):
    def __init__(self, root_path, size, mode="train"):
        super(ColorHintDataset, self).__init__()

        self.root_path = root_path
        self.size = size
        self.mode = mode
        self.transforms = ColorHintTransform(self.size, self.mode)
        self.examples = None
        self.hint = None
        self.mask = None

        if self.mode == "train":
            train_dir = os.path.join(self.root_path, "train")
            self.examples = [os.path.join(self.root_path, "train", dirs) for dirs in os.listdir(train_dir)]
        elif self.mode == "val":
            val_dir = os.path.join(self.root_path, "val")
            self.examples = [os.path.join(self.root_path, "val", dirs) for dirs in os.listdir(val_dir)]
        elif self.mode == "test":
            hint_dir = os.path.join(self.root_path, "test_dataset/hint")
            mask_dir = os.path.join(self.root_path, "test_dataset/mask")
            self.hint = [os.path.join(self.root_path, "test_dataset/hint", dirs) for dirs in os.listdir(hint_dir)]
            self.mask = [os.path.join(self.root_path, "test_dataset/mask", dirs) for dirs in os.listdir(mask_dir)]
        else:
            raise NotImplementedError

    def __len__(self):
        if self.mode != "test":
            return len(self.examples)
        else:
            return len(self.hint)

    def __getitem__(self, idx):
        if self.mode == "test":
            hint_file_name = self.hint[idx]
            mask_file_name = self.mask[idx]
            hint_img = cv2.imread(hint_file_name)
            mask_img = cv2.imread(mask_file_name)

            input_l, input_hint, mask = self.transforms(hint_img, mask_img)
            sample = {"l": input_l, "hint": input_hint, "mask": mask,
                      "file_name": "image_%06d.png" % int(os.path.basename(hint_file_name).split('.')[0])}
        else:
            file_name = self.examples[idx]
            img = cv2.imread(file_name)
            l, ab, hint, mask = self.transforms(img)
            sample = {"l": l, "ab": ab, "hint": hint, "mask": mask}

        return sample

In [None]:
# 텐서를 이미지로 변환
def tensor2im(input_image, imtype=np.uint8):
    if isinstance(input_image, torch.Tensor):
        image_tensor = input_image.data
    else:
        return input_image
    image_numpy = image_tensor[0].cpu().float().numpy()
    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3, 1, 1))
    image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0))), 0, 1) * 255.0
    return image_numpy.astype(imtype)


# Change to your data root directory
root_path = "cv_project"
# Depend on runtime setting
use_cuda = True
# use_cuda = torch.cuda.is_available()

train_dataset = ColorHintDataset(root_path, 256, "train")
train_dataloader = data.DataLoader(train_dataset, batch_size=8, shuffle=True)

val_dataset = ColorHintDataset(root_path, 256, "val")
val_dataloader = data.DataLoader(val_dataset, batch_size=8, shuffle=False)

test_dataset = ColorHintDataset(root_path, 256, "test")
test_dataloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False)

print('train dataset length:' , len(train_dataloader)) 
print('val dataset length: ', len(val_dataloader))
print('test dataset length: ', len(test_dataloader))

# Main Code

### Network

In [None]:
# https://github.com/LeeJunHyun/Image_Segmentation
class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		        nn.BatchNorm2d(ch_out),
			      nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ch_out , ch_out , 3,2,1,1),
            nn.BatchNorm2d(ch_out),
			      nn.ReLU(inplace=True),
         )

    def forward(self,x):
        x = self.up(x)
        return x

class Recurrent_block(nn.Module):
    def __init__(self,ch_out):
        super(Recurrent_block,self).__init__()
        self.ch_out = ch_out
        self.conv = nn.Sequential(
            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		        nn.BatchNorm2d(ch_out),
			      nn.ReLU(inplace=True)
        )

    def forward(self,x):
      x1 = self.conv(x)
      return x1

class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi

class RRCNN_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(RRCNN_block,self).__init__()
        self.RCNN = nn.Sequential(
            Recurrent_block(ch_out),
            Recurrent_block(ch_out)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)

    def forward(self,x):
        x = self.Conv_1x1(x)
        x1 = self.RCNN(x)
        return x+x1

class R2AttU_Net(nn.Module):
    def __init__(self,img_ch=4,output_ch=3,t=2):
        super(R2AttU_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Upsample = nn.Upsample(scale_factor=2)

        self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64)

        self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128)
        
        self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256)
        
        self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512)
        
        self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024)
        

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512)
        
        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.RRCNN1(x)

        x2 = self.Maxpool(x1)
        x2 = self.RRCNN2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.RRCNN3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.RRCNN4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.RRCNN5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_RRCNN5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_RRCNN4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_RRCNN3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_RRCNN2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

### Train

In [None]:
# 훈련 결과 저장 위치
# 훈련 결과 저장 위치는 드라이브에 설정해서 나중에 불러와서 쓸 수 있도록 함
save_path = '/content/gDrive/MyDrive/CV'
os.makedirs(save_path, exist_ok= True)
best_output_path = os.path.join(save_path, 'best_model.tar')  # 가장 좋은 성능 모델
last_output_path = os.path.join(save_path, 'last_model.tar')  # 마지막으로 학습한 모델

### 하이퍼파라미터 ###
# 한번에 학습할 epoch수
EPOCH=200
LR = 0.00025

# 모델 설정
trainNet = R2AttU_Net().cuda()

# loss function 및 optimizer 설정
criterion = nn.L1Loss()
optimizer = optim.Adam(trainNet.parameters(), lr=LR)


# 이전에 학습했던 모델이 있으면 그 모델의 weight값으로 설정
if os.path.isfile(last_output_path):
  print("체크포인트 불러옴")
  checkpoint_dict=torch.load(last_output_path)
  train_info = checkpoint_dict["train_info"]
  val_info = checkpoint_dict["val_info"]
  best_loss = checkpoint_dict["best_loss"]
  start_epoch = checkpoint_dict["epoch"]+1
  print("Epoch {}번부터 시작".format(start_epoch))

  trainNet.load_state_dict(checkpoint_dict['model_weight'])

  epochAxis=np.arange(0, len(checkpoint_dict["train_info"]))

  plt.title("LOSS")
  plt.plot(epochAxis, [info["loss"].cpu() for info in checkpoint_dict["train_info"]], epochAxis, [info["loss"].cpu() for info in checkpoint_dict["val_info"]], "r-")
  plt.legend(["TRAIN","VALIDATION"])
  plt.show()

# 없으면 기본 값 사용
else:
  print("체크포인트 없음")
  train_info=[]
  val_info=[]
  best_loss = 100
  start_epoch=0

In [None]:
# Training 시킬 함수
def train_model(net, train_dataloader):
  total_loss = 0
  iteration = 0
  net.train()

  for i, data in enumerate(tqdm.auto.tqdm(train_dataloader)):
    if use_cuda:
      l_imgs = data["l"].to('cuda')
      ab_imgs = data["ab"].to('cuda')
      hint_imgs = data["hint"].to('cuda')
      mask_imgs = data["mask"].to('cuda')
    else:
      l_imgs = data["l"]
      ab_imgs = data["ab"]
      hint_imgs = data["hint"]
      mask = data["mask"]

    gt_image = torch.cat((l_imgs,ab_imgs), dim=1)
    hint_image = torch.cat((l_imgs, hint_imgs, mask_imgs), dim=1)
    
    optimizer.zero_grad()
    output = net(hint_image)

    loss = criterion(output, gt_image)
    loss.backward()

    optimizer.step()
    total_loss += loss.detach()
    iteration += 1
  
  total_loss /= iteration
  return total_loss

# Validation 시킬 함수
def val_model(net, val_dataloader):
  total_loss = 0
  iteration = 0
  net.eval()

  for i, data in enumerate(tqdm.auto.tqdm(val_dataloader)):
    if use_cuda:
      l_imgs = data["l"].to('cuda')
      ab_imgs = data["ab"].to('cuda')
      hint_imgs = data["hint"].to('cuda')
      mask_imgs = data["mask"].to('cuda')
    else:
      l_imgs = data["l"]
      ab_imgs = data["ab"]
      hint_imgs = data["hint"]

    gt_image = torch.cat((l_imgs,ab_imgs), dim=1)
    hint_image = torch.cat((l_imgs, hint_imgs, mask_imgs), dim=1)
   
    output = net(hint_image).squeeze(1)

    loss = criterion(output, gt_image)

    total_loss += loss.detach()
    iteration += 1

  total_loss /= iteration
  return total_loss

In [None]:
# 이전 학습 모델의 다음 epoch부터 n번 실행
for epoch in range(start_epoch,start_epoch+EPOCH):
  t_loss = train_model(trainNet, val_dataloader)
  print('[TRAINING] Epoch: {} train_loss: {}'.format(epoch, t_loss))
  train_info.append({'loss': t_loss})

  with torch.no_grad():
    v_loss = val_model(trainNet, val_dataloader)
    print('[VALIDATION] Epoch: {} validation_loss: {}'.format(epoch, v_loss))
    val_info.append({'loss': v_loss})
  
  # 검증 loss가 제일 모델 저장
  if best_loss > v_loss:
    best_loss = v_loss
    print("최고 성능 모델 저장")
    torch.save({
      "memo": "This is Best Model",
      "epoch": epoch,
      "best_loss": best_loss,
      "model_weight":trainNet.state_dict()
    }, best_output_path)

  # 해당 epoch에 대한 모델 저장
  print("{}번째 모델 저장".format(epoch))
  torch.save({
      "memo": "This is Last Model",
      "epoch": epoch,
      "best_loss": best_loss,
      "train_info":train_info,
      "val_info":val_info,
      "model_weight":trainNet.state_dict(),
  }, last_output_path)

### Test

In [None]:
# Testing 시킬 함수
def test_model(net, test_dataloader):
  net.eval()
  for i, data in enumerate(tqdm.auto.tqdm(test_dataloader)):
    if use_cuda:
        l = data["l"].to('cuda')
        hint = data["hint"].to('cuda')
        mask = data["mask"].to('cuda')
    else:
        l = data["l"]
        hint = data["hint"]
        mask = data["mask"]

    hint_image = torch.cat((l, hint, mask), dim=1)

    output = net(hint_image).squeeze(1)
    output_np = tensor2im(output)

    output_bgr = cv2.cvtColor(output_np, cv2.COLOR_LAB2BGR)

    # 결과 이미지 저장 (제출용)
    os.makedirs('cv_project/outputs/', exist_ok=True) 
    i = str(i).zfill(6)
    cv2.imwrite('cv_project/outputs/' + data['file_name'][0], output_bgr)

    # 결과 이미지를 보여줌
    plt.imshow(output_bgr)
    plt.show()

    # 아무 값이나 입력해서 다음 이미지 표시
    # input 빼면 자동 저장
    # input()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

save_path = '/content/gDrive/MyDrive/CV'
last_model_path = os.path.join(save_path, 'last_model.tar')
best_model_path = os.path.join(save_path, 'best_model.tar')

if os.path.isfile(last_model_path):
  last_state_dict=torch.load(last_model_path)
  print("Epoch {}까지 돌렸음".format(last_state_dict["epoch"]))
  epochAxis=np.arange(0, len(last_state_dict["train_info"]))

  plt.title("LOSS")
  plt.plot(epochAxis, [info["loss"].cpu() for info in last_state_dict["train_info"]], epochAxis, [info["loss"].cpu() for info in last_state_dict["val_info"]], "r-")
  plt.legend(["TRAIN","VALIDATION"])
  plt.show()


if os.path.isfile(best_model_path):
  best_state_dict=torch.load(best_model_path)
  print("Epoch {}일때 최고 성능이었음".format(best_state_dict["epoch"]))

  testNet = R2AttU_Net().cuda()

  testNet.load_state_dict(best_state_dict['model_weight'])

  test_model(testNet, test_dataloader)