# Install packages

In [None]:
# !pip3 install numpy
# !pip3 install azure-storage-blob
# !pip3 install opencv-python

# Data handlers

This code contains the dataset splitter and loaders used during the process

## Create data splits

In [None]:
from azure.storage.blob import ContainerClient
import re
import io
import cv2
import os

WRK_PTH = os.getcwd()
ENV_MODES = ['Easy']
SIDES = ['Left', 'Right']
NUM_VAL_TRAJECTS = 2
SHUFFLE_TRAJ_DIRS = False
SHUFFLE_IMAGES = False
FRAME_LIMITS = [-1, 2]


RUN = False # not ( os.path.exists(os.path.join(WRK_PTH, 'train_files.txt')) and os.path.exists(os.path.join(WRK_PTH, 'val_files.txt')) )
if RUN:
    # Dataset website: http://theairlab.org/tartanair-dataset/
    account_url = 'https://tartanair.blob.core.windows.net/'
    container_name = 'tartanair-release1'

    container_client = ContainerClient(account_url=account_url, 
                                    container_name=container_name,
                                    credential=None)


def get_environment_list():
    '''
    List all the environments shown in the root directory
    '''
    env_gen = container_client.walk_blobs()
    envlist = []
    for env in env_gen:
        envlist.append(env.name)
    return envlist


def get_trajectory_list(envname, easy_hard = 'Easy'):
    '''
    List all the trajectory folders, which is named as 'P0XX'
    '''
    assert(easy_hard=='Easy' or easy_hard=='Hard')
    traj_gen = container_client.walk_blobs(name_starts_with=envname + '/' + easy_hard+'/')
    trajlist = []
    for traj in traj_gen:
        trajname = traj.name
        trajname_split = trajname.split('/')
        trajname_split = [tt for tt in trajname_split if len(tt)>0]
        if trajname_split[-1][0] == 'P':
            trajlist.append(trajname)
    return trajlist


def _list_blobs_in_folder(folder_name):
    """
    List all blobs in a virtual folder in an Azure blob container
    """
    
    files = []
    generator = container_client.list_blobs(name_starts_with=folder_name)
    for blob in generator:
        files.append(blob.name)
    return files


def get_image_list(trajdir, left_right = 'left'):
    assert(left_right == 'left' or left_right == 'right')
    files = _list_blobs_in_folder(trajdir + '/image_' + left_right + '/')
    files = [fn for fn in files if fn.endswith('.png')]
    return files


def append_paths_srings(env_name, env_mode, traj_dir_list, path_list):
    for traj_dir in traj_dir_list:
        traj_name = traj_dir.split('/')[-2]
        img_pth_list = get_image_list(traj_dir)
        img_pth_list = img_pth_list[abs(FRAME_LIMITS[0]):-FRAME_LIMITS[1]]

        for img_pth in img_pth_list:
            frame_id = re.findall('\d\d\d\d\d\d', img_pth)
            
            if len(frame_id) > 0:
                frame_id = frame_id[0]
            else:
                continue
            
            if 'right' in img_pth:
                side = 'r'
            elif 'left' in img_pth:
                side = 'l'
            else:
                raise AttributeError

            path_list += [
                "{} {} {}".format("{}_{}_{}".format(
                    env_mode, env_name, traj_name), frame_id, side
                )
            ]


def write_paths(file_path, paths):
    with open(file_path, 'w') as txt_file:
        for path in paths:
            txt_file.write("{}\n".format(path))


def create_blob_splits():
    envlist = get_environment_list()
    
    train_paths = []
    val_paths = []

    for env_name in envlist:
        for mode in ENV_MODES:
            env_trajects = get_trajectory_list(env_name, easy_hard=mode)
            
            if SHUFFLE_TRAJ_DIRS:
                rng = np.random.default_rng(42)
                rng.shuffle(env_trajects)

            train_trajects = env_trajects[:-NUM_VAL_TRAJECTS]
            val_trajects = env_trajects[-NUM_VAL_TRAJECTS:]
            
            append_paths_srings(env_name[:-1], mode, train_trajects, train_paths)
            append_paths_srings(env_name[:-1], mode, val_trajects, val_paths)
            
            # TODO: kiszedni
            break
        # TODO: kiszedni    
        break

    if SHUFFLE_IMAGES:
        rng = np.random.default_rng(42)
        rng.shuffle(train_paths)

    write_paths(os.path.join(WRK_PTH, 'train_files.txt'), train_paths)
    write_paths(os.path.join(WRK_PTH, 'val_files.txt'), val_paths)


if RUN:
    import time
    start = time.time()
    create_blob_splits()
    print(time.time() - start)
    print(time.time() - start)

## Data loaders

### Base class

In [None]:
import random
import numpy as np
from PIL import Image  # using pillow-simd for increased speed
import os
import PIL.Image as pil
from zipfile import ZipFile
import re
import torch
import torch.utils.data as data
from torchvision import transforms
from azure.storage.blob import ContainerClient


def pil_zip_loader(path):
    # open path as file to avoid ResourceWarning
    # (https://github.com/python-pillow/Pillow/issues/835)
    pattern_dir = re.compile(".*\.zip")
    match_dir = re.match(pattern_dir, path)
    img_path = path.split(".zip")[-1]
    img_path = img_path.replace("\\", "/")[1:]
    with ZipFile(match_dir[0], 'r') as zip_dir:
        with zip_dir.open(img_path.replace("\\", "/")) as f:
            with pil.open(f) as img:
                return img.convert('RGB')


# From https://github.com/nianticlabs/monodepth2
def pil_loader(path):
    # open path as file to avoid ResourceWarning
    # (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


