# LeJEPA Edge Captioner - Training Notebook

Train the VL-JEPA style embedding prediction model on COCO Captions.

**Requirements:**
- Kaggle GPU (T4/P100)
- COCO dataset
- Gemma-3 access

## 1. Setup

In [None]:
# Install dependencies
!pip install -q timm datasets transformers accelerate

# Clone the repo (if not already present)
!git clone https://github.com/omar-A-hassan/lejepa.git || true
%cd lejepa

In [None]:
import torch
import sys
sys.path.insert(0, '.')

from lejepa_caption.models import LeJEPACaptioner, get_captioner

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
print(f'GPU: {torch.cuda.get_device_name(0) if device == "cuda" else "N/A"}')

## 2. Load Model

In [None]:
# Create model (use 'tiny' for faster training)
model = get_captioner('small')

params = model.num_parameters
print(f'Encoder: {params["encoder"] / 1e6:.1f}M')
print(f'Connector: {params["connector"] / 1e6:.1f}M')
print(f'Predictor: {params["predictor"] / 1e6:.1f}M')
print(f'Total: {params["total"] / 1e6:.1f}M')

## 3. Load Gemma-3 (for target embeddings)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load Gemma-3 (requires accepting license on HuggingFace)
LLM_NAME = 'google/gemma-3-270m-it'

tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
llm = AutoModelForCausalLM.from_pretrained(
    LLM_NAME,
    torch_dtype=torch.bfloat16,
    device_map='auto',
)
llm.eval()

# Freeze LLM
for param in llm.parameters():
    param.requires_grad = False

print(f'LLM embedding dim: {llm.get_input_embeddings().weight.shape}')

## 4. Load COCO Dataset

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision import transforms

# Load COCO
coco = load_dataset('HuggingFaceM4/COCO', split='train[:10000]')  # Subset for demo

# Transforms
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def collate_fn(batch):
    images = [transform(item['image'].convert('RGB')) for item in batch]
    captions = [item['sentences'][0]['raw'] for item in batch]
    return torch.stack(images), captions

train_loader = DataLoader(coco, batch_size=16, shuffle=True, collate_fn=collate_fn)
print(f'Batches: {len(train_loader)}')

## 5. Training Loop

In [None]:
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm import tqdm

model = model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

MAX_LEN = 50
EPOCHS = 3

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    for images, captions in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
        images = images.to(device)
        
        # Forward
        pred_embeds = model(images, num_tokens=MAX_LEN)
        
        # Get target embeddings
        with torch.no_grad():
            tokens = tokenizer(
                captions, 
                padding='max_length', 
                truncation=True, 
                max_length=MAX_LEN,
                return_tensors='pt'
            ).to(device)
            target_embeds = llm.get_input_embeddings()(tokens.input_ids).float()
        
        # MSE loss
        loss = F.mse_loss(pred_embeds, target_embeds)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch+1}: Loss = {avg_loss:.4f}')

## 6. Save Model

In [None]:
# Save checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'config': model.config,
}, 'lejepa_captioner.pt')

print('Model saved!')

## 7. Test Inference

In [None]:
# Test with a sample image
model.eval()
with torch.no_grad():
    sample_img = images[:1]
    pred = model(sample_img)
    print(f'Predicted embedding shape: {pred.shape}')
    
    # Optional: Decode using LLM
    # output = llm.generate(inputs_embeds=pred, max_new_tokens=20)
    # print(tokenizer.decode(output[0]))