<a href="https://colab.research.google.com/github/knoriy/depth_estimation/blob/master/model_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.utils as vutils
import torchvision.models as models
import torchvision.transforms as transforms


In [0]:
# import os
# import time
# import tqdm
# import shutil
# import imageio
# import PIL.Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import IPython.display as display

## Globals

In [0]:
EPOCHS = 1000  # The number of itteration for training
BATCH_SIZE = 16


LEARNING_RATE = 1e-3
WEIGHT_DECAY  = 0

IMG_WIDTH = IMG_HEIGHT = 28

In [0]:
SOURCE_DATA_DIR = "/content/Data"
CHECKPOINT_DIR  = "/content/checkpoint"
OUTPUT_DIR      = "/content/output"

# Utils

In [0]:
def im_show(tensor):
  if len(tensor.shape) == 4:
    tensor_image = image.view(image.shape[2], image.shape[3], image.shape[1])
    plt.imshow(tensor_image.detach().numpy() )
    plt.show()
  else:
    tensor_image = tensor.view(tensor.shape[1], tensor.shape[2], tensor.shape[0])
    plt.imshow(tensor_image.detach().numpy() )
    plt.show()

In [0]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

## Ben's dataloader

In [0]:
import csv
with open('../data_descriptions.csv', newline='') as csvfile: ###### data_descriptions csv must be in this relative location
  spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')

  for count, row in enumerate(spamreader):
    if count == 0:
      folder_names = row
    else:
      num_files = row


for i in range(len(num_files)):
  num_files[i] = int(num_files[i])

list_of_numbers   = ["{0:05}".format(i) for i in range(1, sum(num_files)+1)]

colour_filenames  = [f"colour_{num}.raw" for num in list_of_numbers]
depth_filenames   = [f"colour_{num}.raw" for num in list_of_numbers]

# for num in list_of_numbers:
#   colour_filenames.append(f"colour_{num}.raw")
#   depth_filenames.append(f"depth_{num}.raw")


class ModerateDataset(Dataset):
  def __init__(self, col_dir='', depth_dir='', transform=None, trans_on=False):
    self.path_names = {}
    for folder in folder_names:
      self.path_names[f"{folder}"] = {}
    # for folder in folder_names:
      self.path_names[f'{folder}']['colour'] = {}
      self.path_names[f'{folder}']['depth'] = {}
    for i in range(1, num_files[0]):
      self.path_names['Sunny']['colour'][f"{i}"] = {}
      self.path_names['Sunny']['depth'][f"{i}"] = {}

    print("*************MAKE SURE THE PATH FILE IN THE FOR LOOP IS THE BASE IMAGE DIRECTORY ON YOUR COMPUTER**************")

    count = 0
    for folder in folder_names:
      for i in range(0, num_files[folder_names.index(folder)]):
        self.path_names[f'{folder}']['colour'][f'{i+1}'] = Path(f"C:/Users/Ben/OneDrive - Bournemouth University/Computer Vision/Moderate collection/{folder}/colour/{colour_filenames[count+i]}")  ## Change this path here!!!!
        self.path_names[f'{folder}']['depth'][f'{i+1}'] = Path(f"C:/Users/Ben/OneDrive - Bournemouth University/Computer Vision/Moderate collection/{folder}/depth/{depth_filenames[count+i]}")   ## Change this path here!!!!
      count = count + num_files[folder_names.index(folder)]
    
    self.transform  = transform
    self.col_dir    = col_dir
    self.depth_dir  = depth_dir
    self.trans_on   = trans_on

  def __getitem__(self,idx):
    if idx == 0:
      self.col_dir    = self.path_names[f'{folder_names[0]}']['colour'][f'{idx+1}']
      self.depth_dir  = self.path_names[f'{folder_names[0]}']['depth'][f'{idx+1}']
  
    if (idx>0 and idx <= num_files[0]):  ## 1-500
      self.col_dir    = self.path_names[f'{folder_names[0]}']['colour'][f'{idx}']
      self.depth_dir  = self.path_names[f'{folder_names[0]}']['depth'][f'{idx}']

    elif (idx > num_files[0] and idx < (sum(num_files[:2])+1)): ## 501 - 1500
      self.col_dir    = self.path_names[f'{folder_names[1]}']['colour'][f'{idx-num_files[0]}']
      self.depth_dir  = self.path_names[f'{folder_names[1]}']['depth'][f'{idx-num_files[0]}']

    elif (idx > sum(num_files[:2]) and idx < (sum(num_files[:3])+1) ): ## 1501 - 2600
      self.col_dir    = self.path_names[f'{folder_names[2]}']['colour'][f'{idx-sum(num_files[:2])}'] # -1500
      self.depth_dir  = self.path_names[f'{folder_names[2]}']['depth'][f'{idx-sum(num_files[:2])}']

    elif (idx > sum(num_files[:3]) and idx < (sum(num_files[:4])+1) ): ## 2601 - 5600
      self.col_dir    = self.path_names[f'{folder_names[3]}']['colour'][f'{idx-sum(num_files[:3])}'] #-2600
      self.depth_dir  = self.path_names[f'{folder_names[3]}']['depth'][f'{idx-sum(num_files[:3])}']
        
    elif (idx > sum(num_files[:4]) and idx < (sum(num_files[:5])+1) ): ## 5601 - 7857
      self.col_dir    = self.path_names[f'{folder_names[4]}']['colour'][f'{idx-sum(num_files[:4])}'] # -5600
      self.depth_dir  = self.path_names[f'{folder_names[4]}']['depth'][f'{idx-sum(num_files[:4])}']

    elif (idx > sum(num_files)):
        raise NameError('Index outside of range')

    col_img = import_raw_colour_image(self.col_dir)
    depth_img = import_raw_depth_image(self.depth_dir)

    if self.trans_on == True:
      col_img   = torch.from_numpy(np.flip(col_img,axis=0).copy()) # apply any transforms
      depth_img = torch.from_numpy(np.flip(depth_img,axis=0).copy()) # apply any transforms
      col_img   = col_img.transpose(0,2)
      col_img   = col_img.transpose(1,2)

    if self.transform: # if any transforms were given to initialiser
      col_img = self.transform(col_img) # apply any transforms
      
    return col_img, depth_img
  
  def __len__(self):
      return sum(num_files)

