# Judicaita: GRPO Training with Google Tunix on TPU

## 🎯 Hackathon Context

This notebook demonstrates **GRPO (Group Relative Policy Optimization)** training for the Judicaita legal AI assistant using:
- **Google Tunix** for RL training infrastructure
- **Gemma 3-1B-IT** as the base model
- **TPU v2-8+** for accelerated training
- **LoRA adapters** for parameter-efficient fine-tuning

This is developed for the Kaggle hackathon to train models that generate explainable legal reasoning with structured XML-formatted outputs.

## ⚡ TPU Requirements

**IMPORTANT**: This notebook requires:
- Google Colab with TPU runtime (TPU v2-8 or higher)
- Runtime type: TPU (not CPU or GPU)
- To enable: Runtime → Change runtime type → Hardware accelerator: TPU

## 📋 What This Notebook Does

1. **Environment Setup**: Install Tunix, JAX, and dependencies for TPU
2. **Model Loading**: Download and initialize Gemma 3-1B-IT with LoRA
3. **Dataset Preparation**: Format training data with XML-tagged reasoning
4. **Reward Function**: Score outputs based on format, reasoning length, and correctness
5. **GRPO Training**: Train with `GRPOLearner` and `RLCluster` on TPU
6. **Export**: Package trained LoRA adapters for Kaggle submission

## 🔄 Data Flow

```
Dataset → Prompts → Model Rollouts → Reward Scoring → GRPO Updates
                                                           ↓
                                              LoRA Adapter Checkpoints
```

## ⚠️ Differences from Main Codebase

| Aspect | Main Codebase | This Notebook |
|--------|---------------|---------------|
| Format | Step-by-step format | XML `<reasoning>`/`<answer>` |
| Framework | PyTorch | JAX/Flax |
| Training | Custom GRPO | Tunix GRPOLearner |
| Hardware | GPU/CPU | TPU v2-8+ |

## 📚 References

