1. Environment Check

In [None]:

!nvidia-smi
import torch

# Set device immediately
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

2. Clone repo

In [None]:
!git clone https://github.com/cocooda/JEPAPrimitiveLayer.git
%cd JEPAPrimitiveLayer

3. Imports & Config

In [None]:
import sys
# Add the cloned repo to Python path
sys.path.append("/kaggle/working/JEPAPrimitiveLayer")

EMBED_DIM = 128
PATCH_SIZE = 4
IMAGE_H = 32
IMAGE_W = 32
TOKEN_DIM = 3 * PATCH_SIZE * PATCH_SIZE
ACTION_DIM = 4
MASK_RATIO = 0.15
VICREG_WEIGHT = 0.1
DRIFT_WEIGHT = 0.05
JEPA_WEIGHT = 1.0
EMA_DECAY = 0.99
BATCH_SIZE = 8
NUM_STEPS = 50
LR = 1e-3

DATA_ROOT = "/kaggle/input/test1t/exported_maps"
CKPT_DIR = "/kaggle/working/checkpoints"
os.makedirs(CKPT_DIR, exist_ok=True)
CKPT_PATH = os.path.join(CKPT_DIR, "primitive_layer.pth")

from utils.dataset import DrivingSceneDataset
from utils.patch_utils import unpatchify
from models.primitive_layer import PrimitiveLayer


4. Load dataset

In [None]:
dataset = DrivingSceneDataset(DATA_ROOT)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
print(f"Loaded {len(dataset)} samples from {DATA_ROOT}")

5. Initialize Model


In [None]:
model = PrimitiveLayer(patch_size=PATCH_SIZE).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

6. Check if checkpoints exist

In [None]:
os.makedirs("/kaggle/working/checkpoints", exist_ok=True)
ckpt_path = "/kaggle/working/checkpoints/primitive_layer.pth"

if os.path.exists(ckpt_path):
    model.load_state_dict(torch.load(ckpt_path))
    model.eval()
    print(f"Checkpoint loaded from {ckpt_path}")
else:
    print("No checkpoint found. Training from scratch.")

7. Training loop

In [None]:
losses = []
for step, (frames, kin) in enumerate(loader):
    if step >= NUM_STEPS:
        break
    frames, kin = frames.to(DEVICE), kin.to(DEVICE)
    optimizer.zero_grad()
    _, total_loss, loss_dict = model(frames, kin)
    total_loss.backward()
    optimizer.step()

    losses.append(total_loss.item())
    print(f"Step {step+1}/{NUM_STEPS} | Loss = {total_loss.item():.6f} | JEPA={loss_dict['JEPA'].item():.6f} VICReg={loss_dict['VICReg'].item():.6f} Drift={loss_dict['Drift'].item():.6f}")


8. Save checkpoint

In [None]:
torch.save(model.state_dict(), ckpt_path)
print(f"Checkpoint saved at {ckpt_path}")

9. Plot Training Loss

In [None]:
plt.plot(losses)
plt.title("Training Loss on Tokenized BEV Dataset")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.show()

10. Inference & Visualization

In [None]:
model.eval()
sample_frames, sample_kin = next(iter(loader))
sample_frames, sample_kin = sample_frames.to(DEVICE), sample_kin.to(DEVICE)

with torch.no_grad():
    pred_tokens, _, _ = model(sample_frames, sample_kin)  # (B, N, D)

# Unpatchify tokens to images (B, D, H, W)
B, N, D = pred_tokens.shape
ph = pw = int(N ** 0.5)
pred_imgs = unpatchify(pred_tokens.cpu(), ph, pw, patch_size=PATCH_SIZE)  # (B, D, H, W)

# --- Convert to 3-channel RGB for plotting ---
if pred_imgs.shape[1] >= 3:
    pred_imgs_rgb = pred_imgs[:, :3, :, :]  # take first 3 channels
else:
    # if less than 3 channels, repeat or average
    pred_imgs_rgb = pred_imgs.repeat(1, 3 // pred_imgs.shape[1], 1, 1)

# --- Plot ---
plt.figure(figsize=(12,4))
for i in range(min(4, B)):
    plt.subplot(1,4,i+1)
    plt.imshow(pred_imgs_rgb[i].permute(1,2,0))  # (H, W, 3)
    plt.axis("off")
plt.show()
