<a href="https://colab.research.google.com/github/gut-puncture/double-inference/blob/main/double_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Double Forward Pass Research: Phi-3 Mini 4K Instruct

## Research Objective
Quantitatively test whether performing TWO sequential forward passes per generated token (feeding the final residual stream of pass #1 back into layer #0) improves language-reasoning ability of the 3.8B-parameter "microsoft/phi-3-mini-4k-instruct" model.

## Hypothesis
Extra compute refines internal representations, lowering token-level entropy and boosting accuracy on reasoning benchmarks relative to the single-pass baseline.

## Benchmarks
- **MMLU**: 57 subjects, 5-shot, accuracy
- **BigBench-Hard (BBH)**: 23 tasks, 0-shot, accuracy
- **GSM8K**: Math word problems, 5-shot chain-of-thought, answer accuracy

## Experimental Conditions
1. Baseline (single-pass generation)
2. Double-pass (full N layers)
3. Partial-pass grid (k ∈ {1, 2, 4, 8, N/2, N})
4. Residual-norm ablation (raw vs LayerNorm)

## Environment Setup
- Google Colab A100 GPU
- Model path: `/content/drive/MyDrive/phi3_3.8B`
- All logs saved locally (no external services)

## Expected Runtime
~10 GPU-hours total for all variants across three benchmarks


In [1]:
# Install required packages
!pip install -q transformers>=4.41.0 accelerate torch tensorboard lm-eval[api] datasets
# Skip flash-attn installation - we'll use auto-fallback to SDPA instead

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Verify model path exists
import os
model_path = "/content/drive/MyDrive/phi3_3.8B"
if os.path.exists(model_path):
    print(f"✓ Model found at {model_path}")
    print(f"Contents: {os.listdir(model_path)}")
else:
    print(f"✗ Model not found at {model_path}")
    print("Please ensure the model is uploaded to the correct location")


[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2025.3.0 which is incompatible.[0m[31m
[0mMounted at /content/drive
✓ Model found at /content/drive/MyDrive/phi3_3.8B
Contents: ['.cache', 'LICENSE', 'added_tokens.json', '.gitattributes', 'README.md', 'SECURITY.md', 'NOTICE.md', 'config.json', 'CODE_OF_CONDUCT.md', 'configuration_phi3.py', 'generation_config.json', 'model.safetensors.index.json', 'modeling_phi3.py', 'special_tokens_map.json', 'sample_finetune.py', 'tokenizer_config.json', 'tokenizer.model', 'tokenizer.json', 'model-00002-of-00002.safetensors', 'model-00001-of-00002.safetensors']


In [2]:
# Core imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, AutoConfig,
    GenerationConfig, set_seed
)
from transformers.models.phi3.modeling_phi3 import Phi3Model, Phi3ForCausalLM
import numpy as np
import json
import csv
import time
import hashlib
from datetime import datetime
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional, Tuple, Any, Union
from torch.utils.tensorboard import SummaryWriter
import logging
import subprocess
import sys

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


Using device: cuda
GPU: NVIDIA A100-SXM4-40GB
GPU Memory: 42.5 GB


In [3]:
# Configuration and logging utilities

def resolve_attn_impl(requested: str) -> str:
    """
    Resolve attention implementation with automatic fallback.

    Args:
        requested: 'auto', 'flash2', 'sdpa', or 'eager'

    Returns:
        The actual implementation string for transformers
    """
    if requested in {"sdpa", "eager"}:
        return requested
    if requested == "flash2":
        return "flash_attention_2"

    # Auto-detect: try flash attention first, fallback to sdpa
    try:
        import flash_attn
        # Test if flash_attn can actually be imported without errors
        from flash_attn import flash_attn_func
        logger.info("✓ Flash Attention 2 available - using flash_attention_2")
        return "flash_attention_2"
    except (ImportError, Exception) as e:
        logger.info(f"Flash Attention 2 not available ({str(e)[:50]}...) - falling back to SDPA")
        return "sdpa"

@dataclass
class ExperimentConfig:
    """Configuration for double-pass experiments"""
    model_path: str
    pass_type: str  # 'baseline', 'double_full', 'double_partial'
    second_pass_layers: Optional[int] = None
    residual_variant: str = 'raw'  # 'raw' or 'layernorm'
    attn_impl: str = 'auto'  # 'auto', 'flash2', 'sdpa', 'eager'
    seed: int = 42
    max_length: int = 2048
    temperature: float = 0.7
    top_p: float = 0.9
    do_sample: bool = True
    timestamp: str = None
    git_sha: str = None

    def __post_init__(self):
        if self.timestamp is None:
            self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        if self.git_sha is None:
            try:
                self.git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode().strip()
            except:
                self.git_sha = "unknown"

class ExpLogger:
    """Local experiment logger - no external services"""

    def __init__(self, save_dir: str):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)

        # Initialize files
        self.token_stats_file = self.save_dir / "token_stats.jsonl"
        self.results_file = self.save_dir / "results.csv"
        self.tb_dir = self.save_dir / "tb"
        self.tb_dir.mkdir(exist_ok=True)

        # TensorBoard writer
        self.tb_writer = SummaryWriter(str(self.tb_dir))

        # Initialize CSV
        self.csv_headers = [
            'benchmark', 'variant', 'accuracy', 'num_samples',
            'avg_entropy_base', 'avg_entropy_double', 'runtime_sec',
            'tokens_generated', 'compute_overhead'
        ]

        with open(self.results_file, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=self.csv_headers)
            writer.writeheader()

    def log_token_stats(self, stats: Dict[str, Any]):
        """Log per-token statistics"""
        with open(self.token_stats_file, 'a') as f:
            f.write(json.dumps(stats) + '\n')

    def log_scalar(self, tag: str, value: float, step: int):
        """Log scalar to TensorBoard"""
        self.tb_writer.add_scalar(tag, value, step)

    def log_histogram(self, tag: str, values: np.ndarray, step: int):
        """Log histogram to TensorBoard"""
        self.tb_writer.add_histogram(tag, values, step)

    def log_text(self, tag: str, text: str, step: int):
        """Log text to TensorBoard"""
        self.tb_writer.add_text(tag, text, step)

    def log_results(self, results: Dict[str, Any]):
        """Log aggregate results to CSV"""
        with open(self.results_file, 'a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=self.csv_headers)
            writer.writerow(results)

    def save_config(self, config: ExperimentConfig):
        """Save experiment configuration"""
        config_path = self.save_dir / "run_config.json"
        with open(config_path, 'w') as f:
            json.dump(asdict(config), f, indent=2)

    def finalize(self):
        """Close TensorBoard writer"""
        self.tb_writer.close()

        # Create summary
        summary_path = self.save_dir / "summary.txt"
        with open(summary_path, 'w') as f:
            f.write(f"Experiment completed at {datetime.now()}\n")
            f.write(f"Results saved to: {self.save_dir}\n")
            f.write(f"View TensorBoard: tensorboard --logdir {self.tb_dir}\n")

print("✓ Configuration and logging utilities loaded")


✓ Configuration and logging utilities loaded


In [4]:
# Double-pass Phi3 model implementation

class DoublePassPhi3(nn.Module):
    """
    Phi3 model wrapper that supports double forward passes.

    The key insight: after the first forward pass, we take the residual stream
    (hidden states before final LayerNorm) and feed it back into the transformer
    blocks for a second pass, then sample from the second pass logits.
    """

    def __init__(self, model_path: str, config: ExperimentConfig):
        super().__init__()
        self.config = config
        self.device = device

        # Resolve attention implementation
        attn_impl = resolve_attn_impl(config.attn_impl)
        print(f"Using attention implementation: {attn_impl}")

        # Load model and tokenizer
        print(f"Loading model from {model_path}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
            attn_implementation=attn_impl
        )

        # Ensure pad token exists
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model.eval()
        self.num_layers = len(self.model.model.layers)

        print(f"✓ Model loaded with {self.num_layers} layers")
        print(f"✓ Model parameters: {sum(p.numel() for p in self.model.parameters()) / 1e9:.1f}B")

        # Set up generation config
        self.generation_config = GenerationConfig(
            max_length=config.max_length,
            temperature=config.temperature,
            top_p=config.top_p,
            do_sample=config.do_sample,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            use_cache=True
        )

    def get_residual_stream(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
                           past_key_values=None, use_cache: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Any]:
        """
        Forward pass that returns logits and the residual stream before final LayerNorm.

        Returns:
            logits: [batch_size, seq_len, vocab_size]
            residual_stream: [batch_size, seq_len, hidden_size]
            past_key_values: KV cache for next iteration
        """
        with torch.no_grad():
            # Get embeddings
            inputs_embeds = self.model.model.embed_tokens(input_ids)

            # Add positional embeddings if needed
            if hasattr(self.model.model, 'embed_positions'):
                position_ids = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0)
                inputs_embeds += self.model.model.embed_positions(position_ids)

            # Forward through transformer layers
            hidden_states = inputs_embeds
            new_past_key_values = [] if use_cache else None

            for i, layer in enumerate(self.model.model.layers):
                layer_past = past_key_values[i] if past_key_values else None

                layer_outputs = layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    past_key_value=layer_past,
                    use_cache=use_cache
                )

                hidden_states = layer_outputs[0]

                if use_cache:
                    new_past_key_values.append(layer_outputs[1])

            # Store residual stream BEFORE final norm
            residual_stream = hidden_states.clone()

            # Apply final norm and get logits
            if hasattr(self.model.model, 'norm'):
                hidden_states = self.model.model.norm(hidden_states)

            logits = self.model.lm_head(hidden_states)

            return logits, residual_stream, new_past_key_values

    def second_pass_forward(self, residual_stream: torch.Tensor, attention_mask: torch.Tensor = None,
                           num_layers: Optional[int] = None) -> torch.Tensor:
        """
        Second forward pass using residual stream as input.

        Args:
            residual_stream: [batch_size, seq_len, hidden_size]
            attention_mask: attention mask
            num_layers: number of layers to use (for partial passes)

        Returns:
            logits: [batch_size, seq_len, vocab_size]
        """
        with torch.no_grad():
            # Apply residual variant
            if self.config.residual_variant == 'layernorm':
                if hasattr(self.model.model, 'norm'):
                    hidden_states = self.model.model.norm(residual_stream)
                else:
                    # Use a simple LayerNorm if model doesn't have norm
                    hidden_states = F.layer_norm(residual_stream, residual_stream.shape[-1:])
            else:
                hidden_states = residual_stream

            # Determine number of layers to use
            layers_to_use = num_layers if num_layers is not None else self.num_layers
            layers_to_use = min(layers_to_use, self.num_layers)

            # Forward through specified number of layers
            for i in range(layers_to_use):
                layer_outputs = self.model.model.layers[i](
                    hidden_states,
                    attention_mask=attention_mask,
                    use_cache=False  # Don't use cache for second pass
                )
                hidden_states = layer_outputs[0]

            # Apply final norm and get logits
            if hasattr(self.model.model, 'norm'):
                hidden_states = self.model.model.norm(hidden_states)

            logits = self.model.lm_head(hidden_states)

            return logits

    def generate_with_double_pass(self, prompt: str, max_new_tokens: int = 100,
                                 logger: Optional[ExpLogger] = None) -> Dict[str, Any]:
        """
        Generate text using double forward pass strategy.

        Returns:
            Dictionary with generated text, token stats, and metadata
        """
        # Tokenize input
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)

        # Initialize tracking
        generated_tokens = []
        token_stats = []
        past_key_values = None

        start_time = time.time()

        for step in range(max_new_tokens):
            # First forward pass
            logits1, residual_stream, past_key_values = self.get_residual_stream(
                input_ids, attention_mask, past_key_values, use_cache=True
            )

            # Get logits for current position
            current_logits1 = logits1[:, -1, :]  # [batch_size, vocab_size]

            if self.config.pass_type == 'baseline':
                # Use first pass logits directly
                final_logits = current_logits1
                entropy1 = self._compute_entropy(current_logits1)
                entropy2 = entropy1  # Same for baseline
            else:
                # Second forward pass
                num_layers = self.config.second_pass_layers if self.config.pass_type == 'double_partial' else None
                logits2 = self.second_pass_forward(residual_stream, attention_mask, num_layers)
                current_logits2 = logits2[:, -1, :]

                # Use second pass logits for sampling
                final_logits = current_logits2
                entropy1 = self._compute_entropy(current_logits1)
                entropy2 = self._compute_entropy(current_logits2)

            # Sample next token
            probs = F.softmax(final_logits / self.config.temperature, dim=-1)

            if self.config.do_sample:
                # Top-p sampling
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

                # Remove tokens with cumulative probability above top_p
                sorted_indices_to_remove = cumulative_probs > self.config.top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                probs[indices_to_remove] = 0
                probs = probs / probs.sum(dim=-1, keepdim=True)

                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(probs, dim=-1, keepdim=True)

            # Record token statistics
            token_stats.append({
                'step': step,
                'token_id': next_token.item(),
                'token_text': self.tokenizer.decode(next_token.item()),
                'entropy_pass1': entropy1.item(),
                'entropy_pass2': entropy2.item(),
                'top5_tokens_pass1': torch.topk(F.softmax(current_logits1, dim=-1), 5).indices.tolist(),
                'top5_probs_pass1': torch.topk(F.softmax(current_logits1, dim=-1), 5).values.tolist(),
                'top5_tokens_pass2': torch.topk(F.softmax(final_logits, dim=-1), 5).indices.tolist() if self.config.pass_type != 'baseline' else None,
                'top5_probs_pass2': torch.topk(F.softmax(final_logits, dim=-1), 5).values.tolist() if self.config.pass_type != 'baseline' else None,
            })

            # Log to experiment logger if provided
            if logger:
                logger.log_token_stats({
                    'prompt_hash': hashlib.md5(prompt.encode()).hexdigest()[:8],
                    'step': step,
                    'token_id': next_token.item(),
                    'entropy_base': entropy1.item(),
                    'entropy_double': entropy2.item(),
                    'pass_type': self.config.pass_type
                })

            # Append token and continue
            generated_tokens.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token], dim=-1)

            # Update attention mask
            attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)

            # Stop if EOS token
            if next_token.item() == self.tokenizer.eos_token_id:
                break

        generation_time = time.time() - start_time

        # Decode generated text
        generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        full_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)

        return {
            'prompt': prompt,
            'generated_text': generated_text,
            'full_text': full_text,
            'token_stats': token_stats,
            'generation_time': generation_time,
            'num_tokens': len(generated_tokens),
            'config': self.config
        }

    def _compute_entropy(self, logits: torch.Tensor) -> torch.Tensor:
        """Compute entropy of logits"""
        probs = F.softmax(logits, dim=-1)
        log_probs = F.log_softmax(logits, dim=-1)
        entropy = -(probs * log_probs).sum(dim=-1)
        return entropy.mean()  # Average over batch

