<h3>Imports</h3>

In [1]:
import cv2
import numpy as np
import torch
import os
import math
import random
import torch
from torch import nn, optim
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.io import read_image
import matplotlib.pyplot as plt

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))
if device == 'cuda:0':
    print("Device: %d" % torch.cuda.current_device())
    print("Device count: %d" % torch.cuda.device_count())
    print("Device name: %s" % torch.cuda.get_device_name(torch.cuda.current_device()))
    
# uncomment if its peak hours
# device = 'cpu'

Using cuda:0 device
Device: 0
Device count: 1
Device name: GeForce GTX 1080 Ti


<h3>Everything from the other notebook</h3>

In [2]:
train_ours = os.listdir('../../../../../datasets/cs274-fa21-A00-public/TrainingSet/OURS')
train_stanf = os.listdir('../../../../../datasets/cs274-fa21-A00-public/TrainingSet/STANFORD')
test_paper = os.listdir('../../../../../datasets/cs274-fa21-A00-public/TestSet/PAPER')
test_extra = os.listdir('../../../../../datasets/cs274-fa21-A00-public/TestSet/EXTRA')
# print(train_ours)
train_dir = '../../../../../datasets/cs274-fa21-A00-public/TrainingSet/'
test_dir = '../../../../../datasets/cs274-fa21-A00-public/TestSet/'
names = ["OURS/"+a for a in train_ours] + ["STANFORD/"+b for b in train_stanf]
test_names = ["PAPER/"+a for a in test_paper] + ["EXTRA/"+b for b in test_extra]
print(test_names)

patchSize = 60
patchStride = 16

disparity_levels = 100
disparity_range = (-21, 21)
n = 4

def back_warp(L_pi, q, d):
    # used to create features
    # given h x w x 3 x 4 lightfield views, output h x w x 3 x 4 x L (L=100 disparity levels) tensor
    # At each last index k, the image represents that view warped by disparity level k
    # q is the novel view pos; d is a linspace of 100 disparity levels -21, ... , 21
    h, w = L_pi.shape[:2]
    xs = torch.linspace(0, w, steps=w)
    ys = torch.linspace(0, h, steps=h)
    y, x = torch.meshgrid(ys, xs)
    x = x.repeat(100, 1, 1).transpose(0, 1).transpose(1, 2) # size h x w x 100
    y = y.repeat(100, 1, 1).transpose(0, 1).transpose(1, 2)
    L_pi = torch.permute(L_pi, (3, 2, 0, 1)) # 4 x 3 x h x w for grid_sample
    back_warped = torch.zeros(h, w, 1, 4, disparity_levels)
    for i in range(4):
        pi = torch.tensor([0 if (i == 0 or i == 2) else 7, 0 if (i == 0 or i == 1) else 7])
#         print(pi.shape)
#         print(q.shape)
#         print(d.shape)
#         print(x.shape)
        xd = x + (pi[0]-q[0])*d
        yd = y + (pi[1]-q[1])*d
        # grid sample assumes top left is -1, -1 and bottom right is 1, 1
        xd = torch.clamp((xd*2.0/w) - 1.0, min=-1, max=1)
        yd = torch.clamp((yd*2.0/h) - 1.0, min=-1, max=1)
        grid = torch.stack((xd, yd), dim=3) # h x w x 100 x 2
        grid = torch.permute(grid, (2, 0, 1, 3)) # 100 x h x w x 2 for grid_sample
        back_warped[:, :, :, i, :] = torch.permute(
            F.grid_sample(L_pi[i, :, :, :].unsqueeze(0).repeat(disparity_levels, 1, 1, 1), # d x 3 x h x w
                          grid, # d x h x w x 2
                          mode='bicubic'), 
            (2, 3, 1, 0)) # grid sample outputs d x 3 x h x w
    return back_warped # h x w x 3 x 4 x disparity_levels

