# Judicaita: GRPO Training with Google Tunix on TPU

## 🎯 Hackathon Context

This notebook demonstrates **GRPO (Group Relative Policy Optimization)** training for the Judicaita legal AI assistant using:
- **Google Tunix** for RL training infrastructure
- **Gemma 3-1B-IT** as the base model
- **TPU v2-8+** for accelerated training
- **LoRA adapters** for parameter-efficient fine-tuning

This is developed for the Kaggle hackathon to train models that generate explainable legal reasoning with structured XML-formatted outputs.

## ⚡ TPU Requirements

**IMPORTANT**: This notebook requires:
- Google Colab with TPU runtime (TPU v2-8 or higher)
- Runtime type: TPU (not CPU or GPU)
- To enable: Runtime → Change runtime type → Hardware accelerator: TPU

## 📋 What This Notebook Does

1. **Environment Setup**: Install Tunix, JAX, and dependencies for TPU
2. **Model Loading**: Download and initialize Gemma 3-1B-IT
3. **Dataset Preparation**: Format training data with XML-tagged reasoning
4. **Reward Function**: Score outputs based on format, reasoning length, and correctness
5. **GRPO Training**: Train with LoRA adapters on TPU
6. **Export**: Package trained adapters for Kaggle submission

## 🔄 Data Flow

```
Dataset → Prompts → Model Rollouts → Reward Scoring → GRPO Updates
                                                           ↓
                                              LoRA Adapter Checkpoints
```

## ⚠️ Differences from Main Codebase

| Aspect | Main Codebase | This Notebook |
|--------|---------------|---------------|
| Format | Step-by-step format | XML `<reasoning>`/`<answer>` |
| Framework | PyTorch | JAX/Flax |
| Training | Custom GRPO | Tunix GRPOLearner |
| Hardware | GPU/CPU | TPU v2-8+ |

## 📚 References

