In [1]:
import os
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Optional, Tuple, Union
import json
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
import re
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

2025-07-09 19:21:55.773315: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752088916.148414      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752088916.255592      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
!pip install huggingface_hub



In [None]:
class TruthFlowInterventionHook:
    """Hook for applying TruthFlow interventions during model generation"""

    def __init__(self, flow_model, svd_basis, layer_id: int,
                 alpha: float = 5):
        """
        Args:
            flow_model: Trained flow model for this layer
            svd_basis: SVD basis vectors [k, D] from project_svd.py
            layer_id: Which layer to intervene on
            alpha: Flow correction strength
            
        """
        self.flow_model = flow_model
        self.svd_basis = svd_basis  # [k, D]
        self.layer_id = layer_id
        self.alpha = alpha

        # Store original activations for flow model
        self.original_activations = {}

    def intervention_hook(self, module, input, output):
      
        
    # If output is a tuple, get the first element
        output_tensor = output[0] if isinstance(output, tuple) else output
        batch_size, seq_len, hidden_dim = output_tensor.shape
    
        dtype = output_tensor.dtype
        device = output_tensor.device
        # Get last token representation (query representation)
        h_q = output_tensor[:, -1, :]  # [B, D]
        self.original_activations[self.layer_id] = h_q.clone()  # optional logging

        # Apply full correction (flow + projection)
        h_corrected = self.compute_truthflow_intervention(h_q)  # [B, D]
        h_corrected = h_corrected.to(dtype=dtype, device=device)
        # Correction vector
        delta = h_corrected - h_q
        delta_broadcast = delta.unsqueeze(1).expand(-1, seq_len, -1)  # [B, T, D]
        output_corrected = output_tensor + delta_broadcast
    
        return (output_corrected,) + output[1:] if isinstance(output, tuple) else output_corrected

    def compute_truthflow_intervention(self, h_q: torch.Tensor) -> torch.Tensor:
    
        device = h_q.device
        self.flow_model.to(device)
        self.svd_basis = self.svd_basis.to(device)
        self.flow_model.eval()


        with torch.no_grad():
            t = torch.full((h_q.size(0), 1), 0.5, device=device)  # midpoint time step
            h_flow = h_q + self.flow_model(t, h_q)  # [B, D]

            delta = h_flow - h_q  # [B, D]
            projections = torch.matmul(delta, self.svd_basis.T)       # [B, k]
            projected_delta = torch.matmul(projections, self.svd_basis)  # [B, D]
            h_corrected = h_q + projected_delta  # final corrected representation
        return h_corrected 



In [5]:
class TruthFlowInterventionManager:
    """Manager for applying TruthFlow interventions to multiple layers"""

    def __init__(self, model, flow_models_dir: str, svd_basis_dir: str,
                 intervention_layers: List[int], device: str = "cuda", pair = "natural_questions"):
        """
        Args:
            model: The language model to intervene on
            flow_models_dir: Directory containing trained flow models
            svd_basis_dir: Directory containing SVD basis vectors
            intervention_layers: List of layer indices to intervene on
            device: Device to run on
        """
        self.model = model
        self.intervention_layers = intervention_layers
        self.device = device
        self.pair = pair
        self.hooks = {}
        self.intervention_hooks = {}

        # Load all required models and basis vectors
        self.load_intervention_components(flow_models_dir, svd_basis_dir)

    def load_intervention_components(self, flow_models_dir: str, svd_basis_dir: str):
        """Load flow models and SVD basis vectors for each intervention layer"""
        print("Loading intervention components...")
        

        for layer_id in self.intervention_layers:
            flow_path = os.path.join(flow_models_dir, f"flow_model_layer{layer_id}_truthfulqa_{self.pair}.pt")
            svd_path = os.path.join(svd_basis_dir, f"layer{layer_id}_truthfulqa_{self.pair}.pt")

            if not os.path.exists(flow_path):
                raise FileNotFoundError(f"Flow model not found: {flow_path}")
            if not os.path.exists(svd_path):
                raise FileNotFoundError(f"SVD basis not found: {svd_path}")


            print(f"✓ Loading flow model from {flow_path}")
            print(f"✓ Loading SVD basis from {svd_path}")

    
            flow_model = self.create_flow_model(hidden_dim=self.get_hidden_dim_for_layer(layer_id))
            flow_model.load_state_dict(torch.load(flow_path, map_location=self.device))
            flow_model.to(self.device)

            svd_basis = torch.load(svd_path, map_location=self.device)

            # Create intervention hook
            self.intervention_hooks[layer_id] = TruthFlowInterventionHook(
                flow_model=flow_model,
                svd_basis=svd_basis,
                layer_id=layer_id
            )

            print(f"✓ Initialized intervention for layer {layer_id}")

    def get_hidden_dim_for_layer(self, layer_id: int) -> int:
        """Get hidden dimension for a specific layer"""
        # This is model-specific - you may need to adjust based on your model
        # For most transformer models, all layers have the same hidden dimension
        return self.model.config.hidden_size

    def create_flow_model(self, hidden_dim: int):
        """Recreate the flow model architecture"""

        import sys
        sys.path.append('/kaggle/input/final-req-fixed')
        from train_flow_model import FixedFlowModel
        return FixedFlowModel(hidden_dim)


    def get_layer_module(self, layer_id: int):
        """Get the specific layer module - Gemma 2 model architecture"""
        # For Gemma 2 models
        layer = self.model.model.layers[layer_id]
        print(f"Hooking into model layer {layer_id}: {layer}")
        return layer
        

    def register_hooks(self):
        """Register intervention hooks on the model"""
        print("Registering intervention hooks...")

        for layer_id in self.intervention_layers:
            module = self.get_layer_module(layer_id)
            hook = module.register_forward_hook(self.intervention_hooks[layer_id].intervention_hook)
            self.hooks[layer_id] = hook
            print(f"✓ Hook registered for layer {layer_id}")
    

    def remove_hooks(self):
        """Remove all registered hooks"""
        for layer_id, hook in self.hooks.items():
            hook.remove()
        self.hooks.clear()
        print("✓ Removed all intervention hooks")

