# Test loss

In [9]:
import torch
import numpy as np
from torch import nn

In [3]:
def robust_l1(x, q=0.5, eps=1e-2):
    x = torch.pow((x.pow(2) + eps), q)
    x = x.mean()
    return x

In [43]:
def robust_l1_per_pix(x, q=0.5, eps=1e-2):
    x = torch.pow((x.pow(2) + eps), q)
    return x

In [10]:
# Author: Jonas Wulff

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous(), requires_grad=False)
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    return ssim_map
    #if size_average:
    #    return ssim_map.mean()
    #else:
    #    return ssim_map.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size = 13, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)


In [23]:
tgt_img = torch.rand(4,3,256,832)

ref_imgs = [torch.rand(4,3,256,832),torch.rand(4,3,256,832)]

occ_masks = torch.rand(4,4,640,640)

In [27]:
reconstruction_loss = 0
h, w = 640, 640

tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w))
ref_imgs_scaled = [nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs]

weight = 0.5
wssim = 0.85

ssim_losses = []
photometric_losses = []

for i, ref_img in enumerate(ref_imgs_scaled):

    valid_pixels = 1 - (ref_img == 0).prod(1, keepdim=True).type_as(ref_img)
    diff = (tgt_img_scaled - ref_img) * valid_pixels
    ssim_loss = 1 - ssim(tgt_img_scaled, ref_img) * valid_pixels
    oob_normalization_const = valid_pixels.nelement()/valid_pixels.sum()

    assert((oob_normalization_const == oob_normalization_const).item() == 1)

    diff = diff *(1-occ_masks[:,i:i+1]).expand_as(diff)
    ssim_loss = ssim_loss*(1-occ_masks[:,i:i+1]).expand_as(ssim_loss)

    # reconstruction_loss +=  oob_normalization_const*((1- wssim)*robust_l1_per_pix(diff, q=qch) + weight*wssim*ssim_loss).min() + lambda_oob*robust_l1(1 - valid_pixels, q=qch)
    ssim_losses.append(oob_normalization_const*weight*wssim*ssim_loss)
    photometric_losses.append(oob_normalization_const*(1 - wssim)*robust_l1(diff))
    # assert((reconstruction_loss == reconstruction_loss).item() == 1)
    #weight /= 2.83

ssim_losses = torch.stack(ssim_losses)
photometric_losses = torch.stack(photometric_losses)
reconstruction_loss = torch.min(ssim_losses,0)[0].mean() + torch.mean(photometric_losses)

In [62]:
reconstruction_loss = 0
h, w = 640, 640

tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w))
ref_imgs_scaled = [nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs]

weight = 0.5
wssim = 0.85

pe_losses = []

for i, ref_img in enumerate(ref_imgs_scaled):

    valid_pixels = 1 - (ref_img == 0).prod(1, keepdim=True).type_as(ref_img)
    diff = (tgt_img_scaled - ref_img) * valid_pixels
    ssim_loss = 1 - ssim(tgt_img_scaled, ref_img) * valid_pixels
    oob_normalization_const = valid_pixels.nelement()/valid_pixels.sum()

    assert((oob_normalization_const == oob_normalization_const).item() == 1)

    diff = diff *(1-occ_masks[:,i:i+1]).expand_as(diff)
    ssim_loss = ssim_loss*(1-occ_masks[:,i:i+1]).expand_as(ssim_loss)

    # reconstruction_loss +=  oob_normalization_const*((1- wssim)*robust_l1_per_pix(diff, q=qch) + weight*wssim*ssim_loss).min() + lambda_oob*robust_l1(1 - valid_pixels, q=qch)
    pe = oob_normalization_const*weight*wssim*ssim_loss + oob_normalization_const*(1 - wssim)*robust_l1_per_pix(diff, q=qch) + lambda_oob*robust_l1_per_pix(1 - valid_pixels, q=qch) 
    pe = pe.mean(1)
    pe_losses.append(pe)
    # assert((reconstruction_loss == reconstruction_loss).item() == 1)
    #weight /= 2.83

