1. Environment Check

In [None]:
!nvidia-smi
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)


2. Clone repo

In [None]:
!git clone https://github.com/cocooda/JEPAPrimitiveLayer.git
%cd JEPAPrimitiveLayer
import sys
sys.path.append("/kaggle/working/JEPAPrimitiveLayer")

3. Imports & Config

In [None]:
import os
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# ------------------ Import config from repo ------------------
from config import EMBED_DIM, PATCH_SIZE, IMAGE_H, IMAGE_W, TOKEN_DIM, ACTION_DIM, MASK_RATIO, VICREG_WEIGHT, DRIFT_WEIGHT, JEPA_WEIGHT, EMA_DECAY, BATCH_SIZE, NUM_STEPS, LR, DEVICE, DATA_ROOT

# Your modules
from utils.dataset import DrivingSceneDataset
from utils.patch_utils import unpatchify
from models.primitive_layer import PrimitiveLayer


4. Load dataset

In [None]:
DATA_ROOT = "/kaggle/input/test1t/exported_maps"
dataset = DrivingSceneDataset(DATA_ROOT)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
print(f"Loaded {len(dataset)} samples from {DATA_ROOT}")

5. Initialize Model


In [None]:
model = PrimitiveLayer(patch_size=PATCH_SIZE).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

6. Check if checkpoints exist

In [None]:
os.makedirs("/kaggle/working/checkpoints", exist_ok=True)
ckpt_path = "/kaggle/working/checkpoints/primitive_layer.pth"

if os.path.exists(ckpt_path):
    model.load_state_dict(torch.load(ckpt_path))
    model.eval()
    print(f"Checkpoint loaded from {ckpt_path}")
else:
    print("No checkpoint found. Training from scratch.")

7. Training loop

In [None]:
losses = []
for step, (frames, kin) in enumerate(loader):
    if step >= NUM_STEPS:
        break
    frames, kin = frames.to(DEVICE), kin.to(DEVICE)
    optimizer.zero_grad()
    _, total_loss, loss_dict = model(frames, kin)
    total_loss.backward()
    optimizer.step()

    losses.append(total_loss.item())
    print(f"Step {step+1}/{NUM_STEPS} | Loss = {total_loss.item():.6f} | JEPA={loss_dict['JEPA'].item():.6f} VICReg={loss_dict['VICReg'].item():.6f} Drift={loss_dict['Drift'].item():.6f}")


8. Save checkpoint

In [None]:
torch.save(model.state_dict(), ckpt_path)
print(f"Checkpoint saved at {ckpt_path}")

9. Plot Training Loss

In [None]:
plt.plot(losses)
plt.title("Training Loss on Tokenized BEV Dataset")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.show()

10. Inference & Visualization

In [None]:
model.eval()
sample_frames, sample_kin = next(iter(loader))
sample_frames, sample_kin = sample_frames.to(DEVICE), sample_kin.to(DEVICE)

with torch.no_grad():
    pred_tokens, _, _ = model(sample_frames, sample_kin)

# Unpatchify tokens to images
B, N, D = pred_tokens.shape
ph = pw = int(N ** 0.5)
pred_imgs = unpatchify(pred_tokens.cpu(), ph, pw, patch_size=PATCH_SIZE)

plt.figure(figsize=(12,4))
for i in range(min(4, B)):
    plt.subplot(1,4,i+1)
    plt.imshow(pred_imgs[i].permute(1,2,0))
    plt.axis("off")
plt.show()