class TartanAirDataset:
    def __init__(self,
         data_path,
         filenames,
         height,
         width,
         frame_idxs,
         num_scales,
         is_train=False,
         img_ext='.png'):
        
        # Dataset website: http://theairlab.org/tartanair-dataset/
        account_url = 'https://tartanair.blob.core.windows.net/'
        container_name = 'tartanair-release1'
        container_client = ContainerClient(account_url=account_url,
                                        container_name=container_name,
                                        credential=None)
        self.container_client = container_client

        self.data_path = data_path
        self.filenames = filenames
        self.height = height
        self.width = width
        self.num_scales = num_scales
        self.interp = Image.ANTIALIAS

        self.frame_idxs = frame_idxs

        self.is_train = is_train
        self.img_ext = img_ext

        self.loader = pil_loader
        self.to_tensor = transforms.ToTensor()

        # We need to specify augmentations differently in newer versions of torchvision.
        # We first try the newer tuple version; if this fails we fall back to scalars
        try:
            self.brightness = (0.8, 1.2)
            self.contrast = (0.8, 1.2)
            self.saturation = (0.8, 1.2)
            self.hue = (-0.1, 0.1)
            transforms.ColorJitter.get_params(
                self.brightness, self.contrast, self.saturation, self.hue)
        except TypeError:
            self.brightness = 0.2
            self.contrast = 0.2
            self.saturation = 0.2
            self.hue = 0.1

        self.resize = {}
        for i in range(self.num_scales):
            s = 2 ** i
            self.resize[i] = transforms.Resize((self.height // s, self.width // s),
                                               interpolation=self.interp)

        self.load_depth = self.check_depth()
        self.og_shape = (480, 640)  # (height, width)
        self.loader = pil_zip_loader
        # From https://arxiv.org/pdf/2011.00359.pdf
        fx = 320 / self.og_shape[1]
        fy = 320 / self.og_shape[0]
        self.K = np.array([
            [fx, 0., 0.5, 0.],
            [0., fy, 0.5, 0.],
            [0., 0., 1.0, 0.],
            [0., 0., 0.0, 1.],
        ], dtype=np.float32)
        self.fov = 90
        self.side_map = {"l": "left", "r": "right"}

    def preprocess(self, inputs, color_aug):
        """Resize colour images to the required scales and augment if required

        We create the color_aug object in advance and apply the same augmentation to all
        images in this item. This ensures that all images input to the pose network receive the
        same augmentation.
        """
        for k in list(inputs):
            frame = inputs[k]
            if "color" in k:
                n, im, i = k
                for i in range(self.num_scales):
                    inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])

        for k in list(inputs):
            f = inputs[k]
            if "color" in k:
                n, im, i = k
                inputs[(n, im, i)] = self.to_tensor(f)
                inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        """Returns a single training item from the dataset as a dictionary.

        Values correspond to torch tensors.
        Keys in the dictionary are either strings or tuples:

            ("color", <frame_id>, <scale>)          for raw colour images,
            ("color_aug", <frame_id>, <scale>)      for augmented colour images,
            ("K", scale) or ("inv_K", scale)        for camera intrinsics,
            "stereo_T"                              for camera extrinsics, and
            "depth_gt"                              for ground truth depth maps.

        <frame_id> is either:
            an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
        or
            "s" for the opposite image in the stereo pair.

        <scale> is an integer representing the scale of the image relative to the fullsize image:
            -1      images at native resolution as loaded from disk
            0       images resized to (self.width,      self.height     )
            1       images resized to (self.width // 2, self.height // 2)
            2       images resized to (self.width // 4, self.height // 4)
            3       images resized to (self.width // 8, self.height // 8)
        """
        inputs = {}

        do_color_aug = self.is_train and random.random() > 0.5
        do_flip = self.is_train and random.random() > 0.5

        line = self.filenames[index].split()
        folder = line[0]

        if len(line) == 3:
            frame_index = int(line[1])
        else:
            frame_index = 0

        if len(line) == 3:
            side = line[2]
        else:
            side = None

        for i in self.frame_idxs:
            if i == "s":
                other_side = {"r": "l", "l": "r"}[side]
                inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
            else:
                inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)

        # adjusting intrinsics to match each scale in the pyramid
        for scale in range(self.num_scales):
            K = self.K.copy()

            K[0, :] *= self.width // (2 ** scale)
            K[1, :] *= self.height // (2 ** scale)

            inv_K = np.linalg.pinv(K)

            inputs[("K", scale)] = torch.from_numpy(K)
            inputs[("inv_K", scale)] = torch.from_numpy(inv_K)

        if do_color_aug:
            color_aug = transforms.ColorJitter(
                self.brightness, self.contrast, self.saturation, self.hue)
        else:
            color_aug = (lambda x: x)

        self.preprocess(inputs, color_aug)

        for i in self.frame_idxs:
            del inputs[("color", i, -1)]
            del inputs[("color_aug", i, -1)]

        if self.load_depth:
            depth_gt = self.get_depth(folder, frame_index, side, do_flip)
            inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
            inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))

        if "s" in self.frame_idxs:
            stereo_T = np.eye(4, dtype=np.float32)
            baseline_sign = -1 if do_flip else 1
            side_sign = -1 if side == "l" else 1
            stereo_T[0, 3] = side_sign * baseline_sign * 0.1

            inputs["stereo_T"] = torch.from_numpy(stereo_T)

        return inputs

    def get_color(self, folder, frame_index, side, do_flip):
        image_file = self.get_image_path(folder, frame_index, side)
        bc = self.container_client.get_blob_client(blob=image_file)
        data = bc.download_blob()
        ee = io.BytesIO(data.content_as_bytes())
        img = cv2.imdecode(np.asarray(bytearray(ee.read()),dtype=np.uint8), cv2.IMREAD_COLOR)
        color = img[:, :, [2, 1, 0]] # BGR2RGB
        color = pil.fromarray(color, "RGB")

        if do_flip:
            color = color.transpose(pil.FLIP_LEFT_RIGHT)

        return color

    def get_image_path(self, folder, frame_index, side):
        folder_parts = folder.split('_')
        if len(folder_parts) == 3:
            mode, env, traj = folder_parts
        elif len(folder_parts) == 4:
            mode = folder_parts[0]
            env = f'{folder_parts[1]}_{folder_parts[2]}'
            traj = folder_parts[3]
        else:
            print('Error in getting image path!')
            raise ValueError
        side = self.side_map[side]
        img_folder = "image_{}".format(side)
        f_str = "{:06d}_{}{}".format(frame_index, side, self.img_ext)
        image_path = os.path.join(env, mode, traj, img_folder, f_str)
        return image_path

    def check_depth(self):
        return None

    def get_depth(self, folder, frame_index, side, do_flip):
        raise NotImplementedError

# Models and utils

## Util functions

### Transformations

In [None]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


def disp2depth(disp, min_depth, max_depth):
    """Convert network's sigmoid output into depth prediction
    The formula for this conversion is given in the 'additional considerations'
    section of the paper.
    from https://github.com/nianticlabs/monodepth2/blob/master/layers.py
    """
    min_disp = 1 / max_depth
    max_disp = 1 / min_depth
    scaled_disp = min_disp + (max_disp - min_disp) * disp
    depth = 1 / scaled_disp
    return scaled_disp, depth


def get_translation_matrix(translation_vector):
    """Convert a translation vector into a 4x4 transformation matrix
    from https://github.com/nianticlabs/monodepth2/blob/master/layers.py
    """
    T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)

    t = translation_vector.contiguous().view(-1, 3, 1)

    T[:, 0, 0] = 1
    T[:, 1, 1] = 1
    T[:, 2, 2] = 1
    T[:, 3, 3] = 1
    T[:, :3, 3, None] = t

    return T


def rot_from_axisangle(vec):
    """Convert an axisangle rotation into a 4x4 transformation matrix
    (adapted from https://github.com/Wallacoloo/printipi)
    Input 'vec' has to be Bx1x3
    from https://github.com/nianticlabs/monodepth2/blob/master/layers.py
    """
    angle = torch.norm(vec, 2, 2, True)
    axis = vec / (angle + 1e-7)

    ca = torch.cos(angle)
    sa = torch.sin(angle)
    C = 1 - ca

    x = axis[..., 0].unsqueeze(1)
    y = axis[..., 1].unsqueeze(1)
    z = axis[..., 2].unsqueeze(1)

    xs = x * sa
    ys = y * sa
    zs = z * sa
    xC = x * C
    yC = y * C
    zC = z * C
    xyC = x * yC
    yzC = y * zC
    zxC = z * xC

    rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)

    rot[:, 0, 0] = torch.squeeze(x * xC + ca)
    rot[:, 0, 1] = torch.squeeze(xyC - zs)
    rot[:, 0, 2] = torch.squeeze(zxC + ys)
    rot[:, 1, 0] = torch.squeeze(xyC + zs)
    rot[:, 1, 1] = torch.squeeze(y * yC + ca)
    rot[:, 1, 2] = torch.squeeze(yzC - xs)
    rot[:, 2, 0] = torch.squeeze(zxC - ys)
    rot[:, 2, 1] = torch.squeeze(yzC + xs)
    rot[:, 2, 2] = torch.squeeze(z * zC + ca)
    rot[:, 3, 3] = 1

    return rot


def transformation_from_parameters(axisangle, translation, invert=False):
    """Convert the network's (axisangle, translation) output into a 4x4 matrix
    from https://github.com/nianticlabs/monodepth2/blob/master/layers.py
    """
    R = rot_from_axisangle(axisangle)
    t = translation.clone()

    if invert:
        R = R.transpose(1, 2)
        t *= -1

    T = get_translation_matrix(t)

    if invert:
        M = torch.matmul(R, T)
    else:
        M = torch.matmul(T, R)

    return M


class BackprojectDepth(nn.Module):
    """Layer to transform a depth image into a point cloud
    from https://github.com/nianticlabs/monodepth2/blob/master/layers.py
    """
    def __init__(self, batch_size, height, width):
        super(BackprojectDepth, self).__init__()

        self.batch_size = batch_size
        self.height = height
        self.width = width

        meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
        self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
        self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
                                      requires_grad=False)

        self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
                                 requires_grad=False)

        self.pix_coords = torch.unsqueeze(torch.stack(
            [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
        self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
        self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
                                       requires_grad=False)

    def forward(self, depth, inv_K):
        cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
        cam_points = depth.view(self.batch_size, 1, -1) * cam_points
        cam_points = torch.cat([cam_points, self.ones], 1)

        return cam_points


class Project3D(nn.Module):
    """Layer which projects 3D points into a camera with intrinsics K and at position T
    from https://github.com/nianticlabs/monodepth2/blob/master/layers.py
    """
    def __init__(self, height, width, eps=1e-7):
        super(Project3D, self).__init__()

        self.height = height
        self.width = width
        self.eps = eps

    def forward(self, points, K, T):
        P = torch.matmul(K, T)[:, :3, :]

        cam_points = torch.matmul(P, points)

        pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
        pix_coords = pix_coords.view(-1, 2, self.height, self.width)
        pix_coords = pix_coords.permute(0, 2, 3, 1)
        pix_coords[..., 0] /= self.width - 1
        pix_coords[..., 1] /= self.height - 1
        pix_coords = (pix_coords - 0.5) * 2
        return pix_coords


### Losses

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np


def get_smooth_loss(disp: torch.Tensor, img: torch.Tensor):
    """Computes the smoothness loss for a disparity image
    The color image is used for edge-aware smoothness
    from https://github.com/nianticlabs/monodepth2/blob/master/layers.py
    """
    grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
    grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])

    grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
    grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)

    grad_disp_x *= torch.exp(-grad_img_x)
    grad_disp_y *= torch.exp(-grad_img_y)

    return grad_disp_x.mean() + grad_disp_y.mean()


