In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader

from utils.dataset import CustomDataset
from torchsampler import ImbalancedDatasetSampler

import torch
import torch.nn as nn
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import numpy as np 

import random

from torchsummary import summary

from torch.nn import functional as F 

from tqdm import tqdm

import matplotlib.pyplot as plt


In [None]:
tr_img_dir = '/mnt/HDD/octc/mask_abstract/train'
tr_mask_dir = '/mnt/HDD/octc/mask_abstract/mask' 
vl_img_dir = '/mnt/HDD/octc/mask_abstract/test'
vl_mask_dir = '/mnt/HDD/octc/mask_abstract/mask'

train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
    transforms.RandomAdjustSharpness(0.5),
    transforms.RandomAutocontrast(0.5),
    transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
    transforms.ToTensor()
])

valid_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])


train_dataset = CustomDataset(
    image_dir = tr_img_dir,
    mask_dir = tr_mask_dir,
    transform= train_transform,
    mask_transform= None,
    testing= False,
    mask_shuffle = True,
)
valid_dataset = CustomDataset(
    image_dir = vl_img_dir,
    mask_dir = vl_mask_dir,
    transform= valid_transform,
    mask_transform= None,
    testing = True,
    mask_shuffle = False,
)
tr_batch, vl_batch = 6, 2
train_loader = DataLoader(dataset = train_dataset, batch_size = tr_batch, shuffle = True)

valid_loader = DataLoader(dataset = valid_dataset, batch_size = vl_batch, shuffle = False)


In [None]:
def plotting(images, masks, input_images):
    plt.figure(dpi =256)
    plt.subplot(131)
    plt.imshow(images[1,0], cmap= 'gray')
    plt.title('image[GT]')
    plt.subplot(132)
    plt.imshow(masks[1,0 ], cmap= 'gray')
    plt.title('mask')
    plt.subplot(133)
    plt.imshow(input_images[1,0 ], cmap= 'gray')
    plt.title('input_image')
    plt.tight_layout()
    plt.show()

In [None]:
for images, masks in train_loader:
    # mask가 0이 아닌 부분에 대해 image를 mask로 대체
    input_images = images.clone()
    # mask와 input_images shape이 같아야하므로 mask를 image shape으로 resize
    print(images.shape, masks.shape)
    input_images[masks != 0] = masks[masks != 0] 
    # input_images 처리해줫으니 다시 masks를 1채널로 변경
    masks = masks[:,0,:,:].unsqueeze(1)

    plotting(images, masks, input_images)
    break
for images, masks, paths in valid_loader:
    # mask가 0이 아닌 부분에 대해 image를 mask로 대체
    input_images = images.clone()
    # mask와 input_images shape이 같아야하므로 mask를 image shape으로 resize
    input_images[masks != 0] = masks[masks != 0] 
        
    # input_images 처리해줫으니 다시 masks를 1채널로 변경
    masks = masks[:,0,:,:].unsqueeze(1)
    plotting(images, masks, input_images)
    break


In [None]:
from model.aotgan import InpaintGenerator, Discriminator
netG = InpaintGenerator()
netD = Discriminator()

In [None]:
from torchsummary import summary
from loss.loss import L1, Perceptual, Style, smgan 
import torch.optim as optim
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

netG = InpaintGenerator().to(device)
netD = Discriminator().to(device)
G_check = torch.load('/mnt/HDD/oci_models/aotgan/premodel/G0000000.pt')
D_check = torch.load('/mnt/HDD/oci_models/aotgan/premodel/D0000000.pt') #<-- Discriminator is not needed
O_check = torch.load('/mnt/HDD/oci_models/aotgan/premodel/O0000000.pt')

netG.load_state_dict(G_check)

netD = Discriminator().to(device)
optimG = optim.Adam(params= netG.parameters(), lr = 0.0001, betas = (0, 0.9))
optimD = optim.Adam(params= netD.parameters(), lr = 0.0001, betas = (0, 0.9))
optimG.load_state_dict(O_check['optimG'])
optimD.load_state_dict(O_check['optimD'])

g_loss_1 = L1()
g_loss_2 = Perceptual()
g_loss_3 = Style()
loss_gan = smgan()

def G_LOSS(Adv, L1, Per, Style, w1 = 0.01, w2=1 , w3 = 0.1, w4 = 100):
    """
    Adv Loss / L1 / Perceptual / Style
    """
    return Adv*w1 + L1*w2 + Per*w3 + Style*w4

metrics = {
    't_g_loss':[],
    't_d_loss':[],
    'v_g_loss':[],
    'v_d_loss':[],
}

In [None]:
def save_validation(labels, masks, input_images, pred_images, pred_masks, comp_images,  epoch, save_dir):
    plt.figure(dpi=128)
    plt.subplot(231)
    plt.imshow(labels[0, 0].cpu().detach().numpy(), cmap='gray')
    plt.title("original")
    plt.subplot(232)
    plt.imshow(masks[0, 0].cpu().detach().numpy(), cmap='gray')
    plt.title("mask")
    plt.subplot(233)
    plt.imshow(input_images[0, 0].cpu().detach().numpy(), cmap='gray')
    plt.title('input image')
    plt.subplot(234)
    plt.imshow(pred_images[0, 0].cpu().detach().numpy(), cmap='gray')
    plt.title('pred before image')
    plt.subplot(235)
    plt.imshow(pred_masks[0, 0].cpu().detach().numpy(), cmap='gray')
    plt.title('pred Discriminator mask')
    plt.subplot(236)
    plt.imshow(comp_images[0, 0].cpu().detach().numpy(), cmap='gray')
    plt.title('Result image')
    plt.tight_layout()
    # plt.savefig(os.path.join(save_dir, f'epoch_{epoch}.png'))
    plt.show()
    plt.close()

