# Analogy Analysis with Dirichlet Energy

Transformeranalyzes how analogy reasoning is processed by computing Dirichlet energy across layers.

**hypothesis**: As the model processes analogy, embeddings of functor-related entities (`<e1>` and `<e6>`) become more similar in deeper layers, correlating with correct prediction.

## 1. Setup

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import json
from pathlib import Path
import os

os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN"

DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {DEVICE}")

Using device: cuda


In [None]:
# visualizationStyle configuration (large fonts for readability)
def setup_plot_style():
    plt.close()
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['Inter', 'Helvetica', 'Arial', 'DejaVu Sans'],
        'font.size': 28,
        'axes.labelsize': 32,
        'axes.titlesize': 28,
        'axes.linewidth': 2.0,
        'axes.spines.top': False,
        'axes.spines.right': False,
        'xtick.labelsize': 26,
        'ytick.labelsize': 26,
        'xtick.major.width': 1.5,
        'ytick.major.width': 1.5,
        'xtick.major.size': 8,
        'ytick.major.size': 8,
        'legend.fontsize': 24,
        'legend.frameon': False,
        'axes.grid': True,
        'grid.alpha': 0.2,
        'grid.linestyle': ':',
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'lines.linewidth': 3,
        'lines.markersize': 10,
    })
    plt.close()
    plt.style.use("ggplot")
    plt.rcParams.update({'axes.grid': True})

setup_plot_style()

## 2. model loading

In [None]:
from embeddings import EmbeddingExtractor, load_sample, print_tokenization

MODEL_NAME = "google/gemma-2-2b"
DTYPE = torch.float32

extractor = EmbeddingExtractor(
    model_name=MODEL_NAME,
    device=DEVICE,
    dtype=DTYPE,
)
extractor.load_model()

## 3. Sample loading and embedding extraction

In [None]:
sample = load_sample("samples/sample_1.json")

prompt = sample["prompt"]
target_token = sample["target_token"]
sample_name = sample.get("name", "sample")

# configuration
APPLY_RMSNORM = True
PROXY_METHOD = "position"  # "position" or "mean"

print(f"Sample: {sample_name}")
print(f"Prompt: {repr(prompt)}")
print(f"Target token: '{target_token}'")
print(f"Apply RMSNorm: {APPLY_RMSNORM}")
print(f"Proxy method: {PROXY_METHOD}")

In [None]:
# embedding extraction
embeddings, token_info = extractor.extract_all_layer_embeddings(prompt, apply_ln=APPLY_RMSNORM)
print(f"Embeddings shape: {embeddings.shape}")

print_tokenization(token_info)

## 4. Dirichlet Energy  computation

$$E(G) = \sum_{(i,j) \in \text{edges}} \|x_i - x_j\|^2$$

In [None]:
from dirichlet_energy import (
    compute_dirichlet_energy_with_positions,
    compute_dirichlet_energy_with_mean_positions,
    normalize_dirichlet_energy,
    print_dirichlet_results,
)

edges = [(e["source"], e["target"]) for e in sample["graph"]["edges"]]
print(f"Graph edges: {edges}")

if PROXY_METHOD == "mean":
    dirichlet_results = compute_dirichlet_energy_with_mean_positions(
        extractor, embeddings, token_info, edges, sample["node_positions"]
    )
else:
    dirichlet_results = compute_dirichlet_energy_with_positions(
        extractor, embeddings, token_info, edges, sample["node_positions"]
    )

normalized_results = normalize_dirichlet_energy(dirichlet_results, embeddings)
print_dirichlet_results(dirichlet_results, normalized_results)

## 5. Functor Similarity & Logit Lens

In [None]:
from functor_similarity import compute_functor_similarity_all_layers, print_functor_similarity_results
from logit_lens import compute_logit_lens_all_layers, compute_correlation_with_metrics, print_logit_lens_results

# Functor Similarity
functor_token = "~"
n_functors = sum(1 for info in token_info if functor_token in info.token_str)
print(f"Found {n_functors} functors")

functor_results = None
if n_functors >= 2:
    functor_results = compute_functor_similarity_all_layers(
        extractor, embeddings, token_info, functor_token, n_functors
    )
    print_functor_similarity_results(functor_results)

# Logit Lens
logit_lens_results = compute_logit_lens_all_layers(
    extractor.model, embeddings, target_token, position=-1, top_k=10
)
print_logit_lens_results(logit_lens_results, target_token)

## 6. correlationAnalysis

In [None]:
# correlationComputation
target_probs = np.array([r.target_token_prob for r in logit_lens_results])
dirichlet_energies = np.array([r.energy for r in dirichlet_results])
corr = np.corrcoef(target_probs, dirichlet_energies)[0, 1]
print(f"Correlation (Dirichlet vs P(target)): {corr:.4f}")

## 7. visualization

In [None]:
setup_plot_style()
layers = list(range(len(dirichlet_results)))
sequence_str = f"Sequence: {prompt}[{target_token}]"

fig, ax1 = plt.subplots(figsize=(16, 8))
color1, color2 = '#4C72B0', '#C44E52'

