In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import cv2
import os

from dep2def import depth2defocus
from functools import partial

class pixel_estimator_with_weights(nn.Module):
    def __init__(self, Weights,device = "cuda:0"):
        ## Default: gpu mode
        super(pixel_estimator_with_weights, self).__init__()
        self.device = torch.device(device)
        self.w1 = torch.from_numpy(Weights[0].transpose(3,2,0,1)).to(self.device)
        self.b1 = torch.from_numpy(Weights[1]).to(self.device)
        self.w2 = torch.tensor(Weights[2].transpose(3,2,0,1)).to(self.device)
        self.b2 = torch.tensor(Weights[3]).to(self.device)
        self.w3 = torch.tensor(Weights[4].transpose(3,2,0,1)).to(self.device)
        self.b3 = torch.tensor(Weights[5]).to(self.device)
        self.w4 = torch.tensor(Weights[6]).reshape(4,4,8,1024).permute(3,2,0,1).to(self.device)
        self.b4 = torch.tensor(Weights[7]).to(self.device)
        self.w5 = torch.tensor(Weights[8]).reshape(1,1,1024,512).permute(3,2,0,1).to(self.device)
        self.b5 = torch.tensor(Weights[9]).to(self.device)
        self.w6 = torch.tensor(Weights[10]).reshape(1,1,512,10).permute(3,2,0,1).to(self.device)
        self.b6 = torch.tensor(Weights[11]).to(self.device)
        self.w7 = torch.tensor(Weights[12]).reshape(1,1,10,1).permute(3,2,0,1).to(self.device)
        self.b7 = torch.tensor(Weights[13]).to(self.device)

    def forward(self, x):
        x = F.relu(F.conv2d(x,self.w1,bias = self.b1,stride=1))
        x = F.relu(F.conv2d(x,self.w2,bias = self.b2,stride=1,dilation=8))
        x = F.relu(F.conv2d(x,self.w3,bias = self.b3,stride=1,dilation=32))
        x = F.leaky_relu(F.conv2d(x,self.w4,bias = self.b4,stride=1,dilation=128),0.1)
        x = F.leaky_relu(F.conv2d(x,self.w5,bias = self.b5,stride=1),0.1)
        x = F.leaky_relu(F.conv2d(x,self.w6,bias = self.b6,stride=1),0.1)
        x = F.conv2d(x,self.w7,bias = self.b7,stride=1)
        return x
    
    
def crop_patches(img, window= 1023, step = 512):
    patches = []
    H, W = img.shape
    for i in range(0, H-step, step):
        for j in range(0, W-step, step):
            patches.append(img[i:i+window, j:j+window])
    return np.stack(patches)


def gaf_func(img):
    assert img.max() <= 1.0
    if img.shape != (2160, 3840):
        img = cv2.resize(img, (3840, 2160))
    img = np.pad(img, ((200, 200), (128, 128)), 'reflect')
    H, W = img.shape
    
    patches = crop_patches(img)
    patches = torch.from_numpy(patches).float().unsqueeze(1).cuda()
        
    results = []
    with torch.no_grad():
        for i in range(patches.size()[0]):
            results.append(model(patches[i:i+1]))
    results = torch.stack(results)

    results = results.cpu().numpy()
    results = results.squeeze()
    
    k = 0
    sigma =1
    n_img = np.zeros((H-512, W-512))
    for i in range(0, H-512, 512):
        for j in range(0, W-512, 512):
            n_img[i:i+512, j:j+512] = results[k]
            k += 1

    n_img = np.clip(n_img, 0, 8)
    return n_img

In [8]:
path = '/home/qian/Downloads/DAVIS/viz_predictions/'

files = []
# r=root, d=directories, f = files
for r, d, f in os.walk(path):

    for file in f:
        if '.npy' in file:
            files.append(os.path.join(r, file))

for f in sorted(files):
    print(f)
#     print(f.replace('/viz_predictions/', '/JPEGImages/Full-Resolution/').replace('.npy', '.jpg'))

/home/qian/Downloads/DAVIS/viz_predictions/ ['mallard-water-', 'drift-turn', 'camel', 'bike-packing', 'horsejump-low', 'blackswan', 'rhino', 'mallard-fly-', 'dance-twirl', 'breakdance-flare', 'surf-', 'color-run-', 'swing', 'pigs', 'scooter-board-', 'car-roundabout', 'elephant', 'stroller-', 'varanus-cage', 'india-', 'drift-chicane-', '.ipynb_checkpoints', 'upside-down-', 'rallye', 'car-turn', 'breakdance', 'kite-surf-', 'tennis', 'lindy-hop', 'dog-', 'flamingo', 'dogs-scale', 'paragliding-', 'car-shadow', 'night-race-', 'dog-agility-', 'boxing-fisheye-', 'koala-', 'cat-girl-', 'planes-water-', 'hike', 'soapbox', 'hockey', 'boat', 'bmx-trees-', 'tuk-tuk', 'judo-', 'train', 'tractor-sand', 'disc-jockey-', 'drift-straight-', 'bmx-bumps-', 'kid-football', 'drone-', 'dogs-jump', 'rollerblade-', 'dance-jump', 'dog-gooses', 'lucia', 'scooter-black-', 'shooting-', 'libby-', 'goat', 'scooter-gray-', 'kite-walk-', 'classic-car', 'crossing', 'loading-', 'paragliding-launch', 'lady-running-', 'so

/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00007.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00008.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00009.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00010.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00011.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00012.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00013.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00014.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00015.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00016.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00017.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00018.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00019.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00020.npy
/home/qian/Downloads/DAVIS/viz_predictions/night-race-/00021.npy
/home/qian/Downloads/DAVI

In [None]:
width = 1080 # img.shape[1]
f = 25
fn = 2
FoV_h = 10 * np.pi / 180
pp = 2 * f * np.tan(FoV_h / 2) / width  # pixel pitch in mm
gamma = 2.4
# use partial is recommended to set lens parameter
myd2d = partial(depth2defocus, f=f, fn=fn, pp=pp, r_step=1, inpaint_occlusion=False)  # this would fix f, fn, pp, and r_step

In [3]:
if __name__=='__main__':  
    model = torch.load('autofocus.pth')
    model.eval()
    
    with torch.no_grad():
        for idx, img_path in enumerate(sorted(files)):
            print(idx)
#             print(img_path)
            img = cv2.imread(img_path)#/255.0
            dpt = np.load(dpt_path)
            focus = np.random.rand() * (dpt.max() - dpt.min()) + dpt.min()
            img = myd2d(img, dpt, focus, inpaint_occlusion=False)
            gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)/255.0
            gaf = gaf_func(gray_img)
            af_path = img_path.replace('/JPEGImages/', '/viz_z2/')
#             print(af_path)
            last_slash_idx = af_path.rfind('/')
            directory = af_path[:last_slash_idx]
#             print(directory)
            if not os.path.exists(directory):
                os.makedirs(directory)
            cv2.imwrite(af_path, (gaf /8.0 * 255.0).astype(np.uint8))

0
1
2
3
4
5
6


KeyboardInterrupt: 