## Project 7: self-supervised single-image depth-estimation using neural-networks

### Authors:
#### ◘ Sharhad Bashar
#### ◘ Lizhe Chen
#### ◘ Genséric Ghiro
#### ◘ Futian Zhang

## 1) Abstract: In a few short sentences highlighting the main points of the project.

## 2) Introduction (3-4 paragraphs): reviewing your topic, related technical ideas/algorithms, your selected methodology/approach, its motivation)

## 3) Contributions section: This section should have one separate short paragraph (or a bullet) dedicated to each co-author clearly indicating her/his specific contributions. 

#### Sharhad Bashar:
       •
       •
       •
       
#### Lizhe Chen:
       •
       •
       •
       
#### Genséric Ghiro:
       •
       •
       •
       
#### Futian Zhang:
       •
       •
       •

## 4) Outline section: the overall structure/organization of the report

5aa) Imports
5a) Dataloader
5b) Network
5c) Validator
5d) Trainer
5e) Training
6) Results display
7) Conclusion
8) References

## 5aa) Imports

In [11]:
%matplotlib inline

In [18]:
from __future__ import absolute_import, division, print_function

import os
import glob
import random
import numpy as np
import copy
from PIL import Image 
import torch
import time
from matplotlib import pyplot as plt

#5a)
import torch.utils.data as data
from torchvision import transforms
import torchvision.transforms.functional as tF

#5b)
import torch.nn as nn
import torch.nn.functional as F
import importlib
import torchvision.models as models
from loss import MonodepthLoss #make sure loss.py is in folder

#5c) 
import torch.optim as optim
from torch.utils.data import DataLoader

#5d)
import pickle

#5e)
from torch.utils.data.dataloader import *

## 5a) Dataloader

In [13]:
def pil_loader(path):
    # open path as file to avoid ResourceWarning
    # (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')
        
class JointRandomFlip(object):
    def __call__(self, L, R):
        if np.random.random_sample()>0.5:
            return (tF.hflip(R),tF.hflip(L))
        return (L,R)
    
class JointRandomColorAug(object):

    def __init__(self,gamma=(0.8,1.2),brightness=(0.5,2.0),color_shift=(0.8,1.2)):
        self.gamma = gamma
        self.brightness = brightness
        self.color_shift = color_shift

    def __call__(self, L, R):
        if  np.random.random_sample()>0.5:
            
            random_gamma = np.random.uniform(*self.gamma)
            L_aug = L ** random_gamma
            R_aug = R ** random_gamma

            random_brightness = np.random.uniform(*self.brightness)
            L_aug = L_aug * random_brightness
            R_aug = R_aug * random_brightness

            random_colors = np.random.uniform(self.color_shift[0],self.color_shift[1], 3)
            for i in range(3):
                L_aug[i, :, :] *= random_colors[i]
                R_aug[i, :, :] *= random_colors[i]

            # saturate
            L_aug = torch.clamp(L_aug, 0, 1)
            R_aug = torch.clamp(R_aug, 0, 1)

            return L_aug, R_aug

        else:
            return L, R

class JointToTensor(object):
    def __call__(self, L, R):
        return tF.to_tensor(L),tF.to_tensor(R)
    
class JointToImage(object):
    def __call__(self, L, R):
        return transforms.ToPILImage()(L),transforms.ToPILImage()(R)
    
    
class JointCompose(object):
    def __init__(self, transforms):
        """
        params: 
           transforms (list) : list of transforms
        """
        self.transforms = transforms

    # We override the __call__ function such that this class can be
    # called as a function i.e. JointCompose(transforms)(img, target)
    # Such classes are known as "functors"
    def __call__(self, img, target):
        """
        params:
            img (PIL.Image)    : input image
            target (PIL.Image) : ground truth label 
        """
        assert img.size == target.size
        for t in self.transforms:
            img, target = t(img, target)
        return img, target


