In [None]:
#@title mount drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title imports
%%capture
%cd /content/drive/MyDrive/DepthEstimation/
%ls

%load_ext autoreload
%autoreload 2
%load_ext tensorboard

import matplotlib.pyplot as plt
import os
from glob import glob 
import numpy as np
import json
import ipywidgets as widgets
from torchvision import models
import torch

# internal codes
from torch_implementation.scripts.dataloaders import DepthDataLoader
# from torch_implementation.scripts.models.ordinary_unet import OrdinaryUNet

# from torch_implementation.scripts.utils import good
# from torch_implementation.scripts.utils import update_train_filenames_file

from torch_implementation.config import Config

In [None]:
#@title OrdinaryUnet model
class UpConv(torch.nn.Module):
    def __init__(self,num_decode_filters,feature_map_num_filters):
        super().__init__()
        self.bilinear_up_sampling2d =  torch.nn.UpsamplingBilinear2d(
            scale_factor=2)
        self.conv1 = torch.nn.Conv2d(
            feature_map_num_filters+num_decode_filters*2,
            num_decode_filters,
            3,
            padding='same')
        self.conv2 = torch.nn.Conv2d(
            num_decode_filters,
            num_decode_filters,
            3,
            padding='same')
        self.leaky_relu = torch.nn.LeakyReLU(0.2)
    
    def forward_hook_callback(self,input,output):
        self.feature_map = output.detach()

    def forward(self,input):
        upsampling_output = self.bilinear_up_sampling2d(input)
        skipconnection_output = torch.concat(
            [upsampling_output,
             self.feature_map],
            dim=1) #(NC1HW) (NC2HW) -> (NC12HW) 
        conv1_output = self.conv1(skipconnection_output)
        leaky_relu1_output = self.leaky_relu(conv1_output)
        conv2_output = self.conv2(leaky_relu1_output)
        leaky_relu2_output = self.leaky_relu(conv2_output)
        return leaky_relu2_output

class UnetDecoder(torch.nn.Module):
    def __init__(self,num_decode_filters,num_feature_map_filters):
        super().__init__()
        self.max_depth = 80.0
        self.eps = 1e-5
        self.conv1 = torch.nn.Conv2d(
            num_decode_filters[0],
            num_decode_filters[0],
            kernel_size=1,
            padding='same')
        self.relu = torch.nn.ReLU()
        self.upconvlist = torch.nn.ModuleList()
        for i,j in zip(num_decode_filters[1:],
                       num_feature_map_filters):
            self.upconvlist.append(UpConv(i,j))
        self.conv2 = torch.nn.Conv2d(
            num_decode_filters[-1],
            1,
            kernel_size=3,
            padding='same')
        self.sigmoid = torch.nn.Sigmoid()
        self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
        
    def forward(self,input):
        relu_output = self.relu(input)
        upconv = self.conv1(relu_output)
        for layer in self.upconvlist:
            upconv = layer(upconv)
        conv2_output = self.conv2(upconv)
        conv2_output_sigmoid = self.sigmoid(conv2_output)
        upconv2_output = self.upsample(conv2_output_sigmoid)
        output=upconv2_output*self.max_depth + self.eps
        return upconv2_output

class OrdinaryUnet(torch.nn.Module):
    def __init__(self):
        super().__init__()

        num_decode_filters = 1664 # based on 3, 256, 512 (CHW) input
        self.encoder = models.densenet169(pretrained=True).features
        self.decoder = UnetDecoder(
            [num_decode_filters,
             num_decode_filters//2,
             num_decode_filters//4,
             num_decode_filters//8,
             num_decode_filters//16],
             [self.encoder.transition2[-2].out_channels,
              self.encoder.transition1[-2].out_channels,
              self.encoder.conv0.out_channels,
              self.encoder.conv0.out_channels])
        
        # register forward hooks for skip connections
        self.encoder.transition2.register_forward_hook(
            self.decoder.upconvlist[0].forward_hook_callback) #keras' pool3_pool  32, 16, 256 (WHC)
        self.encoder.transition1.register_forward_hook(
            self.decoder.upconvlist[1].forward_hook_callback) #keras' pool2_pool  64, 32, 128
        self.encoder.pool0.register_forward_hook(
            self.decoder.upconvlist[2].forward_hook_callback) #keras' pool1  128, 64, 64
        self.encoder.relu0.register_forward_hook(
            self.decoder.upconvlist[3].forward_hook_callback) #keras' conv1/relu  256, 128, 64

    def forward(self,input):
        # input = self.quant(input)
        encoder_output = self.encoder(input)
        output = self.decoder(encoder_output)
        # output = self.dequant(output)
        return output