In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from unet_layers import *
import h5py
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tensorboardX import SummaryWriter
import os
import numpy as np
import time
import matplotlib.pyplot as plt
from data_utils import FSdataset_withSAI_h5,my_triplet_RandomCrop,my_triplet_normalize, my_triplet_CenterCrop
from visualize import show_FS,show_EPI_xu,show_EPI_yv,show_SAI
from torch.optim.lr_scheduler import MultiStepLR
from DIBR_modules import depth_rendering_pt,transform_ray_depths_pt,depth_consistency_loss_pt,image_derivs_pt,tv_loss_pt
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']="1"
np.random.seed(100);
torch.manual_seed(100);
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False #Setting to True may leads to faster but undeterminsitc result.

In [None]:
bs_train = 1
bs_val = 1
lr = 3e-4

nF=7
lfsize = [185, 269, 7, 7] #H,W,v,u
disp_mult = 1
SAI_iv = 3 #index of the SAI to be selected as All in focus image, countin from 0 
SAI_iu = 3 #index of the SAI to be selected as All in focus image, countin from 0 
lam_tv = 0.01 
lam_dc = 0.005 # 10 times larger will ensure depth fields to be consistent, not optimal yet, try reduce to 5?
#dimensions of Lytro light fields, H,W,nv,nu. 
#Note original Lytro LF has dimension 376X541 X 14 X 14, the paper takes only first 372/540 spatial pixel and central 8 by 8 SAI
#which is being followed here

transform_train = T.Compose([my_triplet_normalize({'FS':9000,'LF':1})])
transform_val = T.Compose([my_triplet_normalize({'FS':9000,'LF':1})]) 

ds_train = FSdataset_withSAI_h5(FSdata_path='/home/zyhuang/EVO970Plus/FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',\
                             LFdata_path='/home/zyhuang/EVO970Plus/LF_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',\
                             SAIdata_folder = '/home/zyhuang/EVO970Plus/SAI_dataset/FS_dmin_-1_dmax_0.3_nF_7_unet_FS2SAI_v3_tanh_lr_3e-4_bs_train_2_bs_val_5_inverseCrime.h5',\
                             trainorval='train',transform = transform_train)
ds_val =  FSdataset_withSAI_h5(FSdata_path='/home/zyhuang/EVO970Plus/FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',\
                            LFdata_path='/home/zyhuang/EVO970Plus/LF_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',\
                            SAIdata_folder = '/home/zyhuang/EVO970Plus/SAI_dataset/FS_dmin_-1_dmax_0.3_nF_7_unet_FS2SAI_v3_tanh_lr_3e-4_bs_train_2_bs_val_5_inverseCrime.h5',\
                            trainorval='val',transform = transform_val)
log_path = 'logs/Avoid_invcrime/Two stage model/DIBR/FS_dmin_-1_dmax_0.3_nF_7/concat_SAI_True_disp_mult_1_detach_ray_depths/lr_3e-4_lam_tv_1e-2_lam_dc_5e-3_bs_train_1_bs_val_1'
writer = SummaryWriter(log_path)

train_loader=DataLoader(ds_train, batch_size=bs_train,shuffle=True, num_workers = 3,pin_memory = True)
val_loader=DataLoader(ds_val, batch_size=bs_val,shuffle=False, num_workers = 3,pin_memory = True)

In [None]:

device = torch.device("cuda")
depth_Net = depth_network_pt(nF,lfsize,disp_mult,concat_SAI = True)
refine_Net = refineNet()
depth_Net.to(device)
refine_Net.to(device)
criterion = nn.L1Loss()

params = list(depth_Net.parameters()) + list(refine_Net.parameters())
optimizer=optim.Adam(params,lr=lr)

def my_psnr(I,Iref,peakval):
    mse = ((I-Iref)**2).mean()
    return 10*torch.log10(peakval**2/mse)