print("✓ DoublePassPhi3 model implementation loaded")


✓ DoublePassPhi3 model implementation loaded


In [5]:
# lm-eval integration for benchmarks

class LMEvalWrapper:
    """Wrapper to integrate DoublePassPhi3 with lm-eval-harness"""

    def __init__(self, double_pass_model: DoublePassPhi3):
        self.model = double_pass_model
        self.tokenizer = double_pass_model.tokenizer
        self.device = double_pass_model.device

    def loglikelihood(self, requests):
        """Compute log-likelihood for multiple choice tasks"""
        results = []

        for context, continuation in requests:
            # Tokenize context and continuation
            context_tokens = self.tokenizer.encode(context, add_special_tokens=False)
            continuation_tokens = self.tokenizer.encode(continuation, add_special_tokens=False)

            # Full sequence
            full_tokens = context_tokens + continuation_tokens
            input_ids = torch.tensor([full_tokens], device=self.device)

            # Get logits from model
            with torch.no_grad():
                if self.model.config.pass_type == 'baseline':
                    # Standard forward pass
                    outputs = self.model.model(input_ids)
                    logits = outputs.logits
                else:
                    # Double pass
                    logits1, residual_stream, _ = self.model.get_residual_stream(input_ids)
                    if self.model.config.pass_type == 'double_partial':
                        logits = self.model.second_pass_forward(
                            residual_stream,
                            num_layers=self.model.config.second_pass_layers
                        )
                    else:
                        logits = self.model.second_pass_forward(residual_stream)

                # Calculate log-likelihood for continuation tokens
                shift_logits = logits[..., len(context_tokens)-1:-1, :].contiguous()
                shift_labels = torch.tensor(continuation_tokens, device=self.device).unsqueeze(0)

                log_probs = F.log_softmax(shift_logits, dim=-1)
                log_likelihood = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1).sum().item()

                # Return (log_likelihood, is_greedy)
                results.append((log_likelihood, False))

        return results

    def generate_until(self, requests):
        """Generate text until stop sequences"""
        results = []

        for context, gen_kwargs in requests:
            # Extract generation parameters
            max_gen_toks = gen_kwargs.get('max_gen_toks', 100)
            until = gen_kwargs.get('until', [])

            # Generate with double pass
            result = self.model.generate_with_double_pass(
                context,
                max_new_tokens=max_gen_toks
            )

            generated = result['generated_text']

            # Apply stop sequences
            for stop_seq in until:
                if stop_seq in generated:
                    generated = generated.split(stop_seq)[0]
                    break

            results.append(generated)

        return results

