In [None]:
#%pylab inline
import numpy as np
import matplotlib.pyplot as plt
import torch, torch.nn

plt.rcParams['figure.figsize'] = [12, 6]

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Use CUDA if available
# device = 'cpu'
device

In [None]:
imsize = 384
disp = torch.tile(torch.arange(imsize, dtype=torch.float, device=device).view(-1, 1), (1, imsize))
disp1 = disp + 100 * torch.rand(disp.shape, device=device)
disp2 = 500 - disp + 100 * torch.rand(disp.shape, device=device)

def norm_disp_midas(d):
    "Normalise disparity maps a la MiDaS (Eq. 5+6)."
    t = torch.median(d)
    s = torch.abs(d - t).mean()
    print(f"MiDaS normalisation: t={t:.3f}, s={s:.3f}")
    return (d - t) / s

disp1 = norm_disp_midas(disp1)
disp2 = norm_disp_midas(disp2)

plt.subplot(1,3,1)
plt.title("disp")
plt.imshow(disp.cpu())
plt.colorbar()

plt.subplot(1,3,2)
plt.title("disp1")
plt.imshow(disp1.cpu())
plt.colorbar()

plt.subplot(1,3,3)
plt.title("disp2")
plt.imshow(disp2.cpu())
plt.colorbar();

In [None]:
## Initialise scale + offset parameters
gridsize = 17
scale  = torch.ones(1, 1, gridsize, gridsize, requires_grad=True, device=device)
# scale  = torch.rand(1, 1, gridsize, gridsize, requires_grad=True, device=device)  # random initialisation
offset = torch.zeros(1, 1, gridsize, gridsize, requires_grad=True, device=device)
upscale = torch.nn.Upsample(size=disp.shape, mode='bilinear', align_corners=True)

# optimizer = torch.optim.SGD([scale, offset], lr=0.001)
optimizer = torch.optim.Adam([scale, offset], lr=.01)

for iteration in range(1000):
    disp2_scaled = upscale(scale) * disp2 + upscale(offset)
    data_residual = disp1 - disp2_scaled
    data_term = (data_residual ** 2).mean()  # no robust function
    
    # Hedman & Kopf 2018, Eq. 6
    smoothness_term_x = ((scale[:, :, :, 1:] - scale[:, :, :, :-1]) ** 2).mean() + ((offset[:, :, :, 1:] - offset[:, :, :, :-1]) ** 2).mean()
    smoothness_term_y = ((scale[:, :, 1:, :] - scale[:, :, :-1, :]) ** 2).mean() + ((offset[:, :, 1:, :] - offset[:, :, :-1, :]) ** 2).mean()
    smoothness_term = 1e3 * (smoothness_term_x + smoothness_term_y)
    
    scale_term = torch.zeros(1, device=device).squeeze()
#     scale_term = 1e-4 * (1. / (scale + 1e-6)).mean()  # Hedman & Kopf 2018, Eq. 7
#     scale_term = 1e-4 * ((1 - scale) ** 2).mean()  # suggested squared error term
    
    loss = data_term + smoothness_term + scale_term
    
    if iteration % 100 == 0:
        print(f"{iteration:3d}. loss = {loss.detach().cpu().numpy():.3f} = {data_term.detach().cpu().numpy():.3f} + {smoothness_term.detach().cpu().numpy():.3f} + {scale_term.detach().cpu().numpy():.3f}")

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
# Visualise scale + offset
plt.subplot(1,2,1)
plt.title("scale")
plt.imshow(scale[0,0].detach().cpu(), vmin=-2, vmax=2)
plt.colorbar()

plt.subplot(1,2,2)
plt.title("offset")
plt.imshow(offset[0,0].detach().cpu(), vmin=-2, vmax=2)
plt.colorbar();

In [None]:
plt.subplot(1,3,1)
plt.title("GT (disp1)")
plt.imshow(disp1.cpu())
plt.colorbar()

plt.subplot(1,3,2)
plt.title("disp2_scaled")
plt.imshow(disp2_scaled[0,0].detach().cpu())
plt.colorbar()

plt.subplot(1,3,3)
plt.title("data_residual")
plt.imshow(data_residual[0,0].detach().cpu())
plt.colorbar();