In [2]:
%load_ext autoreload
%autoreload 2

import torch
import sys
import glob
import matplotlib.pyplot as plt
import numpy as np

# # Importing the project from a subfolder
sys.path.append('./project')

from models.transformer import Aggregator
from models.tokenizer import Tokenizer
from models.vggt import VGGT, unflatten_tokens, reflatten_tokens
from heads.camera_head import CameraHead

from dataloader.preprocessing import *
from dataloader.dataset import *

# Test the model

In [3]:
# Load the dataset as previously demonstrated, also get device
path = "/sdf/home/y/youngsam/data/dune/larnet/h5/DataAccessExamples/tutorial_example_v1.h5"

dataset = dataset_from_file(path)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
# Grab a sample
sample, rotations = dataset.choose_events(10, 3)
patch_counts, all_coords, all_patches, all_depths = stack_patches(sample)
patch_counts = torch.Tensor(patch_counts).int().to(device)
all_coords = torch.Tensor(all_coords).int().to(device)
all_patches = torch.Tensor(all_patches).to(device)
all_depths = torch.Tensor(all_depths).to(device)

In [45]:
# Initialize a model
test_model = VGGT().to(device)

In [46]:
# Run the model
predictions, test_output, patch_start_idx = test_model(patch_counts, all_coords, all_patches)

In [47]:
len(test_output), test_output[-1].shape
# 24 blocks, results from every block; final result is NxSx(P+5)x(2*D)
# P+5 because 1 camera token and 4 register tokens added
# D*2 because ???

(24, torch.Size([10, 3, 62, 512]))

In [48]:
predictions["depth"][0]

tensor([[-0.0003, -0.0321, -0.0547,  0.0458,  0.0070, -0.0137,  0.0181,  0.0267,
          0.0393, -0.0516,  0.0026, -0.0366,  0.0019, -0.0055,  0.0129, -0.0038],
        [-0.0007, -0.0560, -0.0212,  0.0273, -0.0168,  0.0453,  0.0542,  0.0464,
         -0.0120, -0.0246, -0.0320, -0.0202,  0.0071, -0.0035, -0.0135,  0.0143],
        [ 0.0280,  0.0683,  0.0230,  0.0507, -0.0201,  0.0106,  0.0294, -0.0430,
         -0.0429, -0.0530, -0.0362, -0.0334, -0.0125, -0.0489, -0.0415, -0.0246],
        [-0.0473, -0.0183, -0.0106, -0.0333, -0.0382, -0.0202,  0.0741, -0.0279,
          0.0365, -0.0687, -0.0187, -0.0627,  0.0025, -0.0485,  0.0028, -0.0252],
        [ 0.0366,  0.0365,  0.0530,  0.0574,  0.0017, -0.0026, -0.0254,  0.0128,
          0.0401,  0.0755,  0.0427, -0.0183, -0.0595,  0.0010, -0.0258, -0.0370],
        [ 0.0008,  0.0015, -0.0351, -0.0371, -0.0490,  0.0505, -0.0676,  0.0587,
          0.0179,  0.0049, -0.0843,  0.0194, -0.0533, -0.0466,  0.0026,  0.0480],
        [-0.0608,  0.0