In [None]:
import torch
import torch.nn.functional as F
from datasets import load_dataset
import numpy as np
from tqdm import tqdm as tqdm

ds = load_dataset("commaai/commavq", num_proc=40);
files = ds['40']['path'] # 40 is the val set

BOS_TOKEN = 1024
TOKENS_PER_FRAME = 129
BS = 10
CONTEXT_SIZE_FRAMES = 20
N_FRAMES = 1200
N = N_FRAMES - 20

# Create the data slicing here 
# 59 non-overlapping slices of 20 frames (we drop the last one)
# The target is just the slice shifted by 1
indices = np.arange(0, N*TOKENS_PER_FRAME)
indices = np.array(np.split(indices, N//CONTEXT_SIZE_FRAMES))
# batch them
indices = [indices[i:i+BS] for i in range(0, len(indices), BS)]

In [None]:
total_losses = []

pbar = tqdm(files)
for f in pbar:
  tokens = np.load(f)
  tokens = tokens.reshape(N_FRAMES, TOKENS_PER_FRAME-1) # TOKENS_PER_FRAME includes the BOS token
  tokens = np.c_[np.ones(len(tokens), dtype=np.int64)*BOS_TOKEN, tokens]
  tokens = tokens.reshape(-1)
  tokens = torch.from_numpy(tokens).long().cuda()
  losses, sizes = [], []
  for ii in indices:
    with torch.no_grad(): # potentially add AMP context etc.
      x = tokens[ii.ravel()]
      x = x.reshape(ii.shape[0], ii.shape[1])
      
      # your model here!
      pred = model(x)

      y = tokens[ii.ravel()+1]
      y = y.reshape(ii.shape[0], ii.shape[1])
      loss = F.cross_entropy(pred.reshape(-1, pred.size(-1)), y.reshape(-1)).detach().cpu().numpy() * ii.shape[0]
      
      losses.append(loss)
      sizes.append(ii.shape[0])
  
  total_loss = np.sum(losses)/np.sum(sizes)
  total_losses.append(total_loss)
  pbar.set_description(f"total loss {np.mean(total_losses)}")