In [1]:
import argparse
import os
import torch
from tqdm import tqdm
import random
import torch.nn as nn
from torch.utils.data import DataLoader
from model_our_sim import DPSimulator
import numpy as np
from dataset import HypersimDataset
import h5py
from pathlib import Path



In [2]:


def save_h5py_file(name, my_dict):
    h = h5py.File(name, 'w')
    for k, v in my_dict.items():
        h.create_dataset(k, data=np.array([v]).squeeze())
    h.close()



def norm_dep(dep):
    all_new_dep = torch.zeros_like(dep)
    for i, x in enumerate(dep):
        curr_mask = x != 0
        x[x == 0] = x.max()
        new_dep = (x - x.min()) / (x.max() - x.min())
        new_dep[curr_mask == 0] = -1
        all_new_dep[i] = new_dep
    return all_new_dep



In [3]:


def generate_dp_from_rgbd(args, hypersim_partition):
    ## dataloader
    test_set = HypersimDataset(os.path.join(args.data_dir, hypersim_partition))
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=args.n_worker, drop_last=False)
    print('validation size: {}'.format(test_set.__len__()))

    ## initialization
    save_dir = os.path.join(args.generated_dp_dir, hypersim_partition)
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    model = DPSimulator(k_size=5)
    model.load_state_dict(torch.load(args.cp_dir, map_location='cpu'))
    model = model.to(args.device)
    print('init done')

    ## test
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(tqdm(test_loader)):
            sharp, dep, coc = data['sharp'].to(args.device), data['dep'].to(args.device), data['coc'].to(args.device)
            normalized_dep = norm_dep(dep)
            
            with torch.cuda.amp.autocast():
                pred_l, pred_r, _, _ = model(sharp, normalized_dep, coc)
            pred_l, pred_r = torch.clip(pred_l, 0, 1), torch.clip(pred_r, 0, 1)

            ## save h5 file
            my_dict = {'dp_l': pred_l.detach().cpu().numpy(), 'dp_r': pred_r.detach().cpu().numpy(), 'sharp': data['sharp'].numpy(),
                       'dep': data['dep'].numpy(), 'coc': data['coc'].numpy(), 'normal': data['normal'].numpy(),
                       'focus_dis': data['focus_dis'].numpy(), 'thin_lens_focal_len_in_mm': data['thin_lens_focal_len_in_mm'].numpy(), 'f_number': data['f_number'].numpy(),
                       'M': data['M'].numpy(), 'pixel_size': data['pixel_size'].numpy(), 'af_pt': data['af_pt'].numpy()}
            save_h5py_file(os.path.join(save_dir, '{}.h5'.format(data['curr_name'][0])), my_dict)
            

    print('test finished')






In [None]:





if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cuda:5', help='cuda device')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--n_worker', type=int, default=8, help='numer of workers')
    parser.add_argument('--cp_dir', type=str, default='./checkpoints/DP_simulator_flip_max_val_ssim.cp', help='checkpoint directory')
    parser.add_argument('--data_dir', type=str, default='/dataset/workspace2022/li/selected_hypersim', help='data directory')
    parser.add_argument('--generated_dp_dir', type=str, default='/dataset/workspace2022/li/compact_generated_dp_from_rgbd_flip', help='data directory')
    _args = parser.parse_args(args=[])

    # fix seed
    np.random.seed(_args.seed)
    torch.manual_seed(_args.seed)
    random.seed(_args.seed)
    if _args.device != 'cpu':
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # train
    
    for x in ['partition_0', 'partition_1', 'partition_2', 'partition_3', 'partition_4']:
        generate_dp_from_rgbd(_args, x)


