In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import os

In [5]:
def load_experiment_results(results_dir: str | Path):
    """
    Load all results from an experiment directory.
    
    Args:
        results_dir: Path to the experiment results directory
    
    Returns:
        dict: Dictionary containing all loaded results and statistics
    """
    results_dir = Path(results_dir)
    results = {}
    
    # Load arguments used for the run
    with open(results_dir / "args.json", 'r') as f:
        results['args'] = json.load(f)
    
    # Load data split information if it exists
    split_file = results_dir / "data_split.json"
    if split_file.exists():
        with open(split_file, 'r') as f:
            results['data_split'] = json.load(f)
    
    # Dictionary to store per-model results
    results['models'] = {}
    
    # Load results for each model
    for model_dir in results_dir.iterdir():
        if not model_dir.is_dir():
            continue
            
        model_results = {}
        
        # Load main test results
        test_results_file = model_dir / "test_results.csv"
        if test_results_file.exists():
            model_results['test_results'] = pd.read_csv(test_results_file)
        
        # Load per-layer likelihoods
        layer_ll_file = model_dir / "per_layer_likelihoods.csv"
        if layer_ll_file.exists():
            model_results['layer_likelihoods'] = pd.read_csv(layer_ll_file)
        
        # Load matrices
        matrices = {}
        for matrix_name in ['confusion', 'precision', 'recall', 'f1']:
            matrix_file = model_dir / f"full_tuple_{matrix_name}_matrix.csv"
            if matrix_file.exists():
                matrices[f'{matrix_name}_matrix'] = pd.read_csv(matrix_file, index_col=0)
        model_results['matrices'] = matrices
        
        # Load top-k accuracy
        topk_file = model_dir / "full_tuple_top_k_accuracy.csv"
        if topk_file.exists():
            model_results['top_k_accuracy'] = pd.read_csv(topk_file, index_col=0)
        
        # Load per-dimension results
        dimension_results = {}
        for dim_dir in model_dir.iterdir():
            if not dim_dir.is_dir() or not dim_dir.name.startswith('dimension_'):
                continue
                
            dim_matrices = {}
            for matrix_name in ['confusion', 'precision', 'recall', 'f1']:
                matrix_file = dim_dir / f"{matrix_name}_matrix.csv"
                if matrix_file.exists():
                    dim_matrices[f'{matrix_name}_matrix'] = pd.read_csv(matrix_file, index_col=0)
            
            topk_file = dim_dir / "top_k_accuracy.csv"
            if topk_file.exists():
                dim_matrices['top_k_accuracy'] = pd.read_csv(topk_file, index_col=0)
                
            dimension_results[dim_dir.name] = dim_matrices
        
        model_results['dimension_results'] = dimension_results
        
        # Add all model results to main results dict
        results['models'][model_dir.name] = model_results
    
    return results


