In [1]:
import os

while "src" not in os.listdir():
    assert "/" != os.getcwd(), "src directory not found"
    os.chdir("..")

from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F


from src.lib.nyu_dataset import NYUDataset, transform
from src.lib.depth_estimator import DepthEstimator
from src.lib.resnet_loader import load_classifier_resnet50, load_contrastive_resnet50

In [2]:
DATA_DIR = "data"
DATASET_FILE = "nyu_depth_v2_labeled.mat"

batch_size = 64
lr = 0.0001
epochs = 100

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [3]:
dataset = NYUDataset(os.path.join(DATA_DIR, DATASET_FILE), transform=transform)

n_train, n_val = int(0.8 * len(dataset)), int(0.1 * len(dataset))
n_test = len(dataset) - n_train - n_val

train_set, val_set, test_set = random_split(dataset, [n_train, n_val, n_test])

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [4]:
model = DepthEstimator(load_classifier_resnet50()).to(device)

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.MSELoss()

In [6]:
validation_loss_history = []
loss_history = []

for epoch in range(epochs):
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    model.train()
    for i, (color_map, depth_map) in pbar:
        color_map = color_map.to(device)
        depth_map = depth_map.to(device)

        pred = model(color_map)
        loss = loss_fn(pred, depth_map)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_history.append(loss.item())

        # update pbar description with batch and loss
        pbar.set_description(f"Epoch {epoch + 1}/{epochs} - Batch {i + 1}/{len(train_loader)} - Loss: {loss.item():.4f}")
    
    # take tensors we don't need for validation off the gpu
    del color_map, depth_map, pred, loss
    torch.cuda.empty_cache()
    
    # compute validation loss
    val_loss = 0
    model.eval()
    with torch.no_grad():
        for i, (color_map, depth_map) in enumerate(val_loader):
            color_map = color_map.to(device)
            depth_map = depth_map.to(device)

            pred = model(color_map)
            val_loss += loss_fn(pred, depth_map).item()
        val_loss /= len(val_loader)

        validation_loss_history.append(val_loss)
    # display validation loss
    print(f"Epoch {epoch + 1}/{epochs} - Validation Loss: {val_loss:.4f}")

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 1/100 - Validation Loss: 0.3765


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 2/100 - Validation Loss: 0.2200


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 3/100 - Validation Loss: 0.2193


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 4/100 - Validation Loss: 0.1786


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 5/100 - Validation Loss: 0.2119


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 6/100 - Validation Loss: 0.1668


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 7/100 - Validation Loss: 0.1543


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 8/100 - Validation Loss: 0.1540


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 9/100 - Validation Loss: 0.1347


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 10/100 - Validation Loss: 0.1671


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 11/100 - Validation Loss: 0.1284


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 12/100 - Validation Loss: 0.1606


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 13/100 - Validation Loss: 0.1593


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 14/100 - Validation Loss: 0.1425


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 15/100 - Validation Loss: 0.1269


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 16/100 - Validation Loss: 0.1403


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 17/100 - Validation Loss: 0.1485


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 18/100 - Validation Loss: 0.1295


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 19/100 - Validation Loss: 0.1324


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 20/100 - Validation Loss: 0.1268


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 21/100 - Validation Loss: 0.1427


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 22/100 - Validation Loss: 0.1391


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 23/100 - Validation Loss: 0.1703


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 24/100 - Validation Loss: 0.1866


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 25/100 - Validation Loss: 0.1240


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 26/100 - Validation Loss: 0.1237


  0%|          | 0/19 [00:00<?, ?it/s]

In [None]:
from datetime import datetime
import json

time_str = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
experiment_dir = os.path.join("experiments", time_str)
os.mkdir(experiment_dir)

# save model
with open(os.path.join(experiment_dir, "model.pth"), "wb") as f:
    torch.save(model.state_dict(), f)

config = {
    "batch_size": batch_size,
    "lr": lr,
    "epochs": epochs,
    "device": device.type,
    "experiment_dir": experiment_dir
}
# save config
with open(os.path.join(experiment_dir, "config.json"), "w") as f:
    json.dump(config, f)

In [None]:
# Plot loss history
plt.plot(loss_history)
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.title("Training Loss")
# save figure
plt.savefig(os.path.join(experiment_dir, "training_loss_history.png"))
plt.show()

# Plot validation loss history
plt.plot(validation_loss_history)
plt.xlabel("Epoch")
plt.ylabel("Validation Loss")
plt.title("Validation Loss")
# save figure
plt.savefig(os.path.join(experiment_dir, "validation_loss_history.png"))
plt.show()

In [None]:
# Visualize the models predictions
# pick 9 random samples
idxs = np.random.choice(len(dataset), 9)
samples = [dataset[idx] for idx in idxs]
color_maps = []
depth_maps = []
preds = []
for color_tensor, depth_tensor in samples:
    # to numpy array
    color_map = color_tensor.cpu().numpy().transpose(1, 2, 0)
    depth_map = depth_tensor.cpu().numpy().squeeze()
    depth_prediction = model(color_tensor.unsqueeze(0).to(device)).cpu().detach().squeeze().numpy()

    color_maps.append(color_map)
    depth_maps.append(depth_map)
    preds.append(depth_prediction)

# plot 3x9 grid of color, depth, and prediction
fig, axes = plt.subplots(3, 9, figsize=(9*4, 3*4))
for i in range(9):
    axes[0, i].imshow(color_maps[i])
    depth_map = depth_maps[i]
    # scale between 0 and 1
    depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
    pred = preds[i]
    # scale between 0 and 1
    pred = (pred - pred.min()) / (pred.max() - pred.min())
    axes[2, i].imshow(depth_map)
    axes[1, i].imshow(pred)

for ax in axes.ravel():
    ax.axis("off")

# save figure
plt.savefig(os.path.join(experiment_dir, "predictions.png"))

plt.show()
