# TT-10: Validation-Focused Training with Example-Based Learning

This notebook implements a validation-focused approach where the model is trained on examples (like test-time training) and validated on real comments with known labels.

**Key Concept:**
- **Training**: Model learns from positive/negative examples (not actual comments)
- **Validation**: Model predicts on real `body` comments with `rule_violation` labels
- **Analysis**: Comprehensive metrics to understand generalization from examples to real data

**Features:**
- **Stratified Sampling**: Controllable % of training data while maintaining rule distribution
- **Example-Based Training**: Similar to test-time training approach
- **Real Comment Validation**: Test on actual comments with ground truth labels
- **Comprehensive Metrics**: AUC, F1, Recall, Precision, Confusion Matrix
- **Visualizations**: Performance plots and analysis
- **4-bit + LoRA**: Memory-efficient training without DoRA for vLLM compatibility

**Benefits:**
- **Understand Generalization**: How well example-based training transfers to real comments
- **Validate Test-Time Training**: Effectiveness of the example-learning approach
- **Performance Analysis**: Detailed metrics with controllable data amounts

In [None]:
# Install dependencies - 4-bit BitsAndBytes + LoRA + Validation setup
!uv pip install --system --no-index --find-links='/kaggle/input/jigsaw-packages2/whls/' 'trl==0.21.0' 'optimum==1.27.0' 'bitsandbytes==0.46.1' 'deepspeed==0.17.4' 'logits-processor-zoo==0.2.1' 'vllm==0.10.0'
!uv pip install --system --no-index --find-links='/kaggle/input/jigsaw-packages2/whls/' 'triton==3.2.0'
!uv pip install --system --no-index --find-links='/kaggle/input/jigsaw-packages2/whls/' 'clean-text'
# Install latest PEFT for LoRA support (no DoRA for vLLM compatibility)
!uv pip install --system --no-index -U --no-deps --find-links='/kaggle/input/jigsaw-packages2/whls/' 'peft' 'accelerate' 'datasets'
# Install analysis libraries
!uv pip install --system --no-index --find-links='/kaggle/input/jigsaw-packages2/whls/' 'scikit-learn' 'matplotlib' 'seaborn'

print("✅ Dependencies installed for TT-10 validation-focused training")
print("📊 Analysis libraries: scikit-learn, matplotlib, seaborn")

# 1. Configuration and Data Setup

In [None]:
%%writefile constants.py
# Using base Qwen3 1.7B model from Kaggle input (no internet needed)
BASE_MODEL_PATH = "/kaggle/input/qwen-3/transformers/1.7b/1"  # Update this path as needed
LORA_PATH = "qwen3_1.7b_4bit_lora_validation/"  # 4-bit LoRA output path for validation
DATA_PATH = "/kaggle/input/jigsaw-agile-community-rules/"

# TT-10 Validation Parameters
TRAINING_DATA_PERCENTAGE = 1.0  # Controllable % of training data (0.1 = 10%, 1.0 = 100%)
USE_STRATIFIED_SAMPLING = True  # Maintain rule distribution when sampling

POSITIVE_ANSWER = "Yes"
NEGATIVE_ANSWER = "No"
COMPLETE_PHRASE = "Answer:"
BASE_PROMPT = '''You are given a comment from reddit and a rule. Your task is to classify whether the comment violates the rule. Only respond Yes/No.'''

print("✅ Using Qwen3 1.7B model from local Kaggle input")
print(f"🎯 TT-10: Example-based training with {TRAINING_DATA_PERCENTAGE*100:.0f}% of data")
print(f"📊 Stratified sampling: {USE_STRATIFIED_SAMPLING}")

In [None]:
%%writefile utils.py
import pandas as pd
from datasets import Dataset
from constants import POSITIVE_ANSWER, NEGATIVE_ANSWER, COMPLETE_PHRASE, BASE_PROMPT, TRAINING_DATA_PERCENTAGE, USE_STRATIFIED_SAMPLING
import random, numpy as np
from sklearn.model_selection import train_test_split
random.seed(42)
np.random.seed(42)


def build_prompt(row):
    return f"""
{BASE_PROMPT}

Subreddit: r/{row["subreddit"]}
Rule: {row["rule"]}
Examples:
1) {row["positive_example"]}
{COMPLETE_PHRASE} Yes

2) {row["negative_example"]}
{COMPLETE_PHRASE} No

---
Comment: {row["body"]}
{COMPLETE_PHRASE}"""


