# Train Unet for direct LF reconstruction from FS, considering inverse crime.
Use train.ipynb instead for version without considering inverse crime.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from unet_layers import *
import h5py
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tensorboardX import SummaryWriter
import os
import numpy as np
import time
import matplotlib.pyplot as plt
from data_utils import FSdataset_h5,my_paired_RandomCrop,my_paired_normalize, my_paired_gamma_correction
from visualize import show_FS,show_EPI_xu,show_EPI_yv,show_SAI
from torch.optim.lr_scheduler import MultiStepLR
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']="1"
np.random.seed(100);
torch.manual_seed(100);
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False #Setting to True may leads to faster but undeterminsitc result.

In [None]:
bs_train = 2
bs_val = 5
lr = 5e-4

nF=7
lfsize = [185, 269, 7, 7] #H,W,v,u
#dimensions of Lytro light fields, H,W,nv,nu. 
#Note original Lytro LF has dimension 376X541 X 14 X 14, the paper takes only first 372/540 spatial pixel and central 8 by 8 SAI
#which is being followed here

transform_train = T.Compose([my_paired_normalize({'FS':9000,'LF':1})]) #Since FS is not normalized in matlab and LF is normalized already to [0,1]
transform_val = T.Compose([my_paired_normalize({'FS':9000,'LF':1})])


ds_train = FSdataset_h5(FSdata_path='/home/zyhuang/EVO970Plus/FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',\
                        LFdata_path='/home/zyhuang/EVO970Plus/LF_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',trainorval='train',transform = transform_train)
ds_val =  FSdataset_h5(FSdata_path='/home/zyhuang/EVO970Plus/FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',\
                       LFdata_path='/home/zyhuang/EVO970Plus/LF_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',trainorval='val',transform = transform_val)

log_path = 'logs/Avoid_invcrime/FS_dmin_-1_dmax_0.3_nF_7/unet_FS2LF_v3_tanh/lr_5e-4_bs_train_2_bs_val_5'
writer = SummaryWriter(log_path)

train_loader=DataLoader(ds_train, batch_size=bs_train,shuffle=True, num_workers = 10,pin_memory = True)
val_loader=DataLoader(ds_val, batch_size=bs_val,shuffle=False, num_workers = 5,pin_memory = True)

In [None]:

device = torch.device("cuda")
net = unet_FS2LF_v3(nF=nF,nu=lfsize[3],nv=lfsize[2],box_constraint = 'tanh')
net.to(device)
criterion = nn.L1Loss()
#criterion = nn.MSELoss()
optimizer=optim.Adam(net.parameters(),lr=lr)
#scheduler = MultiStepLR(optimizer, milestones=[3,6,10,20], gamma=0.5)

def my_psnr(I,Iref,peakval):
    mse = ((I-Iref)**2).mean()
    return 10*torch.log10(peakval**2/mse)

In [None]:
#t1 =time.time()
step=0
for epoch in range(100):
#    scheduler.step()
    print("Current epoch number%d" %epoch) 
    for idx,data in enumerate(train_loader,0):
        net.train()
        FS,LF=data['FS'].to(device),data['LF'].to(device) # 2019 4 16: check loaded FS, LF dimension and make sure they are compatible with network struture here
        reconLF=net(FS)
        loss=criterion(reconLF,LF)
        print('Train Loss is %3f' %(loss.item()))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        
        if step % 10 == 0:
            writer.add_scalar('output_loss', loss.item(), step)    
        if step % 1000 == 0:
            
            print('Train visualization')
            reconSAI_grid = show_SAI(reconLF.detach().cpu(),[(0,0),(3,3)], isshow = False)
            reconEPI_xu_grid = show_EPI_xu(reconLF.detach().cpu(),[(100,0),(150,3)],isshow = False)
            SAI_grid = show_SAI(LF.detach().cpu(),[(0,0),(3,3)], isshow = False) 
            EPI_xu_grid = show_EPI_xu(LF.detach().cpu(),[(100,0),(150,3)], isshow = False)    
            FS_grid = show_FS(FS, isshow = False)
            writer.add_image('Training reconLF SAI [0,0], [3,3]', reconSAI_grid, step)
            writer.add_image('Training reconLF EPI [100,0], [150,3]', reconEPI_xu_grid, step)
            writer.add_image('Training trueLF SAI [0,0], [3,3]', SAI_grid, step)
            writer.add_image('Training trueLF EPI [100,0], [150,3]', EPI_xu_grid, step)  
            writer.add_image('Training FS', FS_grid, step)  
        step = step + 1 

  
    #Calculate Full loss across entire val dataset every epoch
    Full_output_loss = 0
    Full_PSNR = 0
    for idx,data in enumerate(val_loader,0):
        net.eval()
        FS,LF=data['FS'].to(device),data['LF'].to(device) 

        with torch.no_grad():
            reconLF=net(FS)
            loss=criterion(reconLF,LF)
            PSNR = my_psnr(reconLF,LF,1) ##Note if batch size is not 1, then final Full_PSNR is not exactly the average PSNR of each sample LF, but the average 
                                        #PSNR of considering the batch LF as a meta-image, as one single sample, this will be slightly different.
        #visualize first batch val sample
        if idx == 0:
            print('Validation visualization')
            reconSAI_grid = show_SAI(reconLF.detach().cpu(),[(0,0),(3,3)], isshow = False)
            reconEPI_xu_grid = show_EPI_xu(reconLF.detach().cpu(),[(100,0),(150,3)],isshow = False)
            SAI_grid = show_SAI(LF.detach().cpu(),[(0,0),(3,3)], isshow = False) 
            EPI_xu_grid = show_EPI_xu(LF.detach().cpu(),[(100,0),(150,3)], isshow = False)      
            FS_grid = show_FS(FS.cpu(), isshow = False)
            writer.add_image('Val reconLF SAI [0,0], [3,3]', reconSAI_grid, epoch)
            writer.add_image('Val reconLF EPI [100,0], [150,3]', reconEPI_xu_grid, epoch)
            writer.add_image('Val trueLF SAI [0,0], [3,3]', SAI_grid, epoch)
            writer.add_image('Val trueLF EPI [100,0], [150,3]', EPI_xu_grid, epoch)
            writer.add_image('Val FS', FS_grid, epoch)  
        print('Minibatch val_loss at the end of epoch %d is:%.4f' %(epoch, loss.item()))
        print('Minibatch PSNR at the end of epoch %d is:%.4f' %(epoch, PSNR.item()))
        Full_output_loss += loss.item()
        Full_PSNR += PSNR.item()
                           
    Full_output_loss = Full_output_loss/len(val_loader)# this assumes each batch has same size
    Full_PSNR = Full_PSNR/len(val_loader)
    print('Full_output_loss at the end of epoch %d is:%.4f' %(epoch,Full_output_loss))
    print('Full_PSNR at the end of epoch %d is:%.4f' %(epoch,Full_PSNR))
    writer.add_scalar('Full Val LF output loss',Full_output_loss, epoch)
    writer.add_scalar('Full Val PSNR',Full_PSNR, epoch)
    torch.save(net.state_dict(), os.path.join(log_path, 'model.pth'))
    #torch.cuda.empty_cache()
        
        