class LightFieldDataSet(Dataset):
    def __init__(self, img_dir, img_names):
        self.img_dir = img_dir
        self.img_names = img_names
    def __len__(self):
        return len(self.img_names)
    def __getitem__(self, idx):
        with torch.no_grad():
            # LOAD THE IMAGE AND TURN IT FROM LIGHTFIELD TO SEPARATED IMAGES
            img_path = os.path.join(self.img_dir, self.img_names[idx])#'/'.join(self.img_dir, self.img_names[idx])#self.img_dir, self.img_names[idx])
            image = plt.imread(img_path)
            numImgsX = 14
            numImgsY = 14
            h = int(image.shape[0] / numImgsY)
            w = int(image.shape[1] / numImgsX)
            fullLF = np.zeros((h, w, 3, numImgsX, numImgsY))
            for ax in range(numImgsX):
                for ay in range(numImgsY):
                    #print(im1[ay::numImgsY, ax::numImgsX, :].shape)
                    fullLF[:,:,:, ax, ay] = image[ay::numImgsY, ax::numImgsX, :3]
            if h == 375 and w == 540:
                np.pad(fullLF, ((0, 1), (0, 1), (0, 0), (0, 0), (0, 0)))
            if h == 375 and w == 541:
                np.pad(fullLF, ((0, 1), (0, 0), (0, 0), (0, 0), (0, 0)))
            fullLF = fullLF[:, :, :, 4:12, 4:12] # 8 middle images
            inr = np.array(((0, 0), (7, 7)))
            inc = np.array(((0, 7), (0, 7)))
            inLF = fullLF[:, :, :, inc, inr] # 4 corner images
    #         print(inLF.shape)

            # CONVERT FULL AND INPUT LIGHTFIELDS TO TENSORS
            fullTens = torch.from_numpy(fullLF).float() # should be h x w x 3 x 8 x 8
            inTens = torch.from_numpy(inLF).float() # h x w x 3 x 2 x 2

            # CREATE THE INPUT, FEATURE, GROUND TRUTH PATCHES, AND REFERENCE VIEW FOR THE NETWORK
            # create input patches
            inP = torch.reshape(inTens, (h, w, 3, -1))
            inP = inP.unfold(0, patchSize, patchStride).unfold(1, patchSize, patchStride)
            inP = torch.flatten(inP, start_dim=0, end_dim=1)

            # choose random q
            q = [0, 0]
            while q[0] == 0 or q[0] == 7 or q[1] == 0 or q[1] == 7:
                q[0] = random.randrange(8)
                q[1] = random.randrange(8)
            q = torch.tensor(q)

            # create feature patches
            disps = torch.linspace(disparity_range[0], disparity_range[1], 100)
            inTens_gray = 0.299*inTens[:, :, 0, :, :] + 0.587*inTens[:, :, 1, :, :] + 0.114*inTens[:, :, 2, :, :] #gray
            warped = back_warp(torch.reshape(inTens_gray, (h, w, 1, -1)), q, disps) # h x w x 3(1) x 4 x disparity_levels
            warped = torch.squeeze(warped) # remove singleton dimension... gives h x w x 4 x disparity_levels
            mean_feat = torch.mean(warped, dim=2)
            stdev_feat = torch.sqrt((1.0/3.0)*torch.sum(torch.pow(warped - mean_feat.unsqueeze(2), 2), dim=2))
            featP = torch.cat((mean_feat, stdev_feat), dim=2).unfold(0, patchSize, patchStride).unfold(1, patchSize, patchStride)
            featP = torch.flatten(featP, start_dim=0, end_dim=1)
            
            # create gt patches
            gtP = fullTens[:, :, :, q[0], q[1]].unfold(0, patchSize, patchStride).unfold(1, patchSize, patchStride)
            gtP = torch.flatten(gtP, start_dim=0, end_dim=1)[:, :, 12:48, 12:48] # output patches are 36x36
            
#             print("inP")
#             print(inP.shape)
#             print("featP")
#             print(featP.shape)
#             print("gtP")
#             print(gtP.shape)
#             print("q")
#             print(q.shape)
            
            # randomly permute the tensors for training
            perm = torch.randperm(inP.shape[0])
            inP = inP[perm, :, :, :, :]
            featP = featP[perm, :, :, :]
            gtP = gtP[perm, :, :, :]

            return [inP, featP, gtP, q]

# train_ours = LightFieldDataSet('../../../../../datasets/cs274-fa21-A00-public/TrainingSet/OURS')
# train_stanf = LightFieldDataSet('../../../../../datasets/cs274-fa21-A00-public/TrainingSet/STANFORD')
trainset = LightFieldDataSet(train_dir, names)
testset = LightFieldDataSet(test_dir, test_names)

train_dataloader = DataLoader(trainset, batch_size=1, shuffle=True)
test_dataloader = DataLoader(testset, batch_size=1, shuffle=True)

class DispEst(nn.Module):
    def __init__(self):
        super(DispEst, self).__init__()
        self.c1 = nn.Conv2d(200, 100, 7)
        self.c2 = nn.Conv2d(100, 100, 5)
        self.c3 = nn.Conv2d(100, 50, 3)
        self.c4 = nn.Conv2d(50, 1, 1)
        
    def forward(self, x):
        # x is n x 200 x h x w
        # go through the layers, relu after each:
        x = self.c1(x)
        x = F.relu(x)
        x = self.c2(x)
        x = F.relu(x)
        x = self.c3(x)
        x = F.relu(x)
        x = self.c4(x)
        
        return x
    
class ColorEst(nn.Module):
    def __init__(self):
        super(ColorEst, self).__init__()
        self.c1 = nn.Conv2d(3*4 + 3, 100, 7)
        self.c2 = nn.Conv2d(100, 100, 5)
        self.c3 = nn.Conv2d(100, 50, 3)
        self.c4 = nn.Conv2d(50, 3, 1)
        
    def forward(self, x):
        # x is n x 3N+3 x h x w
        # go through the layers, relu after each:
        x = self.c1(x)
        x = F.relu(x)
        x = self.c2(x)
        x = F.relu(x)
        x = self.c3(x)
        x = F.relu(x)
        x = self.c4(x)
        
        return x
    
def forward_warp(inp, q, disp):
    # inp is n x 3 x 4 x 60 x 60
    # q is 2,
    # disp is n x 1 x 60 x 60 (?)
    ph, pw = disp.shape[2:]
    xs = torch.linspace(0, pw, steps=pw)
    ys = torch.linspace(0, ph, steps=ph)
    y, x = torch.meshgrid(ys, xs)
    y = y.repeat(disp.shape[0], 1, 1) # n x 60 x 60
    x = x.repeat(disp.shape[0], 1, 1)
    out = torch.zeros(inp.shape[0], 3*4, ph, pw).to(device)# n x 3*4 x 60 x 60
    for i in range(4):
        pi = torch.tensor([0 if (i == 0 or i == 2) else 7, 0 if (i == 0 or i == 1) else 7])
        # send x, y, pi to device
        x = x.to(device)
        y = y.to(device)
        pi = pi.to(device)
        xd = x + (pi[0]-q[0])*disp.squeeze() # still n x 60 x 60
        yd = y + (pi[1]-q[1])*disp.squeeze()
        # grid sample assumes top left is -1, -1 and bottom right is 1, 1
        xd = torch.clamp((xd*2.0/pw) - 1.0, min=-1, max=1)
        yd = torch.clamp((yd*2.0/ph) - 1.0, min=-1, max=1)
        grid = torch.stack((xd, yd), dim=3) # n x 60 x 60 x 2
        out[:, 3*i:3*i+3, :, :] = F.grid_sample(inp[:, :, i, :, :], grid, mode='bicubic')
    return out
    