In [None]:
step=0
for epoch in range(50):    
    #scheduler.step()
    print("Current epoch number%d" %epoch) 
    for idx,data in enumerate(train_loader,0):
        depth_Net.train()
        refine_Net.train()
        FS,LF,est_SAI =data['FS'].to(device),data['LF'].to(device), data['SAI'].to(device)
        
        ray_depths = depth_Net(FS,est_SAI) # B,v,u,H,W
        lf_shear_r = depth_rendering_pt(est_SAI[:,0,:,:],ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv) # B,H,W,v,u
        lf_shear_g = depth_rendering_pt(est_SAI[:,1,:,:],ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv)
        lf_shear_b = depth_rendering_pt(est_SAI[:,2,:,:],ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv)
        lf_shear = torch.cat([lf_shear_r.unsqueeze(5),lf_shear_g.unsqueeze(5),lf_shear_b.unsqueeze(5)],dim = 5).permute(0,5,3,4,1,2) #B,C,v,u,H,W
        lf_denoised = refine_Net(lf_shear,ray_depths.detach()) #B,C,v,u,H,W

        shear_loss = criterion(lf_shear,LF)
        tv_loss = lam_tv * tv_loss_pt(ray_depths.permute(0,3,4,1,2))
        depth_consistency_loss = lam_dc * depth_consistency_loss_pt(ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv)
        output_loss = criterion(lf_denoised,LF)
        loss = shear_loss + output_loss + tv_loss + depth_consistency_loss 
        
        print('Train Loss is %3f' %(loss.item()))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        if step % 10 == 0:
            writer.add_scalar('shear_loss', shear_loss.item(), step)   
            writer.add_scalar('tv_loss', tv_loss.item(), step)   
            writer.add_scalar('depth_consistency_loss', depth_consistency_loss.item(), step)   
            writer.add_scalar('output_loss', output_loss.item(), step)   
            writer.add_scalar('loss', loss.item(), step)   
        if step % 400 == 0:
            print('Train visualization')
            est_SAI_grid = show_SAI(est_SAI.detach().cpu(), isshow = False) 
            FS_grid = show_FS(FS, isshow = False)
            LambertLF_SAI_grid = show_SAI(lf_shear.detach().cpu(),[(0,0),(3,3)], isshow = False)
            LambertLF_EPI_xu_grid = show_EPI_xu(lf_shear.detach().cpu(),[(100,0),(150,3)],isshow = False)
            reconLF_SAI_grid = show_SAI(lf_denoised.detach().cpu(),[(0,0),(3,3)], isshow = False)
            reconLF_EPI_xu_grid = show_EPI_xu(lf_denoised.detach().cpu(),[(100,0),(150,3)],isshow = False)
            SAI_grid = show_SAI(LF.detach().cpu(),[(0,0),(3,3)], isshow = False) 
            EPI_xu_grid = show_EPI_xu(LF.detach().cpu(),[(100,0),(150,3)], isshow = False)  
            
            
            B,v,u,H,W = ray_depths.shape
            writer.add_image('Training ray_depths', ray_depths.permute(0,1,3,2,4).reshape(B*v*H,u*W), step)
            writer.add_image('Training est SAI', est_SAI_grid, step)
            writer.add_image('Training FS', FS_grid, step)
            writer.add_image('Training LambertLF SAI [0,0], [3,3]', LambertLF_SAI_grid, step)
            writer.add_image('Training LambertLF EPI [100,0], [150,3]', LambertLF_EPI_xu_grid, step)
            writer.add_image('Training reconLF SAI [0,0], [3,3]', reconLF_SAI_grid, step)
            writer.add_image('Training reconLF EPI [100,0], [150,3]', reconLF_EPI_xu_grid, step)
            writer.add_image('Training trueLF SAI [0,0], [3,3]', SAI_grid, step)
            writer.add_image('Training trueLF EPI [100,0], [150,3]', EPI_xu_grid, step)  
        
        step = step + 1 
        
    #Calculate Full loss across entire val dataset every epoch
    Full_loss = 0
    Full_output_loss = 0
    Full_PSNR = 0
    torch.cuda.empty_cache()
    for idx,data in enumerate(val_loader,0):
        depth_Net.eval()
        refine_Net.eval()
        FS,LF, est_SAI =data['FS'].to(device),data['LF'].to(device), data['SAI'].to(device)
        with torch.no_grad():
            ray_depths = depth_Net(FS,est_SAI) # B,v,u,H,W
            lf_shear_r = depth_rendering_pt(est_SAI[:,0,:,:],ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv) # B,H,W,v,u
            lf_shear_g = depth_rendering_pt(est_SAI[:,1,:,:],ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv)
            lf_shear_b = depth_rendering_pt(est_SAI[:,2,:,:],ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv)
            lf_shear = torch.cat([lf_shear_r.unsqueeze(5),lf_shear_g.unsqueeze(5),lf_shear_b.unsqueeze(5)],dim = 5).permute(0,5,3,4,1,2) #B,C,v,u,H,W
            lf_denoised = refine_Net(lf_shear,ray_depths) #B,C,v,u,H,W

            shear_loss = criterion(lf_shear,LF)
            tv_loss = lam_tv * tv_loss_pt(ray_depths.permute(0,3,4,1,2))
            depth_consistency_loss = lam_dc * depth_consistency_loss_pt(ray_depths.permute(0,3,4,1,2),lfsize,SAI_iu,SAI_iv)
            output_loss = criterion(lf_denoised,LF)
            loss = shear_loss + output_loss + tv_loss + depth_consistency_loss 
            PSNR = my_psnr(lf_denoised,LF,1)
        Full_loss += loss.item()
        Full_output_loss += output_loss.item()
        Full_PSNR += PSNR.item()
        print('Minibatch val_loss at the end of epoch %d is:%.4f' %(epoch, loss.item()))
        
        
        
    Full_loss = Full_loss/len(val_loader) # this assumes each batch has same size
    Full_output_loss = Full_output_loss/len(val_loader)
    Full_PSNR = Full_PSNR/len(val_loader)
    print('Full val_loss at the end of epoch %d is:%.4f' %(epoch,Full_loss))
    print('Full_PSNR at the end of epoch %d is:%.4f' %(epoch,Full_PSNR))
    writer.add_scalar('Full val loss', Full_loss, epoch)    
    writer.add_scalar('Full Val LF output loss', Full_output_loss, epoch) 
    writer.add_scalar('Full Val PSNR',Full_PSNR, epoch)
    torch.save(refine_Net.state_dict(), os.path.join(log_path, 'model_refine_Net.pth'))
    torch.save(depth_Net.state_dict(), os.path.join(log_path, 'model_depth_Net.pth'))