class SSIM(nn.Module):
    """Layer to compute the SSIM loss between a pair of images
    from https://github.com/nianticlabs/monodepth2/blob/master/layers.py
    """
    def __init__(self):
        super(SSIM, self).__init__()
        self.mu_x_pool   = nn.AvgPool2d(3, 1)
        self.mu_y_pool   = nn.AvgPool2d(3, 1)
        self.sig_x_pool  = nn.AvgPool2d(3, 1)
        self.sig_y_pool  = nn.AvgPool2d(3, 1)
        self.sig_xy_pool = nn.AvgPool2d(3, 1)

        self.refl = nn.ReflectionPad2d(1)

        self.C1 = 0.01 ** 2
        self.C2 = 0.03 ** 2

    def forward(self, x, y):
        x = self.refl(x)
        y = self.refl(y)

        mu_x = self.mu_x_pool(x)
        mu_y = self.mu_y_pool(y)

        sigma_x  = self.sig_x_pool(x ** 2) - mu_x ** 2
        sigma_y  = self.sig_y_pool(y ** 2) - mu_y ** 2
        sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y

        SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
        SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)

        return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)


class MVS3DLoss(nn.Module):
    """
    Multi-view 3D structure consistency loss
    """
    def __init__(self, batch_size: int, height: int, width: int, reduction: str = 'mean'):
        super(MVS3DLoss, self).__init__()
        self.back_project = BackprojectDepth(batch_size, height, width)
        self.reduction = reduction

    def forward(
            self,
            depth_src: torch.Tensor,
            depth_tgt: torch.Tensor,
            pose: torch.Tensor,
            inv_intrinsics: torch.Tensor
    ):
        cloud_src = self.back_project(depth_src, inv_intrinsics)
        cloud_tgt = self.back_project(depth_tgt, inv_intrinsics)

        if self.reduction == 'mean':
            return (cloud_tgt - pose @ cloud_src).abs().mean()
        elif self.reduction == 'sum':
            return (cloud_tgt - pose @ cloud_src).abs().sum()
        elif self.reduction == 'none':
            return (cloud_tgt - pose @ cloud_src).abs()
        else:
            raise NotImplementedError


