<a href="https://colab.research.google.com/github/sswon314/DL_Object_Detection_Project/blob/master/Computer_Vision_Team_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount("/content/gDrive")

# Utils

### Transform (이미지 변환)

In [None]:
from torchvision import transforms

import cv2
import random
import numpy as np


class ColorHintTransform(object):
    def __init__(self, size=256, mode="train"):
        super(ColorHintTransform, self).__init__()
        self.size = size
        self.mode = mode
        self.transform = transforms.Compose([transforms.ToTensor()])

    def bgr_to_lab(self, img):
        # rgb 값을 lab로 변경
        lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
        # l은 명도(gray), ab는 채색?인듯
        l, ab = lab[:, :, 0], lab[:, :, 1:]
        return l, ab

    def hint_mask(self, bgr, threshold=[0.95, 0.97, 0.99]):
        h, w, c = bgr.shape
        mask_threshold = random.choice(threshold)
        # h*w*1 크기의 매트릭스에 T or F로 (랜덤값이 threshold보다 크면 T)
        mask = np.random.random([h, w, 1]) > mask_threshold 
        return mask

    def img_to_mask(self, mask_img):
        mask = mask_img[:, :, 0, np.newaxis] >= 255
        return mask

    def __call__(self, img, mask_img=None):
        threshold = [0.95, 0.97, 0.99]
        if (self.mode == "train") | (self.mode == "val"):
            image = cv2.resize(img, (self.size, self.size))
            mask = self.hint_mask(image, threshold)

            # hint mask에서 true 인 부분의 픽셀만 값을 가짐
            hint_image = image * mask

            # 이미지의 l, ab값 추출
            # 힌트 이미지(특정 픽셀만 보이는 이미지)의 l, ab값 추출
            l, ab = self.bgr_to_lab(image)
            l_hint, ab_hint = self.bgr_to_lab(hint_image)

            # 각각의 값을 tensor로 변환해서 반환
            return self.transform(l), self.transform(ab), self.transform(ab_hint)

        elif self.mode == "test":
            image = cv2.resize(img, (self.size, self.size))
            hint_image = image * self.img_to_mask(mask_img)

            l, _ = self.bgr_to_lab(image)
            _, ab_hint = self.bgr_to_lab(hint_image)

            return self.transform(l), self.transform(ab_hint)

        else:
            return NotImplementedError

### Data Unzip

In [None]:
# 드라이브에 있는 데이터 zip 파일 압축 해제
import zipfile
import os

# Train, Val 데이터 압축해제
trainFileName="colorization_dataset.zip"
trainZipPath="/content/gDrive/MyDrive/CV/colorization_dataset.zip"

# zip path에 있는 현재 디렉토리에 복사 -> 압축해제 -> 복사한 zip파일 삭제
!cp "{trainZipPath}" .
!unzip -q "{trainFileName}"
!rm "{trainFileName}"


# Test 데이터 압축해제
testFileName="test_dataset.zip"
testZipPath="/content/gDrive/MyDrive/CV/test_dataset.zip"

# zip path에 있는 현재 디렉토리에 복사 -> 압축해제 -> 복사한 zip파일 삭제
!cp "{testZipPath}" .
!unzip -q "{testFileName}" -d "cv_project"
!rm "{testFileName}"


# 압축 해제 후 디렉토리
# cv_project
#   └ train
#   └ val
#   └ test_dataset
#       └ hint
#       └ mask

### DataLoader

In [None]:
import torch
import torch.utils.data as data
import os
import cv2


