# üß† Concept Encoder Analysis Notebook

A comprehensive analysis toolkit for understanding what concepts learn in the ConceptEncoder architecture.

## Overview

This notebook provides:
1. **Concept Space Geometry** - Effective rank, isotropy, uniformity, collapse detection
2. **Attention Pattern Analysis** - Concept-token attention visualization
3. **Concept Specialization** - What tokens/patterns each concept captures
4. **Publication-Quality Visualizations** - Figures for research papers

### Research Background
- VICReg (Bardes et al., 2021): Variance-Invariance-Covariance analysis
- Perceiver IO (Jaegle et al., 2021): Cross-attention bottleneck analysis  
- Probing Tasks (Miaschi et al., 2020): Linguistic property probing
- Intrinsic Dimensionality (Aghajanyan et al., 2020): Effective dimensionality
- T-REGS (Mordacq et al., 2025): Uniformity and collapse metrics

In [None]:
# Setup paths and imports
import sys
import os

# Add project root to path
project_root = os.path.dirname(os.path.dirname(os.path.abspath('.')))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    
# Also add current working directory
cwd = os.getcwd()
if cwd not in sys.path:
    sys.path.insert(0, cwd)
    
# Move up to project root if we're in analysis folder
if os.path.basename(cwd) == 'analysis':
    os.chdir('..')
    
print(f"Working directory: {os.getcwd()}")

In [None]:
# Core imports
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from collections import defaultdict
from pathlib import Path

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Set aesthetic defaults
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Publication quality settings
plt.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
    'figure.figsize': (10, 8),
    'figure.dpi': 100,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'font.family': 'sans-serif'
})

print("‚úÖ Core imports complete!")

In [None]:
# Optional imports with fallbacks
try:
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE
    from sklearn.cluster import KMeans
    HAS_SKLEARN = True
except ImportError:
    HAS_SKLEARN = False
    print("‚ö†Ô∏è sklearn not available - some analyses will be limited")

try:
    import umap
    HAS_UMAP = True
except ImportError:
    HAS_UMAP = False
    print("‚ö†Ô∏è umap not available - using t-SNE/PCA as fallback")

print(f"‚úÖ sklearn available: {HAS_SKLEARN}")
print(f"‚úÖ UMAP available: {HAS_UMAP}")

In [None]:
# Import ConceptEncoder modules
from nn.concept_encoder import ConceptEncoder, ConceptEncoderConfig
from nn.concept_encoder_perceiver import ConceptEncoderForMaskedLMPerceiver
from nn.concept_encoder_weighted import ConceptEncoderForMaskedLMWeighted

# Import analysis toolkit
from analysis.concept_analysis import (
    compute_concept_geometry_metrics,
    ConceptAttentionExtractor,
    ConceptSpecializationAnalyzer,
    ConceptVisualizer,
    ConceptAnalyzer,
    ConceptMetricsCallback
)

print("‚úÖ ConceptEncoder modules loaded!")

## 1. Load Model and Data

Configure the model checkpoint and load data for analysis.

**Important:** Update `MODEL_PATH` to point to your trained model checkpoint!

In [None]:
# ============================================================
# CONFIGURATION - UPDATE THESE FOR YOUR MODEL
# ============================================================
MODEL_PATH = "./Cache/Training/YOUR_MODEL_CHECKPOINT"  # <-- UPDATE THIS!
MODEL_TYPE = "perceiver_mlm"  # Options: "perceiver_mlm" or "weighted_mlm"
TOKENIZER_NAME = "bert-base-uncased"
OUTPUT_DIR = "./Cache/Outputs/concept_analysis"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"üìÅ Model path: {MODEL_PATH}")
print(f"üîß Model type: {MODEL_TYPE}")
print(f"üìÇ Output dir: {OUTPUT_DIR}")

# Check if model exists
if not os.path.exists(MODEL_PATH):
    print(f"\n‚ö†Ô∏è  WARNING: Model path does not exist!")
    print(f"   Please update MODEL_PATH to point to your trained checkpoint.")
    print(f"   Example: './Cache/Training/perceiver_mlm_H512L4C128_20240115_120000'")

In [None]:
# Load tokenizer
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
print(f"‚úÖ Tokenizer loaded: {TOKENIZER_NAME}")
print(f"   Vocab size: {tokenizer.vocab_size}")