class EpipolarLoss(nn.Module):
    """
    Epipolar geometry for loss calculation
    """
    def __init__(self, batch_size: int, height: int, width: int, pix_group_size: int = 128, reduction: str = 'mean'):
        super(EpipolarLoss, self).__init__()

        self.batch_size = batch_size
        self.height = height
        self.width = width
        self.pix_group_size = pix_group_size
        self.reduction = reduction

        meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
        self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
        self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
                                      requires_grad=False)

        self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
                                 requires_grad=False)

        self.pix_coords = torch.unsqueeze(torch.stack(
            [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
        self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
        self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
                                       requires_grad=False)

    def forward(
            self,
            flow: torch.Tensor,
            pose: torch.Tensor,
            inv_intrinsics: torch.Tensor
    ) -> torch.Tensor:
        rotation = pose[:, :3, :3]
        translation = pose[:, :3, -1].unsqueeze(1)

        # Translation skew-matrix
        t_skew = torch.zeros_like(translation)
        t_skew = t_skew.repeat(1, translation.shape[-1], 1)
        t_skew[:, 0, 1] = -translation[..., 2]
        t_skew[:, 0, 2] = translation[..., 1]
        t_skew[:, 1, 2] = -translation[..., 0]
        t_skew = t_skew - t_skew.transpose(1, 2)

        # Move pixels with flow
        flattened_flow = flow.view(flow.shape[0], flow.shape[1], -1)
        pix_flow = torch.ones_like(self.pix_coords)
        pix_flow[:, :2] = self.pix_coords[:, :2] + flattened_flow

        # Epipolar geometry using the predicted flow as target image
        if self.pix_group_size > 1:
            pix_losses = None
            for idx in range(0, self.pix_coords.shape[-1], self.pix_group_size):
                pix_coords = self.pix_coords[..., idx:idx+self.pix_group_size]
                poseT_mm_invT = pix_coords.transpose(1, 2) @ inv_intrinsics[:, :3, :3].transpose(1, 2)
                rot_tskew_inv = rotation @ t_skew @ inv_intrinsics[:, :3, :3]
                if pix_losses is not None:
                    pix_losses = pix_losses + \
                                 (poseT_mm_invT @ rot_tskew_inv @ pix_flow[..., idx:idx+self.pix_group_size]
                                  ).sum()
                else:
                    pix_losses = (poseT_mm_invT @ rot_tskew_inv @ pix_flow[..., idx:idx+self.pix_group_size]
                                  ).sum()
        else:
            pix_losses = self.pix_coords.transpose(1, 2) @ inv_intrinsics[:, :3, :3].transpose(1, 2) \
                         @ rotation @ t_skew @ inv_intrinsics[:, :3, :3] @ pix_flow

        if self.reduction == 'sum':
            if self.pix_group_size > 1:
                return pix_losses
            else:
                return pix_losses.sum()
        elif self.reduction == 'mean':
            if self.pix_group_size > 1:
                return pix_losses / self.pix_coords.shape[-1]
            else:
                return pix_losses.mean()
        elif self.reduction == 'none':
            return pix_losses
        else:
            raise NotImplementedError


class AdaptivePhotometricLoss(nn.Module):
    def __init__(self, batch_size: int, height: int, width: int, r: float = 0.85, reduction: str = 'mean'):
        super(AdaptivePhotometricLoss, self).__init__()
        self.batch_size = batch_size
        self.height = height
        self.width = width
        self.reduction = reduction

        self.project = Project3D(height, width)
        self.back_project = BackprojectDepth(batch_size, height, width)
        self.ssim = SSIM()
        self.r = r

        meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
        self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
        self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
                                      requires_grad=False)

        self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
                                 requires_grad=False)

        self.pix_coords = torch.unsqueeze(torch.stack(
            [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
        self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
        self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
                                       requires_grad=False)

    def forward(
            self,
            img_src: torch.Tensor,
            img_tgt: torch.Tensor,
            depth: torch.Tensor,
            flow: torch.Tensor,
            pose: torch.Tensor,
            inv_intrinsics: torch.Tensor
    ):
        # 3D warp target image
        cam_coord_src = self.back_project(depth, inv_intrinsics)
        pix_coords_tgt = self.project(cam_coord_src, inv_intrinsics, pose)
        warped_tgt_3d = F.grid_sample(img_src, pix_coords_tgt,
                                      padding_mode='border', align_corners=True)

        # Flow warp target image
        pix_coords = self.pix_coords[:, :2].view(self.batch_size, 2, self.height, self.width).contiguous()
        pix_flow = pix_coords + flow
        warped_tgt_flow = F.grid_sample(img_src, pix_flow.permute(0, 2, 3, 1).contiguous(),
                                        padding_mode='border', align_corners=True)

        # Pixel-wise minimum of SSIM maps
        ssim_3d = self.ssim(img_tgt, warped_tgt_3d).mean(1, keepdim=True)
        ssim_flow = self.ssim(img_tgt, warped_tgt_flow).mean(1, keepdim=True)

        l1_3d = (img_tgt - warped_tgt_3d).abs().mean(1, keepdim=True)
        l1_flow = (img_tgt - warped_tgt_flow).abs().mean(1, keepdim=True)

        s_3d = self.r * (1 - ssim_3d) / 2 + (1 - self.r) * l1_3d
        s_flow = self.r * (1 - ssim_flow) / 2 + (1 - self.r) * l1_flow

        apc_loss = torch.stack([s_3d, s_flow], dim=1)
        apc_loss = apc_loss.min(dim=1)[0]

        if self.reduction == 'mean':
            return apc_loss.mean()
        elif self.reduction == 'sum':
            return apc_loss.sum()
        elif self.reduction == 'none':
            return apc_loss
        else:
            raise NotImplementedError


class FwdBwdFlowConsistency(nn.Module):
    def __init__(
            self,
            batch_size: int,
            height: int,
            width: int,
            alpha: float = 3.,
            beta: float = 0.05,
            scale: int = 0,
            reduction: str = 'mean'
    ):
        super(FwdBwdFlowConsistency, self).__init__()
        self.batch_size = batch_size
        self.height = height
        self.width = width
        self.alpha = alpha
        self.beta = beta
        self.scale = scale
        self.reduction = reduction

        meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
        self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
        self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
                                      requires_grad=False)

        self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
                                 requires_grad=False)

        self.pix_coords = torch.unsqueeze(torch.stack(
            [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
        self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
        self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
                                       requires_grad=False)

    def forward(self, flow_fwd: torch.Tensor, flow_bwd: torch.Tensor):
        # Warp
        bwd2fwd = F.grid_sample(flow_bwd, flow_fwd, padding_mode='border', align_corners=True)
        fwd2bwd = F.grid_sample(flow_fwd, flow_bwd, padding_mode='border', align_corners=True)

        # Consistency error
        diff_fwd = (bwd2fwd + flow_fwd).abs()
        diff_bwd = (fwd2bwd + flow_bwd).abs()

        # Condition
        bound_fwd = self.beta * (2 ** self.scale) * flow_fwd.norm(p=2, dim=1, keepdim=True)
        with torch.no_grad:
            bound_fwd = bound_fwd.clamp_min(self.alpha)

        bound_bwd = self.beta * (2 ** self.scale) * flow_bwd.norm(p=2, dim=1, keepdim=True)
        with torch.no_grad:
            bound_bwd = bound_bwd.clamp_min(self.alpha)

        # Mask
        noc_mask_src = ((2 ** self.scale) * diff_bwd.norm(p=2, dim=1, keepdim=True) < bound_bwd)
        noc_mask_tgt = ((2 ** self.scale) * diff_fwd.norm(p=2, dim=1, keepdim=True) < bound_fwd)

        # Consistency loss
        loss_fwd = (diff_fwd.mean(dim=1, keepdim=True) * noc_mask_tgt).sum() / noc_mask_tgt.sum()
        loss_bwd = (diff_bwd.mean(dim=1, keepdim=True) * noc_mask_src).sum() / noc_mask_src.sum()
        consistency_loss = (loss_fwd + loss_bwd) / 2

        if self.reduction == 'mean':
            return consistency_loss.mean()
        elif self.reduction == 'sum':
            return consistency_loss.sum()
        elif self.reduction == 'none':
            return consistency_loss
        else:
            raise NotImplementedError


def run_once(f):
    def wrapper(*args, **kwargs):
        if not wrapper.has_run:
            wrapper.has_run = True
            return f(*args, **kwargs)
    wrapper.has_run = False
    return wrapper


class GLNetLoss(nn.Module):
    def __init__(
            self,
            img_size: tuple,
            scales: int,
            mvs_weight: float,
            epi_weight: float,
            apc_weight: float,
            disp_smooth: float,
            flow_smooth: float,
            flow_cons_params: tuple = None,
            flow_cons_weight: float = 0.,
            ssim_r: float = 0.85,
            reduction: str = 'mean'
    ):
        super(GLNetLoss, self).__init__()
        self.height = img_size[0]
        self.width = img_size[1]
        self.scales = scales
        self.mvs_weight = mvs_weight
        self.epi_weight = epi_weight
        self.apc_weight = apc_weight
        self.disp_smooth = disp_smooth
        self.flow_smooth = flow_smooth
        self.flow_cons_params = flow_cons_params
        self.flow_cons_weight = flow_cons_weight
        self.ssim_r = ssim_r
        self.reduction = reduction

        self.pix_coords_pyramid = {}
        for scale in range(scales):
            self.pix_coords_pyramid[scale] = \
                self.__generate_pix_coords(self.height // (2 ** scale), self.width // (2 ** scale))

        self.ssim_f_3d = SSIM()
        self.ssim_f_flow = SSIM()

    def __generate_pix_coords(self, height: int, width: int):
        meshgrid = np.meshgrid(range(width), range(height), indexing='xy')
        id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
        id_coords = torch.from_numpy(id_coords)
        ones = torch.ones(1, 1, height * width)
        pix_coords = torch.unsqueeze(
            torch.stack([id_coords[0].view(-1), id_coords[1].view(-1)], 0)
            , 0
        )
        pix_coords = torch.cat([pix_coords, ones], 1)
        return pix_coords

    @run_once
    def __set_batch_size(self, batch_size: int, ref_tensor: torch.Tensor):
        self.batch_size = batch_size
        for i in range(len(self.pix_coords_pyramid)):
            self.pix_coords_pyramid[i] = self.pix_coords_pyramid[i].repeat(batch_size, 1, 1).to(ref_tensor)

    def forward(
            self,
            inputs: dict,
            depths: dict,
            poses: dict,
            flows_fwd: dict,
            scales: int,
            disps: dict = None,
            flows_bwd: dict = None
    ):
        pose_frames = list(poses.keys())
        flow_frames = list(flows_fwd.keys())

        calc_disp_smooth = (disps is not None and self.disp_smooth != 0.)
        calc_flow_smooth = self.flow_smooth != 0.
        calc_flow_consistency = (flows_bwd is not None and self.flow_cons_weight != 0.)

        mvs_loss = 0.
        epipolar_loss = 0.
        apc_loss = 0.
        disp_smooth_loss = 0.
        flow_smooth_loss = 0.
        flow_consistency_loss = 0.

        for scale in range(scales):
            for frame_group in list(set(pose_frames) & set(flow_frames)):
                pose_params = poses[frame_group]
                flow_fwd = flows_fwd[frame_group][('flow', scale)]
                depth_src = depths[frame_group[0]][('depth', scale)]
                depth_tgt = depths[frame_group[-1]][('depth', scale)]
                inv_intrinsics = inputs[('inv_K', scale)]
                intrinsics = inputs[('K', scale)]
                image_src = inputs[('color', frame_group[0], scale)]
                image_tgt = inputs[('color', frame_group[-1], scale)]

                pose = transformation_from_parameters(pose_params['axisangle'], pose_params['translation'])

                # Multi-view 3D structure loss
                mvs_loss += self.mvs3d_loss(depth_src, depth_tgt, pose, inv_intrinsics, scale)

                # Epipolar loss
                epipolar_loss += self.epipolar_loss(flow_fwd, pose, inv_intrinsics, scale, pix_group_size=128)

                # Adaptive Photometric Loss
                apc_loss += self.adaptive_photometric_loss(image_src, image_tgt, depth_src, flow_fwd,
                                                           pose, intrinsics, inv_intrinsics, scale)

            if calc_disp_smooth:
                # Disparity smoothness loss
                for frame_id, disp_i in disps.items():
                    disp_e = disp_i[('disp', scale)]
                    mean_disp = disp_e.mean(2, True).mean(3, True)
                    norm_disp = disp_e / (mean_disp + 1e-7)
                    disp_smooth_loss += \
                        (self.disp_smooth / (2 ** scale) *
                         get_smooth_loss(norm_disp, inputs[('color', frame_id, scale)]))

            if calc_flow_smooth:
                # Flow smoothness loss
                div = (2 ** (scale + 1)) if flows_bwd is not None else (2 ** scale)
                for frame_group, flow_fwd_i in flows_fwd.items():
                    for chan in range(2):
                        flow_smooth_loss += (self.flow_smooth / div *
                                             get_smooth_loss(flow_fwd_i[('flow', scale)][:, chan].unsqueeze(1),
                                                             inputs[('color', frame_group[-1], scale)]))
                        if flows_bwd is not None:
                            flow_bwd_i = flows_bwd[frame_group]
                            flow_smooth_loss += (self.flow_smooth / div *
                                                 get_smooth_loss(flow_bwd_i[('flow', scale)][:, chan].unsqueeze(1),
                                                                 inputs[('color', frame_group[0], scale)]))

            if calc_flow_consistency:
                # Forward-backward flow consistency
                for f_idx, flow_fwd_i in flows_fwd.items():
                    flow_bwd_i = flows_bwd[f_idx]
                    flow_consistency_loss += self.fwd_bwd_flow_consistency(flow_fwd_i[('flow', scale)],
                                                                           flow_bwd_i[('flow', scale)],
                                                                           scale)

        total_loss = self.mvs_weight * mvs_loss + self.epi_weight * epipolar_loss + self.apc_weight * apc_loss
        if calc_disp_smooth:
            total_loss += disp_smooth_loss
        if calc_flow_smooth:
            total_loss += flow_smooth_loss
        if calc_flow_consistency:
            total_loss += (self.flow_cons_weight * flow_consistency_loss)
        total_loss /= self.scales

        loss_parts = {
            'mvs': mvs_loss / self.scales,
            'epi': epipolar_loss / self.scales,
            'apc': apc_loss / self.scales,
            'ds': disp_smooth_loss / self.scales,
            'fs': flow_smooth_loss / self.scales,
            'fc': flow_consistency_loss / self.scales
        }
        return total_loss, loss_parts

    def back_project(self, depth: torch.Tensor, inv_intrinsics: torch.Tensor, scale: int) -> torch.Tensor:
        self.__set_batch_size(depth.shape[0], depth)

        cam_points = torch.matmul(inv_intrinsics[:, :3, :3], self.pix_coords_pyramid[scale])
        cam_points = depth.view(self.batch_size, 1, -1) * cam_points
        cam_points = torch.cat([cam_points, self.pix_coords_pyramid[scale][:, -1].unsqueeze(1)], 1)

        return cam_points

    def project_3d(self, cloud: torch.Tensor, intrinsics: torch.Tensor, pose: torch.Tensor, scale: int) -> torch.Tensor:
        self.__set_batch_size(cloud.shape[0], cloud)

        P = torch.matmul(intrinsics, pose)[:, :3, :]
        cam_points = torch.matmul(P, cloud)
        pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + 1e-6)
        pix_coords = pix_coords.view(-1, 2, self.height // (2 ** scale), self.width // (2 ** scale))
        pix_coords = pix_coords.permute(0, 2, 3, 1)
        pix_coords[..., 0] /= self.width // (2 ** scale) - 1
        pix_coords[..., 1] /= self.height // (2 ** scale) - 1
        pix_coords = (pix_coords - 0.5) * 2
        return pix_coords

    def mvs3d_loss(
            self,
            depth_src: torch.Tensor,
            depth_tgt: torch.Tensor,
            pose: torch.Tensor,
            inv_intrinsics: torch.Tensor,
            scale: int
    ):
        self.__set_batch_size(depth_src.shape[0], depth_src)
        cloud_src = self.back_project(depth_src, inv_intrinsics, scale)
        cloud_tgt = self.back_project(depth_tgt, inv_intrinsics, scale)

        if self.reduction == 'mean':
            return (cloud_tgt - pose @ cloud_src).abs().mean()
        elif self.reduction == 'sum':
            return (cloud_tgt - pose @ cloud_src).abs().sum()
        elif self.reduction == 'none':
            return (cloud_tgt - pose @ cloud_src).abs()
        else:
            raise NotImplementedError

    def epipolar_loss(
            self,
            flow: torch.Tensor,
            pose: torch.Tensor,
            inv_intrinsics: torch.Tensor,
            scale: int,
            pix_group_size: int = 1,
    ):
        self.__set_batch_size(pose.shape[0], pose)
        pix_coords = self.pix_coords_pyramid[scale]
        rotation = pose[:, :3, :3]
        translation = pose[:, :3, -1].unsqueeze(1)

        # Translation skew-matrix
        t_skew = torch.zeros_like(translation)
        t_skew = t_skew.repeat(1, translation.shape[-1], 1)
        t_skew[:, 0, 1] = -translation[:, 0, 2]
        t_skew[:, 0, 2] = translation[:, 0, 1]
        t_skew[:, 1, 2] = -translation[:, 0, 0]
        t_skew = t_skew - t_skew.transpose(1, 2)

        # Move pixels with flow
        flattened_flow = flow.view(flow.shape[0], flow.shape[1], -1)
        pix_flow = torch.ones_like(pix_coords)
        pix_flow[:, :2] = pix_coords[:, :2] + flattened_flow

        # Epipolar geometry using the predicted flow as target image
        if pix_group_size > 1:
            pix_losses = 0.
            for idx in range(0, pix_coords.shape[-1], pix_group_size):
                pix_coords_i = pix_coords[..., idx:idx + pix_group_size]
                poseT_mm_invT = pix_coords_i.transpose(1, 2) @ inv_intrinsics[:, :3, :3].transpose(1, 2)
                rot_tskew_inv = rotation @ t_skew @ inv_intrinsics[:, :3, :3]
                pix_losses += (poseT_mm_invT @ rot_tskew_inv @ pix_flow[..., idx:idx + pix_group_size]
                               ).abs().sum()
        else:
            pix_losses = (pix_coords.transpose(1, 2) @ inv_intrinsics[:, :3, :3].transpose(1, 2) @
                          rotation @ t_skew @ inv_intrinsics[:, :3, :3] @ pix_flow
                          ).abs().sum()

        if self.reduction == 'sum':
            if pix_group_size > 1:
                return pix_losses
            else:
                return pix_losses.sum()
        elif self.reduction == 'mean':
            if pix_group_size > 1:
                return pix_losses / pix_coords.shape[-1]
            else:
                return pix_losses.mean()
        elif self.reduction == 'none':
            return pix_losses
        else:
            raise NotImplementedError

    def adaptive_photometric_loss(
            self,
            img_src: torch.Tensor,
            img_tgt: torch.Tensor,
            depth: torch.Tensor,
            flow: torch.Tensor,
            pose: torch.Tensor,
            intrinsics: torch.Tensor,
            inv_intrinsics: torch.Tensor,
            scale: int
    ):
        self.__set_batch_size(pose.shape[0], pose)
        r = self.ssim_r
        pix_coords = self.pix_coords_pyramid[scale]

        # 3D warp target image
        cam_coord_src = self.back_project(depth, inv_intrinsics, scale)
        pix_coords_tgt = self.project_3d(cam_coord_src, intrinsics, pose, scale)
        warped_tgt_3d = F.grid_sample(img_src, pix_coords_tgt, padding_mode='border', align_corners=True)

        # Flow warp target image
        pix_coords_i = pix_coords[:, :2]\
            .view(self.batch_size, 2, self.height // (2 ** scale), self.width // (2 ** scale)).contiguous()
        pix_flow = pix_coords_i + flow
        warped_tgt_flow = F.grid_sample(img_src, pix_flow.permute(0, 2, 3, 1).contiguous(),
                                        padding_mode='border', align_corners=True)

        # Pixel-wise minimum of SSIM maps
        ssim_3d = self.ssim_f_3d(img_tgt, warped_tgt_3d).mean(1, keepdim=True)
        ssim_flow = self.ssim_f_flow(img_tgt, warped_tgt_flow).mean(1, keepdim=True)

        l1_3d = (img_tgt - warped_tgt_3d).abs().mean(1, keepdim=True)
        l1_flow = (img_tgt - warped_tgt_flow).abs().mean(1, keepdim=True)

        s_3d = r * (1 - ssim_3d) / 2 + (1 - r) * l1_3d
        s_flow = r * (1 - ssim_flow) / 2 + (1 - r) * l1_flow

        apc_loss = torch.stack([s_3d, s_flow], dim=1)
        apc_loss = apc_loss.min(dim=1)[0]

        if self.reduction == 'mean':
            return apc_loss.mean()
        elif self.reduction == 'sum':
            return apc_loss.sum()
        elif self.reduction == 'none':
            return apc_loss
        else:
            raise NotImplementedError

    def fwd_bwd_flow_consistency(self, flow_fwd: torch.Tensor, flow_bwd: torch.Tensor, scale: int):
        self.__set_batch_size(flow_fwd.shape[0], flow_fwd)
        alpha = self.flow_cons_params[0]
        beta = self.flow_cons_params[1]

        # Warp
        bwd2fwd = F.grid_sample(flow_bwd, flow_fwd.permute(0, 2, 3, 1), padding_mode='border', align_corners=True)
        fwd2bwd = F.grid_sample(flow_fwd, flow_bwd.permute(0, 2, 3, 1), padding_mode='border', align_corners=True)

        # Consistency error
        diff_fwd = (bwd2fwd + flow_fwd).abs()
        diff_bwd = (fwd2bwd + flow_bwd).abs()

        # Condition
        bound_fwd = beta * (2 ** scale) * flow_fwd.norm(p=2, dim=1, keepdim=True)
        with torch.no_grad():
            bound_fwd = bound_fwd.clamp_min(alpha)

        bound_bwd = beta * (2 ** scale) * flow_bwd.norm(p=2, dim=1, keepdim=True)
        with torch.no_grad():
            bound_bwd = bound_bwd.clamp_min(alpha)

        # Mask
        noc_mask_src = ((2 ** scale) * diff_bwd.norm(p=2, dim=1, keepdim=True) < bound_bwd)
        noc_mask_tgt = ((2 ** scale) * diff_fwd.norm(p=2, dim=1, keepdim=True) < bound_fwd)

        # Consistency loss
        loss_fwd = (diff_fwd.mean(dim=1, keepdim=True) * noc_mask_tgt).sum() / noc_mask_tgt.sum()
        loss_bwd = (diff_bwd.mean(dim=1, keepdim=True) * noc_mask_src).sum() / noc_mask_src.sum()
        consistency_loss = (loss_fwd + loss_bwd) / 2
        return consistency_loss

## Model

### Modules

#### Depth decoder

In [None]:
"""
Modified from https://github.com/nianticlabs/monodepth2/blob/master/networks/depth_decoder.py
and https://github.com/nianticlabs/monodepth2/blob/master/layers.py
"""

# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.


import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import OrderedDict


class ConvBlock(nn.Module):
    """Layer to perform a convolution followed by ELU
    """
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        self.conv = Conv3x3(in_channels, out_channels)
        self.nonlin = nn.ELU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.nonlin(out)
        return out


class Conv3x3(nn.Module):
    """Layer to pad and convolve input
    """
    def __init__(self, in_channels, out_channels, use_refl=True):
        super(Conv3x3, self).__init__()

        if use_refl:
            self.pad = nn.ReflectionPad2d(1)
        else:
            self.pad = nn.ZeroPad2d(1)
        self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)

    def forward(self, x):
        out = self.pad(x)
        out = self.conv(out)
        return out


def upsample(x, size=None):
    """Upsample input tensor by a factor of 2
    """
    if size is not None:
        return F.interpolate(x, size=size, mode="nearest")
    else:
        return F.interpolate(x, scale_factor=2, mode="nearest")


class DepthDecoder(nn.Module):
    def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True):
        super(DepthDecoder, self).__init__()

        self.num_output_channels = num_output_channels
        self.use_skips = use_skips
        self.upsample_mode = 'nearest'
        self.scales = scales

        self.num_ch_enc = num_ch_enc
        self.num_ch_dec = np.array([16, 32, 64, 128, 256])

        # decoder
        self.convs = OrderedDict()
        for i in range(4, -1, -1):
            # upconv_0
            num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
            num_ch_out = self.num_ch_dec[i]
            self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)

            # upconv_1
            num_ch_in = self.num_ch_dec[i]
            if self.use_skips and i > 0:
                num_ch_in += self.num_ch_enc[i - 1]
            num_ch_out = self.num_ch_dec[i]
            self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)

        for s in self.scales:
            self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)

        self.decoder = nn.ModuleList(list(self.convs.values()))
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_features):
        self.outputs = {}

        # decoder
        x = input_features[-1]
        for i in range(4, -1, -1):
            x = self.convs[("upconv", i, 0)](x)
            if i != 0:
                x = [upsample(x, input_features[i - 1].shape[-2:])]
            else:
                x = [upsample(x)]
            if self.use_skips and i > 0:
                x += [input_features[i - 1]]
            x = torch.cat(x, 1)
            x = self.convs[("upconv", i, 1)](x)
            if i in self.scales:
                self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))

        return self.outputs