def run_benchmark(model_wrapper: LMEvalWrapper, benchmark_name: str,
                 config: ExperimentConfig, logger: ExpLogger) -> Dict[str, Any]:
    """Run a specific benchmark using lm-eval"""

    print(f"Running {benchmark_name} benchmark...")
    start_time = time.time()

    # Map benchmark names to lm-eval task names
    task_mapping = {
        'mmlu': 'mmlu',
        'bbh': 'bigbench_hard',
        'gsm8k': 'gsm8k'
    }

    task_name = task_mapping.get(benchmark_name.lower(), benchmark_name)

    try:
        # Import lm-eval
        from lm_eval import evaluator
        from lm_eval.models.huggingface import HFLM

        # Create a custom model class that uses our wrapper
        class CustomHFLM(HFLM):
            def __init__(self, wrapper):
                self.wrapper = wrapper
                self.tokenizer = wrapper.tokenizer
                self.device = wrapper.device

            def loglikelihood(self, requests):
                return self.wrapper.loglikelihood(requests)

            def generate_until(self, requests):
                return self.wrapper.generate_until(requests)

        # Create model instance
        model = CustomHFLM(model_wrapper)

        # Run evaluation
        results = evaluator.simple_evaluate(
            model=model,
            tasks=[task_name],
            num_fewshot=5 if benchmark_name.lower() in ['mmlu', 'gsm8k'] else 0,
            batch_size=1,  # Keep small for memory
            device=str(model_wrapper.device),
            no_cache=True
        )

        runtime = time.time() - start_time

        # Extract results
        task_results = results['results'][task_name]
        accuracy = task_results.get('acc', task_results.get('exact_match', 0.0))

        # Log results
        result_dict = {
            'benchmark': benchmark_name,
            'variant': f"{config.pass_type}_{config.second_pass_layers if config.second_pass_layers else 'full'}_{config.residual_variant}",
            'accuracy': accuracy,
            'num_samples': task_results.get('num_samples', 0),
            'avg_entropy_base': 0.0,  # Will be computed from token stats
            'avg_entropy_double': 0.0,  # Will be computed from token stats
            'runtime_sec': runtime,
            'tokens_generated': 0,  # Will be computed from token stats
            'compute_overhead': 2.0 if config.pass_type != 'baseline' else 1.0
        }

        logger.log_results(result_dict)
        logger.log_scalar(f"{benchmark_name}/accuracy", accuracy, 0)

        print(f"✓ {benchmark_name} completed: {accuracy:.3f} accuracy in {runtime:.1f}s")

        return {
            'benchmark': benchmark_name,
            'accuracy': accuracy,
            'runtime': runtime,
            'full_results': task_results
        }

    except Exception as e:
        print(f"✗ Error running {benchmark_name}: {str(e)}")
        return {
            'benchmark': benchmark_name,
            'accuracy': 0.0,
            'runtime': 0.0,
            'error': str(e)
        }