In [6]:
def get_metrics_from_confusion(conf_mat: pd.DataFrame) -> tuple[pd.Series, pd.Series, pd.Series]:
    """Calculate precision, recall and F1 scores from a confusion matrix.
    
    Args:
        conf_mat: Square pandas DataFrame with identical row and column labels,
                 where rows are true labels and columns are predicted labels
                 
    Returns:
        Tuple of (precision, recall, f1) where each is a pandas Series
        with one value per category
    """
    # Ensure matrix is square with matching labels
    if not all(conf_mat.index == conf_mat.columns):
        raise ValueError("Confusion matrix must have identical row and column labels")
        
    # Calculate metrics for each category
    precision = pd.Series(index=conf_mat.columns, dtype=float)
    recall = pd.Series(index=conf_mat.index, dtype=float)
    f1 = pd.Series(index=conf_mat.index, dtype=float)
    
    for category in conf_mat.columns:
        # True positives are diagonal elements
        tp = conf_mat.loc[category, category]
        
        # False positives are sum of column minus true positives
        fp = conf_mat[category].sum() - tp
        
        # False negatives are sum of row minus true positives  
        fn = conf_mat.loc[category].sum() - tp
        
        # Calculate metrics, handling division by zero
        prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        
        # Calculate F1 as harmonic mean of precision and recall
        f1_score = 2 * (prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
        
        precision[category] = prec
        recall[category] = rec
        f1[category] = f1_score
        
    return precision, recall, f1

In [7]:
def get_subcats(label: str) -> tuple[str, str]:
    """Split a category label into its two subcategories.
    
    Args:
        label: String of format '("subcat1_i", "subcat2_i")'
        
    Returns:
        Tuple of (subcat1_i, subcat2_i)
    """
    # Remove outer single quote, parentheses and split on comma
    label = label.strip("'").strip('()')
    subcat1, subcat2 = label.split(',')
    # Clean up quotes and whitespace
    subcat1 = subcat1.strip().strip('"')
    subcat1 = subcat1.strip().strip("'")
    subcat2 = subcat2.strip().strip('"')
    subcat2 = subcat2.strip().strip("'")
    return (subcat1, subcat2)


In [8]:
def plot_results(results_dir: str,
                 figures_dir: str = None,
                 show_plots: bool = False,
                 model_names: list[str] = ['gpt2-small', 'gpt2-medium', 'gpt2-large', 
                                           'pythia-160m', 'pythia-410m', 'pythia-1b']):
    """
    Plot and optionally save analysis results.
    
    Args:
        results_dir: Directory containing results
        figures_dir: Directory to save figures (defaults to results_dir/figs/)
        show_plots: Whether to display plots in notebook
        model_names: List of model names to include in analysis
    """
    # Set up figures directory
    if figures_dir is None:
        figures_dir = os.path.join(results_dir, 'figs')
    os.makedirs(figures_dir, exist_ok=True)

    results = load_experiment_results(results_dir)

    model_count = len(model_names)
    # Get the row/column labels. They are all assumed the same
    subcats1 = list(results['models'][model_names[0]]['test_results']['true_category_1'].unique())
    subcats2 = list(results['models'][model_names[0]]['test_results']['true_category_2'].unique())

    # Initialize DataFrames with correct columns
    precision1_matrix = pd.DataFrame(index=model_names, columns=subcats1, dtype=float)
    precision2_matrix = pd.DataFrame(index=model_names, columns=subcats2, dtype=float)
    recall1_matrix = pd.DataFrame(index=model_names, columns=subcats1, dtype=float)
    recall2_matrix = pd.DataFrame(index=model_names, columns=subcats2, dtype=float)
    f1_1_matrix = pd.DataFrame(index=model_names, columns=subcats1, dtype=float)
    f1_2_matrix = pd.DataFrame(index=model_names, columns=subcats2, dtype=float)

    cmap_val = 'YlOrRd'
    for model_nm in model_names:
        # Column sums histogram
        plt.figure(figsize=(8, 5))
        conf_mat = results['models'][model_nm]['matrices']['confusion_matrix']
        col_sums = conf_mat.sum(axis=0)
        plt.hist(col_sums, bins=50)
        plt.title(f'Distribution of Column Sums - {model_nm}')
        plt.xlabel('Sum')
        plt.ylabel('Count')
        plt.savefig(os.path.join(figures_dir, f'{model_nm}_column_sums.png'), bbox_inches='tight')
        if show_plots:
            plt.show()
        plt.close()

        # Confusion matrix heatmap
        plt.figure(figsize=(8, 5))
        sns.heatmap(conf_mat, annot=False, cmap=cmap_val, xticklabels=False, yticklabels=False)
        plt.title(f'Confusion Matrix - {model_nm}')
        plt.tight_layout()
        plt.savefig(os.path.join(figures_dir, f'{model_nm}_confusion_matrix.png'), bbox_inches='tight')
        if show_plots:
            plt.show()
        plt.close()

        # Get labels from the confusion matrix indices and verify they match columns
        labels = conf_mat.index
        assert all(labels == conf_mat.columns), "Row and column labels don't match"

        # Get the row/column labels
        subcats1 = list(results['models'][model_nm]['test_results']['true_category_1'].unique())
        subcats2 = list(results['models'][model_nm]['test_results']['true_category_2'].unique())

        # Create empty matrices for each subcat
        subcat1_matrix = pd.DataFrame(np.zeros((len(subcats1), len(subcats1))),
                                    index=subcats1, columns=subcats1)
        subcat2_matrix = pd.DataFrame(np.zeros((len(subcats2), len(subcats2))),
                                    index=subcats2, columns=subcats2)

        # Fill matrices by summing appropriate blocks
        for i, label_i in enumerate(labels):
            sub1_i, sub2_i = get_subcats(label_i)
            for j, label_j in enumerate(labels):
                sub1_j, sub2_j = get_subcats(label_j)
                subcat1_matrix.loc[sub1_i, sub1_j] += conf_mat.loc[label_i, label_j]
                subcat2_matrix.loc[sub2_i, sub2_j] += conf_mat.loc[label_i, label_j]

        # Plot subcategory matrices
        plt.figure(figsize=(20, 8))
        plt.subplot(1, 2, 1)
        sns.heatmap(subcat1_matrix, annot=True, fmt='.0f', xticklabels=subcats1, 
                   yticklabels=subcats1, cmap=cmap_val)
        plt.title(f'Confusion Matrix for First Subcategory - {model_nm}')

        plt.subplot(1, 2, 2)
        sns.heatmap(subcat2_matrix, annot=True, fmt='.0f', xticklabels=subcats2, 
                   yticklabels=subcats2, cmap=cmap_val)
        plt.title(f'Confusion Matrix for Second Subcategory - {model_nm}')
        plt.tight_layout()
        plt.savefig(os.path.join(figures_dir, f'{model_nm}_subcategory_matrices.png'), bbox_inches='tight')
        if show_plots:
            plt.show()
        plt.close()

        # Calculate metrics for subcategories
        prec1, rec1, f1_1 = get_metrics_from_confusion(subcat1_matrix)
        prec2, rec2, f1_2 = get_metrics_from_confusion(subcat2_matrix)

        # Store in matrices
        precision1_matrix.loc[model_nm] = prec1
        precision2_matrix.loc[model_nm] = prec2
        recall1_matrix.loc[model_nm] = rec1
        recall2_matrix.loc[model_nm] = rec2
        f1_1_matrix.loc[model_nm] = f1_1
        f1_2_matrix.loc[model_nm] = f1_2

    # Plot precision matrices
    plt.figure(figsize=(18, 8))
    plt.subplot(1, 2, 1)
    sns.heatmap(precision1_matrix, annot=True, fmt='.2f', cmap=cmap_val, vmin=0, vmax=1)
    plt.title('Precision - First Subcategory')
    plt.yticks(rotation=0)

    plt.subplot(1, 2, 2)
    sns.heatmap(precision2_matrix, annot=True, fmt='.2f', cmap=cmap_val, vmin=0, vmax=1)
    plt.title('Precision - Second Subcategory')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(figures_dir, 'precision_matrices.png'), bbox_inches='tight')
    if show_plots:
        plt.show()
    plt.close()

    # Plot recall matrices
    plt.figure(figsize=(18, 8))
    plt.subplot(1, 2, 1)
    sns.heatmap(recall1_matrix, annot=True, fmt='.2f', cmap=cmap_val, vmin=0, vmax=1)
    plt.title('Recall - First Subcategory')
    plt.yticks(rotation=0)

    plt.subplot(1, 2, 2)
    sns.heatmap(recall2_matrix, annot=True, fmt='.2f', cmap=cmap_val, vmin=0, vmax=1)
    plt.title('Recall - Second Subcategory')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(figures_dir, 'recall_matrices.png'), bbox_inches='tight')
    if show_plots:
        plt.show()
    plt.close()

    # Plot F1 matrices
    plt.figure(figsize=(18, 8))
    plt.subplot(1, 2, 1)
    sns.heatmap(f1_1_matrix, annot=True, fmt='.2f', cmap=cmap_val, vmin=0, vmax=1)
    plt.title('F1 - First Subcategory')
    plt.yticks(rotation=0)

    plt.subplot(1, 2, 2)
    sns.heatmap(f1_2_matrix, annot=True, fmt='.2f', cmap=cmap_val, vmin=0, vmax=1)
    plt.title('F1 - Second Subcategory')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(figures_dir, 'f1_matrices.png'), bbox_inches='tight')
    if show_plots:
        plt.show()
    plt.close()


