In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from unet_layers import *
import h5py
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tensorboardX import SummaryWriter
import os
import numpy as np
import time
import matplotlib.pyplot as plt
from data_utils import FSdataset_h5,my_paired_RandomCrop,my_paired_normalize, my_paired_gamma_correction
from visualize import show_FS,show_EPI_xu,show_EPI_yv,show_SAI
from torch.optim.lr_scheduler import MultiStepLR
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']="1"
np.random.seed(100);
torch.manual_seed(100);
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False #Setting to True may leads to faster but undeterminsitc result.

In [None]:
bs_train = 2
bs_val = 5
lr = 3e-4

nF=7
lfsize = [185, 269, 7, 7] #H,W,v,u
SAI_iv = 3 #index of the SAI to be selected as All in focus image, counting from 0 
SAI_iu = 3 #index of the SAI to be selected as All in focus image, counting from 0 
#dimensions of Lytro light fields, H,W,nv,nu. 
#Note original Lytro LF has dimension 376X541 X 14 X 14, the paper takes only first 372/540 spatial pixel and central 8 by 8 SAI
#which is being followed here. Here the LF is further cropped to 7 by 7 in angular dimensions.

transform_train = T.Compose([my_paired_normalize({'FS':9000,'LF':1})]) #Since FS is not normalized in matlab and LF is normalized already to [0,1]
transform_val = T.Compose([my_paired_normalize({'FS':9000,'LF':1})])

ds_train = FSdataset_h5(FSdata_path='/home/zyhuang/EVO970Plus/FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',\
                        LFdata_path='/home/zyhuang/EVO970Plus/LF_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',trainorval='train',transform = transform_train)
ds_val =  FSdataset_h5(FSdata_path='/home/zyhuang/EVO970Plus/FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',\
                       LFdata_path='/home/zyhuang/EVO970Plus/LF_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',trainorval='val',transform = transform_val)
#log_path = 'logs/trial'
log_path = 'logs/Avoid_invcrime/Two stage model/ViewNet/FS_dmin_-1_dmax_0.3_nF_7/unet_FS2SAI_v3_tanh/lr_3e-4_gamma_0p5_3_6_10_20_bs_train_2_bs_val_5'
writer = SummaryWriter(log_path)

train_loader=DataLoader(ds_train, batch_size=bs_train,shuffle=True, num_workers = 10,pin_memory = True)
val_loader=DataLoader(ds_val, batch_size=bs_val,shuffle=False, num_workers = 3,pin_memory = True)

In [None]:
device = torch.device("cuda")
net = unet_FS2SAI_v3(nF=nF,nu=lfsize[3],nv=lfsize[2],box_constraint = 'tanh')
net.to(device)
criterion = nn.L1Loss()
#criterion = nn.MSELoss()
optimizer=optim.Adam(net.parameters(),lr=lr)
scheduler = MultiStepLR(optimizer, milestones=[3,6,10,20], gamma=0.5)

In [None]:
#t1 =time.time()
step=0
for epoch in range(100):    
    scheduler.step()
    print("Current epoch number%d" %epoch) 
    for idx,data in enumerate(train_loader,0):
        net.train()
        FS,LF=data['FS'].to(device),data['LF'].to(device) 
        SAI = LF[:,:,SAI_iv,SAI_iu,:,:] #B,C,H,W
        reconSAI=net(FS)
        loss=criterion(reconSAI,SAI)
        print('Train Loss is %3f' %(loss.item()))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        
        if step % 10 == 0:
            writer.add_scalar('loss', loss.item(), step)    
        if step % 400 == 0:
            print('Train visualization')
            reconSAI_grid = show_SAI(reconSAI.detach().cpu(), isshow = False)
            SAI_grid = show_SAI(SAI.detach().cpu(), isshow = False) 
            FS_grid = show_FS(FS, isshow = False)
            writer.add_image('Training reconLF SAI', reconSAI_grid, step)
            writer.add_image('Training trueLF SAI', SAI_grid, step)
            writer.add_image('Training FS', FS_grid, step)
        step = step + 1 
            
    #Calculate Full loss across entire val dataset every epoch
    Full_loss = 0
    for idx,data in enumerate(val_loader,0):
        net.eval()
        FS,LF=data['FS'].to(device),data['LF'].to(device) 
        SAI = LF[:,:,SAI_iv,SAI_iu,:,:] #B,C,H,W
        with torch.no_grad():
            reconSAI=net(FS)
            loss=criterion(reconSAI,SAI)
        Full_loss += loss.item()
        print('Minibatch val_loss at the end of epoch %d is:%.4f' %(epoch, loss.item()))
        
        if idx == 0:
            print('Validation visualization for first batch')
            reconSAI_grid = show_SAI(reconSAI.detach().cpu(), isshow = False)
            SAI_grid = show_SAI(SAI.detach().cpu(), isshow = False) 
            FS_grid = show_FS(FS.cpu(), isshow = False)
            writer.add_image('Val reconLF SAI', reconSAI_grid, epoch)
            writer.add_image('Val trueLF SAI', SAI_grid, epoch)
            writer.add_image('Val FS', FS_grid, epoch)          
        
    Full_loss = Full_loss/len(val_loader) # this assumes each batch has same size
    print('Full val_loss at the end of epoch %d is:%.4f' %(epoch,Full_loss))
    writer.add_scalar('Full Val loss', Full_loss, epoch)    
    torch.save(net.state_dict(), os.path.join(log_path, 'model.pth'))
        