print("✓ lm-eval integration loaded")


✓ lm-eval integration loaded


In [6]:
# Experiment runner and main loop

def run_experiment(config: ExperimentConfig, benchmarks: List[str] = None) -> Dict[str, Any]:
    """Run a complete experiment with the given configuration"""

    if benchmarks is None:
        benchmarks = ['mmlu', 'bbh', 'gsm8k']

    # Create logger
    run_dir = f"runs/{config.timestamp}-{config.pass_type}"
    if config.second_pass_layers:
        run_dir += f"-{config.second_pass_layers}layers"
    run_dir += f"-{config.residual_variant}"

    logger = ExpLogger(run_dir)
    logger.save_config(config)

    print(f"Starting experiment: {config.pass_type}")
    print(f"Logs will be saved to: {run_dir}")

    # Initialize model
    try:
        model = DoublePassPhi3(config.model_path, config)
        wrapper = LMEvalWrapper(model)

        # Run benchmarks
        results = {}
        for benchmark in benchmarks:
            print(f"\n{'='*50}")
            print(f"Running {benchmark.upper()} benchmark")
            print(f"{'='*50}")

            result = run_benchmark(wrapper, benchmark, config, logger)
            results[benchmark] = result

            # Log to TensorBoard
            logger.log_scalar(f"accuracy/{benchmark}", result['accuracy'], 0)
            logger.log_scalar(f"runtime/{benchmark}", result['runtime'], 0)

        # Compute aggregate metrics
        total_accuracy = np.mean([r['accuracy'] for r in results.values()])
        total_runtime = sum([r['runtime'] for r in results.values()])

        logger.log_scalar("accuracy/overall", total_accuracy, 0)
        logger.log_scalar("runtime/total", total_runtime, 0)

        print(f"\n{'='*50}")
        print(f"EXPERIMENT COMPLETE")
        print(f"{'='*50}")
        print(f"Overall accuracy: {total_accuracy:.3f}")
        print(f"Total runtime: {total_runtime:.1f}s")
        print(f"Results saved to: {run_dir}")

        logger.finalize()

        return {
            'config': config,
            'results': results,
            'overall_accuracy': total_accuracy,
            'total_runtime': total_runtime,
            'run_dir': run_dir
        }

    except Exception as e:
        print(f"✗ Experiment failed: {str(e)}")
        logger.finalize()
        raise e

