# Import and configurations

In [None]:
from nyuv2_dataset import NYUv2Dataset
from cityscapes_dataset import CityscapesDataset
import torch
from torch.utils.data import DataLoader
from cross_stitchnet import CrossStitchNet
from densenet import DenseNet
from depthnet import DepthNet
from splitnet import SplitNet
from stan import STAN
from mtan import MTAN
from segnet import SegNet
from normalnet import NormalNet
from trainer import Trainer
from utils import count_params, visualize_results, build_stats_dict
import matplotlib.pyplot as plt
import os
import numpy as np

In [None]:
dataset_string = 'cityscapes'
tasks = ['segmentation', 'depth'] if dataset_string == 'cityscapes' else ['segmentation', 'depth', 'normal']
BATCH_SIZE = 8 if dataset_string == 'cityscapes' else 2
LR = 1e-4
filter = [64, 128, 256, 512, 512] 

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data Loading

In [None]:
if dataset_string == 'nyuv2':
    print("NYUv2 Dataset")
    nyuv2_train = NYUv2Dataset(root="../dataset/nyuv2_preprocessed", split='train')
    classes = nyuv2_train.get_classes()

    nyuv2_val = NYUv2Dataset(root="../dataset/nyuv2_preprocessed", split='val')
    train_dl = DataLoader(nyuv2_train, batch_size=BATCH_SIZE, shuffle=True)
    val_dl = DataLoader(nyuv2_val, batch_size=BATCH_SIZE, shuffle=False)

    for image, out in train_dl:
        print('Image: ' + str(list(image.shape)) + ',' + ' Label: ' + str(list(out['segmentation'].shape)) + ',' + ' Depth: ' + str(list(out['depth'].shape)) + ',' + ' Normals: ' + str(list(out['normal'].shape)))
        print(f'Image: {image.max().item()}, {image.min().item()}') 
        print('Label: ' +  str(out['segmentation'].max().item()) + ',' + str(out['segmentation'].min().item()))
        print('Depth: ' + str(out['depth'].max().item()) + ', ' + str(out['depth'].min().item()))
        print('Normals: ' + str(out['normal'].max().item()) + ', ' + str(out['normal'].min().item()))
        break
else:
    print("Cityscapes Dataset")
    cityscapes_train = CityscapesDataset(root="../dataset/cityscapes_preprocessed")
    cityscapes_val = CityscapesDataset(root="../dataset/cityscapes_preprocessed", split='val')
    train_dl = DataLoader(cityscapes_train, batch_size=BATCH_SIZE, shuffle=True)
    val_dl = DataLoader(cityscapes_val, batch_size=BATCH_SIZE, shuffle=False)
    classes = cityscapes_train.get_classes()
    for image, out in train_dl:
        print('Image: ' + str(list(image.shape)) + ',' + ' Label: ' + str(list(out['segmentation'].shape)) + ',' + ' Depth: ' + str(list(out['depth'].shape)))
        print(f'Image: {image.max().item()}, {image.min().item()}') 
        print('Label: ' +  str(out['segmentation'].max().item()) + ', ' + str(out['segmentation'].min().item()))
        print('Depth: ' + str(out['depth'].max().item()) + ', ' + str(out['depth'].min().item()))
        for t in out.keys():
            plt.imshow(out[t][0].cpu().numpy())
            plt.savefig(f'../{dataset_string}_{t}')
        break
print(f"Number of classes: {classes}")

for image, out in val_dl:
    plt.imshow(image[0].permute(1,2,0).cpu().numpy())
    plt.savefig(f'../{dataset_string}_image')
    for t in out.keys():
        if t == 'normal':
            plt.imshow(out[t][0].permute(1,2,0).cpu().numpy())
        elif t == 'depth':
            plt.imshow(out[t][0].cpu().numpy(), cmap='jet')
        else:
            plt.imshow(out[t][0].cpu().numpy())
        plt.savefig(f'../{dataset_string}_{t}')
    break

# Model Definitions and Dimensions

In [None]:
cross = CrossStitchNet(filter=filter, classes=classes, mid_layers=1, tasks=tasks)
dense = DenseNet(filter=filter, classes=classes, mid_layers=0, tasks=tasks)
depth = DepthNet(filter=filter, mid_layers=6)
mtan = MTAN(filter=filter, mid_layers=0 , classes=classes, tasks=tasks)
norm = NormalNet(filter=filter, mid_layers=6)
seg = SegNet(filter=filter, mid_layers=6, classes=classes)
split = SplitNet(filter=filter, mid_layers=6, classes=classes, tasks=tasks)
stan = STAN(filter=filter, mid_layers=4, classes=classes, task=tasks[0])