class DispColorNet(nn.Module):
    def __init__(self):
        super(DispColorNet, self).__init__()
        self.d = DispEst()
        self.c = ColorEst()
    def forward(self, x, feat_d, q):
        # disparity estimation network takes in n x 200 x h x w features, outputs n x 1 x h x w disparity
        disp = self.d(feat_d)
        
        # forward warp, cat with disp and q: n x 3*4+3 x h x w
        fw = forward_warp(x, q, disp)
        q_rep = q[None, :, None, None].repeat((fw.shape[0], 1, fw.shape[2], fw.shape[3])) # n x 2 x h x w
        feat_c = torch.cat((fw, disp, q_rep), dim=1)
        
        # color estimation network takes in n x 3*4+3 x h x w, outputs n x 3 x h-12, w-12
        out = self.c(feat_c)
        
        return out
    
# evaluate model on a test image
def eval_on_image(im_path, net, q):
    net.eval()
    with torch.no_grad():
        # LOAD THE IMAGE AND TURN IT FROM LIGHTFIELD TO SEPARATED IMAGES
        img_path = im_path
        image = plt.imread(img_path)#torch.permute(read_image(img_path), (1, 2, 0)).float()
        numImgsX = 14
        numImgsY = 14
        h = int(image.shape[0] / numImgsY)
        w = int(image.shape[1] / numImgsX)
        fullLF = np.zeros((h, w, 3, numImgsX, numImgsY))
        for ax in range(numImgsX):
            for ay in range(numImgsY):
                #print(im1[ay::numImgsY, ax::numImgsX, :].shape)
                fullLF[:,:,:, ax, ay] = image[ay::numImgsY, ax::numImgsX, :3]
        if h == 375 and w == 540:
            np.pad(fullLF, ((0, 1), (0, 1), (0, 0), (0, 0), (0, 0)))
        if h == 375 and w == 541:
            np.pad(fullLF, ((0, 1), (0, 0), (0, 0), (0, 0), (0, 0)))
        fullLF = fullLF[:, :, :, 4:12, 4:12] # 8 middle images
        inr = np.array(((0, 0), (7, 7)))
        inc = np.array(((0, 7), (0, 7)))
        inLF = fullLF[:, :, :, inc, inr] # 4 corner images
#         print(inLF.shape)

        # CONVERT FULL AND INPUT LIGHTFIELDS TO TENSORS
        fullTens = torch.from_numpy(fullLF.copy()).float() # should be h x w x 3 x 8 x 8
        inTens = torch.from_numpy(inLF.copy()).float() # h x w x 3 x 2 x 2

        # CREATE THE INPUT, FEATURE, GROUND TRUTH PATCHES, AND REFERENCE VIEW FOR THE NETWORK
        # create input
        inP = torch.reshape(inTens, (h, w, 3, -1)) # h x w x 3 x 4
        in_eval = torch.permute(inP.unsqueeze(0), (0, 3, 4, 1, 2)) # 1 x h x w x 3 x 4 -> 1 x 3 x 4 x h x w for the network

        # load q into tensor
        q_eval = torch.tensor(q).long()

        # create feature patches
        disps = torch.linspace(disparity_range[0], disparity_range[1], 100)
        inTens_gray = 0.299*inTens[:, :, 0, :, :] + 0.587*inTens[:, :, 1, :, :] + 0.114*inTens[:, :, 2, :, :] #gray
        warped = back_warp(torch.reshape(inTens_gray, (h, w, 1, -1)), q_eval, disps) # h x w x 3(1) x 4 x disparity_levels
        warped = torch.squeeze(warped) # remove singleton dimension... gives h x w x 4 x disparity_levels
        mean_feat = torch.mean(warped, dim=2)
        stdev_feat = torch.sqrt((1.0/3.0)*torch.sum(torch.pow(warped - mean_feat.unsqueeze(2), 2), dim=2))
        featP = torch.cat((mean_feat, stdev_feat), dim=2) # h x w x 200
        feat_eval = torch.permute(featP.unsqueeze(0), (0, 3, 1, 2)) # 1 x h x w x 200 -> 1 x 200 x h x w for the network

        # create gt tensor
        gt = fullTens[12:-12, 12:-12, :, q_eval[0], q_eval[1]]
        
        # evaluate the network
        out_tens = net.forward(in_eval.to(device), feat_eval.to(device), q_eval.to(device))
        
        out_disp = net.d(feat_eval.to(device))
        
    # return tuple of the input, the features, the gt, and the final output
    return in_eval, feat_eval, gt, out_tens, out_disp


def colorcorr(frame):
    gamma = 1.5
    I = np.power(frame, 1.0/gamma) # apply gamma correction
    I = cv2.cvtColor(I, cv2.COLOR_RGB2HSV) # rgb to hsv
    I[:, :, 1] = I[:, :, 1] * gamma # mult value by gamma
    I = cv2.cvtColor(I, cv2.COLOR_HSV2RGB) # hsv to rgb
    return np.clip(I, 0, 1) # convert color puts negative values; just clip them

