# ðŸš€ Tunix Hackathon - Lightweight Version (PyTorch)

This is a simplified version that:
- âœ… Works without HF_TOKEN (uses public models)
- âœ… Uses PyTorch instead of JAX (easier setup)
- âœ… Simple GRPO implementation
- âœ… Direct dataset loading from HuggingFace

**Perfect for quick experiments on Kaggle!**

## Installation

In [None]:
!pip install -q torch transformers datasets accelerate bitsandbytes

## Imports

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import random

## 1. Load Model (Public - No Token Required)

In [None]:
# Using Gemma 2B - public model, no token needed
MODEL = "google/gemma-2-2b-it"

print(f"Loading {MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype=torch.float16,
    device_map="auto"
)
print("âœ… Model loaded successfully!")

## 2. Load GSM8K Dataset

In [None]:
print("Downloading GSM8K via HuggingFace...")
ds = load_dataset("openai/gsm8k", "main")
train_data = ds["train"]
test_data = ds["test"]
print(f"âœ… Train samples: {len(train_data)}")
print(f"âœ… Test samples: {len(test_data)}")

## 3. Prepare Data

In [None]:
def encode(example):
    q = example["question"]
    answer = example["answer"]
    prompt = f"Question: {q}\nAnswer:"
    
    example["prompt"] = prompt
    example["target"] = answer
    return example

train_data = train_data.map(encode)
print("âœ… Data prepared!")

## 4. GRPO Functions

In [None]:
def generate_answer(prompt, max_new_tokens=64):
    """Generate answer from model"""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    output = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=0.8,
        top_p=0.9
    )
    text = tokenizer.decode(output[0], skip_special_tokens=True)
    return text.replace(prompt, "").strip()

def reward_fn(pred, gold):
    """Simple reward: 1 if correct answer number is in prediction"""
    gold_num = gold.split("####")[-1].strip()
    return 1.0 if gold_num in pred else 0.0

## 5. GRPO Training (Simplified)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)
EPOCHS = 1
GROUP = 2   # Number of samples per step
STEPS = 200  # Number of training steps

print(f"Starting GRPO training...")
print(f"Epochs: {EPOCHS}, Steps: {STEPS}, Group size: {GROUP}")

for epoch in range(EPOCHS):
    print(f"\n=== Epoch {epoch+1} ===")
    
    for idx in range(STEPS):
        # Sample random batch
        batch = [train_data[random.randint(0, len(train_data)-1)] for _ in range(GROUP)]
        
        prompts = [b["prompt"] for b in batch]
        golds   = [b["target"] for b in batch]
        
        # Generate predictions
        preds = [generate_answer(p) for p in prompts]
        rewards = [reward_fn(preds[i], golds[i]) for i in range(GROUP)]
        
        # Calculate advantage relative to group average
        avg_reward = sum(rewards) / GROUP
        advantages = [r - avg_reward for r in rewards]
        
        # Calculate losses
        losses = []
        for i in range(GROUP):
            inp = tokenizer(prompts[i], return_tensors="pt").to(model.device)
            out = tokenizer(preds[i], return_tensors="pt", add_special_tokens=False).to(model.device)
            
            # Forward pass
            outputs = model(**inp, labels=out["input_ids"])
            logits = outputs.loss
            
            # Weight by advantage
            losses.append(logits * advantages[i])
        
        # Backward pass
        loss = sum(losses) / GROUP
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if idx % 20 == 0:
            print(f"Step {idx}, Loss = {loss.item():.4f}, Rewards = {rewards}, Avg Reward = {avg_reward:.2f}")

print("\nðŸŽ‰ Training Done!")

## 6. Test on Sample

In [None]:
# Test on a sample
sample = test_data[0]
prompt = f"Question: {sample['question']}\nAnswer:"
prediction = generate_answer(prompt)

print("Question:", sample['question'])
print("\nPrediction:", prediction)
print("\nGround Truth:", sample['answer'])