# üöÄ Tunix Hackathon - Lightweight Version (PyTorch)

This is a simplified version that:
- ‚úÖ Works without HF_TOKEN (uses public models)
- ‚úÖ Uses PyTorch instead of JAX (easier setup)
- ‚úÖ Simple GRPO implementation
- ‚úÖ Direct dataset loading from HuggingFace

**Perfect for quick experiments on Kaggle!**

## Installation

In [None]:
# Install packages if not already installed
import subprocess
import sys

def install_if_needed(package):
    try:
        __import__(package)
        print(f"‚úÖ {package} already installed")
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])

# Install required packages
packages = ["torch", "transformers", "datasets", "accelerate"]
for pkg in packages:
    install_if_needed(pkg)

print("\n‚úÖ All packages ready!")

## Imports

In [None]:
# Suppress transformers warnings about generation flags
import warnings
import os
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'  # Suppress info/warning messages
warnings.filterwarnings('ignore', category=UserWarning)

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import random


## 1. Load Model (Public - No Token Required)

In [None]:
# Using Qwen 0.5B - truly public model, no token needed
# Alternative: google/gemma-2-2b-it (might need token)
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"

print(f"Loading {MODEL}...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    print("‚úÖ Model loaded successfully!")
except Exception as e:
    print(f"‚ùå Error loading {MODEL}: {e}")
    print("\nTrying alternative model: microsoft/Phi-3-mini-4k-instruct")
    MODEL = "microsoft/Phi-3-mini-4k-instruct"
    tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    print("‚úÖ Alternative model loaded!")

## 2. Load GSM8K Dataset

In [None]:
print("Downloading GSM8K via HuggingFace...")
ds = load_dataset("openai/gsm8k", "main")
train_data = ds["train"]
test_data = ds["test"]
print(f"‚úÖ Train samples: {len(train_data)}")
print(f"‚úÖ Test samples: {len(test_data)}")

## 3. Prepare Data

In [None]:
def encode(example):
    q = example["question"]
    answer = example["answer"]
    prompt = f"Question: {q}\nAnswer:"
    
    example["prompt"] = prompt
    example["target"] = answer
    return example

train_data = train_data.map(encode)
print("‚úÖ Data prepared!")

## 4. GRPO Functions

In [None]:
def generate_answer(prompt, max_new_tokens=64):
    """Custom generation loop with safe sampling - NO model.generate()"""
    try:
        # Tokenize input
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        input_ids = inputs["input_ids"]
        generated_ids = input_ids.clone()
        
        model.eval()  # Set to eval mode for generation
        
        with torch.no_grad():
            for step in range(max_new_tokens):
                # Forward pass - get logits for next token
                outputs = model(input_ids=generated_ids)
                next_token_logits = outputs.logits[:, -1, :]  # Shape: (batch, vocab_size)
                
                # ============================================================
                # üîí Safe Sampling - Handle NaN/Inf (Your Solution)
                # ============================================================
                logits = next_token_logits
                
                # 1) Remove NaN/Inf values
                logits = torch.nan_to_num(logits, neginf=-1e9, posinf=1e9)
                
                # 2) Clamp to prevent overflow
                logits = torch.clamp(logits, -50, 50)
                
                # 3) Apply temperature
                temperature = 0.7
                logits = logits / temperature
                
                # 4) Stable softmax (numerically stable)
                max_logits = logits.max(dim=-1, keepdim=True).values
                stable_logits = logits - max_logits
                exp_logits = torch.exp(stable_logits)
                probs = exp_logits / exp_logits.sum(dim=-1, keepdim=True)
                
                # 5) Fix any remaining NaN
                probs = torch.nan_to_num(probs, nan=0.0)
                probs = probs / probs.sum(dim=-1, keepdim=True)
                
                # 6) Apply top_p filtering (nucleus sampling)
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                top_p = 0.9
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0  # Keep at least one token
                
                # Scatter back to original indices
                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                probs[indices_to_remove] = 0.0
                
                # Renormalize after filtering
                probs = probs / probs.sum(dim=-1, keepdim=True)
                
                # 7) Safe sampling with fallback
                try:
                    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                except RuntimeError as e:
                    # Fallback to greedy if sampling fails
                    if step == 0:  # Only print once
                        print(f"‚ö†Ô∏è Sampling failed at step {step}, using greedy: {e}")
                    next_tokens = torch.argmax(probs, dim=-1)
                
                # Append to generated sequence
                generated_ids = torch.cat([generated_ids, next_tokens.unsqueeze(1)], dim=1)
                
                # Stop if EOS token
                if tokenizer.eos_token_id is not None:
                    if next_tokens.item() == tokenizer.eos_token_id:
                        break
        
        # Decode and return
        text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        return text.replace(prompt, "").strip()
        
    except Exception as e:
        print(f"‚ùå Generation error: {e}")
        import traceback
        traceback.print_exc()
        return "[Generation Error]"
    finally:
        model.train()  # Restore training mode if needed

def reward_fn(pred, gold):
    """Simple reward: 1 if correct answer number is in prediction"""
    if pred == "[Generation Error]" or not pred:
        return 0.0
    try:
        gold_num = gold.split("####")[-1].strip()
        return 1.0 if gold_num in pred else 0.0
    except:
        return 0.0

## 5. GRPO Training (Simplified)

In [None]:
# Quick model health check
print("Testing model generation...")
test_prompt = "Question: What is 2+2?\nAnswer:"
try:
    test_output = generate_answer(test_prompt, max_new_tokens=20)
    print(f"‚úÖ Model test successful: {test_output[:50]}...")
except Exception as e:
    print(f"‚ùå Model test failed: {e}")
    print("Please check model loading or try a different model.")

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)
EPOCHS = 1
GROUP = 2   # Number of samples per step
STEPS = 200  # Number of training steps

