<a href="https://colab.research.google.com/github/gut-puncture/double-inference/blob/main/double_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
### 1️⃣ Setup & Config

# Fix version conflicts by installing compatible versions
!pip uninstall -y torch torchvision torchaudio
!pip install -q torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
!pip install -q transformers==4.44.0 accelerate datasets==2.18.0

from google.colab import drive
drive.mount('/content/drive')

# All imports consolidated (removed lm-eval to avoid conflicts)
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Tuple, Dict, Optional, List
import time
import json
from pathlib import Path
import datetime
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
from dataclasses import dataclass

# GPU check
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cpu': raise ValueError('No GPU available')
gpu_name = torch.cuda.get_device_name(0)
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f'GPU: {gpu_name}')
print(f'VRAM: {vram_gb:.1f} GB')

# Config with constants
@dataclass
class ExperimentConfig:
    model_path: str = '/content/drive/MyDrive/phi3_3.8B'
    max_length: int = 2048
    temperature: float = 0.7
    top_p: float = 0.9
    seed: int = 42
    pass_type: str = 'baseline'
    second_pass_layers: int = None
    residual_variant: str = 'raw'
    attn_impl: str = 'auto'
    entropy_eps: float = 1e-9  # New constant
    min_vram_gb: float = 10.0  # New constant

config = ExperimentConfig()
if vram_gb < config.min_vram_gb: raise ValueError('Insufficient VRAM')
torch.manual_seed(config.seed)
random.seed(config.seed)
np.random.seed(config.seed)
print(config)


Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
GPU: NVIDIA A100-SXM4-40GB
VRAM: 42.5 GB
ExperimentConfig(model_path='/content/drive/MyDrive/phi3_3.8B', max_length=2048, temperature=0.7, top_p=0.9, seed=42, pass_type='baseline', second_pass_layers=None, residual_variant='raw', attn_impl='auto', entropy_eps=1e-09, min_vram_gb=10.0)


In [3]:
### 2️⃣ Utility – Attention Implementation Fallback

def get_fallback_chain(requested: str) -> List[str]:
    # Always use eager attention for Phi-3 compatibility
    return ['eager']

def resolve_attn_impl(requested: str) -> str:
    return get_fallback_chain(requested)[0]

In [4]:
### 3️⃣ Utility – Robust Embedding Builder

def build_inputs_embeds(model, input_ids: torch.Tensor) -> torch.Tensor:
    token_embeds = model.get_input_embeddings()(input_ids)
    seq_len = input_ids.size(1)
    position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
    pos_attrs = [
        ('model.embed_positions', lambda m: getattr(m.model, 'embed_positions', None)),
        ('wpe', lambda m: getattr(m.transformer, 'wpe', None) if hasattr(m, 'transformer') else None),
        ('embeddings.position_embeddings', lambda m: getattr(m.embeddings, 'position_embeddings', None) if hasattr(m, 'embeddings') else None)
    ]
    for _, getter in pos_attrs:
        pos_layer = getter(model)
        if pos_layer is not None and callable(pos_layer):
            pos_embeds = pos_layer(position_ids)
            if pos_embeds is not None and pos_embeds.shape[:2] == token_embeds.shape[:2]:
                token_embeds += pos_embeds
                break
    return token_embeds


In [5]:
### 4️⃣ Class DoublePassPhi3