- [Google Tunix Documentation](https://github.com/google/tunix)
- [Gemma Model Card](https://ai.google.dev/gemma/docs)
- [GRPO Paper](https://arxiv.org/abs/2402.03300)
- [Judicaita Repository](https://github.com/clduab11/judicAIta)

## ⚠️ Known Limitations

- **TPU Required**: Cannot run on CPU/GPU without code modifications
- **Memory**: TPU v2-8 has ~64GB; larger models may need v3 or higher
- **Dataset**: Assumes generic legal reasoning tasks (not LegalBench-specific)
- **Checkpoints**: Large checkpoint files may exceed Colab storage limits

## 📦 Step 1: Install Dependencies

Install required packages for TPU training with Tunix and Gemma.

In [None]:
# Install core dependencies
!pip install -q google-tunix
!pip install -q 'jax[tpu]>=0.4.20' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -q flax>=0.8.0
!pip install -q transformers>=4.40.0
!pip install -q huggingface_hub>=0.20.0
!pip install -q datasets>=2.14.0
!pip install -q sentencepiece>=0.1.99

print("✅ Dependencies installed successfully!")
print("⚠️  IMPORTANT: Runtime restart required for TPU libraries.")
print("   Go to: Runtime → Restart runtime")
print("   Then continue with the next cell.")

### ⚠️ Runtime Restart Required

**STOP HERE** and restart the runtime:
1. Click `Runtime` → `Restart runtime` in the menu
2. After restart, continue from the next cell

This is necessary for TPU libraries to be properly loaded.

## 🚀 Step 2: Initialize TPU Runtime

Set up JAX to use TPU devices.

In [None]:
import jax
import jax.numpy as jnp
from jax.tools import colab_tpu

# Initialize TPU
print("Initializing TPU runtime...")
try:
    colab_tpu.setup_tpu()
    print("✅ TPU initialized successfully!")
except Exception as e:
    print(f"❌ TPU initialization failed: {e}")
    print("\n🔧 Troubleshooting:")
    print("   1. Check runtime type: Runtime → Change runtime type → TPU")
    print("   2. Ensure TPU v2-8 or higher is available")
    print("   3. Try restarting the runtime")
    print("   4. Check Google Cloud TPU quota if using custom project")
    raise

# Verify TPU devices
devices = jax.devices()
print(f"\n📊 TPU Device Information:")
print(f"   Number of devices: {len(devices)}")
print(f"   Device type: {devices[0].platform}")
print(f"   Devices: {devices}")

if len(devices) == 0:
    raise RuntimeError("No TPU devices detected! Please check your runtime configuration.")

print("\n✅ TPU setup complete and verified!")

## 🔐 Step 3: Authenticate with Hugging Face

Login to Hugging Face to download the Gemma model.

In [None]:
from huggingface_hub import login, snapshot_download
import os

# Login to Hugging Face
# You'll be prompted to enter your HF token
# Get your token from: https://huggingface.co/settings/tokens
print("Please enter your Hugging Face token:")
login()

print("\n✅ Authenticated with Hugging Face!")

## 📥 Step 4: Download Gemma 3-1B-IT Model

Download the model files and initialize the tokenizer.

**Note**: Using `gemma-3-1b-it` as it's the latest available Gemma instruction-tuned model. Update to `gemma-3-1b-it` if/when available.

In [None]:
from transformers import AutoTokenizer
import os

# Download model
MODEL_ID = "google/gemma-3-1b-it"  # Using gemma-3-1b-it as gemma-3-1b-it may not be available yet
CACHE_DIR = "./gemma_model_cache"

print(f"Downloading {MODEL_ID}...")
model_path = snapshot_download(
    repo_id=MODEL_ID,
    cache_dir=CACHE_DIR,
    local_dir=f"{CACHE_DIR}/gemma",
    local_dir_use_symlinks=False
)
print(f"✅ Model downloaded to: {model_path}")

# Initialize tokenizer
print("\nInitializing tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
print(f"✅ Tokenizer initialized")
print(f"   Vocab size: {len(tokenizer)}")
print(f"   Special tokens: {tokenizer.special_tokens_map}")

# Test tokenization
test_text = "What is the legal precedent for breach of contract?"
tokens = tokenizer(test_text, return_tensors="np")
print(f"\n📝 Test tokenization:")
print(f"   Input: {test_text}")
print(f"   Token count: {len(tokens['input_ids'][0])}")

## 🔧 Step 5: Create Preprocessing Function

Gemma models don't have native system role support. We'll prepend the system prompt to the first user turn.

In [None]:
def preprocess_with_system_prompt(messages, system_prompt):
    """
    Prepend system prompt to first user message.
    
    Gemma doesn't support system role natively, so we merge it with
    the first user turn as a workaround.
    
    Args:
        messages: List of message dicts with 'role' and 'content'
        system_prompt: System instruction string
        
    Returns:
        Modified messages list with system prompt prepended
    """
    if not messages:
        return messages
    
    processed = messages.copy()
    
    # Find first user message
    for i, msg in enumerate(processed):
        if msg.get('role') == 'user':
            # Prepend system prompt
            original_content = msg['content']
            processed[i]['content'] = f"{system_prompt}\n\n{original_content}"
            break
    
    return processed

# Define system prompt for legal reasoning
SYSTEM_PROMPT = """You are a legal AI assistant. For each question, provide your analysis in this exact format:
<reasoning>Your step-by-step legal reasoning here. Include relevant legal principles, precedents, and analysis. Aim for at least 100 tokens of detailed reasoning.</reasoning>
<answer>Your final answer or conclusion here.</answer>

Always use this XML format and ensure your reasoning is thorough and well-explained."""

# Test preprocessing
test_messages = [
    {"role": "user", "content": "Is a non-compete clause enforceable in California?"}
]
processed = preprocess_with_system_prompt(test_messages, SYSTEM_PROMPT)
print("📝 Test preprocessing:")
print(f"Original: {test_messages[0]['content'][:50]}...")
print(f"\nProcessed length: {len(processed[0]['content'])} chars")
print(f"System prompt prepended: {'<reasoning>' in processed[0]['content']}")
print("\n✅ Preprocessing function ready!")

## 📊 Task 2: Prepare Training Dataset

Create a dataset with XML-tagged reasoning format compatible with Tunix GRPO.

### JSONL Format Requirements

Each training example must be a JSON object with:
- `prompt`: The question or task
- `ground_truth`: The correct answer for evaluation
- `metadata` (optional): Additional info like task_id, difficulty, etc.

In [None]:
import json
import re
from typing import List, Dict, Any

def prepare_dataset_for_tunix(examples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Prepare dataset in Tunix-compatible JSONL format.
    
    Args:
        examples: List of dicts with 'question' and 'answer' fields
        
    Returns:
        List of dicts with 'prompt', 'ground_truth', and 'metadata'
    """
    prepared = []
    
    for idx, ex in enumerate(examples):
        prepared.append({
            "prompt": ex.get("question", ex.get("prompt", "")),
            "ground_truth": ex.get("answer", ex.get("ground_truth", "")),
            "metadata": {
                "example_id": idx,
                "original_question": ex.get("question", ""),
                "task_type": ex.get("task_type", "general_reasoning")
            }
        })
    
    return prepared

# Create synthetic sample data (replace with real data)
sample_examples = [
    {
        "question": "Can an employer in California enforce a non-compete clause against a former employee?",
        "answer": "No, non-compete clauses are generally unenforceable in California except in limited circumstances involving sale of business or dissolution of partnership.",
        "task_type": "legal_qa"
    },
    {
        "question": "What is the statute of limitations for filing a breach of contract claim?",
        "answer": "The statute of limitations varies by jurisdiction. In many states, it is 4-6 years for written contracts and 2-3 years for oral contracts.",
        "task_type": "legal_qa"
    },
    {
        "question": "Under what circumstances can a contract be voided for duress?",
        "answer": "A contract can be voided for duress when one party was forced to enter the agreement through threats, violence, or other improper pressure that overcame their free will.",
        "task_type": "legal_qa"
    },
    {
        "question": "What is required to establish an attorney-client privilege?",
        "answer": "Attorney-client privilege requires: (1) an attorney-client relationship, (2) confidential communication, (3) made for the purpose of seeking or providing legal advice.",
        "task_type": "legal_qa"
    },
]

# Prepare dataset
prepared_dataset = prepare_dataset_for_tunix(sample_examples)

print(f"✅ Prepared {len(prepared_dataset)} training examples")
print(f"\n📝 Sample example:")
print(json.dumps(prepared_dataset[0], indent=2))

# Note: In production, load from file or HuggingFace dataset
print("\n💡 To load from file:")
print("   # with open('data.jsonl', 'r') as f:")
print("   #     examples = [json.loads(line) for line in f]")
print("   #     prepared_dataset = prepare_dataset_for_tunix(examples)")

### Prompt Template with XML Format

Create a template that formats prompts to expect XML-tagged reasoning.

In [None]:
def create_prompt_template(question: str, system_prompt: str = SYSTEM_PROMPT) -> str:
    """
    Create a formatted prompt with XML output expectations.
    
    Args:
        question: The legal question to answer
        system_prompt: System instructions for format
        
    Returns:
        Formatted prompt string
    """
    template = f"""{system_prompt}

Question: {question}

Response:"""
    return template

def validate_xml_format(response: str) -> bool:
    """
    Validate that response contains proper XML tags.
    
    Args:
        response: Model generated response
        
    Returns:
        True if valid XML format, False otherwise
    """
    # Check for both opening and closing tags
    has_reasoning = '<reasoning>' in response and '</reasoning>' in response
    has_answer = '<answer>' in response and '</answer>' in response
    
    return has_reasoning and has_answer

# Apply template to all examples
templated_prompts = []
for example in prepared_dataset:
    templated = {
        "prompt": create_prompt_template(example["prompt"]),
        "ground_truth": example["ground_truth"],
        "metadata": example["metadata"],
        "original_prompt": example["prompt"]
    }
    templated_prompts.append(templated)

print(f"✅ Created {len(templated_prompts)} templated prompts")
print(f"\n📝 Sample templated prompt (first 300 chars):")
print(templated_prompts[0]["prompt"][:300])
print("...")

# Test validation
test_valid = "<reasoning>This is reasoning</reasoning><answer>This is answer</answer>"
test_invalid = "This is just text without tags"
print(f"\n✅ Validation test:")
print(f"   Valid format: {validate_xml_format(test_valid)}")
print(f"   Invalid format: {validate_xml_format(test_invalid)}")

### Tokenization and Batching

Tokenize prompts and prepare batches for training.

In [None]:
import numpy as np
from typing import List, Dict

# Set maximum prompt length
MAX_PROMPT_LENGTH = 512  # Adjust based on your needs (512 or 1024)
MAX_RESPONSE_LENGTH = 512

def tokenize_prompts(prompts: List[str], tokenizer, max_length: int = MAX_PROMPT_LENGTH):
    """
    Tokenize prompts with padding and truncation.
    
    Args:
        prompts: List of prompt strings
        tokenizer: HuggingFace tokenizer
        max_length: Maximum token length
        
    Returns:
        Dict with input_ids and attention_mask
    """
    tokenized = tokenizer(
        prompts,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="np"
    )
    return tokenized

def create_training_batches(dataset: List[Dict], batch_size: int = 4):
    """
    Create batches from dataset.
    
    Args:
        dataset: List of training examples
        batch_size: Number of examples per batch
        
    Returns:
        List of batches, each batch is a list of examples
    """
    batches = []
    for i in range(0, len(dataset), batch_size):
        batch = dataset[i:i + batch_size]
        batches.append(batch)
    return batches

# Tokenize all prompts
all_prompts = [ex["prompt"] for ex in templated_prompts]
tokenized_prompts = tokenize_prompts(all_prompts, tokenizer, MAX_PROMPT_LENGTH)

print(f"✅ Tokenized {len(all_prompts)} prompts")
print(f"   Max length: {MAX_PROMPT_LENGTH} tokens")
print(f"   Shape: {tokenized_prompts['input_ids'].shape}")

# Create final dataset for training
training_dataset = []
for i, ex in enumerate(templated_prompts):
    training_dataset.append({
        "prompt": ex["prompt"],
        "prompt_tokens": tokenized_prompts['input_ids'][i],
        "attention_mask": tokenized_prompts['attention_mask'][i],
        "ground_truth": ex["ground_truth"],
        "metadata": ex["metadata"]
    })

print(f"\n✅ Final training dataset: {len(training_dataset)} examples")
print(f"   Each example has: {list(training_dataset[0].keys())}")

# Validate dataset format
required_fields = ["prompt", "ground_truth", "metadata"]
all_valid = all(all(field in ex for field in required_fields) for ex in training_dataset)
print(f"\n✅ Dataset validation: {'PASSED' if all_valid else 'FAILED'}")

if not all_valid:
    print("❌ Some examples missing required fields!")
else:
    print("   All examples have required fields: prompt, ground_truth, metadata")

## 🎯 Task 3: Implement Custom Reward Function

Create a reward function that scores:
1. **Format**: Proper XML tags
2. **Reasoning Length**: At least 100 tokens
3. **Answer Correctness**: Match with ground truth

In [None]:
import re
from typing import Tuple, Optional

def extract_xml_content(response: str) -> Tuple[Optional[str], Optional[str]]:
    """
    Extract content from <reasoning> and <answer> XML tags.
    
    Args:
        response: Model-generated response string
        
    Returns:
        Tuple of (reasoning_content, answer_content)
        Returns (None, None) if tags are malformed or missing
    """
    try:
        # Extract reasoning
        reasoning_match = re.search(r'<reasoning>(.*?)</reasoning>', response, re.DOTALL)
        reasoning = reasoning_match.group(1).strip() if reasoning_match else None
        
        # Extract answer
        answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
        answer = answer_match.group(1).strip() if answer_match else None
        
        return reasoning, answer
    except Exception as e:
        print(f"Warning: Error extracting XML content: {e}")
        return None, None

# Test extraction with edge cases
test_cases = [
    # Valid case
    "<reasoning>Step by step analysis here</reasoning><answer>Final answer</answer>",
    # Missing tags
    "Just plain text without tags",
    # Partial tags
    "<reasoning>Incomplete reasoning",
    # Nested content
    "<reasoning>Analysis with <term>nested</term> content</reasoning><answer>Yes</answer>",
    # Multi-line
    """<reasoning>
Line 1 of reasoning
Line 2 of reasoning
</reasoning>
<answer>Final answer</answer>"""
]

print("🧪 Testing XML extraction:")
for i, test in enumerate(test_cases, 1):
    reasoning, answer = extract_xml_content(test)
    print(f"\nTest {i}:")
    print(f"  Reasoning found: {reasoning is not None}")
    print(f"  Answer found: {answer is not None}")
    if reasoning:
        print(f"  Reasoning preview: {reasoning[:50]}...")
    if answer:
        print(f"  Answer: {answer}")

print("\n✅ XML extraction function tested with edge cases")

In [None]:
def compute_format_reward(response: str) -> float:
    """
    Reward for valid XML format.
    
    Returns:
        1.0 if both <reasoning> and <answer> tags present and valid
        0.0 otherwise
    """
    reasoning, answer = extract_xml_content(response)
    
    # Check both tags present and have content
    if reasoning is not None and answer is not None:
        if len(reasoning.strip()) > 0 and len(answer.strip()) > 0:
            return 1.0
    
    return 0.0

def compute_reasoning_length_reward(response: str, tokenizer, min_tokens: int = 100) -> float:
    """
    Reward based on reasoning length.
    
    Returns:
        1.0 if reasoning >= min_tokens
        proportional score (tokens/min_tokens) if fewer
        0.0 if no reasoning found
    """
    reasoning, _ = extract_xml_content(response)
    
    if reasoning is None:
        return 0.0
    
    # Tokenize reasoning to count tokens
    tokens = tokenizer(reasoning, return_tensors="np")["input_ids"]
    num_tokens = len(tokens[0])
    
    # Return 1.0 if meets threshold, otherwise proportional
    if num_tokens >= min_tokens:
        return 1.0
    else:
        return num_tokens / min_tokens

def compute_answer_correctness_reward(response: str, ground_truth: str, tokenizer) -> float:
    """
    Reward based on answer correctness.
    
    Returns:
        1.0 for exact match (after normalization)
        Partial credit for token overlap (Jaccard similarity)
        0.0 if no answer found
    """
    _, answer = extract_xml_content(response)
    
    if answer is None:
        return 0.0
    
    # Normalize for comparison
    answer_norm = answer.lower().strip()
    ground_truth_norm = ground_truth.lower().strip()
    
    # Check exact match
    if answer_norm == ground_truth_norm:
        return 1.0
    
    # Tokenize both for overlap calculation
    answer_tokens = set(tokenizer.tokenize(answer_norm))
    truth_tokens = set(tokenizer.tokenize(ground_truth_norm))
    
    # Calculate Jaccard similarity (intersection / union)
    if len(answer_tokens) == 0 or len(truth_tokens) == 0:
        return 0.0
    
    intersection = len(answer_tokens & truth_tokens)
    union = len(answer_tokens | truth_tokens)
    
    jaccard = intersection / union if union > 0 else 0.0
    
    # Return Jaccard similarity as partial credit
    return jaccard

print("✅ Reward component functions defined:")
print("   - compute_format_reward()")
print("   - compute_reasoning_length_reward()")
print("   - compute_answer_correctness_reward()")

In [None]:
from typing import List

def composite_reward_function(
    prompts: List[str],
    completions: List[str],
    metadata: List[Dict],
    tokenizer  # ADD: tokenizer parameter
) -> List[float]:
    """
    Main reward function with Tunix RewardFn signature.
    
    Combines three reward components:
    - Format reward (30%): Valid XML tags
    - Reasoning reward (30%): Sufficient reasoning length
    - Correctness reward (40%): Answer accuracy
    
    Args:
        prompts: List of input prompts
        completions: List of model completions
        metadata: List of metadata dicts with ground_truth
        tokenizer: HuggingFace tokenizer for token counting
        
    Returns:
        List of scalar rewards (one per example)
    """
    # Reward weights
    FORMAT_WEIGHT = 0.3
    REASONING_WEIGHT = 0.3
    CORRECTNESS_WEIGHT = 0.4
    
    rewards = []
    
    for i, (prompt, completion, meta) in enumerate(zip(prompts, completions, metadata)):
        # Compute each reward component
        format_reward = compute_format_reward(completion)
        reasoning_reward = compute_reasoning_length_reward(completion, tokenizer, min_tokens=100)
        
        # Get ground truth from metadata
        ground_truth = meta.get("ground_truth", "")
        correctness_reward = compute_answer_correctness_reward(completion, ground_truth, tokenizer)
        
        # Aggregate rewards
        total_reward = (
            FORMAT_WEIGHT * format_reward +
            REASONING_WEIGHT * reasoning_reward +
            CORRECTNESS_WEIGHT * correctness_reward
        )
        
        rewards.append(total_reward)
        
        # Log breakdown for first few examples
        if i < 3:
            print(f"\n📊 Example {i} reward breakdown:")
            print(f"   Format: {format_reward:.2f} (weight: {FORMAT_WEIGHT})")
            print(f"   Reasoning: {reasoning_reward:.2f} (weight: {REASONING_WEIGHT})")
            print(f"   Correctness: {correctness_reward:.2f} (weight: {CORRECTNESS_WEIGHT})")
            print(f"   Total: {total_reward:.2f}")
    
    return rewards

# Test reward function with tokenizer parameter
test_prompts = ["Test question"]
test_completions = [
    "<reasoning>This is a detailed legal analysis with sufficient tokens to explain the reasoning behind the answer. We consider precedent, statutory law, and policy implications.</reasoning><answer>Yes, it is enforceable.</answer>"
]
test_metadata = [{"ground_truth": "Yes, it is enforceable."}]

test_rewards = composite_reward_function(test_prompts, test_completions, test_metadata, tokenizer)
print(f"\n✅ Reward function test complete")
print(f"   Test reward: {test_rewards[0]:.2f}")
print("\n✅ Composite reward function ready for Tunix GRPO!")
print("\n💡 Note: When using with Tunix, wrap this function to match their exact signature.")
print("   Example: lambda p, c, m: composite_reward_function(p, c, m, tokenizer)")

## 🚀 Task 4: Configure and Execute GRPO Training

Set up LoRA adapters and run GRPO training on TPU.

In [None]:
# LoRA Hyperparameters
LORA_CONFIG = {
    "rank": 16,  # LoRA rank (16 or 32)
    "alpha": 32,  # LoRA alpha (32 or 64)
    "dropout": 0.05,  # LoRA dropout
    "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],  # Attention layers
}

# GRPO Configuration
GRPO_CONFIG = {
    "num_generations": 4,  # Number of rollouts per prompt
    "num_iterations": 2,  # Training iterations/epochs
    "beta": 0.04,  # KL penalty coefficient
    "epsilon": 0.2,  # Clipping parameter
    "learning_rate": 1e-5,  # Learning rate
    "batch_size": 4,  # Batch size (adjust for TPU memory)
    "gradient_accumulation_steps": 2,  # Gradient accumulation
    "eval_frequency": 50,  # Evaluate every N steps
    "checkpoint_interval": 100,  # Save checkpoint every N steps
}

print("✅ Configuration defined:")
print("\n🔧 LoRA Configuration:")
for k, v in LORA_CONFIG.items():
    print(f"   {k}: {v}")
print("\n🎯 GRPO Configuration:")
for k, v in GRPO_CONFIG.items():
    print(f"   {k}: {v}")

print("\n💡 Hyperparameter Rationale:")
print("   - LoRA rank=16: Balance between capacity and memory")
print("   - num_generations=4: Good variance reduction per Tunix defaults")
print("   - beta=0.04: Standard KL penalty to prevent policy divergence")
print("   - learning_rate=1e-5: Conservative for fine-tuning")

### ⚠️  Initialize Training Components

**IMPORTANT: The cells below contain TEMPLATE CODE only.**

The actual `google-tunix` library API may differ from this template. This notebook provides:
- Conceptual structure for GRPO training
- Typical patterns used in RL training
- Placeholder code to be adapted

**Before running training:**
1. Verify `google-tunix` is publicly available
2. Review official Tunix documentation and examples
3. Update import statements and API calls to match actual library
4. Test with small dataset first

**Alternatives if Tunix is unavailable:**
- Use [TRL (Transformer Reinforcement Learning)](https://github.com/huggingface/trl) library
- Adapt GRPO implementation from main Judicaita codebase (PyTorch)
- Use other RL frameworks (RLlib, Stable-Baselines3)

In [None]:
# ⚠️  IMPORTANT: TEMPLATE CODE - ADAPT TO ACTUAL TUNIX API
# The code below is a template based on typical RL training patterns.
# The actual google-tunix library API may differ significantly.
# Refer to official Tunix documentation and examples to adapt this code.
# 
# If google-tunix is not publicly available, consider using:
# - TRL (Transformer Reinforcement Learning) library
# - Custom GRPO implementation (see main codebase)
# - Other RL frameworks like RLlib or SB3

print("📦 Setting up training components...")

# Create checkpoint directories
import os
CHECKPOINT_DIR = "./checkpoints"
FINAL_DIR = "./final_checkpoint"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(FINAL_DIR, exist_ok=True)

print(f"✅ Checkpoint directories created:")
print(f"   Intermediate: {CHECKPOINT_DIR}")
print(f"   Final: {FINAL_DIR}")

# TEMPLATE: Typical Tunix-style setup (adapt to actual API)
print("\n📋 Training setup template (MUST BE ADAPTED):")
print("""
# Example imports (adapt based on actual library):
# from tunix.rl import RLCluster, GRPOLearner, GRPOConfig
# from tunix.models import load_model_with_lora
# import jax

# 1. Initialize model mesh for TPU sharding
# mesh = jax.sharding.Mesh(...)

# 2. Load base model and create LoRA config
# base_model = load_model_with_lora(
#     model_path=model_path,
#     lora_config=LORA_CONFIG
# )

# 3. Create RLCluster with actor and reference policies
# rl_cluster = RLCluster(
#     actor_policy=base_model,  # Model being trained
#     reference_policy=base_model,  # Frozen reference for KL penalty
#     tokenizer=tokenizer
# )

# 4. Create GRPO learner
# learner = GRPOLearner(
#     rl_cluster=rl_cluster,
#     reward_fn=composite_reward_function,
#     config=GRPOConfig(**GRPO_CONFIG)
# )
""")

print("\n⚠️  ACTION REQUIRED:")
print("   1. Check if google-tunix is publicly available")
print("   2. Review actual Tunix API documentation")
print("   3. Replace template code with correct imports and calls")
print("   4. Test with a small dataset first")
print("\n💡 Alternative: Use TRL library or main codebase GRPO implementation")

In [None]:
# Training loop template
print("🎯 Training Loop Template:")
print("""
# Execute training within JAX mesh context
# with mesh:
#     for iteration in range(GRPO_CONFIG['num_iterations']):
#         print(f"\n{'='*60}")
#         print(f"Iteration {iteration + 1}/{GRPO_CONFIG['num_iterations']}")
#         print(f"{'='*60}")
#         
#         # Training step
#         metrics = learner.train_step(
#             dataset=training_dataset,
#             batch_size=GRPO_CONFIG['batch_size']
#         )
#         
#         # Log metrics
#         print(f"  Loss: {metrics.get('loss', 0):.4f}")
#         print(f"  Avg Reward: {metrics.get('avg_reward', 0):.4f}")
#         print(f"  Policy KL: {metrics.get('policy_kl', 0):.4f}")
#         
#         # Sample outputs
#         if iteration % GRPO_CONFIG['eval_frequency'] == 0:
#             sample_output = learner.generate(
#                 prompts=[training_dataset[0]['prompt']],
#                 max_length=MAX_RESPONSE_LENGTH
#             )
#             print(f"\n  Sample output: {sample_output[0][:200]}...")
#         
#         # Save checkpoint
#         if iteration % GRPO_CONFIG['checkpoint_interval'] == 0:
#             checkpoint_path = f"{CHECKPOINT_DIR}/iter_{iteration}"
#             learner.save_checkpoint(checkpoint_path)
#             print(f"  ✅ Checkpoint saved: {checkpoint_path}")
#     
#     # Save final checkpoint
#     learner.save_checkpoint(FINAL_DIR)
#     print(f"\n✅ Training complete! Final checkpoint: {FINAL_DIR}")
""")

print("\n💡 Training Tips:")
print("   - Monitor loss and reward trends")
print("   - Check sample outputs for quality")
print("   - Save checkpoints frequently for recovery")
print("   - Adjust batch_size if OOM errors occur")

print("\n⚠️  Since actual Tunix API may differ, refer to:")
print("   - https://github.com/google/tunix (if available)")
print("   - Tunix documentation and examples")
print("   - Adapt the template above to match the actual API")

## 📦 Task 5: Export LoRA Adapters and Create Kaggle Submission

Package trained adapters for Kaggle submission.

In [None]:
import os
import shutil

# Create kaggle_upload directory
KAGGLE_DIR = "./kaggle_upload"
os.makedirs(KAGGLE_DIR, exist_ok=True)

print(f"✅ Created Kaggle submission directory: {KAGGLE_DIR}")
print("\n📋 Export checklist:")
print("   [ ] adapter_config.json - LoRA configuration")
print("   [ ] adapter_model.safetensors - LoRA weights")
print("   [ ] tokenizer files (if modified)")
print("   [ ] README with inference instructions")

In [None]:
# Template for exporting adapters
print("📦 Export Template (adapt to actual Tunix API):")
print("""
# Option 1: Export LoRA adapters separately
# learner.export_lora_adapters(
#     output_dir=KAGGLE_DIR,
#     format='safetensors'
# )

# Option 2: Export merged model (if needed)
# learner.export_merged_model(
#     output_dir=KAGGLE_DIR,
#     format='safetensors'
# )

# Copy tokenizer files
# shutil.copytree(
#     model_path,
#     f"{KAGGLE_DIR}/tokenizer",
#     dirs_exist_ok=True
# )
""")

# Create adapter_config.json manually as example
import json

adapter_config = {
    "peft_type": "LORA",
    "task_type": "CAUSAL_LM",
    "r": LORA_CONFIG["rank"],
    "lora_alpha": LORA_CONFIG["alpha"],
    "lora_dropout": LORA_CONFIG["dropout"],
    "target_modules": LORA_CONFIG["target_modules"],
    "inference_mode": False,
    "base_model_name_or_path": MODEL_ID,
}

config_path = f"{KAGGLE_DIR}/adapter_config.json"
with open(config_path, 'w') as f:
    json.dump(adapter_config, f, indent=2)

print(f"\n✅ Created {config_path}")
print("\n📄 adapter_config.json contents:")
print(json.dumps(adapter_config, indent=2))

# Create README
readme_content = """# Judicaita GRPO-Trained LoRA Adapters

## Model Information

- Base Model: {model_id}
- Training Method: GRPO (Group Relative Policy Optimization)
- Framework: Google Tunix + JAX/Flax
- LoRA Rank: {rank}
- LoRA Alpha: {alpha}

## Inference Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Load base model
base_model = AutoModelForCausalLM.from_pretrained("{model_id}")
tokenizer = AutoTokenizer.from_pretrained("{model_id}")

# Load LoRA adapters
model = PeftModel.from_pretrained(
    base_model,
    "./adapter_model"  # Path to this directory
)

# Generate
prompt = "Your legal question here"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=512)
response = tokenizer.decode(outputs[0])
```

## Expected Output Format

The model generates responses in XML format:
```xml
<reasoning>
Detailed legal reasoning with analysis...
</reasoning>
<answer>
Final answer or conclusion
</answer>
```

## Validation

- Reasoning should be >= 100 tokens
- Both XML tags must be present
- Answer should be relevant to the question
""".format(
    model_id=MODEL_ID,
    rank=LORA_CONFIG["rank"],
    alpha=LORA_CONFIG["alpha"]
)

with open(f"{KAGGLE_DIR}/README.md", 'w') as f:
    f.write(readme_content)

print(f"\n✅ Created README.md with inference instructions")

### Validate Exported Model

Test the exported adapters with inference.

In [None]:
# Inference validation template
print("🧪 Inference Validation Template:")
print("""
# Load exported model
# from transformers import AutoModelForCausalLM
# from peft import PeftModel
# 
# base_model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
# model = PeftModel.from_pretrained(base_model, KAGGLE_DIR)
# 
# # Test prompts
# test_prompts = [
#     "Is a verbal contract enforceable?",
#     "What is required to prove negligence?",
# ]
# 
# print("\n🔍 Testing fine-tuned model...")
# for prompt in test_prompts:
#     full_prompt = create_prompt_template(prompt)
#     inputs = tokenizer(full_prompt, return_tensors="pt")
#     
#     outputs = model.generate(
#         **inputs,
#         max_length=MAX_RESPONSE_LENGTH,
#         temperature=0.7,
#         do_sample=True
#     )
#     
#     response = tokenizer.decode(outputs[0], skip_special_tokens=True)
#     
#     # Validate format
#     has_valid_format = validate_xml_format(response)
#     reasoning, answer = extract_xml_content(response)
#     
#     if reasoning:
#         reasoning_tokens = len(tokenizer(reasoning)["input_ids"])
#     else:
#         reasoning_tokens = 0
#     
#     print(f"\n{'='*60}")
#     print(f"Prompt: {prompt}")
#     print(f"Valid format: {has_valid_format}")
#     print(f"Reasoning tokens: {reasoning_tokens}")
#     print(f"Response preview: {response[:200]}...")
#     
#     # Compute reward
#     reward = composite_reward_function(
#         [full_prompt],
#         [response],
#         [{"ground_truth": "test"}]
#     )[0]
#     print(f"Reward score: {reward:.2f}")
""")

print("\n✅ Validation template ready")
print("   Run this to verify model quality before submission")

In [None]:
import zipfile
import os

# Create zip archive
def create_submission_zip(source_dir: str, output_file: str):
    """
    Create a zip archive for Kaggle submission.
    
    Args:
        source_dir: Directory containing files to zip
        output_file: Output zip file path
    """
    with zipfile.ZipFile(output_file, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(source_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, source_dir)
                zipf.write(file_path, arcname)
                print(f"   Added: {arcname}")
    
    # Get zip file size
    size_mb = os.path.getsize(output_file) / (1024 * 1024)
    return size_mb

# Create submission
submission_zip = "./judicaita_submission.zip"
print("📦 Creating Kaggle submission package...")
print(f"   Source: {KAGGLE_DIR}")
print(f"   Output: {submission_zip}")
print("\n📄 Files included:")

try:
    size = create_submission_zip(KAGGLE_DIR, submission_zip)
    print(f"\n✅ Submission package created!")
    print(f"   File: {submission_zip}")
    print(f"   Size: {size:.2f} MB")
    
    print("\n📋 Submission Checklist:")
    print("   ✅ adapter_config.json")
    print("   ✅ README.md with instructions")
    print("   ⚠️  adapter_model.safetensors (add after training)")
    print("   ⚠️  Validation results (add after testing)")
    
    print("\n🎯 Next Steps:")
    print("   1. Complete GRPO training")
    print("   2. Export adapter weights to kaggle_upload/")
    print("   3. Run inference validation")
    print("   4. Re-run this cell to create final zip")
    print("   5. Upload to Kaggle competition")
    
except Exception as e:
    print(f"❌ Error creating zip: {e}")
    print("   Make sure kaggle_upload directory has content")

### 🔧 Troubleshooting Guide

#### Memory Errors
- **Reduce batch_size**: Try 2 or 1
- **Reduce num_generations**: Try 2 instead of 4
- **Reduce max_length**: Lower MAX_PROMPT_LENGTH and MAX_RESPONSE_LENGTH
- **Use smaller LoRA rank**: Try rank=8 instead of 16

#### TPU Timeouts
- **Reduce num_iterations**: Start with 1 iteration for testing
- **Save checkpoints frequently**: Checkpoint every 50 steps
- **Use gradient accumulation**: Increase gradient_accumulation_steps

#### Checkpoint Corruption
- **Keep multiple checkpoints**: Don't overwrite old checkpoints
- **Verify after saving**: Load checkpoint immediately after saving
- **Export early and often**: Export at multiple stages

#### Low Rewards
- **Check data quality**: Verify ground_truth accuracy
- **Adjust reward weights**: Increase correctness_weight
- **Increase training iterations**: More epochs may help
- **Verify XML format**: Check model generates proper tags

#### Export Issues
- **Check file permissions**: Ensure write access to kaggle_upload/
- **Verify adapter format**: Ensure safetensors format is used
- **Test loading**: Try loading exported adapters before zipping

## 🎉 Conclusion

This notebook demonstrated:
1. ✅ TPU setup and Gemma 3 model loading
2. ✅ XML-formatted dataset preparation for legal reasoning
3. ✅ Custom reward function with format, length, and correctness components
4. ✅ GRPO training configuration with LoRA adapters
5. ✅ Export and Kaggle submission preparation

### Next Steps

1. **Adapt Tunix API**: Update training code with actual google-tunix API
2. **Load Real Data**: Replace synthetic examples with actual legal datasets
3. **Run Training**: Execute GRPO training on TPU
4. **Validate Outputs**: Test model quality and XML format compliance
5. **Submit to Kaggle**: Upload trained adapters

### Resources

- [Judicaita Repository](https://github.com/clduab11/judicAIta)
- [Google Tunix Documentation](https://github.com/google/tunix)
- [Gemma Model Cards](https://ai.google.dev/gemma/docs)
- [TPU Best Practices](https://cloud.google.com/tpu/docs/best-practices)

### Feedback

If you encounter issues or have improvements:
- Open an issue: https://github.com/clduab11/judicAIta/issues
- Submit a PR with fixes or enhancements

---

**Made with ❤️ for the Kaggle hackathon and legal tech community**