def get_example_based_training_data(data_path):
    """
    TT-10: Create training data from examples (like test-time training)
    This trains the model on examples, not actual comments
    """
    train_dataset = pd.read_csv(f"{data_path}/train.csv")
    
    # Sample data if needed while maintaining rule distribution
    if TRAINING_DATA_PERCENTAGE < 1.0:
        if USE_STRATIFIED_SAMPLING:
            # Stratified sampling to maintain rule distribution
            train_dataset = train_dataset.groupby('rule', group_keys=False).apply(
                lambda x: x.sample(frac=TRAINING_DATA_PERCENTAGE, random_state=42)
            ).reset_index(drop=True)
            print(f"📊 Stratified sampling: {len(train_dataset)} samples ({TRAINING_DATA_PERCENTAGE*100:.0f}%)")
        else:
            # Simple random sampling
            train_dataset = train_dataset.sample(frac=TRAINING_DATA_PERCENTAGE, random_state=42).reset_index(drop=True)
            print(f"📊 Random sampling: {len(train_dataset)} samples ({TRAINING_DATA_PERCENTAGE*100:.0f}%)")
    
    print(f"📊 Training data size: {len(train_dataset)} samples")
    print(f"📊 Rule distribution: {train_dataset['rule'].value_counts().to_dict()}")
    
    flatten = []
    
    # Create training data from examples (similar to test-time training)
    for violation_type in ["positive", "negative"]:
        for i in range(1, 3):
            sub_dataset = train_dataset[["rule","subreddit",
                                        "positive_example_1","positive_example_2",
                                        "negative_example_1","negative_example_2"]].copy()

            if violation_type == "positive":
                # Use positive example as the "body" to classify
                body_col = f"positive_example_{i}"
                other_positive_col = f"positive_example_{3-i}"  # other positive
                sub_dataset["body"] = sub_dataset[body_col]
                sub_dataset["positive_example"] = sub_dataset[other_positive_col]
                # negative_example randomly selected
                sub_dataset["negative_example"] = np.where(
                    np.random.rand(len(sub_dataset)) < 0.5,
                    sub_dataset["negative_example_1"],
                    sub_dataset["negative_example_2"]
                )
                sub_dataset["rule_violation"] = 1  # Positive examples violate rules

            else:  # violation_type == "negative"
                # Use negative example as the "body" to classify
                body_col = f"negative_example_{i}"
                other_negative_col = f"negative_example_{3-i}"
                sub_dataset["body"] = sub_dataset[body_col]
                sub_dataset["negative_example"] = sub_dataset[other_negative_col]
                sub_dataset["positive_example"] = np.where(
                    np.random.rand(len(sub_dataset)) < 0.5,
                    sub_dataset["positive_example_1"],
                    sub_dataset["positive_example_2"]
                )
                sub_dataset["rule_violation"] = 0  # Negative examples don't violate rules

            # Drop original candidate columns
            sub_dataset.drop(columns=["positive_example_1","positive_example_2",
                                      "negative_example_1","negative_example_2"], inplace=True)

            flatten.append(sub_dataset)

    # Merge all DataFrames
    example_training_df = pd.concat(flatten, axis=0)
    example_training_df = example_training_df.drop_duplicates(ignore_index=True)
    
    print(f"📊 Example-based training dataset: {len(example_training_df)} samples")
    print(f"📊 Positive examples: {sum(example_training_df['rule_violation'] == 1)}")
    print(f"📊 Negative examples: {sum(example_training_df['rule_violation'] == 0)}")
    
    return example_training_df


def get_real_comment_validation_data(data_path):
    """
    TT-10: Get real comments with labels for validation
    This is what we actually want to predict
    """
    train_dataset = pd.read_csv(f"{data_path}/train.csv")
    
    # Use actual comments and their labels for validation
    validation_df = train_dataset[["body", "rule", "subreddit", "rule_violation",
                                  "positive_example_1","positive_example_2",
                                  "negative_example_1","negative_example_2"]].copy()

    # Randomly select positive_example and negative_example for prompts
    validation_df["positive_example"] = np.where(
        np.random.rand(len(validation_df)) < 0.5,
        validation_df["positive_example_1"],
        validation_df["positive_example_2"]
    )
    validation_df["negative_example"] = np.where(
        np.random.rand(len(validation_df)) < 0.5,
        validation_df["negative_example_1"],
        validation_df["negative_example_2"]
    )

    # Drop original candidate columns
    validation_df.drop(columns=["positive_example_1","positive_example_2",
                               "negative_example_1","negative_example_2"], inplace=True)
    
    print(f"📊 Real comment validation dataset: {len(validation_df)} samples")
    print(f"📊 Rule violations: {sum(validation_df['rule_violation'] == 1)} positive, {sum(validation_df['rule_violation'] == 0)} negative")
    
    return validation_df