class DoublePassPhi3(torch.nn.Module):
    def __init__(self, model_path: str, config: ExperimentConfig):
        super().__init__()
        self.config = config
        self.device = torch.device('cuda')
        attn_impl = resolve_attn_impl(config.attn_impl)

        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path, torch_dtype=torch.float16, device_map='auto',
                trust_remote_code=True, attn_implementation=attn_impl
            )
            print(f'✅ Model loaded successfully with {attn_impl} attention')
        except Exception as e:
            error_msg = f'❌ Failed to load model with {attn_impl} attention: {str(e)}'
            raise RuntimeError(error_msg)

        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model.eval()
        self.num_layers = len(self.model.model.layers)

    # NO CHANGE to get_residual_stream, it is correct.
    def get_residual_stream(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
                           past_key_values=None, use_cache: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[List]]:
        with torch.no_grad():
            outputs = self.model(
                input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                use_cache=use_cache,
                output_hidden_states=True
            )
            logits = outputs.logits
            residual_stream = outputs.hidden_states[-1]
            new_past_key_values = outputs.past_key_values if use_cache else None
            return logits, residual_stream, new_past_key_values

    # BUG FIX 1: Added position_ids to the layer call.
    def second_pass_forward(self, residual_stream: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
                      num_layers: Optional[int] = None) -> torch.Tensor:
      with torch.no_grad():
        hidden_states = self.model.model.norm(residual_stream) if self.config.residual_variant == 'layernorm' and hasattr(self.model.model, 'norm') else residual_stream
        layers_to_use = min(num_layers or self.num_layers, self.num_layers)

        # Create position_ids based on the sequence length of the residual stream
        seq_len = residual_stream.size(1)
        position_ids = torch.arange(seq_len, device=self.device).unsqueeze(0)

        # The model expects a 4D attention mask for the layers. We can create it from the 2D mask.
        # This is often handled internally, but being explicit is safer in custom loops.
        causal_mask = self.model._prepare_decoder_attention_mask(
            attention_mask, (1, seq_len), hidden_states, 0
        )

        for i in range(layers_to_use):
            outputs = self.model.model.layers[i](
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids, # Pass the position IDs
                use_cache=False
            )
            hidden_states = outputs[0]

        hidden_states = self.model.model.norm(hidden_states) if hasattr(self.model.model, 'norm') else hidden_states
        return self.model.lm_head(hidden_states)

    # BUG FIX 2: Completely rewritten generate loop for correct state management.
    def generate(self, prompt: str, max_new_tokens: int = 100) -> Dict:
        inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True).to(self.device)
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']

        generated_token_ids = input_ids
        entropies1, entropies2 = [], []
        start_time = time.time()

        for _ in range(max_new_tokens):
            # --- Pass 1: Standard autoregressive step to get the next token logits ---
            # For the first pass, we can use the efficient KV cache.
            # We pass the *entire* generated sequence so far. The model's use_cache
            # logic will handle only processing the last token.
            outputs = self.model(
                input_ids=generated_token_ids,
                attention_mask=attention_mask,
                use_cache=True, # Note: This will be overwritten if past_key_values is passed
                output_hidden_states=True,
            )

            # The logits for the VERY LAST token are what we need to predict the next one.
            logits1 = outputs.logits[:, -1, :]

            probs1 = F.softmax(logits1, dim=-1)
            entropy1 = -torch.sum(probs1 * torch.log(probs1 + self.config.entropy_eps), dim=-1).item()
            entropies1.append(entropy1)

            final_logits = logits1
            entropies2.append(entropy1) # Default for baseline

            # --- Pass 2: The "Double Pass" logic ---
            # This is only performed if not in baseline mode.
            if self.config.pass_type != 'baseline':
                # To get the full residual stream, we MUST run a forward pass
                # on the entire sequence WITHOUT the KV cache. This is inefficient
                # but is the only way to correctly implement the experimental idea.
                with torch.no_grad():
                    full_outputs = self.model(
                        input_ids=generated_token_ids,
                        attention_mask=attention_mask,
                        output_hidden_states=True,
                        use_cache=False # CRITICAL: No cache to get the full residual
                    )

                full_residual_stream = full_outputs.hidden_states[-1]

                num_l = self.config.second_pass_layers if self.config.pass_type == 'double_partial' else None
                logits2 = self.second_pass_forward(full_residual_stream, attention_mask, num_l)

                final_logits = logits2[:, -1, :] # We only care about the last token's logits

                probs2 = F.softmax(final_logits, dim=-1)
                entropy2 = -torch.sum(probs2 * torch.log(probs2 + self.config.entropy_eps), dim=-1).item()
                entropies2[-1] = entropy2 # Overwrite the last entropy value


            # --- Sampling ---
            probs = F.softmax(final_logits / self.config.temperature, dim=-1)

            # Top-p (nucleus) sampling
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
            sorted_indices_to_remove = cumulative_probs > self.config.top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            probs[indices_to_remove] = 0.0

            # Sample from the filtered distribution
            next_token = torch.multinomial(probs, num_samples=1)

            if next_token.item() == self.tokenizer.eos_token_id:
                break

            # Append the new token for the next iteration
            generated_token_ids = torch.cat([generated_token_ids, next_token], dim=-1)
            attention_mask = torch.cat([attention_mask, torch.ones((1,1), device=self.device)], dim=-1)


        generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)
        # We need to remove the original prompt from the output
        if generated_text.startswith(prompt):
            generated_text = generated_text[len(prompt):]


        return {
            'generated_text': generated_text.strip(),
            'entropies1': entropies1,
            'entropies2': entropies2,
            'num_tokens': len(entropies1),
            'time': time.time() - start_time
        }

