In [None]:
SEED = 1212027

In [None]:
import os
import sys

# Include root directory into PYTHONPATH
sys.path.append(os.path.abspath(".."))

# Initialization

In [None]:
import local.utils as utils

utils.deterministic_environment(SEED)

# Determine global device
device = utils.get_torch_device()
print("Device `{}` for computations with PyTorch".format(device))

# Dataset Preparation

In [None]:
import os

ROOT = os.path.abspath("..")
DIODE_ROOT = os.path.join(ROOT, "diode")
DATASET_PATH = os.path.join(DIODE_ROOT, "val", "indoors")

In [None]:
from local.diode.utils import get_filelist, format_filelist
import pandas as pd

# DataFrame storage with initial shuffling for train/valid split
data = pd.DataFrame(format_filelist(get_filelist(DATASET_PATH)))
data = data.sample(frac=1, random_state=SEED)

In [None]:
import math

split = math.floor(len(data) * 0.8)
train = data[:split]
valid = data[split:]

# Model Creation

In [None]:
import torch

# UNet Architecture from PyTorch Hub
model = torch.hub.load(
    'mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32,
    pretrained=False
)

# Model Summaries

In [None]:
from torchinfo import summary

summary(model, input_size=(1, 3, 768, 1024), depth=8)

# Training

In [None]:
from local.diode.dataset import DIODEDataset
from local.diode.transforms import ClipMaxDepth, ZScoreNormalization, MinMaxScaler
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Grayscale, ToTensor

base_preprocess = Compose([
  # Convert to PyTorch Tensor
  ToTensor(),
])

image_preprocess = Compose([
  # Grayscale since depth perception does not need colors ideally
  Grayscale(num_output_channels=3),
  # Normalize input data
  ZScoreNormalization(),
])

depth_preprocess = Compose([
  # Clip maximum depth
  ClipMaxDepth(300),
  # Scale to range [0, 1]
  MinMaxScaler(0., 1.)
])

# Create datasets
datasets = {"train": train, "valid": valid}
datasets = {
    k: DIODEDataset(
        v, 
        image_transform=Compose([base_preprocess, image_preprocess]),
        depth_transform=Compose([base_preprocess, depth_preprocess]),
    ) 
    for k, v in datasets.items()
}

# Create DataLoaders for batch processing
dataloaders = {
    k: DataLoader(v, batch_size=2, shuffle=True, num_workers=1) 
    for k, v in datasets.items()
}

In [None]:
import torch.nn as nn
from local.loss import ssim, depth_smoothness, weighted_loss

criterion = weighted_loss.WeightedLoss(
  [
    ssim.MaximumSSIMLoss(k1=0.01 * 2, k2=0.03 * 2),
    nn.MSELoss(),
    depth_smoothness.InverseDepthSmoothnessLoss(),
  ],
  weights=[0.85, 0.1, 0.9],
)

In [None]:
from torch import optim

optimizer = optim.Adam(model.parameters(), lr=1e-2)
# Decay learning rate by 0.01 every 10 epochs
adaptive_lr = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.01)

In [None]:
import time


def train_model(model, criterion, optimizer, scheduler, epochs=50):
    best_loss = {phase: 1e6 for phase in ["train", "valid"]}

    model = model.to(device)
    criterion = criterion.to(device)
    for epoch in range(1, epochs + 1):
        print(f"Epoch {epoch}/{epochs}")
        print('-' * 10)

        # Training and validation phases
        for phase in ["train", "valid"]:
            dataset_loader = dataloaders[phase]

            if phase == "train":
                model.train()
            elif phase == "valid":
                model.eval()

            phase_tick = time.time()
            phase_loss = 0.0
            phase_total = 0

            for batch in dataset_loader:
                input, target = batch["input"], batch["target"]

                input = input.to(device)
                target = target.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    output = model(input)

                    # Repeat channels by 3
                    output = torch.repeat_interleave(output, 3, axis=1)
                    target = torch.repeat_interleave(target, 3, axis=1)

                    loss = criterion(target, output)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                
                phase_loss = phase_loss + loss * input.shape[0]
                phase_total = phase_total + input.shape[0]

            if phase == "train":
                scheduler.step()

            phase_time_diff = time.time() - phase_tick
            epoch_loss = phase_loss / phase_total
            print("{} Loss: {:.4f} - {} minute(s) {:.4f} second(s)".format(
                phase, epoch_loss, phase_time_diff // 60, phase_time_diff % 60
            ))

            best_phase_loss = best_loss[phase]
            if epoch_loss < best_phase_loss:
                best_loss[phase] = epoch_loss
                torch.save(model.state_dict(), "model_{}.pth".format(phase))
        
        print()

    # Load best model state
    model = model.cpu()
    criterion = criterion.cpu()
    return model

In [None]:
train_model(model, criterion, optimizer, adaptive_lr, epochs=30)

In [None]:
torch.save(model.state_dict(), 'model_final.pth')

# Evaluating Model

In [None]:
model.load_state_dict(torch.load('model_valid.pth', map_location=torch.device('cpu')))
model.eval()

In [None]:
from torchvision.transforms import ToTensor


test_dataset = DataLoader(
    DIODEDataset(valid, image_transform=base_preprocess, depth_transform=base_preprocess), 
    batch_size=1, shuffle=True, num_workers=2
)

In [None]:
from local.utils import visualize_depth_map
from piqa import SSIM
import numpy as np


for batch_index, batch in enumerate(test_dataset):
    input, target, mask = batch["input"], batch["target"], batch["mask"]

    with torch.set_grad_enabled(False):
        output = model(image_preprocess(input))

    criterion_target = torch.repeat_interleave(depth_preprocess(target), 3, dim=1)
    criterion_output = torch.repeat_interleave(output, 3, dim=1)
    print(
        SSIM(k1=0.01 * 2, k2=0.03 * 2)(criterion_target, criterion_output),
        nn.MSELoss()(criterion_target, criterion_output),
    )

    input = np.transpose(input.squeeze(axis=0), (1, 2, 0))
    target = np.transpose(target.squeeze(axis=0), (1, 2, 0))
    output = np.transpose(output.squeeze(axis=0), (1, 2, 0))
    mask = np.transpose(mask, (1, 2, 0))

    # target = np.ma.masked_where(~mask, target)
    visualize_depth_map(input, target, output)

    if batch_index == 3:
        break