class TwoViewDataset(data.Dataset):
    
    def __init__(self, 
                 data_path,
                 resize_shape=(512,256), 
                 is_train=False,
                 transforms=None,
                 sanity_check=None):
        super(TwoViewDataset, self).__init__()
        self.data_path = data_path

        self.interp = Image.ANTIALIAS
        self.resize_shape = resize_shape
        self.is_train = is_train
        self.transforms=transforms
        self.loader = pil_loader
        
        if is_train:
            self.imgR_folder = os.path.join(data_path, "train", "image_right")
            self.imgL_folder = os.path.join(data_path, "train", "image_left")
        else:
            self.imgR_folder = os.path.join(data_path, "val", "image_right")
            self.imgL_folder = os.path.join(data_path, "val", "image_left")
        
        
        self.imgR=[os.path.join(self.imgR_folder, x) for x in os.listdir(self.imgR_folder)]
        self.imgL=[os.path.join(self.imgL_folder, x) for x in os.listdir(self.imgL_folder)]

    def get_color(self, path, do_flip):
        color = self.loader(path)
        if do_flip:
            color = color.transpose(Image.FLIP_LEFT_RIGHT)
        return self.to_tensor(color)


    def __len__(self):
        return len(list(glob.glob1(self.imgL_folder, "*.jpg")))

    def __getitem__(self, index):
        #print(np.array(Image.open(self.imgR[index]).convert('RGB')).shape)
        colorR=Image.open(self.imgR[index]).convert('RGB').resize(self.resize_shape)
        colorL=Image.open(self.imgL[index]).convert('RGB').resize(self.resize_shape)
        #print(np.array(colorR).shape)
        
        if self.transforms is not None:
            colorR, colorL = self.transforms(colorR, colorL)
        return colorL, colorR

## 5b) Network

In [14]:
class get_disp(nn.Module):
    def __init__(self, num_in_channels):
        super(get_disp, self).__init__()
        self.p2d = (1, 1, 1, 1)
        self.disp = nn.Sequential(nn.Conv2d(num_in_channels, 2, kernel_size=3, stride=1),
                                  nn.BatchNorm2d(2),
                                  torch.nn.Sigmoid())

    def forward(self, x):
        x = self.disp(F.pad(x, self.p2d))
        return 0.3 * x


class iconv(nn.Module):
    def __init__(self, num_in_channels, num_out_channels, kernel_size, stride):
        super(iconv, self).__init__()
        p = int(np.floor((kernel_size - 1) / 2))
        self.p2d = p2d = (p, p, p, p)

        self.iconv = nn.Sequential(nn.Conv2d(num_in_channels, num_out_channels, kernel_size=kernel_size, stride=stride),
                                  nn.BatchNorm2d(num_out_channels))

    def forward(self, x):
        x = self.iconv(F.pad(x, self.p2d))
        return F.elu(x, inplace=True)

class upconv(nn.Module):
    def __init__(self, num_in_channels, num_out_channels, kernel_size, scale):
        super(upconv, self).__init__()
        self.scale = scale
        self.conv1 = iconv(num_in_channels, num_out_channels, kernel_size, 1)

    def forward(self, x):
        x = nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=True)
        return self.conv1(x)

