What defines a test? When does it stop being a training test and just start being training?

In [1]:
import torch
import sys
import glob
import matplotlib.pyplot as plt
import numpy as np

# # Importing the project from a subfolder
sys.path.append('./project')

from models.transformer import Aggregator
from models.tokenizer import Tokenizer
from models.vggt import VGGT, unflatten_tokens
from heads.camera_head import CameraHead

from dataloader.projection import *
from dataloader.dataset import *

In [2]:
# Load dataset and select device
path = "/sdf/home/y/youngsam/data/dune/larnet/h5/DataAccessExamples/tutorial_example_v1.h5"
dataset = Dataset(path)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# Initialize model and associated optimizer
model = VGGT()
model = model.to(device)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

i = 0

In [10]:
for _ in range(100):
    # Take a sample
    sample, _, rotations = dataset.choose_events(10, 3)
    patch_counts, all_coords, all_patches = stack_patches(sample)
    patch_counts = torch.Tensor(patch_counts).int().to(device)
    all_coords = torch.Tensor(all_coords).int().to(device)
    all_patches = torch.Tensor(all_patches).to(device)

    # Zero gradients
    optimizer.zero_grad()
    
    # Make predictions
    predictions, test_output, patch_start_idx = model(patch_counts, all_coords, all_patches)
    pred_quaternions = predictions["pose_enc"]

    # Compute the ground truth
    quaternions = torch.tensor(np.array([[r.as_quat() for r in row] for row in rotations])).to(device).float()
    
    # Run backprop
    loss = loss_fn(pred_quaternions, quaternions)
    loss.backward()
    optimizer.step()
    
    # Print loss
    print(f"i={i}, loss={loss.item()}")
    i += 1

i=0, loss=1.136203646659851
i=1, loss=6.668038845062256
i=2, loss=14.131070137023926
i=3, loss=1.113821029663086
i=4, loss=0.4904232323169708
i=5, loss=0.2852943539619446
i=6, loss=0.5181164145469666
i=7, loss=0.27671828866004944
i=8, loss=0.3216931223869324
i=9, loss=0.4111461043357849
i=10, loss=0.3387817442417145
i=11, loss=0.3271327614784241
i=12, loss=0.2937907576560974
i=13, loss=0.35033178329467773
i=14, loss=0.2920547425746918
i=15, loss=0.27357518672943115
i=16, loss=0.3281884789466858
i=17, loss=0.35970374941825867
i=18, loss=0.2719650864601135
i=19, loss=0.2856592833995819
i=20, loss=0.30406367778778076
i=21, loss=0.34112176299095154
i=22, loss=0.2863227128982544
i=23, loss=0.2695299983024597
i=24, loss=0.25360867381095886
i=25, loss=0.2793818414211273
i=26, loss=0.25466814637184143
i=27, loss=0.2809293270111084
i=28, loss=0.28252068161964417
i=29, loss=0.25788435339927673
i=30, loss=0.2546480596065521
i=31, loss=0.25088146328926086
i=32, loss=0.28435787558555603
i=33, loss=