# Frequency Analysis

This notebook provides detailed analysis of the frequency structure in compressed modular addition models:
- Validation accuracy curve
- Interaction matrix visualization
- Eigendecomposition analysis (effective rank, explained variance)
- FFT analysis of eigenvectors
- Frequency heatmaps per bottleneck size

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
from pathlib import Path
from tqdm import tqdm

from models import load_sweep_results, get_device
from js_div import (
    compute_interaction_matrices,
    compute_eigen_data,
    compute_frequency_heatmap,
    analyze_eigenvector_fft,
    entropy_effective_rank,
    ratio_effective_rank,
    cumulative_explained_variance,
    components_for_variance_threshold
)

# Configuration
SWEEP_PATH = Path('../comp_diagrams/sweep_results_0401.pkl')
DEVICE = get_device()
print(f'Device: {DEVICE}')

## Load Data

In [None]:
# Load sweep results
models_state, val_acc, P = load_sweep_results(SWEEP_PATH)
print(f'Loaded {len(models_state)} models with P={P}')
print(f'Validation accuracies available: {len(val_acc)}')

## Validation Accuracy Curve

In [None]:
if len(val_acc) > 0:
    dims_sorted = sorted(list(val_acc.keys()), reverse=True)
    accs_sorted = [val_acc[d] for d in dims_sorted]
    
    fig, ax = plt.subplots(figsize=(12, 5))
    ax.plot(dims_sorted, accs_sorted, 'o-', linewidth=2, markersize=6, label='Validation Accuracy')
    ax.set_xlabel('Hidden Dimension (Bottleneck)', fontsize=12)
    ax.set_ylabel('Validation Accuracy', fontsize=12)
    ax.set_title('Validation Accuracy vs. Bottleneck Dimension', fontsize=14)
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=11)
    ax.set_ylim([0, 1.05])
    plt.tight_layout()
    plt.show()
    
    # Find threshold dimensions
    perfect_dims = [d for d, acc in val_acc.items() if acc >= 0.999]
    if perfect_dims:
        print(f'Smallest dimension with perfect accuracy: {min(perfect_dims)}')
    high_acc_dims = [d for d, acc in val_acc.items() if acc >= 0.95]
    if high_acc_dims:
        print(f'Smallest dimension with >95% accuracy: {min(high_acc_dims)}')
else:
    print('No validation accuracies available.')

## Compute Interaction Matrices and Eigendecomposition

In [None]:
# Compute interaction matrices
print('Computing interaction matrices...')
int_mats = compute_interaction_matrices(models_state, P, DEVICE)
print(f'Computed interaction matrices for {len(int_mats)} models')

# Compute eigendecomposition
print('\nComputing eigendecomposition...')
eigen_data = compute_eigen_data(int_mats, P)
print(f'Computed eigendecomposition for {len(eigen_data)} models')

## Interaction Matrix Visualization

In [None]:
def show_interaction_matrix(d_hidden=P, remainder=0):
    """Visualize interaction matrix for a given bottleneck size and remainder."""
    mat = int_mats[d_hidden][remainder]
    
    fig, ax = plt.subplots(figsize=(8, 7))
    im = ax.imshow(mat, cmap='RdBu', aspect='auto', origin='lower')
    ax.set_xlabel('Input index')
    ax.set_ylabel('Input index')
    ax.set_title(f'Interaction Matrix: d_hidden={d_hidden}, remainder={remainder}')
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()

# Interactive exploration
d_slider = widgets.IntSlider(min=1, max=P, step=1, value=P, description='d_hidden')
r_slider = widgets.IntSlider(min=0, max=P-1, step=1, value=0, description='remainder')
out = widgets.interactive_output(show_interaction_matrix, {'d_hidden': d_slider, 'remainder': r_slider})
display(widgets.HBox([d_slider, r_slider]), out)

## Effective Rank Analysis

In [None]:
# Compute effective ranks for all models
entropy_ranks = {}
ratio_ranks = {}

for d_hidden in range(1, P+1):
    evals = eigen_data[d_hidden]['eigenvalues']  # (P, 2P)
    
    entropy_per_remainder = [entropy_effective_rank(evals[r]) for r in range(P)]
    ratio_per_remainder = [ratio_effective_rank(evals[r]) for r in range(P)]
    
    entropy_ranks[d_hidden] = {
        'mean': np.mean(entropy_per_remainder),
        'std': np.std(entropy_per_remainder),
        'per_remainder': entropy_per_remainder
    }
    ratio_ranks[d_hidden] = {
        'mean': np.mean(ratio_per_remainder),
        'std': np.std(ratio_per_remainder),
        'per_remainder': ratio_per_remainder
    }

print('Computed effective ranks for all models')