class ResnetDispModel(nn.Module):

    def __init__(self, num_input_channel, encoder='resnet18', pretrained=True, criterion=None):
        super(ResnetDispModel, self).__init__()
        self.criterion = criterion
        self.num_input_channel = num_input_channel
        resnet = models.resnet18(pretrained=pretrained)
        filters_res18 = [64, 128, 256, 512]
        resnet_pool1 = list(resnet.children())[1:4]

        self.conv1 = resnet.conv1
        self.maxpool = nn.Sequential(*resnet_pool1)

        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        self.upconv6 = upconv(filters_res18[3], 512, 3, 2)
        self.iconv6 = iconv(filters_res18[2] + 512, 512, 3, 1)

        self.upconv5 = upconv(512, 256, 3, 2)
        self.iconv5 = iconv(filters_res18[1] + 256, 256, 3, 1)

        self.upconv4 = upconv(256, 128, 3, 2)
        self.iconv4 = iconv(filters_res18[0] + 128, 128, 3, 1)
        self.disp4_layer = get_disp(128)

        self.upconv3 = upconv(128, 64, 3, 1) #
        self.iconv3 = iconv(64 + 64 + 2, 64, 3, 1)
        self.disp3_layer = get_disp(64)

        self.upconv2 = upconv(64, 32, 3, 2)
        self.iconv2 = iconv(64 + 32 + 2, 32, 3, 1)
        self.disp2_layer = get_disp(32)

        self.upconv1 = upconv(32, 16, 3, 2)
        self.iconv1 = iconv(16 + 2, 16, 3, 1)
        self.disp1_layer = get_disp(16)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)

    def forward(self, x):
        # encoder
        x_conv1 = self.conv1(x)
        x_pool1 = self.maxpool(x_conv1)
        x1 = self.layer1(x_pool1)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        # skips
        skip1 = x_conv1
        skip2 = x_pool1
        skip3 = x1
        skip4 = x2
        skip5 = x3
        # print(skip4.size())

        # decoder
        upconv6 = self.upconv6(x4)
        concat6 = torch.cat((upconv6, skip5), 1)
        iconv6 = self.iconv6(concat6)

        upconv5 = self.upconv5(iconv6)
        # print(upconv5.size())
        concat5 = torch.cat((upconv5, skip4), 1)
        iconv5 = self.iconv5(concat5)

        upconv4 = self.upconv4(iconv5)
        concat4 = torch.cat((upconv4, skip3), 1)
        iconv4 = self.iconv4(concat4)
        self.disp4 = self.disp4_layer(iconv4)
        self.udisp4 = nn.functional.interpolate(self.disp4, scale_factor=1, mode='bilinear', align_corners=True)
        self.disp4 = nn.functional.interpolate(self.disp4, scale_factor=0.5, mode='bilinear', align_corners=True)

        upconv3 = self.upconv3(iconv4)
        concat3 = torch.cat((upconv3, skip2, self.udisp4), 1)
        iconv3 = self.iconv3(concat3)
        self.disp3 = self.disp3_layer(iconv3)
        self.udisp3 = nn.functional.interpolate(self.disp3, scale_factor=2, mode='bilinear', align_corners=True)

        upconv2 = self.upconv2(iconv3)
        concat2 = torch.cat((upconv2, skip1, self.udisp3), 1)
        iconv2 = self.iconv2(concat2)
        self.disp2 = self.disp2_layer(iconv2)
        self.udisp2 = nn.functional.interpolate(self.disp2, scale_factor=2, mode='bilinear', align_corners=True)

        upconv1 = self.upconv1(iconv2)
        concat1 = torch.cat((upconv1, self.udisp2), 1)
        iconv1 = self.iconv1(concat1)
        self.disp1 = self.disp1_layer(iconv1)
        
        return self.disp1, self.disp2, self.disp3, self.disp4


#ResnetDispModel(3)

## 5c) Validator

In [16]:
class Validator:
    def __init__(self, val_loader, batch_size, params_file=None, use_gpu=False):
        self.use_gpu = use_gpu
        self.params_file = params_file
        self.val_loader = val_loader
        if use_gpu :
            self.device = "cuda:0"
        else:
            self.device = "cpu"
        self.loss = MonodepthLoss(
            n=4,
            SSIM_w=0.85,
            disp_gradient_w=0.1, lr_w=1).to(self.device)
        self.val_losses = []
        self.batch_size = batch_size


    def validate(self, network):

        network.eval()

        total_loss = 0
        counter = 0
        for i, data in enumerate(self.val_loader):
            left, right = data

            if self.use_gpu:
                left = left.cuda()
                network = network.cuda()
                right = right.cuda()

            model_outputs = network(left)

            loss = self.loss(model_outputs, [left, right])
            self.val_losses.append(loss.item())
            total_loss += loss.item()
            counter += 1

        total_loss /= self.batch_size * counter

        return total_loss

## 5d) Trainer