In [0]:
total_Data = ModerateDataset(trans_on=True)  ## instancing the dataset

### Train/test split

In [0]:
train_size = int(0.8 * len(total_Data))
val_size = int((len(total_Data) - train_size)/2)

train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(total_Data, [train_size, val_size, val_size])


train_dl        = DataLoader(train_dataset,       batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
validation_dl   = DataLoader(validation_dataset,  batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
test_dl         = DataLoader(test_dataset,        batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)

## Kari's Dataloader


# Model

In [0]:
class ResBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, padding_mode="zeros"):
    super().__init__()
    self.Conv2d_1   = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode=padding_mode)
    self.Conv2d_2   = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode=padding_mode)
    self.BatchNorm2d  = nn.BatchNorm2d(out_channels)
    self.ReLU       = nn.ReLU()

  def forward(self, x):
    x_short = x

    x = self.Conv2d_1(x)
    x = self.BatchNorm2d(x)
    x = self.ReLU(x)

    x = self.Conv2d_2(x)
    x = self.BatchNorm2d(x)
    x = torch.add(x, x_short)
    x = self.ReLU(x)

    return x

class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, padding_mode="zeros"):
    super().__init__()
    self.conv2d_1   = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode=padding_mode)
    self.batchNorm  = nn.BatchNorm2d(out_channels)
    self.MaxPool2d  = nn.MaxPool2d(kernel_size, stride, padding)
    self.Dropout2d  = nn.Dropout2d(0.1)
    self.ReLU       = nn.ReLU()

  def forward(self, x):

    x = self.conv2d_1(x)
    x = self.batchNorm(x)
    # x = self.MaxPool2d(x)
    x = self.Dropout2d(x)
    x = self.ReLU(x)

    return x

class DeConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, padding_mode="zeros", scale_factor=2):
    super().__init__()
    self.conv2d_1   = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode=padding_mode)
    self.batchNorm  = nn.BatchNorm2d(out_channels)
    self.UpSample2d = nn.UpsamplingNearest2d(scale_factor=scale_factor)
    self.Dropout2d  = nn.Dropout2d(0.1)
    self.ReLU       = nn.ReLU()

  def forward(self, x):

    x = self.conv2d_1(x)
    x = self.batchNorm(x)
    x = self.UpSample2d(x)
    x = self.Dropout2d(x)
    x = self.ReLU(x)

    return x