['PAPER/Cars.png', 'PAPER/Flower1.png', 'PAPER/Flower2.png', 'PAPER/Seahorse.png', 'PAPER/Rock.png', 'EXTRA/IMG_1528_eslf.png', 'EXTRA/IMG_1586_eslf.png', 'EXTRA/IMG_1184_eslf.png', 'EXTRA/IMG_1324_eslf.png', 'EXTRA/IMG_1541_eslf.png', 'EXTRA/IMG_1312_eslf.png', 'EXTRA/IMG_1086_eslf.png', 'EXTRA/IMG_1325_eslf.png', 'EXTRA/IMG_1316_eslf.png', 'EXTRA/IMG_1306_eslf.png', 'EXTRA/IMG_1555_eslf.png', 'EXTRA/IMG_1085_eslf.png', 'EXTRA/IMG_1411_eslf.png', 'EXTRA/IMG_1321_eslf.png', 'EXTRA/IMG_1419_eslf.png', 'EXTRA/IMG_1390_eslf.png', 'EXTRA/IMG_1743_eslf.png', 'EXTRA/IMG_1389_eslf.png', 'EXTRA/IMG_1327_eslf.png', 'EXTRA/IMG_1328_eslf.png', 'EXTRA/IMG_1317_eslf.png', 'EXTRA/IMG_1340_eslf.png', 'EXTRA/IMG_1554_eslf.png', 'EXTRA/IMG_1187_eslf.png', 'EXTRA/IMG_1320_eslf.png']


<h3>Create Synthesized Views Video File</h3>

In [8]:
def draw_uv_on_frame(frame, u, v):
    # returns another frame that has the u, v representation drawn-on... 4 corner boxes are color, selected box is sel_color
    # frame is h x w x 3
    new_frame = frame
    offset = (6, 6) # row/col offset
    box_shape = (6, 6) # row/col box shape
    # draw corners
    color = np.array([1.0, 0.5, 0])
    new_frame[offset[0]+0*box_shape[0]:offset[0]+1*box_shape[0], offset[1]+0*box_shape[1]:offset[1]+1*box_shape[1], :] = color # tl
    new_frame[offset[0]+0*box_shape[0]:offset[0]+1*box_shape[0], offset[1]+7*box_shape[1]:offset[1]+8*box_shape[1], :] = color # tr
    new_frame[offset[0]+7*box_shape[0]:offset[0]+8*box_shape[0], offset[1]+0*box_shape[1]:offset[1]+1*box_shape[1], :] = color # bl
    new_frame[offset[0]+7*box_shape[0]:offset[0]+8*box_shape[0], offset[1]+7*box_shape[1]:offset[1]+8*box_shape[1], :] = color # br
    # draw selected box
    sel_color = np.array([0.8, 0.8, 0.8])
    new_frame[offset[0]+v*box_shape[0]:offset[0]+(v+1)*box_shape[0], offset[1]+u*box_shape[1]:offset[1]+(u+1)*box_shape[1], :] = sel_color
    return new_frame
    

load_from_file = True
im_filepath = "../../../../../datasets/cs274-fa21-A00-public/TestSet/EXTRA/IMG_1325_eslf.png"
net_filepath = "./best3_loss_18.pt"
net = DispColorNet()