In [17]:
class Trainer:
    def __init__(self, network, train_loader, optimizer, batch_size, params_file=None, use_gpu=False):
        self.net = network
        self.use_gpu = use_gpu
        self.optimizer = optimizer
        self.validator = None
        self.history = {"Train": [], "Val": []}
        self.params_file = params_file
        self.train_loader = train_loader
        self.batch_size = batch_size
        if use_gpu :
            self.device = "cuda:0"
        else:
            self.device = "cpu"

        self.loss_function = MonodepthLoss(
            n=4,
            SSIM_w=0.85,
            disp_gradient_w=0.1, lr_w=1).to(self.device)


    def setValidator(self, validator):
        self.validator = validator

    def saveParams(self, path):
        torch.save(self.net.state_dict(), path)

    def loadModel(self, path):
        self.net.load_state_dict(torch.load(path))

    def train(self):
        total_loss = 0.0

        self.net.train()
        counter = 0
        for i, data in enumerate(self.train_loader):
            left, right = data
            if self.use_gpu:
                left = left.cuda()
                self.net = self.net.cuda()
                right = right.cuda()


            self.optimizer.zero_grad()
            disps = self.net(left)
            loss = self.loss_function(disps, [left, right])
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            counter += 1

        main_loss = total_loss / (counter * self.batch_size)

        return main_loss

    def run_train(self, epoch):
        if self.params_file:
            self.loadModel(self.params_file)
        prev_score = np.inf
        if self.validator:
            prev_score = self.validator.validate(self.net)

        for e in range(epoch):

            loss = self.train()
            print("Epoch: {} Loss: {}".format(e, loss))
            self.history["Train"].append(loss)

            if self.validator:
                val_score = self.validator.validate(self.net)
                self.history["Val"].append(val_score)
                if val_score < prev_score:
                    print("update model file with prev_score {} and current score {}".format(prev_score, val_score))
                    self.saveParams('params.pkl')
                    prev_score = val_score

            with open('train_history.pickle', 'wb') as handle:
                pickle.dump(self.history, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def copyNetwork(self):
        return copy.deepcopy(self.net)

## 5e) Training Network

In [None]:
use_gpu = False

val_dataset = TwoViewDataset("data/dataset/", is_train=False, transforms=JointToTensor())
trn_dataset = TwoViewDataset("data/dataset/", is_train=True, transforms=JointToTensor())
val_loader = data.DataLoader(val_dataset, batch_size=1, num_workers=1, shuffle=False)
trn_loader = data.DataLoader(trn_dataset, batch_size=8, num_workers=1, shuffle=False)

network = ResnetDispModel(3)
val = Validator(val_loader, 1, use_gpu)

opt = torch.optim.SGD(network.parameters(), lr=1e-2, weight_decay=1e-6,momentum=0.5, nesterov=False)

trn = Trainer(network, trn_loader, opt, 8, use_gpu)

trn.setValidator(val)
trn.run_train(50)

trained_net = trn.net


for left, right in val_loader:
    sample = left
    break

sample_np = np.array(sample)
sample_np.shape
# sample.size()

## 6) Comparing Results

In [None]:

disp1, disp2, disp3, disp4 = trained_net(sample.cuda())

disp1_np = np.array(disp1.cpu().detach().numpy())
disp1_np =np.squeeze(disp1_np)[0]
plt.imshow(disp1_np, cmap="plasma")

disp2_np = np.array(disp2.cpu().detach().numpy())
disp2_np =np.squeeze(disp2_np)[0]
plt.imshow(disp2_np, cmap="plasma")

disp3_np = np.array(disp3.cpu().detach().numpy())
disp3_np =np.squeeze(disp3_np)[0]
plt.imshow(disp3_np, cmap="plasma")

disp4_np = np.array(disp4.cpu().detach().numpy())
disp4_np =np.squeeze(disp4_np)[0]
plt.imshow(disp4_np, cmap="plasma")

## 7) Conclusion section (2-4 paragraphs): summarizing your observations, results, etc.

## 8) References

### • Godard, C., Mac Aodha, O., & Brostow, G. (2019). Unsupervised Monocular Depth Estimation with Left-Right             Consistency. Retrieved 13 December 2019, from https://arxiv.org/abs/1609.03677

### • OniroAI, MonoDepth-PyTorch, (2018), GitHub repository, https://github.com/OniroAI/MonoDepth-PyTorch