#### Flow decoder

In [None]:
"""
Modified from https://github.com/yzcjtr/GeoNet/blob/master/geonet_nets.py
"""

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import OrderedDict


class ConvBlock(nn.Module):
    """Layer to perform a convolution followed by ELU
    """
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        self.conv = Conv3x3(in_channels, out_channels)
        self.nonlin = nn.ELU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.nonlin(out)
        return out


class Conv3x3(nn.Module):
    """Layer to pad and convolve input
    """
    def __init__(self, in_channels, out_channels, use_refl=True):
        super(Conv3x3, self).__init__()

        if use_refl:
            self.pad = nn.ReflectionPad2d(1)
        else:
            self.pad = nn.ZeroPad2d(1)
        self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)

    def forward(self, x):
        out = self.pad(x)
        out = self.conv(out)
        return out


def upsample(x, size=None):
    """Upsample input tensor by a factor of 2
    """
    if size is not None:
        return F.interpolate(x, size=size, mode="nearest")
    else:
        return F.interpolate(x, scale_factor=2, mode="nearest")


class FlowDecoder(nn.Module):
    def __init__(self, num_ch_enc, scales=range(4), num_output_frames=1, use_skips=True):
        super(FlowDecoder, self).__init__()

        self.num_output_frames = num_output_frames
        self.use_skips = use_skips
        self.upsample_mode = 'nearest'
        self.scales = scales

        self.num_ch_enc = num_ch_enc
        self.num_ch_dec = np.array([16, 32, 64, 128, 256])

        self.flow_scale = 0.1

        # decoder
        self.convs = OrderedDict()
        for i in range(4, -1, -1):
            # upconv_0
            num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
            num_ch_out = self.num_ch_dec[i]
            self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)

            # upconv_1
            num_ch_in = self.num_ch_dec[i]
            if self.use_skips and i > 0:
                num_ch_in += self.num_ch_enc[i - 1]
            num_ch_out = self.num_ch_dec[i]
            self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)

        for s in self.scales:
            self.convs[("flowconv", s)] = Conv3x3(self.num_ch_dec[s], 2 * self.num_output_frames)

        self.decoder = nn.ModuleList(list(self.convs.values()))

    def forward(self, input_features):
        self.outputs = {}

        # decoder
        x = input_features[-1]
        for i in range(4, -1, -1):
            if i != 0:
                x = upsample(x, input_features[i - 1].shape[-2:])
            else:
                x = upsample(x)
            x = [self.convs[("upconv", i, 0)](x)]
            if self.use_skips and i > 0:
                x += [input_features[i - 1]]
            x = torch.cat(x, 1)
            x = self.convs[("upconv", i, 1)](x)
            if i in self.scales:
                self.outputs[("flow", i)] = self.flow_scale * self.convs[("flowconv", i)](x)

        return self.outputs

