In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path

import numpy as np
import torch
import torch.utils.data as data
from torch import Tensor
from torch.nn import Module
from torch.optim import AdamW
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from dpt import DepthSense
from util.loss import GeoNetLoss

In [None]:
torch.random.manual_seed(7643)

In [None]:
class DepthSenseDataset(Dataset):
    """
    Utility and wrapper for loading datasets.
    """

    def __init__(self, root_dir: str):
        self.root_dir: Path = Path(root_dir)
        #self.directories: list[str] = [
        #    d for d in self.root_dir.iterdir() if d.is_dir()
        #]

    def __getitem__(self, i: int) -> tuple[Tensor, Tensor, Tensor]:
        #cur: Path = self.directories[i]
        # HACK: google drive mount is super slow...
        curr_idx = str(i).rjust(7, '0')
        cur = os.path.join(self.root_dir, curr_idx)
        image: Tensor = torch.from_numpy(np.load(f"{cur}/frame.npy")).float()
        depth: Tensor = torch.from_numpy(np.load(f"{cur}/depth.npy")).float()
        normal: Tensor = torch.from_numpy(np.load(f"{cur}/normal.npy")).float()
        return image, depth, normal

    def __len__(self) -> int:
        return 44924

In [None]:
# Parameters and hyperparameters used for training.
description: str = "DepthSense for Metric Depth and Normal Estimation"
model_path: str = "models/teacher_{}.pth"

batch_size: int = 4
betas: tuple[float, float] = 0.9, 0.999
dataset_name: Dataset = "hypersim"
decay: float = 1e-2
device: str = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
encoder: str = "vits"
epochs: int = 10
eps: float = 1e-8
features: int = 128
lr: float = 1e-4
refine_edges: bool = False

# Data splitting.
dataset: Dataset = DepthSenseDataset(f"datasets/{dataset_name}")
data_size: int = len(dataset)
train_size: int = int(0.9 * data_size)
val_size: int = data_size - train_size
train_set, val_set = data.random_split(dataset, [train_size, val_size])

# Model initialization.
model_name: str = model_path.replace("{}", dataset_name)
model: DepthSense = DepthSense(encoder, features, device=device)
criterion: Module = GeoNetLoss()
optimizer: Optimizer = AdamW(model.parameters(), lr, betas, eps, decay)

In [None]:
# OPTIONAL. Model loading, if not training from scratch.
try:
    model = torch.load(model_name, weights_only=True)
except FileNotFoundError:
    model = DepthSense(encoder, features)

model.eval()

In [None]:
# Training.
model.train()

train_loader: DataLoader = DataLoader(train_set, batch_size, shuffle=True)
max_iters: int = len(train_loader)
for e in range(epochs):
    running_loss: float = 0.0
    for i, (x, z_gt, n_gt) in enumerate(train_loader):
        # Move to appropriate device.
        x = x.to(device).permute(0, 3, 1, 2)
        z_gt = z_gt.to(device)
        n_gt = n_gt.to(device)
        # Forward pass.
        z_hat, n_hat = model(x, refine_edges)
        loss: Tensor = criterion(z_hat, z_gt, n_hat, n_gt)
        # Backward pass.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Statistics recollection and display.
        running_loss += loss.item()
        #if (i + 1) % 100 == 0:
        loss = running_loss / i
        print(f"Epoch {e}, iter: {i + 1} -- Loss: {loss:.3f}")

# Save current model.
torch.save(model.state_dict(), model_name)

In [None]:
# Validation.
model.eval()

val_loader: DataLoader = DataLoader(val_set, batch_size, shuffle=True)
max_iters: int = len(val_loader)
running_loss: float = 0.0
for i, (x, z_gt, n_gt) in enumerate(val_loader):
    # Move to appropriate device.
    x = x.to(device)
    z_gt = z_gt.to(device)
    n_gt = n_gt.to(device)
    # Forward pass.
    z, n, z_hat, n_hat = model(x, refine_edges)
    loss: Tensor = criterion(z_hat, z_gt, n_hat, n_gt)
    running_loss += loss.item()

# Statistics recollection and display.
avg_loss: float = running_loss / max_iters
print(f"Validation loss: {avg_loss:.3f}")

In [None]:
# Pending tasks.

# TODO: Train the teacher and distillate to students.