print(f"Starting GRPO training...")
print(f"Epochs: {EPOCHS}, Steps: {STEPS}, Group size: {GROUP}")

model.train()  # Set to training mode

for epoch in range(EPOCHS):
    print(f"\n=== Epoch {epoch+1} ===")
    
    for idx in range(STEPS):
        # Sample random batch
        batch = [train_data[random.randint(0, len(train_data)-1)] for _ in range(GROUP)]
        
        prompts = [b["prompt"] for b in batch]
        golds   = [b["target"] for b in batch]
        
        # Generate predictions
        preds = [generate_answer(p) for p in prompts]
        rewards = [reward_fn(preds[i], golds[i]) for i in range(GROUP)]
        
        # Calculate advantage relative to group average
        avg_reward = sum(rewards) / GROUP
        advantages = [r - avg_reward for r in rewards]
        
        # Calculate losses - FIXED: use proper loss calculation
        losses = []
        for i in range(GROUP):
            # Tokenize prompt
            inp = tokenizer(prompts[i], return_tensors="pt").to(model.device)
            
            # Tokenize generated prediction as target
            pred_tokens = tokenizer(preds[i], return_tensors="pt", add_special_tokens=False).to(model.device)
            
            # Concatenate prompt and prediction for full sequence
            full_input_ids = torch.cat([inp["input_ids"], pred_tokens["input_ids"]], dim=1)
            
            # Create labels (only compute loss on generated part)
            labels = full_input_ids.clone()
            labels[:, :inp["input_ids"].shape[1]] = -100  # Ignore prompt in loss
            
            # Forward pass
            outputs = model(input_ids=full_input_ids, labels=labels)
            logits = outputs.loss
            
            # Weight by advantage (GRPO)
            losses.append(logits * advantages[i])
        
        # Backward pass
        loss = sum(losses) / GROUP
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"‚ö†Ô∏è Step {idx}: Invalid loss, skipping...")
            optimizer.zero_grad()
            continue
            
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()
        optimizer.zero_grad()
        
        if idx % 20 == 0:
            print(f"Step {idx}, Loss = {loss.item():.4f}, Rewards = {rewards}, Avg Reward = {avg_reward:.2f}")

print("\nüéâ Training Done!")
model.eval()  # Set back to eval mode

## 6. Test on Sample

In [None]:
# Test on a sample
sample = test_data[0]
prompt = f"Question: {sample['question']}\nAnswer:"
prediction = generate_answer(prompt)

print("Question:", sample['question'])
print("\nPrediction:", prediction)
print("\nGround Truth:", sample['answer'])