For test time inference of DIBR_train_invC_ipynb


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from unet_layers import *
import h5py
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tensorboardX import SummaryWriter
import os
import numpy as np
import time
import matplotlib.pyplot as plt
from data_utils import FSdataset_h5,FSdataset_withSAI_h5,my_triplet_RandomCrop,my_paired_normalize,my_triplet_normalize, my_triplet_CenterCrop
from visualize import show_FS,show_EPI_xu,show_EPI_yv,show_SAI
from torch.optim.lr_scheduler import MultiStepLR
from DIBR_modules import depth_rendering_pt,transform_ray_depths_pt,depth_consistency_loss_pt,image_derivs_pt,tv_loss_pt
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']="0,1"
np.random.seed(100);
torch.manual_seed(100);
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False #Setting to True may leads to faster but undeterminsitc result.

In [2]:
bs_train = 1
bs_val = 1
lr = 3e-4
SAI_only = False # Whether only use SAI for depth estimation (without FS)
concat_SAI = True # whether concat SAI with FS along color channel for the depth estimation.(only matter when SAI_only = False)
nF=7
lfsize = [185, 269, 7, 7] #H,W,v,u
disp_mult = 1
SAI_iv = 3 #index of the SAI to be selected as All in focus image, countin from 0 
SAI_iu = 3 #index of the SAI to be selected as All in focus image, countin from 0 
lam_tv = 0.01 
lam_dc = 0.005 # 10 times larger will ensure depth fields to be consistent, not optimal yet, try reduce to 5?
#dimensions of Lytro light fields, H,W,nv,nu. 
#Note original Lytro LF has dimension 376X541 X 14 X 14, the paper takes only first 372/540 spatial pixel and central 8 by 8 SAI
#which is being followed here


FSdata_path = '/home/zyhuang/EVO970Plus/FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5'
LFdata_path = '/home/zyhuang/EVO970Plus/LF_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5'
if SAI_only:
    transform_train = T.Compose([my_paired_normalize({'FS':9000,'LF':1})]) #Since FS is not normalized in matlab and LF is normalized already to [0,1]
    transform_val = T.Compose([my_paired_normalize({'FS':9000,'LF':1})])
    ds_train = FSdataset_h5(FSdata_path=FSdata_path,LFdata_path=LFdata_path,trainorval='train',transform = transform_train)
    ds_val =  FSdataset_h5(FSdata_path=FSdata_path,LFdata_path=LFdata_path,trainorval='val',transform = transform_val) 
else:
    SAIdata_folder = '/home/zyhuang/EVO970Plus/SAI_dataset/FS_dmin_-1_dmax_0.3_nF_7_unet_FS2SAI_v3_tanh_lr_3e-4_bs_train_2_bs_val_5_inverseCrime.h5'
    transform_train = T.Compose([my_triplet_normalize({'FS':9000,'LF':1})])
    transform_val = T.Compose([my_triplet_normalize({'FS':9000,'LF':1})])
    ds_train = FSdataset_withSAI_h5(FSdata_path=FSdata_path,LFdata_path=LFdata_path,SAIdata_folder = SAIdata_folder,trainorval='train',transform = transform_train)
    ds_val =  FSdataset_withSAI_h5(FSdata_path=FSdata_path,LFdata_path=LFdata_path,SAIdata_folder = SAIdata_folder,trainorval='val',transform = transform_val)
    
#model_folder = 'logs/Avoid_invcrime/Two stage model/DIBR/SAI_only_True_disp_mult_1_detach_ray_depths/lr_3e-4_lam_tv_1e-2_lam_dc_5e-3_bs_train_1_bs_val_1'
model_folder = 'logs/Avoid_invcrime/Two stage model/DIBR/FS_dmin_-1_dmax_0.3_nF_7_GenMat/concat_SAI_True_disp_mult_1_detach_ray_depths/lr_3e-4_lam_tv_1e-2_lam_dc_5e-3_bs_train_1_bs_val_1'

train_loader=DataLoader(ds_train, batch_size=bs_train,shuffle=True, num_workers = 3,pin_memory = True)
val_loader=DataLoader(ds_val, batch_size=bs_val,shuffle=False, num_workers = 3,pin_memory = True)