def run_all_experiments(model_path: str, benchmarks: List[str] = None) -> List[Dict[str, Any]]:
    """Run all experimental conditions"""

    if benchmarks is None:
        benchmarks = ['mmlu', 'bbh', 'gsm8k']

    # Define experimental conditions
    experiments = []

    # 1. Baseline
    experiments.append(ExperimentConfig(
        model_path=model_path,
        pass_type='baseline',
        residual_variant='raw'
    ))

    # 2. Double-pass full
    experiments.append(ExperimentConfig(
        model_path=model_path,
        pass_type='double_full',
        residual_variant='raw'
    ))

    # 3. Double-pass with LayerNorm
    experiments.append(ExperimentConfig(
        model_path=model_path,
        pass_type='double_full',
        residual_variant='layernorm'
    ))

    # 4. Partial double-pass (different layer counts)
    for num_layers in [1, 2, 4, 8]:
        experiments.append(ExperimentConfig(
            model_path=model_path,
            pass_type='double_partial',
            second_pass_layers=num_layers,
            residual_variant='raw'
        ))

    print(f"Running {len(experiments)} experiments across {len(benchmarks)} benchmarks")
    print(f"Estimated total time: {len(experiments) * len(benchmarks) * 30 / 60:.1f} minutes")

    # Run all experiments
    all_results = []
    for i, config in enumerate(experiments):
        print(f"\n{'='*60}")
        print(f"EXPERIMENT {i+1}/{len(experiments)}")
        print(f"Config: {config.pass_type}, layers: {config.second_pass_layers}, variant: {config.residual_variant}")
        print(f"{'='*60}")

        try:
            result = run_experiment(config, benchmarks)
            all_results.append(result)

            # Clear GPU cache between experiments
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"✗ Experiment {i+1} failed: {str(e)}")
            continue

    print(f"\n{'='*60}")
    print(f"ALL EXPERIMENTS COMPLETE")
    print(f"{'='*60}")
    print(f"Completed {len(all_results)}/{len(experiments)} experiments")

    return all_results

print("✓ Experiment runner loaded")


✓ Experiment runner loaded


In [7]:
# Quick test and demo cell

def test_model_loading():
    """Test that the model loads correctly"""
    print("Testing model loading...")

    config = ExperimentConfig(
        model_path=model_path,
        pass_type='baseline'
    )

    try:
        model = DoublePassPhi3(model_path, config)
        print("✓ Model loaded successfully")
        print(f"  - Layers: {model.num_layers}")
        print(f"  - Parameters: {sum(p.numel() for p in model.model.parameters()) / 1e9:.1f}B")
        return model
    except Exception as e:
        print(f"✗ Model loading failed: {str(e)}")
        return None

def demo_generation():
    """Demo generation with different pass types"""
    print("\nTesting generation with different pass types...")

    test_prompts = [
        "The capital of France is",
        "2 + 2 equals",
        "The largest planet in our solar system is"
    ]

    # Test baseline
    print("\n--- BASELINE ---")
    config_baseline = ExperimentConfig(
        model_path=model_path,
        pass_type='baseline',
        attn_impl='auto',  # Auto-detect: Flash Attention 2 if available, else SDPA
        temperature=0.7,
        max_length=50
    )

    try:
        model_baseline = DoublePassPhi3(model_path, config_baseline)

        for prompt in test_prompts:
            result = model_baseline.generate_with_double_pass(prompt, max_new_tokens=20)
            print(f"Prompt: {prompt}")
            print(f"Output: {result['generated_text']}")
            print(f"Tokens: {result['num_tokens']}, Time: {result['generation_time']:.2f}s")
            print()

        del model_baseline
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"✗ Baseline generation failed: {str(e)}")

    # Test double-pass
    print("\n--- DOUBLE PASS ---")
    config_double = ExperimentConfig(
        model_path=model_path,
        pass_type='double_full',
        attn_impl='auto',  # Auto-detect: Flash Attention 2 if available, else SDPA
        temperature=0.7,
        max_length=50
    )

    try:
        model_double = DoublePassPhi3(model_path, config_double)

        for prompt in test_prompts:
            result = model_double.generate_with_double_pass(prompt, max_new_tokens=20)
            print(f"Prompt: {prompt}")
            print(f"Output: {result['generated_text']}")
            print(f"Tokens: {result['num_tokens']}, Time: {result['generation_time']:.2f}s")

            # Show entropy comparison
            if result['token_stats']:
                avg_entropy1 = np.mean([s['entropy_pass1'] for s in result['token_stats']])
                avg_entropy2 = np.mean([s['entropy_pass2'] for s in result['token_stats']])
                print(f"Avg entropy - Pass 1: {avg_entropy1:.3f}, Pass 2: {avg_entropy2:.3f}")
            print()

        del model_double
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"✗ Double-pass generation failed: {str(e)}")

# Run tests
print("Running quick tests...")
test_model = test_model_loading()
if test_model:
    demo_generation()
    print("✓ All tests completed successfully!")
else:
    print("✗ Tests failed - check model path and setup")


Running quick tests...
Testing model loading...
Using attention implementation: sdpa
Loading model from /content/drive/MyDrive/phi3_3.8B...




