In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats as scipy_stats
import re
import ast
from typing import Dict, List, Any
from pathlib import Path


In [None]:
import logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)
tf.get_logger().setLevel('ERROR')


In [None]:
class LLMResultsAnalyzer:
    def __init__(self):
        """Initialize the analyzer with dataset paths"""
        self.datasets = {
            'congruent': 'congruent_incongruent_1000/test/comparison_congruent_size_test.tfrecord',
            'incongruent': 'congruent_incongruent_1000/test/comparison_incongruent_size_test.tfrecord',
            'random': 'congruent_incongruent_1000/test/comparison_random_string_size_test.tfrecord',
            'permuted': 'congruent_incongruent_1000/test/comparison_permuted_size_test.tfrecord'
        }
        
        self.prompt_types = ['baseline', 'numberline', 'road', 'circle', 'clusters', 
                            '3d_space', 'cloud', 'hyperbolic', 'cot']
        
        self.ground_truth = self.load_ground_truth()
    
    def parse_tfrecord_example(self, serialized_item: bytes) -> Dict[str, Any]:
        """Parse a single TFRecord example"""
        feature_description = {
            'question': tf.io.FixedLenFeature([], tf.string),
            'answer': tf.io.FixedLenFeature([], tf.string),
            'metadata': tf.io.FixedLenFeature([], tf.string),
            'question_only': tf.io.FixedLenFeature([], tf.string),
            'alternative_answers': tf.io.VarLenFeature(tf.string),
            'index': tf.io.FixedLenFeature([], tf.int64),
        }
        return tf.io.parse_single_example(serialized_item, feature_description)
    
    def load_ground_truth(self) -> Dict[str, List[str]]:
        """Load ground truth answers from tfrecord files"""
        ground_truth = {}
        
        for condition, filepath in self.datasets.items():
            print(f"Loading ground truth for {condition}...")
            dataset = tf.data.TFRecordDataset(filepath)
            # For permuted dataset, use simpler parsing
            if condition == 'permuted':
                feature_description = {
                    'question': tf.io.FixedLenFeature([], tf.string),
                    'answer': tf.io.FixedLenFeature([], tf.string),
                    'metadata': tf.io.FixedLenFeature([], tf.string),
                }
            else:
                # For original datasets, include all fields
                feature_description = {
                    'question': tf.io.FixedLenFeature([], tf.string),
                    'answer': tf.io.FixedLenFeature([], tf.string),
                    'metadata': tf.io.FixedLenFeature([], tf.string),
                    'question_only': tf.io.FixedLenFeature([], tf.string),
                    'alternative_answers': tf.io.VarLenFeature(tf.string),
                    'index': tf.io.FixedLenFeature([], tf.int64),
                }
            
            parsed_dataset = dataset.map(lambda x: tf.io.parse_single_example(x, feature_description))
            
            answers = []
            for record in parsed_dataset.take(1000):  # Take 1000 examples
                answer = record['answer'].numpy().decode('utf-8').lower()
                answers.append(answer)
            
            ground_truth[condition] = answers
            print(f"  Loaded {len(answers)} answers")
        
        return ground_truth
    
    def extract_answer_from_response(self, response: str) -> str:
        """Extract yes/no answer from model response"""
        response_lower = response.lower().strip()
        
        # Look for explicit answer patterns
        patterns = [
            r'answer:\s*(yes|no)',
            r'answer\s*:\s*(yes|no)',
            r'\*\*answer:\s*(yes|no)\*\*',
            r'\*\*answer\s*:\s*(yes|no)\*\*',
            r'therefore.*answer is\s*(yes|no)',
            r'the answer is\s*\*?\*?(yes|no)',
            r'answer yes or no\.\s*(yes|no)',
            r'^(yes|no)$',
        ]
        
        for pattern in patterns:
            match = re.search(pattern, response_lower, re.MULTILINE | re.IGNORECASE)
            if match:
                return match.group(1).lower()
        
        # Fallback: look for last occurrence of yes or no
        if 'yes' in response_lower or 'no' in response_lower:
            yes_pos = response_lower.rfind('yes')
            no_pos = response_lower.rfind('no')
            if yes_pos > no_pos:
                return 'yes'
            elif no_pos > yes_pos:
                return 'no'
        
        return 'unclear'  # Could not extract answer
    
    def parse_model_results(self, results_file: str, model_name: str) -> pd.DataFrame:
        """Parse results file from a model"""
        print(f"\nParsing results for {model_name}...")
        
        # Read the results file
        with open(results_file, 'r') as f:
            content = f.read()
        
        # Split by delimiter
        responses = content.split('---DELIM---')
        
        results = []
        
        # We expect 36,000 responses (4 conditions × 9 prompts × 1000 examples)
        expected_per_condition = 9 * 1000
        
        for condition_idx, condition in enumerate(['congruent', 'incongruent', 'random', 'permuted']):
            for prompt_idx, prompt_type in enumerate(self.prompt_types):
                for example_idx in range(1000):
                    # Calculate overall index
                    overall_idx = (condition_idx * expected_per_condition + 
                                 prompt_idx * 1000 + example_idx)
                    
                    if overall_idx >= len(responses):
                        print(f"Warning: Missing response at index {overall_idx}")
                        continue
                    
                    response = responses[overall_idx].strip()
                    
                    # Extract predicted answer
                    predicted = self.extract_answer_from_response(response)
                    
                    # Get ground truth
                    ground_truth = self.ground_truth[condition][example_idx]
                    
                    # Check if correct
                    is_correct = 1 if predicted == ground_truth else 0
                    
                    results.append({
                        'model': model_name,
                        'condition': condition,
                        'prompt_type': prompt_type,
                        'example_idx': example_idx,
                        'predicted': predicted,
                        'ground_truth': ground_truth,
                        'is_correct': is_correct,
                        'response_length': len(response)
                    })
        
        df = pd.DataFrame(results)
        print(f"  Parsed {len(df)} responses")
        return df
    
    def calculate_confidence_interval(self, accuracies: List[float], confidence: float = 0.95) -> tuple:
        """Calculate confidence interval for accuracy"""
        n = len(accuracies)
        if n == 0:
            return 0, 0
        
        mean_acc = np.mean(accuracies)
        std_err = np.std(accuracies, ddof=1) / np.sqrt(n)
        
        # Use t-distribution for small samples
        t_val = scipy_stats.t.ppf((1 + confidence) / 2, n - 1)
        margin_error = t_val * std_err
        
        return mean_acc - margin_error, mean_acc + margin_error
    
    def analyze_single_model(self, df: pd.DataFrame, model_name: str):
        """Analyze results for a single model"""
        print(f"\n{'='*60}")
        print(f"ANALYSIS FOR {model_name.upper()}")
        print(f"{'='*60}")
        
        # Calculate accuracy by condition and prompt type
        accuracy_summary = []
        
        for condition in df['condition'].unique():
            for prompt_type in df['prompt_type'].unique():
                subset = df[(df['condition'] == condition) & 
                           (df['prompt_type'] == prompt_type)]
                
                if len(subset) > 0:
                    acc = subset['is_correct'].mean()
                    ci_low, ci_high = self.calculate_confidence_interval(subset['is_correct'].values)
                    
                    accuracy_summary.append({
                        'condition': condition,
                        'prompt_type': prompt_type,
                        'accuracy': acc,
                        'ci_low': ci_low,
                        'ci_high': ci_high,
                        'n': len(subset)
                    })
        
        accuracy_df = pd.DataFrame(accuracy_summary)
        
        # Print summary
        print("\nAccuracy by Condition and Prompt Type:")
        print("-" * 60)
        for condition in ['congruent', 'incongruent', 'random', 'permuted']:
            print(f"\n{condition.upper()}:")
            condition_data = accuracy_df[accuracy_df['condition'] == condition]
            for _, row in condition_data.iterrows():
                print(f"  {row['prompt_type']:12}: {row['accuracy']:.3f} "
                      f"(CI: [{row['ci_low']:.3f}, {row['ci_high']:.3f}])")
        
        return accuracy_df
    
    def compare_geometries(self, df: pd.DataFrame, model_name: str):
        """Compare linear vs non-linear geometries"""
        linear_geometries = ['numberline', 'road']
        nonlinear_geometries = ['circle', 'clusters', '3d_space', 'cloud', 'hyperbolic']
        
        results = []
        
        for condition in df['condition'].unique():
            # Baseline
            baseline = df[(df['condition'] == condition) & 
                         (df['prompt_type'] == 'baseline')]
            baseline_acc = baseline['is_correct'].mean() if len(baseline) > 0 else 0
            
            # Linear geometries
            linear = df[(df['condition'] == condition) & 
                       (df['prompt_type'].isin(linear_geometries))]
            linear_acc = linear['is_correct'].mean() if len(linear) > 0 else 0
            
            # Non-linear geometries
            nonlinear = df[(df['condition'] == condition) & 
                          (df['prompt_type'].isin(nonlinear_geometries))]
            nonlinear_acc = nonlinear['is_correct'].mean() if len(nonlinear) > 0 else 0
            
            # CoT
            cot = df[(df['condition'] == condition) & 
                    (df['prompt_type'] == 'cot')]
            cot_acc = cot['is_correct'].mean() if len(cot) > 0 else 0
            
            results.append({
                'condition': condition,
                'baseline': baseline_acc,
                'linear': linear_acc,
                'nonlinear': nonlinear_acc,
                'cot': cot_acc,
                'linear_improvement': linear_acc - baseline_acc,
                'nonlinear_improvement': nonlinear_acc - baseline_acc,
                'cot_improvement': cot_acc - baseline_acc
            })
        
        comparison_df = pd.DataFrame(results)
        
        print(f"\n{model_name} - Geometry Comparison:")
        print("-" * 60)
        for _, row in comparison_df.iterrows():
            print(f"{row['condition']:12}: Baseline {row['baseline']:.3f} | "
                  f"Linear {row['linear']:.3f} (Δ={row['linear_improvement']:+.3f}) | "
                  f"Non-linear {row['nonlinear']:.3f} (Δ={row['nonlinear_improvement']:+.3f}) | "
                  f"CoT {row['cot']:.3f} (Δ={row['cot_improvement']:+.3f})")
        
        return comparison_df
    
    def create_model_comparison_plot(self, all_results: Dict[str, pd.DataFrame]):
        """Create visualization comparing all models"""
        fig, axes = plt.subplots(3, 3, figsize=(18, 12))
        axes = axes.flatten()
        
        conditions = ['congruent', 'incongruent', 'random', 'permuted']
        
        for idx, (model_name, df) in enumerate(all_results.items()):
            if idx >= 9:
                break
            
            ax = axes[idx]
            
            # Calculate accuracy by condition and prompt category
            linear = ['numberline', 'road']
            nonlinear = ['circle', 'clusters', '3d_space', 'cloud', 'hyperbolic']
            
            data_for_plot = []
            for condition in conditions:
                baseline_acc = df[(df['condition'] == condition) & 
                                 (df['prompt_type'] == 'baseline')]['accuracy'].mean()
                linear_acc = df[(df['condition'] == condition) & 
                              (df['prompt_type'].isin(linear))]['accuracy'].mean()
                nonlinear_acc = df[(df['condition'] == condition) & 
                                 (df['prompt_type'].isin(nonlinear))]['accuracy'].mean()
                cot_acc = df[(df['condition'] == condition) & 
                            (df['prompt_type'] == 'cot')]['accuracy'].mean()
                
                data_for_plot.append({
                    'condition': condition,
                    'Baseline': baseline_acc,
                    'Linear': linear_acc,
                    'Non-linear': nonlinear_acc,
                    'CoT': cot_acc
                })
            
            plot_df = pd.DataFrame(data_for_plot)
            plot_df.set_index('condition')[['Baseline', 'Linear', 'Non-linear', 'CoT']].plot(
                kind='bar', ax=ax, width=0.8
            )
            
            ax.set_title(model_name.replace('_', ' ').title())
            ax.set_xlabel('Condition')
            ax.set_ylabel('Accuracy')
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
            ax.legend(loc='upper right', fontsize=8)
            ax.set_ylim(0, 1.0)
            ax.grid(axis='y', alpha=0.3)
        
        plt.suptitle('Model Comparison: Geometry Effects on Transitive Inference', fontsize=14, y=1.02)
        plt.tight_layout()
        plt.savefig('all_models_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def run_analysis(self, model_files: Dict[str, str]):
        """Run complete analysis for all models"""
        all_results = {}
        all_comparisons = {}
        
        for model_name, results_file in model_files.items():
            if not Path(results_file).exists():
                print(f"Skipping {model_name} - file not found: {results_file}")
                continue
            
            # Parse results
            df = self.parse_model_results(results_file, model_name)
            
            # Analyze
            accuracy_df = self.analyze_single_model(df, model_name)
            comparison_df = self.compare_geometries(df, model_name)
            
            # Store results
            all_results[model_name] = accuracy_df
            all_comparisons[model_name] = comparison_df
            
            # Save to CSV
            df.to_csv(f'{model_name}_detailed_results.csv', index=False)
            accuracy_df.to_csv(f'{model_name}_accuracy_summary.csv', index=False)
            comparison_df.to_csv(f'{model_name}_geometry_comparison.csv', index=False)
        
        # Create comparison plot
        self.create_model_comparison_plot(all_results)
        
        return all_results, all_comparisons


In [None]:
def main():
    """Main analysis function"""
    analyzer = LLMResultsAnalyzer()
    
    # Define model result files
    model_files = {
        'gemini_2.5_flash': 'llm_output/gemini-2.5-flash-all_prompts.txt',
        'gemini_2.5_flash_lite': 'llm_output/gemini-2.5-flash-lite-all_prompts.txt',
        'gemini_2.5_pro': 'llm_output/gemini-2.5-pro-all_prompts.txt',
        'gemma3_1b': 'llm_output/gemma3_1b-all_prompts.txt',
        'gemma3_4b': 'llm_output/gemma3_4b-all_prompts.txt',
        'gemma3_12b': 'llm_output/gemma3_12b-all_prompts.txt',
        'gemma3_27b': 'llm_output/gemma3_27b-all_prompts.txt'
    }
    
    # Run analysis
    all_results, all_comparisons = analyzer.run_analysis(model_files)
    
    print("\n" + "="*60)
    print("ANALYSIS COMPLETE")
    print("="*60)
    print("\nFiles generated:")
    print("- [model_name]_detailed_results.csv - Full results for each model")
    print("- [model_name]_accuracy_summary.csv - Accuracy by condition and prompt")
    print("- [model_name]_geometry_comparison.csv - Linear vs non-linear comparison")
    print("- all_models_comparison.png - Visual comparison across all models")
    
    return all_results, all_comparisons


In [None]:
results, comparisons = main()

In [None]:
    
# Define model result files
model_files = {
    'gemini_2.5_flash': 'llm_output/gemini-2.5-flash-all_prompts.txt',
    'gemini_2.5_flash_lite': 'llm_output/gemini-2.5-flash-lite-all_prompts.txt',
    'gemini_2.5_pro': 'llm_output/gemini-2.5-pro-all_prompts.txt',
    'gemma3_1b': 'llm_output/gemma3_1b-all_prompts.txt',
    'gemma3_4b': 'llm_output/gemma3_4b-all_prompts.txt',
    'gemma3_12b': 'llm_output/gemma3_12b-all_prompts.txt',
    'gemma3_27b': 'llm_output/gemma3_27b-all_prompts.txt'
}

analyzer = LLMResultsAnalyzer()
# Run analysis
all_results, all_comparisons = analyzer.run_analysis(model_files)

analyzer.create_model_comparison_plot(all_results)

In [None]:
model_files = {
    'gemini_2.5_flash': 'llm_output/gemini-2.5-flash-all_prompts.txt',
    'gemini_2.5_flash_lite': 'llm_output/gemini-2.5-flash-lite-all_prompts.txt',
    'gemini_2.5_pro': 'llm_output/gemini-2.5-pro-all_prompts.txt',
    'gemma3_1b': 'llm_output/gemma3_1b-all_prompts.txt',
    'gemma3_4b': 'llm_output/gemma3_4b-all_prompts.txt',
    'gemma3_12b': 'llm_output/gemma3_12b-all_prompts.txt',
    'gemma3_27b': 'llm_output/gemma3_27b-all_prompts.txt'
}

model_names = list(model_files.keys())

In [None]:

for i, model_name in enumerate(model_names):
    df = pd.read_csv(model_names[i] + '_accuracy_summary.csv')
    
    print(model_names[i])
    
    # Preprocess
    df['yerr_lower'] = df['accuracy'] - df['ci_low']
    df['yerr_upper'] = df['ci_high'] - df['accuracy']
    fig, ax = plt.subplots(figsize=(10,6))
    g = sns.barplot(ax=ax, data=df, x='prompt_type', y='accuracy', hue='condition')
    # g._legend.set_bbox_to_anchor((1.1, 0.5))  # x, y — move to right center
    
    ax.set_ylim([0, 1])
    plt.title(model_name)
    # Add error bars manually
    # Iterate over bars and add errorbars
    for bar, (_, row) in zip(ax.patches, df.iterrows()):
        # Calculate the center of each bar
        x = bar.get_x() + bar.get_width() / 2
        y = bar.get_height()
        
        # Add error bar using asymmetric values
        ax.errorbar(
            x, y,
            yerr=[[row['yerr_lower']], [row['yerr_upper']]],
            fmt='none',
            c='black',
            capsize=5,
            linewidth=1
        )
    
    
    fig.legend(bbox_to_anchor=(1.15,0.5))
    plt.tight_layout()
    plt.savefig(f'accuracy_{model_name}.png')
    plt.close('all')

In [None]:
for i, model_name in enumerate(model_names):
    
    df = pd.read_csv(model_name + '_geometry_comparison.csv')
    
    df
    df_pivoted = df[['condition', 'baseline', 'linear', 'nonlinear', 'cot']].melt(id_vars='condition',
                                                                                 var_name='prompt_type',
                                                                                 value_name='accuracy')
    df_pivoted
    
    
    fig, ax = plt.subplots()
    plt.ylim((0,1))
    sns.barplot(ax=ax, data=df_pivoted, x='prompt_type', y='accuracy', hue='condition', palette='husl')
    plt.title(model_name)
    plt.tight_layout()
    plt.savefig(f'geometry_comparison_{model_name}.png')

In [None]:
df[['condition', 'baseline', 'linear', 'nonlinear', 'cot']]

In [None]:
dfs = []
n_unclear_responses = {}
for model_name in model_names:
    df = pd.read_csv(model_name + '_detailed_results.csv').reset_index()

    n_unclear_responses[model_name] = np.sum(df['predicted']=='unclear')
    dfs.append(df)
n_unclear_responses

In [None]:
all_detailed_results_df = pd.concat(dfs)
all_detailed_results_df.head()

In [None]:
fig, ax = plt.subplots(figsize=(10,5))
sns.barplot(data=all_detailed_results_df, y='response_length', x='model', ax=ax)
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(figsize=(10,5))
sns.barplot(data=all_detailed_results_df, y='response_length', x='model', hue='condition', ax=ax)
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(figsize=(10,5))
sns.barplot(data=all_detailed_results_df, y='response_length', x='model', hue='prompt_type', ax=ax, palette='husl')
plt.tight_layout()

next, look at distribution of response lengths, per model. are there two types of responses (quick answer versus reasoning trace), or is it more a continuum. 

In [None]:
plt.figure()
mn = model_names[0]
sns.histplot(data=all_detailed_results_df[all_detailed_results_df.model==mn], x='response_length')
plt.title(mn)

In [None]:
for i, model_name in enumerate(model_names):
    fig, ax = plt.subplots()
    sns.histplot(ax=ax, data=all_detailed_results_df[all_detailed_results_df.model==model_name], x='response_length')
    plt.title(model_name)
    plt.tight_layout()
    plt.savefig(f'response_length_{model_name}.png')
    plt.close(fig)
