<a href="https://colab.research.google.com/github/knoriy/depth_estimation/blob/master/model_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils

In [0]:
# import os
# import time
# import tqdm
# import shutil
# import imageio
# import PIL.Image
# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# import IPython.display as display

## Globals

In [0]:
EPOCHS = 1000  # The number of itteration for training

LEARNING_RATE = 1e-3
WEIGHT_DECAY  = 0

IMG_WIDTH = IMG_HEIGHT = 28

In [0]:
SOURCE_DATA_DIR = "/content/Data"
CHECKPOINT_DIR  = "/content/checkpoint"
OUTPUT_DIR      = "/content/output"

# Utils

# Model

In [0]:
class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.relu = nn.ReLU
    self.Dropout2d = nn.Dropout2d(0.1)
    self.MaxPool2d = nn.MaxPool2d(3, 1)
    self.BatchNorm2d = nn.BatchNorm2d
    self.conv1 = self.conv_block

  def conv_block(self, x, filters, kernel_size, stride=1, padding_mode='zeros'):
    x = nn.Conv2d(out_channels=filters, kernel_size=kernel_size, stride=stride, padding_mode=padding_mode)(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.MaxPool2d(x)
    x = self.Dropout2d(x)
    return x

  # Decoder
  def deconv_block(self, x, filters, kernel_size, stride=1, padding_mode='zeros'):
    x = nn.ConvTranspose2d(out_channels=filters, kernel_size=kernel_size, stride=stride, padding_mode=padding_mode)(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.Dropout2d(x)
    return x

  # Residual
  def identity_block(self, x, filters, kernel_size, padding_mode='zeros'):

    x_shortcut = x

    x = nn.Conv2d(out_channels=filters, kernel_size=kernel_size, stride=1, padding_mode=padding_mode)(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.Dropout2d(x)

    x = nn.Conv2d(out_channels=filters, kernel_size=kernel_size, stride=1, padding_mode=padding_mode)(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.Dropout2d(x)

    x = nn.Conv2d(out_channels=filters, kernel_size=kernel_size, stride=1, padding_mode=padding_mode)(x)
    x = self.BatchNorm2d(x)

    x = torch.add(x, x_shortcut)
    x = self.relu(x)

    return x

  def forward(self, x):
    #####################################################################
    # Encoder
    #####################################################################
    x = self.conv_block(x=x, filters=64 , kernel_size=2, stride=1, padding_mode='replicate') 
    x = self.conv_block(x=x, filters=128, kernel_size=2, stride=2, padding_mode='replicate')
    x = self.conv_block(x=x, filters=256, kernel_size=2, stride=2, padding_mode='replicate')
    #####################################################################
    # Encoder
    #####################################################################



    #####################################################################
    # Residual Block
    #####################################################################
    x = self.identity_block(x=x, filters=256, stride=1, kernel_size=2, padding_mode='replicate')
    x = self.identity_block(x=x, filters=256, stride=1, kernel_size=2, padding_mode='replicate')
    x = self.identity_block(x=x, filters=256, stride=1, kernel_size=2, padding_mode='replicate')
    x = self.identity_block(x=x, filters=256, stride=1, kernel_size=2, padding_mode='replicate')
    #####################################################################
    # Residual Block
    #####################################################################


    #####################################################################
    # Encoder
    #####################################################################
    x = self.deconv_block(x=x, filters=256, kernel_size=2, stride=2, padding_mode='replicate')
    x = self.deconv_block(x=x, filters=256, kernel_size=2, stride=2, padding_mode='replicate')
    x = self.deconv_block(x=x, filters=256, kernel_size=2, stride=2, padding_mode='replicate')
    #####################################################################
    # Encoder
    #####################################################################

    x = nn.ConvTranspose2d(out_channels=3, kernel_size=2, stride=1, padding_mode='replicate')(x)
    x = F.tanh()(x)

    return x


In [0]:
class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.conv2d1 = nn.Conv2d(in_channels=(1, 3, 
                                          height, 
                                          width), out_channels=64 , kernel_size=2, stride=1, padding_mode='replicate')
    self.conv2d2 = nn.Conv2d(in_channels=2, out_channels=128, kernel_size=2, stride=2, padding_mode='replicate')
    self.conv2d3 = nn.Conv2d(in_channels=3, out_channels=256, kernel_size=2, stride=2, padding_mode='replicate')

    self.ConvTranspose2d1 = nn.ConvTranspose2d(in_channels=3, out_channels=256, kernel_size=2, stride=2, padding_mode='replicate')
    self.ConvTranspose2d2 = nn.ConvTranspose2d(in_channels=1, out_channels=128, kernel_size=2, stride=2, padding_mode='replicate')
    self.ConvTranspose2d3 = nn.ConvTranspose2d(in_channels=1, out_channels=64 , kernel_size=2, stride=2, padding_mode='replicate')

    self.BatchNorm2d = nn.BatchNorm2d
    self.relu = nn.ReLU
    self.MaxPool2d = nn.MaxPool2d(3, 1)
    self.Dropout2d1 = nn.Dropout2d(0.1)

  # Decoder
  def deconv_block(self, x, filters, kernel_size, stride=1, padding_mode='zeros'):
    x = nn.ConvTranspose2d(out_channels=filters, kernel_size=kernel_size, stride=stride, padding_mode=padding_mode)(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.Dropout2d(x)
    return x

  # Residual
  def identity_block(self, x, filters, kernel_size, padding_mode='zeros'):

    x_shortcut = x

    x = nn.Conv2d(out_channels=filters, kernel_size=kernel_size, stride=1, padding_mode=padding_mode)(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.Dropout2d(x)

    x = nn.Conv2d(out_channels=filters, kernel_size=kernel_size, stride=1, padding_mode=padding_mode)(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.Dropout2d(x)

    x = nn.Conv2d(out_channels=filters, kernel_size=kernel_size, stride=1, padding_mode=padding_mode)(x)
    x = self.BatchNorm2d(x)

    x = torch.add(x, x_shortcut)
    x = self.relu(x)

    return x

  def forward(self, x):
    #####################################################################
    # Encoder
    #####################################################################

    x = self.conv2d1(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.MaxPool2d(x)
    x = self.Dropout2d(x)

    x = self.conv2d2(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.MaxPool2d(x)
    x = self.Dropout2d(x)

    x = self.conv2d3(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.MaxPool2d(x)
    x = self.Dropout2d(x)

    #####################################################################
    # Encoder
    #####################################################################



    # #####################################################################
    # # Residual Block
    # #####################################################################
    # x = self.identity_block(x=x, filters=256, stride=1, kernel_size=2, padding_mode='replicate')
    # x = self.identity_block(x=x, filters=256, stride=1, kernel_size=2, padding_mode='replicate')
    # x = self.identity_block(x=x, filters=256, stride=1, kernel_size=2, padding_mode='replicate')
    # x = self.identity_block(x=x, filters=256, stride=1, kernel_size=2, padding_mode='replicate')
    # #####################################################################
    # # Residual Block
    # #####################################################################


    #####################################################################
    # Encoder
    #####################################################################

    x = self.ConvTranspose2d1(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.Dropout2d(x)

    x = self.ConvTranspose2d2(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.Dropout2d(x)

    x = self.ConvTranspose2d3(x)
    x = self.BatchNorm2d(x)
    x = self.relu(x)
    x = self.Dropout2d(x)
    #####################################################################
    # Encoder
    #####################################################################

    x = nn.ConvTranspose2d(out_channels=3, kernel_size=2, stride=1, padding_mode='replicate')(x)
    x = F.tanh()(x)

    return x


In [0]:
model = Model()
print(model)

In [0]:
# model(Some_image)

## Loss

In [0]:
def custom_loss(pred, true):
  loss = nn.MSELoss()
  return loss(pred, true)

## Optimizer

In [0]:
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Train

In [0]:
def checkpoint(model, dir):
  return torch.utils.checkpoint.checkpoint(model)

def train_step():
  return

def train_loop():
  for epoch in range(EPOCHS):
    for img in images:
      model.zero_grad()

      prediction = model("INSET MODEL INPUTS HERE")
      
      loss = custom_loss(prediction, img)
      loss.backward()

      optimizer.step()
    print("Loss: {}".format(loss))
    with torch.no_grad():

train_loop()

# Evaluate

In [0]:
with torch.no_grad():
  for image in images:
    prediction = net(image.view(-1, IMG_WIDTH*IMG_HEIGHT))
    print(loss(prediction, image))

# Save model

In [0]:
def save_model(dir, model):
  torch.save(model.state_dict(), dir)

save_model(CHECKPOINT_DIR, net)

# Load Model

In [0]:
net = Net()
net.load_state_dict(torch.load(CHECKPOINT_DIR))