pe_losses = torch.stack(pe_losses)

pe_losses = pe_losses.min(0)[0]

reconstruction_loss = pe_losses.mean()


In [55]:
ssim_losses.shape

torch.Size([2, 4, 640, 640])

In [61]:
photometric_losses.min(0)[0].mean()

tensor(0.0193)

In [46]:
robust_l1_per_pix(diff).shape

torch.Size([4, 3, 640, 640])

In [50]:
ssim(tgt_img_scaled, ref_img).mean(1, True).shape

torch.Size([4, 1, 640, 640])

In [42]:
robust_l1(diff)

tensor(0.1600)

In [None]:
reconstruction_loss +=  
(1- wssim)*robust_l1(diff, q=qch) + wssim*ssim_loss.mean()

In [38]:
quang = [tgt_img.mean(1,True),tgt_img.mean(1,True),tgt_img.mean(1,True)]
torch.cat(quang,1).shape

torch.Size([4, 3, 256, 832])

In [63]:
reconstruction_loss

tensor(0.1539)

In [51]:
reconstruction_loss

tensor(0.1528)

In [14]:
def photometric_reconstruction_loss(tgt_img, ref_imgs, intrinsics, intrinsics_inv, depth, explainability_mask, pose, rotation_mode='euler', padding_mode='zeros', lambda_oob=0, qch=0.5, wssim=0.5):
    if type(explainability_mask) not in [tuple, list]:
        explainability_mask = [explainability_mask]
    if type(depth) not in [list, tuple]:
        depth = [depth]

    loss = 0
    for d, mask in zip(depth, explainability_mask):
        occ_masks = depth_occlusion_masks(d, pose, intrinsics, intrinsics_inv)
        loss += one_scale(d, mask, occ_masks)
    return loss



# Test network

In [65]:
import torch
import torch.nn as nn


def downsample_conv(in_planes, out_planes, kernel_size=3):
    return nn.Sequential(
        nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2),
        nn.ReLU(inplace=True)
    )


def predict_disp(in_planes):
    return nn.Sequential(
        nn.Conv2d(in_planes, 1, kernel_size=3, padding=1),
        nn.Sigmoid()
    )


def conv(in_planes, out_planes):
    return nn.Sequential(
        nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
    )


def upconv(in_planes, out_planes):
    return nn.Sequential(
        nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.ReLU(inplace=True)
    )


def crop_like(input, ref):
    assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
    return input[:, :, :ref.size(2), :ref.size(3)]