#### PoseCNN

In [None]:
"""
From https://github.com/nianticlabs/monodepth2/blob/master/networks/pose_cnn.py
"""

# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.


import torch
import torch.nn as nn


class PoseCNN(nn.Module):
    def __init__(self, num_input_frames, intrinsics=False, num_output_frames=None):
        super(PoseCNN, self).__init__()

        self.num_input_frames = num_input_frames
        self.num_output_frames = (num_input_frames - 1) if num_output_frames is None else num_output_frames
        self.intrinsics = intrinsics
        self.output_parameters = 6 if not intrinsics else 8

        self.convs = {}
        self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3)
        self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2)
        self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1)
        self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1)
        self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1)
        self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1)
        self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1)

        self.pose_conv = nn.Conv2d(256, self.output_parameters * self.num_output_frames, 1)

        self.num_convs = len(self.convs)

        self.relu = nn.ReLU(True)

        self.net = nn.ModuleList(list(self.convs.values()))

    def forward(self, out):

        for i in range(self.num_convs):
            out = self.convs[i](out)
            out = self.relu(out)

        out = self.pose_conv(out)
        out = out.mean(3).mean(2)

        out = 0.01 * out.view(-1, self.num_output_frames, 1, self.output_parameters)

        axisangle = out[..., :3].squeeze(1)
        if not self.intrinsics:
            translation = out[..., 3:].squeeze(1)
            return {'axisangle': axisangle, 'translation': translation}
        else:
            translation = out[..., 3:7].squeeze(1)
            intrinsics = out[...,  7:].squeeze(1)
            return {'axisangle': axisangle, 'translation': translation, 'intrinsics': intrinsics}


