In [None]:
import numpy as np
from tqdm.notebook import tqdm
import torch
from dpt.models import DPTDepthModel
from finetune.datasets import Nutrition5k
from finetune.utils import calculate_metrics, visualize_image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
data_path = "data/nutrition5k/"
dataset = Nutrition5k(split='test', dataset_path=data_path)
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
weight_path = 'weights/aug_dpt.pt'

In [None]:
@torch.no_grad()
def evaluate_dpt(model_path, loader, device, scale=0.0000305, shift=0.1378):
    model = DPTDepthModel(
        path=model_path,
        scale=scale,
        shift=shift,
        invert=True,
        backbone="vitb_rn50_384",
        non_negative=True,
        enable_attention_hooks=False,
    ).to(device)

    batch_pred = []
    batch_depth = []
    batch_mask = []
    batch_metrics = []
    for rgb_img, depth, mask in tqdm(loader):
        rgb_img = rgb_img.to(device)
        pred = model(rgb_img)
        pred = pred.cpu().detach().numpy()
        depth = depth.cpu().detach().numpy()
        mask = mask.cpu().detach().numpy()
        metrics = calculate_metrics(pred, depth, mask)
        batch_pred.append(pred)
        batch_depth.append(depth)
        batch_mask.append(mask)
        batch_metrics.append(metrics)
    overall_metrics = calculate_metrics(
        np.concatenate(batch_pred),
        np.concatenate(batch_depth),
        np.concatenate(batch_mask),
    )
    return batch_pred, batch_metrics, overall_metrics

In [None]:
batch_pred, batch_metrics, overall_metrics = evaluate_dpt(weight_path, loader, device)

In [None]:
batch_acc = [metrics['accuracy'][0] for metrics in batch_metrics]
examined_indices = np.argsort(batch_acc)[:10]

for i in examined_indices:
    rgb_image = np.clip(dataset[i][0].transpose(1, 2, 0) * 0.5 + 0.5, 0, 1)
    visualize_image(rgb_image, batch_pred[i].squeeze(), dataset[i][1], figsize=(20, 15), fontsize=40, norm_value=(1, 4), cmap="jet")