if load_from_file:
    checkpoint = torch.load(net_filepath, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    e0 = checkpoint['epoch']

net.to(torch.device(device))

frameSize = (517, 352)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('./1325final.mp4', fourcc, float(30), frameSize)
repetitions = 1
held_frames = 3

mode = 'circle' # horizontal, vertical, or circle

if mode == 'horizontal':
    for rep in range(repetitions):
        print("Repetition: %d" % rep)
        for v in range(8):
            for u in range(8)[::1 if v % 2 == 0 else -1]:
                print("Writing view at (u, v) = %d, %d" % (u, v))
                in_eval, feat_eval, gt, output, out_disp = eval_on_image(im_filepath, net, [u, v])
                curr_frame = torch.clamp(output[0, :, :, :], min=0.0, max=1.0).permute((1, 2, 0)).cpu().numpy()
                curr_frame = colorcorr(curr_frame)
                img = draw_uv_on_frame(curr_frame, u, v)
                img = (img*255).astype(np.uint8)[:, :, ::-1]
                for t in range(held_frames):
                    out.write(img)
elif mode == 'vertical':
    for rep in range(repetitions):
        print("Repetition: %d" % rep)
        for u in range(8):
            for v in range(8)[::1 if u % 2 == 0 else -1]:
                print("Writing view at (u, v) = %d, %d" % (u, v))
                in_eval, feat_eval, gt, output, out_disp = eval_on_image(im_filepath, net, [u, v])
                curr_frame = torch.clamp(output[0, :, :, :], min=0.0, max=1.0).permute((1, 2, 0)).cpu().numpy()
                curr_frame = colorcorr(curr_frame)
                img = draw_uv_on_frame(curr_frame, u, v)
                img = (img*255).astype(np.uint8)[:, :, ::-1]
                for t in range(held_frames):
                    out.write(img)
elif mode == 'circle':
    us = list(range(8)) + [7]*6 + list(range(8)[::-1]) + [0]*6
    vs = [0]*7 + list(range(8)) + [7]*6 + list(range(8)[::-1])[:-1]
    for u, v in zip(us, vs):
        print("Writing view at (u, v) = %d, %d" % (u, v))
        in_eval, feat_eval, gt, output, out_disp = eval_on_image(im_filepath, net, [u, v])
        curr_frame = torch.clamp(output[0, :, :, :], min=0.0, max=1.0).permute((1, 2, 0)).cpu().numpy()
        curr_frame = colorcorr(curr_frame)
        img = draw_uv_on_frame(curr_frame, u, v)
        img = (img*255).astype(np.uint8)[:, :, ::-1]
        for t in range(held_frames):
            out.write(img)
    
out.release()

Writing view at (u, v) = 0, 0
Writing view at (u, v) = 1, 0
Writing view at (u, v) = 2, 0
Writing view at (u, v) = 3, 0
Writing view at (u, v) = 4, 0
Writing view at (u, v) = 5, 0
Writing view at (u, v) = 6, 0
Writing view at (u, v) = 7, 0
Writing view at (u, v) = 7, 1
Writing view at (u, v) = 7, 2
Writing view at (u, v) = 7, 3
Writing view at (u, v) = 7, 4
Writing view at (u, v) = 7, 5
Writing view at (u, v) = 7, 6
Writing view at (u, v) = 7, 7
Writing view at (u, v) = 6, 7
Writing view at (u, v) = 5, 7
Writing view at (u, v) = 4, 7
Writing view at (u, v) = 3, 7
Writing view at (u, v) = 2, 7
Writing view at (u, v) = 1, 7
Writing view at (u, v) = 0, 7
Writing view at (u, v) = 0, 6
Writing view at (u, v) = 0, 5
Writing view at (u, v) = 0, 4
Writing view at (u, v) = 0, 3
Writing view at (u, v) = 0, 2
Writing view at (u, v) = 0, 1


<h3>Create Disparity Features Video</h3>

In [4]:
load_from_file = True
im_filepath = "../../../../../datasets/cs274-fa21-A00-public/TestSet/PAPER/Seahorse.png"
net_filepath = "./best_loss_54.pt"
net = DispColorNet()

if load_from_file:
    checkpoint = torch.load(net_filepath, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    e0 = checkpoint['epoch']

net.to(torch.device(device))

frameSize = (541, 376)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('./dstdev33.mp4', fourcc, float(30), frameSize)

repetitions = 4
held_frames = 2
u, v = 3, 3

in_eval, feat_eval, gt, output, out_disp = eval_on_image(im_filepath, net, [u, v])
feats = feat_eval[0, :, :, :].permute((1, 2, 0)).cpu().numpy()
offset = 100

for rep in range(repetitions):
    print("Repetition: %d" % rep)
    for d in range(disparity_levels):
        print("Writing disparity = %d" % d)
        curr_frame = np.repeat(feats[:, :, d+offset][:, :, None], 3, axis=2)
        print(curr_frame.shape)
        img = (curr_frame*255).astype(np.uint8)
        for t in range(held_frames):
            out.write(img)

out.release()

Repetition: 0
Writing disparity = 0
(376, 541, 3)
Writing disparity = 1
(376, 541, 3)
Writing disparity = 2
(376, 541, 3)
Writing disparity = 3
(376, 541, 3)
Writing disparity = 4
(376, 541, 3)
Writing disparity = 5
(376, 541, 3)
Writing disparity = 6
(376, 541, 3)
Writing disparity = 7
(376, 541, 3)
Writing disparity = 8
(376, 541, 3)
Writing disparity = 9
(376, 541, 3)
Writing disparity = 10
(376, 541, 3)
Writing disparity = 11
(376, 541, 3)
Writing disparity = 12
(376, 541, 3)
Writing disparity = 13
(376, 541, 3)
Writing disparity = 14
(376, 541, 3)
Writing disparity = 15
(376, 541, 3)
Writing disparity = 16
(376, 541, 3)
Writing disparity = 17
(376, 541, 3)
Writing disparity = 18
(376, 541, 3)
Writing disparity = 19
(376, 541, 3)
Writing disparity = 20
(376, 541, 3)
Writing disparity = 21
(376, 541, 3)
Writing disparity = 22
(376, 541, 3)
Writing disparity = 23
(376, 541, 3)
Writing disparity = 24
(376, 541, 3)
Writing disparity = 25
(376, 541, 3)
Writing disparity = 26
(376, 541, 

(376, 541, 3)
Writing disparity = 27
(376, 541, 3)
Writing disparity = 28
(376, 541, 3)
Writing disparity = 29
(376, 541, 3)
Writing disparity = 30
(376, 541, 3)
Writing disparity = 31
(376, 541, 3)
Writing disparity = 32
(376, 541, 3)
Writing disparity = 33
(376, 541, 3)
Writing disparity = 34
(376, 541, 3)
Writing disparity = 35
(376, 541, 3)
Writing disparity = 36
(376, 541, 3)
Writing disparity = 37
(376, 541, 3)
Writing disparity = 38
(376, 541, 3)
Writing disparity = 39
(376, 541, 3)
Writing disparity = 40
(376, 541, 3)
Writing disparity = 41
(376, 541, 3)
Writing disparity = 42
(376, 541, 3)
Writing disparity = 43
(376, 541, 3)
Writing disparity = 44
(376, 541, 3)
Writing disparity = 45
(376, 541, 3)
Writing disparity = 46
(376, 541, 3)
Writing disparity = 47
(376, 541, 3)
Writing disparity = 48
(376, 541, 3)
Writing disparity = 49
(376, 541, 3)
Writing disparity = 50
(376, 541, 3)
Writing disparity = 51
(376, 541, 3)
Writing disparity = 52
(376, 541, 3)
Writing disparity = 53
(

<h3>Create Extrapolated Views Video</h3>

In [9]:
# extrapolate model on a test image
def eval_ext_image(im_path, net, q):
    net.eval()
    with torch.no_grad():
        # LOAD THE IMAGE AND TURN IT FROM LIGHTFIELD TO SEPARATED IMAGES
        img_path = im_path
        image = plt.imread(img_path)#torch.permute(read_image(img_path), (1, 2, 0)).float()
        numImgsX = 14
        numImgsY = 14
        h = int(image.shape[0] / numImgsY)
        w = int(image.shape[1] / numImgsX)
        fullLF = np.zeros((h, w, 3, numImgsX, numImgsY))
        for ax in range(numImgsX):
            for ay in range(numImgsY):
                #print(im1[ay::numImgsY, ax::numImgsX, :].shape)
                fullLF[:,:,:, ax, ay] = image[ay::numImgsY, ax::numImgsX, :3]
        if h == 375 and w == 540:
            np.pad(fullLF, ((0, 1), (0, 1), (0, 0), (0, 0), (0, 0)))
        if h == 375 and w == 541:
            np.pad(fullLF, ((0, 1), (0, 0), (0, 0), (0, 0), (0, 0)))
        fullLF = fullLF[:, :, :, 4:12, 4:12] # 8 middle images
        inr = np.array(((0, 0), (7, 7)))
        inc = np.array(((0, 7), (0, 7)))
        inLF = fullLF[:, :, :, inc, inr] # 4 corner images
#         print(inLF.shape)

        # CONVERT FULL AND INPUT LIGHTFIELDS TO TENSORS
        fullTens = torch.from_numpy(fullLF.copy()).float() # should be h x w x 3 x 8 x 8
        inTens = torch.from_numpy(inLF.copy()).float() # h x w x 3 x 2 x 2

        # CREATE THE INPUT, FEATURE, GROUND TRUTH PATCHES, AND REFERENCE VIEW FOR THE NETWORK
        # create input
        inP = torch.reshape(inTens, (h, w, 3, -1)) # h x w x 3 x 4
        in_eval = torch.permute(inP.unsqueeze(0), (0, 3, 4, 1, 2)) # 1 x h x w x 3 x 4 -> 1 x 3 x 4 x h x w for the network

        # load q into tensor
        q_eval = torch.tensor(q).long()

        # create feature patches
        disps = torch.linspace(disparity_range[0], disparity_range[1], 100)
        inTens_gray = 0.299*inTens[:, :, 0, :, :] + 0.587*inTens[:, :, 1, :, :] + 0.114*inTens[:, :, 2, :, :] #gray
        warped = back_warp(torch.reshape(inTens_gray, (h, w, 1, -1)), q_eval, disps) # h x w x 3(1) x 4 x disparity_levels
        warped = torch.squeeze(warped) # remove singleton dimension... gives h x w x 4 x disparity_levels
        mean_feat = torch.mean(warped, dim=2)
        stdev_feat = torch.sqrt((1.0/3.0)*torch.sum(torch.pow(warped - mean_feat.unsqueeze(2), 2), dim=2))
        featP = torch.cat((mean_feat, stdev_feat), dim=2) # h x w x 200
        feat_eval = torch.permute(featP.unsqueeze(0), (0, 3, 1, 2)) # 1 x h x w x 200 -> 1 x 200 x h x w for the network

        # create gt tensor
#         gt = fullTens[12:-12, 12:-12, :, q_eval[0], q_eval[1]]
        gt = "sorry no gt"
        
        # evaluate the network
        out_tens = net.forward(in_eval.to(device), feat_eval.to(device), q_eval.to(device))
        
        out_disp = net.d(feat_eval.to(device))
        
    # return tuple of the input, the features, the gt, and the final output
    return in_eval, feat_eval, gt, out_tens, out_disp

def draw_ext_uv_on_frame(frame, u, v):
    # returns another frame that has the u, v representation drawn-on... 4 corner boxes are color, selected box is sel_color
    # frame is h x w x 3
    new_frame = frame
    offset = (18, 18) # row/col offset
    box_shape = (6, 6) # row/col box shape
    # draw corners
    color = np.array([1.0, 0.5, 0])
    new_frame[offset[0]+0*box_shape[0]:offset[0]+1*box_shape[0], offset[1]+0*box_shape[1]:offset[1]+1*box_shape[1], :] = color # tl
    new_frame[offset[0]+0*box_shape[0]:offset[0]+1*box_shape[0], offset[1]+7*box_shape[1]:offset[1]+8*box_shape[1], :] = color # tr
    new_frame[offset[0]+7*box_shape[0]:offset[0]+8*box_shape[0], offset[1]+0*box_shape[1]:offset[1]+1*box_shape[1], :] = color # bl
    new_frame[offset[0]+7*box_shape[0]:offset[0]+8*box_shape[0], offset[1]+7*box_shape[1]:offset[1]+8*box_shape[1], :] = color # br
    # draw selected box
    sel_color = np.array([0.8, 0.8, 0.8])
    new_frame[offset[0]+v*box_shape[0]:offset[0]+(v+1)*box_shape[0], offset[1]+u*box_shape[1]:offset[1]+(u+1)*box_shape[1], :] = sel_color
    return new_frame

load_from_file = True
im_filepath = "../../../../../datasets/cs274-fa21-A00-public/TestSet/PAPER/Flower1.png"
net_filepath = "./best3_loss_18.pt"
net = DispColorNet()

if load_from_file:
    checkpoint = torch.load(net_filepath, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    e0 = checkpoint['epoch']

net.to(torch.device(device))

frameSize = (517, 352)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('./flextfinal.mp4', fourcc, float(30), frameSize)

repetitions = 1
held_frames = 2
for rep in range(repetitions):
    print("Repetition: %d" % rep)    
    us = list(range(-1, 9)) + [8]*8 + list(range(-1, 9)[::-1]) + [-1]*9
    vs = [-1]*9 + list(range(-1, 9)) + [8]*8 + list(range(-1, 9)[::-1])
    us += list(range(-2, 10)) + [9]*10 + list(range(-2, 10)[::-1]) + [-2]*11
    vs += [-2]*11 + list(range(-2, 10)) + [9]*10 + list(range(-2, 10)[::-1])
    for u, v in zip(us, vs):
        print("Writing view at (u, v) = %d, %d" % (u, v))
        in_eval, feat_eval, gt, output, out_disp = eval_ext_image(im_filepath, net, [u, v])
        curr_frame = torch.clamp(output[0, :, :, :], min=0.0, max=1.0).permute((1, 2, 0)).cpu().numpy()
        curr_frame = colorcorr(curr_frame)
        img = draw_ext_uv_on_frame(curr_frame, u, v)
        img = (img*255).astype(np.uint8)[:, :, ::-1]
        for t in range(held_frames):
            out.write(img)

out.release()

Repetition: 0
Writing view at (u, v) = -1, -1
Writing view at (u, v) = 0, -1
Writing view at (u, v) = 1, -1
Writing view at (u, v) = 2, -1
Writing view at (u, v) = 3, -1
Writing view at (u, v) = 4, -1
Writing view at (u, v) = 5, -1
Writing view at (u, v) = 6, -1
Writing view at (u, v) = 7, -1
Writing view at (u, v) = 8, -1
Writing view at (u, v) = 8, 0
Writing view at (u, v) = 8, 1
Writing view at (u, v) = 8, 2
Writing view at (u, v) = 8, 3
Writing view at (u, v) = 8, 4
Writing view at (u, v) = 8, 5
Writing view at (u, v) = 8, 6
Writing view at (u, v) = 8, 7
Writing view at (u, v) = 8, 8
Writing view at (u, v) = 7, 8
Writing view at (u, v) = 6, 8
Writing view at (u, v) = 5, 8
Writing view at (u, v) = 4, 8
Writing view at (u, v) = 3, 8
Writing view at (u, v) = 2, 8
Writing view at (u, v) = 1, 8
Writing view at (u, v) = 0, 8
Writing view at (u, v) = -1, 8
Writing view at (u, v) = -1, 7
Writing view at (u, v) = -1, 6
Writing view at (u, v) = -1, 5
Writing view at (u, v) = -1, 4
Writing vi

<h3>Create Phone Views Video</h3>

In [5]:
# extrapolate model on a test image
def eval_phone_image(ims, net, q):
    net.eval()
    with torch.no_grad():
        # LOAD THE IMAGE AND TURN IT FROM LIGHTFIELD TO SEPARATED IMAGES
        image = plt.imread(ims[0])[::8, ::8, :]
        h = int(image.shape[0])
        w = int(image.shape[1])
        inLF = np.zeros((h, w, 3, 2, 2))
        for i in range(4):
            inLF[:,:,:, i % 2, i // 2] = np.power(plt.imread(ims[i])[::8, ::8, :]/255.0, 1.5)

        # CONVERT FULL AND INPUT LIGHTFIELDS TO TENSORS
        inTens = torch.from_numpy(inLF.copy()).float() # h x w x 3 x 2 x 2

        # CREATE THE INPUT, FEATURE, GROUND TRUTH PATCHES, AND REFERENCE VIEW FOR THE NETWORK
        # create input
        inP = torch.reshape(inTens, (h, w, 3, -1)) # h x w x 3 x 4
        in_eval = torch.permute(inP.unsqueeze(0), (0, 3, 4, 1, 2)) # 1 x h x w x 3 x 4 -> 1 x 3 x 4 x h x w for the network

        # load q into tensor
        q_eval = torch.tensor(q).long()

        # create feature patches
        disps = torch.linspace(disparity_range[0], disparity_range[1], 100)
        inTens_gray = 0.299*inTens[:, :, 0, :, :] + 0.587*inTens[:, :, 1, :, :] + 0.114*inTens[:, :, 2, :, :] #gray
        warped = back_warp(torch.reshape(inTens_gray, (h, w, 1, -1)), q_eval, disps) # h x w x 3(1) x 4 x disparity_levels
        warped = torch.squeeze(warped) # remove singleton dimension... gives h x w x 4 x disparity_levels
        mean_feat = torch.mean(warped, dim=2)
        stdev_feat = torch.sqrt((1.0/3.0)*torch.sum(torch.pow(warped - mean_feat.unsqueeze(2), 2), dim=2))
        featP = torch.cat((mean_feat, stdev_feat), dim=2) # h x w x 200
        feat_eval = torch.permute(featP.unsqueeze(0), (0, 3, 1, 2)) # 1 x h x w x 200 -> 1 x 200 x h x w for the network

        # create gt tensor
#         gt = fullTens[12:-12, 12:-12, :, q_eval[0], q_eval[1]]
        gt = "sorry no gt"
        
        # evaluate the network
        out_tens = net.forward(in_eval.to(device), feat_eval.to(device), q_eval.to(device))
        
        out_disp = net.d(feat_eval.to(device))
        
    # return tuple of the input, the features, the gt, and the final output
    return in_eval, feat_eval, gt, out_tens, out_disp

def phone_uv_on_frame(frame, u, v):
    # returns another frame that has the u, v representation drawn-on... 4 corner boxes are color, selected box is sel_color
    # frame is h x w x 3
    new_frame = frame
    offset = (6, 6) # row/col offset
    box_shape = (6, 6) # row/col box shape
    # draw corners
    color = np.array([1.0, 0.5, 0])
    new_frame[offset[0]+0*box_shape[0]:offset[0]+1*box_shape[0], offset[1]+0*box_shape[1]:offset[1]+1*box_shape[1], :] = color # tl
    new_frame[offset[0]+0*box_shape[0]:offset[0]+1*box_shape[0], offset[1]+7*box_shape[1]:offset[1]+8*box_shape[1], :] = color # tr
    new_frame[offset[0]+7*box_shape[0]:offset[0]+8*box_shape[0], offset[1]+0*box_shape[1]:offset[1]+1*box_shape[1], :] = color # bl
    new_frame[offset[0]+7*box_shape[0]:offset[0]+8*box_shape[0], offset[1]+7*box_shape[1]:offset[1]+8*box_shape[1], :] = color # br
    # draw selected box
    sel_color = np.array([0.8, 0.8, 0.8])
    new_frame[offset[0]+v*box_shape[0]:offset[0]+(v+1)*box_shape[0], offset[1]+u*box_shape[1]:offset[1]+(u+1)*box_shape[1], :] = sel_color
    return new_frame
    

load_from_file = True
im_filepath = ["./im0.JPG", "./im1.JPG", "./im2.JPG", "./im3.JPG"]
net_filepath = "./best3_loss_18.pt"
net = DispColorNet()

if load_from_file:
    checkpoint = torch.load(net_filepath, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    e0 = checkpoint['epoch']

net.to(torch.device(device))

h, w, _ = eval_phone_image(im_filepath, net, [3, 3])[3][0,:,:,:].permute((1, 2, 0)).shape

frameSize = (w, h)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('./blphonefinal.mp4', fourcc, float(30), frameSize)
repetitions = 1
held_frames = 3

mode = 'circle' # horizontal, vertical, or circle

if mode == 'horizontal':
    for rep in range(repetitions):
        print("Repetition: %d" % rep)
        for v in range(8):
            for u in range(8)[::1 if v % 2 == 0 else -1]:
                print("Writing view at (u, v) = %d, %d" % (u, v))
                in_eval, feat_eval, gt, output, out_disp = eval_phone_image(im_filepath, net, [u, v])
                curr_frame = torch.clamp(output[0, :, :, :], min=0.0, max=1.0).permute((1, 2, 0)).cpu().numpy()
                curr_frame = colorcorr(curr_frame)
                img = phone_uv_on_frame(curr_frame, u, v)
                img = (img*255).astype(np.uint8)[:, :, ::-1]
                for t in range(held_frames):
                    out.write(img)
elif mode == 'vertical':
    for rep in range(repetitions):
        print("Repetition: %d" % rep)
        for u in range(8):
            for v in range(8)[::1 if u % 2 == 0 else -1]:
                print("Writing view at (u, v) = %d, %d" % (u, v))
                in_eval, feat_eval, gt, output, out_disp = eval_phone_image(im_filepath, net, [u, v])
                curr_frame = torch.clamp(output[0, :, :, :], min=0.0, max=1.0).permute((1, 2, 0)).cpu().numpy()
                curr_frame = colorcorr(curr_frame)
                img = phone_uv_on_frame(curr_frame, u, v)
                img = (img*255).astype(np.uint8)[:, :, ::-1]
                for t in range(held_frames):
                    out.write(img)
elif mode == 'circle':
    us = list(range(8)) + [7]*6 + list(range(8)[::-1]) + [0]*6
    vs = [0]*7 + list(range(8)) + [7]*6 + list(range(8)[::-1])[:-1]
    for u, v in zip(us, vs):
        print("Writing view at (u, v) = %d, %d" % (u, v))
        in_eval, feat_eval, gt, output, out_disp = eval_phone_image(im_filepath, net, [u, v])
        curr_frame = torch.clamp(output[0, :, :, :], min=0.0, max=1.0).permute((1, 2, 0)).cpu().numpy()
        curr_frame = colorcorr(curr_frame)
        img = phone_uv_on_frame(curr_frame, u, v)
        img = (img*255).astype(np.uint8)[:, :, ::-1]
        for t in range(held_frames):
            out.write(img)
    
out.release()

Writing view at (u, v) = 0, 0
Writing view at (u, v) = 1, 0
Writing view at (u, v) = 2, 0
Writing view at (u, v) = 3, 0
Writing view at (u, v) = 4, 0
Writing view at (u, v) = 5, 0
Writing view at (u, v) = 6, 0
Writing view at (u, v) = 7, 0
Writing view at (u, v) = 7, 1
Writing view at (u, v) = 7, 2
Writing view at (u, v) = 7, 3
Writing view at (u, v) = 7, 4
Writing view at (u, v) = 7, 5
Writing view at (u, v) = 7, 6
Writing view at (u, v) = 7, 7
Writing view at (u, v) = 6, 7
Writing view at (u, v) = 5, 7
Writing view at (u, v) = 4, 7
Writing view at (u, v) = 3, 7
Writing view at (u, v) = 2, 7
Writing view at (u, v) = 1, 7
Writing view at (u, v) = 0, 7
Writing view at (u, v) = 0, 6
Writing view at (u, v) = 0, 5
Writing view at (u, v) = 0, 4
Writing view at (u, v) = 0, 3
Writing view at (u, v) = 0, 2
Writing view at (u, v) = 0, 1