def build_dataset(dataframe):
    dataframe["prompt"] = dataframe.apply(build_prompt, axis=1)

    columns = ["prompt"]
    if "rule_violation" in dataframe:
        dataframe["completion"] = dataframe["rule_violation"].map(
            {
                1: POSITIVE_ANSWER,
                0: NEGATIVE_ANSWER,
            }
        )
        columns.append("completion")

    dataframe = dataframe[columns]
    dataset = Dataset.from_pandas(dataframe)
    return dataset


def build_validation_dataset(dataframe):
    """Build dataset for validation (keep labels for evaluation)"""
    dataframe["prompt"] = dataframe.apply(build_prompt, axis=1)
    dataframe = dataframe[["prompt", "rule_violation"]]  # Keep true labels for evaluation
    dataset = Dataset.from_pandas(dataframe)
    return dataset

In [None]:
%%writefile train.py
import pandas as pd
import torch
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from tqdm.auto import tqdm
from transformers.utils import is_torch_bf16_gpu_available
from utils import build_dataset, get_example_based_training_data
from constants import DATA_PATH, BASE_MODEL_PATH, LORA_PATH


def main():
    # TT-10: Get example-based training data (train on examples, not real comments)
    train_df = get_example_based_training_data(DATA_PATH)
    train_dataset = build_dataset(train_df)
    
    print(f"Training dataset size: {len(train_dataset)} samples")
    
    # BitsAndBytes 4-bit quantization config
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,  # Enable 4-bit quantization
        bnb_4bit_compute_dtype=torch.float16,  # Compute in FP16
        bnb_4bit_use_double_quant=True,  # Use double quantization for better quality
        bnb_4bit_quant_type="nf4"  # Use NF4 quantization
    )
    print("✅ BitsAndBytes 4-bit quantization config created")
    
    # LoRA configuration (no DoRA for vLLM compatibility)
    lora_config = LoraConfig(
        r=16,  # LoRA rank
        lora_alpha=32,  # LoRA alpha  
        lora_dropout=0.05,  # LoRA dropout
        bias="none",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        task_type="CAUSAL_LM",
        # No use_dora=True for vLLM compatibility
    )
    print("✅ LoRA config created (no DoRA for vLLM compatibility)")
    
    # Training config optimized for validation
    training_args = SFTConfig(
        num_train_epochs=1,  # Single epoch for validation
        
        # Batch sizes for 4-bit training
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,  # Effective batch size = 4*4*2 = 32
        
        optim="paged_adamw_8bit",  # 8-bit optimizer for memory efficiency
        learning_rate=1e-4,  # Learning rate
        weight_decay=0.01,
        max_grad_norm=1.0,
        
        lr_scheduler_type="cosine",
        warmup_ratio=0.03,
        
        bf16=is_torch_bf16_gpu_available(),
        fp16=not is_torch_bf16_gpu_available(),
        dataloader_pin_memory=True,
        
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
    
        save_strategy="no",  # Don't save during validation training
        report_to="none",
    
        completion_only_loss=True,
        packing=False,
        remove_unused_columns=False,
    )
    print("✅ Training config created for example-based learning")
    
    # Load model with BitsAndBytes quantization
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_PATH,
        quantization_config=quantization_config,
        torch_dtype=torch.float16,
        # Remove device_map="auto" to avoid distributed training conflicts
        trust_remote_code=True,
        local_files_only=True,  # Use only local files (no internet)
    )
    print("✅ Base model loaded with 4-bit quantization")
    
    # Create SFTTrainer
    trainer = SFTTrainer(
        model=base_model,  # Pass loaded model directly
        args=training_args,
        train_dataset=train_dataset,
        peft_config=lora_config,
    )
    
    print("🚀 Starting example-based training (like test-time training)...")
    trainer.train()
    
    # Save LoRA adapters
    trainer.save_model(LORA_PATH)
    print(f"✅ LoRA adapters saved to: {LORA_PATH}")


if __name__ == "__main__":
    main()

In [None]:
%%writefile validation.py
import os
os.environ["VLLM_USE_V1"] = "0"

import vllm
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score, 
                           roc_auc_score, confusion_matrix, classification_report, roc_curve)
from logits_processor_zoo.vllm import MultipleChoiceLogitsProcessor
from vllm.lora.request import LoRARequest
from utils import build_validation_dataset, get_real_comment_validation_data
from constants import BASE_MODEL_PATH, LORA_PATH, DATA_PATH, POSITIVE_ANSWER, NEGATIVE_ANSWER


