In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from data_loader import prepare_dataloader


In [23]:
n_img, data_loader = prepare_dataloader('../data/kitti', mode='train', batch_size=2)

Use a dataset with 108 images


In [29]:
class MonodepthLoss(nn.Module):
    
    def __init__(self, n_scales=4, ssim_w=0.85, disp_grad_w=1.0, lr_w=1.0):
        super(MonodepthLoss, self).__init__()
        
        self.n_scales = n_scales
        self.ssim_w = ssim_w
        self.disp_grad_w = disp_grad_w
        self.lr_w = lr_w
        
    def scale_pyramid(self, img):
        scaled_imgs = [img]
        s = img.size()
        h = s[2]
        w = s[3]
        for i in range(self.n_scales - 1):
            ratio = 2 ** (i + 1)
            nh = h // ratio
            nw = w // ratio
            scaled_imgs.append(nn.functional.interpolate(img,
                               size=[nh, nw], mode='bilinear',
                               align_corners=True))
        return scaled_imgs
        
    def forward(self, disparities, images):
        """ Compute the loss function according to eqn (1)
        
        Parameters
        ----------
        disparities : tuple of n_scales tensors, each of size [batch_size, 2, img_x, img_y]

            output disparities at n_scales different scales of decreasing size

        images : tuple of 2 tensors, each of size [batch_size, 3 (rgb), img_y, img_y]
            left and right image
            
        Returns
        -------
        loss : float
            total loss at all scales
        
        """
        
        # input images
        left, right = images
        left_pyramid = self.scale_pyramid(left)
        right_pyramid = self.scale_pyramid(right)
        
        # get left and right disparities
        disp_left_est = [d[:, 0, :, :].unsqueeze(1) for d in disparities]
        disp_right_est = [d[:, 1, :, :].unsqueeze(1) for d in disparities]
        

In [27]:
data = next(iter(data_loader))
images = data['left_image'], data['right_image']
print(images[0].size())

disparities = [torch.from_numpy(np.load('../data/output/d{}.npy'.format(i))) for i in range(4)]

torch.Size([2, 3, 256, 512])


In [31]:
loss = MonodepthLoss()

loss(disparities, images)

torch.Size([2, 3, 256, 512])