In [3]:
device = torch.device("cpu")
criterion = nn.L1Loss()
depth_Net = depth_network_pt(nF,lfsize,disp_mult,concat_SAI = concat_SAI, SAI_only = SAI_only)
refine_Net = refineNet()
depth_Net.to(device)
refine_Net.to(device)
depth_Net.load_state_dict(torch.load(model_folder + '/model_depth_Net.pth'))
refine_Net.load_state_dict(torch.load(model_folder + '/model_refine_Net.pth'))
def my_psnr(I,Iref,peakval):
    mse = ((I-Iref)**2).mean()
    return 10*torch.log10(peakval**2/mse)

In [None]:


#Calculate Full loss across entire val dataset every epoch
Full_loss = 0
Full_output_loss = 0
Full_PSNR = 0
PSNR_all = []
reconLF_all = np.zeros([len(ds_val),3,7,7,185,269])
trueLF_all = np.zeros([len(ds_val),3,7,7,185,269])
FS_all = np.zeros([len(ds_val),3,7,185,269])
depth_all = np.zeros([len(ds_val),7,7,185,269])
torch.cuda.empty_cache()
for idx,data in enumerate(val_loader,0):
    depth_Net.eval()
    refine_Net.eval()

    if SAI_only:
        LF,est_SAI = data['LF'].to(device),data['LF'][:,:,SAI_iv,SAI_iu,:,:].to(device) # here est_SAI is the true SAI, since in SAI_only = True mode, camera captures true SAI      
    else:
        FS,LF,est_SAI =data['FS'].to(device),data['LF'].to(device), data['SAI'].to(device)
        #FS,LF,est_SAI =data['FS'].to(device),data['LF'].to(device),data['LF'][:,:,SAI_iv,SAI_iu,:,:].to(device) # using true SAI as est_SAI to see How much performance can increase if SAI is estimated perfectly

    with torch.no_grad():
        if SAI_only:
            ray_depths = depth_Net(est_SAI) # B,v,u,H,W
        else:
            ray_depths = depth_Net(FS,est_SAI) # B,v,u,H,W
        lf_shear_r = depth_rendering_pt(est_SAI[:,0,:,:],ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv) # B,H,W,v,u
        lf_shear_g = depth_rendering_pt(est_SAI[:,1,:,:],ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv)
        lf_shear_b = depth_rendering_pt(est_SAI[:,2,:,:],ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv)
        lf_shear = torch.cat([lf_shear_r.unsqueeze(5),lf_shear_g.unsqueeze(5),lf_shear_b.unsqueeze(5)],dim = 5).permute(0,5,3,4,1,2) #B,C,v,u,H,W
        lf_denoised = refine_Net(lf_shear,ray_depths) #B,C,v,u,H,W

        shear_loss = criterion(lf_shear,LF)
        tv_loss = lam_tv * tv_loss_pt(ray_depths.permute(0,3,4,1,2))
        depth_consistency_loss = lam_dc * depth_consistency_loss_pt(ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv)
        output_loss = criterion(lf_denoised,LF)
        loss = shear_loss + output_loss + tv_loss + depth_consistency_loss 
        PSNR = my_psnr(lf_denoised,LF,1)
        
        depth_all[idx,...]=ray_depths.cpu().numpy()
        reconLF_all[idx,...]=lf_denoised.cpu().numpy()
        trueLF_all[idx,...]=LF.cpu().numpy()
        FS_all[idx,...]=FS.cpu().numpy()
    Full_loss += loss.item()
    Full_output_loss += output_loss.item()
    Full_PSNR += PSNR.item()
    PSNR_all.append(PSNR.item())
    print('Minibatch val_loss at the end is:%.4f' %(loss.item()))



Full_loss = Full_loss/len(val_loader) # this assumes each batch has same size
Full_output_loss = Full_output_loss/len(val_loader)
Full_PSNR = Full_PSNR/len(val_loader)
print('Full val_loss at the end is:%.4f' %(Full_loss))
print('Full_PSNR at the end is:%.4f' %(Full_PSNR))



In [45]:
import scipy.io as sio
save_howmany = 30 # avoid saving too large file

sio.savemat('Avoid_invcrime_Two stage model_DIBR_FS_dmin_-1_dmax_0.3_nF_7_GenMat_concat_SAI_True_disp_mult_1_detach_ray_depths_lr_3e-4_lam_tv_1e-2_lam_dc_5e-3_bs_train_1_bs_val_1_100testsample(30saved)_result.mat',{'FS_all':FS_all[:save_howmany],'depth':depth_all[:save_howmany],'trueLF_all':trueLF_all[:save_howmany],'reconLF_all':reconLF_all[:save_howmany]},do_compression=True)

In [13]:
Avg_time

3.960785622596741