def run_validation():
    """Run validation on real comments using example-trained model"""
    
    # Get real comment validation data
    val_df = get_real_comment_validation_data(DATA_PATH)
    val_dataset = build_validation_dataset(val_df)
    
    print(f"🔍 Running validation on {len(val_dataset)} real comments")
    
    # Initialize vLLM with LoRA
    llm = vllm.LLM(
        BASE_MODEL_PATH,
        tensor_parallel_size=1,
        gpu_memory_utilization=0.95,
        trust_remote_code=True,
        dtype="half",
        enforce_eager=True,
        max_model_len=2836,
        disable_log_stats=True,
        enable_prefix_caching=True,
        enable_lora=True,
        max_lora_rank=64,
    )

    tokenizer = llm.get_tokenizer()
    mclp = MultipleChoiceLogitsProcessor(tokenizer, choices=[POSITIVE_ANSWER, NEGATIVE_ANSWER])

    texts = val_dataset["prompt"]
    true_labels = val_dataset["rule_violation"]

    # Generate predictions
    outputs = llm.generate(
        texts,
        vllm.SamplingParams(
            skip_special_tokens=True,
            max_tokens=1,
            logits_processors=[mclp],
            logprobs=2,  # Get log probabilities for AUC calculation
        ),
        use_tqdm=True,
        lora_request=LoRARequest("default", 1, LORA_PATH)
    )

    # Extract predictions and probabilities
    predictions = []
    probabilities = []  # For AUC calculation
    
    for i, out in enumerate(outputs):
        log_probs = {lp.decoded_token: lp.logprob for lp in out.outputs[0].logprobs[0].values()}
        
        # Get prediction (highest probability)
        if POSITIVE_ANSWER in log_probs and NEGATIVE_ANSWER in log_probs:
            if log_probs[POSITIVE_ANSWER] > log_probs[NEGATIVE_ANSWER]:
                predictions.append(1)
            else:
                predictions.append(0)
            
            # Calculate probability for positive class (for AUC)
            exp_pos = np.exp(log_probs[POSITIVE_ANSWER])
            exp_neg = np.exp(log_probs[NEGATIVE_ANSWER])
            prob_positive = exp_pos / (exp_pos + exp_neg)
            probabilities.append(prob_positive)
        else:
            # Fallback if logprobs not available
            predictions.append(0)
            probabilities.append(0.5)

    return true_labels, predictions, probabilities, val_df


def calculate_and_display_metrics(true_labels, predictions, probabilities):
    """Calculate comprehensive metrics and display results"""
    
    # Basic metrics
    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions)
    recall = recall_score(true_labels, predictions)
    auc = roc_auc_score(true_labels, probabilities)
    
    print("=" * 60)
    print("📊 TT-10 VALIDATION RESULTS")
    print("=" * 60)
    print(f"🎯 Accuracy:  {accuracy:.4f}")
    print(f"🎯 F1 Score:  {f1:.4f}")
    print(f"🎯 Precision: {precision:.4f}")
    print(f"🎯 Recall:    {recall:.4f}")
    print(f"🎯 AUC Score: {auc:.4f}")
    print("=" * 60)
    
    # Confusion matrix
    cm = confusion_matrix(true_labels, predictions)
    print("\n📈 Confusion Matrix:")
    print(f"True Negative: {cm[0,0]:4d} | False Positive: {cm[0,1]:4d}")
    print(f"False Negative: {cm[1,0]:4d} | True Positive:  {cm[1,1]:4d}")
    
    # Classification report
    print("\n📋 Classification Report:")
    print(classification_report(true_labels, predictions, target_names=['No Violation', 'Violation']))
    
    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'auc': auc,
        'confusion_matrix': cm
    }