In [6]:
### 5️⃣ Quick Smoke Test

try:
    model = DoublePassPhi3(config.model_path, config)
    result = model.generate('Sanity check:', max_new_tokens=10)
    print(f'Smoke test passed: {result["generated_text"]} in {result["time"]:.2f}s')
    del model
    torch.cuda.empty_cache()
except Exception as e:
    raise RuntimeError(f'Smoke test failed: {str(e)}')




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



✅ Model loaded successfully with eager attention
Smoke test passed: the size of the dataset should be divisible by in 1.05s


In [7]:
### 6️⃣ Simple Benchmark System

# Simple Q&A dataset for testing
SIMPLE_QA_DATASET = [
    {"question": "What is the capital of France?", "answer": "Paris"},
    {"question": "What is 2 + 2?", "answer": "4"},
    {"question": "What is the largest planet in our solar system?", "answer": "Jupiter"},
    {"question": "Who wrote Romeo and Juliet?", "answer": "Shakespeare"},
    {"question": "What is the square root of 16?", "answer": "4"},
    {"question": "What year did World War II end?", "answer": "1945"},
    {"question": "What is the chemical symbol for gold?", "answer": "Au"},
    {"question": "What is the capital of Japan?", "answer": "Tokyo"},
    {"question": "How many days are in a leap year?", "answer": "366"},
    {"question": "What is the speed of light?", "answer": "299,792,458"},
]

def simple_benchmark(model: DoublePassPhi3, dataset: List[Dict] = None) -> Dict:
    """Run a simple Q&A benchmark."""
    if dataset is None:
        dataset = SIMPLE_QA_DATASET

    correct = 0
    total = len(dataset)
    results = []
    total_time = 0

    for item in dataset:
        question = item["question"]
        expected = item["answer"].lower()

        start_time = time.time()
        result = model.generate(question, max_new_tokens=20)
        inference_time = time.time() - start_time
        total_time += inference_time

        generated = result['generated_text'].lower().strip()

        # Simple scoring - check if expected answer is in generated text
        is_correct = expected in generated or any(word in generated for word in expected.split())
        if is_correct:
            correct += 1

        results.append({
            'question': question,
            'expected': expected,
            'generated': result['generated_text'],
            'correct': is_correct,
            'time': inference_time
        })

    accuracy = correct / total
    avg_time = total_time / total

    return {
        'accuracy': accuracy,
        'avg_time': avg_time,
        'total_time': total_time,
        'samples': total,
        'results': results
    }


In [8]:
### 7️⃣ Simple Benchmark Runner

def run_simple_benchmark(model: DoublePassPhi3) -> Dict:
    """Run the simple Q&A benchmark."""
    try:
        print(f"Running simple benchmark with {model.config.pass_type} method...")
        results = simple_benchmark(model)
        print(f"✅ Completed: {results['accuracy']:.2f} accuracy, {results['avg_time']:.2f}s avg time")
        return results
    except Exception as e:
        print(f'❌ Error in benchmark: {str(e)}')
        return {'accuracy': 0.0, 'samples': 0, 'avg_time': 0.0, 'total_time': 0.0}


In [11]:
run_simple_benchmark(DoublePassPhi3(config.model_path, config))



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Model loaded successfully with eager attention
Running simple benchmark with baseline method...
✅ Completed: 0.80 accuracy, 1.04s avg time


