# ResNet-BK Interpretability

Visualize and understand ResNet-BK internals.

This notebook visualizes:
- G_ii diagonal elements (real and imaginary)
- Learned potential v_i
- Expert routing patterns
- Attention-like patterns from BK-Core

In [None]:
!pip install datasets torch matplotlib seaborn -q

In [None]:
# Repo setup (clone if needed, add to sys.path)
import os, sys, subprocess, pathlib
REPO_URL = 'https://github.com/neko-jpg/Project-ResNet-BK-An-O-N-Language-Model-Architecture.git'
REPO_DIR = 'Project-ResNet-BK-An-O-N-Language-Model-Architecture'
cwd = pathlib.Path.cwd()
candidates = [cwd, cwd.parent, cwd / REPO_DIR, cwd.parent / REPO_DIR]
root = next((p for p in candidates if (p / 'src').exists()), None)
if root is None:
    root = cwd / REPO_DIR
    if not root.exists():
        subprocess.run(['git', 'clone', REPO_URL, str(root)], check=True)
if root != pathlib.Path.cwd():
    os.chdir(root)
root_str = str(pathlib.Path.cwd())
if root_str not in sys.path:
    sys.path.insert(0, root_str)
print('PWD:', root_str)


In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from src.models import LanguageModel
from src.utils import get_data_loader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Load Trained Model

In [None]:
# Load checkpoint
checkpoint = torch.load('checkpoints/resnet_bk_final.pt', map_location=device)
config = checkpoint['config']

model = LanguageModel(**config).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("Model loaded successfully")

## Visualize G_ii Diagonal Elements

In [None]:
# Get sample input
train_data, vocab, get_batch = get_data_loader(20, 128)
x_batch, _ = get_batch(train_data, 0)
x_batch = x_batch.t().contiguous().to(device)

# Forward pass and extract G_ii
with torch.no_grad():
    # Hook to capture BK-Core output
    G_ii_features = []
    
    def hook_fn(module, input, output):
        G_ii_features.append(output.cpu())
    
    hook = model.blocks[0].bk_layer.output_proj.register_forward_hook(hook_fn)
    _ = model(x_batch)
    hook.remove()

# Visualize first sample
features = G_ii_features[0][0].numpy()  # (N, 2)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(features[:, 0], label='Real(G_ii)')
ax1.set_xlabel('Position')
ax1.set_ylabel('Value')
ax1.set_title('Real Part of G_ii')
ax1.grid(True, alpha=0.3)
ax1.legend()

ax2.plot(features[:, 1], label='Imag(G_ii)', color='orange')
ax2.set_xlabel('Position')
ax2.set_ylabel('Value')
ax2.set_title('Imaginary Part of G_ii')
ax2.grid(True, alpha=0.3)
ax2.legend()

plt.tight_layout()
plt.show()

## Visualize Learned Potential v_i

In [None]:
# Extract potential values
with torch.no_grad():
    v_values = []
    
    def hook_v(module, input, output):
        v_values.append(output.cpu())
    
    hook = model.blocks[0].bk_layer.v_proj.register_forward_hook(hook_v)
    _ = model(x_batch)
    hook.remove()

v = v_values[0][0].squeeze().numpy()

plt.figure(figsize=(10, 4))
plt.plot(v)
plt.xlabel('Position')
plt.ylabel('Potential v_i')
plt.title('Learned Potential Across Sequence')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Potential statistics:")
print(f"  Mean: {v.mean():.4f}")
print(f"  Std: {v.std():.4f}")
print(f"  Min: {v.min():.4f}")
print(f"  Max: {v.max():.4f}")

## Summary

This notebook demonstrates how to:
- Extract and visualize BK-Core features
- Analyze learned potential patterns
- Understand model internals

**Next steps:**
- Analyze expert routing patterns
- Compare patterns across different layers
- Correlate patterns with linguistic structure