For Generating Focal stack h5 dataset using add shift algorithm written by myself. (Used in network training not avoiding inverse crime)

In [None]:
from PIL import Image 
import numpy as np
import matplotlib.pyplot as plt
import h5py 
import os
from data_utils import read_lytroLF_as5D

In [None]:
lfsize = [372, 540, 8, 8] #dimensions of Lytro light fields, H,W,nv,nu. 
#Note original Lytro LF has dimension 376X541 X 14 X 14, the paper takes only first 372/540 spatial pixel and central 8 by 8 SAI
#which is being followed here



def bilinear_interpolate(im, x, y):
    """
    Input: im (H,W)
    x,y: index point or 2D grid (could be float number), in range 0,1,...(H-1) or (W-1) to interpolate
    
    Output: interpolated image at point or input 2D grid
    
    eg: 
    im = np.random.rand(100,200)
    x = [1,5]
    y= [2,55,22]
    X, Y= np.meshgrid(x,y,indexing = 'xy')
    bilinear_interpolate(im, X, Y)
    """
    x = np.asarray(x)
    y = np.asarray(y)

    x0 = np.floor(x).astype(int)
    x1 = x0 + 1
    y0 = np.floor(y).astype(int)
    y1 = y0 + 1

    x0 = np.clip(x0, 0, im.shape[1]-1);
    x1 = np.clip(x1, 0, im.shape[1]-1);
    y0 = np.clip(y0, 0, im.shape[0]-1);
    y1 = np.clip(y1, 0, im.shape[0]-1);
    
    
    Ia = im[ y0, x0 ]
    Ib = im[ y1, x0 ]
    Ic = im[ y0, x1 ]
    Id = im[ y1, x1 ]
    
    wa = (x1-x) * (y1-y)
    wb = (x1-x) * (y-y0)
    wc = (x-x0) * (y1-y)
    wd = (x-x0) * (y-y0)

    return wa*Ia + wb*Ib + wc*Ic + wd*Id

def shift_img_2D_bilinear(im,shift_x,shift_y):
    """
    Shift 2D image by pixel amount specified in shift_x and shift_y, using bi-linear interpolation.
    out of range region are filled with 0.
    input: 
        im: 2D image of shape H,W, shiting_x >0 means shift img to right. shifting_y >0 means shift img up
    Output: 
    shifted image, with 0 padding
    """
    H,W = im.shape
    X,Y = np.meshgrid(np.arange(W) - shift_x, np.arange(H) + shift_y, indexing = 'xy')
    return bilinear_interpolate(im, X, Y)
def generate_FS(LF,disparities,FSview_rounding = True):
    """
    Input：
    LF： LF array in shape of 3,H,W,nv,nu
    disparities: list of disparity at which to focus, the length gives the size of FS, i.e. nF. 
    Disparity <0 means refocusing farther than the focal plane
    FSview_rounding: Whether generate FS at integer u,v view point floor rounded (True) or just the true Center view (False).
    Output:
    refocused images in shape of nF,C,H,W
    """
    
    nC,H,W,nv,nu = LF.shape
    nF = len(disparities)
    FS = np.zeros([nF,3,H,W])
    I_tmp = np.zeros([3,H,W])
    
    X, Y = np.meshgrid(range(W),range(H),indexing = 'xy')
    for iF in range(nF):
        disp = disparities[iF]
        for iv in range(nv):
            for iu in range(nu):
                for ic in range(nC):
                    #I_tmp[ic] = bilinear_interpolate(LF[ic,:,:,iv,iu], X-disp*(iu-(nu-1)/2), Y+disp*(iv-(nv-1)/2)) # for MIT LF in https://github.com/MITComputationalCamera/LightFields, there is a sign difference because the LF given has opposite covention of positive v direction
                    if FSview_rounding:
                        I_tmp[ic] = bilinear_interpolate(LF[ic,:,:,iv,iu], X-disp*(iu-np.floor((nu-1)/2)), Y-disp*(iv-np.floor((nv-1)/2))) #For Flower CVPR LF dataset
                    else:
                        I_tmp[ic] = bilinear_interpolate(LF[ic,:,:,iv,iu], X-disp*(iu-(nu-1)/2), Y-disp*(iv-(nv-1)/2)) #For Flower CVPR LF dataset                       
                FS[iF] += I_tmp/(nv*nu)
    return FS
    