In [None]:
# Plot effective ranks vs bottleneck dimension
dims = list(range(1, P+1))
entropy_means = [entropy_ranks[d]['mean'] for d in dims]
entropy_stds = [entropy_ranks[d]['std'] for d in dims]
ratio_means = [ratio_ranks[d]['mean'] for d in dims]
ratio_stds = [ratio_ranks[d]['std'] for d in dims]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Entropy-based effective rank
ax1 = axes[0]
ax1.errorbar(dims, entropy_means, yerr=entropy_stds, fmt='o-', capsize=3,
             linewidth=2, markersize=4, label='Entropy-based')
ax1.set_xlabel('Bottleneck Dimension', fontsize=12)
ax1.set_ylabel('Effective Rank', fontsize=12)
ax1.set_title('Entropy-based Effective Rank', fontsize=14)
ax1.grid(True, alpha=0.3)
ax1.legend()

# Ratio-based effective rank
ax2 = axes[1]
ax2.errorbar(dims, ratio_means, yerr=ratio_stds, fmt='s-', capsize=3,
             linewidth=2, markersize=4, color='orange', label='Ratio-based')
ax2.set_xlabel('Bottleneck Dimension', fontsize=12)
ax2.set_ylabel('Effective Rank', fontsize=12)
ax2.set_title('Ratio-based Effective Rank', fontsize=14)
ax2.grid(True, alpha=0.3)
ax2.legend()

plt.tight_layout()
plt.show()

## Cumulative Explained Variance

In [None]:
# Compute for multiple thresholds
variance_thresholds = [0.90, 0.95, 0.99]
components_needed = {thresh: {} for thresh in variance_thresholds}

for d_hidden in range(1, P+1):
    evals = eigen_data[d_hidden]['eigenvalues']  # (P, 2P), already sorted by |λ|
    
    for thresh in variance_thresholds:
        comps_per_remainder = [components_for_variance_threshold(evals[r], thresh) for r in range(P)]
        components_needed[thresh][d_hidden] = {
            'mean': np.mean(comps_per_remainder),
            'std': np.std(comps_per_remainder),
            'max': np.max(comps_per_remainder),
            'per_remainder': comps_per_remainder
        }

print(f'Computed components needed for variance thresholds: {variance_thresholds}')

In [None]:
# Plot components needed vs bottleneck dimension
fig, ax = plt.subplots(figsize=(12, 6))

dims = list(range(1, P+1))
colors = ['green', 'blue', 'red']

for thresh, color in zip(variance_thresholds, colors):
    means = [components_needed[thresh][d]['mean'] for d in dims]
    stds = [components_needed[thresh][d]['std'] for d in dims]
    ax.errorbar(dims, means, yerr=stds, fmt='o-', capsize=3, linewidth=2,
                markersize=4, color=color, label=f'{int(thresh*100)}% variance')

