# Full Motif Analysis: Motif Zoo & Cross-Task Comparison

This notebook loads the pipeline results from `data/results/` and presents:
1. **Motif Zoo** â€” summary statistics for each triad class across all 99 graphs
2. **Z-score heatmap** â€” mean Z-scores by task category Ã— motif class
3. **Task similarity** â€” cosine similarity matrix and dendrogram
4. **Interesting graphs** â€” graphs with extreme motif enrichment patterns
5. **Neuronpedia-style plots** â€” attribution graphs with highlighted motif instances

**Prerequisite:** Run the pipeline first:
```bash
python -m src.pipeline --n-random 100
```

In [None]:
import sys
sys.path.insert(0, "..")

import json
import pickle
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from src.graph_loader import load_attribution_graph
from src.motif_census import (
    TRIAD_LABELS, CONNECTED_TRIAD_INDICES,
    MOTIF_FFL, MOTIF_CHAIN, MOTIF_FAN_IN, MOTIF_FAN_OUT,
    find_motif_instances,
)
from src.visualization import (
    plot_zscore_heatmap,
    plot_sp_heatmap,
    plot_task_dendrogram,
    plot_cosine_similarity_matrix,
    plot_grouped_bar,
    plot_top_motif,
)

%matplotlib inline
plt.rcParams['figure.dpi'] = 120

## 1. Load Pipeline Results

In [None]:
results_dir = Path("../data/results")

# Load analysis summary (JSON)
with open(results_dir / "analysis_summary.json") as f:
    summary = json.load(f)

# Load task profiles (pickle â€” needed for visualization functions)
with open(results_dir / "task_profiles.pkl", "rb") as f:
    task_profiles = pickle.load(f)

# Load clustering
with open(results_dir / "clustering.pkl", "rb") as f:
    clustering = pickle.load(f)

# Load per-graph results
with open(results_dir / "per_graph_results.pkl", "rb") as f:
    per_graph = pickle.load(f)

print(f"Loaded results for {len(summary['graphs'])} graphs across {len(task_profiles)} categories")
for name, profile in sorted(task_profiles.items()):
    print(f"  {name}: {profile.n_graphs} graphs")

## 2. Motif Zoo â€” Summary Statistics

For each of the 15 connected triad classes, we show:
- Mean/median Z-score across all graphs
- What percentage of graphs show enrichment (Z > 2) or depletion (Z < -2)
- Range of Z-scores observed

In [None]:
zoo = summary["motif_zoo"]["motif_summary"]

print(f"{'Motif':<8s}  {'Mean Z':>7s}  {'Med Z':>7s}  {'Min Z':>7s}  {'Max Z':>7s}  "
      f"{'Enriched':>8s}  {'Depleted':>8s}  {'N':>3s}")
print("-" * 75)
for m in zoo:
    print(f"{m['label']:<8s}  {m['mean_z']:>+7.2f}  {m['median_z']:>+7.2f}  "
          f"{m['min_z']:>+7.1f}  {m['max_z']:>+7.1f}  "
          f"{m['pct_enriched']:>7.1f}%  {m['pct_depleted']:>7.1f}%  "
          f"{m['n_graphs']:>3d}")

In [None]:
# Bar chart of enrichment/depletion rates across motifs
labels = [m['label'] for m in zoo]
enriched = [m['pct_enriched'] for m in zoo]
depleted = [-m['pct_depleted'] for m in zoo]  # negative for visual