class ColorHintDataset(data.Dataset):
    def __init__(self, root_path, size, mode="train"):
        super(ColorHintDataset, self).__init__()

        self.root_path = root_path
        self.size = size
        self.mode = mode
        self.transforms = ColorHintTransform(self.size, self.mode)
        self.examples = None
        self.hint = None
        self.mask = None

        if self.mode == "train":
            train_dir = os.path.join(self.root_path, "train")
            self.examples = [os.path.join(self.root_path, "train", dirs) for dirs in os.listdir(train_dir)]
        elif self.mode == "val":
            val_dir = os.path.join(self.root_path, "val")
            self.examples = [os.path.join(self.root_path, "val", dirs) for dirs in os.listdir(val_dir)]
        elif self.mode == "test":
            hint_dir = os.path.join(self.root_path, "test_dataset/hint")
            mask_dir = os.path.join(self.root_path, "test_dataset/mask")
            self.hint = [os.path.join(self.root_path, "test_dataset/hint", dirs) for dirs in os.listdir(hint_dir)]
            self.mask = [os.path.join(self.root_path, "test_dataset/mask", dirs) for dirs in os.listdir(mask_dir)]
        else:
            raise NotImplementedError

    def __len__(self):
        if self.mode != "test":
            return len(self.examples)
        else:
            return len(self.hint)

    def __getitem__(self, idx):
        if self.mode == "test":
            hint_file_name = self.hint[idx]
            mask_file_name = self.mask[idx]
            hint_img = cv2.imread(hint_file_name)
            mask_img = cv2.imread(mask_file_name)

            input_l, input_hint = self.transforms(hint_img, mask_img)
            sample = {"l": input_l, "hint": input_hint,
                      "file_name": "image_%06d.png" % int(os.path.basename(hint_file_name).split('.')[0])}
        else:
            file_name = self.examples[idx]
            img = cv2.imread(file_name)
            l, ab, hint = self.transforms(img)
            sample = {"l": l, "ab": ab, "hint": hint}

        return sample

In [None]:
import torch
import torch.utils.data as data
import cv2
import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as img

# 텐서를 이미지로 변환
def tensor2im(input_image, imtype=np.uint8):
    if isinstance(input_image, torch.Tensor):
        image_tensor = input_image.data
    else:
        return input_image
    image_numpy = image_tensor[0].cpu().float().numpy()
    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3, 1, 1))
    image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0))), 0, 1) * 255.0
    return image_numpy.astype(imtype)


# Change to your data root directory
root_path = "cv_project"
# Depend on runtime setting
use_cuda = True
# use_cuda = torch.cuda.is_available()

train_dataset = ColorHintDataset(root_path, 256, "train")
train_dataloader = data.DataLoader(train_dataset, batch_size=4, shuffle=True)

val_dataset = ColorHintDataset(root_path, 256, "val")
val_dataloader = data.DataLoader(val_dataset, batch_size=4, shuffle=False)

test_dataset = ColorHintDataset(root_path, 256, "test")
test_dataloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False)

print('train dataset length:' , len(train_dataloader)) 
print('val dataset length: ', len(val_dataloader))
print('test dataset length: ', len(test_dataloader))

### 이미지 변환 결과 확인

In [None]:
def checkImage(dataloader, mode):
  if mode=="train" or mode=="val":
    # 트레이닝or검증 이미지 샘플 보기
    for i, data in enumerate(tqdm.tqdm(dataloader)):
      if use_cuda:
          l = data["l"].to('cuda')
          ab = data["ab"].to('cuda')
          hint = data["hint"].to('cuda')
      else:
          l = data["l"]
          ab = data["ab"]
          hint = data["hint"]

      # 실제 이미지 = l, ab를 합친 이미지
      gt_image = torch.cat((l, ab), dim=1)

      # 힌트 이미지 = l, 힌트ab(특정 필셀)을 합친 이미지
      hint_image = torch.cat((l, hint), dim=1)

      gt_np = tensor2im(gt_image)
      hint_np = tensor2im(hint_image)

      # 각각의 lab이미지를 rgb로 변환
      gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_LAB2RGB)
      hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2RGB)

      # 실제 이미지와 힌트 이미지를 따로 보여줌
      plt.subplot(1,2,1)
      plt.imshow(gt_bgr)

      plt.subplot(1,2,2)
      plt.imshow(hint_bgr)
      plt.show()

      # input()
    
  elif mode=="test":
    # 테스트 이미지 샘플 보기
    for i, data in enumerate(tqdm.tqdm(test_dataloader)):
      if use_cuda:
        l = data["l"].to('cuda')
        hint = data["hint"].to('cuda')
      else:
        l = data["l"]
        hint = data["hint"]

      # 힌트 이미지 = l, 힌트(특정 필셀)을 합친 이미지
      hint_image = torch.cat((l, hint), dim=1)
      hint_np = tensor2im(hint_image)

      # 각각의 lab이미지를 rgb로 변환
      hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2RGB)

      # 테스트용 이미지를 따로 보여줌
      plt.figure(1)
      plt.imshow(hint_bgr)
      plt.show()

      # input()
    
# checkImage(train_dataloader,"train")
# checkImage(val_dataloader,"val")
# checkImage(test_dataloader,"test")