- [Google Tunix Documentation](https://tunix.readthedocs.io/)
- [Tunix GRPO Gemma Example](https://github.com/google/tunix/tree/main/examples/grpo_gemma)
- [Gemma Model Card](https://ai.google.dev/gemma/docs)
- [GRPO Paper](https://arxiv.org/abs/2402.03300)
- [Judicaita Repository](https://github.com/clduab11/judicAIta)

## ⚠️ Known Limitations

- **TPU Required**: Cannot run on CPU/GPU without code modifications
- **Memory**: TPU v2-8 has ~64GB; larger models may need v3 or higher
- **Dataset**: Assumes generic legal reasoning tasks (not LegalBench-specific)
- **Checkpoints**: Large checkpoint files may exceed Colab storage limits
- **API Stability**: Tunix API may change; verify imports match your version


## 📦 Step 1: Install Dependencies

Install required packages for TPU training with Tunix and Gemma.

In [None]:
# Install core dependencies
# IMPORTANT: Google Colab TPU requires specific JAX/Flax versions for compatibility
# See: https://tunix.readthedocs.io/en/latest/installation.html

# Install Tunix with TPU support
!pip install -q "google-tunix[tpu]>=0.5.0"

# Install JAX with TPU support - use versions compatible with Colab TPU
!pip install -q jax==0.4.35 jaxlib==0.4.35
!pip install -q flax==0.10.2

# Install ML dependencies
!pip install -q transformers>=4.40.0
!pip install -q huggingface_hub>=0.20.0
!pip install -q datasets>=2.14.0
!pip install -q sentencepiece>=0.1.99
!pip install -q safetensors>=0.4.0

# Verify installation
print("\n📦 Verifying package versions:")
try:
    import jax
    print(f"   JAX: {jax.__version__}")
except ImportError as e:
    print(f"   ❌ JAX import failed: {e}")

try:
    import flax
    print(f"   Flax: {flax.__version__}")
except ImportError as e:
    print(f"   ❌ Flax import failed: {e}")

try:
    import tunix
    print(f"   Tunix: {tunix.__version__ if hasattr(tunix, '__version__') else 'installed'}")
except ImportError as e:
    print(f"   ⚠️ Tunix import will be available after restart: {e}")

print("\n✅ Dependencies installed successfully!")
print("⚠️  IMPORTANT: Runtime restart required for TPU libraries.")
print("   Go to: Runtime → Restart runtime")
print("   Then continue with the next cell.")


### ⚠️ Runtime Restart Required

**STOP HERE** and restart the runtime:
1. Click `Runtime` → `Restart runtime` in the menu
2. After restart, continue from the next cell

This is necessary for TPU libraries to be properly loaded.

## 🚀 Step 2: Initialize TPU Runtime

Set up JAX to use TPU devices.

In [None]:
import jax
import jax.numpy as jnp
from jax.tools import colab_tpu

# Initialize TPU
print("Initializing TPU runtime...")
try:
    colab_tpu.setup_tpu()
    print("✅ TPU initialized successfully!")
except Exception as e:
    print(f"❌ TPU initialization failed: {e}")
    print("\n🔧 Troubleshooting:")
    print("   1. Check runtime type: Runtime → Change runtime type → TPU")
    print("   2. Ensure TPU v2-8 or higher is available")
    print("   3. Try restarting the runtime")
    print("   4. Check Google Cloud TPU quota if using custom project")
    raise

# Verify TPU devices
devices = jax.devices()
print(f"\n📊 TPU Device Information:")
print(f"   Number of devices: {len(devices)}")
print(f"   Device type: {devices[0].platform}")
print(f"   Devices: {devices}")

if len(devices) == 0:
    raise RuntimeError("No TPU devices detected! Please check your runtime configuration.")

print("\n✅ TPU setup complete and verified!")

## 🔐 Step 3: Authenticate with Hugging Face

Login to Hugging Face to download the Gemma model.

In [None]:
from huggingface_hub import login, snapshot_download
import os

# Login to Hugging Face
# You'll be prompted to enter your HF token
# Get your token from: https://huggingface.co/settings/tokens
print("Please enter your Hugging Face token:")
login()

print("\n✅ Authenticated with Hugging Face!")

## 📥 Step 4: Download Gemma 3-1B-IT Model

Download the model files and initialize the tokenizer.

**Note**: Using `gemma-3-1b-it` as it's the latest available Gemma instruction-tuned model. Update to `gemma-3-1b-it` if/when available.

In [None]:
from transformers import AutoTokenizer
import os

# Download model
MODEL_ID = "google/gemma-3-1b-it"  # Using gemma-3-1b-it as gemma-3-1b-it may not be available yet
CACHE_DIR = "./gemma_model_cache"

print(f"Downloading {MODEL_ID}...")
model_path = snapshot_download(
    repo_id=MODEL_ID,
    cache_dir=CACHE_DIR,
    local_dir=f"{CACHE_DIR}/gemma",
    local_dir_use_symlinks=False
)
print(f"✅ Model downloaded to: {model_path}")

# Initialize tokenizer
print("\nInitializing tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
print(f"✅ Tokenizer initialized")
print(f"   Vocab size: {len(tokenizer)}")
print(f"   Special tokens: {tokenizer.special_tokens_map}")

# Test tokenization
test_text = "What is the legal precedent for breach of contract?"
tokens = tokenizer(test_text, return_tensors="np")
print(f"\n📝 Test tokenization:")
print(f"   Input: {test_text}")
print(f"   Token count: {len(tokens['input_ids'][0])}")

## 🔧 Step 5: Create Preprocessing Function

Gemma models don't have native system role support. We'll prepend the system prompt to the first user turn.

In [None]:
def preprocess_with_system_prompt(messages, system_prompt):
    """
    Prepend system prompt to first user message.
    
    Gemma doesn't support system role natively, so we merge it with
    the first user turn as a workaround.
    
    Args:
        messages: List of message dicts with 'role' and 'content'
        system_prompt: System instruction string
        
    Returns:
        Modified messages list with system prompt prepended
    """
    if not messages:
        return messages
    
    processed = messages.copy()
    
    # Find first user message
    for i, msg in enumerate(processed):
        if msg.get('role') == 'user':
            # Prepend system prompt
            original_content = msg['content']
            processed[i]['content'] = f"{system_prompt}\n\n{original_content}"
            break
    
    return processed

# Define system prompt for legal reasoning
SYSTEM_PROMPT = """You are a legal AI assistant. For each question, provide your analysis in this exact format:
<reasoning>Your step-by-step legal reasoning here. Include relevant legal principles, precedents, and analysis. Aim for at least 100 tokens of detailed reasoning.</reasoning>
<answer>Your final answer or conclusion here.</answer>

Always use this XML format and ensure your reasoning is thorough and well-explained."""

# Test preprocessing
test_messages = [
    {"role": "user", "content": "Is a non-compete clause enforceable in California?"}
]
processed = preprocess_with_system_prompt(test_messages, SYSTEM_PROMPT)
print("📝 Test preprocessing:")
print(f"Original: {test_messages[0]['content'][:50]}...")
print(f"\nProcessed length: {len(processed[0]['content'])} chars")
print(f"System prompt prepended: {'<reasoning>' in processed[0]['content']}")
print("\n✅ Preprocessing function ready!")

## 📊 Task 2: Prepare Training Dataset

Create a dataset with XML-tagged reasoning format compatible with Tunix GRPO.

### JSONL Format Requirements

Each training example must be a JSON object with:
- `prompt`: The question or task
- `ground_truth`: The correct answer for evaluation
- `metadata` (optional): Additional info like task_id, difficulty, etc.

In [None]:
import json
import re
from typing import List, Dict, Any

def prepare_dataset_for_tunix(examples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Prepare dataset in Tunix-compatible JSONL format.
    
    Args:
        examples: List of dicts with 'question' and 'answer' fields
        
    Returns:
        List of dicts with 'prompt', 'ground_truth', and 'metadata'
    """
    prepared = []
    
    for idx, ex in enumerate(examples):
        prepared.append({
            "prompt": ex.get("question", ex.get("prompt", "")),
            "ground_truth": ex.get("answer", ex.get("ground_truth", "")),
            "metadata": {
                "example_id": idx,
                "original_question": ex.get("question", ""),
                "task_type": ex.get("task_type", "general_reasoning")
            }
        })
    
    return prepared

# Create synthetic sample data (replace with real data)
sample_examples = [
    {
        "question": "Can an employer in California enforce a non-compete clause against a former employee?",
        "answer": "No, non-compete clauses are generally unenforceable in California except in limited circumstances involving sale of business or dissolution of partnership.",
        "task_type": "legal_qa"
    },
    {
        "question": "What is the statute of limitations for filing a breach of contract claim?",
        "answer": "The statute of limitations varies by jurisdiction. In many states, it is 4-6 years for written contracts and 2-3 years for oral contracts.",
        "task_type": "legal_qa"
    },
    {
        "question": "Under what circumstances can a contract be voided for duress?",
        "answer": "A contract can be voided for duress when one party was forced to enter the agreement through threats, violence, or other improper pressure that overcame their free will.",
        "task_type": "legal_qa"
    },
    {
        "question": "What is required to establish an attorney-client privilege?",
        "answer": "Attorney-client privilege requires: (1) an attorney-client relationship, (2) confidential communication, (3) made for the purpose of seeking or providing legal advice.",
        "task_type": "legal_qa"
    },
]

# Prepare dataset
prepared_dataset = prepare_dataset_for_tunix(sample_examples)

print(f"✅ Prepared {len(prepared_dataset)} training examples")
print(f"\n📝 Sample example:")
print(json.dumps(prepared_dataset[0], indent=2))

# Note: In production, load from file or HuggingFace dataset
print("\n💡 To load from file:")
print("   # with open('data.jsonl', 'r') as f:")
print("   #     examples = [json.loads(line) for line in f]")
print("   #     prepared_dataset = prepare_dataset_for_tunix(examples)")

### Prompt Template with XML Format

Create a template that formats prompts to expect XML-tagged reasoning.

In [None]:
def create_prompt_template(question: str, system_prompt: str = SYSTEM_PROMPT) -> str:
    """
    Create a formatted prompt with XML output expectations.
    
    Args:
        question: The legal question to answer
        system_prompt: System instructions for format
        
    Returns:
        Formatted prompt string
    """
    template = f"""{system_prompt}

Question: {question}

Response:"""
    return template

def validate_xml_format(response: str) -> bool:
    """
    Validate that response contains proper XML tags.
    
    Args:
        response: Model generated response
        
    Returns:
        True if valid XML format, False otherwise
    """
    # Check for both opening and closing tags
    has_reasoning = '<reasoning>' in response and '</reasoning>' in response
    has_answer = '<answer>' in response and '</answer>' in response
    
    return has_reasoning and has_answer

# Apply template to all examples
templated_prompts = []
for example in prepared_dataset:
    templated = {
        "prompt": create_prompt_template(example["prompt"]),
        "ground_truth": example["ground_truth"],
        "metadata": example["metadata"],
        "original_prompt": example["prompt"]
    }
    templated_prompts.append(templated)

print(f"✅ Created {len(templated_prompts)} templated prompts")
print(f"\n📝 Sample templated prompt (first 300 chars):")
print(templated_prompts[0]["prompt"][:300])
print("...")

# Test validation
test_valid = "<reasoning>This is reasoning</reasoning><answer>This is answer</answer>"
test_invalid = "This is just text without tags"
print(f"\n✅ Validation test:")
print(f"   Valid format: {validate_xml_format(test_valid)}")
print(f"   Invalid format: {validate_xml_format(test_invalid)}")

### Tokenization and Batching

Tokenize prompts and prepare batches for training.

In [None]:
import numpy as np
from typing import List, Dict

# Set maximum prompt length
MAX_PROMPT_LENGTH = 512  # Adjust based on your needs (512 or 1024)
MAX_RESPONSE_LENGTH = 512

def tokenize_prompts(prompts: List[str], tokenizer, max_length: int = MAX_PROMPT_LENGTH):
    """
    Tokenize prompts with padding and truncation.
    
    Args:
        prompts: List of prompt strings
        tokenizer: HuggingFace tokenizer
        max_length: Maximum token length
        
    Returns:
        Dict with input_ids and attention_mask
    """
    tokenized = tokenizer(
        prompts,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="np"
    )
    return tokenized

def create_training_batches(dataset: List[Dict], batch_size: int = 4):
    """
    Create batches from dataset.
    
    Args:
        dataset: List of training examples
        batch_size: Number of examples per batch
        
    Returns:
        List of batches, each batch is a list of examples
    """
    batches = []
    for i in range(0, len(dataset), batch_size):
        batch = dataset[i:i + batch_size]
        batches.append(batch)
    return batches

# Tokenize all prompts
all_prompts = [ex["prompt"] for ex in templated_prompts]
tokenized_prompts = tokenize_prompts(all_prompts, tokenizer, MAX_PROMPT_LENGTH)

print(f"✅ Tokenized {len(all_prompts)} prompts")
print(f"   Max length: {MAX_PROMPT_LENGTH} tokens")
print(f"   Shape: {tokenized_prompts['input_ids'].shape}")

# Create final dataset for training
training_dataset = []
for i, ex in enumerate(templated_prompts):
    training_dataset.append({
        "prompt": ex["prompt"],
        "prompt_tokens": tokenized_prompts['input_ids'][i],
        "attention_mask": tokenized_prompts['attention_mask'][i],
        "ground_truth": ex["ground_truth"],
        "metadata": ex["metadata"]
    })

print(f"\n✅ Final training dataset: {len(training_dataset)} examples")
print(f"   Each example has: {list(training_dataset[0].keys())}")

# Validate dataset format
required_fields = ["prompt", "ground_truth", "metadata"]
all_valid = all(all(field in ex for field in required_fields) for ex in training_dataset)
print(f"\n✅ Dataset validation: {'PASSED' if all_valid else 'FAILED'}")

if not all_valid:
    print("❌ Some examples missing required fields!")
else:
    print("   All examples have required fields: prompt, ground_truth, metadata")

## 🎯 Task 3: Implement Custom Reward Function

Create a reward function that scores:
1. **Format**: Proper XML tags
2. **Reasoning Length**: At least 100 tokens
3. **Answer Correctness**: Match with ground truth

In [None]:
import re
from typing import Tuple, Optional

def extract_xml_content(response: str) -> Tuple[Optional[str], Optional[str]]:
    """
    Extract content from <reasoning> and <answer> XML tags.
    
    Args:
        response: Model-generated response string
        
    Returns:
        Tuple of (reasoning_content, answer_content)
        Returns (None, None) if tags are malformed or missing
    """
    try:
        # Extract reasoning
        reasoning_match = re.search(r'<reasoning>(.*?)</reasoning>', response, re.DOTALL)
        reasoning = reasoning_match.group(1).strip() if reasoning_match else None
        
        # Extract answer
        answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
        answer = answer_match.group(1).strip() if answer_match else None
        
        return reasoning, answer
    except Exception as e:
        print(f"Warning: Error extracting XML content: {e}")
        return None, None

# Test extraction with edge cases
test_cases = [
    # Valid case
    "<reasoning>Step by step analysis here</reasoning><answer>Final answer</answer>",
    # Missing tags
    "Just plain text without tags",
    # Partial tags
    "<reasoning>Incomplete reasoning",
    # Nested content
    "<reasoning>Analysis with <term>nested</term> content</reasoning><answer>Yes</answer>",
    # Multi-line
    """<reasoning>
Line 1 of reasoning
Line 2 of reasoning
</reasoning>
<answer>Final answer</answer>"""
]

print("🧪 Testing XML extraction:")
for i, test in enumerate(test_cases, 1):
    reasoning, answer = extract_xml_content(test)
    print(f"\nTest {i}:")
    print(f"  Reasoning found: {reasoning is not None}")
    print(f"  Answer found: {answer is not None}")
    if reasoning:
        print(f"  Reasoning preview: {reasoning[:50]}...")
    if answer:
        print(f"  Answer: {answer}")

print("\n✅ XML extraction function tested with edge cases")

In [None]:
def compute_format_reward(response: str) -> float:
    """
    Reward for valid XML format.
    
    Returns:
        1.0 if both <reasoning> and <answer> tags present and valid
        0.0 otherwise
    """
    reasoning, answer = extract_xml_content(response)
    
    # Check both tags present and have content
    if reasoning is not None and answer is not None:
        if len(reasoning.strip()) > 0 and len(answer.strip()) > 0:
            return 1.0
    
    return 0.0

def compute_reasoning_length_reward(response: str, tokenizer, min_tokens: int = 100) -> float:
    """
    Reward based on reasoning length.
    
    Returns:
        1.0 if reasoning >= min_tokens
        proportional score (tokens/min_tokens) if fewer
        0.0 if no reasoning found
    """
    reasoning, _ = extract_xml_content(response)
    
    if reasoning is None:
        return 0.0
    
    # Tokenize reasoning to count tokens
    tokens = tokenizer(reasoning, return_tensors="np")["input_ids"]
    num_tokens = len(tokens[0])
    
    # Return 1.0 if meets threshold, otherwise proportional
    if num_tokens >= min_tokens:
        return 1.0
    else:
        return num_tokens / min_tokens

def compute_answer_correctness_reward(response: str, ground_truth: str, tokenizer) -> float:
    """
    Reward based on answer correctness.
    
    Returns:
        1.0 for exact match (after normalization)
        Partial credit for token overlap (Jaccard similarity)
        0.0 if no answer found
    """
    _, answer = extract_xml_content(response)
    
    if answer is None:
        return 0.0
    
    # Normalize for comparison
    answer_norm = answer.lower().strip()
    ground_truth_norm = ground_truth.lower().strip()
    
    # Check exact match
    if answer_norm == ground_truth_norm:
        return 1.0
    
    # Tokenize both for overlap calculation
    answer_tokens = set(tokenizer.tokenize(answer_norm))
    truth_tokens = set(tokenizer.tokenize(ground_truth_norm))
    
    # Calculate Jaccard similarity (intersection / union)
    if len(answer_tokens) == 0 or len(truth_tokens) == 0:
        return 0.0
    
    intersection = len(answer_tokens & truth_tokens)
    union = len(answer_tokens | truth_tokens)
    
    jaccard = intersection / union if union > 0 else 0.0
    
    # Return Jaccard similarity as partial credit
    return jaccard

print("✅ Reward component functions defined:")
print("   - compute_format_reward()")
print("   - compute_reasoning_length_reward()")
print("   - compute_answer_correctness_reward()")

In [None]:
from typing import List

def composite_reward_function(
    prompts: List[str],
    completions: List[str],
    metadata: List[Dict],
    tokenizer  # Required for token counting
) -> List[float]:
    """
    Composite reward function compatible with Tunix GRPO.
    
    This function scores model outputs based on:
    - Format reward (30%): Valid XML tags (<reasoning>/<answer>)
    - Reasoning reward (30%): Sufficient reasoning length (>=100 tokens)
    - Correctness reward (40%): Answer matches ground truth
    
    Args:
        prompts: List of input prompts
        completions: List of model completions
        metadata: List of metadata dicts with ground_truth
        tokenizer: HuggingFace tokenizer for token counting
        
    Returns:
        List of scalar rewards (one per example)
    """
    # Reward weights
    FORMAT_WEIGHT = 0.3
    REASONING_WEIGHT = 0.3
    CORRECTNESS_WEIGHT = 0.4
    
    rewards = []
    
    for i, (prompt, completion, meta) in enumerate(zip(prompts, completions, metadata)):
        # Compute each reward component
        format_reward = compute_format_reward(completion)
        reasoning_reward = compute_reasoning_length_reward(completion, tokenizer, min_tokens=100)
        
        # Get ground truth from metadata
        ground_truth = meta.get("ground_truth", "")
        correctness_reward = compute_answer_correctness_reward(completion, ground_truth, tokenizer)
        
        # Aggregate rewards
        total_reward = (
            FORMAT_WEIGHT * format_reward +
            REASONING_WEIGHT * reasoning_reward +
            CORRECTNESS_WEIGHT * correctness_reward
        )
        
        rewards.append(total_reward)
        
        # Log breakdown for first few examples
        if i < 3:
            print(f"\n📊 Example {i} reward breakdown:")
            print(f"   Format: {format_reward:.2f} (weight: {FORMAT_WEIGHT})")
            print(f"   Reasoning: {reasoning_reward:.2f} (weight: {REASONING_WEIGHT})")
            print(f"   Correctness: {correctness_reward:.2f} (weight: {CORRECTNESS_WEIGHT})")
            print(f"   Total: {total_reward:.2f}")
    
    return rewards


def tunix_reward_wrapper(prompts: List[str], outputs: List[str]) -> List[float]:
    """
    Wrapper function matching Tunix RewardFn signature: (prompts, outputs) -> rewards.
    
    This wrapper adapts our composite_reward_function to Tunix's expected signature.
    It extracts metadata from the global training_dataset for ground truth comparison.
    
    Args:
        prompts: List of prompt strings
        outputs: List of generated output strings
        
    Returns:
        List of float reward values
    """
    # Build metadata from training dataset (prompts contain the original questions)
    metadata = []
    for prompt in prompts:
        # Find matching ground truth from training_dataset
        found = False
        for example in training_dataset:
            if example["prompt"] in prompt or prompt in example["prompt"]:
                metadata.append({"ground_truth": example["ground_truth"]})
                found = True
                break
        if not found:
            metadata.append({"ground_truth": ""})
    
    return composite_reward_function(prompts, outputs, metadata, tokenizer)


# Test reward function
print("🧪 Testing reward function...")
test_prompts = ["Test question"]
test_completions = [
    "<reasoning>This is a detailed legal analysis with sufficient tokens to explain the reasoning behind the answer. We consider precedent, statutory law, and policy implications.</reasoning><answer>Yes, it is enforceable.</answer>"
]
test_metadata = [{"ground_truth": "Yes, it is enforceable."}]

test_rewards = composite_reward_function(test_prompts, test_completions, test_metadata, tokenizer)
print(f"\n✅ Reward function test complete")
print(f"   Test reward: {test_rewards[0]:.2f}")
print("\n✅ Composite reward function ready for Tunix GRPO!")
print("\n💡 Use tunix_reward_wrapper() as the reward function for GRPOLearner")


In [None]:
# Verify Tunix installation before training setup
print("📦 Verifying Tunix installation...")

import sys

# Check Tunix availability
try:
    import tunix
    print(f"✅ Tunix installed: {tunix.__version__ if hasattr(tunix, '__version__') else 'version unknown'}")
except ImportError as e:
    print(f"❌ Tunix not available: {e}")
    print("\n🔧 To install Tunix:")
    print("   !pip install 'google-tunix[tpu]>=0.5.0'")
    print("   Then restart runtime and run this cell again.")
    raise

# Check required submodules
modules_to_check = [
    ("tunix.rl.grpo.grpo_learner", "GRPOConfig, GRPOLearner"),
    ("tunix.rl.rl_cluster", "RLCluster"),
    ("tunix.models.gemma", "GemmaForCausalLM"),
    ("tunix.peft.lora", "LoRAConfig"),
]

print("\n📋 Checking Tunix submodules:")
all_available = True
for module_path, expected_exports in modules_to_check:
    try:
        module = __import__(module_path, fromlist=[''])
        print(f"   ✅ {module_path}")
    except ImportError as e:
        print(f"   ❌ {module_path}: {e}")
        all_available = False

if all_available:
    print("\n✅ All Tunix modules available!")
else:
    print("\n⚠️ Some modules not available. Check Tunix version and installation.")
    print("   The training cells may need adaptation for your Tunix version.")

# Check JAX backend
print("\n📊 JAX Backend Status:")
import jax
print(f"   JAX version: {jax.__version__}")
print(f"   Backend: {jax.default_backend()}")
print(f"   Devices: {jax.device_count()} ({jax.devices()[0].platform if jax.devices() else 'none'})")

print("\n✅ Environment verified - ready for training setup!")


## 🚀 Task 4: Configure and Execute GRPO Training

Set up LoRA adapters and run GRPO training on TPU.

In [None]:
# LoRA Hyperparameters for parameter-efficient fine-tuning
LORA_CONFIG = {
    "rank": 16,           # LoRA rank (16 or 32 recommended)
    "alpha": 32,          # LoRA alpha (typically 2x rank)
    "dropout": 0.05,      # LoRA dropout for regularization
    "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],  # Attention layers
}

# GRPO Configuration matching Tunix GRPOConfig parameters
# Reference: https://tunix.readthedocs.io/en/latest/api/grpo.html
GRPO_CONFIG = {
    # Rollout settings
    "num_generations": 4,           # Number of response samples per prompt for GRPO
    "max_tokens_to_generate": 512,  # Maximum tokens for rollout generation
    
    # GRPO algorithm hyperparameters
    "beta": 0.04,                   # KL penalty coefficient (prevents policy divergence)
    "epsilon": 0.2,                 # PPO-style clipping parameter
    
    # Training settings
    "learning_rate": 1e-5,          # Learning rate for LoRA parameters
    "batch_size": 4,                # Batch size per TPU core (adjust for memory)
    "num_iterations": 2,            # Number of training epochs/iterations
    
    # Evaluation and checkpointing
    "eval_every_n_steps": 50,       # Evaluate model every N steps
    "checkpoint_every_n_steps": 100, # Save checkpoint every N steps
}

# Training configuration for RLCluster
TRAINING_CONFIG = {
    "warmup_steps": 10,             # Learning rate warmup steps
    "weight_decay": 0.01,           # Weight decay for regularization
    "max_grad_norm": 1.0,           # Gradient clipping threshold
    "log_every_n_steps": 10,        # Log metrics every N steps
}

print("✅ Configuration defined:")
print("\n🔧 LoRA Configuration:")
for k, v in LORA_CONFIG.items():
    print(f"   {k}: {v}")
print("\n🎯 GRPO Configuration:")
for k, v in GRPO_CONFIG.items():
    print(f"   {k}: {v}")
print("\n📊 Training Configuration:")
for k, v in TRAINING_CONFIG.items():
    print(f"   {k}: {v}")

print("\n💡 Hyperparameter Rationale:")
print("   - LoRA rank=16: Balance between capacity and memory efficiency")
print("   - num_generations=4: Standard for GRPO variance reduction")
print("   - beta=0.04: Conservative KL penalty to prevent policy divergence")
print("   - learning_rate=1e-5: Safe starting point for LoRA fine-tuning")
print("   - max_tokens_to_generate=512: Sufficient for detailed legal reasoning")


### 🔧 Initialize Training Components

This section sets up the Tunix GRPO training infrastructure:

1. **Import Tunix modules**: GRPOConfig, GRPOLearner, RLCluster
2. **Load and configure models**: Actor (trainable) and Reference (frozen) policies
3. **Setup TPU mesh**: Configure sharding for distributed training
4. **Initialize learner**: Create GRPOLearner with reward function

**Prerequisites**:
- TPU runtime initialized (verified in Step 2)
- Model downloaded (completed in Step 4)
- Reward function defined (completed above)
- Training dataset prepared (completed above)

**Documentation**:
- [Tunix GRPO Guide](https://tunix.readthedocs.io/en/latest/tutorials/grpo.html)
- [Official GRPO Gemma Example](https://github.com/google/tunix/tree/main/examples/grpo_gemma)


In [None]:
# Import Tunix GRPO modules
print("📦 Importing Tunix modules...")

try:
    from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
    from tunix.rl import rl_cluster as rl_cluster_lib
    from tunix.rl.rollout import base_rollout
    from tunix.models import gemma as gemma_lib
    print("✅ Tunix modules imported successfully!")
except ImportError as e:
    print(f"❌ Tunix import failed: {e}")
    print("\n🔧 Troubleshooting:")
    print("   1. Verify Tunix is installed: pip install google-tunix[tpu]")
    print("   2. Restart runtime after installation")
    print("   3. Check Tunix version compatibility")
    raise

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import os

# Create checkpoint directories
CHECKPOINT_DIR = "./checkpoints"
FINAL_DIR = "./final_checkpoint"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(FINAL_DIR, exist_ok=True)

print(f"✅ Checkpoint directories created:")
print(f"   Intermediate: {CHECKPOINT_DIR}")
print(f"   Final: {FINAL_DIR}")

# Setup TPU mesh for distributed training
print("\n🔧 Setting up TPU mesh...")
devices = jax.devices()
num_devices = len(devices)

# Create 1D mesh for data parallelism across TPU cores
mesh = Mesh(devices, axis_names=("data",))
print(f"✅ TPU mesh created with {num_devices} devices")
print(f"   Mesh shape: {mesh.shape}")
print(f"   Axis names: {mesh.axis_names}")

# Load Gemma model for GRPO training
print("\n📥 Loading Gemma model for GRPO...")

# Create model configuration
model_config = gemma_lib.GemmaConfig.from_pretrained(model_path)
print(f"   Model config loaded: {type(model_config).__name__}")

# Initialize actor model (trainable policy with LoRA)
print("\n🎭 Initializing actor model (trainable)...")
actor_model = gemma_lib.GemmaForCausalLM.from_pretrained(
    model_path,
    dtype=jnp.bfloat16,  # Use bfloat16 for TPU efficiency
)

# Apply LoRA configuration to actor model
from tunix.peft import lora as lora_lib

lora_config = lora_lib.LoRAConfig(
    rank=LORA_CONFIG["rank"],
    alpha=LORA_CONFIG["alpha"],
    dropout=LORA_CONFIG["dropout"],
    target_modules=LORA_CONFIG["target_modules"],
)
actor_model = lora_lib.apply_lora(actor_model, lora_config)
print(f"   LoRA applied: rank={LORA_CONFIG['rank']}, alpha={LORA_CONFIG['alpha']}")

# Initialize reference model (frozen copy for KL penalty)
print("\n📋 Initializing reference model (frozen)...")
reference_model = gemma_lib.GemmaForCausalLM.from_pretrained(
    model_path,
    dtype=jnp.bfloat16,
)
# Reference model parameters are frozen (no gradients)
print("   Reference model loaded (frozen for KL divergence)")

print("\n✅ Models initialized successfully!")
print(f"   Actor model: LoRA-adapted, trainable")
print(f"   Reference model: Frozen for KL penalty calculation")

# Create RLCluster configuration
print("\n🔧 Creating RLCluster...")

# Define sharding specs for model parallelism
data_sharding = NamedSharding(mesh, PartitionSpec("data"))

# Create RLCluster with actor and reference models
rl_cluster = rl_cluster_lib.RLCluster(
    actor_model=actor_model,
    reference_model=reference_model,
    tokenizer=tokenizer,
    mesh=mesh,
    data_sharding=data_sharding,
)
print("✅ RLCluster created successfully!")

# Create GRPO configuration
print("\n🎯 Creating GRPOConfig...")
grpo_config = GRPOConfig(
    num_generations=GRPO_CONFIG["num_generations"],
    max_tokens_to_generate=GRPO_CONFIG["max_tokens_to_generate"],
    beta=GRPO_CONFIG["beta"],
    epsilon=GRPO_CONFIG["epsilon"],
    learning_rate=GRPO_CONFIG["learning_rate"],
    warmup_steps=TRAINING_CONFIG["warmup_steps"],
    weight_decay=TRAINING_CONFIG["weight_decay"],
    max_grad_norm=TRAINING_CONFIG["max_grad_norm"],
)
print(f"✅ GRPOConfig created:")
print(f"   num_generations: {grpo_config.num_generations}")
print(f"   max_tokens_to_generate: {grpo_config.max_tokens_to_generate}")
print(f"   beta (KL penalty): {grpo_config.beta}")
print(f"   learning_rate: {grpo_config.learning_rate}")

# Initialize GRPO Learner
print("\n🎓 Initializing GRPOLearner...")
grpo_learner = GRPOLearner(
    rl_cluster=rl_cluster,
    algo_config=grpo_config,
    reward_fns=[tunix_reward_wrapper],  # Use our wrapped reward function
)
print("✅ GRPOLearner initialized!")
print("   Reward function: tunix_reward_wrapper (composite XML/length/correctness)")

print("\n" + "="*60)
print("✅ TRAINING SETUP COMPLETE")
print("="*60)
print("\nReady to execute GRPO training loop in the next cell.")


In [None]:
# Execute GRPO Training
print("🎯 Starting GRPO Training...")
print("="*60)

import time
from datetime import datetime

# Prepare training dataset in Tunix format
print("\n📊 Preparing training data...")
train_prompts = [ex["prompt"] for ex in training_dataset]
print(f"   Training examples: {len(train_prompts)}")

# Training configuration
num_iterations = GRPO_CONFIG["num_iterations"]
batch_size = GRPO_CONFIG["batch_size"]
eval_every = GRPO_CONFIG["eval_every_n_steps"]
checkpoint_every = GRPO_CONFIG["checkpoint_every_n_steps"]
log_every = TRAINING_CONFIG["log_every_n_steps"]

print(f"\n📋 Training Configuration:")
print(f"   Iterations: {num_iterations}")
print(f"   Batch size: {batch_size}")
print(f"   Eval every: {eval_every} steps")
print(f"   Checkpoint every: {checkpoint_every} steps")

# Training metrics storage
training_metrics = {
    "losses": [],
    "rewards": [],
    "kl_divergences": [],
    "steps": [],
}

# Execute training
start_time = time.time()
global_step = 0

try:
    with mesh:
        for iteration in range(num_iterations):
            print(f"\n{'='*60}")
            print(f"📈 Iteration {iteration + 1}/{num_iterations}")
            print(f"{'='*60}")
            
            iteration_start = time.time()
            
            # Create batches for this iteration
            num_batches = (len(train_prompts) + batch_size - 1) // batch_size
            
            for batch_idx in range(num_batches):
                # Get batch prompts
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(train_prompts))
                batch_prompts = train_prompts[start_idx:end_idx]
                
                # Execute GRPO training step
                step_metrics = grpo_learner.train_step(
                    prompts=batch_prompts,
                )
                
                global_step += 1
                
                # Store metrics
                training_metrics["losses"].append(step_metrics.get("loss", 0.0))
                training_metrics["rewards"].append(step_metrics.get("mean_reward", 0.0))
                training_metrics["kl_divergences"].append(step_metrics.get("kl_divergence", 0.0))
                training_metrics["steps"].append(global_step)
                
                # Log progress
                if global_step % log_every == 0:
                    print(f"\n   Step {global_step}:")
                    print(f"      Loss: {step_metrics.get('loss', 0.0):.4f}")
                    print(f"      Mean Reward: {step_metrics.get('mean_reward', 0.0):.4f}")
                    print(f"      KL Divergence: {step_metrics.get('kl_divergence', 0.0):.4f}")
                
                # Evaluation
                if global_step % eval_every == 0:
                    print(f"\n   📊 Evaluation at step {global_step}:")
                    # Generate sample output
                    sample_prompt = train_prompts[0]
                    sample_output = grpo_learner.generate(
                        prompts=[sample_prompt],
                        max_tokens=GRPO_CONFIG["max_tokens_to_generate"],
                    )[0]
                    
                    # Validate output format
                    has_format = validate_xml_format(sample_output)
                    reasoning, answer = extract_xml_content(sample_output)
                    
                    print(f"      Valid XML format: {has_format}")
                    if reasoning:
                        reasoning_tokens = len(tokenizer.encode(reasoning))
                        print(f"      Reasoning tokens: {reasoning_tokens}")
                    print(f"      Sample output preview: {sample_output[:200]}...")
                
                # Checkpoint
                if global_step % checkpoint_every == 0:
                    checkpoint_path = f"{CHECKPOINT_DIR}/step_{global_step}"
                    grpo_learner.save_checkpoint(checkpoint_path)
                    print(f"\n   💾 Checkpoint saved: {checkpoint_path}")
            
            iteration_time = time.time() - iteration_start
            print(f"\n   ⏱️ Iteration {iteration + 1} completed in {iteration_time:.1f}s")
            
            # Iteration summary
            recent_losses = training_metrics["losses"][-num_batches:]
            recent_rewards = training_metrics["rewards"][-num_batches:]
            print(f"   📊 Iteration Summary:")
            print(f"      Avg Loss: {sum(recent_losses)/len(recent_losses):.4f}")
            print(f"      Avg Reward: {sum(recent_rewards)/len(recent_rewards):.4f}")

    # Training complete
    total_time = time.time() - start_time
    print(f"\n{'='*60}")
    print("✅ TRAINING COMPLETE!")
    print(f"{'='*60}")
    print(f"   Total steps: {global_step}")
    print(f"   Total time: {total_time:.1f}s ({total_time/60:.1f} minutes)")
    print(f"   Final avg loss: {sum(training_metrics['losses'][-10:])/10:.4f}")
    print(f"   Final avg reward: {sum(training_metrics['rewards'][-10:])/10:.4f}")

    # Save final checkpoint
    print(f"\n💾 Saving final checkpoint to {FINAL_DIR}...")
    grpo_learner.save_checkpoint(FINAL_DIR)
    print("✅ Final checkpoint saved!")
    
except KeyboardInterrupt:
    print("\n⚠️ Training interrupted by user!")
    print(f"   Completed steps: {global_step}")
    # Save emergency checkpoint
    emergency_path = f"{CHECKPOINT_DIR}/interrupted_step_{global_step}"
    grpo_learner.save_checkpoint(emergency_path)
    print(f"   Emergency checkpoint saved: {emergency_path}")
    
except Exception as e:
    print(f"\n❌ Training error: {e}")
    print(f"   Last completed step: {global_step}")
    # Try to save checkpoint on error
    try:
        error_path = f"{CHECKPOINT_DIR}/error_step_{global_step}"
        grpo_learner.save_checkpoint(error_path)
        print(f"   Error checkpoint saved: {error_path}")
    except:
        print("   Could not save error checkpoint")
    raise

# Display training summary plot
print("\n📊 Training Metrics Summary:")
print(f"   Steps: {len(training_metrics['steps'])}")
print(f"   Loss range: {min(training_metrics['losses']):.4f} - {max(training_metrics['losses']):.4f}")
print(f"   Reward range: {min(training_metrics['rewards']):.4f} - {max(training_metrics['rewards']):.4f}")


## 📦 Task 5: Export LoRA Adapters and Create Kaggle Submission

Package trained adapters for Kaggle submission.

In [None]:
import os
import shutil

# Create kaggle_upload directory
KAGGLE_DIR = "./kaggle_upload"
os.makedirs(KAGGLE_DIR, exist_ok=True)

print(f"✅ Created Kaggle submission directory: {KAGGLE_DIR}")
print("\n📋 Export checklist:")
print("   [ ] adapter_config.json - LoRA configuration")
print("   [ ] adapter_model.safetensors - LoRA weights")
print("   [ ] tokenizer files (if modified)")
print("   [ ] README with inference instructions")

In [None]:
# Export LoRA Adapters using Tunix API
print("📦 Exporting LoRA adapters...")

import json
import shutil
from safetensors.flax import save_file as save_safetensors

# Export LoRA weights from trained model
print("\n📤 Extracting LoRA weights from actor model...")

try:
    # Method 1: Use Tunix's built-in export (preferred)
    grpo_learner.export_lora_adapters(
        output_dir=KAGGLE_DIR,
        format="safetensors"
    )
    print("✅ LoRA adapters exported using Tunix API")
    
except AttributeError:
    # Method 2: Manual extraction if export method not available
    print("   Using manual extraction method...")
    from tunix.peft import lora as lora_lib
    
    # Extract LoRA weights
    lora_weights = lora_lib.extract_lora_weights(actor_model)
    
    # Save in safetensors format
    adapter_path = f"{KAGGLE_DIR}/adapter_model.safetensors"
    save_safetensors(lora_weights, adapter_path)
    print(f"✅ LoRA weights saved: {adapter_path}")

# Create adapter_config.json
adapter_config = {
    "peft_type": "LORA",
    "task_type": "CAUSAL_LM",
    "r": LORA_CONFIG["rank"],
    "lora_alpha": LORA_CONFIG["alpha"],
    "lora_dropout": LORA_CONFIG["dropout"],
    "target_modules": LORA_CONFIG["target_modules"],
    "inference_mode": True,
    "base_model_name_or_path": MODEL_ID,
    "bias": "none",
    "fan_in_fan_out": False,
}

config_path = f"{KAGGLE_DIR}/adapter_config.json"
with open(config_path, 'w') as f:
    json.dump(adapter_config, f, indent=2)
print(f"✅ Config saved: {config_path}")

# Copy tokenizer files
print("\n📁 Copying tokenizer files...")
tokenizer_files = ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]
for fname in tokenizer_files:
    src = f"{model_path}/{fname}"
    dst = f"{KAGGLE_DIR}/{fname}"
    if os.path.exists(src):
        shutil.copy2(src, dst)
        print(f"   Copied: {fname}")

# Create README
readme_content = f"""# Judicaita GRPO-Trained LoRA Adapters

## Model Information

- **Base Model**: {MODEL_ID}
- **Training Method**: GRPO (Group Relative Policy Optimization)
- **Framework**: Google Tunix + JAX/Flax
- **LoRA Rank**: {LORA_CONFIG["rank"]}
- **LoRA Alpha**: {LORA_CONFIG["alpha"]}
- **Training Platform**: Google Colab TPU

## Inference Usage

### With Transformers + PEFT

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Load base model
base_model = AutoModelForCausalLM.from_pretrained("{MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained("{MODEL_ID}")

# Load LoRA adapters
model = PeftModel.from_pretrained(
    base_model,
    "./adapter_model"  # Path to this directory
)

# Generate
prompt = "Your legal question here"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=512)
response = tokenizer.decode(outputs[0])
```

### Expected Output Format

The model generates responses in XML format:
```xml
<reasoning>
Detailed legal reasoning with analysis...
</reasoning>
<answer>
Final answer or conclusion
</answer>
```

## Training Details

- **Reward Function**: Composite (30% format + 30% length + 40% correctness)
- **GRPO Beta (KL penalty)**: {GRPO_CONFIG["beta"]}
- **Num Generations**: {GRPO_CONFIG["num_generations"]}
- **Learning Rate**: {GRPO_CONFIG["learning_rate"]}

## Validation Criteria

- Reasoning should be >= 100 tokens
- Both XML tags must be present
- Answer should be relevant to the question

## License

Same as base model ({MODEL_ID})
"""

with open(f"{KAGGLE_DIR}/README.md", 'w') as f:
    f.write(readme_content)
print("✅ README.md created")

# List exported files
print("\n📋 Exported files:")
for item in os.listdir(KAGGLE_DIR):
    item_path = os.path.join(KAGGLE_DIR, item)
    size = os.path.getsize(item_path) if os.path.isfile(item_path) else 0
    print(f"   {item}: {size/1024:.1f} KB" if size > 0 else f"   {item}/")

print("\n✅ Export complete!")


### Validate Exported Model

Test the exported adapters with inference.

In [None]:
# Validate Exported Model with Inference
print("🧪 Running Inference Validation...")
print("="*60)

# Test prompts for validation
test_prompts = [
    "Is a verbal contract enforceable in most jurisdictions?",
    "What are the elements required to prove negligence?",
    "Can a contract be voided if one party was under duress?",
]

print("\n📝 Test Prompts:")
for i, prompt in enumerate(test_prompts, 1):
    print(f"   {i}. {prompt}")

# Generate responses using trained model
print("\n🔄 Generating responses with trained model...")

validation_results = []

for i, prompt in enumerate(test_prompts):
    # Create full prompt with system instructions
    full_prompt = create_prompt_template(prompt)
    
    # Generate response
    try:
        response = grpo_learner.generate(
            prompts=[full_prompt],
            max_tokens=GRPO_CONFIG["max_tokens_to_generate"],
            temperature=0.7,
        )[0]
    except Exception as e:
        print(f"\n❌ Generation error for prompt {i+1}: {e}")
        continue
    
    # Validate format
    has_valid_format = validate_xml_format(response)
    reasoning, answer = extract_xml_content(response)
    
    # Count reasoning tokens
    reasoning_tokens = 0
    if reasoning:
        reasoning_tokens = len(tokenizer.encode(reasoning))
    
    # Compute reward
    reward = composite_reward_function(
        [full_prompt],
        [response],
        [{"ground_truth": ""}],  # No ground truth for test prompts
        tokenizer
    )[0]
    
    result = {
        "prompt": prompt,
        "response": response,
        "valid_format": has_valid_format,
        "reasoning_tokens": reasoning_tokens,
        "has_reasoning": reasoning is not None,
        "has_answer": answer is not None,
        "reward": reward,
    }
    validation_results.append(result)
    
    # Display results
    print(f"\n{'='*60}")
    print(f"📋 Test {i+1}: {prompt[:50]}...")
    print(f"{'='*60}")
    print(f"   ✓ Valid XML format: {has_valid_format}")
    print(f"   ✓ Reasoning tokens: {reasoning_tokens}")
    print(f"   ✓ Has reasoning: {reasoning is not None}")
    print(f"   ✓ Has answer: {answer is not None}")
    print(f"   ✓ Reward score: {reward:.3f}")
    
    if reasoning:
        print(f"\n   📝 Reasoning preview:")
        print(f"      {reasoning[:200]}...")
    if answer:
        print(f"\n   💡 Answer:")
        print(f"      {answer[:200]}")

# Summary
print("\n" + "="*60)
print("📊 VALIDATION SUMMARY")
print("="*60)

valid_count = sum(1 for r in validation_results if r["valid_format"])
avg_reasoning_tokens = sum(r["reasoning_tokens"] for r in validation_results) / len(validation_results) if validation_results else 0
avg_reward = sum(r["reward"] for r in validation_results) / len(validation_results) if validation_results else 0

print(f"   Total test prompts: {len(test_prompts)}")
print(f"   Valid XML format: {valid_count}/{len(validation_results)} ({100*valid_count/len(validation_results):.0f}%)" if validation_results else "   No results")
print(f"   Avg reasoning tokens: {avg_reasoning_tokens:.0f}")
print(f"   Avg reward score: {avg_reward:.3f}")

# Quality assessment
print("\n📈 Quality Assessment:")
if avg_reward >= 0.7:
    print("   ✅ EXCELLENT: Model produces high-quality legal reasoning")
elif avg_reward >= 0.5:
    print("   ✅ GOOD: Model produces adequate legal reasoning")
elif avg_reward >= 0.3:
    print("   ⚠️ FAIR: Model needs more training for better quality")
else:
    print("   ❌ POOR: Model requires significant improvement")

if valid_count == len(validation_results) and validation_results:
    print("   ✅ All outputs have valid XML format")
elif valid_count > 0:
    print(f"   ⚠️ Some outputs missing proper XML tags ({len(validation_results) - valid_count} invalid)")
else:
    print("   ❌ No outputs have valid XML format - check training")

print("\n✅ Validation complete!")


In [None]:
import zipfile
import os

# Create zip archive
def create_submission_zip(source_dir: str, output_file: str):
    """
    Create a zip archive for Kaggle submission.
    
    Args:
        source_dir: Directory containing files to zip
        output_file: Output zip file path
    """
    with zipfile.ZipFile(output_file, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(source_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, source_dir)
                zipf.write(file_path, arcname)
                print(f"   Added: {arcname}")
    
    # Get zip file size
    size_mb = os.path.getsize(output_file) / (1024 * 1024)
    return size_mb

# Create submission
submission_zip = "./judicaita_submission.zip"
print("📦 Creating Kaggle submission package...")
print(f"   Source: {KAGGLE_DIR}")
print(f"   Output: {submission_zip}")
print("\n📄 Files included:")

try:
    size = create_submission_zip(KAGGLE_DIR, submission_zip)
    print(f"\n✅ Submission package created!")
    print(f"   File: {submission_zip}")
    print(f"   Size: {size:.2f} MB")
    
    print("\n📋 Submission Checklist:")
    print("   ✅ adapter_config.json")
    print("   ✅ README.md with instructions")
    print("   ⚠️  adapter_model.safetensors (add after training)")
    print("   ⚠️  Validation results (add after testing)")
    
    print("\n🎯 Next Steps:")
    print("   1. Complete GRPO training")
    print("   2. Export adapter weights to kaggle_upload/")
    print("   3. Run inference validation")
    print("   4. Re-run this cell to create final zip")
    print("   5. Upload to Kaggle competition")
    
except Exception as e:
    print(f"❌ Error creating zip: {e}")
    print("   Make sure kaggle_upload directory has content")

### 🔧 Troubleshooting Guide

#### Tunix Import Errors
- **ModuleNotFoundError: No module named 'tunix'**
  - Ensure you installed with TPU extras: `pip install "google-tunix[tpu]"`
  - Restart runtime after installation
  - Verify version: `python -c "import tunix; print(tunix.__version__)"`

- **ImportError: cannot import name 'GRPOLearner'**
  - Check Tunix version >= 0.5.0
  - Verify correct import path: `from tunix.rl.grpo.grpo_learner import GRPOLearner`

#### JAX/TPU Initialization Issues
- **RuntimeError: TPU not found**
  - Verify Colab runtime is set to TPU: Runtime → Change runtime type → TPU
  - Try restarting the runtime completely
  - Check TPU quota in Google Cloud Console if using custom project

- **JAX version mismatch errors**
  - Install pinned versions: `pip install jax==0.4.35 jaxlib==0.4.35`
  - Restart runtime after JAX installation
  - Verify: `python -c "import jax; print(jax.__version__, jax.devices())"`

#### RLCluster Configuration Errors
- **ValueError: Mesh shape mismatch**
  - Ensure mesh is created with correct number of devices
  - Check `len(jax.devices())` matches expected TPU cores
  - For TPU v2-8, expect 8 devices

- **Sharding errors during training**
  - Verify data_sharding is compatible with batch size
  - Reduce batch_size to 1 or 2 for debugging
  - Check model dtype is bfloat16 for TPU

#### Memory Errors (OOM)
- **Out of Memory during rollout generation**
  - Reduce `num_generations` from 4 to 2
  - Reduce `max_tokens_to_generate` from 512 to 256
  - Reduce `batch_size` from 4 to 2 or 1

- **Out of Memory during backward pass**
  - Use smaller LoRA rank: try rank=8 instead of 16
  - Enable gradient checkpointing if available
  - Reduce sequence length

#### Reward Function Issues
- **Reward function signature mismatch**
  - Tunix expects `reward_fn(prompts: List[str], outputs: List[str]) -> List[float]`
  - Use `tunix_reward_wrapper` instead of `composite_reward_function` directly
  - Ensure function returns Python list of floats, not numpy/jax arrays

- **All rewards are 0.0**
  - Check if model is generating XML tags properly
  - Verify `extract_xml_content()` is working correctly
  - Test reward function manually with sample outputs

#### Checkpoint Issues
- **Checkpoint save fails**
  - Ensure checkpoint directory exists and is writable
  - Check disk space (Colab has ~100GB limit)
  - For large models, consider saving to Google Drive

- **Checkpoint load fails**
  - Verify checkpoint path is correct
  - Check if checkpoint was saved completely (no interruption)
  - Try loading with `strict=False` to ignore missing keys

#### Training Not Converging
- **Loss not decreasing**
  - Try lower learning rate: 5e-6 or 1e-6
  - Increase warmup steps
  - Check if rewards are providing meaningful signal

- **KL divergence too high**
  - Increase beta (KL penalty coefficient)
  - Reduce learning rate
  - Ensure reference model is properly frozen

- **Rewards not improving**
  - Verify ground truth data quality
  - Check reward function components individually
  - Increase training iterations

#### Export Issues
- **safetensors export fails**
  - Install safetensors: `pip install safetensors>=0.4.0`
  - Verify weights are on CPU before saving
  - Check file path permissions

- **Exported adapters don't load in PyTorch**
  - Ensure adapter_config.json has correct format
  - Verify target_modules match PyTorch model layer names
  - Check if conversion from Flax to PyTorch is needed

#### Colab-Specific Issues
- **Runtime disconnection during training**
  - Save checkpoints frequently (every 50-100 steps)
  - Keep browser tab active
  - Consider using Colab Pro for longer runtime

- **Storage limit reached**
  - Clear old checkpoints: keep only latest + final
  - Export to Google Drive
  - Use smaller checkpoint format


## 🎉 Conclusion

This notebook demonstrates end-to-end GRPO training for legal reasoning using Google Tunix on TPU:

### What We Built

1. ✅ **TPU Setup**: Initialized JAX with TPU v2-8 using `colab_tpu.setup_tpu()`
2. ✅ **Model Loading**: Downloaded Gemma 3-1B-IT and initialized with LoRA adapters
3. ✅ **Dataset Preparation**: Created XML-formatted prompts for legal reasoning
4. ✅ **Reward Function**: Implemented composite scoring (format + length + correctness)
5. ✅ **GRPO Training**: Executed training with `GRPOLearner` and `RLCluster`
6. ✅ **Export**: Packaged LoRA adapters in safetensors format for submission

### Training Results

After training, the model should:
- Generate responses in valid XML format (`<reasoning>...</reasoning><answer>...</answer>`)
- Produce detailed legal reasoning (100+ tokens)
- Provide accurate answers based on legal principles

### Files Produced

| File | Description |
|------|-------------|
| `adapter_config.json` | LoRA configuration for PEFT |
| `adapter_model.safetensors` | Trained LoRA weights |
| `README.md` | Inference instructions |
| `judicaita_submission.zip` | Kaggle submission package |

### Next Steps

1. **Upload to Kaggle**: Submit `judicaita_submission.zip` to the competition
2. **Fine-tune Further**: Increase training iterations for better results
3. **Add More Data**: Include additional legal reasoning examples
4. **Evaluate on LegalBench**: Test on official benchmark tasks

### Resources

- [Tunix Documentation](https://tunix.readthedocs.io/)
- [Tunix GRPO Gemma Example](https://github.com/google/tunix/tree/main/examples/grpo_gemma)
- [Judicaita Repository](https://github.com/clduab11/judicAIta)
- [Gemma Model Cards](https://ai.google.dev/gemma)
- [JAX TPU Guide](https://jax.readthedocs.io/en/latest/notebooks/TPU_Colab.html)

### Troubleshooting & Support

If you encounter issues:
1. Check the Troubleshooting Guide section above
2. Open an issue: https://github.com/clduab11/judicAIta/issues
3. Review Tunix documentation for API changes

### Contributing

Improvements welcome! Submit a PR with:
- Additional reward function components
- Better data preprocessing
- Performance optimizations
- Documentation improvements

---

**Made with ❤️ for the Kaggle hackathon and legal tech community**

*Powered by Google Tunix, JAX, and Gemma*