class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.ConvBlock_1    = ConvBlock(3  ,64 ,3,1,1,"replicate")
    self.ConvBlock_2    = ConvBlock(64 ,128,3,2,1,"replicate")
    self.ConvBlock_3    = ConvBlock(128,256,3,2,1,"replicate")

    self.ResBlock_1     = ResBlock(256,256,3,1,1,"replicate")

    self.DeConvBlock_1  = DeConvBlock(256,128,3,1,1,"replicate",2)
    self.DeConvBlock_2  = DeConvBlock(128,64 ,3,1,1,"replicate",2)
    self.DeConvBlock_3  = DeConvBlock(64 ,3  ,3,1,1,"replicate",1)

  def forward(self, x):

    x = self.ConvBlock_1(x)
    x = self.ConvBlock_2(x)
    x = self.ConvBlock_3(x)

    # x = self.ResBlock_1(x)
    # x = self.ResBlock_1(x)
    # x = self.ResBlock_1(x)

    x = self.DeConvBlock_1(x)
    x = self.DeConvBlock_2(x)
    x = self.DeConvBlock_3(x)

    x = nn.Tanh()(x)

    return x 

model = Model()
# model.cuda

In [0]:
image = model(torch.rand((1,3,256,256)))
im_show(image)
image.shape

## Loss

In [0]:
def MSE_loss(pred, true):
  loss = nn.MSELoss()
  return loss(pred, true)

In [0]:

vgg = models.vgg16(pretrained=True)

# https://towardsdatascience.com/pytorch-implementation-of-perceptual-losses-for-real-time-style-transfer-8d608e2e9902

class VGGFearureExtractor(nn.Module):
  def __init__(self, vgg):
    super().__init__()
    self.vgg_layers   = vgg.features
    self.layer_names  = { '3': "relu1_2", '8': "relu2_2", '15': "relu3_3", '22': "relu4_3"}

  def forward(self, x):
    output = {}
    for name, module in self.vgg_layers._modules.items():
      x = module(x)
      if name in self.layer_names:
          output[self.layer_names[name]] = x

    return output

vgg_extractor = VGGFearureExtractor(vgg)

def VGG_loss(pred, true):

  prediction_features = vgg_extractor(pred)
  true_features = vgg_extractor(true)
  losses = 0

  for key in prediction_features.keys():
    loss = prediction_features[key] - true_features[key]
    loss = torch.mean(loss**2)
    losses = torch.add(loss, losses)

  return losses

img = torch.rand((1,3,256,256))
# VGG_loss(torch.rand((1,3,256,256)),torch.rand((1,3,256,256)))
# VGG_loss(img, img)


In [0]:

resnet = models.resnet18(pretrained=True)

class ResNetFeatureExtractor(nn.Module):
  def __init__(self, Resnet):
    super().__init__()
    self.Feature_block = [4,5,6,7]

    self.levels = []
    for i in self.Feature_block:
      self.levels.append(nn.Sequential(*list(resnet.children())[:i]))


  def forward(self, x):
    output = []
    for i, level in enumerate(self.levels):
      output.append(level(x))

    return output

resnet_extractor = ResNetFeatureExtractor(resnet)

# resnet_extractor(torch.rand((1,3,256,256)))


def ResNet_loss(pred, true):

  prediction_features = resnet_extractor(pred)
  true_features = resnet_extractor(true)
  losses = 0

  for i, block in enumerate(prediction_features):
    loss = prediction_features[i] - true_features[i]
    loss = torch.mean(loss**2)
    losses = torch.add(loss, losses)
    

  return losses

# ResNet_loss(torch.rand((1,3,256,256)),torch.rand((1,3,256,256)))
ResNet_loss(img, img)


## Optimizer

In [0]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Train

In [0]:
# %%time

def checkpoint(model, dir):
  return torch.utils.checkpoint.checkpoint(model)

images = [img]

def train_loop():
  history = []
  for epoch in range(100):
    for img in images:
      model.zero_grad()

      prediction = model(img)
      
      loss = VGG_loss(prediction, img)
      history.append(loss.detach().numpy().tolist())

      loss.backward()

      optimizer.step()
    print("Loss: {}".format(loss))
    with torch.no_grad():
      im_show(prediction)
      display.clear_output(wait=True)
  return history, prediction


history, final_prediction = train_loop()

In [0]:
s = pd.Series(history)
s.plot.line()

In [0]:
torch.mean(final_prediction - img)

# Evaluate

In [0]:
with torch.no_grad():
  for image in images:
    prediction = net(image.view(-1, IMG_WIDTH*IMG_HEIGHT))
    print(loss(prediction, image))

# Save model

In [0]:
def save_model(dir, model):
  torch.save(model.state_dict(), dir)

save_model(CHECKPOINT_DIR, net)

# Load Model

In [0]:
net = Net()
net.load_state_dict(torch.load(CHECKPOINT_DIR))