# Main Code

### Network

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.double_conv = nn.Sequential(nn.Conv2d(nin, nout, 3, padding=1, stride=1),
                                         nn.BatchNorm2d(nout),
                                         nn.ReLU(inplace=True),
                                         nn.Conv2d(nout, nout, 3, padding=1, stride=1),
                                         nn.BatchNorm2d(nout),
                                         nn.ReLU(inplace=True)
                                         )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.down_conv = nn.Sequential(nn.MaxPool2d(2),
                                       DoubleConv(nin, nout))

    def forward(self, x):
        return self.down_conv(x)


class Up(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up_conv = DoubleConv(nin, nout)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # padding
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2))

        x = torch.cat([x2, x1], dim=1)
        x = self.up_conv(x)
        return x


class OutConv(nn.Module):
    def __init__(self, nin, nout):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class MyUNet(nn.Module):
    def __init__(self):
        super(MyUNet, self).__init__()
        self.in_conv = DoubleConv(3, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024 // 2)
        self.up1 = Up(1024, 512 //  2)
        self.up2 = Up(512, 256 // 2)
        self.up3 = Up(256, 128 // 2)
        self.up4 = Up(128, 64)
        self.out_conv = OutConv(64, 2)

    def forward(self, x):
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.out_conv(x)
        return x

In [None]:
# https://github.com/ousinkou/Gachon_SW_Colorization_Contest
import torch
import torch.nn as nn

class upsample_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(upsample_block, self).__init__()

        self.up = nn.Sequential(
            nn.Conv2d(ch_in,ch_out,3,1,1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True ),
            nn.ConvTranspose2d(ch_out , ch_out , 3,2,1,1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True )
        )

    def forward(self, x):
        x = self.up(x)

        return x

class Residual_recurrent_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(Residual_recurrent_block, self).__init__()

        self.RCNN = nn.Sequential(
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out)
            #nn.ReLU(inplace=True)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.Conv_1x1(x)

        x1 = self.RCNN(x)
        return x + x1



class Attention_block(nn.Module):
    def __init__(self, upsample, downsample, ch_result):
        super(Attention_block, self).__init__()

        self.skip = nn.Sequential(
            nn.Conv2d(upsample, ch_result, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(ch_result)
        )

        self.up = nn.Sequential(
            nn.Conv2d(downsample, ch_result, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(ch_result)
        )

        self.concat = nn.Sequential(
            nn.Conv2d(ch_result, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        s1 = self.skip(g)

        u1 = self.up(x)
        attd = self.relu(s1 + u1)
        attd = self.concat(attd)

        return x * attd

class ResAttdU_Net(nn.Module):
    def __init__(self, img_ch=3, output_ch=3):
        super(ResAttdU_Net, self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.DownSample1 = Residual_recurrent_block(ch_in=img_ch, ch_out=64 )
        self.DownSample1_1 = Residual_recurrent_block(ch_in=64 , ch_out=64)

        self.DownSample2 = Residual_recurrent_block(ch_in=64, ch_out=128)
        self.DownSample2_1 = Residual_recurrent_block(ch_in=128, ch_out=128)

        self.DownSample3 = Residual_recurrent_block(ch_in=128, ch_out=256)
        self.DownSample3_1 = Residual_recurrent_block(ch_in=256, ch_out=256)

        self.DownSample4 = Residual_recurrent_block(ch_in=256, ch_out=512 )
        self.DownSample4_1 = Residual_recurrent_block(ch_in=512, ch_out=512)

        self.DownSample5 = Residual_recurrent_block(ch_in=512, ch_out=1024)
        self.DownSample5_1 = Residual_recurrent_block(ch_in=1024, ch_out=1024)

        self.Up5 = upsample_block(ch_in=1024, ch_out=512)
        self.Att5 = Attention_block(upsample=512, downsample=512, ch_result=256)
        self.UpSample5 = Residual_recurrent_block(ch_in=1024, ch_out=512)
        self.UpSample5_1 = Residual_recurrent_block(ch_in=512, ch_out=512)

        self.Up4 = upsample_block(ch_in=512, ch_out=256)
        self.Att4 = Attention_block(upsample=256, downsample=256, ch_result=128)
        self.UpSample4 = Residual_recurrent_block(ch_in=512, ch_out=256 )
        self.UpSample4_1 = Residual_recurrent_block(ch_in=256, ch_out=256)

        self.Up3 = upsample_block(ch_in=256, ch_out=128)
        self.Att3 = Attention_block(upsample=128, downsample=128, ch_result=64)
        self.UpSample3 = Residual_recurrent_block(ch_in=256, ch_out=128 )
        self.UpSample3_1 = Residual_recurrent_block(ch_in=128, ch_out=128)

        self.Up2 = upsample_block(ch_in=128, ch_out=64)
        self.Att2 = Attention_block(upsample=64, downsample=64, ch_result=32)
        self.UpSample2 = Residual_recurrent_block(ch_in=128, ch_out=64)
        self.UpSample2_1 = Residual_recurrent_block(ch_in=64, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):

        # encoding path
        x1 = self.DownSample1(x)
        x1 = self.DownSample1_1(x1)

        x2 = self.Maxpool(x1)
        x2 = self.DownSample2(x2)
        x2 = self.DownSample2_1(x2)

        x3 = self.Maxpool(x2)
        x3 = self.DownSample3(x3)
        x3 = self.DownSample3_1(x3)

        x4 = self.Maxpool(x3)
        x4 = self.DownSample4(x4)
        x4 = self.DownSample4_1(x4)

        x5 = self.Maxpool(x4)
        x5 = self.DownSample5(x5)
        x5 = self.DownSample5_1(x5)

        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5, x=x4)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.UpSample5(d5)
        d5 = self.UpSample5_1(d5)

        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4, x=x3)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.UpSample4(d4)
        d4 = self.UpSample4_1(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3, x=x2)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.UpSample3(d3)
        d3 = self.UpSample3_1(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2, x=x1)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.UpSample2(d2)
        d2 = self.UpSample2_1(d2)

        d1 = self.Conv_1x1(d2)

        return d1

In [None]:
# https://github.com/psyrocloud/MS-SSIM_L1_LOSS
import torch
import torch.nn as nn
import torch.nn.functional as F


class MS_SSIM_L1_LOSS(nn.Module):
    # Have to use cuda, otherwise the speed is too slow.
    def __init__(self, gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0],
                 data_range = 1.0,
                 K=(0.01, 0.03),
                 alpha=0.025,
                 compensation=200.0,
                 cuda_dev=0,):
        super(MS_SSIM_L1_LOSS, self).__init__()
        self.DR = data_range
        self.C1 = (K[0] * data_range) ** 2
        self.C2 = (K[1] * data_range) ** 2
        self.pad = int(2 * gaussian_sigmas[-1])
        self.alpha = alpha
        self.compensation=compensation
        filter_size = int(4 * gaussian_sigmas[-1] + 1)
        g_masks = torch.zeros((3*len(gaussian_sigmas), 1, filter_size, filter_size))
        for idx, sigma in enumerate(gaussian_sigmas):
            # r0,g0,b0,r1,g1,b1,...,rM,gM,bM
            g_masks[3*idx+0, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
            g_masks[3*idx+1, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
            g_masks[3*idx+2, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
        self.g_masks = g_masks.cuda(cuda_dev)

    def _fspecial_gauss_1d(self, size, sigma):
        """Create 1-D gauss kernel
        Args:
            size (int): the size of gauss kernel
            sigma (float): sigma of normal distribution
        Returns:
            torch.Tensor: 1D kernel (size)
        """
        coords = torch.arange(size).to(dtype=torch.float)
        coords -= size // 2
        g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
        g /= g.sum()
        return g.reshape(-1)

    def _fspecial_gauss_2d(self, size, sigma):
        """Create 2-D gauss kernel
        Args:
            size (int): the size of gauss kernel
            sigma (float): sigma of normal distribution
        Returns:
            torch.Tensor: 2D kernel (size x size)
        """
        gaussian_vec = self._fspecial_gauss_1d(size, sigma)
        return torch.outer(gaussian_vec, gaussian_vec)

    def forward(self, x, y):
        b, c, h, w = x.shape
        mux = F.conv2d(x, self.g_masks, groups=3, padding=self.pad)
        muy = F.conv2d(y, self.g_masks, groups=3, padding=self.pad)

        mux2 = mux * mux
        muy2 = muy * muy
        muxy = mux * muy

        sigmax2 = F.conv2d(x * x, self.g_masks, groups=3, padding=self.pad) - mux2
        sigmay2 = F.conv2d(y * y, self.g_masks, groups=3, padding=self.pad) - muy2
        sigmaxy = F.conv2d(x * y, self.g_masks, groups=3, padding=self.pad) - muxy

        # l(j), cs(j) in MS-SSIM
        l  = (2 * muxy    + self.C1) / (mux2    + muy2    + self.C1)  # [B, 15, H, W]
        cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2)

        lM = l[:, -1, :, :] * l[:, -2, :, :] * l[:, -3, :, :]
        PIcs = cs.prod(dim=1)

        loss_ms_ssim = 1 - lM*PIcs  # [B, H, W]

        loss_l1 = F.l1_loss(x, y, reduction='none')  # [B, 3, H, W]
        # average l1 loss in 3 channels
        gaussian_l1 = F.conv2d(loss_l1, self.g_masks.narrow(dim=0, start=-3, length=3),
                               groups=3, padding=self.pad).mean(1)  # [B, H, W]

        loss_mix = self.alpha * loss_ms_ssim + (1 - self.alpha) * gaussian_l1 / self.DR
        loss_mix = self.compensation*loss_mix

        return loss_mix.mean()

### Train

In [None]:
# 훈련 결과 저장 위치
# 훈련 결과 저장 위치는 드라이브에 설정해서 나중에 불러와서 쓸 수 있도록 함
save_path = '/content/gDrive/MyDrive/CV'
os.makedirs(save_path, exist_ok= True)
best_output_path = os.path.join(save_path, 'best_model.tar')  # 가장 좋은 성능 모델
last_output_path = os.path.join(save_path, 'last_model.tar')  # 마지막으로 학습한 모델

### 하이퍼파라미터 ###
# 한번에 학습할 epoch수
EPOCH=100
LR = 0.0005

# 모델 설정
# trainNet = MyUNet().cuda()
trainNet = ResAttdU_Net().cuda()

# loss function 및 optimizer 설정
import torch.optim as optim
# criterion = nn.L1Loss()
criterion = MS_SSIM_L1_LOSS()
optimizer = torch.optim.Adam(trainNet.parameters(), lr=LR)


# 이전에 학습했던 모델이 있으면 그 모델의 weight값으로 설정
# state_dict = memo, epoch, best_loss, train_info, val_info, model_weight, optim_state
if os.path.isfile(last_output_path):
  print("체크포인트 불러옴")
  state_dict=torch.load(last_output_path)
  train_info = state_dict["train_info"]
  val_info = state_dict["val_info"]
  best_loss = state_dict["best_loss"]
  start_epoch=state_dict["epoch"]+1
  print("Epoch {}번부터 시작".format(start_epoch))

  trainNet.load_state_dict(state_dict['model_weight'])

# 없으면 기본 값 사용
else:
  print("체크포인트 없음")
  train_info=[]
  val_info=[]
  best_loss = 100
  start_epoch=0

In [None]:
# Training 시킬 함수
def train_model(net, train_dataloader):
  total_loss = 0
  iteration = 0
  net.train()

  for i, data in enumerate(tqdm.auto.tqdm(train_dataloader)):
    if use_cuda:
      l_imgs = data["l"].to('cuda')
      ab_imgs = data["ab"].to('cuda')
      hint_imgs = data["hint"].to('cuda')
    else:
      l_imgs = data["l"]
      ab_imgs = data["ab"]
      hint_imgs = data["hint"]

    gt_image = torch.cat((l_imgs,ab_imgs), dim=1)
    hint_image = torch.cat((l_imgs, hint_imgs), dim=1)

    gt_image = gt_image.float().cuda()
    hint_image = hint_image.float().cuda()
    
    optimizer.zero_grad()
    output = net(hint_image).squeeze(1)
    # output = torch.cat((l_imgs, output), dim=1)

    loss = criterion(output, gt_image)
    # loss = criterion(output, ab_imgs)
    loss.backward()

    optimizer.step()
    total_loss += loss.detach()
    iteration += 1
  
  total_loss /= iteration
  return total_loss

# Validation 시킬 함수
def val_model(net, val_dataloader):
  total_loss = 0
  iteration = 0
  net.eval()

  for i, data in enumerate(tqdm.auto.tqdm(val_dataloader)):
    if use_cuda:
      l_imgs = data["l"].to('cuda')
      ab_imgs = data["ab"].to('cuda')
      hint_imgs = data["hint"].to('cuda')
    else:
      l_imgs = data["l"]
      ab_imgs = data["ab"]
      hint_imgs = data["hint"]

    gt_image = torch.cat((l_imgs,ab_imgs), dim=1)
    hint_image = torch.cat((l_imgs, hint_imgs), dim=1)

    gt_image = gt_image.float().cuda()
    hint_image = hint_image.float().cuda()
   
    output = net(hint_image).squeeze(1)
    # output = torch.cat((l_imgs, output), dim=1)

    loss = criterion(output, gt_image)
    # loss = criterion(output, ab_imgs)

    total_loss += loss.detach()
    iteration += 1

  total_loss /= iteration
  return total_loss

In [None]:
# 이전 학습 모델의 다음 epoch부터 n번 실행
for epoch in range(start_epoch,start_epoch+EPOCH):
  t_loss = train_model(trainNet, train_dataloader)
  print('[TRAINING] Epoch: {} train_score: {}'.format(epoch, t_loss))
  train_info.append({'loss': t_loss})

  with torch.no_grad():
    v_loss = val_model(trainNet, val_dataloader)
    print('[VALIDATION] Epoch: {} loss: {}'.format(epoch, v_loss))
    val_info.append({'loss': v_loss})
  
  # 검증 loss가 제일 모델 저장
  if best_loss > v_loss:
    best_loss = v_loss
    print("최고 성능 모델 저장")
    torch.save({
      "memo": "This is Best Model",
      "epoch": epoch,
      "best_loss": best_loss,
      "train_info":train_info,
      "val_info":val_info,
      "model_weight":trainNet.state_dict(),
    }, best_output_path)

  # 해당 epoch에 대한 모델 저장
  print("{}번째 모델 저장".format(epoch))
  torch.save({
      "memo": "This is Last Model",
      "epoch": epoch,
      "best_loss": best_loss,
      "train_info":train_info,
      "val_info":val_info,
      "model_weight":trainNet.state_dict(),
  }, last_output_path)

### Test

In [None]:
# Testing 시킬 함수
def test_model(net, test_dataloader):
  net.eval()
  for i, data in enumerate(tqdm.auto.tqdm(test_dataloader)):
    if use_cuda:
        l = data["l"].to('cuda')
        hint = data["hint"].to('cuda')
    else:
        l = data["l"]
        hint = data["hint"]

    hint_image = torch.cat((l, hint), dim=1)
    hint_np = tensor2im(hint_image)
    hint_image = hint_image.float().cuda()

    output = net(hint_image).squeeze(1)
    # output = torch.cat((l, output), dim=1)
    output_np = tensor2im(output)

    hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2RGB)
    output_bgr = cv2.cvtColor(output_np, cv2.COLOR_LAB2BGR)

    # 결과 이미지 저장 (제출용)
    os.makedirs('cv_project/outputs/', exist_ok=True) 
    i = str(i).zfill(6)
    cv2.imwrite('cv_project/outputs/' + data['file_name'][0], output_bgr)

    # 테스트 이미지와 결과 이미지를 동시에 보여줌
    plt.subplot(1,2,1)
    plt.imshow(hint_bgr)

    plt.subplot(1,2,2)
    plt.imshow(output_bgr)
    plt.show()

    # 아무 값이나 입력해서 다음 이미지 표시
    # input 빼면 자동 저장
    # input()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

save_path = '/content/gDrive/MyDrive/CV'
last_model_path = os.path.join(save_path, 'last_model.tar')
best_model_path = os.path.join(save_path, 'best_model.tar')

if os.path.isfile(last_model_path):
  last_state_dict=torch.load(last_model_path)
  print("Epoch {}까지 돌렸음".format(last_state_dict["epoch"]))
  epochAxis=np.arange(0, len(last_state_dict["train_info"]))

  plt.title("LOSS")
  plt.plot(epochAxis, [info["loss"].cpu() for info in last_state_dict["train_info"]], epochAxis, [info["loss"].cpu() for info in last_state_dict["val_info"]], "r-")
  plt.legend(["TRAIN","VALIDATION"])
  plt.show()


if os.path.isfile(best_model_path):
  best_state_dict=torch.load(best_model_path)
  print("Epoch {}일때 최고 성능이었음".format(best_state_dict["epoch"]))

  testNet = ResAttdU_Net().cuda()
  # testNet = MyUNet().cuda()

  testNet.load_state_dict(best_state_dict['model_weight'])

  test_model(testNet, test_dataloader)