In [1]:
import argparse
import os
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from model_our_sim import DPSimulator
import numpy as np
from loss import DPLoss, EdgeLoss
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from dataset import MyDataset, collate_fn_replace_corrupted
import functools
import h5py
from pathlib import Path



In [2]:


def save_h5py_file(name, my_dict):
    h = h5py.File(name, 'w')
    for k, v in my_dict.items():
        h.create_dataset(k, data=np.array([v]).squeeze())
    h.close()



def MAE_PSNR_SSIM(batch_img_1, batch_img_2, reduction='sum', mask=None):
    def _my_mae(_x, _y, _mask):
        _diff = np.abs(_x - _y) * _mask
        return _diff[np.nonzero(_mask)].mean()

    def _my_ssim(_x, _y, _win_size=7, _data_range=1.0, _mask=None):
        _, _ssim_mat = structural_similarity(_x, _y, data_range=_data_range, multichannel=True, win_size=_win_size, full=True)
        _pad = (_win_size - 1) // 2
        _ssim_mat = _ssim_mat[_pad:-_pad, _pad:-_pad] * _mask[_pad:-_pad, _pad:-_pad]
        return _ssim_mat[np.nonzero(_mask[_pad:-_pad, _pad:-_pad])].mean()

    def _my_psnr(_x, _y, _data_range=1.0, _mask=None):
        _diff = ((_x - _y) ** 2) * _mask
        return 10 * np.log10((_data_range ** 2) / _diff[np.nonzero(_mask)].mean())

    batch_img_1, batch_img_2 = torch.clip(batch_img_1, 0, 1), torch.clip(batch_img_2, 0, 1)
    batch_1, batch_2 = batch_img_1.detach().cpu().numpy(), batch_img_2.detach().cpu().numpy(),
    batch_1, batch_2 = batch_1.transpose(0, 2, 3, 1), batch_2.transpose(0, 2, 3, 1)
    if mask is None:
        batch_mask = np.ones_like(batch_1)
    else:
        batch_mask = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
    mae, psnr, ssim = [], [], []
    for x, y, m in zip(batch_1, batch_2, batch_mask):
        mae.append(_my_mae(x, y, _mask=m))
        psnr.append(_my_psnr(x, y, _mask=m))
        ssim.append(_my_ssim(x, y, _mask=m))
    if reduction == 'sum':
        return np.sum(mae), np.sum(psnr), np.sum(ssim)
    elif reduction == 'mean':
        return np.mean(mae), np.mean(psnr), np.mean(ssim)
    
    

def norm_dep(dep):
    all_new_dep = torch.zeros_like(dep)
    for i, x in enumerate(dep):
        curr_mask = x != 0
        x[x == 0] = x.max()
        new_dep = (x - x.min()) / (x.max() - x.min())
        new_dep[curr_mask == 0] = -1
        all_new_dep[i] = new_dep
    return all_new_dep



In [3]:


def train(args):
    ## dataloader
    train_set = MyDataset(args.data_dir, partition='train', required_dep_percent=args.required_dep_percent)
    val_set = MyDataset(args.data_dir, partition='valid')
    train_loader = DataLoader(train_set, batch_size=args.bs, shuffle=True, num_workers=args.n_worker, drop_last=True,
                              collate_fn=functools.partial(collate_fn_replace_corrupted, torch_dataset=train_set))
    val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=args.n_worker, drop_last=False)
    print('train size: {}, validation size: {}'.format(train_set.__len__(), val_set.__len__()))

    ## initialization
    model = DPSimulator(k_size=5)
    model = model.to(args.device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epoch, eta_min=1e-6)
    criterion_dp = DPLoss(loss_type=args.loss_type, use_mask=True)
    criterion_edge = EdgeLoss(device=args.device, use_mask=True)
    min_train_loss, min_val_loss, min_val_mae, max_val_ssim, max_val_psnr = float('inf'), float('inf'), float('inf'), 0, 0
    logger = open('log_{}.txt'.format(args.task), 'w').close()
    Path('./checkpoints').mkdir(exist_ok=True, parents=True)
    print('init done')

    ## loop
    for epoch in range(args.n_epoch):
        ## train phase
        scaler = torch.cuda.amp.GradScaler()
        curr_train_loss, curr_train_mae, curr_train_psnr, curr_train_ssim = 0, 0, 0, 0
        num_train = 0
        model.train()
        for data in tqdm(train_loader):
            sharp, dep, coc = data['sharp'].to(args.device), data['dep'].to(args.device), data['coc'].to(args.device)
            dp_l, dp_r = data['dp_l'].to(args.device), data['dp_r'].to(args.device)
            mask = torch.where(dep == 0, 0, 1)
            dep = norm_dep(dep)
            
            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():
                pred_l, pred_r, _, _ = model(sharp, dep, coc)
                loss = criterion_dp(pred_l, pred_r, dp_l, dp_r, mask) + criterion_edge(pred_l, dp_l, mask) + criterion_edge(pred_r, dp_r, mask)
                
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            
            curr_train_loss += loss.item()
            mae_l, psnr_l, ssim_l = MAE_PSNR_SSIM(pred_l, dp_l, 'sum', mask)
            mae_r, psnr_r, ssim_r = MAE_PSNR_SSIM(pred_r, dp_r, 'sum', mask)
            mae, psnr, ssim = (mae_l + mae_r) / 2, (psnr_l + psnr_r) / 2, (ssim_l + ssim_r) / 2
            curr_train_mae, curr_train_psnr, curr_train_ssim = curr_train_mae + mae, curr_train_psnr + psnr, curr_train_ssim + ssim
            num_train += sharp.shape[0]
            
            scaler.update()
                
        scheduler.step()
        min_train_loss = min(min_train_loss, curr_train_loss)
        curr_train_mae, curr_train_psnr, curr_train_ssim = curr_train_mae / num_train, curr_train_psnr / num_train, curr_train_ssim / num_train
        print('Epoch: {}, curr_train_loss: {:.5f}, curr_train_mae: {:.5f}, curr_train_psnr: {:.5f}, '
              'curr_train_ssim: {:.5f}'.format(epoch, curr_train_loss, curr_train_mae, curr_train_psnr, curr_train_ssim))


        ## validation phase
        curr_val_loss, curr_val_mae, curr_val_psnr, curr_val_ssim = 0, 0, 0, 0
        num_val = 0
        model.eval()
        with torch.no_grad():
            for data in tqdm(val_loader):
                sharp, dep, coc = data['sharp'].to(args.device), data['dep'].to(args.device), data['coc'].to(args.device)
                dp_l, dp_r = data['dp_l'].to(args.device), data['dp_r'].to(args.device)
                mask = torch.where(dep == 0, 0, 1)
                dep = norm_dep(dep)
                
                with torch.cuda.amp.autocast():
                    pred_l, pred_r, _, _ = model(sharp, dep, coc)
                    loss = criterion_dp(pred_l, pred_r, dp_l, dp_r, mask) + criterion_edge(pred_l, dp_l, mask) + criterion_edge(pred_r, dp_r, mask)
                    
                curr_val_loss += loss.item()
                mae_l, psnr_l, ssim_l = MAE_PSNR_SSIM(pred_l, dp_l, 'sum', mask)
                mae_r, psnr_r, ssim_r = MAE_PSNR_SSIM(pred_r, dp_r, 'sum', mask)
                mae, psnr, ssim = (mae_l + mae_r) / 2, (psnr_l + psnr_r) / 2, (ssim_l + ssim_r) / 2
                num_val += sharp.shape[0]
                curr_val_mae, curr_val_psnr, curr_val_ssim = curr_val_mae + mae, curr_val_psnr + psnr, curr_val_ssim + ssim
            
            ## checkpoints
            min_val_loss = min(min_val_loss, curr_val_loss)
            curr_val_mae, curr_val_psnr, curr_val_ssim = curr_val_mae / num_val, curr_val_psnr / num_val, curr_val_ssim / num_val
            if max_val_ssim < curr_val_ssim:
                max_val_ssim = curr_val_ssim
                torch.save(model.state_dict(), './checkpoints/{}_max_val_ssim.cp'.format(args.task))
            if max_val_psnr < curr_val_psnr:
                max_val_psnr = curr_val_psnr
                torch.save(model.state_dict(), './checkpoints/{}_max_val_psnr.cp'.format(args.task))
            if min_val_mae > curr_val_mae:
                min_val_mae = curr_val_mae
                torch.save(model.state_dict(), './checkpoints/{}_min_val_mae.cp'.format(args.task))
          
        
        ## logger to txt
        print('Epoch: {}, curr_val_loss: {:.5f}, min_val_loss: {:.5f}, curr_val_mae: {:.5f}, min_val_mae: {:.5f}, curr_val_psnr: {:.5f}, max_val_psnr: {:.5f}, '
              'curr_val_ssim: {:.5f}, max_val_ssim: {:.5f}'.format(epoch, curr_val_loss, min_val_loss, curr_val_mae, min_val_mae, curr_val_psnr, max_val_psnr, curr_val_ssim, max_val_ssim))
        
        f = open('log_{}.txt'.format(args.task), 'a')
        f.write('Epoch: {}, curr_train_loss: {:.5f}, curr_train_mae: {:.5f}, curr_train_psnr: {:.5f}, curr_train_ssim: {:.5f}, curr_val_loss: {:.5f}, '
                'min_val_loss: {:.5f}, curr_val_mae: {:.5f}, min_val_mae: {:.5f}, curr_val_psnr: {:.5f}, max_val_psnr: {:.5f}, curr_val_ssim: {:.5f}, '
                'max_val_ssim: {:.5f}\n'.format(epoch, curr_train_loss, curr_train_mae, curr_train_psnr, curr_train_ssim, curr_val_loss, min_val_loss, 
                 curr_val_mae, min_val_mae, curr_val_psnr, max_val_psnr, curr_val_ssim, max_val_ssim))
        f.close()

    print('training finished')




In [None]:





if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    # train-related
    parser.add_argument('--task', type=str, default='DP_simulator', help='task name')
    parser.add_argument('--device', type=str, default='cuda:5', help='cuda device')
    parser.add_argument('--n_epoch', type=int, default=100, help='number of epochs')
    parser.add_argument('--bs', type=int, default=8, help='batch size')
    parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
    parser.add_argument('--loss_type', type=str, default='charbonnier', help='loss type')
    # others
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--data_dir', type=str, default='/dataset/workspace2021/li/final_data', help='data directory')
    parser.add_argument('--n_worker', type=int, default=8, help='number of workers')
    parser.add_argument('--required_dep_percent', type=float, default=0.8, help='required percent of valid depths of a patch in training')
    _args = parser.parse_args(args=[])

    # fix seed
    np.random.seed(_args.seed)
    torch.manual_seed(_args.seed)
    random.seed(_args.seed)
    if _args.device != 'cpu':
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # train
    train(_args)


