In [1]:
import argparse
import os
import torch
from tqdm import tqdm
import random
import torch.nn as nn
from torch.utils.data import DataLoader
from model_our_sim import DPSimulator
import numpy as np
from skimage.metrics import structural_similarity
from dataset import MyDataset
import h5py
from pathlib import Path



In [2]:


def save_h5py_file(name, my_dict):
    h = h5py.File(name, 'w')
    for k, v in my_dict.items():
        h.create_dataset(k, data=np.array([v]).squeeze())
    h.close()




def MAE_PSNR_SSIM(batch_img_1, batch_img_2, reduction='sum', mask=None):
    def _my_mae(_x, _y, _mask):
        _diff = np.abs(_x - _y) * _mask
        return _diff[np.nonzero(_mask)].mean()

    def _my_ssim(_x, _y, _win_size=7, _data_range=1.0, _mask=None):
        _, _ssim_mat = structural_similarity(_x, _y, data_range=_data_range, multichannel=True, win_size=_win_size, full=True)
        _pad = (_win_size - 1) // 2
        _ssim_mat = _ssim_mat[_pad:-_pad, _pad:-_pad] * _mask[_pad:-_pad, _pad:-_pad]
        return _ssim_mat[np.nonzero(_mask[_pad:-_pad, _pad:-_pad])].mean()

    def _my_psnr(_x, _y, _data_range=1.0, _mask=None):
        _diff = ((_x - _y) ** 2) * _mask
        return 10 * np.log10((_data_range ** 2) / _diff[np.nonzero(_mask)].mean())

    batch_img_1, batch_img_2 = torch.clip(batch_img_1, 0, 1), torch.clip(batch_img_2, 0, 1)
    batch_1, batch_2 = batch_img_1.detach().cpu().numpy(), batch_img_2.detach().cpu().numpy(),
    batch_1, batch_2 = batch_1.transpose(0, 2, 3, 1), batch_2.transpose(0, 2, 3, 1)
    if mask is None:
        batch_mask = np.ones_like(batch_1)
    else:
        batch_mask = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
    
    mae, psnr, ssim = [], [], []
    for x, y, m in zip(batch_1, batch_2, batch_mask):
        mae.append(_my_mae(x, y, _mask=m))
        psnr.append(_my_psnr(x, y, _mask=m))
        ssim.append(_my_ssim(x, y, _mask=m))
    if reduction == 'sum':
        return np.sum(mae), np.sum(psnr), np.sum(ssim)
    elif reduction == 'mean':
        return np.mean(mae), np.mean(psnr), np.mean(ssim)
    
    

def norm_dep(dep):
    all_new_dep = torch.zeros_like(dep)
    for i, x in enumerate(dep):
        curr_mask = x != 0
        x[x == 0] = x.max()
        new_dep = (x - x.min()) / (x.max() - x.min())
        new_dep[curr_mask == 0] = -1
        all_new_dep[i] = new_dep
    return all_new_dep



In [3]:



def test(args):
    ## dataloader
    test_set = MyDataset(args.data_dir, partition='test')
    val_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=args.n_worker, drop_last=False)
    Path(args.res_dir).mkdir(parents=True, exist_ok=True)
    print('test size: {}'.format(test_set.__len__()))

    ## initialization
    model = DPSimulator(k_size=5)
    model.load_state_dict(torch.load('./checkpoints/DP_simulator.cp', map_location='cpu'))
    model = model.to(args.device)
    print('init done')

    ## test
    curr_val_loss, curr_val_mae, curr_val_psnr, curr_val_ssim = 0, 0, 0, 0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(tqdm(val_loader)):
            sharp, dep, coc = data['sharp'].to(args.device), data['dep'].to(args.device), data['coc'].to(args.device)
            dp_l, dp_r = data['dp_l'].to(args.device), data['dp_r'].to(args.device)
            mask = torch.where(dep == 0, 0, 1)
            dep = norm_dep(dep)

            with torch.cuda.amp.autocast():
                pred_l, pred_r, adder_l, adder_r = model(sharp, dep, coc)

                tmp_mask = torch.stack([mask.squeeze(1), mask.squeeze(1), mask.squeeze(1)], axis=1)
                pred_l[tmp_mask == 0] = sharp[tmp_mask == 0]
                pred_r[tmp_mask == 0] = sharp[tmp_mask == 0]

                mae_l, psnr_l, ssim_l = MAE_PSNR_SSIM(pred_l, dp_l, 'sum', mask)
                mae_r, psnr_r, ssim_r = MAE_PSNR_SSIM(pred_r, dp_r, 'sum', mask)
                mae, psnr, ssim = (mae_l + mae_r) / 2, (psnr_l + psnr_r) / 2, (ssim_l + ssim_r) / 2
                curr_val_mae, curr_val_psnr, curr_val_ssim = curr_val_mae + mae, curr_val_psnr + psnr, curr_val_ssim + ssim

            ## save h5 file
            if i % 10 == 0:
                my_dict = {'dp_l': data['dp_l'].detach().cpu().numpy(), 'dp_r': data['dp_r'].detach().cpu().numpy(), 'sharp': data['sharp'].detach().cpu().numpy(),
                           'pred_l': pred_l.detach().cpu().numpy(), 'pred_r': pred_r.detach().cpu().numpy(),
                           'dep': data['dep'].detach().cpu().numpy(), 'coc': data['coc'].detach().cpu().numpy(),
                           'mask': mask.detach().cpu().numpy(), 'focus_dis': data['focus_dis'].detach().cpu().numpy(), 'focus_pt': data['focus_pt'].detach().cpu().numpy(),
                           'adder_l': adder_l.detach().cpu().numpy(), 'adder_r': adder_r.detach().cpu().numpy()}
                save_h5py_file(os.path.join(args.res_dir, '{}.h5'.format(data['curr_name'][0])), my_dict)


    ## results
    curr_val_mae, curr_val_psnr, curr_val_ssim = curr_val_mae / test_set.__len__(), curr_val_psnr / test_set.__len__(), curr_val_ssim / test_set.__len__()
    print('curr_val_mae: {:.5f}, curr_val_psnr: {:.5f}, curr_val_ssim: {:.5f}'.format(curr_val_mae, curr_val_psnr, curr_val_ssim))


    print('test finished')







In [None]:





if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cuda:0', help='cuda device')
    parser.add_argument('--n_worker', type=int, default=8, help='number of workers')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--cp_dir', type=str, default='./pretrained/DP_simulator.cp', help='checkpoint directory')
    parser.add_argument('--data_dir', type=str, default='/dataset/workspace2021/li/final_data', help='data directory')
    parser.add_argument('--res_dir', type=str, default='./res_our_sim', help='data directory')
    _args = parser.parse_args(args=[])

    # fix seed
    np.random.seed(_args.seed)
    torch.manual_seed(_args.seed)
    random.seed(_args.seed)
    if _args.device != 'cpu':
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # train
    test(_args)