In [None]:
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üñ•Ô∏è  Using device: {device}")

try:
    if MODEL_TYPE == "perceiver_mlm":
        model = ConceptEncoderForMaskedLMPerceiver.from_pretrained(MODEL_PATH)
    elif MODEL_TYPE == "weighted_mlm":
        model = ConceptEncoderForMaskedLMWeighted.from_pretrained(MODEL_PATH)
    else:
        raise ValueError(f"Unknown model type: {MODEL_TYPE}")

    model = model.to(device)
    model.eval()

    # Print model config
    config = model.config
    print(f"\nüìä Model Configuration:")
    print(f"   Vocab size: {config.vocab_size}")
    print(f"   Hidden size: {config.hidden_size}")
    print(f"   Num layers: {config.num_hidden_layers}")
    print(f"   Num concepts: {config.concept_num}")
    print(f"   Max sequence length: {config.max_sequence_length}")

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\n   Total parameters: {total_params:,} ({total_params/1e6:.1f}M)")
    
    MODEL_LOADED = True
except Exception as e:
    print(f"\n‚ùå Failed to load model: {e}")
    print("   Please update MODEL_PATH and re-run this cell.")
    MODEL_LOADED = False

## 2. Concept Space Geometry Analysis

Analyze the geometric properties of the learned concept space:

| Metric | What It Measures | Healthy Range |
|--------|------------------|---------------|
| **Effective Rank** | How many dimensions are actually used | > 0.5 (normalized) |
| **Isotropy** | Are all dimensions equally utilized? | > 0.01 |
| **Uniformity** | Are concepts well-distributed on hypersphere? | < 0.1 |
| **Max Similarity** | Collapse detection (too similar?) | < 0.5 |

In [None]:
# Load a sample dataset for analysis
from datasets import load_dataset
from torch.utils.data import DataLoader

if MODEL_LOADED:
    print("üìö Loading WikiText dataset for analysis...")
    
    # Use WikiText for analysis (clean, well-known text)
    dataset = load_dataset("wikitext", "wikitext-103-v1", split="validation[:1000]")
    
    # Filter and tokenize
    def tokenize_function(examples):
        texts = [t for t in examples['text'] if len(t.strip()) > 20]
        if not texts:
            return {'input_ids': [], 'attention_mask': []}
        return tokenizer(texts, truncation=True, max_length=128, padding='max_length')
    
    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
    tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    
    dataloader = DataLoader(tokenized_dataset, batch_size=32, shuffle=False)
    print(f"‚úÖ Dataset loaded: {len(tokenized_dataset)} samples, {len(dataloader)} batches")
else:
    print("‚ö†Ô∏è  Model not loaded. Please load model first.")

In [None]:
# Collect concept representations from multiple batches
if MODEL_LOADED:
    print("üîÑ Collecting concept representations...")
    
    all_concepts = []
    num_batches = min(10, len(dataloader))  # Limit for efficiency
    
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= num_batches:
                break
            
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            # Get concept representations from encoder
            if hasattr(model, 'encoder'):
                concept_output = model.encoder(input_ids=input_ids, attention_mask=attention_mask)
            else:
                concept_output = model.concept_encoder(input_ids=input_ids, attention_mask=attention_mask)
            
            all_concepts.append(concept_output.cpu())
    
    # Combine all concepts: (total_samples, num_concepts, hidden_size)
    concepts_tensor = torch.cat(all_concepts, dim=0)
    print(f"‚úÖ Collected concepts shape: {concepts_tensor.shape}")
    print(f"   = {concepts_tensor.shape[0]} samples x {concepts_tensor.shape[1]} concepts x {concepts_tensor.shape[2]} hidden_dim")
else:
    print("‚ö†Ô∏è  Model not loaded.")

