<a href="https://colab.research.google.com/github/dntjr41/CV_TermP/blob/main/attention_unet_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 드라이브 마운트

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# 경로 지정

In [None]:
%cd drive/My\ Drive/2022_cv_project
!pwd
!ls -la

# 압축 풀기

In [None]:
import tqdm
import zipfile
import os
import glob

# # 폴더 복사 하기
# import shutil
# shutil.copytree('./cv_project', './cv_project_na')

# # 파일 크기 확인
# filepaths = os.listdir('./cv_project/train')
# print(len(filepaths))
# filepaths = os.listdir('./cv_project/val')
# print(len(filepaths))
# filepaths = os.listdir('./cv_project/mask')
# print(len(filepaths))
# filepaths = os.listdir('./cv_project/hint')
# print(len(filepaths))

# 압축 풀기
file_name = 'cv_project_na'
!unzip -qq '{file_name}'

# 파일 크기 확인
filepaths = os.listdir('./cv_project_na/train')
print(len(filepaths))
filepaths = os.listdir('./cv_project_na/val')
print(len(filepaths))
filepaths = os.listdir('./cv_project_na/mask')
print(len(filepaths))
filepaths = os.listdir('./cv_project_na/hint')
print(len(filepaths))

# Colab Pro 버전

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

# Install pytorch-ssim, mssim (psnr 필요)

In [None]:
!pip install pytorch_ssim
!pip install pytorch_msssim

# Training

In [None]:
import os
from data.dataset import ColorHintDataset
import torch.utils.data as data
import torch
import cv2
import tqdm
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
from data.transform import tensor2im
from model.res_unet.res_unet import ResUnet
from model.res_unet.res_unet_plus import ResUnetPlusPlus
from model.res_unet.unet import UNet
# from model.att_unet.att_unet import Unet
from model.res_unet.attention_unet import AttentionUNet
import matplotlib.image as img
import copy, time
# from model.sm.model import ResAttdU_Net
from utils import AverageMeter, SSIM, psnr, save_img
from torchsummary import summary

device = "cpu"
if torch.cuda.is_available():
  device = "cuda:0"
  print('device 0 :', torch.cuda.get_device_name(0))


def main():
    # Change to your data root directory
    root_path = "./cv_project_na"

    check_path = './checkpoints/'
    # Depend on runtime setting
    use_cuda = True

    # make the directory
    os.makedirs('./checkpoints/', exist_ok=True)
    os.makedirs('./outputs/', exist_ok=True)
    os.makedirs('./outputs/test', exist_ok=True)
    os.makedirs('./outputs/GroundTruth', exist_ok=True)
    os.makedirs('./outputs/Hint', exist_ok=True)
    os.makedirs('./outputs/Output', exist_ok=True)
    os.makedirs('./checkpoints', exist_ok=True)

    # Load the data
    train_dataset = ColorHintDataset(root_path, 256, "train")
    val_dataset = ColorHintDataset(root_path, 256, "val")

    dataloaders = {}
    dataloaders['train'] = torch.utils.data.DataLoader(train_dataset, batch_size=6, num_workers = 2,  shuffle=True)
    dataloaders['valid'] = torch.utils.data.DataLoader(val_dataset, batch_size=6, num_workers = 2, shuffle=False)


    print('train dataset: ', len(train_dataset))
    print('validation dataset: ', len(val_dataset))

    # Select the model
    models = {'ResUnet': ResUnet(3), 'ResUnetPlusPlus': ResUnetPlusPlus(3), 'UNet': UNet(), 'AttentionUNet' : AttentionUNet()}
    # model = ResUnetPlusPlus(3).to(device)
    model = AttentionUNet(3).to(device)

    # load the model
    model.load_state_dict(torch.load('./checkpoints/model-epoch-5-losses-0.01241.pth'))

    criterion = torch.nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    # summary(model, (3, 256, 256))

    lmbda = lambda epoch : 0.95
    exp_lr_scehduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
    epochs = 10

    # initialize the
    since = time.time()
    train_loss, train_acc, valid_loss, valid_acc = [], [], [], []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 999
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch + 1, epochs))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            losses = AverageMeter()
            psnr_total = 0
            ssim_total = 0
            count = 0

            # Iterate over data.
            for i, data in enumerate(tqdm.tqdm(dataloaders[phase])):
                if use_cuda:
                    l = data["l"].to(device)
                    ab = data["ab"].to(device)
                    hint = data["hint"].to(device)
                else:
                    l = data["l"]
                    ab = data["ab"]
                    hint = data["hint"]

                gt_image = torch.cat((l, ab), dim=1)
                hint_image = torch.cat((l, hint), dim=1)
                hint_image = hint_image.float().to(device)
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(hint_image)
                    loss = criterion(outputs, ab)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        # zero the parameter gradients
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                # statistics
                losses.update(loss.item(), hint_image.size(0))


                if phase == 'train':
                  if i % 500 == 0:
                    print('\t Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(loss=losses))
                
                else:
                  outputs = torch.cat((l, outputs), dim = 1)
                  out_hint_np = tensor2im(outputs)
                  out_hint_bgr = cv2.cvtColor(out_hint_np, cv2.COLOR_LAB2BGR)

                  hint_np = tensor2im(hint_image)
                  hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)

                  gt_np = tensor2im(gt_image)
                  gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_LAB2BGR)
                  # psnr, ssim = save_img(gt_bgr, hint_bgr, out_hint_bgr, i)
                  # psnr_total += psnr
                  # ssim_total += ssim

            if phase == 'train':
                exp_lr_scehduler.step()
                train_loss.append(losses.avg)

            else:
                # print(' {} PSNR AVG : {:.4f} SSIM AVG : {:.4f}'.format(phase, psnr_total/len(dataloaders[phase]), ssim_total/len(dataloaders[phase])))
                valid_loss.append(losses.avg)

            print(' {} Loss: {:.3f} '.format(phase, losses.avg))

            # deep copy the model
            if phase == 'valid' and losses.avg < best_loss:
                best_idx = epoch
                best_loss = losses.avg
                best_model_wts = copy.deepcopy(model.state_dict())

                # Save model & checkpoint
                torch.save(model.state_dict(), './checkpoints/model-epoch-{}-losses-{:.5f}.pth'.format(epoch + 1, best_loss))

                print('==> best model saved - %d / %.3f' % (best_idx, best_loss))

    # Training Result
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best valid Acc: %d - %.4f' % (best_idx, best_loss))


    # Plot the training procedure
    epoch_axis = np.arange(0, epochs)
    plt.figure()
    plt.title('LOSS')
    plt.plot(epoch_axis, train_loss, epoch_axis, valid_loss, 'r-')
    plt.legend(['Train', 'Validation'], loc='best')
    plt.show()