def create_visualizations(true_labels, predictions, probabilities, metrics):
    """Create comprehensive visualizations"""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('TT-10: Example-Based Training Validation Results', fontsize=16, fontweight='bold')
    
    # 1. Confusion Matrix Heatmap
    cm = metrics['confusion_matrix']
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0,0],
                xticklabels=['No Violation', 'Violation'],
                yticklabels=['No Violation', 'Violation'])
    axes[0,0].set_title('Confusion Matrix')
    axes[0,0].set_xlabel('Predicted')
    axes[0,0].set_ylabel('Actual')
    
    # 2. ROC Curve
    fpr, tpr, _ = roc_curve(true_labels, probabilities)
    axes[0,1].plot(fpr, tpr, linewidth=2, label=f'ROC Curve (AUC = {metrics["auc"]:.3f})')
    axes[0,1].plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random Classifier')
    axes[0,1].set_xlabel('False Positive Rate')
    axes[0,1].set_ylabel('True Positive Rate')
    axes[0,1].set_title('ROC Curve')
    axes[0,1].legend()
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Probability Distribution
    pos_probs = [probabilities[i] for i in range(len(probabilities)) if true_labels[i] == 1]
    neg_probs = [probabilities[i] for i in range(len(probabilities)) if true_labels[i] == 0]
    
    axes[1,0].hist(neg_probs, bins=30, alpha=0.7, label='No Violation', color='blue', density=True)
    axes[1,0].hist(pos_probs, bins=30, alpha=0.7, label='Violation', color='red', density=True)
    axes[1,0].set_xlabel('Predicted Probability')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Probability Distribution by True Label')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Metrics Bar Chart
    metric_names = ['Accuracy', 'F1 Score', 'Precision', 'Recall', 'AUC']
    metric_values = [metrics['accuracy'], metrics['f1'], metrics['precision'], metrics['recall'], metrics['auc']]
    
    bars = axes[1,1].bar(metric_names, metric_values, color=['skyblue', 'lightgreen', 'orange', 'pink', 'gold'])
    axes[1,1].set_ylabel('Score')
    axes[1,1].set_title('Performance Metrics')
    axes[1,1].set_ylim(0, 1)
    axes[1,1].grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar, value in zip(bars, metric_values):
        axes[1,1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                      f'{value:.3f}', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('/kaggle/working/tt10_validation_results.png', dpi=300, bbox_inches='tight')
    plt.show()


def analyze_by_rule(true_labels, predictions, probabilities, val_df):
    """Analyze performance by rule type"""
    
    # Add predictions to dataframe
    analysis_df = val_df.copy()
    analysis_df['predictions'] = predictions
    analysis_df['probabilities'] = probabilities
    
    print("\n📊 PERFORMANCE BY RULE:")
    print("=" * 60)
    
    rule_metrics = []
    for rule in analysis_df['rule'].unique():
        rule_data = analysis_df[analysis_df['rule'] == rule]
        
        rule_true = rule_data['rule_violation'].values
        rule_pred = rule_data['predictions'].values
        rule_prob = rule_data['probabilities'].values
        
        if len(np.unique(rule_true)) > 1:  # Check if both classes exist
            rule_auc = roc_auc_score(rule_true, rule_prob)
        else:
            rule_auc = np.nan
            
        rule_acc = accuracy_score(rule_true, rule_pred)
        rule_f1 = f1_score(rule_true, rule_pred) if len(np.unique(rule_true)) > 1 else np.nan
        
        print(f"Rule: {rule}")
        print(f"  Samples: {len(rule_data)}")
        print(f"  Accuracy: {rule_acc:.3f}")
        print(f"  F1 Score: {rule_f1:.3f}" if not np.isnan(rule_f1) else "  F1 Score: N/A")
        print(f"  AUC Score: {rule_auc:.3f}" if not np.isnan(rule_auc) else "  AUC Score: N/A")
        print()
        
        rule_metrics.append({
            'rule': rule,
            'samples': len(rule_data),
            'accuracy': rule_acc,
            'f1': rule_f1,
            'auc': rule_auc
        })
    
    # Save detailed results
    analysis_df.to_csv('/kaggle/working/tt10_detailed_results.csv', index=False)
    pd.DataFrame(rule_metrics).to_csv('/kaggle/working/tt10_rule_metrics.csv', index=False)
    
    return rule_metrics


def main():
    print("🔬 TT-10: Example-Based Training Validation")
    print("📚 Training: Model learned from examples (like test-time training)")
    print("🧪 Validation: Testing on real comments with ground truth labels")
    print("=" * 70)
    
    # Run validation
    true_labels, predictions, probabilities, val_df = run_validation()
    
    # Calculate metrics
    metrics = calculate_and_display_metrics(true_labels, predictions, probabilities)
    
    # Create visualizations
    create_visualizations(true_labels, predictions, probabilities, metrics)
    
    # Analyze by rule
    rule_metrics = analyze_by_rule(true_labels, predictions, probabilities, val_df)
    
    print("✅ Validation completed!")
    print("📈 Visualizations saved: /kaggle/working/tt10_validation_results.png")
    print("📊 Detailed results: /kaggle/working/tt10_detailed_results.csv")
    print("📋 Rule metrics: /kaggle/working/tt10_rule_metrics.csv")
    
    return metrics, rule_metrics


if __name__ == "__main__":
    main()

In [None]:
%%writefile accelerate_config.yaml
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 4
  gradient_clipping: 1.0
  train_batch_size: 32
  train_micro_batch_size_per_gpu: 4
  
  zero_stage: 2
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  
  stage3_gather_16bit_weights_on_model_save: false
  stage3_max_live_parameters: 1e8
  stage3_max_reuse_distance: 1e8
  stage3_prefetch_bucket_size: 5e7
  stage3_param_persistence_threshold: 1e5
  
  zero_allow_untested_optimizer: true
  zero_force_ds_cpu_optimizer: false
  
  fp16:
    enabled: true
    loss_scale: 0
    initial_scale_power: 16
    loss_scale_window: 1000
    hysteresis: 2
    min_loss_scale: 1
  
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_config:
  dynamo_backend: INDUCTOR
  dynamo_use_fullgraph: false
  dynamo_use_dynamic: false
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

In [None]:
!accelerate launch --config_file accelerate_config.yaml train.py

In [None]:
!python validation.py

In [None]:
# Display saved results
import pandas as pd
import matplotlib.pyplot as plt

# Load detailed results
try:
    detailed_results = pd.read_csv('/kaggle/working/tt10_detailed_results.csv')
    print("📊 Detailed Results Shape:", detailed_results.shape)
    print("\n📋 Sample Results:")
    print(detailed_results[['rule', 'rule_violation', 'predictions', 'probabilities']].head(10))
    
    # Load rule metrics
    rule_metrics = pd.read_csv('/kaggle/working/tt10_rule_metrics.csv')
    print("\n📈 Rule-wise Performance:")
    print(rule_metrics)
    
except FileNotFoundError as e:
    print(f"❌ Results files not found: {e}")
    print("Run the validation cell first to generate results.")

In [None]:
# Additional analysis - data distribution and performance insights
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

try:
    detailed_results = pd.read_csv('/kaggle/working/tt10_detailed_results.csv')
    
    # Analyze performance by confidence level
    print("🎯 Performance Analysis by Confidence Level:")
    print("=" * 50)
    
    # Create confidence bins
    detailed_results['confidence'] = np.abs(detailed_results['probabilities'] - 0.5) * 2  # 0 = least confident, 1 = most confident
    detailed_results['confidence_bin'] = pd.cut(detailed_results['confidence'], 
                                               bins=[0, 0.3, 0.6, 1.0], 
                                               labels=['Low', 'Medium', 'High'])
    
    # Calculate accuracy by confidence bin
    confidence_analysis = detailed_results.groupby('confidence_bin').agg({
        'rule_violation': 'count',
        'predictions': lambda x: accuracy_score(detailed_results.loc[x.index, 'rule_violation'], x)
    }).rename(columns={'rule_violation': 'sample_count', 'predictions': 'accuracy'})
    
    print(confidence_analysis)
    
    # Data distribution analysis
    print("\n📊 Data Distribution Analysis:")
    print("=" * 50)
    print("Overall rule violation distribution:")
    print(detailed_results['rule_violation'].value_counts(normalize=True))
    
    print("\nRule violation distribution by rule:")
    rule_dist = detailed_results.groupby('rule')['rule_violation'].agg(['count', 'mean'])
    rule_dist.columns = ['total_samples', 'violation_rate']
    print(rule_dist)
    
except FileNotFoundError:
    print("❌ Run validation first to generate analysis data.")
except Exception as e:
    print(f"❌ Analysis error: {e}")

# 📊 TT-10 Analysis Guide

## 🎯 **What TT-10 Tests:**
- **Generalization**: How well example-based training transfers to real comments
- **Effectiveness**: Whether test-time training approach works for this task
- **Performance**: Comprehensive metrics on real comment classification

## 🔧 **How to Adjust Training Data:**

### **Change Data Percentage** (Cell 4 - `constants.py`):
```python
TRAINING_DATA_PERCENTAGE = 0.5  # Use 50% of training data
TRAINING_DATA_PERCENTAGE = 0.1  # Use 10% of training data
TRAINING_DATA_PERCENTAGE = 1.0  # Use 100% of training data (default)
```

### **Toggle Stratified Sampling** (Cell 4 - `constants.py`):
```python
USE_STRATIFIED_SAMPLING = True   # Maintain rule distribution (recommended)
USE_STRATIFIED_SAMPLING = False  # Random sampling
```

## 📈 **Understanding Results:**

### **Key Metrics:**
- **AUC Score**: Most important - measures discrimination ability (0.5 = random, 1.0 = perfect)
- **F1 Score**: Balance of precision and recall
- **Accuracy**: Overall correctness
- **Confusion Matrix**: Detailed breakdown of correct/incorrect predictions

### **Visualizations Generated:**
1. **Confusion Matrix**: Shows prediction accuracy breakdown
2. **ROC Curve**: Illustrates true vs false positive rates
3. **Probability Distribution**: How confident the model is by true label
4. **Metrics Bar Chart**: Visual comparison of all performance metrics

### **Rule-wise Analysis**:
- Performance broken down by individual rules
- Identifies which rules are easier/harder to learn
- Shows sample distribution across rules

## 🚀 **Optimization Tips:**

### **If Performance is Low:**
1. **Increase Training Data**: Set `TRAINING_DATA_PERCENTAGE = 1.0`
2. **Adjust LoRA Parameters**: Increase rank (`r=32`) in `train.py`
3. **More Training**: Increase `num_train_epochs` in `train.py`

### **If Training is Too Slow:**
1. **Reduce Data**: Set `TRAINING_DATA_PERCENTAGE = 0.3`
2. **Smaller Batches**: Reduce `per_device_train_batch_size` in `train.py`
3. **Lower Rank**: Reduce LoRA rank (`r=8`) in `train.py`

## 💡 **Key Insights:**
- **High AUC (>0.8)**: Example-based training works well
- **Low AUC (<0.6)**: May need more data or different approach
- **Rule Variation**: Some rules may be inherently harder to learn
- **Confidence Analysis**: Higher confidence predictions should be more accurate

This validation approach helps understand whether the test-time training methodology is effective for your specific classification task!

# 🚀 TT-10 Speed Optimization Guide

## ⚡ **Training Speed Optimization Strategies**

### **1. Alternative Libraries for Faster Training**

#### **A. Unsloth - 2x-5x Faster Training**
```python
# Replace TRL + PEFT with Unsloth for massive speed gains
!pip install unsloth

from unsloth import FastLanguageModel
import torch

# Load model with Unsloth (much faster than standard transformers)
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL_PATH,
    max_seq_length=2048,
    dtype=None,  # Auto-detect
    load_in_4bit=True,
)

# Add LoRA adapters (Unsloth handles this automatically)
model = FastLanguageModel.get_peft_model(
    model,
    r=16,  # LoRA rank
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing=True,
    random_state=3407,
)

# Training arguments (much faster convergence)
training_args = TrainingArguments(
    per_device_train_batch_size=2,  # Smaller batches work better with Unsloth
    gradient_accumulation_steps=4,
    warmup_steps=5,
    max_steps=60,  # Converges much faster
    learning_rate=2e-4,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    logging_steps=1,
    optim="adamw_8bit",
    weight_decay=0.01,
    lr_scheduler_type="linear",
    seed=3407,
    output_dir="outputs",
)

# Use Unsloth's trainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=2048,
    dataset_num_proc=2,
    packing=False,
    args=training_args,
)
```

#### **B. Flash Attention v2 - 2x Faster Attention**
```python
# Install Flash Attention
!pip install flash-attn --no-build-isolation

# Use with transformers
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_PATH,
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",  # Enable Flash Attention
    quantization_config=quantization_config,
)
```

#### **C. xFormers - Memory Efficient Attention**
```python
# Alternative to Flash Attention
!pip install xformers

# Use with transformers
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_PATH,
    torch_dtype=torch.float16,
    attn_implementation="sdpa",  # Use xFormers via SDPA
)
```

### **2. Quantization Alternatives**

#### **A. GPTQ Quantization (Faster Inference)**
```python
# Use GPTQ instead of BitsAndBytes for faster inference
!pip install auto-gptq optimum

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_PATH,
    device_map="auto",
    torch_dtype=torch.float16,
    quantization_config=GPTQConfig(
        bits=4,
        dataset="c4",
        tokenizer=tokenizer,
    )
)
```

#### **B. AWQ Quantization (Better Quality)**
```python
# AWQ often gives better quality than GPTQ
!pip install autoawq

from awq import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(
    BASE_MODEL_PATH,
    quant_file="awq_model_w4_g128.pt",
    fuse_layers=True,
)
```

### **3. Training Framework Alternatives**

#### **A. PyTorch Lightning - Better Structure**
```python
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor

class LoRALightningModule(L.LightningModule):
    def __init__(self, model, tokenizer):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        
    def training_step(self, batch, batch_idx):
        outputs = self.model(**batch, labels=batch["input_ids"])
        return outputs.loss
        
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=2e-4)

# Usage
trainer = L.Trainer(
    max_epochs=1,
    accelerator="gpu",
    devices=1,
    precision="16-mixed",
    callbacks=[
        ModelCheckpoint(save_top_k=1, monitor="train_loss"),
        LearningRateMonitor(logging_interval="step"),
    ]
)
```

#### **B. HuggingFace Accelerate - Simpler Multi-GPU**
```python
from accelerate import Accelerator

accelerator = Accelerator()

# Automatic device placement
model, optimizer, train_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader
)

# Automatic mixed precision
with accelerator.autocast():
    outputs = model(**batch)
    loss = outputs.loss

accelerator.backward(loss)
accelerator.step(optimizer)
```

### **4. Data Processing Optimizations**

#### **A. Faster Data Loading with HuggingFace Datasets**
```python
from datasets import load_dataset, Dataset
import multiprocessing

# Use streaming for large datasets
dataset = load_dataset(
    "csv", 
    data_files="train.csv",
    streaming=True
)

# Parallel processing
dataset = dataset.map(
    preprocess_function,
    num_proc=multiprocessing.cpu_count(),
    batched=True,
    batch_size=1000,
)
```

#### **B. Memory-Efficient Data Processing**
```python
# Use IterableDataset for memory efficiency
from torch.utils.data import IterableDataset

class StreamingDataset(IterableDataset):
    def __init__(self, data_path):
        self.data_path = data_path
        
    def __iter__(self):
        for chunk in pd.read_csv(self.data_path, chunksize=1000):
            for item in chunk:
                yield self.preprocess(item)
```

### **5. Advanced Speed Techniques**

#### **A. Gradient Checkpointing + Offloading**
```python
# Combine gradient checkpointing with CPU offloading
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# Use disk offloading for very large models
from accelerate import disk_offload
model = disk_offload(model, offload_dir="./offload")
```

#### **B. Dynamic Batch Sizing**
```python
# Adjust batch size based on available memory
def get_optimal_batch_size(model, tokenizer, max_length=2048):
    """Find largest batch size that fits in memory"""
    batch_size = 1
    while True:
        try:
            # Test if batch fits
            inputs = tokenizer("test" * max_length, return_tensors="pt")
            inputs = {k: v.repeat(batch_size, 1) for k, v in inputs.items()}
            with torch.no_grad():
                model(**inputs.to(model.device))
            batch_size *= 2
        except RuntimeError:
            return batch_size // 2
```

#### **C. Model Parallelism**
```python
# Use model parallelism for very large models
from accelerate import load_checkpoint_and_dispatch

model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH)
model = load_checkpoint_and_dispatch(
    model,
    checkpoint=BASE_MODEL_PATH,
    device_map="auto",  # Automatic device mapping
    max_memory={0: "10GB", 1: "10GB"},  # Memory limits per GPU
)
```

### **6. Performance Comparison Table**

| Method | Speed Improvement | Memory Usage | Quality Impact | Complexity |
|--------|------------------|--------------|----------------|------------|
| **Unsloth** | 2x-5x faster | Same | Same/Better | Low |
| **Flash Attention** | 2x faster | 20% less | Same | Low |
| **GPTQ** | 3x faster inference | 75% less | Slight decrease | Medium |
| **AWQ** | 3x faster inference | 75% less | Minimal decrease | Medium |
| **PyTorch Lightning** | Same | Same | Same | Low |
| **Accelerate** | Same | Better distribution | Same | Low |

### **7. Recommended Optimization Stack**

For **maximum speed** with TT-10:
```python
# 1. Use Unsloth for training
!pip install unsloth

# 2. Add Flash Attention
!pip install flash-attn

# 3. Use GPTQ for inference
!pip install auto-gptq

# 4. Optimize data loading
from datasets import Dataset
dataset = Dataset.from_csv("train.csv").to_iterable_dataset()
```

### **8. Monitoring and Profiling**

#### **A. Training Speed Monitoring**
```python
import time
from torch.profiler import profile, record_function, ProfilerActivity

# Profile training step
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    trainer.train()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
```

#### **B. Memory Usage Tracking**
```python
import psutil
import GPUtil

def log_system_stats():
    # CPU and RAM
    cpu_percent = psutil.cpu_percent()
    ram_percent = psutil.virtual_memory().percent
    
    # GPU
    gpus = GPUtil.getGPUs()
    gpu_percent = gpus[0].memoryUtil * 100 if gpus else 0
    
    print(f"CPU: {cpu_percent}%, RAM: {ram_percent}%, GPU: {gpu_percent}%")
```

### **9. Quick Wins (5-30 minutes)**

1. **Reduce data size**: `TRAINING_DATA_PERCENTAGE = 0.3`
2. **Use smaller batches**: `per_device_train_batch_size = 2`
3. **Enable gradient checkpointing**: `gradient_checkpointing=True`
4. **Use 8-bit optimizer**: `optim="paged_adamw_8bit"`
5. **Reduce LoRA rank**: `r=8` instead of `r=16`

### **10. Advanced Optimizations (1-4 hours)**

1. **Switch to Unsloth**: Replace TRL training loop
2. **Add Flash Attention**: Modify model loading
3. **Use GPTQ quantization**: For inference speed
4. **Implement dynamic batching**: Automatic batch size optimization
5. **Add model parallelism**: For multi-GPU setups

**Expected Results:**
- **Unsloth + Flash Attention**: 3x-8x faster training
- **GPTQ Inference**: 5x-10x faster inference
- **Combined optimizations**: 10x-20x total speedup

These optimizations can dramatically reduce training time while maintaining or even improving model quality!
```