#### ResNet encoder

In [None]:
"""
Modified from https://github.com/nianticlabs/monodepth2/blob/master/networks/resnet_encoder.py
"""

# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.

import numpy as np

import torch
import torch.nn as nn
import torchvision.models as models
import torch.utils.model_zoo as model_zoo


class ResNetMultiImageInput(models.ResNet):
    """Constructs a resnet model with varying number of input images.
    Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
    """
    def __init__(self, block, layers, num_classes=1000, num_input_images=1):
        super(ResNetMultiImageInput, self).__init__(block, layers)
        self.inplanes = 64
        self.conv1 = nn.Conv2d(
            num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
    """Constructs a ResNet model.
    Args:
        num_layers (int): Number of resnet layers. Must be 18 or 50
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        num_input_images (int): Number of frames stacked as input
    """
    assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
    blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
    block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
    model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)

    if pretrained:
        loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
        loaded['conv1.weight'] = torch.cat(
            [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
        model.load_state_dict(loaded)
    return model


class ResnetEncoder(nn.Module):
    """Pytorch module for a resnet encoder
    """
    def __init__(self, num_layers, pretrained, num_input_images=1):
        super(ResnetEncoder, self).__init__()

        self.num_ch_enc = np.array([64, 64, 128, 256, 512])

        resnets = {18: models.resnet18,
                   34: models.resnet34,
                   50: models.resnet50,
                   101: models.resnet101,
                   152: models.resnet152}

        # weights = {18: models.ResNet18_Weights.DEFAULT,
        #            34: models.ResNet34_Weights.DEFAULT,
        #            50: models.ResNet50_Weights.IMAGENET1K_V2,
        #            101: models.ResNet101_Weights.IMAGENET1K_V2,
        #            152: models.ResNet152_Weights.IMAGENET1K_V2}

        if num_layers not in resnets:
            raise ValueError("{} is not a valid number of resnet layers".format(num_layers))

        if num_input_images > 1:
            self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
        else:
            # self.encoder = resnets[num_layers](weights=weights[num_layers])
            self.encoder = resnets[num_layers](pretrained=True)

        if num_layers > 34:
            self.num_ch_enc[1:] *= 4

    def forward(self, input_image, norm_input=False):
        self.features = []
        if norm_input:
            x = (input_image - 0.45) / 0.225
        else:
            x = input_image
        x = self.encoder.conv1(x)
        x = self.encoder.bn1(x)
        self.features.append(self.encoder.relu(x))
        self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
        self.features.append(self.encoder.layer2(self.features[-1]))
        self.features.append(self.encoder.layer3(self.features[-1]))
        self.features.append(self.encoder.layer4(self.features[-1]))

        return self.features

### GLNet model 

In [None]:
import numpy as np
import torch
from torch import nn
from itertools import combinations

class GLNet:
    def __init__(
            self,
            pose_input_num: int,
            pose_output_num: int,
            depth_input_num: int,
            depth_output_num: int,
            flow_input_num: int,
            flow_output_num: int,
            loss_parameters: dict,
            pred_intrinsics: bool = False,
            depth_res_layers: int = 18,
            depth_use_skips: bool = True,
            flow_res_layers: int = 18,
            flow_use_skips: int = True,
            resnet_pretrained: bool = True,
            shared_resnet: bool = True,
            depth_limits: tuple = (0.1, 100.),
            frame_ids: list = None,
            scales: int = 4
    ):
        self.shared_resnet = shared_resnet
        self.intrinsics = pred_intrinsics

        self.pose_input_num = pose_input_num
        self.depth_input_num = depth_input_num
        self.flow_input_num = flow_input_num

        self.scales = scales

        self.frame_ids = np.sort(frame_ids) if frame_ids is not None else [0, 1]
        self.depth_limits = depth_limits

        if shared_resnet and \
                (depth_input_num != flow_input_num or depth_res_layers != flow_res_layers):
            raise AttributeError("In case of shared resnet the flow and depth input nums must match, "
                                 "as well as the num of res layers!")

        self.camera_net = PoseCNN(pose_input_num, pred_intrinsics, pose_output_num)

        self.depth_encoder = ResnetEncoder(depth_res_layers, resnet_pretrained, depth_input_num)
        self.depth_decoder = DepthDecoder(self.depth_encoder.num_ch_enc, range(scales),
                                          depth_output_num, depth_use_skips)

        if self.shared_resnet:
            self.flow_encoder = self.depth_encoder
        else:
            self.flow_encoder = ResnetEncoder(flow_res_layers, resnet_pretrained, flow_input_num)

        self.flow_decoder = FlowDecoder(self.flow_encoder.num_ch_enc, range(scales),
                                        flow_output_num, flow_use_skips)

        self.glnet_loss = GLNetLoss(**loss_parameters)
        
        parameter_list = [
            {'params': self.depth_decoder.parameters()}, 
            {'params': self.depth_encoder.parameters()}, 
            {'params': self.camera_net.parameters()}
        ]
        if not self.shared_resnet:
            parameter_list += [{'params': self.flow_encoder.parameters()}]
        self.optimizer = torch.optim.Adam(parameter_list, lr=2e-4, betas=(0.9, 0.999))
    
    def train(self):
        self.depth_decoder.train()
        self.depth_encoder.train()
        self.flow_decoder.train()
        if not self.shared_resnet:
            self.flow_encoder.train()
        self.camera_net.train()

    def eval(self):
        self.depth_decoder.eval()
        self.depth_encoder.eval()
        self.flow_decoder.eval()
        if not self.shared_resnet:
            self.flow_encoder.eval()
        self.camera_net.eval()

    def cuda(self):
        self.depth_decoder.cuda()
        self.depth_encoder.cuda()
        self.flow_decoder.cuda()
        if not self.shared_resnet:
            self.flow_encoder.cuda()
        self.camera_net.cuda()

    def cpu(self):
        self.depth_decoder.cpu()
        self.depth_encoder.cpu()
        self.flow_decoder.cpu()
        if not self.shared_resnet:
            self.flow_encoder.cpu()
        self.camera_net.cpu()

    def training_step(self, batch: dict, *args, **kwargs):
        self.train()

        all_color_aug = torch.stack([batch[('color_aug', i, 0)] for i in self.frame_ids], dim=1)
        disps, poses, flows_fwd, flows_bwd = self.__predict_for_train_val(all_color_aug)

        # Convert disparities to depths
        depths = {}
        for frame, disp in disps.items():
            depth_dict = {}
            for scale in range(self.scales):
                depth_dict[('depth', scale)] = \
                    disp2depth(disp[('disp', scale)], self.depth_limits[0], self.depth_limits[1])[1]
            depths[frame] = depth_dict

        loss, loss_parts = self.glnet_loss(batch, depths, poses, flows_fwd, self.scales, disps, flows_bwd)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        loss_parts['total_loss'] = loss
        return loss_parts

    @torch.no_grad()
    def validation_step(self, batch: dict, *args, **kwargs):
        self.eval()

        all_color = torch.stack([batch[('color', i, 0)] for i in self.frame_ids], dim=1)
        disps, poses, flows_fwd, flows_bwd = self.__predict_for_train_val(all_color)

        # Convert disparities to depths
        depths = {}
        for frame, disp in disps.items():
            depth_dict = {}
            for scale in range(self.scales):
                depth_dict[('depth', scale)] = \
                    disp2depth(disp[('disp', scale)], self.depth_limits[0], self.depth_limits[1])[1]
            depths[frame] = depth_dict

        loss, loss_parts = self.glnet_loss(batch, depths, poses, flows_fwd, self.scales, disps, flows_bwd)
        
        loss_parts['total_loss'] = loss
        return loss_parts

    def __predict_for_train_val(self, image_stack: torch.Tensor, is_train: bool = True):
        inp_shape = image_stack.shape
        key_features = []

        disps = {}
        poses = {}
        flows_fwd = {}
        flows_bwd = {}

        # Depth estimation
        for f_idx in range(len(self.frame_ids) - self.depth_input_num + 1):
            frame_ids = self.frame_ids[f_idx:f_idx + self.depth_input_num]
            key_features.append(f_idx + (self.depth_input_num - 1) // 2)
            disps[frame_ids[(self.depth_input_num - 1) // 2]] = \
                self.depth_decoder(
                    self.depth_encoder(
                        image_stack[:, f_idx:f_idx + self.depth_input_num]
                        .view(inp_shape[0], -1, *inp_shape[-2:]).contiguous()
                    )
                )

        # Pose estimation
        for feature_group in combinations(key_features, self.pose_input_num):
            fg_list = list(feature_group)
            frame_group = self.frame_ids[fg_list]

            poses[tuple(frame_group)] = self.camera_net(
                image_stack[:, fg_list]
                .view(inp_shape[0], -1, *inp_shape[-2:]).contiguous()
            )

        # Forward and Backward Flow estimation
        for feature_group in combinations(key_features, self.flow_input_num):
            fg_list = list(feature_group)
            frame_group = self.frame_ids[fg_list]

            flows_fwd[tuple(frame_group)] = \
                self.flow_decoder(
                    self.flow_encoder(
                        image_stack[:, fg_list]
                        .view(inp_shape[0], -1, *inp_shape[-2:]).contiguous()
                    )
                )

            if is_train:
                flows_bwd[tuple(frame_group)] = \
                    self.flow_decoder(
                        self.flow_encoder(
                            image_stack[:, fg_list[::-1]]
                            .view(inp_shape[0], -1, *inp_shape[-2:]).contiguous()
                        )
                    )

        return disps, poses, flows_fwd, flows_bwd



# Main

## Config

In [None]:
import os
from datetime import datetime

ON_GPU = torch.cuda.is_available()

DATA_CFG = {
    'batch_size': 16,
    'img_size': (240, 320),
    'frame_idxs': [-1, 0, 1, 2],
    'scales': 4,

    'train_txt': os.path.join(os.getcwd(), "train_files.txt"),
    'val_txt': os.path.join(os.getcwd(), "val_files.txt"),
    'test_txt': None,
    'num_workers': 6
}


GLNET_LOSS_CFG = {
    'img_size': DATA_CFG['img_size'],
    'scales': DATA_CFG['scales'],
    'mvs_weight': 10.,
    'epi_weight': 10.,
    'apc_weight': 1.,
    'disp_smooth': 0.5,
    'flow_smooth': 0.2,
    'flow_cons_params': (3.0, 0.05),
    'flow_cons_weight': 0.2,
    'ssim_r': 0.85,
    'reduction': 'mean'
}


GLNET_CFG = {
    'pose_input_num': 2,
    'pose_output_num': 1,
    'pred_intrinsics': False,
    'resnet_pretrained': True,

    'depth_input_num': 3,
    'depth_output_num': 1,
    'depth_res_layers': 18,
    'depth_use_skips': True,

    'flow_input_num': 2,
    'flow_output_num': 1,
    'flow_res_layers': 18,
    'flow_use_skips': True,

    'shared_resnet': False,
    'frame_ids': DATA_CFG['frame_idxs'],
    'scales': DATA_CFG['scales'],
    'depth_limits': (0.1, 50.0),
    'loss_parameters': GLNET_LOSS_CFG
}


TRAIN_CFG = {
    'epochs': 10,
    'eval_every_n_steps': 993,
    'eval_ratio': 0.1
}

## Main file

In [None]:

import os
from tqdm.notebook import tqdm

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


log_path = os.path.join(os.getcwd(), 'logs', datetime.now().strftime('%d%m%y_%H%M%S'))
os.makedirs(log_path)
# logger = SummaryWriter(log_path)
logger = None


def setup_dataloaders():
    def readlines(filename):
        """Read all the lines in a text file and return as a list
        """
        with open(filename, 'r') as f:
            lines = f.read().splitlines()
        return lines

    train_file_names = readlines(DATA_CFG['train_txt'])
    val_file_names = readlines(DATA_CFG['val_txt'])

    train_set = TartanAirDataset(
        data_path=None,
        filenames=train_file_names,
        height=DATA_CFG['img_size'][0],
        width=DATA_CFG['img_size'][1],
        frame_idxs=DATA_CFG['frame_idxs'],
        num_scales=DATA_CFG['scales'],
        is_train=True
    )
    train_loader = DataLoader(
        train_set,
        batch_size=DATA_CFG['batch_size'],
        shuffle=False,
        num_workers=DATA_CFG['num_workers'],
        pin_memory=True
    )

    val_set = TartanAirDataset(
        data_path=None,
        filenames=val_file_names,
        height=DATA_CFG['img_size'][0],
        width=DATA_CFG['img_size'][1],
        frame_idxs=DATA_CFG['frame_idxs'],
        num_scales=DATA_CFG['scales'],
        is_train=False
    )
    val_loader = DataLoader(
        val_set,
        batch_size=DATA_CFG['batch_size'],
        shuffle=False,
        num_workers=DATA_CFG['num_workers'],
        pin_memory=True
    )
    
    return train_loader, val_loader



def validate(model, epoch, val_loader, partition=1.0, logger=None, step=None):
    model.eval()
    val_len = int(partition * len(val_loader))
    val_loss = 0.
    with tqdm(total=val_len, leave=False) as pbar:
        for bidx, batch in enumerate(val_loader):
            # Evaluate only a subset
            if bidx > val_len:
                break

            # Data and model
            if ON_GPU:
                for key, value in batch.items():
                    batch[key] = value.cuda()

            val_losses = model.validation_step(batch)

            # Overall loss
            val_loss += val_losses['total_loss'].item()
            
            # Update progress bar to contain all loss elements
            pbar_dict = {
                'Epoch': f'{epoch+1}/{TRAIN_CFG["epochs"]}',
                'AvgLoss': val_loss / (bidx + 1)
            }
            for key, value in val_losses.items():
                pbar_dict[key] = value.item()
            pbar.set_postfix(pbar_dict)
            pbar.update(1)
    
    if logger is not None:
        logger.add_scalar('Validation/Avg Loss', val_loss / val_len, epoch if step is None else step)


def main():
    train_loader, val_loader = setup_dataloaders()

    model = GLNet(**GLNET_CFG)
    if ON_GPU:
        model.cuda()

    for epoch in range(TRAIN_CFG['epochs']):
        model.train()
        
        # Training phase
        with tqdm(total=len(train_loader), leave=False) as pbar:
            epoch_loss = 0.
            for bidx, batch in enumerate(train_loader):
                # Data and model
                if ON_GPU:
                    for key, value in batch.items():
                        batch[key] = value.cuda()

                losses = model.training_step(batch)

                # Validate every n step if needed
                if (val_loader is not None) and \
                    (TRAIN_CFG['eval_every_n_steps'] is not None) and \
                    ((bidx + 1) % TRAIN_CFG['eval_every_n_steps'] == 0):
                    validate(model, epoch, val_loader, TRAIN_CFG['eval_ratio'], logger, bidx)
                    model.train()

                # Log loss elements every step
                if logger is not None:
                    logger.add_scalars('Train/Losses', losses, epoch * len(train_loader) + bidx)

                # Overall loss
                epoch_loss += losses['total_loss'].item()
                
                # Update progress bar to contain loss elements
                pbar_dict = {
                    'Epoch': f'{epoch+1}/{TRAIN_CFG["epochs"]}',
                    'AvgLoss': epoch_loss / (bidx + 1)
                }
                for key, value in losses.items():
                    pbar_dict[key] = value.item()
                pbar.set_postfix(pbar_dict)
                pbar.update(1)

        if val_loader is not None:
            validate(model, epoch, val_loader, TRAIN_CFG['eval_ratio'])

        torch.save(
            {
            'model':{
                'depth_decoder': model.depth_decoder.state_dict(),
                'depth_encoder': model.depth_encoder.state_dict(),
                'flow_decoder': model.flow_decoder.state_dict(),
                'flow_encoder': model.flow_encoder.state_dict(),
                'camera_net': model.camera_net.state_dict()},
            'optimizer': model.optimizer.state_dict(),
            'epoch': epoch,
            'train_loss': epoch_loss
            },
            os.path.join(log_path, f'model_epoch{epoch}.pth')
        )

main()