ax.set_xlabel('Bottleneck Dimension', fontsize=12)
ax.set_ylabel('Number of Components Needed', fontsize=12)
ax.set_title('Components Needed for Variance Threshold', fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend()
plt.tight_layout()
plt.show()

## FFT Analysis of Eigenvectors

In [None]:
def plot_eigenvector_fft(d_hidden=P, remainder=0, evec_idx=0):
    """Plot eigenvector and its FFT spectrum."""
    evecs = eigen_data[d_hidden]['eigenvectors']
    evals = eigen_data[d_hidden]['eigenvalues']
    
    evec = evecs[remainder, :, evec_idx]
    eval_val = evals[remainder, evec_idx]
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Plot eigenvector (full)
    ax = axes[0, 0]
    ax.plot(evec, 'b-', linewidth=1)
    ax.axvline(P, color='r', linestyle='--', alpha=0.5, label='a|b boundary')
    ax.set_xlabel('Index')
    ax.set_ylabel('Value')
    ax.set_title(f'Eigenvector (λ={eval_val:.4f})')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot eigenvector halves separately
    ax = axes[0, 1]
    ax.plot(evec[:P], 'b-', linewidth=1, label='Input a')
    ax.plot(evec[P:], 'r-', linewidth=1, label='Input b')
    ax.set_xlabel('Index')
    ax.set_ylabel('Value')
    ax.set_title('Eigenvector (split by input)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # FFT of first half (input a)
    ax = axes[1, 0]
    fft_a = np.abs(np.fft.rfft(evec[:P]))
    freqs_a = np.fft.rfftfreq(P, d=1.0) * P  # Convert to integer frequencies
    ax.stem(freqs_a, fft_a, basefmt=' ')
    ax.set_xlabel('Frequency')
    ax.set_ylabel('Magnitude')
    ax.set_title('FFT of Input a component')
    ax.grid(True, alpha=0.3)
    
    # FFT of second half (input b)
    ax = axes[1, 1]
    fft_b = np.abs(np.fft.rfft(evec[P:]))
    freqs_b = np.fft.rfftfreq(P, d=1.0) * P
    ax.stem(freqs_b, fft_b, basefmt=' ')
    ax.set_xlabel('Frequency')
    ax.set_ylabel('Magnitude')
    ax.set_title('FFT of Input b component')
    ax.grid(True, alpha=0.3)
    
    fig.suptitle(f'd_hidden={d_hidden}, remainder={remainder}, eigenvector #{evec_idx+1}', fontsize=14)
    plt.tight_layout()
    plt.show()

# Interactive widget
d_slider = widgets.IntSlider(min=1, max=P, step=1, value=P, description='d_hidden')
r_slider = widgets.IntSlider(min=0, max=P-1, step=1, value=0, description='remainder')
e_slider = widgets.IntSlider(min=0, max=9, step=1, value=0, description='evec_idx')
out = widgets.interactive_output(plot_eigenvector_fft, {'d_hidden': d_slider, 'remainder': r_slider, 'evec_idx': e_slider})
display(widgets.HBox([d_slider, r_slider, e_slider]), out)

## Frequency Heatmaps

In [None]:
# Plot frequency heatmaps for selected bottleneck sizes
bottleneck_sizes_to_plot = [P, P*3//4, P//2, P//4]
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

for ax, d_hidden in zip(axes.flat, bottleneck_sizes_to_plot):
    heatmap = compute_frequency_heatmap(eigen_data, d_hidden, P, n_evecs=4)
    im = ax.imshow(heatmap, aspect='auto', cmap='hot', origin='lower')
    ax.set_xlabel('Frequency')
    ax.set_ylabel('Remainder')
    ax.set_title(f'd_hidden={d_hidden}')
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

fig.suptitle('Frequency Content by Remainder (weighted by eigenvalue magnitude)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Interactive frequency heatmap explorer
def show_freq_heatmap(d_hidden=P, n_evecs=4):
    heatmap = compute_frequency_heatmap(eigen_data, d_hidden, P, n_evecs)
    
    fig, ax = plt.subplots(figsize=(10, 8))
    im = ax.imshow(heatmap, aspect='auto', cmap='hot', origin='lower')
    ax.set_xlabel('Frequency', fontsize=12)
    ax.set_ylabel('Remainder', fontsize=12)
    ax.set_title(f'Frequency Heatmap (d_hidden={d_hidden}, top {n_evecs} evecs)', fontsize=14)
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()

d_slider = widgets.IntSlider(min=1, max=P, step=1, value=P, description='d_hidden')
n_slider = widgets.IntSlider(min=1, max=10, step=1, value=4, description='n_evecs')
out = widgets.interactive_output(show_freq_heatmap, {'d_hidden': d_slider, 'n_evecs': n_slider})
display(widgets.HBox([d_slider, n_slider]), out)

## Dominant Frequencies Analysis

In [None]:
def count_dominant_frequencies(fft_magnitudes, energy_threshold=0.95):
    """Count number of frequencies needed to capture energy_threshold of total FFT energy."""
    energy = fft_magnitudes ** 2
    total_energy = energy.sum()
    if total_energy < 1e-12:
        return 0
    sorted_energy = np.sort(energy)[::-1]
    cumulative = np.cumsum(sorted_energy) / total_energy
    idx = np.searchsorted(cumulative, energy_threshold)
    return min(idx + 1, len(fft_magnitudes))

# Compute average number of dominant frequencies for each bottleneck size
n_top_evecs = 4
freq_threshold = 0.90
avg_dominant_freqs = {}

for d_hidden in range(1, P+1):
    evecs = eigen_data[d_hidden]['eigenvectors']
    
    n_freqs_per_evec = []
    for r in range(P):
        for i in range(min(n_top_evecs, evecs.shape[2])):
            evec = evecs[r, :, i]
            fft_a = np.abs(np.fft.rfft(evec[:P]))
            n_freqs = count_dominant_frequencies(fft_a, freq_threshold)
            n_freqs_per_evec.append(n_freqs)
    
    avg_dominant_freqs[d_hidden] = {
        'mean': np.mean(n_freqs_per_evec),
        'std': np.std(n_freqs_per_evec),
        'median': np.median(n_freqs_per_evec)
    }

# Plot
dims = list(range(1, P+1))
means = [avg_dominant_freqs[d]['mean'] for d in dims]
stds = [avg_dominant_freqs[d]['std'] for d in dims]

fig, ax = plt.subplots(figsize=(12, 6))
ax.errorbar(dims, means, yerr=stds, fmt='o-', capsize=3, linewidth=2, markersize=4)
ax.set_xlabel('Bottleneck Dimension', fontsize=12)
ax.set_ylabel(f'Avg # Frequencies for {int(freq_threshold*100)}% Energy', fontsize=12)
ax.set_title('Average Dominant Frequencies vs. Bottleneck Dimension', fontsize=14)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Summary

Key observations:
- For small bottlenecks, eigenvectors have more dominant frequencies
- Frequency content stabilizes as bottleneck dimension increases
- The interaction matrices show clear frequency patterns (cosine-like structure)