In [None]:
# Compute geometry metrics
if MODEL_LOADED:
    print("üìê Computing concept space geometry metrics...\n")
    
    metrics = compute_concept_geometry_metrics(concepts_tensor)
    
    print("=" * 60)
    print("             CONCEPT SPACE GEOMETRY REPORT")
    print("=" * 60)
    
    # Effective Rank Analysis
    print(f"\nüìä EFFECTIVE RANK (Information Utilization)")
    print(f"   Raw effective rank: {metrics['effective_rank']:.2f} / {concepts_tensor.shape[2]} dims")
    print(f"   Normalized (0-1):   {metrics['normalized_effective_rank']:.3f}")
    if metrics['normalized_effective_rank'] > 0.5:
        print(f"   ‚úÖ Good - Concepts use most of the embedding space")
    elif metrics['normalized_effective_rank'] > 0.2:
        print(f"   ‚ö†Ô∏è  Moderate - Some dimensions may be underutilized")
    else:
        print(f"   ‚ùå Low - Possible dimensional collapse!")
    
    # Isotropy Analysis  
    print(f"\nüîÆ ISOTROPY (Dimension Utilization Uniformity)")
    print(f"   Isotropy score: {metrics['isotropy']:.4f}")
    if metrics['isotropy'] > 0.01:
        print(f"   ‚úÖ Good - Dimensions are utilized fairly uniformly")
    else:
        print(f"   ‚ö†Ô∏è  Low - Some dimensions dominate the representation")
    
    # Uniformity Analysis
    print(f"\nüéØ UNIFORMITY (Distribution on Hypersphere)")
    print(f"   Uniformity loss: {metrics['uniformity']:.4f}")
    if metrics['uniformity'] < 0.1:
        print(f"   ‚úÖ Good - Concepts are well-distributed")
    else:
        print(f"   ‚ö†Ô∏è  High - Concepts may be clustered")
    
    # Similarity Analysis (Collapse Detection)
    print(f"\nüîç SIMILARITY ANALYSIS (Collapse Detection)")
    print(f"   Mean pairwise similarity: {metrics['mean_similarity']:.4f}")
    print(f"   Max pairwise similarity:  {metrics['max_similarity']:.4f}")
    print(f"   Min pairwise similarity:  {metrics['min_similarity']:.4f}")
    
    if metrics['max_similarity'] > 0.9:
        print(f"   ‚ùå CRITICAL: Some concepts are nearly identical (collapse!)")
    elif metrics['max_similarity'] > 0.5:
        print(f"   ‚ö†Ô∏è  Some concepts are highly similar")
    else:
        print(f"   ‚úÖ Good diversity between concepts")
    
    # Variance Analysis
    print(f"\nüìà VARIANCE ANALYSIS")
    print(f"   Mean variance: {metrics['mean_variance']:.4f}")
    print(f"   Variance std:  {metrics['var_std']:.4f}")
    
    # Norm Statistics
    print(f"\nüìè NORM STATISTICS")
    print(f"   Mean norm: {metrics['mean_norm']:.4f}")
    print(f"   Norm std:  {metrics['norm_std']:.4f}")
    
    print("\n" + "=" * 60)
else:
    print("‚ö†Ô∏è  Model not loaded.")

## 3. Publication-Quality Visualizations

Generate figures suitable for a research paper:

1. **Concept Similarity Matrix** - Pairwise cosine similarity heatmap
2. **Singular Value Spectrum** - Shows dimensionality usage (effective rank)
3. **2D Projections** - PCA/t-SNE/UMAP visualizations of concept space

In [None]:
# 3.1 Concept Similarity Matrix
if MODEL_LOADED:
    visualizer = ConceptVisualizer(save_dir=OUTPUT_DIR)
    
    # Average concepts across samples to get concept prototypes
    concept_prototypes = concepts_tensor.mean(dim=0)  # (num_concepts, hidden_size)
    
    fig = visualizer.plot_concept_similarity_matrix(
        concept_prototypes, 
        title="Concept Pairwise Cosine Similarity"
    )
    plt.show()
    print(f"üíæ Saved to: {OUTPUT_DIR}/concept_similarity_matrix.png")
else:
    print("‚ö†Ô∏è  Model not loaded.")

In [None]:
# 3.2 Singular Value Spectrum
if MODEL_LOADED:
    fig = visualizer.plot_svd_spectrum(
        concept_prototypes,
        title="Singular Value Spectrum of Concept Representations"
    )
    plt.show()
    print(f"üíæ Saved to: {OUTPUT_DIR}/svd_spectrum.png")
    
    # Interpretation
    print("\nüìä Interpretation:")
    print("   - Rapid decay = low effective dimensionality (few dominant directions)")
    print("   - Slow decay = high effective dimensionality (many useful directions)")
    print("   - Knee/elbow = boundary between significant and noise dimensions")
