# Cross-Scale Motif Analysis: Gemma 3 Family

This notebook analyzes how motif enrichment patterns change across model scales
using the Gemma 3 family (270M, 1B, 4B, 12B, 27B).

**Key question:** Does the structural grammar of computation change as models get bigger?

Sections:
1. Setup & load model registry
2. Per-model summaries (graph counts, sizes)
3. **Scale heatmap** — main result figure
4. **FFL backbone analysis** — 030T Z-score across scales
5. **Scaling curves** — key motifs vs log(params)
6. **Per-task scaling** — arithmetic, safety, reasoning separately
7. **Phase transition detection**
8. **Model similarity** — cosine matrix + dendrogram
9. **SP overlay** — all models' profiles overlaid
10. Comparison with Haiku/Gemma-2-2B/Qwen3-4B baseline

**Prerequisite:** Run the scale pipeline first:
```bash
python -m src.pipeline --scale-mode --data-dir data/raw --results-dir data/results --n-random 100
```

In [None]:
import sys
sys.path.insert(0, "..")

import json
import pickle
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from src.models import (
    get_model, gemma3_scaling_curve, GEMMA_3_MODELS, ALL_MODELS,
)
from src.motif_census import (
    TRIAD_LABELS, CONNECTED_TRIAD_INDICES,
    MOTIF_FFL, MOTIF_CHAIN, MOTIF_FAN_IN, MOTIF_FAN_OUT,
)
from src.scale_comparison import (
    ModelProfile, ScaleComparison, ScaleTrend,
    build_model_profile, run_scale_comparison,
    compute_scale_trends, detect_phase_transitions,
    pairwise_model_similarity, check_ffl_backbone,
    compare_task_across_scales,
)
from src.visualization import (
    plot_scale_trend, plot_scale_heatmap, plot_sp_overlay,
    plot_per_task_scaling, plot_scale_dendrogram,
    plot_cosine_similarity_matrix,
    MODEL_SCALE_COLORS,
)

%matplotlib inline
plt.rcParams['figure.dpi'] = 120

## 1. Model Registry

The Gemma 3 family spans 2 orders of magnitude in parameter count.

In [None]:
scaling_curve = gemma3_scaling_curve()

print(f"{'Model':<22s}  {'Params':>8s}  {'Layers':>6s}  {'Hidden':>6s}  {'log10(N)':>8s}  {'Transcoders':>11s}")
print('-' * 75)
for spec in scaling_curve:
    p_str = f"{spec.n_params}M" if spec.n_params < 1000 else f"{spec.n_params // 1000}B"
    n_tc = len(spec.transcoders)
    clt = any(t.is_clt for t in spec.transcoders)
    tc_str = f"{n_tc} ({'+ CLT' if clt else 'PLT only'})"
    print(f"{spec.model_id:<22s}  {p_str:>8s}  {spec.n_layers:>6d}  {spec.hidden_dim:>6d}  {spec.log_params:>8.2f}  {tc_str:>11s}")

## 2. Load Scale Pipeline Results

In [None]:
scale_dir = Path("../data/results/scale")

# Load model profiles
profiles_path = scale_dir / "scale_profiles.pkl"
if profiles_path.exists():
    with open(profiles_path, "rb") as f:
        model_profiles = pickle.load(f)
    print(f"Loaded profiles for {len(model_profiles)} model(s)")
    for model_id, mp in sorted(model_profiles.items()):
        p = mp.model_spec.n_params
        p_str = f"{p}M" if p < 1000 else f"{p // 1000}B"
        print(f"  {model_id} ({p_str}): {mp.n_total_graphs} graphs, {len(mp.task_profiles)} tasks")
        for task, tp in sorted(mp.task_profiles.items()):
            print(f"    {task}: {tp.n_graphs} graphs")
else:
    print(f"No scale profiles found at {profiles_path}")
    print("Run: python -m src.pipeline --scale-mode --n-random 100")
    model_profiles = {}

# Load scale analysis JSON
analysis_path = scale_dir / "scale_analysis.json"
if analysis_path.exists():
    with open(analysis_path) as f:
        scale_analysis = json.load(f)
    print(f"\nLoaded scale analysis summary")
else:
    scale_analysis = None

## 3. Scale Heatmap — Main Result

Mean Z-scores for each model size x motif class. This is the primary visualization.

In [None]:
if model_profiles:
    fig = plot_scale_heatmap(
        model_profiles,
        metric="z_score",
        title="Motif Z-Score Profile Across Model Scales",
        figsize=(16, max(4, len(model_profiles) * 1.2)),
    )
    plt.show()
else:
    print("No model profiles loaded.")