{'accuracy': 0.8,
 'avg_time': 1.0375317096710206,
 'total_time': 10.375317096710205,
 'samples': 10,
 'results': [{'question': 'What is the capital of France?',
   'expected': 'paris',
   'generated': '# Answer\nThe capital of France is Paris.',
   'correct': True,
   'time': 0.7548317909240723},
  {'question': 'What is 2 + 2?',
   'expected': '4',
   'generated': "I know this one, it's 4. Now, can you give me a harder question",
   'correct': True,
   'time': 1.1072328090667725},
  {'question': 'What is the largest planet in our solar system?',
   'expected': 'jupiter',
   'generated': 'Jupiter is the largest planet in our solar system. It is a gas giant and is known for',
   'correct': True,
   'time': 1.1078200340270996},
  {'question': 'Who wrote Romeo and Juliet?',
   'expected': 'shakespeare',
   'generated': '# Answer\nWilliam Shakespeare wrote "Romeo and Juliet."',
   'correct': True,
   'time': 0.984142541885376},
  {'question': 'What is the square root of 16?',
   'expected'

In [9]:
### 8️⃣ Single Experiment Driver

def run_experiment(config: ExperimentConfig) -> Dict:
    """Run a single experiment with the given configuration."""
    print(f"\n🚀 Starting experiment: {config.pass_type}")

    model = DoublePassPhi3(config.model_path, config)
    results = run_simple_benchmark(model)

    # Clean up memory
    del model
    torch.cuda.empty_cache()

    # Save results
    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    path = Path(f'/content/drive/MyDrive/runs/{timestamp}-{config.pass_type}.json')
    path.parent.mkdir(exist_ok=True, parents=True)

    with open(path, 'w') as f:
        json.dump(results, f, indent=2, default=str)

    print(f'📊 Results: Accuracy {results["accuracy"]:.3f}, Avg Time {results["avg_time"]:.2f}s')
    print(f'💾 Saved to: {path}')

    return results


In [10]:
### 9️⃣ Grid Search Launcher


# Programmatic experiment grid generation
def generate_experiment_configs(custom_experiments: List[Dict] = None) -> List[ExperimentConfig]:
    """Generate list of experiment configurations. Optionally add custom experiments."""
    experiments = []

    # Standard experiment set
    experiments.append(ExperimentConfig(pass_type='baseline', residual_variant='raw'))
    experiments.append(ExperimentConfig(pass_type='double_full', residual_variant='raw'))
    experiments.append(ExperimentConfig(pass_type='double_full', residual_variant='layernorm'))

    # Systematic partial pass experiments
    for layers in [1, 2, 4, 8]:
        experiments.append(ExperimentConfig(
            pass_type='double_partial',
            second_pass_layers=layers,
            residual_variant='raw'
        ))

    # Add custom experiments if provided
    if custom_experiments:
        for custom in custom_experiments:
            experiments.append(ExperimentConfig(**custom))

    return experiments

def run_all_experiments(custom_experiments: List[Dict] = None):
    """Run all experiments and compare results."""
    exps = generate_experiment_configs(custom_experiments)
    print(f'🔬 Running {len(exps)} experiments, estimated time: {len(exps)*5:.0f} minutes')

    all_results = []
    for i, exp in enumerate(exps, 1):
        print(f"\n📊 Experiment {i}/{len(exps)}")
        try:
            res = run_experiment(exp)
            variant_name = f"{exp.pass_type}_{exp.residual_variant}"
            if exp.second_pass_layers:
                variant_name += f"_L{exp.second_pass_layers}"
            all_results.append((variant_name, res['accuracy']))
        except Exception as e:
            print(f"❌ Failed experiment {exp.pass_type}: {str(e)}")
            continue

    # Sort by accuracy
    all_results.sort(key=lambda x: x[1], reverse=True)

    print('\n🏆 LEADERBOARD:')
    print('=' * 40)
    for rank, (name, acc) in enumerate(all_results, 1):
        print(f'{rank}. {name}: {acc:.3f}')

    return all_results


In [11]:
### 🔟 Interactive Playground & Quick Test

def playground():
    """Interactive playground to compare baseline vs double pass."""
    print("🎮 Interactive Playground - Compare Baseline vs Double Pass")
    print("Type 'q' to quit, or enter any prompt to test both methods.\n")

    while True:
        prompt = input('📝 Enter prompt (or q to quit): ')
        if prompt.lower() == 'q': break

        print(f"\n🔍 Testing prompt: '{prompt}'")
        print("-" * 60)

        # Baseline experiment
        print("🚀 Running baseline...")
        base_config = ExperimentConfig(pass_type='baseline')
        base_model = DoublePassPhi3(base_config.model_path, base_config)
        base_res = base_model.generate(prompt, max_new_tokens=30)
        del base_model

        # Double pass experiment
        print("🚀 Running double pass...")
        double_config = ExperimentConfig(pass_type='double_full')
        double_model = DoublePassPhi3(double_config.model_path, double_config)
        double_res = double_model.generate(prompt, max_new_tokens=30)
        del double_model

        # Clear CUDA cache to prevent memory leak
        torch.cuda.empty_cache()

        # Calculate entropy differences
        base_ent = sum(base_res['entropies1']) / len(base_res['entropies1']) if base_res['entropies1'] else 0
        double_ent = sum(double_res['entropies2']) / len(double_res['entropies2']) if double_res['entropies2'] else 0

        print(f"\n📊 RESULTS:")
        print(f"🔵 Baseline: {base_res['generated_text']}")
        print(f"   Time: {base_res['time']:.2f}s, Entropy: {base_ent:.3f}")
        print(f"🟣 Double Pass: {double_res['generated_text']}")
        print(f"   Time: {double_res['time']:.2f}s, Entropy: {double_ent:.3f}")
        print(f"⚡ Speed ratio: {double_res['time']/base_res['time']:.1f}x slower")
        print(f"🧠 Entropy change: {base_ent - double_ent:.3f}\n")

def quick_test():
    """Run a quick test of both methods."""
    print("🧪 Quick Test - Baseline vs Double Pass\n")

    test_prompts = [
        "The capital of France is",
        "2 + 2 equals",
        "The largest planet is"
    ]

    for prompt in test_prompts:
        print(f"🔍 Testing: '{prompt}'")

        # Baseline
        base_config = ExperimentConfig(pass_type='baseline')
        base_model = DoublePassPhi3(base_config.model_path, base_config)
        base_res = base_model.generate(prompt, max_new_tokens=10)
        del base_model

        # Double pass
        double_config = ExperimentConfig(pass_type='double_full')
        double_model = DoublePassPhi3(double_config.model_path, double_config)
        double_res = double_model.generate(prompt, max_new_tokens=10)
        del double_model

        torch.cuda.empty_cache()

        print(f"  🔵 Baseline: {base_res['generated_text']}")
        print(f"  🟣 Double: {double_res['generated_text']}")
        print(f"  ⚡ Speed: {double_res['time']/base_res['time']:.1f}x slower\n")

# Uncomment to run:
# playground()
# quick_test()


In [12]:
# Cell 6.1: GSM8K BENCHMARKING SYSTEM

from datasets import load_dataset
import re

def parse_gsm8k_answer(generated_text: str) -> Optional[float]:
    """
    Parses the model's generated text to find the final numerical answer for GSM8K.
    It looks for the '####' marker or the last number in the string.
    """
    # Look for the #### pattern, which is standard for GSM8K reasoning chains
    match = re.search(r"####\s*([0-9,.]+)", generated_text)
    if match:
        # Extract number, remove commas, and convert to float
        return float(match.group(1).replace(',', ''))

    # If #### is not found, fall back to finding the last number in the string
    # This is less reliable but a good fallback.
    matches = re.findall(r"([0-9,.]+)", generated_text)
    if matches:
        return float(matches[-1].replace(',', ''))

    return None # Return None if no number is found

def run_gsm8k_benchmark(model: DoublePassPhi3, num_questions: int = 100) -> Dict:
    """
    Runs a benchmark on the GSM8K dataset.

    Args:
        model: An initialized DoublePassPhi3 model.
        num_questions: The number of questions to run from the test set.
                       Set to 'all' to run the entire benchmark (takes a long time!).
    """
    print(f"🚀 Loading GSM8K dataset...")
    try:
        # Load the 'main' configuration of gsm8k
        dataset = load_dataset("gsm8k", "main", split="test")
    except Exception as e:
        print(f"❌ Failed to load dataset: {e}")
        return {}

    if num_questions != 'all':
        dataset = dataset.select(range(num_questions))
        print(f"📊 Running on the first {num_questions} questions of the GSM8K test set.")
    else:
        print(f"📊 Running on the ENTIRE {len(dataset)} questions of the GSM8K test set. This will take a while!")

    correct = 0
    total = len(dataset)
    results = []
    total_time = 0

    for i, item in enumerate(dataset):
        question = item["question"]
        # The expected answer is just the number
        expected_answer_text = item["answer"].split("####")[-1].strip()
        expected = float(expected_answer_text.replace(',', ''))

        # Generate an answer. We need more tokens for GSM8K's reasoning.
        result = model.generate(question, max_new_tokens=256)
        total_time += result['time']

        # Parse the final number from the model's output
        generated = parse_gsm8k_answer(result['generated_text'])
        is_correct = generated is not None and generated == expected

        if is_correct:
            correct += 1

        print(f"  Q{i+1}/{total}: Expected: {expected}, Got: {generated} -> {'✅ Correct' if is_correct else '❌ Incorrect'}")

        results.append({
            'question': question,
            'expected': expected,
            'generated_text': result['generated_text'],
            'parsed_answer': generated,
            'correct': is_correct,
            'time': result['time']
        })

    accuracy = correct / total if total > 0 else 0
    avg_time = total_time / total if total > 0 else 0

    benchmark_results = {
        'accuracy': accuracy,
        'avg_time': avg_time,
        'total_time': total_time,
        'samples': total,
        'pass_type': model.config.pass_type,
        'results': results
    }

    print(f"\n✅ GSM8K BENCHMARK COMPLETE ({model.config.pass_type})")
    print(f"   Accuracy: {accuracy:.3f} ({correct}/{total})")
    print(f"   Average Time per Question: {avg_time:.2f}s")

    return benchmark_results

In [13]:
# Cell to run the GSM8K benchmark comparison

# --- CONFIGURATION ---
# Set the number of questions you want to test.
# Use a small number like 20 for a quick test.
# Use 'all' for the full benchmark (can take hours).
NUM_QUESTIONS_TO_RUN = 20

# --- 1. RUN BASELINE MODEL ---
print("="*60)
print("📊 STARTING GSM8K BENCHMARK: BASELINE")
print("="*60)
# Create a config for the baseline run
baseline_config = ExperimentConfig(pass_type='baseline')
# Initialize the model with this config
baseline_model = DoublePassPhi3(baseline_config.model_path, baseline_config)
# Run the benchmark
baseline_results = run_gsm8k_benchmark(baseline_model, num_questions=NUM_QUESTIONS_TO_RUN)
# Clean up memory
del baseline_model
torch.cuda.empty_cache()


# --- 2. RUN DOUBLE PASS MODEL ---
print("\n" + "="*60)
print("📊 STARTING GSM8K BENCHMARK: DOUBLE PASS (FULL)")
print("="*60)
# Create a config for your experimental run
# We will test a full double pass here
double_pass_config = ExperimentConfig(
    pass_type='double_full',
    residual_variant='raw' # or 'layernorm'
)
# Initialize the model
double_pass_model = DoublePassPhi3(double_pass_config.model_path, double_pass_config)
# Run the benchmark
double_pass_results = run_gsm8k_benchmark(double_pass_model, num_questions=NUM_QUESTIONS_TO_RUN)
# Clean up memory
del double_pass_model
torch.cuda.empty_cache()


# --- 3. COMPARE RESULTS ---
print("\n" + "="*60)
print("🏆 FINAL GSM8K BENCHMARK COMPARISON")
print("="*60)
if baseline_results:
    print(f"🔵 Baseline Accuracy: {baseline_results['accuracy']:.3f}")
    print(f"   Avg. Time: {baseline_results['avg_time']:.2f}s")
if double_pass_results:
    print(f"🟣 Double Pass Accuracy: {double_pass_results['accuracy']:.3f}")
    print(f"   Avg. Time: {double_pass_results['avg_time']:.2f}s")



📊 STARTING GSM8K BENCHMARK: BASELINE


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Model loaded successfully with eager attention
🚀 Loading GSM8K dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Downloading data: 100%|██████████| 2.31M/2.31M [00:00<00:00, 9.12MB/s]
Downloading data: 100%|██████████| 419k/419k [00:00<00:00, 3.79MB/s]


Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

📊 Running on the first 20 questions of the GSM8K test set.


ValueError: could not convert string to float: '.'

In [None]:
### 1️⃣1️⃣ Results Analysis & Plots

def analyze_results(runs_dir='/content/drive/MyDrive/runs'):
    """Analyze experiment results with visualization."""
    files = glob(f'{runs_dir}/*.json')
    if not files:
        print('No results found');
        return

    data = []
    for f in files:
        with open(f) as jf:
            res = json.load(jf)
            variant = Path(f).stem.split('-')[1]
            for bench, metrics in res.items():
                data.append({
                    'variant': variant,
                    'benchmark': bench,
                    'accuracy': metrics['accuracy'],
                    'runtime': metrics['runtime']
                })

    df = pd.DataFrame(data)
    fig, axs = plt.subplots(1, 2, figsize=(12,4))

    # Average accuracy by variant
    df.groupby('variant')['accuracy'].mean().sort_values(ascending=False).plot.bar(ax=axs[0])
    axs[0].set_title('Avg Accuracy by Variant')
    axs[0].set_ylabel('Accuracy')

    # Accuracy vs Runtime scatter
    df.plot.scatter(x='runtime', y='accuracy', ax=axs[1])
    axs[1].set_title('Accuracy vs Runtime')
    axs[1].set_xlabel('Runtime (seconds)')
    axs[1].set_ylabel('Accuracy')

    plt.tight_layout()
    plt.savefig(f'{runs_dir}/analysis.png')
    plt.show()

    return df

# analyze_results()  # Uncomment to run


In [None]:
### 1️⃣2️⃣ Usage Examples & How to Run

print("✅ All functions loaded successfully!")
print("\n🚀 HOW TO RUN THE EXPERIMENT:")
print("=" * 50)

print("\n1️⃣ QUICK START:")
print("   quick_test()  # Run quick comparison")

print("\n2️⃣ SINGLE EXPERIMENTS:")
print("   # Baseline method")
print("   baseline_config = ExperimentConfig(pass_type='baseline')")
print("   result1 = run_experiment(baseline_config)")
print("")
print("   # Double pass method")
print("   double_config = ExperimentConfig(pass_type='double_full')")
print("   result2 = run_experiment(double_config)")

print("\n3️⃣ FULL COMPARISON:")
print("   run_all_experiments()  # Compare all methods")

print("\n4️⃣ INTERACTIVE MODE:")
print("   playground()  # Test custom prompts")

print("\n5️⃣ CUSTOM EXPERIMENTS:")
print("   custom_exps = [")
print("       {'pass_type': 'double_partial', 'second_pass_layers': 4}")
print("   ]")
print("   run_all_experiments(custom_experiments=custom_exps)")

print("\n💡 TIPS:")
print("- Run smoke test first to verify model loading")
print("- Start with quick_test() for immediate results")
print("- Each experiment takes ~2-5 minutes")
print("- Results saved to /content/drive/MyDrive/runs/")
print("- Use Ctrl+C to stop long-running experiments")

print("\n🎯 RECOMMENDED WORKFLOW:")
print("1. Run the smoke test (cell 5)")
print("2. Try quick_test() for immediate comparison")
print("3. Run single experiments to understand differences")
print("4. Use playground() to test specific prompts")
print("5. Run full comparison with run_all_experiments()")


In [None]:
# Additional utility functions and debug helpers

def print_model_summary(model):
    """Print model architecture summary."""
    print(f"Model: {model.__class__.__name__}")
    print(f"Device: {model.device}")
    print(f"Layers: {model.num_layers}")
    print(f"Config: {model.config}")

# Quick profiling: torch.utils.bottleneck.run(model.generate('test'))
# Toggle attention: config.attn_impl = 'eager'
# Memory check: print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.1f}GB")


In [None]:
### 1️⃣4️⃣ Cleanup & Tips


# Tips:
# - Restart runtime if OOM: Runtime > Restart session
# - Download results: Files > Mount Drive > Copy from /content/drive/MyDrive/runs
# - Expected times: Baseline ~30min/benchmark, Double ~60min
# - Checklist: Run smoke test, check GPU, verify model path, start with single exp


In [None]:
### 1️⃣5️⃣ Ready to Run! 🚀

print("🎉 SETUP COMPLETE!")
print("All functions are loaded and ready to use.")
print("\n" + "="*60)
print("🚀 TO START EXPERIMENTING, RUN ONE OF THESE:")
print("="*60)
print("👉 quick_test()           # 2-minute quick demo")
print("👉 playground()           # Interactive testing")
print("👉 run_all_experiments()  # Full comparison (~30 min)")
print("="*60)

print("\n💡 WHAT EACH METHOD DOES:")
print("🔵 Baseline: Normal single forward pass (standard AI)")
print("🟣 Double Pass: Runs model twice for each word (experimental)")
print("\n🎯 GOAL: See if 'thinking twice' improves accuracy!")

# Uncomment any of these to run:
# quick_test()
# playground()
# run_all_experiments()
