In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from unet_layers import *
import h5py
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tensorboardX import SummaryWriter
import os
import numpy as np
import time
import matplotlib.pyplot as plt
from data_utils import FSdataset,my_paired_RandomCrop,my_paired_normalize, my_paired_gamma_correction
from visualize import show_FS,show_EPI_xu,show_EPI_yv,show_SAI
from torch.optim.lr_scheduler import MultiStepLR
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']="0"
np.random.seed(100);
torch.manual_seed(100);
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False #Setting to True may leads to faster but undeterminsitc result.

In [None]:
bs_train = 9
bs_val = 5
lr = 1e-3

nF=7
lfsize = [372, 540, 8, 8] #H,W,v,u
#dimensions of Lytro light fields, H,W,nv,nu. 
#Note original Lytro LF has dimension 376X541 X 14 X 14, the paper takes only first 372/540 spatial pixel and central 8 by 8 SAI
#which is being followed here

transform_train = T.Compose([my_paired_normalize(255),my_paired_RandomCrop(256),my_paired_gamma_correction(0.4,1)])
#transform_train = T.Compose([my_normalize(255)]) 
transform_val = T.Compose([my_paired_normalize(255)])

ds_train = FSdataset(lfsize = lfsize, FSdata_path='FS_dataset/FS_dmin_-1_dmax_0.3_nF_7.h5',LFdata_folder='Flowers_8bit/',trainorval='train',transform = transform_train)
ds_val =  FSdataset(lfsize = lfsize, FSdata_path='FS_dataset/FS_dmin_-1_dmax_0.3_nF_7.h5',LFdata_folder='Flowers_8bit/',trainorval='val',transform = transform_val)

log_path = 'logs/no_crop/FS_dmin_-1_dmax_0.3_nF_7/unet_FS2LF_v3_tanh/lr_1e-3_bs_train_9_bs_val_5_crop_256_gamma_corr_0p4_1'
#log_path = 'logs/no_crop/FS_dmin_-1_dmax_0.3_nF_7/unet_FS2LF_v3_tanh/lr_1e-3_gamma_0p5_3_6_10_20_bs_train_1_bs_val_5_crop_256_gamma_corr_0p4_1'
#writer = SummaryWriter('logs/no_crop/FS_dmin_-1_dmax_0.3_nF_7/unet_FS2LF_v2/lr_1e-3_bs_train_4_bs_val_10_pin_memory')
writer = SummaryWriter(log_path)

train_loader=DataLoader(ds_train, batch_size=bs_train,shuffle=True, num_workers = 3,pin_memory = True)
val_loader=DataLoader(ds_val, batch_size=bs_val,shuffle=False, num_workers = 0,pin_memory = True)

In [None]:

device = torch.device("cuda")
net = unet_FS2LF_v3(nF=nF,nu=lfsize[3],nv=lfsize[2],box_constraint = 'tanh')
net.to(device)
criterion = nn.L1Loss()
#criterion = nn.MSELoss()
optimizer=optim.Adam(net.parameters(),lr=lr)
#scheduler = MultiStepLR(optimizer, milestones=[3,6,10,20], gamma=0.5)

In [None]:
#t1 =time.time()
step=0
for epoch in range(50):
    #scheduler.step()
    print("Current epoch number%d" %epoch) 
    for idx,data in enumerate(train_loader,0):
        net.train()
        FS,LF=data['FS'].to(device),data['LF'].to(device) # 2019 4 16: check loaded FS, LF dimension and make sure they are compatible with network struture here
        reconLF=net(FS)
        loss=criterion(reconLF,LF)
        print('Train Loss is %3f' %(loss.item()))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        
        if step % 10 == 0:
            writer.add_scalar('loss', loss.item(), step)    
        if step % 100 == 0:
            
            print('Train visualization')
            reconSAI_grid = show_SAI(reconLF.detach().cpu(),[(0,0),(5,5)], isshow = False)
            reconEPI_xu_grid = show_EPI_xu(reconLF.detach().cpu(),[(100,0),(200,3)],isshow = False)
            SAI_grid = show_SAI(LF.detach().cpu(),[(0,0),(5,5)], isshow = False) 
            EPI_xu_grid = show_EPI_xu(LF.detach().cpu(),[(100,0),(200,3)], isshow = False)    
            FS_grid = show_FS(FS, isshow = False)
            writer.add_image('Training reconLF SAI [0,0], [5,5]', reconSAI_grid, step)
            writer.add_image('Training reconLF EPI [100,0], [200,3]', reconEPI_xu_grid, step)
            writer.add_image('Training trueLF SAI [0,0], [5,5]', SAI_grid, step)
            writer.add_image('Training trueLF EPI [100,0], [200,3]', EPI_xu_grid, step)  
            writer.add_image('Training FS', FS_grid, step)  
        step = step + 1 
#        print('step %d, time %f' %(step,time.time()-t1))
#    if np.mod(epoch,1)==0: 
        if step % 100 == 0:
            #torch.cuda.empty_cache()
            net.eval()
            data=next(iter(val_loader))
            FS,LF=data['FS'].to(device),data['LF'].to(device) 

            with torch.no_grad():
                reconLF=net(FS)
                loss=criterion(reconLF,LF)
            print('Validation visualization')
            reconSAI_grid = show_SAI(reconLF.detach().cpu(),[(0,0),(5,5)], isshow = False)
            reconEPI_xu_grid = show_EPI_xu(reconLF.detach().cpu(),[(100,0),(200,3)],isshow = False)
            SAI_grid = show_SAI(LF.detach().cpu(),[(0,0),(5,5)], isshow = False) 
            EPI_xu_grid = show_EPI_xu(LF.detach().cpu(),[(100,0),(200,3)], isshow = False)      
            FS_grid = show_FS(FS.cpu(), isshow = False)
            writer.add_image('Val reconLF SAI [0,0], [5,5]', reconSAI_grid, step)
            writer.add_image('Val reconLF EPI [100,0], [200,3]', reconEPI_xu_grid, step)
            writer.add_image('Val trueLF SAI [0,0], [5,5]', SAI_grid, step)
            writer.add_image('Val trueLF EPI [100,0], [200,3]', EPI_xu_grid, step)
            writer.add_image('Val FS', FS_grid, step)  
            print('Val Loss is %3f' %(loss.item()))#for multiple test sample

            writer.add_scalar('Val loss', loss.item(), step)
            torch.save(net.state_dict(), os.path.join(log_path, 'model.pth'))
            #torch.cuda.empty_cache()
        