if __name__ == '__main__':
    main()

# Test & Predict

In [None]:
from data.dataset import ColorHintDataset
import torch
import torch.utils.data as data
import cv2
import tqdm
import os
from data.transform import tensor2im
from model.res_unet.res_unet import ResUnet
from model.res_unet.res_unet_plus import ResUnetPlusPlus
from model.res_unet.unet import UNet
from model.res_unet.attention_unet import AttentionUNet
import matplotlib.pyplot as plt
from google.colab.patches import cv2_imshow

device = "cpu"
if torch.cuda.is_available():
  device = "cuda:0"
  print('device 0 :', torch.cuda.get_device_name(0))
  
def main():
    # Change to your data root directory
    root_path = "./cv_project_na"

    # Depend on runtime setting
    use_cuda = True

    test_dataset = ColorHintDataset(root_path, 256, 'test')

    dataloaders = {}
    dataloaders['test'] = torch.utils.data.DataLoader(test_dataset, batch_size = 1, shuffle=False)
    print('test dataset: ', len(test_dataset))


    # state_dict = torch.load(check_point)
    model = AttentionUNet(3).to(device)
    model.load_state_dict(torch.load('./checkpoints/model-epoch-5-losses-0.01241.pth'))

    os.makedirs('outputs/test', exist_ok=True)

    model.eval()
    for i, data in enumerate(tqdm.tqdm(dataloaders['test'])):
        if use_cuda:
            l = data["l"].to(device)
            hint = data["hint"].to(device)
            file_name = data["file_name"][0]

        hint_image = torch.cat((l, hint), dim=1)
        hint_np = tensor2im(hint_image)
        hint_image = hint_image.float().to(device)

        output = model(hint_image).squeeze(1)
        output = torch.cat((l, output), dim = 1)
        out_hint_np = tensor2im(output)

        hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)
        out_hint_bgr = cv2.cvtColor(out_hint_np, cv2.COLOR_LAB2BGR)


        plt.figure(1)
        cv2_imshow(hint_bgr)
        plt.figure(2)
        cv2_imshow(out_hint_bgr)

        input()
        
        # 사진 저장
        fname = str(file_name).replace("['", '')
        fname = fname.replace("']", '')

        cv2.imwrite("outputs/test/"+fname, out_hint_bgr)

if __name__ == '__main__':
    main()