In [1]:
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from scipy import ndimage
from scipy import signal
import scipy.sparse.linalg
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import conv2d
import functools
from matplotlib.widgets import Slider, Button, RadioButtons
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from functools import partial
%load_ext autoreload
%load_ext line_profiler
%autoreload 2

In [2]:
from crf.gaussian_matrix import LatticeGaussian, LatticeFilter#, LSHGaussian

In [3]:
from crf.utils import read_image, read_pfm, read_pgm
from crf.features import Vgg16features
from crf.crf import *
from crf.depth import *

In [4]:
img1 = read_image('imL.png')#[::2,::2]
img2 = read_image('imR.png')#[::2,::2]
gt_depth = read_pgm('truedisp.row3.col3.pgm')
#img1 = read_image('im0.png')[::3,::3]
#img2 = read_image('im1.png')[::3,::3]
#gt_depth = read_pfm('disp0.pfm')[::3,::3]
device = torch.device('cpu')#torch.device('cuda')#torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [5]:
q = 0
# Get random projection of VGG16 features
VGG = Vgg16features()
VGG.to(device)
VGG.eval()
features = VGG.get_features(img1,k=q+1)
torch_features = torch.from_numpy(features[q]).to(device).detach()

In [6]:
img1.shape

(288, 384, 3)

In [7]:
d=5
class ReferenceMatrix(nn.Module):
    def __init__(self):
        super().__init__()
        self.projection = nn.Linear(64,d)
        self.sigma_c = nn.Parameter(torch.tensor(.1).float())
        self.sigma_p = nn.Parameter(torch.tensor(.1).float())
    def forward(self,img,nn_features):
        projected_features = self.projection(nn_features)/10
        scaled_rgb = torch.from_numpy(img).float()/self.sigma_c
        h,w,_ = img.shape
        ij = torch.from_numpy(np.mgrid[:h,:w].transpose((1,2,0))/np.sqrt(h**2+w**2)).float()
        scaled_ij = ij/self.sigma_p
        #print(ij.shape,scaled_rgb.shape, projected_features.shape)
        return torch.cat([scaled_ij,scaled_rgb,projected_features],dim=-1).reshape(h*w,5+d)#,
    
class denseCRF(nn.Module):
    def __init__(self,n_iters=5,num_classes=48):
        super().__init__()
        self.device = torch.device('cpu')
        self.n_iters = n_iters
        self.reference = ReferenceMatrix()
        self.labels = torch.arange(num_classes).float().to(self.device)
        self.Mu = compatibility_matrix(partial(charbonneir,gamma=3),self.labels)
        self.w1 = nn.Parameter(torch.tensor(1).float())
        self.E_weight = nn.Parameter(torch.tensor(1).float())
        
        
    def forward(self,E,img,nn_features):
        ref = self.reference(img,nn_features)
        W = LatticeGaussian(ref)
        Q_out = mean_field_infer(E_0*self.E_weight,W,self.Mu*self.w1,self.n_iters)
        expected_depths = Q_out@self.labels
        return expected_depths

In [8]:
## Hyper parameters
ws = 9      # Disparity aggregation window size
gamma = 3    # Charbonneir turning point
sigma_c = .1#.15#.1#.1 # Filter stdev for color channels
sigma_p = .1#.08#.1 # Filter stdev for position channels
sigma_f = 3#3#10#3.46#10 # Filter stdev for feature channels
n_iters = 10 # Number of mean field message passing iterations
down_factor = 1

In [9]:
# Get the unary potentials from window sweep
disp_energy = disparity_badness(img1,img2,ws,criterion=AD)
disps = np.argmin(disp_energy,axis=-1)
L = disp_energy.shape[-1] # Number of possible disparities
downsampled_out = disp_energy[::down_factor,::down_factor]
h,w,_ = downsampled_out.shape
n = h*w

E_0 = torch.from_numpy(downsampled_out.reshape(-1,L)).float().to(device)
P_0 = F.softmax(-E_0,dim=1)

In [10]:
# with torch.no_grad():
#     mf = mean_field_infer(E_0,W,Mu,n_iters)
#     expected_depths = mf@labels.to(device)#.max(dim=-1)[1]#@labels.to(device)
#     crf_depth = expected_depths.reshape(h,w).cpu().numpy()
labels = torch.arange(L).float()
baseline_depth = (P_0@labels.to(device)).reshape(h,w).cpu().numpy()
D = denseCRF(n_iters=2).to(device)

In [11]:
import time

In [12]:
gt = torch.from_numpy(gt_depth.reshape(-1)).float()
def testrun():
    crf_depth = D(E_0,img1,torch_features);
    diff = (4*crf_depth - gt/4)[gt!=0];
    loss = (diff**2).mean(); loss.backward()
def testrun_nograd():
    with torch.no_grad():
        crf_depth = D(E_0,img1,torch_features);
        diff = (4*crf_depth - gt/4)[gt!=0];
        loss = (diff**2).mean()
#%lprun -f LatticeFilter.backward testrun()
#%timeit -n 1 -r 2 testrun()
# with torch.autograd.profiler.profile() as prof:
#     testrun()
# print(prof)
n=5
from pympler.tracker import SummaryTracker
tracker = SummaryTracker()

t0 = time.time()
for i in range(n):
    testrun()
print((time.time()-t0)/n)

tracker.print_diff()




[2.68100977e-01 2.17308283e-01 7.87857273e+01 5.84125519e-05
 7.58018255e-01]
[4.58715916e-01 2.16243982e-01 8.14373910e+01 4.50611115e-05
 1.38150525e+00]
[2.55412579e-01 2.16589689e-01 7.98322637e+01 3.29017639e-05
 7.51201630e-01]
[4.51885939e-01 2.12408543e-01 8.47038596e+01 4.50611115e-05
 1.38389325e+00]
[2.52475977e-01 2.17891693e-01 7.97238255e+01 3.43322754e-05
 8.37720156e-01]
[4.49995279e-01 2.13065863e-01 8.35526903e+01 4.50611115e-05
 1.38874626e+00]
[2.67519236e-01 2.18491316e-01 8.07413890e+01 3.21865082e-05
 7.50349760e-01]
[4.51190233e-01 2.12913036e-01 8.17247176e+01 4.52995300e-05
 1.39198279e+00]
[2.68355131e-01 2.16835499e-01 7.95785453e+01 3.17096710e-05
 7.53378630e-01]
[4.53297377e-01 2.15456247e-01 8.29399123e+01 4.38690186e-05
 1.38188171e+00]
174.14342527389528
                               types |   # objects |   total size
                        <class 'list |       21905 |      2.02 MB
                         <class 'str |       22933 |      1.64 MB
   

In [None]:
E_0.shape