In [None]:
#%matplotlib auto
lf = read_lytroLF_as5D('Flowers_8bit/IMG_3679_eslf.png',lfsize)
plt.figure()
plt.imshow(lf[:,:,:,7,0].transpose(1,2,0))

In [None]:
disps=np.linspace(-1,0.3,7)
FS = generate_FS(lf,disps)
#plt.imshow(FS[0].transpose(1,2,0).astype(int))
for i in range(len(disps)):
    plt.figure()
    plt.title('%f' %disps[i])
    plt.imshow(FS[i].transpose(1,2,0).astype(int))

In [None]:
import random
from random import shuffle
def prepare_FS_dataset(LF_datafolder_path,disparities,FS_saving_path,val_size= None, peek = False):
    """
    e.g:prepare_FS_dataset('Flowers_8bit',np.linspace(-1,0.3,7),'FS_dataset/test.h5')
    val_size determine the size of validation set, the size of training set is the rest. 
    """
    random.seed(0) #for deterministic shuffling
    LF_filenames = [f for f in os.listdir(LF_datafolder_path) if not f.startswith('.')]
    shuffle(LF_filenames) # make the list more random
    LF_filenames_val = LF_filenames[:val_size]
    LF_filenames_train = LF_filenames[val_size:] #rest of the data
    f = h5py.File(FS_saving_path, 'w')
    f.create_group('train')
    f.create_group('val')
    #f.create_dataset('disparities',data=np.array(disparities))
    
    
    
    #generate data for validation
    i = 0
    for LF_fname in LF_filenames_val:
        print("Val dataset [%f/%f]" %(i+1,len(LF_filenames_val)))
        lf = read_lytroLF_as5D(os.path.join(LF_datafolder_path,LF_fname),lfsize)
        FS = generate_FS(lf,disparities)
        FS = np.rint(FS).astype(int) #round and convert FS to nearest integer
        f['val'].create_dataset(LF_fname, data=FS)
        if peek:
            plt.figure() #fi
            plt.title('%f' %disparities[0])
            plt.imshow(FS[0].transpose(1,2,0).astype(int))
            plt.show()
            plt.figure()
            plt.title('%f' %disparities[-1])
            plt.imshow(FS[-1].transpose(1,2,0).astype(int))
            plt.show()
        i += 1    
    
    #generate data for training
    i = 0
    for LF_fname in LF_filenames_train:
        print("Training dataset [%f/%f]" %(i+1,len(LF_filenames_train)))
        lf = read_lytroLF_as5D(os.path.join(LF_datafolder_path,LF_fname),lfsize)
        FS = generate_FS(lf,disparities)
        FS = np.rint(FS).astype(int) #round and convert FS to nearest integer
        f['train'].create_dataset(LF_fname, data=FS)
        if peek:
            plt.figure() #fi
            plt.title('%f' %disparities[0])
            plt.imshow(FS[0].transpose(1,2,0).astype(int))
            plt.show()
            plt.figure()
            plt.title('%f' %disparities[-1])
            plt.imshow(FS[-1].transpose(1,2,0).astype(int))
            plt.show()
        i += 1
    f.close()
        

In [None]:
prepare_FS_dataset('Flowers_8bit',np.linspace(-1,0.3,7),'FS_dataset/FS_dmin_-1_dmax_0.3_nF_7_GenInPy_FSview_rounding_true.h5',val_size=100, peek = False)