class DispNetS6(nn.Module):

    def __init__(self, alpha=10, beta=0.01):
        super(DispNetS6, self).__init__()

        self.alpha = alpha
        self.beta = beta

        conv_planes = [32, 64, 128, 256, 512, 512, 512]
        self.conv1 = downsample_conv(3,              conv_planes[0], kernel_size=7)
        self.conv2 = downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
        self.conv3 = downsample_conv(conv_planes[1], conv_planes[2])
        self.conv4 = downsample_conv(conv_planes[2], conv_planes[3])
        self.conv5 = downsample_conv(conv_planes[3], conv_planes[4])
        self.conv6 = downsample_conv(conv_planes[4], conv_planes[5])
        self.conv7 = downsample_conv(conv_planes[5], conv_planes[6])

        upconv_planes = [512, 512, 256, 128, 64, 32, 16]
        self.upconv7 = upconv(conv_planes[6],   upconv_planes[0])
        self.upconv6 = upconv(upconv_planes[0], upconv_planes[1])
        self.upconv5 = upconv(upconv_planes[1], upconv_planes[2])
        self.upconv4 = upconv(upconv_planes[2], upconv_planes[3])
        self.upconv3 = upconv(upconv_planes[3], upconv_planes[4])
        self.upconv2 = upconv(upconv_planes[4], upconv_planes[5])
        self.upconv1 = upconv(upconv_planes[5], upconv_planes[6])

        self.iconv7 = conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
        self.iconv6 = conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
        self.iconv5 = conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
        self.iconv4 = conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
        self.iconv3 = conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
        self.iconv2 = conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
        self.iconv1 = conv(1 + upconv_planes[6], upconv_planes[6])

        self.predict_disp6 = predict_disp(upconv_planes[1])
        self.predict_disp5 = predict_disp(upconv_planes[2])
        self.predict_disp4 = predict_disp(upconv_planes[3])
        self.predict_disp3 = predict_disp(upconv_planes[4])
        self.predict_disp2 = predict_disp(upconv_planes[5])
        self.predict_disp1 = predict_disp(upconv_planes[6])

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.xavier_uniform(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x):
        out_conv1 = self.conv1(x)
        out_conv2 = self.conv2(out_conv1)
        out_conv3 = self.conv3(out_conv2)
        out_conv4 = self.conv4(out_conv3)
        out_conv5 = self.conv5(out_conv4)
        out_conv6 = self.conv6(out_conv5)
        out_conv7 = self.conv7(out_conv6)

        out_upconv7 = crop_like(self.upconv7(out_conv7), out_conv6)
        concat7 = torch.cat((out_upconv7, out_conv6), 1)
        out_iconv7 = self.iconv7(concat7)

        out_upconv6 = crop_like(self.upconv6(out_iconv7), out_conv5)
        concat6 = torch.cat((out_upconv6, out_conv5), 1)
        out_iconv6 = self.iconv6(concat6)
        disp6 = self.alpha * self.predict_disp6(out_iconv6) + self.beta

        out_upconv5 = crop_like(self.upconv5(out_iconv6), out_conv4)
        concat5 = torch.cat((out_upconv5, out_conv4), 1)
        out_iconv5 = self.iconv5(concat5)
        disp5 = self.alpha * self.predict_disp5(out_iconv5) + self.beta

        out_upconv4 = crop_like(self.upconv4(out_iconv5), out_conv3)
        concat4 = torch.cat((out_upconv4, out_conv3), 1)
        out_iconv4 = self.iconv4(concat4)
        disp4 = self.alpha * self.predict_disp4(out_iconv4) + self.beta

        out_upconv3 = crop_like(self.upconv3(out_iconv4), out_conv2)
        disp4_up = crop_like(nn.functional.upsample(disp4, scale_factor=2, mode='bilinear'), out_conv2)
        concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
        out_iconv3 = self.iconv3(concat3)
        disp3 = self.alpha * self.predict_disp3(out_iconv3) + self.beta

        out_upconv2 = crop_like(self.upconv2(out_iconv3), out_conv1)
        disp3_up = crop_like(nn.functional.upsample(disp3, scale_factor=2, mode='bilinear'), out_conv1)
        concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
        out_iconv2 = self.iconv2(concat2)
        disp2 = self.alpha * self.predict_disp2(out_iconv2) + self.beta

        out_upconv1 = crop_like(self.upconv1(out_iconv2), x)
        disp2_up = crop_like(nn.functional.upsample(disp2, scale_factor=2, mode='bilinear'), x)
        concat1 = torch.cat((out_upconv1, disp2_up), 1)
        out_iconv1 = self.iconv1(concat1)
        disp1 = self.alpha * self.predict_disp1(out_iconv1) + self.beta

        if self.training:
            return disp1, disp2, disp3, disp4, disp5, disp6
        else:
            return disp1


In [66]:
model = DispNetS6()

In [None]:
def save_checkpoint(save_path, dispnet_state, filename='checkpoint.pth.tar'):
    file_prefixes = ['dispnet']
    states = [dispnet_state]
    for (prefix, state) in zip(file_prefixes, states):
        torch.save(state, save_path/'{}_{}'.format(prefix,filename))