fig, ax = plt.subplots(figsize=(14, 5))
x = np.arange(len(labels))
ax.bar(x, enriched, color='#d62728', alpha=0.8, label='Enriched (Z > 2)')
ax.bar(x, depleted, color='#1f77b4', alpha=0.8, label='Depleted (Z < -2)')
ax.axhline(0, color='black', linewidth=0.5)
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=9)
ax.set_ylabel('% of graphs', fontsize=11)
ax.set_title('Motif Enrichment/Depletion Rates Across All Attribution Graphs', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.set_ylim(-100, 100)
plt.tight_layout()
plt.show()

## 3. Z-Score Distribution per Motif

Violin plots showing the distribution of Z-scores for each motif class across all graphs.

In [None]:
# Collect all Z-scores per motif
motif_labels = [TRIAD_LABELS[i] for i in CONNECTED_TRIAD_INDICES]
z_data = {label: [] for label in motif_labels}

for g in summary['graphs']:
    for label in motif_labels:
        z = g['z_scores'].get(label, 0)
        if np.isfinite(z):
            z_data[label].append(z)

fig, ax = plt.subplots(figsize=(14, 6))
positions = list(range(len(motif_labels)))
parts = ax.violinplot(
    [z_data[label] for label in motif_labels],
    positions=positions,
    showmeans=True,
    showmedians=True,
)

# Color the violin bodies
for pc in parts['bodies']:
    pc.set_facecolor('#6baed6')
    pc.set_alpha(0.7)

ax.axhline(2.0, color='red', linestyle='--', alpha=0.5, label='Z = +/-2')
ax.axhline(-2.0, color='red', linestyle='--', alpha=0.5)
ax.axhline(0, color='black', linewidth=0.5)
ax.set_xticks(positions)
ax.set_xticklabels(motif_labels, rotation=45, ha='right', fontsize=9)
ax.set_ylabel('Z-score', fontsize=11)
ax.set_title('Z-Score Distribution per Motif Class (All Graphs)', fontsize=13, fontweight='bold')
ax.legend(fontsize=9)
plt.tight_layout()
plt.show()

## 4. Cross-Task Z-Score Heatmap

The main result figure: mean Z-scores for each task category Ã— motif class.
Red = enriched, blue = depleted.

In [None]:
fig = plot_zscore_heatmap(
    task_profiles,
    title='Mean Motif Z-Scores by Task Category',
    figsize=(16, 7),
)
plt.show()

In [None]:
fig = plot_sp_heatmap(
    task_profiles,
    title='Mean Significance Profiles by Task Category',
    figsize=(16, 7),
)
plt.show()

In [None]:
fig = plot_grouped_bar(
    task_profiles,
    title='Motif Z-Scores by Task Category (Grouped Bar)',
    figsize=(18, 7),
)
plt.show()

## 5. Task Similarity & Clustering

In [None]:
sim_data = summary['similarity_matrix']
sim_matrix = np.array(sim_data['matrix'])
task_names = sim_data['task_names']

fig = plot_cosine_similarity_matrix(
    sim_matrix, task_names,
    title='Task Category Cosine Similarity (SP Vectors)',
    figsize=(9, 8),
)
plt.show()

In [None]:
linkage_matrix = clustering['linkage']
cluster_names = clustering['names']

if len(linkage_matrix) > 0:
    fig = plot_task_dendrogram(
        linkage_matrix, cluster_names,
        title='Task Category Dendrogram (Cosine Distance on SP Vectors)',
        figsize=(12, 6),
    )
    plt.show()
else:
    print('Not enough categories for clustering.')

## 6. Kruskal-Wallis Test Results

Which motif classes differ significantly across task categories?

In [None]:
kw = summary['kruskal_wallis']

print(f"{'Motif':<8s}  {'H-stat':>8s}  {'p-value':>10s}  {'Significant':>11s}")
print('-' * 42)
for r in kw:
    if r['motif_index'] not in CONNECTED_TRIAD_INDICES:
        continue
    sig = 'YES ***' if r['significant'] else 'no'
    h = f"{r['H_statistic']:.2f}" if np.isfinite(r['H_statistic']) else 'N/A'
    p = f"{r['p_value']:.4f}" if np.isfinite(r['p_value']) else 'N/A'
    print(f"{r['label']:<8s}  {h:>8s}  {p:>10s}  {sig:>11s}")

## 7. Most Interesting Graphs

Graphs ranked by maximum absolute Z-score â€” these have the most extreme motif enrichment patterns.

In [None]:
interesting = summary['motif_zoo']['interesting_graphs']

print(f"{'Rank':>4s}  {'Category':<18s}  {'Name':<35s}  {'Max |Z|':>7s}  {'# Sig':>5s}  {'Nodes':>5s}  {'Edges':>5s}")
print('-' * 95)
for i, g in enumerate(interesting[:20]):
    print(f"{i+1:>4d}  {g['category']:<18s}  {g['name']:<35s}  {g['max_abs_z']:>7.1f}  "
          f"{g['n_significant_motifs']:>5d}  {g['n_nodes']:>5d}  {g['n_edges']:>5d}")

## 8. Neuronpedia-Style Visualization of Top Interesting Graphs

We pick the top-ranked interesting graphs and show their attribution graph
with the highest-weight feedforward loop (FFL) motif highlighted.

In [None]:
# Find the per-graph result entries for the top interesting graphs
def find_graph_path(per_graph, category, name):
    """Find the file path for a graph by category and name.
    
    Paths stored in pipeline results are relative to project root,
    so we prepend '..' since the notebook runs from notebooks/.
    """
    for r in per_graph.get(category, []):
        if r['name'] == name:
            return Path("..") / r['path']
    return None


# Show attribution graphs for top 3 interesting graphs
n_show = min(3, len(interesting))
for i in range(n_show):
    info = interesting[i]
    path = find_graph_path(per_graph, info['category'], info['name'])
    if path is None:
        print(f"Could not find path for {info['name']}")
        continue

    print(f"\n{'='*70}")
    print(f"Graph #{i+1}: {info['category']}/{info['name']}")
    print(f"Max |Z|={info['max_abs_z']:.1f}, "
          f"{info['n_significant_motifs']} significant motifs")
    print(f"Prompt: {info['prompt']}")
    print(f"{'='*70}")

    # Load with pruning for clearer visualization
    g = load_attribution_graph(path, weight_threshold=1.0)
    print(f"Pruned graph: {g.vcount()} nodes, {g.ecount()} edges")

    # Try FFL first, then chain
    for motif_name, motif_id in [('FFL', MOTIF_FFL), ('Chain', MOTIF_CHAIN)]:
        instances = find_motif_instances(g, motif_id, sort_by='weight')
        if instances:
            print(f"Found {len(instances)} {motif_name} instances")
            fig, inst = plot_top_motif(
                g, motif_id, rank=0,
                title=f"{info['category']}/{info['name']} â€” Top {motif_name} (pruned, threshold=1.0)",
                figsize=(18, 14),
            )
            print(f"Top {motif_name} weight: {inst.total_weight:.1f}")
            for node_idx, role in inst.node_roles.items():
                print(f"  {role:12s}: {g.vs[node_idx]['clerp']} (layer {g.vs[node_idx]['layer']})")
            plt.show()
            break
    else:
        print("No FFL or Chain instances found in pruned graph.")

## 9. Per-Category Motif Instance Counts

How many motif instances of each type appear per category (averaged over graphs)?

In [None]:
# Aggregate instance counts by category
instance_types = ['FFL', 'chain', 'fan_in', 'fan_out', 'cycle']

print(f"{'Category':<18s}  {'N':>3s}  {'FFL':>7s}  {'Chain':>7s}  {'Fan-in':>7s}  {'Fan-out':>8s}  {'Cycle':>6s}")
print('-' * 65)

for category in sorted(per_graph.keys()):
    results_list = per_graph[category]
    if not results_list:
        continue
    n = len(results_list)
    means = {}
    for itype in instance_types:
        vals = [r['instance_counts'][itype] for r in results_list]
        means[itype] = np.mean(vals)

    print(f"{category:<18s}  {n:>3d}  {means['FFL']:>7.0f}  {means['chain']:>7.0f}  "
          f"{means['fan_in']:>7.0f}  {means['fan_out']:>8.0f}  {means['cycle']:>6.0f}")

## 10. Graph Size Distribution

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Node count distribution
node_counts = [g['n_nodes'] for g in summary['graphs']]
edge_counts = [g['n_edges'] for g in summary['graphs']]
categories = [g['category'] for g in summary['graphs']]
unique_cats = sorted(set(categories))

from src.visualization import TASK_COLORS

for cat in unique_cats:
    cat_nodes = [n for n, c in zip(node_counts, categories) if c == cat]
    cat_edges = [e for e, c in zip(edge_counts, categories) if c == cat]
    color = TASK_COLORS.get(cat, '#7f7f7f')
    axes[0].scatter(cat_nodes, [cat]*len(cat_nodes), c=color, s=40, alpha=0.7, label=cat)
    axes[1].scatter(cat_edges, [cat]*len(cat_edges), c=color, s=40, alpha=0.7, label=cat)

axes[0].set_xlabel('Node Count', fontsize=11)
axes[0].set_title('Node Counts by Category', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Edge Count', fontsize=11)
axes[1].set_title('Edge Counts by Category', fontsize=12, fontweight='bold')

for ax in axes:
    ax.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Node count range: {min(node_counts)} - {max(node_counts)} (mean {np.mean(node_counts):.0f})")
print(f"Edge count range: {min(edge_counts)} - {max(edge_counts)} (mean {np.mean(edge_counts):.0f})")

## 11. Pairwise Task Comparisons

Which task pairs have the most/least similar motif profiles?

In [None]:
pairwise = summary['pairwise_comparisons']
pairwise_sorted = sorted(pairwise, key=lambda x: x['cosine_similarity'], reverse=True)

print("Most similar pairs:")
for p in pairwise_sorted[:5]:
    print(f"  {p['task_a']} vs {p['task_b']}: cos_sim={p['cosine_similarity']:.3f}, "
          f"{p['n_significant_motifs']} sig. motifs")

print(f"\nLeast similar pairs:")
for p in pairwise_sorted[-5:]:
    print(f"  {p['task_a']} vs {p['task_b']}: cos_sim={p['cosine_similarity']:.3f}, "
          f"{p['n_significant_motifs']} sig. motifs")