else:
    print("‚ö†Ô∏è  Model not loaded.")

In [None]:
# 3.3 2D Projections (PCA, t-SNE)
if MODEL_LOADED and HAS_SKLEARN:
    # Flatten concepts for projection: (num_samples * num_concepts, hidden_size)
    flat_concepts = concepts_tensor.reshape(-1, concepts_tensor.shape[-1]).numpy()
    
    # Create labels for coloring (concept index)
    n_samples, n_concepts, hidden_dim = concepts_tensor.shape
    concept_labels = np.tile(np.arange(n_concepts), n_samples)
    
    # PCA projection
    fig = visualizer.plot_2d_projection(
        flat_concepts,
        labels=concept_labels,
        method='pca',
        title="PCA Projection of Concept Space"
    )
    plt.show()
    print(f"üíæ Saved to: {OUTPUT_DIR}/pca_projection.png")
    
    # t-SNE projection (on subset for efficiency)
    subset_size = min(500 * n_concepts, len(flat_concepts))
    indices = np.random.choice(len(flat_concepts), subset_size, replace=False)
    
    fig = visualizer.plot_2d_projection(
        flat_concepts[indices],
        labels=concept_labels[indices],
        method='tsne',
        title="t-SNE Projection of Concept Space"
    )
    plt.show()
    print(f"üíæ Saved to: {OUTPUT_DIR}/tsne_projection.png")
else:
    if not MODEL_LOADED:
        print("‚ö†Ô∏è  Model not loaded.")
    else:
        print("‚ö†Ô∏è  sklearn required for 2D projections.")

## 4. Summary Report

Generate a comprehensive summary of the analysis results for documentation and papers.

In [None]:
# Generate summary report
if MODEL_LOADED:
    import json
    from datetime import datetime
    
    report = {
        "timestamp": datetime.now().isoformat(),
        "model_path": MODEL_PATH,
        "model_type": MODEL_TYPE,
        "model_config": {
            "vocab_size": config.vocab_size,
            "hidden_size": config.hidden_size,
            "num_layers": config.num_hidden_layers,
            "num_concepts": config.concept_num,
            "max_sequence_length": config.max_sequence_length
        },
        "geometry_metrics": {k: float(v) if isinstance(v, (float, np.floating)) else v 
                            for k, v in metrics.items()},
        "samples_analyzed": concepts_tensor.shape[0]
    }
    
    # Print summary
    print("=" * 70)
    print("                    CONCEPT ENCODER ANALYSIS SUMMARY")
    print("=" * 70)
    print(f"\nüìÖ Timestamp: {report['timestamp']}")
    print(f"üìÅ Model: {MODEL_PATH}")
    print(f"\nüîß Configuration:")
    for k, v in report['model_config'].items():
        print(f"   {k}: {v}")
    
    print(f"\nüìä Key Metrics:")
    print(f"   Effective Rank (normalized): {metrics['normalized_effective_rank']:.3f}")
    print(f"   Isotropy: {metrics['isotropy']:.4f}")
    print(f"   Uniformity: {metrics['uniformity']:.4f}")
    print(f"   Max Similarity: {metrics['max_similarity']:.4f}")
    
    # Health Assessment
    print(f"\nüè• Health Assessment:")
    issues = []
    if metrics['normalized_effective_rank'] < 0.3:
        issues.append("Low effective rank - possible dimensional collapse")
    if metrics['max_similarity'] > 0.7:
        issues.append("High concept similarity - concepts may not be diverse")
    if metrics['uniformity'] > 0.5:
        issues.append("High uniformity loss - concepts clustered")
    
    if not issues:
        print("   ‚úÖ Model appears healthy!")
    else:
        for issue in issues:
            print(f"   ‚ö†Ô∏è  {issue}")
    
    # Save report
    report_path = os.path.join(OUTPUT_DIR, "analysis_report.json")
    with open(report_path, 'w') as f:
        json.dump(report, f, indent=2)
    print(f"\nüíæ Report saved to: {report_path}")
    
    print("\n" + "=" * 70)
else:
    print("‚ö†Ô∏è  Model not loaded.")