## For Generating .h5 dataset of SAI from trained ViewNet. 
### To run:
1. Modify the network path, the SAI dataset saving path f. 
2. Check the FSdata_path , lfsize (depending on whether avoid inverse crime or not), SAI_iv, iu, nF, the network constructor is choosing right network.
3. Whether the right transform is used, Note for avoiding iverse crime the FS has to divided by about 9000
4. Check avoid_invCrime boolean

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from unet_layers import *
import h5py
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tensorboardX import SummaryWriter
import os
import numpy as np
import time
import matplotlib.pyplot as plt
from data_utils import FSdataset_withName,FSdataset_withName_h5,my_paired_RandomCrop,my_paired_normalize, my_paired_gamma_correction
from visualize import show_FS,show_EPI_xu,show_EPI_yv,show_SAI
from torch.optim.lr_scheduler import MultiStepLR
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']="0"
np.random.seed(100);
torch.manual_seed(100);
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False #Setting to True may leads to faster but undeterminsitc result.

In [None]:
bs_train = 1 # has to be one
bs_val = 1 # has to be one

avoid_invCrime = False # Change to false if not 

nF=7
SAI_iv = 3 #index of the SAI to be selected as All in focus image, countin from 0 
SAI_iu = 3 #index of the SAI to be selected as All in focus image, countin from 0
#dimensions of Lytro light fields, H,W,nv,nu. 
#Note original Lytro LF has dimension 376X541 X 14 X 14, the paper takes only first 372/540 spatial pixel and central 8 by 8 SAI
#which is being followed here
if not avoid_invCrime:
    lfsize = [372, 540, 8, 8] #H,W,v,u
    transform_train = T.Compose([my_paired_normalize(255)])
    transform_val = T.Compose([my_paired_normalize(255)])
    ds_train = FSdataset_withName(lfsize = lfsize, FSdata_path='/home/zyhuang/EVO970Plus/FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_GenPy_FSview_rounding_true.h5',\
                              LFdata_folder='/home/zyhuang/EVO970Plus/Flowers_8bit/',trainorval='train',transform = transform_train)
    ds_val =  FSdataset_withName(lfsize = lfsize, FSdata_path='/home/zyhuang/EVO970Plus/FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_GenPy_FSview_rounding_true.h5',\
                             LFdata_folder='/home/zyhuang/EVO970Plus/Flowers_8bit/',trainorval='val',transform = transform_val)
else:
    lfsize = [185, 269, 7, 7] #H,W,v,u
    transform_train = T.Compose([my_paired_normalize({'FS':9000,'LF':1})]) #Since FS is not normalized in matlab and LF is normalized already to [0,1]
    transform_val = T.Compose([my_paired_normalize({'FS':9000,'LF':1})])
    ds_train = FSdataset_withName_h5(FSdata_path='/home/zyhuang/EVO970Plus/FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',\
                              LFdata_path='/home/zyhuang/EVO970Plus/LF_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',trainorval='train',transform = transform_train)
    ds_val =  FSdataset_withName_h5(FSdata_path='/home/zyhuang/EVO970Plus/FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',\
                             LFdata_path='/home/zyhuang/EVO970Plus/LF_dataset/FS_dmin_-1_dmax_0.3_nF_7_inverseCrime.h5',trainorval='val',transform = transform_val)
    
train_loader=DataLoader(ds_train, batch_size=bs_train,shuffle=False, num_workers = 8,pin_memory = True)
val_loader=DataLoader(ds_val, batch_size=bs_val,shuffle=False, num_workers = 8,pin_memory = True)

path = 'logs/Not_Avoid_invcrime/Two stage model/ViewNet/FS_dmin_-1_dmax_0.3_nF_7_GenPy_FSview_rounding_true/unet_FS2SAI_v3_tanh/lr_3e-4_bs_train_2_bs_val_5/model.pth' #Model path of view net
f = h5py.File('SAI_dataset/FS_dmin_-1_dmax_0.3_nF_7_GenPy_FSview_rounding_true_unet_FS2SAI_v3_tanh_lr_3e-4_bs_train_2_bs_val_5.h5', 'w') #Path of saving the generated SAI dataset

In [None]:
device = torch.device("cuda")
net = unet_FS2SAI_v3(nF=nF,nu=lfsize[3],nv=lfsize[2],box_constraint = 'tanh')
net.to(device)
net.load_state_dict(torch.load(path))
criterion = nn.L1Loss()

In [None]:
net.eval()
f.create_group('train')
f.create_group('val')
Total_loss = 0
for idx,data_withName in enumerate(val_loader,0):
    data,name = data_withName
    name = name[0] # unpack name list, due to dataloader collacate fn, name is in a size 1 list.
    FS,LF=data['FS'].to(device),data['LF'].to(device) 
    SAI = LF[:,:,SAI_iv,SAI_iu,:,:] #B,C,H,W
    with torch.no_grad():
        reconSAI=net(FS)
        loss=criterion(reconSAI,SAI)
        print("[%d/%d], loss is %.4f" %(idx+1,len(val_loader),loss.item()))
        Total_loss += loss.item()
    f['val'].create_dataset(name,data = torch.squeeze(reconSAI).cpu().numpy())
print("Mean Loss is {:.4f}".format(Total_loss/len(val_loader)))
Total_loss = 0
for idx,data_withName in enumerate(train_loader,0):
    data,name = data_withName
    name = name[0] # unpack name list, due to dataloader collacate fn, name is in a size 1 list.
    FS,LF=data['FS'].to(device),data['LF'].to(device) 
    SAI = LF[:,:,SAI_iv,SAI_iu,:,:] #B,C,H,W
    with torch.no_grad():
        reconSAI=net(FS)
        loss=criterion(reconSAI,SAI)
        print("[%d/%d], loss is %.4f" %(idx+1,len(train_loader),loss.item()))
        Total_loss += loss.item()
    f['train'].create_dataset(name,data = torch.squeeze(reconSAI).cpu().numpy())
print("Mean Loss is {:.4f}".format(Total_loss/len(train_loader)))                            
f.close()
    
    