✗ Model loading failed: Phi3ForCausalLM does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`
✗ Tests failed - check model path and setup


In [None]:
## 🚀 Run Full Experiments

**Important**: The cell below will run all experiments. This will take several hours on an A100 GPU.

**Before running**:
1. Make sure the model loaded successfully in the test above
2. Ensure you have sufficient GPU memory
3. Consider running a single experiment first to verify everything works

**What will happen**:
- 7 different experimental conditions will be tested
- Each will be evaluated on MMLU, BigBench-Hard, and GSM8K
- Results will be saved to local `runs/` directory
- TensorBoard logs will be created for visualization

**Estimated time**: 6-10 hours total


In [None]:
# Run all experiments
# WARNING: This will take several hours!

# Set random seed for reproducibility
set_seed(42)

# Run all experiments
print("Starting full experimental suite...")
print("This will take several hours - grab a coffee (or several)!")

try:
    all_results = run_all_experiments(model_path)

    # Print summary
    print("\n" + "="*80)
    print("FINAL RESULTS SUMMARY")
    print("="*80)

    for result in all_results:
        config = result['config']
        variant_name = f"{config.pass_type}"
        if config.second_pass_layers:
            variant_name += f"_{config.second_pass_layers}layers"
        variant_name += f"_{config.residual_variant}"

        print(f"\n{variant_name}:")
        print(f"  Overall accuracy: {result['overall_accuracy']:.3f}")
        print(f"  Runtime: {result['total_runtime']:.1f}s")

        for benchmark, bench_result in result['results'].items():
            print(f"  {benchmark}: {bench_result['accuracy']:.3f}")

    print(f"\nAll results saved to individual run directories in 'runs/'")
    print("Use TensorBoard to visualize: tensorboard --logdir runs/")

except Exception as e:
    print(f"Experiment suite failed: {str(e)}")
    import traceback
    traceback.print_exc()


In [None]:
# Alternative: Run single experiment for testing
# Use this cell to test a single configuration before running the full suite

def run_single_experiment_test():
    """Run a single experiment for testing purposes"""

    print("Running single experiment test...")

    # Choose a fast configuration for testing
    config = ExperimentConfig(
        model_path=model_path,
        pass_type='double_full',  # or 'baseline' for faster testing
        residual_variant='raw',
        attn_impl='auto',  # Auto-detect: Flash Attention 2 if available, else SDPA
        temperature=0.7,
        max_length=512  # Shorter for testing
    )

    # Run just one benchmark for testing
    benchmarks = ['mmlu']  # Start with just MMLU

    try:
        result = run_experiment(config, benchmarks)
        print(f"\nTest completed successfully!")
        print(f"MMLU accuracy: {result['results']['mmlu']['accuracy']:.3f}")
        print(f"Runtime: {result['total_runtime']:.1f}s")
        print(f"Results saved to: {result['run_dir']}")

        return result

    except Exception as e:
        print(f"Single experiment test failed: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

# Uncomment the line below to run a single experiment test
test_result = run_single_experiment_test()


Running single experiment test...
Starting experiment: double_full
Logs will be saved to: runs/20250705_233613-double_full-raw
Loading model from /content/drive/MyDrive/phi3_3.8B...
✗ Experiment failed: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.
Single experiment test failed: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.


Traceback (most recent call last):
  File "/tmp/ipython-input-11-229376830.py", line 22, in run_single_experiment_test
    result = run_experiment(config, benchmarks)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-6-640186334.py", line 67, in run_experiment
    raise e
  File "/tmp/ipython-input-6-640186334.py", line 23, in run_experiment
    model = DoublePassPhi3(config.model_path, config)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-4-234219190.py", line 20, in __init__
    self.model = AutoModelForCausalLM.from_pretrained(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py", line 593, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py", line 311, in _wrapper
    return func(*args, **kwargs)
           ^

In [None]:
# Analysis and visualization utilities

def analyze_results(runs_dir: str = "runs/"):
    """Analyze results from all completed experiments"""

    import pandas as pd
    import matplotlib.pyplot as plt

    # Collect all results.csv files
    results_files = list(Path(runs_dir).glob("*/results.csv"))

    if not results_files:
        print("No results found. Run experiments first!")
        return

    # Load all results
    all_data = []
    for file in results_files:
        df = pd.read_csv(file)
        run_name = file.parent.name
        df['run_name'] = run_name
        all_data.append(df)

    combined_df = pd.concat(all_data, ignore_index=True)

    print(f"Found {len(combined_df)} benchmark results across {len(results_files)} runs")

    # Create summary plots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # 1. Accuracy by variant
    accuracy_by_variant = combined_df.groupby('variant')['accuracy'].mean().sort_values(ascending=False)
    axes[0, 0].bar(range(len(accuracy_by_variant)), accuracy_by_variant.values)
    axes[0, 0].set_xticks(range(len(accuracy_by_variant)))
    axes[0, 0].set_xticklabels(accuracy_by_variant.index, rotation=45, ha='right')
    axes[0, 0].set_title('Average Accuracy by Variant')
    axes[0, 0].set_ylabel('Accuracy')

    # 2. Runtime by variant
    runtime_by_variant = combined_df.groupby('variant')['runtime_sec'].mean()
    axes[0, 1].bar(range(len(runtime_by_variant)), runtime_by_variant.values)
    axes[0, 1].set_xticks(range(len(runtime_by_variant)))
    axes[0, 1].set_xticklabels(runtime_by_variant.index, rotation=45, ha='right')
    axes[0, 1].set_title('Average Runtime by Variant')
    axes[0, 1].set_ylabel('Runtime (seconds)')

    # 3. Accuracy by benchmark
    benchmark_pivot = combined_df.pivot_table(values='accuracy', index='benchmark', columns='variant', aggfunc='mean')
    benchmark_pivot.plot(kind='bar', ax=axes[1, 0])
    axes[1, 0].set_title('Accuracy by Benchmark and Variant')
    axes[1, 0].set_ylabel('Accuracy')
    axes[1, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')

    # 4. Compute overhead vs accuracy
    axes[1, 1].scatter(combined_df['compute_overhead'], combined_df['accuracy'], alpha=0.6)
    axes[1, 1].set_xlabel('Compute Overhead')
    axes[1, 1].set_ylabel('Accuracy')
    axes[1, 1].set_title('Accuracy vs Compute Overhead')

    plt.tight_layout()
    plt.savefig(f"{runs_dir}/analysis_summary.png", dpi=300, bbox_inches='tight')
    plt.show()

    # Print summary statistics
    print("\n" + "="*50)
    print("SUMMARY STATISTICS")
    print("="*50)

    print(f"\nBest performing variant:")
    best_variant = accuracy_by_variant.index[0]
    best_accuracy = accuracy_by_variant.iloc[0]
    print(f"  {best_variant}: {best_accuracy:.3f} accuracy")

    print(f"\nBaseline performance:")
    baseline_results = combined_df[combined_df['variant'].str.contains('baseline')]
    if not baseline_results.empty:
        baseline_acc = baseline_results['accuracy'].mean()
        print(f"  Baseline: {baseline_acc:.3f} accuracy")

        improvement = best_accuracy - baseline_acc
        print(f"  Best improvement: +{improvement:.3f} accuracy ({improvement/baseline_acc*100:.1f}%)")

    print(f"\nRuntime comparison:")
    baseline_runtime = combined_df[combined_df['variant'].str.contains('baseline')]['runtime_sec'].mean()
    double_runtime = combined_df[combined_df['variant'].str.contains('double')]['runtime_sec'].mean()
    if baseline_runtime > 0 and double_runtime > 0:
        overhead = double_runtime / baseline_runtime
        print(f"  Baseline: {baseline_runtime:.1f}s")
        print(f"  Double-pass: {double_runtime:.1f}s")
        print(f"  Overhead: {overhead:.1f}x")

    return combined_df

def launch_tensorboard(runs_dir: str = "runs/"):
    """Launch TensorBoard to visualize results"""

    print(f"Launching TensorBoard for {runs_dir}")
    print("Note: In Colab, you may need to use the public URL")

    # Install tensorboard extension for Colab
    try:
        get_ipython().system('pip install -q tensorboard-plugin-profile')

        # Load tensorboard extension
        get_ipython().run_line_magic('load_ext', 'tensorboard')

        # Launch tensorboard
        get_ipython().run_line_magic('tensorboard', f'--logdir {runs_dir}')

    except Exception as e:
        print(f"Could not launch TensorBoard in notebook: {str(e)}")
        print(f"Run manually: tensorboard --logdir {runs_dir}")

# Uncomment to run analysis after experiments complete
# results_df = analyze_results()
# launch_tensorboard()


In [None]:
# Attention Implementation Demo

def demo_attention_implementations():
    """Demo different attention implementations to verify fallback works"""

    print("Testing different attention implementations...")
    print("This verifies that the automatic fallback from Flash Attention to SDPA works correctly.\n")

    test_prompt = "The capital of France is"

    # Test different attention implementations
    attention_configs = [
        ("Auto-detect (recommended)", "auto"),
        ("Force SDPA (safe fallback)", "sdpa"),
        ("Force Eager (slowest but reliable)", "eager"),
        ("Force Flash Attention 2 (will fail if not available)", "flash2")
    ]

    for name, attn_impl in attention_configs:
        print(f"--- {name} ---")

        config = ExperimentConfig(
            model_path=model_path,
            pass_type='baseline',  # Use baseline for faster testing
            attn_impl=attn_impl,
            temperature=0.7,
            max_length=100
        )

        try:
            start_time = time.time()
            model = DoublePassPhi3(model_path, config)
            load_time = time.time() - start_time

            # Quick generation test
            result = model.generate_with_double_pass(test_prompt, max_new_tokens=10)

            print(f"✓ Success! Load time: {load_time:.2f}s")
            print(f"  Generated: {result['generated_text']}")
            print(f"  Generation time: {result['generation_time']:.3f}s")

            # Clean up
            del model
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"✗ Failed: {str(e)}")

        print()

# Uncomment to test attention implementations
print("Available attention implementation options:")
print("- 'auto': Automatically detect Flash Attention 2, fallback to SDPA")
print("- 'flash2': Force Flash Attention 2 (fails if not available)")
print("- 'sdpa': Force PyTorch SDPA (fast, reliable)")
print("- 'eager': Force eager attention (slowest but most compatible)")
print("\nRecommendation: Use 'auto' for automatic fallback behavior")

# demo_attention_implementations()


In [None]:
## 📊 Interactive Demo & Exploration

Use the cells below to interactively explore the double-pass generation and compare outputs between different configurations.


In [None]:
# Interactive prompt testing

def interactive_comparison(prompt: str, max_tokens: int = 50):
    """Compare baseline vs double-pass generation for a given prompt"""

    print(f"Prompt: '{prompt}'")
    print("="*60)

    configs = [
        ("Baseline", ExperimentConfig(model_path=model_path, pass_type='baseline')),
        ("Double-Pass Full", ExperimentConfig(model_path=model_path, pass_type='double_full')),
        ("Double-Pass 4 Layers", ExperimentConfig(model_path=model_path, pass_type='double_partial', second_pass_layers=4)),
        ("Double-Pass + LayerNorm", ExperimentConfig(model_path=model_path, pass_type='double_full', residual_variant='layernorm'))
    ]

    results = {}

    for name, config in configs:
        print(f"\n--- {name} ---")
        try:
            model = DoublePassPhi3(model_path, config)
            result = model.generate_with_double_pass(prompt, max_new_tokens=max_tokens)

            print(f"Generated: {result['generated_text']}")
            print(f"Time: {result['generation_time']:.2f}s")
            print(f"Tokens: {result['num_tokens']}")

            if result['token_stats']:
                entropies1 = [s['entropy_pass1'] for s in result['token_stats']]
                entropies2 = [s['entropy_pass2'] for s in result['token_stats']]
                print(f"Avg entropy pass 1: {np.mean(entropies1):.3f}")
                print(f"Avg entropy pass 2: {np.mean(entropies2):.3f}")

                # Show token-by-token comparison for first few tokens
                print("Token details (first 5):")
                for i, stats in enumerate(result['token_stats'][:5]):
                    print(f"  {i+1}. '{stats['token_text']}' - H1: {stats['entropy_pass1']:.3f}, H2: {stats['entropy_pass2']:.3f}")

            results[name] = result

            # Clean up
            del model
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"Error: {str(e)}")
            results[name] = None

    return results

# Test prompts for different reasoning types
test_prompts = [
    "The largest city in Australia is",
    "If I have 3 apples and give away 2, how many do I have left?",
    "The chemical symbol for gold is",
    "What is the square root of 64?",
    "Complete this sequence: 2, 4, 8, 16,",
    "The author of '1984' is"
]

print("Available test prompts:")
for i, prompt in enumerate(test_prompts):
    print(f"{i+1}. {prompt}")

# Example usage:
# results = interactive_comparison("The largest city in Australia is", max_tokens=30)


In [None]:
## 🔧 Flash Attention Compatibility Fix

### Problem Solved
This notebook now handles the Flash Attention 2 / PyTorch 2.6 compatibility issue automatically:

- **Before**: Hard-coded `flash_attention_2` caused crashes when FA2 wasn't compatible
- **After**: Intelligent fallback system that detects available attention implementations

### New Attention Parameter
All `ExperimentConfig` instances now support an `attn_impl` parameter:

```python
config = ExperimentConfig(
    model_path=model_path,
    pass_type='baseline',
    attn_impl='auto',  # ← New parameter
    # ... other params
)
```

### Attention Implementation Options

| Option | Description | When to Use |
|--------|-------------|-------------|
| `'auto'` | **Recommended**. Auto-detect Flash Attention 2, fallback to SDPA | Default for all use cases |
| `'sdpa'` | Force PyTorch SDPA (Scaled Dot Product Attention) | When you want consistent behavior |
| `'eager'` | Force basic eager attention | Debugging or maximum compatibility |
| `'flash2'` | Force Flash Attention 2 | Only when you know FA2 is available |

### Performance Impact
- **Flash Attention 2**: Fastest (~1.0x baseline)
- **SDPA**: Good performance (~1.3-1.5x slower than FA2)
- **Eager**: Slowest (~2-3x slower than FA2)

For research purposes, the performance difference between FA2 and SDPA is usually acceptable.


In [None]:
## 📝 Usage Instructions

### Quick Start
1. **Run the setup cells** (1-4) to install packages and load utilities
2. **Test model loading** with the quick test cell
3. **Try interactive demo** with a few prompts to verify everything works
4. **Run single experiment** to test one configuration
5. **Run full experiment suite** (will take 6-10 hours)
6. **Analyze results** using the analysis utilities

### File Structure
After running experiments, you'll have:
```
runs/
├── YYYYMMDD_HHMMSS-baseline-raw/
│   ├── run_config.json
│   ├── token_stats.jsonl
│   ├── results.csv
│   ├── tb/                # TensorBoard logs
│   └── summary.txt
├── YYYYMMDD_HHMMSS-double_full-raw/
│   └── ...
└── analysis_summary.png   # Generated by analysis
```

### Key Functions
- `interactive_comparison(prompt)` - Compare different configurations on a prompt
- `run_single_experiment_test()` - Test one configuration quickly
- `run_all_experiments()` - Run full experimental suite
- `analyze_results()` - Generate summary plots and statistics
- `launch_tensorboard()` - View detailed logs

### Expected Results
- **Hypothesis**: Double-pass should reduce entropy and improve accuracy
- **Metrics**: Track accuracy, entropy, runtime, compute overhead
- **Benchmarks**: MMLU, BigBench-Hard, GSM8K
- **Variants**: Baseline, full double-pass, partial double-pass, LayerNorm variants

### Flash Attention Compatibility
- **Fixed**: Automatic fallback from Flash Attention 2 to SDPA when FA2 isn't available
- **Default**: All configs use `attn_impl='auto'` for automatic detection
- **Performance**: SDPA is ~1.3x slower than FA2 but much faster than eager attention
- **Override**: Set `attn_impl='sdpa'` or `attn_impl='eager'` to force specific implementations

### Troubleshooting
- ✅ **Flash Attention errors**: Now automatically handled with SDPA fallback
- If model fails to load, check the path `/content/drive/MyDrive/phi3_3.8B`
- If out of memory, reduce `max_length` in configs
- If lm-eval fails, try installing older version: `pip install lm-eval==0.3.0`
- For TensorBoard issues, run manually: `tensorboard --logdir runs/`