## 4. FFL Backbone Analysis

The central question: is the feedforward loop (030T) universally enriched across all model scales?

In [None]:
if model_profiles:
    universal, ffl_details = check_ffl_backbone(model_profiles)
    
    print(f"FFL universally enriched: {universal}")
    print()
    print(f"{'Model':<22s}  {'Mean Z':>7s}  {'Std Z':>7s}  {'Enriched':>8s}  {'N':>4s}")
    print('-' * 55)
    for model_id in sorted(ffl_details.keys()):
        d = ffl_details[model_id]
        if d.get('mean_z') is not None:
            print(f"{model_id:<22s}  {d['mean_z']:>+7.2f}  {d['std_z']:>7.2f}  "
                  f"{d['pct_enriched']:>7.1f}%  {d['n_total']:>4d}")
        else:
            print(f"{model_id:<22s}  {'N/A':>7s}")
    
    # Plot FFL Z-score across scales
    sorted_models = sorted(model_profiles.items(), key=lambda kv: kv[1].model_spec.n_params)
    x = [mp.model_spec.log_params for _, mp in sorted_models]
    y = [float(mp.overall_mean_z[MOTIF_FFL]) for _, mp in sorted_models]
    y_std = [float(mp.overall_std_z[MOTIF_FFL]) for _, mp in sorted_models]
    labels = [mid for mid, _ in sorted_models]
    
    fig, ax = plt.subplots(figsize=(10, 5))
    colors = [MODEL_SCALE_COLORS.get(mid, 'gray') for mid in labels]
    ax.errorbar(x, y, yerr=y_std, fmt='o-', color='#d62728', linewidth=2.5,
                markersize=10, capsize=5, capthick=2, label='030T (FFL)')
    for xi, yi, mid, c in zip(x, y, labels, colors):
        ax.scatter([xi], [yi], c=c, s=100, zorder=5, edgecolors='black')
    
    ax.axhline(y=2.0, color='red', linestyle='--', alpha=0.3, label='Z = 2.0')
    ax.axhline(y=0, color='black', linewidth=0.5)
    
    tick_labels = [f"{mp.model_spec.n_params}M" if mp.model_spec.n_params < 1000
                   else f"{mp.model_spec.n_params // 1000}B" for _, mp in sorted_models]
    ax.set_xticks(x)
    ax.set_xticklabels(tick_labels, fontsize=11)
    ax.set_xlabel('Model Size', fontsize=12)
    ax.set_ylabel('Mean Z-score', fontsize=12)
    ax.set_title('FFL (030T) Enrichment Across Model Scales', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    plt.tight_layout()
    plt.show()
else:
    print("No model profiles loaded.")

## 5. Scaling Curves — Key Motifs vs log(params)

In [None]:
if model_profiles:
    trends = compute_scale_trends(model_profiles, metric="z_score")
    
    # Show key motifs
    key_motifs = [MOTIF_FFL, MOTIF_CHAIN, MOTIF_FAN_IN, MOTIF_FAN_OUT, 5, 11]  # FFL, Chain, Fan-in, Fan-out, 111U, 030C
    fig = plot_scale_trend(
        trends,
        motif_indices=key_motifs,
        title="Key Motif Z-Scores Across Model Scales",
        figsize=(14, 7),
    )
    plt.show()
    
    # Print trend summary
    print(f"\n{'Motif':<8s}  {'Direction':>10s}  {'Slope':>8s}  {'R²':>6s}  {'p-val':>8s}  {'ρ':>6s}  {'ρ p-val':>8s}")
    print('-' * 65)
    for t in sorted(trends, key=lambda t: abs(t.slope), reverse=True):
        if t.motif_index in CONNECTED_TRIAD_INDICES:
            sig = " *" if t.is_significant else ""
            print(f"{t.motif_label:<8s}  {t.trend_direction:>10s}  {t.slope:>+8.3f}  {t.r_squared:>6.3f}  "
                  f"{t.p_value:>8.4f}  {t.spearman_rho:>+6.3f}  {t.spearman_p:>8.4f}{sig}")
else:
    print("No model profiles loaded.")

## 6. Per-Task Scaling

How do scaling patterns differ by task category?

In [None]:
if model_profiles:
    # Find common tasks
    all_tasks = set()
    for mp in model_profiles.values():
        all_tasks.update(mp.task_profiles.keys())
    common_tasks = sorted(all_tasks)
    
    for task_name in common_tasks[:4]:  # Show first 4 tasks
        fig = plot_per_task_scaling(
            model_profiles,
            task_name,
            motif_indices=[MOTIF_FFL, MOTIF_CHAIN, MOTIF_FAN_IN, MOTIF_FAN_OUT],
            figsize=(12, 5),
        )
        plt.show()
else:
    print("No model profiles loaded.")

## 7. Phase Transition Detection

Are there abrupt jumps in motif enrichment between adjacent model sizes?

In [None]:
if model_profiles:
    transitions = detect_phase_transitions(model_profiles, min_effect_size=1.0)
    
    if transitions:
        print(f"Found {len(transitions)} phase transition(s) (Cohen's d >= 1.0):\n")
        print(f"{'Motif':<8s}  {'At size':>10s}  {'Before':>8s}  {'After':>8s}  {'Cohen d':>8s}")
        print('-' * 50)
        for pt in sorted(transitions, key=lambda t: t.effect_size, reverse=True)[:20]:
            p_str = f"{pt.transition_point}M" if pt.transition_point < 1000 else f"{pt.transition_point // 1000}B"
            print(f"{pt.motif_label:<8s}  {p_str:>10s}  {pt.before_mean:>+8.2f}  {pt.after_mean:>+8.2f}  {pt.effect_size:>8.2f}")
    else:
        print("No phase transitions detected (all changes gradual).")
else:
    print("No model profiles loaded.")

## 8. Model Similarity — Cosine Matrix & Dendrogram

In [None]:
if model_profiles and len(model_profiles) >= 2:
    sim_matrix, model_names = pairwise_model_similarity(model_profiles)
    
    fig = plot_cosine_similarity_matrix(
        sim_matrix, model_names,
        title="Model Cosine Similarity (SP Vectors)",
        figsize=(8, 7),
    )
    plt.show()
    
    # Run full comparison for dendrogram
    result = run_scale_comparison(model_profiles)
    if result.linkage_matrix.size > 0:
        fig = plot_scale_dendrogram(
            result.linkage_matrix,
            result.model_names,
            title="Model Similarity Dendrogram (Cosine Distance on SP Vectors)",
        )
        plt.show()
else:
    print("Need at least 2 model profiles for similarity analysis.")

## 9. SP Overlay — All Models' Profiles

In [None]:
if model_profiles:
    fig = plot_sp_overlay(
        model_profiles,
        title="Significance Profiles Across Gemma 3 Scales",
        figsize=(18, 7),
    )
    plt.show()
else:
    print("No model profiles loaded.")

## 10. Comparison with Haiku / Gemma-2-2B / Qwen3-4B Baseline

If the original Haiku pipeline results exist, compare the Gemma 3 scaling
profiles against the original Haiku/Gemma-2/Qwen3 baseline.

In [None]:
# Load original pipeline task profiles for comparison
original_profiles_path = Path("../data/results/task_profiles.pkl")
if original_profiles_path.exists() and model_profiles:
    with open(original_profiles_path, "rb") as f:
        original_profiles = pickle.load(f)
    
    # Compare overall mean Z-scores
    print("Original Haiku analysis (99 graphs):")
    for task, tp in sorted(original_profiles.items()):
        print(f"  {task}: {tp.n_graphs} graphs, mean |Z| = {np.mean(np.abs(tp.mean_z)):.2f}")
    
    # Build aggregate profile for Haiku
    all_sp = []
    all_z = []
    for tp in original_profiles.values():
        all_sp.extend(tp.sp_vectors)
        all_z.extend(tp.z_score_vectors)
    
    if all_z:
        haiku_mean_z = np.array(all_z).mean(axis=0)
        
        # Compare with each Gemma 3 model
        motif_labels = [TRIAD_LABELS[i] for i in CONNECTED_TRIAD_INDICES]
        print(f"\n{'Motif':<8s}  {'Haiku':>8s}", end='')
        for mid, mp in sorted(model_profiles.items(), key=lambda kv: kv[1].model_spec.n_params):
            p = mp.model_spec.n_params
            label = f"{p}M" if p < 1000 else f"{p // 1000}B"
            print(f"  {label:>8s}", end='')
        print()
        print('-' * (18 + 10 * len(model_profiles)))
        
        for idx in CONNECTED_TRIAD_INDICES:
            label = TRIAD_LABELS[idx]
            print(f"{label:<8s}  {haiku_mean_z[idx]:>+8.2f}", end='')
            for mid, mp in sorted(model_profiles.items(), key=lambda kv: kv[1].model_spec.n_params):
                if len(mp.overall_mean_z) > idx:
                    print(f"  {mp.overall_mean_z[idx]:>+8.2f}", end='')
                else:
                    print(f"  {'N/A':>8s}", end='')
            print()
else:
    if not model_profiles:
        print("No model profiles loaded.")
    else:
        print(f"Original profiles not found at {original_profiles_path}")