In [7]:
import torch
import sys
import glob
import matplotlib.pyplot as plt
import numpy as np

# # Importing the project from a subfolder
sys.path.append('./project')

from models.transformer import Aggregator
from models.tokenizer import Tokenizer
from models.vggt import VGGT, unflatten_tokens

from dataloader.projection import *
from dataloader.dataset import *

# Test the Aggregator

In [2]:
test_model = Aggregator(embed_dim=64)

In [15]:
B, S, P, C = 5, 3, 20, 64 # B events, S images per event, P tokens per image, C elements per token

device = "cuda" if torch.cuda.is_available() else "cpu"

test_input = torch.tensor(np.random.randn(B, S, P, C)).float().to(device)
test_pos = torch.tensor(np.random.randint(1, 8, size=(B, S, P, 2))).to(device)
test_model = test_model.to(device)

In [20]:
test_output, test_idx = test_model.forward(test_input, test_pos)

In [23]:
test_output[0].shape

torch.Size([5, 3, 25, 128])

# Test the tokenizer

In [6]:
test_model = Tokenizer()

In [7]:
B, H = 5, 16 # B total patches, each image a HxH square

device = "cuda" if torch.cuda.is_available() else "cpu"

test_input = torch.tensor(np.random.randn(B, H, H)).float().to(device)
test_input = test_input.view(B, 1, H, H)
test_model = test_model.to(device)

In [8]:
test_output = test_model.forward(test_input)

In [10]:
test_output.shape

torch.Size([5, 256, 1, 1])

In [11]:
test_output.view(5, -1).shape

torch.Size([5, 256])

# Test both

In [2]:
test_tokenizer = Tokenizer()
test_aggregator = Aggregator(embed_dim=256)

In [9]:
B, S, P, H = 1, 3, 5, 16 # B events, S images per event, P patches per image, HxH patches

device = "cuda" if torch.cuda.is_available() else "cpu"

test_input = torch.tensor(np.random.randn(B, S, P, H, H)).float().to(device)
test_pos = torch.tensor(np.random.randint(1, H+1, size=(B, S, P, 2))).to(device)
test_tokenizer, test_aggregator = test_tokenizer.to(device), test_aggregator.to(device)

In [14]:
# Simple flattening for this test. In the real case this would involve recording the sequence lengths
test_tokens = test_tokenizer.forward(test_input.view(B*S*P, H, H)).view(B, S, P, 256)
test_output, test_idx = test_aggregator.forward(test_tokens, test_pos)

In [16]:
test_output[0].shape

torch.Size([1, 3, 10, 512])

# Test the combined model

In [60]:
# Load the dataset as previously demonstrated, also get device
path = "/sdf/home/y/youngsam/data/dune/larnet/h5/DataAccessExamples/tutorial_example_v1.h5"

dataset = Dataset(path)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [61]:
# Grab a sample
sample = dataset.choose_events(10, 3)
patch_counts, all_coords, all_patches = stack_patches(sample)
patch_counts = torch.Tensor(patch_counts).int().to(device)
all_coords = torch.Tensor(all_coords).int().to(device)
all_patches = torch.Tensor(all_patches).to(device)

In [62]:
test_model = VGGT()
test_model = test_model.to(device)

In [63]:
test_output, patch_start_idx = test_model(patch_counts, all_coords, all_patches)

In [64]:
len(test_output), test_output[-1].shape
# 24 blocks, results from every block; final result is NxSx(P+5)x(2*D)
# P+5 because 1 camera token and 4 register tokens added
# D*2 because ???

(24, torch.Size([10, 3, 71, 512]))

In [65]:
# Try some backpropagation
loss = test_output[-1].sum()
%time loss.backward()

CPU times: user 347 ms, sys: 137 ms, total: 484 ms
Wall time: 615 ms


In [44]:
# A bit of performance information
%timeit dataset.choose_events(10, 3)
%timeit stack_patches(sample)
%timeit test_model.forward(patch_counts, all_coords, all_patches)

139 ms ± 16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
250 μs ± 4.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
151 ms ± 4.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [45]:
# Does it scale with batch size?
%timeit dataset.choose_events(50, 3)
%timeit stack_patches(sample)
%timeit test_model.forward(patch_counts, all_coords, all_patches)

713 ms ± 94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
252 μs ± 13.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
160 ms ± 9.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