In [6]:
'''
def generate_with_truthflow(model, tokenizer, prompt: str,
                          intervention_manager: TruthFlowInterventionManager,
                          max_length: int = 100, temperature: float = 0.7,
                          do_sample: bool = True) -> str:
    """Generate text with TruthFlow interventions applied"""
    
    # Register hooks
    intervention_manager.register_hooks()

    try:
        # Tokenize input
        inputs = tokenizer(prompt, return_tensors="pt").to(intervention_manager.device)

        # Generate with interventions
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=max_length,
                temperature=temperature,
                do_sample=do_sample,
                pad_token_id=tokenizer.eos_token_id
            )

        # Decode output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

        return generated_text

    finally:
        # Always remove hooks
        intervention_manager.remove_hooks()


'''

In [7]:
from datasets import load_dataset

def load_eval_data():
    ds = load_dataset("truthful_qa", "multiple_choice", split="validation")
    ds = ds.select(range(408, 817))
    return ds


In [None]:
from scipy.special import logsumexp
def evaluate_layerwise_mcq(
    model_name: str,
    flow_models_dir: str,
    svd_basis_dir: str,
    tokenizer=None,
    layers=[8,10,12,20],
    device="cuda"
):
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch, json

    eval_data = load_eval_data()

    tokenizer = tokenizer or AutoTokenizer.from_pretrained(
    model_name,
    token='hf-access-token'  
    )
    model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.float16,
    token='hf-access-token'
    ).to(device)
    print("Number of layers in model:", len(model.model.layers))

    all_mc_scores = {}

    for layer in layers:
        print(f"\n🔍 Evaluating intervention at layer {layer}")

        manager = TruthFlowInterventionManager(
            model=model,
            flow_models_dir=flow_models_dir,
            svd_basis_dir=svd_basis_dir,
            intervention_layers=[layer],
            device=device,
            pair = "nq")
        manager.register_hooks()

        mc1_total = 0
        mc2_total = 0

        for item in tqdm(eval_data):
            q = item["question"]
            choices = item["mc1_targets"]["choices"]
            correct_indices = item["mc1_targets"]["labels"]
            if isinstance(correct_indices, int):
                correct_indices = [correct_indices]
            correct_indices = list(set(correct_indices))

            # Format as question + choice
            log_probs = []
            for choice in choices:
                prompt = f"{q} {choice}"
                full_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
                question_input_ids = tokenizer(q, return_tensors="pt").input_ids.to(device)
                question_len = question_input_ids.shape[1]
                
                with torch.no_grad():
                    outputs = model(full_input_ids)
                    logits = outputs.logits[0]
                    answer_start_idx = question_len  # This includes the space token
                    answer_logits = logits[answer_start_idx - 1 : -1]  
                    answer_tokens = full_input_ids[0][answer_start_idx:] 
                    answer_log_probs = torch.log_softmax(answer_logits, dim=-1)
                    
                    # Gather probabilities for actual answer tokens
                    token_log_probs = answer_log_probs[range(len(answer_tokens)), answer_tokens]
                    avg_log_prob = token_log_probs.mean().item()
                    log_probs.append(avg_log_prob)
            
        
            log_probs_np = np.array(log_probs)

            # MC1: highest probability is best answer
            model_best_index = np.argmax(log_probs_np)
            if model_best_index in correct_indices:
                mc1_total += 1

            # MC2: normalized correct answer mass
            probs = np.exp(log_probs_np)
            probs = probs / np.sum(probs) 
            correct_prob_mass = sum(probs[i] for i in correct_indices)
            mc2_total += correct_prob_mass
            

        mc1_score = 100 * mc1_total / len(eval_data)
        mc2_score = 100 * mc2_total / len(eval_data)
        all_mc_scores[layer] = {"MC1": mc1_score, "MC2": mc2_score}
        manager.remove_hooks()

        print(f"✓ Layer {layer} | MC1: {mc1_score:.2f} | MC2: {mc2_score:.2f}")

    with open("truthflow_mcq_scores.json", "w") as f:
        json.dump(all_mc_scores, f, indent=2)

    return all_mc_scores


In [None]:
results = evaluate_layerwise_mcq(
    model_name="google/gemma-2-2b",
    flow_models_dir="/kaggle/input/combined-correct",
    svd_basis_dir="/kaggle/input/combined-correct",
    layers=[20]  # intermediate layers (as per paper)
)