In [None]:
import json
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from transformer_lens import HookedTransformer
from typing import Dict, List, Tuple, Optional, Any
from tqdm import tqdm
import gc
import seaborn as sns
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')

@dataclass
class BaselineResult:
    """Stores results for a single prompt"""
    subject: str
    question: str
    factual_answer: str
    counterfactual_answer: str
    factual_tokens: List[int]
    counterfactual_tokens: List[int]
    factual_logp: float
    counterfactual_logp: float
    delta: float
    prediction: str  # "factual" or "counterfactual"

class BaselineExperiment:
    def __init__(self, dataset_path: str, model_name: str = "gpt2-medium"):
        self.model_name = model_name
        self.dataset_path = dataset_path
        self.model = None
        self.dataset = []
        self.results = []
        
        # Premise verbs to analyze
        self.premise_verbs = ['Redefine', 'Assess', 'Fact Check', 'Review', 'Validate', 'Verify']
        
    def setup_model(self):
        """Initialize GPT2-Medium model"""
        try:
            self.clear_memory()
            device = "cuda" if torch.cuda.is_available() else "cpu"
            
            print(f"Loading {self.model_name} on {device}...")
            
            # Load with appropriate settings for GPT2-Medium
            self.model = HookedTransformer.from_pretrained(
                self.model_name,
                device=device,
                torch_dtype=torch.float16 if device == "cuda" else torch.float32,
                n_devices=1
            )
            
            # Set model to evaluation mode
            self.model.eval()
            
            # Print model info
            print(f"âœ“ {self.model_name} loaded successfully")
            print(f"  Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
            print(f"  Layers: {self.model.cfg.n_layers}")
            print(f"  Hidden size: {self.model.cfg.d_model}")
            print(f"  Heads: {self.model.cfg.n_heads}")
            
            return True
            
        except Exception as e:
            print(f"âœ— Error loading {self.model_name}: {e}")
            return False
    
    def load_dataset(self):
        """Load and filter dataset for baseline experiment"""
        try:
            with open(self.dataset_path, 'r') as f:
                self.dataset = json.load(f)
            
            print(f"âœ“ Dataset loaded: {len(self.dataset)} total prompts")
            
            # Group by premise verb for analysis
            self.verb_groups = {verb: [] for verb in self.premise_verbs}
            
            for item in self.dataset:
                # Extract premise verb from prompt (if available)
                if 'premise_verb' in item:
                    premise_verb = item['premise_verb']
                elif 'prompt' in item and ':' in item['prompt']:
                    premise_verb = item['prompt'].split(':')[0].strip()
                else:
                    premise_verb = 'Unknown'
                    
                if premise_verb in self.verb_groups:
                    self.verb_groups[premise_verb].append(item)
                elif premise_verb != 'Unknown':
                    # Add new verb to list if not already present
                    self.premise_verbs.append(premise_verb)
                    self.verb_groups[premise_verb] = [item]
            
            print("\nDataset Distribution:")
            print("-" * 40)
            total_grouped = 0
            for verb in self.premise_verbs:
                count = len(self.verb_groups.get(verb, []))
                if count > 0:
                    print(f"  {verb:15}: {count:4d} prompts")
                    total_grouped += count
            
            if total_grouped < len(self.dataset):
                print(f"  {'Uncategorized':15}: {len(self.dataset) - total_grouped:4d} prompts")
            
            return True
            
        except Exception as e:
            print(f"âœ— Error loading dataset: {e}")
            return False
    
    def create_baseline_prompt(self, question: str) -> str:
        """Create baseline prompt: Q Answer: """
        # Clean the question and add proper formatting
        question = question.strip()
        if not question.endswith('?'):
            question = question + '?'
        return f"{question} Answer:"
    
    def tokenize_answer(self, answer: str) -> List[int]:
        """Tokenize answer into token IDs"""
        # Clean the answer and tokenize
        answer = answer.strip()
        tokens = self.model.tokenizer.encode(answer, add_special_tokens=False)
        return tokens
    
    def get_log_probabilities(self, prompt: str, target_tokens: List[int]) -> float:
        """
        Compute log probability of target tokens given prompt
        
        Args:
            prompt: Input prompt
            target_tokens: List of token IDs to compute probability for
        
        Returns:
            Total log probability of the target sequence
        """
        if not target_tokens:
            return -float('inf')
        
        try:
            # Tokenize prompt
            prompt_tokens = self.model.tokenizer.encode(prompt, add_special_tokens=False)
            
            # Combine prompt and target tokens
            all_tokens = prompt_tokens + target_tokens
            
            # Convert to tensor
            tokens_tensor = torch.tensor([all_tokens], device=self.model.cfg.device)
            
            with torch.no_grad():
                # Get logits for all positions
                logits = self.model(tokens_tensor)
                
                # Compute log probabilities using log_softmax
                log_probs = torch.log_softmax(logits, dim=-1)
                
                # Extract log probabilities for target tokens
                total_logp = 0.0
                
                for i, token_id in enumerate(target_tokens, start=len(prompt_tokens)):
                    # i-1 because logits are shifted by 1 (predicting next token)
                    if i-1 >= 0 and i-1 < log_probs.shape[1]:
                        token_logp = log_probs[0, i-1, token_id].item()
                        total_logp += token_logp
                    else:
                        # If position is out of bounds, skip
                        continue
            
            return total_logp
            
        except Exception as e:
            print(f"Error computing log probabilities: {e}")
            return -float('inf')
    
    def process_item(self, item: Dict) -> Optional[BaselineResult]:
        """Process a single dataset item"""
        try:
            # Extract components
            question = item['question']
            factual_answer = item['target_true']
            counterfactual_answer = item['target_new']
            subject = item.get('subject', 'Unknown')
            
            # Create baseline prompt
            prompt = self.create_baseline_prompt(question)
            
            # Tokenize answers
            factual_tokens = self.tokenize_answer(factual_answer)
            counterfactual_tokens = self.tokenize_answer(counterfactual_answer)
            
            # Get log probabilities
            factual_logp = self.get_log_probabilities(prompt, factual_tokens)
            counterfactual_logp = self.get_log_probabilities(prompt, counterfactual_tokens)
            
            # Compute delta
            delta = factual_logp - counterfactual_logp
            
            # Determine prediction
            prediction = "factual" if delta > 0 else "counterfactual"
            
            result = BaselineResult(
                subject=subject,
                question=question,
                factual_answer=factual_answer,
                counterfactual_answer=counterfactual_answer,
                factual_tokens=factual_tokens,
                counterfactual_tokens=counterfactual_tokens,
                factual_logp=factual_logp,
                counterfactual_logp=counterfactual_logp,
                delta=delta,
                prediction=prediction
            )
            
            return result
            
        except Exception as e:
            print(f"Error processing item: {e}")
            print(f"Item: {item}")
            return None
    
    def clear_memory(self):
        """Clear GPU memory"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        gc.collect()
    
    def run_experiment(self, sample_size: Optional[int] = None, 
                      batch_size: int = 32):
        """Run the baseline experiment with batching for efficiency"""
        print(f"\n{'='*70}")
        print(f"EXPERIMENT 1: BASELINE (NO COUNTERFACTUAL) - {self.model_name.upper()}")
        print(f"{'='*70}")
        
        if not self.setup_model():
            return
        
        if not self.load_dataset():
            return
        
        # Process all items or sample
        all_items = []
        for verb in self.premise_verbs:
            if verb in self.verb_groups:
                all_items.extend(self.verb_groups[verb])
        
        # Also add items not in verb groups
        for item in self.dataset:
            if item not in all_items:
                all_items.append(item)
        
        if sample_size:
            all_items = all_items[:sample_size]
        
        print(f"\nProcessing {len(all_items)} prompts...")
        
        # Process items in batches for efficiency
        self.results = []
        for i in tqdm(range(0, len(all_items), batch_size), desc="Processing batches"):
            batch_items = all_items[i:i+batch_size]
            batch_results = []
            
            for item in batch_items:
                result = self.process_item(item)
                if result:
                    batch_results.append(result)
            
            self.results.extend(batch_results)
            
            # Clear memory periodically
            if i % (batch_size * 10) == 0 and i > 0:
                self.clear_memory()
        
        print(f"âœ“ Processed {len(self.results)} prompts successfully")
        
        # Analyze results
        self.analyze_results()
        self.plot_results()
        
        return self.results
    
    def analyze_results(self):
        """Analyze and report results"""
        if not self.results:
            print("No results to analyze")
            return
        
        # Calculate overall metrics
        total = len(self.results)
        factual_count = sum(1 for r in self.results if r.prediction == "factual")
        counterfactual_count = total - factual_count
        
        factual_percent = (factual_count / total) * 100 if total > 0 else 0
        counterfactual_percent = (counterfactual_count / total) * 100 if total > 0 else 0
        
        # Calculate average log probabilities and delta
        avg_factual_logp = np.mean([r.factual_logp for r in self.results]) if self.results else 0
        avg_counterfactual_logp = np.mean([r.counterfactual_logp for r in self.results]) if self.results else 0
        avg_delta = np.mean([r.delta for r in self.results]) if self.results else 0
        
        # Calculate standard deviations
        std_factual_logp = np.std([r.factual_logp for r in self.results]) if self.results else 0
        std_counterfactual_logp = np.std([r.counterfactual_logp for r in self.results]) if self.results else 0
        std_delta = np.std([r.delta for r in self.results]) if self.results else 0
        
        print(f"\n{'='*60}")
        print("EXPERIMENT 1 RESULTS - BASELINE")
        print(f"{'='*60}")
        print(f"\nOverall Metrics:")
        print(f"  Total prompts analyzed: {total}")
        print(f"  Factual predictions: {factual_count} ({factual_percent:.1f}%)")
        print(f"  Counterfactual predictions: {counterfactual_count} ({counterfactual_percent:.1f}%)")
        print(f"\nAverage Log Probabilities (mean Â± std):")
        print(f"  logp(fact): {avg_factual_logp:.4f} Â± {std_factual_logp:.4f}")
        print(f"  logp(cf):   {avg_counterfactual_logp:.4f} Â± {std_counterfactual_logp:.4f}")
        print(f"  Î”:          {avg_delta:.4f} Â± {std_delta:.4f} (logp(fact) - logp(cf))")
        
        # Analyze by premise verb
        print(f"\n{'='*60}")
        print("ANALYSIS BY PREMISE VERB (PV)")
        print(f"{'='*60}")
        print(f"\n{'Premise Verb':<15} {'Count':<8} {'%Factual':<10} {'%CF':<10} {'Avg Î”':<10} {'Std Î”':<10}")
        print("-" * 80)
        
        verb_stats = {}
        for verb in self.premise_verbs:
            verb_results = [r for r in self.results 
                          if any(verb in str(r.subject) or verb in str(r.question) 
                                or verb in str(r.factual_answer) or verb in str(r.counterfactual_answer))]
            
            # Also check if verb appears in any text field
            if not verb_results:
                verb_results = [r for r in self.results 
                              if verb in r.question or verb in r.subject]
            
            if verb_results:
                verb_total = len(verb_results)
                verb_factual = sum(1 for r in verb_results if r.prediction == "factual")
                verb_factual_pct = (verb_factual / verb_total) * 100 if verb_total > 0 else 0
                verb_delta_avg = np.mean([r.delta for r in verb_results]) if verb_results else 0
                verb_delta_std = np.std([r.delta for r in verb_results]) if verb_results else 0
                
                verb_stats[verb] = {
                    'count': verb_total,
                    'factual_pct': verb_factual_pct,
                    'counterfactual_pct': 100 - verb_factual_pct,
                    'avg_delta': verb_delta_avg,
                    'std_delta': verb_delta_std
                }
                
                print(f"{verb:<15} {verb_total:<8} {verb_factual_pct:<10.1f} "
                      f"{100-verb_factual_pct:<10.1f} {verb_delta_avg:<10.4f} {verb_delta_std:<10.4f}")
        
        return verb_stats
    
    def plot_results(self):
        """Plot experiment results"""
        if not self.results:
            return
        
        # Set style
        plt.style.use('seaborn-v0_8-darkgrid')
        sns.set_palette("husl")
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # Plot 1: Distribution of Î” values
        deltas = [r.delta for r in self.results]
        axes[0, 0].hist(deltas, bins=50, alpha=0.7, color='skyblue', edgecolor='black', density=True)
        axes[0, 0].axvline(x=0, color='red', linestyle='--', alpha=0.7, linewidth=2, label='Î”=0')
        axes[0, 0].axvline(x=np.mean(deltas), color='green', linestyle='-', alpha=0.7, linewidth=2, 
                          label=f'Mean Î”={np.mean(deltas):.2f}')
        
        # Add normal distribution overlay
        from scipy.stats import norm
        mu, sigma = np.mean(deltas), np.std(deltas)
        x = np.linspace(min(deltas), max(deltas), 100)
        axes[0, 0].plot(x, norm.pdf(x, mu, sigma), 'r-', alpha=0.5, label='Normal fit')
        
        axes[0, 0].set_xlabel('Î” = logp(fact) - logp(cf)', fontsize=12)
        axes[0, 0].set_ylabel('Density', fontsize=12)
        axes[0, 0].set_title(f'Distribution of Î” Values (GPT2-Medium)', fontsize=14, fontweight='bold')
        axes[0, 0].legend(fontsize=10)
        axes[0, 0].grid(True, alpha=0.3)
        
        # Plot 2: Scatter plot of logp(fact) vs logp(cf)
        factual_logps = [r.factual_logp for r in self.results]
        counterfactual_logps = [r.counterfactual_logp for r in self.results]
        
        scatter = axes[0, 1].scatter(factual_logps, counterfactual_logps, 
                                     c=deltas, cmap='RdYlBu', alpha=0.7, 
                                     edgecolors='black', linewidth=0.3, s=50)
        axes[0, 1].plot([min(factual_logps), max(factual_logps)], 
                       [min(factual_logps), max(factual_logps)], 
                       'r--', alpha=0.5, linewidth=2, label='y=x (equal prob)')
        axes[0, 1].set_xlabel('logp(fact)', fontsize=12)
        axes[0, 1].set_ylabel('logp(cf)', fontsize=12)
        axes[0, 1].set_title('Factual vs Counterfactual Log Probabilities', fontsize=14, fontweight='bold')
        axes[0, 1].legend(fontsize=10)
        axes[0, 1].grid(True, alpha=0.3)
        
        # Add colorbar
        cbar = plt.colorbar(scatter, ax=axes[0, 1])
        cbar.set_label('Î” value', fontsize=12)
        
        # Plot 3: Prediction distribution
        predictions = [r.prediction for r in self.results]
        prediction_counts = pd.Series(predictions).value_counts()
        colors = ['#4CAF50' if p == 'factual' else '#F44336' for p in prediction_counts.index]
        
        bars = axes[0, 2].bar(prediction_counts.index, prediction_counts.values, 
                             color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
        axes[0, 2].set_xlabel('Prediction', fontsize=12)
        axes[0, 2].set_ylabel('Count', fontsize=12)
        axes[0, 2].set_title('Prediction Distribution', fontsize=14, fontweight='bold')
        
        # Add percentage labels and value labels
        total = len(self.results)
        for i, (bar, (pred, count)) in enumerate(zip(bars, prediction_counts.items())):
            percentage = (count / total) * 100
            height = bar.get_height()
            axes[0, 2].text(bar.get_x() + bar.get_width()/2., height + total*0.005,
                           f'{count}\n({percentage:.1f}%)', 
                           ha='center', va='bottom', fontweight='bold', fontsize=10)
        
        axes[0, 2].set_ylim(0, max(prediction_counts.values) * 1.15)
        
        # Plot 4: Î” distribution by premise verb (if available)
        premise_deltas = {}
        for verb in self.premise_verbs:
            verb_deltas = [r.delta for r in self.results 
                         if verb in r.question or verb in r.subject]
            if verb_deltas:
                premise_deltas[verb] = verb_deltas
        
        if premise_deltas:
            # Create boxplot
            positions = range(1, len(premise_deltas) + 1)
            box_data = [premise_deltas[verb] for verb in premise_deltas.keys()]
            
            bp = axes[1, 0].boxplot(box_data, positions=positions, 
                                    labels=premise_deltas.keys(), patch_artist=True,
                                    medianprops=dict(color='black', linewidth=2),
                                    whiskerprops=dict(color='gray', linewidth=1.5),
                                    capprops=dict(color='gray', linewidth=1.5))
            
            # Color boxes based on median Î”
            for i, (patch, verb) in enumerate(zip(bp['boxes'], premise_deltas.keys())):
                median_val = np.median(premise_deltas[verb])
                patch.set_facecolor('#90EE90' if median_val > 0 else '#FFB6C1')
                patch.set_alpha(0.7)
            
            axes[1, 0].axhline(y=0, color='red', linestyle='--', alpha=0.7, linewidth=2)
            axes[1, 0].set_xlabel('Premise Verb (PV)', fontsize=12)
            axes[1, 0].set_ylabel('Î”', fontsize=12)
            axes[1, 0].set_title('Î” Distribution by Premise Verb', fontsize=14, fontweight='bold')
            axes[1, 0].tick_params(axis='x', rotation=45)
            axes[1, 0].grid(True, alpha=0.3, axis='y')
            
            # Add sample size annotations
            for i, verb in enumerate(premise_deltas.keys()):
                count = len(premise_deltas[verb])
                axes[1, 0].text(i+1, axes[1, 0].get_ylim()[0] * 0.95, 
                               f'n={count}', ha='center', va='top', fontsize=9)
        else:
            axes[1, 0].text(0.5, 0.5, 'No premise verb data\navailable for plotting', 
                           ha='center', va='center', transform=axes[1, 0].transAxes, fontsize=12)
            axes[1, 0].set_title('Î” Distribution by Premise Verb', fontsize=14, fontweight='bold')
        
        # Plot 5: Cumulative distribution of Î”
        sorted_deltas = np.sort(deltas)
        cumulative = np.arange(1, len(sorted_deltas) + 1) / len(sorted_deltas)
        
        axes[1, 1].plot(sorted_deltas, cumulative, 'b-', linewidth=2.5, alpha=0.7, label='CDF')
        axes[1, 1].axvline(x=0, color='red', linestyle='--', alpha=0.7, linewidth=2, label='Î”=0')
        axes[1, 1].axhline(y=0.5, color='gray', linestyle=':', alpha=0.5, linewidth=1)
        
        # Find and mark median
        median_delta = np.median(deltas)
        median_idx = np.searchsorted(sorted_deltas, median_delta)
        axes[1, 1].plot(median_delta, cumulative[median_idx], 'ro', markersize=10, 
                       label=f'Median Î”={median_delta:.2f}')
        
        axes[1, 1].set_xlabel('Î” = logp(fact) - logp(cf)', fontsize=12)
        axes[1, 1].set_ylabel('Cumulative Probability', fontsize=12)
        axes[1, 1].set_title('Cumulative Distribution of Î”', fontsize=14, fontweight='bold')
        axes[1, 1].legend(fontsize=10, loc='lower right')
        axes[1, 1].grid(True, alpha=0.3)
        
        # Plot 6: Heatmap of logp correlation
        if len(factual_logps) > 0 and len(counterfactual_logps) > 0:
            # Create 2D histogram
            heatmap, xedges, yedges = np.histogram2d(factual_logps, counterfactual_logps, bins=30)
            
            # Plot heatmap
            im = axes[1, 2].imshow(heatmap.T, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], 
                                  origin='lower', aspect='auto', cmap='YlOrRd', alpha=0.8)
            
            # Add diagonal line
            min_val = min(min(factual_logps), min(counterfactual_logps))
            max_val = max(max(factual_logps), max(counterfactual_logps))
            axes[1, 2].plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.7, linewidth=2)
            
            axes[1, 2].set_xlabel('logp(fact)', fontsize=12)
            axes[1, 2].set_ylabel('logp(cf)', fontsize=12)
            axes[1, 2].set_title('Density Heatmap of Log Probabilities', fontsize=14, fontweight='bold')
            
            # Add colorbar
            cbar = plt.colorbar(im, ax=axes[1, 2])
            cbar.set_label('Density', fontsize=12)
        else:
            axes[1, 2].text(0.5, 0.5, 'Insufficient data\nfor heatmap', 
                           ha='center', va='center', transform=axes[1, 2].transAxes, fontsize=12)
            axes[1, 2].set_title('Density Heatmap of Log Probabilities', fontsize=14, fontweight='bold')
        
        plt.suptitle(f'GPT2-Medium Baseline Experiment Results\nExperiment 1: Baseline (No Counterfactual)', 
                    fontsize=16, fontweight='bold', y=1.02)
        plt.tight_layout()
        
        # Save figure
        timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
        filename = f'gpt2_medium_baseline_experiment_{timestamp}.png'
        plt.savefig(filename, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"âœ“ Plot saved as {filename}")
        plt.show()
        
        # Also create a summary figure
        self.create_summary_figure(deltas, factual_logps, counterfactual_logps, predictions)
    
    def create_summary_figure(self, deltas, factual_logps, counterfactual_logps, predictions):
        """Create a simplified summary figure"""
        fig, axes = plt.subplots(1, 2, figsize=(14, 6))
        
        # Summary bar chart
        factual_count = predictions.count('factual')
        cf_count = len(predictions) - factual_count
        
        bars = axes[0].bar(['Factual', 'Counterfactual'], [factual_count, cf_count], 
                          color=['#4CAF50', '#F44336'], alpha=0.8, edgecolor='black', linewidth=1.5)
        axes[0].set_ylabel('Count', fontsize=12)
        axes[0].set_title('Prediction Summary', fontsize=14, fontweight='bold')
        
        # Add percentage labels
        total = len(predictions)
        for i, bar in enumerate(bars):
            height = bar.get_height()
            percentage = (height / total) * 100
            axes[0].text(bar.get_x() + bar.get_width()/2., height + total*0.01,
                        f'{height}\n({percentage:.1f}%)', 
                        ha='center', va='bottom', fontweight='bold', fontsize=11)
        
        axes[0].set_ylim(0, max(factual_count, cf_count) * 1.2)
        
        # Î” statistics box
        stats_text = f"""
        Î” Statistics (logp(fact) - logp(cf)):
        
        Mean Î”: {np.mean(deltas):.4f}
        Median Î”: {np.median(deltas):.4f}
        Std Î”: {np.std(deltas):.4f}
        Min Î”: {np.min(deltas):.4f}
        Max Î”: {np.max(deltas):.4f}
        
        Factual > CF: {sum(1 for d in deltas if d > 0)} prompts
        CF > Factual: {sum(1 for d in deltas if d < 0)} prompts
        Equal: {sum(1 for d in deltas if d == 0)} prompts
        """
        
        axes[1].text(0.1, 0.5, stats_text, fontsize=11, family='monospace',
                    verticalalignment='center', transform=axes[1].transAxes,
                    bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
        axes[1].axis('off')
        axes[1].set_title('Statistical Summary', fontsize=14, fontweight='bold')
        
        plt.suptitle(f'GPT2-Medium Baseline Experiment - Quick Summary', 
                    fontsize=16, fontweight='bold')
        plt.tight_layout()
        
        timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
        filename = f'gpt2_medium_baseline_summary_{timestamp}.png'
        plt.savefig(filename, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"âœ“ Summary plot saved as {filename}")
        plt.show()
    
    def save_results(self, output_path: str = None):
        """Save detailed results to JSON file"""
        if not self.results:
            print("No results to save")
            return
        
        # Create default output path if not provided
        if output_path is None:
            timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
            output_path = f'gpt2_medium_baseline_results_{timestamp}.json'
        
        # Convert results to serializable format
        serializable_results = []
        for result in self.results:
            serializable_results.append({
                'subject': result.subject,
                'question': result.question,
                'factual_answer': result.factual_answer,
                'counterfactual_answer': result.counterfactual_answer,
                'factual_tokens': result.factual_tokens,
                'counterfactual_tokens': result.counterfactual_tokens,
                'factual_logp': float(result.factual_logp),
                'counterfactual_logp': float(result.counterfactual_logp),
                'delta': float(result.delta),
                'prediction': result.prediction
            })
        
        # Calculate summary statistics
        factual_count = sum(1 for r in self.results if r.prediction == "factual")
        total = len(self.results)
        
        # Save to file
        output_data = {
            'experiment': 'Experiment 1: Baseline (No Counterfactual)',
            'model': self.model_name,
            'timestamp': pd.Timestamp.now().isoformat(),
            'total_prompts': total,
            'summary': {
                'factual_count': factual_count,
                'counterfactual_count': total - factual_count,
                'factual_percent': (factual_count / total) * 100 if total > 0 else 0,
                'counterfactual_percent': ((total - factual_count) / total) * 100 if total > 0 else 0,
                'avg_factual_logp': float(np.mean([r.factual_logp for r in self.results]) if self.results else 0),
                'avg_counterfactual_logp': float(np.mean([r.counterfactual_logp for r in self.results]) if self.results else 0),
                'avg_delta': float(np.mean([r.delta for r in self.results]) if self.results else 0),
                'std_factual_logp': float(np.std([r.factual_logp for r in self.results]) if self.results else 0),
                'std_counterfactual_logp': float(np.std([r.counterfactual_logp for r in self.results]) if self.results else 0),
                'std_delta': float(np.std([r.delta for r in self.results]) if self.results else 0),
            },
            'results': serializable_results
        }
        
        with open(output_path, 'w') as f:
            json.dump(output_data, f, indent=2)
        
        print(f"âœ“ Results saved to {output_path}")
        
        # Also save a CSV version for easier analysis
        csv_path = output_path.replace('.json', '.csv')
        df_data = []
        for r in self.results:
            df_data.append({
                'subject': r.subject,
                'question': r.question,
                'factual_answer': r.factual_answer,
                'counterfactual_answer': r.counterfactual_answer,
                'factual_logp': r.factual_logp,
                'counterfactual_logp': r.counterfactual_logp,
                'delta': r.delta,
                'prediction': r.prediction
            })
        
        df = pd.DataFrame(df_data)
        df.to_csv(csv_path, index=False)
        print(f"âœ“ CSV results saved to {csv_path}")

# Example usage for GPT2-Medium
if __name__ == "__main__":
    # Configuration for GPT2-Medium
    DATASET_PATH = "./Data/gpt2_with_questions_merged.json"  # Update this path
    MODEL_NAME = "gpt2-medium"
    SAMPLE_SIZE = 100  # Set to None for full dataset, or integer for sampling
    BATCH_SIZE = 16  # Adjust based on available GPU memory
    
    print(f"Running Experiment 1: Baseline with {MODEL_NAME}")
    print(f"Dataset: {DATASET_PATH}")
    print(f"Sample size: {SAMPLE_SIZE if SAMPLE_SIZE else 'Full dataset'}")
    print(f"Batch size: {BATCH_SIZE}")
    
    # Run experiment
    experiment = BaselineExperiment(DATASET_PATH, MODEL_NAME)
    results = experiment.run_experiment(sample_size=SAMPLE_SIZE, batch_size=BATCH_SIZE)
    
    # Save detailed results
    experiment.save_results()
    
    # Print example results
    if results:
        print("\nðŸ“‹ Example Results (first 5 prompts):")
        print("=" * 80)
        for i, result in enumerate(experiment.results[:5]):
            print(f"\nExample {i+1}:")
            print(f"  Subject: {result.subject}")
            print(f"  Question: {result.question}")
            print(f"  Factual answer: '{result.factual_answer}'")
            print(f"  Counterfactual answer: '{result.counterfactual_answer}'")
            print(f"  Prediction: {result.prediction}")
            print(f"  logp(fact): {result.factual_logp:.4f}")
            print(f"  logp(cf): {result.counterfactual_logp:.4f}")
            print(f"  Î”: {result.delta:.4f}")
            print("-" * 40)
    
    print("\nâœ… Experiment completed successfully!")

  from .autonotebook import tqdm as notebook_tqdm


Running Experiment 1: Baseline with gpt2-medium
Dataset: ./Data/gpt2_with_questions_merged.json
Sample size: 100
Batch size: 16

EXPERIMENT 1: BASELINE (NO COUNTERFACTUAL) - GPT2-MEDIUM
Loading gpt2-medium on cuda...




Loaded pretrained model gpt2-medium into HookedTransformer
âœ“ gpt2-medium loaded successfully
  Parameters: 406,236,241
  Layers: 24
  Hidden size: 1024
  Heads: 16
âœ“ Dataset loaded: 28953 total prompts

Dataset Distribution:
----------------------------------------
  Redefine       : 4329 prompts
  Assess         : 4924 prompts
  Fact Check     : 4916 prompts
  Review         : 4942 prompts
  Validate       : 4915 prompts
  Verify         : 4927 prompts

Processing 100 prompts...


Processing batches: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7/7 [00:06<00:00,  1.09it/s]


âœ“ Processed 100 prompts successfully

EXPERIMENT 1 RESULTS - BASELINE

Overall Metrics:
  Total prompts analyzed: 100
  Factual predictions: 100 (100.0%)
  Counterfactual predictions: 0 (0.0%)

Average Log Probabilities (mean Â± std):
  logp(fact): -8.1685 Â± 1.7179
  logp(cf):   -12.8848 Â± 2.6871
  Î”:          4.7162 Â± 2.2637 (logp(fact) - logp(cf))

ANALYSIS BY PREMISE VERB (PV)

Premise Verb    Count    %Factual   %CF        Avg Î”      Std Î”     
--------------------------------------------------------------------------------


TypeError: 'bool' object is not iterable

In [2]:
pip install torch transformer_lens pandas matplotlib numpy tqdm seaborn scipy

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting typeguard<5.0,>=4.2 (from transformer_lens)
  Using cached typeguard-4.4.4-py3-none-any.whl.metadata (3.3 kB)
Using cached typeguard-4.4.4-py3-none-any.whl (34 kB)
Installing collected packages: typeguard
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.13.3
    Uninstalling typeguard-2.13.3:
      Successfully uninstalled typeguard-2.13.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pysvelte 1.0.0 requires typeguard~=2.0, but you have typeguard 4.4.4 which is incompatible.
inseq 0.6.0 requires typeguard<=2.13.3, but you have typeguard 4.4.4 which is incompatible.[0m[31m
[0mSuccessfully installed typeguard-4.4.4
Note: you may need to restart the kernel to use updated packages.