In [None]:
mtan_params = count_params(mtan)
cross_params = count_params(cross)
dense_params = count_params(dense)
depth_params = count_params(depth)
norm_params = count_params(norm)
seg_params = count_params(seg)
split_params = count_params(split)
stan_params = count_params(stan)
print(f"MTAN: {mtan_params}")
print(f"Cross: {cross_params}, {cross_params>=mtan_params}")
print(f"Dense: {dense_params}, {dense_params>=mtan_params}")
print(f"Depth: {depth_params}, {depth_params>=mtan_params}")
print(f"Norm: {norm_params}, {norm_params>=mtan_params}")
print(f"Seg: {seg_params}, {seg_params>=mtan_params}")
print(f"Split: {split_params}, {split_params>=mtan_params}")
print(f"STAN: {stan_params}, {stan_params>=mtan_params}")

# Model Training

In [None]:
model = mtan.to(device)
print(f"{model.name} has {count_params(model)} parameters")

In [None]:
opt = torch.optim.Adam(model.parameters(), lr=LR)
trainer = Trainer(model, opt, dataset_string, device, dwa=False, save_path='../')

In [None]:
trainer.train(train_dl, val_dl, epochs=5, save=False, check=1, grad=True)

# Model Evaluation

In [None]:
dwa_model = False
model = DenseNet(filter=filter, classes=classes, mid_layers=0, tasks=tasks)
# model = MTAN(classes=classes, tasks=tasks)
# model = CrossStitchNet(classes=classes, mid_layers=1, tasks=tasks)
# model = SplitNet(filter=filter, mid_layers=6, classes=classes, tasks=tasks)
# model = DepthNet(filter=filter, mid_layers=6)
# model = SegNet(filter=filter, mid_layers=6, classes=classes)
# model = NormalNet(filter=filter, mid_layers=6)
# model = STAN(filter=filter, mid_layers=4, classes=classes, task='segmentation')
# model = STAN(filter=filter, mid_layers=4, classes=classes, task='depth')
# model = STAN(filter=filter, mid_layers=4, classes=classes, task='normal')
path = f'../models/{dataset_string}/{model.name}'
if len(model.tasks) > 1:
    path += '_dwa' if dwa_model else '_equal'
path += f'/{model.name}_100.pth'
print(path)
model.load_state_dict(torch.load(path, weights_only=True))

In [None]:
nresults = 10
id_result = 0
for i, (image, out) in enumerate(val_dl):
    state = visualize_results(model, device, image, out, id_result, nresults, out=True, save=True, save_path='../', dataset_str=dataset_string)
    id_result += BATCH_SIZE
    if state:
        break

In [None]:
stats = build_stats_dict(model, device)
stats_str = []
stats_val = []
train_stats = Trainer(model, None, dataset_string, device, dwa=dwa_model, save_path='../tmp')
loss = train_stats._val_epoch(val_dl, stats)

save_path = '../'
if len(model.tasks) == 1:
    path = save_path + f"results/{dataset_string}/{model.name}"
    if not os.path.exists(path):
        os.makedirs(path)
else:
    dwa_string = 'dwa' if dwa_model else 'equal'
    path = save_path +  f"results/{dataset_string}/{model.name}_{dwa_string}"
    for t in model.tasks:
        if not os.path.exists(path + f'/{t}'):
            os.makedirs(path + f'/{t}')

for k in stats.keys():
    for t in stats[k].keys():
        stat_comp = stats[k][t].compute()
        if t != 'ad':
            stats_str.append(t)
            stat_tmp = stat_comp.cpu().item()
            stats_val.append(f'{stat_tmp:.4f}')
            print(f"{t}: {stat_tmp:.4f}")
        else:
            for i in stat_comp.keys():
                if i != 'tolls':
                    stat_tmp = stat_comp[i].cpu().item()
                    stats_str.append(t + f'_{i}')
                    stats_val.append(f'{stat_tmp:.4f}')
                    print(f"{t}_{i}: {stat_tmp:.4f}")
                else:
                    for j in range(len(stat_comp[i])):
                        stats_str.append(i + f'_{stats[k][t].tolls[j]}')
                        stat_tmp = stat_comp[i][j].cpu().item()
                        stats_val.append(f'{stat_tmp:.4f}')
                        print(f"{i}_{stats[k][t].tolls[j]}': {stat_tmp:.4f}")
np.savetxt(path + f'/stats.txt', [p for p in zip(stats_str, stats_val)], delimiter=': ', fmt='%s')