ax1.set_xlabel('Layer')
ax1.set_ylabel('Dirichlet Energy', color=color1)
line1 = ax1.plot(layers, dirichlet_energies, color=color1, marker='o', linewidth=4, markersize=12, label='Dirichlet Energy')
ax1.tick_params(axis='y', labelcolor=color1)

ax2 = ax1.twinx()
ax2.set_ylabel('P(target)', color=color2)
line2 = ax2.plot(layers, target_probs, color=color2, marker='s', linewidth=4, markersize=12, label='P')
ax2.tick_params(axis='y', labelcolor=color2)
ax2.spines['right'].set_visible(True)

ax1.legend(line1 + line2, [l.get_label() for l in line1 + line2], loc='upper left')
fig.suptitle(f'Dirichlet Energy vs P(target) (Corr: {corr:.4f})', fontsize=36, y=0.98)
ax1.set_title(sequence_str, fontsize=20, color='gray', pad=15)
fig.tight_layout()
fig.subplots_adjust(top=0.85)
plt.savefig('results/correlation_plot.pdf', bbox_inches='tight')
plt.show()

In [None]:
# Normalized Dirichlet Energy vs P(target)
setup_plot_style()
fig, ax1 = plt.subplots(figsize=(16, 8))

ax1.set_xlabel('Layer')
ax1.set_ylabel('Normalized Dirichlet Energy', color=color1)
line1 = ax1.plot(layers, normalized_results, color=color1, marker='o', linewidth=4, markersize=12, label='Norm. Dirichlet')
ax1.tick_params(axis='y', labelcolor=color1)

ax2 = ax1.twinx()
ax2.set_ylabel('P(target)', color=color2)
line2 = ax2.plot(layers, target_probs, color=color2, marker='s', linewidth=4, markersize=12, label='P')
ax2.tick_params(axis='y', labelcolor=color2)
ax2.spines['right'].set_visible(True)

ax1.legend(line1 + line2, [l.get_label() for l in line1 + line2], loc='upper left')
fig.suptitle('Normalized Dirichlet Energy vs P(target)', fontsize=36, y=0.98)
fig.tight_layout()
fig.subplots_adjust(top=0.88)
plt.savefig('results/dirichlet_vs_probability.pdf', bbox_inches='tight')
plt.show()

## 8. Summary

In [None]:
print("=" * 70)
print("ANALYSIS SUMMARY")
print("=" * 70)
print(f"\nModel: {MODEL_NAME}")
print(f"Prompt: {prompt}")
print(f"Target: {target_token}")
print(f"\nSettings: RMSNorm={APPLY_RMSNORM}, Proxy={PROXY_METHOD}")
print(f"\nCorrelation (Dirichlet vs P(target)): {corr:.4f}")
print(f"Final P(target): {target_probs[-1]:.4f}")
print(f"Max P(target): {max(target_probs):.4f} at layer {np.argmax(target_probs)}")

if corr < -0.5:
    print("\n✓ Strong negative correlation → supports hypothesis")
    print("  Dirichletenergyprediction probability increases as it decreases")
elif corr > 0.5:
    print("\n⚠ Positive correlation (possibly due to embedding norm growth)")
else:
    print("\n~ Weak correlation - results are inconclusive")
print("=" * 70)

In [None]:
# resultssave
norm_dir = "rmsnorm" if APPLY_RMSNORM else "no_rmsnorm"
proxy_dir = "mean_proxy" if PROXY_METHOD == "mean" else "num_proxy"
output_path = Path("results") / sample_name / norm_dir / proxy_dir
output_path.mkdir(parents=True, exist_ok=True)

results = {
    "model": MODEL_NAME, "prompt": prompt, "target_token": target_token,
    "correlation": float(corr),
    "dirichlet_energy": [{"layer": r.layer, "energy": r.energy} for r in dirichlet_results],
    "logit_lens": [{"layer": r.layer, "prob": r.target_token_prob} for r in logit_lens_results],
}
with open(output_path / "analysis_results.json", "w") as f:
    json.dump(results, f, indent=2)
print(f"Results saved to {output_path}")

In [None]:
if functor_results:
    correlations = compute_correlation_with_metrics(logit_lens_results, dirichlet_results, functor_results)
    print(f"Correlation (Dirichlet vs P(target)): {correlations['dirichlet_energy']:.4f}")
    print(f"Correlation (Cosine Sim vs P(target)): {correlations['cosine_similarity']:.4f}")
else:
    target_probs = np.array([r.target_token_prob for r in logit_lens_results])
    dirichlet_energies = np.array([r.energy for r in dirichlet_results])
    corr = np.corrcoef(target_probs, dirichlet_energies)[0, 1]
    correlations = {"dirichlet_energy": corr}
    print(f"Correlation (Dirichlet vs P(target)): {corr:.4f}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import json
from pathlib import Path

# Hugging FaceAuthentication (for gated models like Gemma)
import os
os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN"

# Device configuration
if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"
print(f"Using device: {DEVICE}")