## Model setup and training

In [None]:
# In[1]:

import os
import glob
import time
from datetime import timedelta
import cv2
import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

from dpt.plmodels import InteriorNetDPT
from dpt.transforms import Resize, NormalizeImage, PrepareForNet
from data.InteriorNetDataset import InteriorNetDataset
from data.metrics import SILog, get_metrics
from util.gpu_config import get_batch_size

In [None]:
torch.manual_seed(0)
np.random.seed(0)

# k8s paths
k8s = True
k8s_repo = r'opt/repo/dynamic-inference'
k8s_pvc = r'christh9-pvc'

# path settings
input_path = 'input'
output_path = 'output_monodepth'
model_path = 'weights/dpt_hybrid_nyu-2ce69ec7.pt'
dataset_path = 'video_inference_common/resources'
logs_path = 'train-logs'

if k8s:
    input_path = os.path.join(k8s_repo, input_path)
    output_path = os.path.join(k8s_repo, output_path)
    model_path = os.path.join(k8s_pvc, 'dpt-hybrid-nyu.pt')
    dataset_path = os.path.join(k8s_repo, dataset_path)
    logs_dir = os.path.join(k8s_pvc, logs_path)
    os.chdir('/')

In [None]:
net_w = 640
net_h = 480

normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    
transform = Compose(
    [
        Resize(
            net_w,
            net_h,
            resize_target=None,
            keep_aspect_ratio=True,
            ensure_multiple_of=32,
            resize_method="minimal",
            image_interpolation_method=cv2.INTER_CUBIC,
        ),
        normalization,
        PrepareForNet(),
    ]
)

In [None]:
# create dataloader

start = time.time()
    
batch_size = get_batch_size(None)
lr = 1e-4
num_epochs = 200

print('-- Hyperparams --')
print(f'Batchsize: {batch_size}')
print(f'Learning rate: {lr}')
print(f'Epochs: {num_epochs}')
print('-----------------')

In [None]:
start = time.time()

# model setup
model = InteriorNetDPT(batch_size, lr, num_epochs, model_path)

# logging setup
exp_idx = len(list(filter(lambda f: '.pt' in f, os.listdir(os.path.join(logs_dir)))))

# dataloader setup
interiornet_dataset = InteriorNetDataset(dataset_path, transform=transform, subsample=True)
dataloader = DataLoader(interiornet_dataset, 
                        batch_size=model.hparams.batch_size,
                        shuffle=True,
                        pin_memory=True,
                        num_workers=4*torch.cuda.device_count() if torch.cuda.is_available() else 0)


print(f'Created datasets in {timedelta(seconds=round(time.time()-start,2))}')

In [None]:
# ddp doesn't work on jupyter
if torch.cuda.is_available(): 
    trainer = pl.Trainer(gpus=torch.cuda.device_count(), 
                         max_epochs=model.hparams.num_epochs)
else:
    trainer = pl.Trainer(max_epochs=1)

In [None]:
print('Training')

start = time.time()
trainer.fit(model, dataloader)

print(f'Training completed in {timedelta(seconds=round(time.time()-start,2))}')
print(f'Training checkpoints and logs are saved in {trainer.log_dir}')

## Model evaluation

In [None]:
import matplotlib.pyplot as plt
plt.style.use('dark_background')

from util.validate_nyu import BadPixelMetric

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# compare to original model predictions
original_model = InteriorNetDPT(batch_size, lr, num_epochs, model_path)
original_model.to(device)
print('Loaded original model')

In [None]:
metric = BadPixelMetric()

with torch.no_grad():
    num_photos = len(interiornet_dataset)
    for i,sample in enumerate(interiornet_dataset):
        fig = plt.figure(figsize=(12,12))
        
        image, depth = torch.tensor(sample['image']).to(device), torch.tensor(sample['depth']).to(device)
        out = model(image.unsqueeze(0))
        out_nyu = original_model(image.unsqueeze(0)).squeeze(0)
        print(f'Frame {i} finetune err:', metric(out, depth.unsqueeze(0), ~torch.isnan(depth.unsqueeze(0))).item())
        print(f'Frame {i} NYU err:', metric(out_nyu, depth.unsqueeze(0), ~torch.isnan(depth.unsqueeze(0))).item())
        
        out, depth, out_nyu = out.squeeze(0).to('cpu'), depth.to('cpu'), out_nyu.to('cpu')
        
        for j,img in enumerate([out, depth, out_nyu]):
            ax = fig.add_subplot(1,3,j+1)
            if j == 0:
                ax.set_ylabel(f'Frame {i}', rotation=0, size='large')
            ax.imshow(img)
            
            ax.yaxis.set_label_coords(-0.2, 0.5)
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])
        
        for ax,name in zip(fig.axes,['Finetuned', 'Truth', 'NYU']):
            ax.set_title(name)
                
        fig.show()
        
        if i >= num_photos - 1: break
        i += 1
        
        plt.show()