def save_model(netG, netD, optimG, optimD, epoch, save_dir):
    torch.save({
        'netG': netG.state_dict(),
        'netD': netD.state_dict(),
        'optimG': optimG.state_dict(),
        'optimD': optimD.state_dict()
    }, os.path.join(save_dir, f'epoch_{epoch}.pt'))

def save_loss(metrics, save_dir):
    # loss plot
    plt.figure(dpi=128)
    for key, value in metrics.items():
        plt.plot(value, label=key)
    plt.legend()
    plt.savefig(os.path.join(save_dir, 'loss.png'))
    plt.close()
    
    np.save(os.path.join(save_dir, 'metrics.npy'),metrics)


In [None]:
save_path = '/mnt/HDD/oci_models/aotgan/240505'
os.makedirs(save_path, exist_ok=True)
for epoch in range(30):
    t_g_losses, t_d_losses, v_g_losses, v_d_losses = 0., 0., 0., 0.
    
    netG.train()
    netD.train()
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        
        ### 전처리 
        # mask가 0이 아닌 부분에 대해 image를 mask로 대체
        input_images = images.clone()
        # mask와 input_images shape이 같아야하므로 mask를 image shape으로 resize
        input_images[masks != 0] = masks[masks != 0] 

        # input_images 처리해줫으니 다시 masks를 1채널로 변경
        masks = masks[:,0,:,:].unsqueeze(1)
        
        # 입력이미지 device 할당 
        input_images = input_images.to(device) 
    
        ### Training
        pred_images = netG(input_images, masks)
        
        ## mask에서 0이 아닌 부분을 GT로 대체, 이때 마스크는 0~1사이의 값을 가짐 
        masks = masks.repeat(1,3,1,1)
        comp_images = images.clone()
        comp_images[masks != 0] = pred_images[masks != 0]
        masks = masks[:,0,:,:].unsqueeze(1)

        l1 = g_loss_1(pred_images, images)
        l2 = g_loss_2(pred_images, images)
        l3 = g_loss_3(pred_images, images)
        pred_masks, (d_loss, adv_loss) = loss_gan(netD=netD, fake=comp_images, real=images, masks=masks)
        g_loss = G_LOSS(adv_loss, l1, l2, l3)
        
        optimG.zero_grad()
        optimD.zero_grad()
        
        g_loss.backward()
        d_loss.backward()
        
        optimG.step()
        optimD.step()
        
        t_g_losses += g_loss.cpu().detach().item()
        t_d_losses += d_loss.cpu().detach().item()
        
        print(images.shape)
        print(masks.shape)
        print(input_images.shape)
        print(pred_images.shape)
        print(comp_images.shape)
    with torch.no_grad():
        netG.eval()
        netD.eval()
        for images, masks, paths in valid_loader:
            images, masks = images.to(device), masks.to(device)
            ### 전처리
            # mask가 0이 아닌 부분에 대해 image를 mask로 대체
            input_images = images.clone()
            # mask와 input_images shape이 같아야하므로 mask를 image shape으로 resize
            input_images[masks != 0] = masks[masks != 0] 
            # input_images 처리해줫으니 다시 masks를 1채널로 변경
            masks = masks[:,0,:,:].unsqueeze(1)
            # 입력이미지 device 할당
            input_images = input_images.to(device) 

            ### inference
            pred_images = netG(input_images, masks)  # 3+1ch
            
            ## mask에서 0이 아닌 부분을 GT로 대체, 이때 마스크는 0~1사이의 값을 가짐 
            masks = masks.repeat(1,3,1,1)
            comp_images = images.clone()
            comp_images[masks != 0] = pred_images[masks != 0]
            masks = masks[:,0,:,:].unsqueeze(1)


            l1 = g_loss_1(pred_images, images)
            l2 = g_loss_2(pred_images, images)
            l3 = g_loss_3(pred_images, images)
            pred_masks, (d_loss, adv_loss) = loss_gan(netD=netD, fake=comp_images, real=images, masks=masks)
            g_loss = G_LOSS(adv_loss, l1, l2, l3)
            
            v_g_losses += g_loss.cpu().detach().item()
            v_d_losses += d_loss.cpu().detach().item()
    if epoch % 5 == 0:
        save_validation(images, masks, input_images, pred_images, pred_masks, comp_images,  epoch, save_dir = save_path)
        save_model(netG, netD, optimG, optimD, epoch, save_dir = save_path)
        save_loss(metrics, save_dir = save_path)

    metrics['t_g_loss'].append(t_g_losses / len(train_loader))
    metrics['t_d_loss'].append(t_d_losses / len(train_loader))
    metrics['v_g_loss'].append(v_g_losses / len(valid_loader))
    metrics['v_d_loss'].append(v_d_losses / len(valid_loader))

    print("#" * 100)    
    print(f"Train - G LOSS : {metrics['t_g_loss'][-1]} | {metrics['t_d_loss'][-1]}\n")
    print(f"Valid - D LOSS : {metrics['v_g_loss'][-1]} | {metrics['v_d_loss'][-1]}\n")
    print("#" * 100)
    