In [15]:
results_path_all_points = "/home/mattylev/projects/transformers/distributions/kde_classification/results/20250107_4/"
results_path_tails = "/home/mattylev/projects/transformers/distributions/kde_classification/results/20250109_1/"
results_path_tails_2 = "/home/mattylev/projects/transformers/distributions/kde_classification/results/20250109_2/"
#results = load_experiment_results(results_path)


In [13]:
plot_results(results_path_tails)

In [11]:
plot_results(results_path_all_points)

In [None]:
plot_results(results_path_tails_2, )

In [None]:
## We're getting the exact same confusion matrix, and thus derivative metrics, for tails only and all points analysis. Obviously
## that can't be right. Both analyses are fully done at this point. So load the test results and see if you can get any hints about 
## WTF is going on?

In [20]:
conf_mat.sum(axis=1)[conf_mat.sum(axis=1) != 40]

('about celebrities', 'in all lower case')    38
('about poetry', 'in all lower case')         41
('about vacation', 'in all lower case')       41
dtype: int64

In [51]:
conf_mat = results['models']['pythia-410m']['matrices']['confusion_matrix']
col_sums = conf_mat.sum(axis=0)
sorted_sums = col_sums.sort_values(ascending=False)
print(sorted_sums.head(20))


('about music', 'in a kind tone')                    664
('about politics', 'in english')                     617
('about poetry', 'in a concise style')               239
('about politics', 'in a concise style')             235
('about celebrities', 'in a concise style')          181
('about celebrities', 'in all lower case')           178
('about celebrities', 'in english')                  148
('about music', 'in spanish')                        143
('about poetry', 'in japanese')                      125
('about computer science', 'in an angry tone')       105
('about science fiction', 'in japanese')              96
('about american football', 'in an angry tone')       76
('about american football', 'in all lower case')      71
('about celebrities', 'in japanese')                  69
('about computer science', 'in a childish style')     68
('about american football', 'in all upper case')      67
('about poetry', 'in spanish')                        61
('about effective altruism', 'i

In [42]:
model_results['matrices']['confusion_matrix'].loc[labels[0], labels[1]]

1

In [41]:
model_results['matrices']['confusion_matrix'].index

Index(['('about american football', 'in a childish style')',
       '('about american football', 'in a concise style')',
       '('about american football', 'in a didactic style')',
       '('about american football', 'in a flowery style')',
       '('about american football', 'in a helpful style')',
       '('about american football', 'in a kind tone')',
       '('about american football', 'in all lower case')',
       '('about american football', 'in all upper case')',
       '('about american football', 'in an angry tone')',
       '('about american football', 'in english')',
       ...
       '('about vacation', 'in a didactic style')',
       '('about vacation', 'in a flowery style')',
       '('about vacation', 'in a helpful style')',
       '('about vacation', 'in a kind tone')',
       '('about vacation', 'in all lower case')',
       '('about vacation', 'in all upper case')',
       '('about vacation', 'in an angry tone')',
       '('about vacation', 'in english')', '('about v