# ParisiJAX Demo: Spin Glass Physics

This notebook demonstrates the ParisiJAX library for simulating and analyzing the Sherrington-Kirkpatrick spin glass model using JAX.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from parisijax.core import hamiltonian, solver, mcmc
from parisijax.analysis import overlap
from parisijax.viz import animation

# Set random seed
key = jax.random.PRNGKey(42)

## 1. SK Hamiltonian Basics

Generate a random SK coupling matrix and compute energies.

In [None]:
# System size
n_spins = 100

# Generate coupling matrix
key, subkey = jax.random.split(key)
J = hamiltonian.sample_couplings(subkey, n_spins, n_samples=1)[0]

# Generate random spin configurations
key, subkey = jax.random.split(key)
spins = hamiltonian.random_spins(subkey, n_spins, n_samples=10)

# Compute energies
energies = jax.vmap(hamiltonian.sk_energy, in_axes=(0, None))(spins, J)

print(f"Energy per spin: mean = {jnp.mean(energies/n_spins):.4f}, std = {jnp.std(energies/n_spins):.4f}")

## 2. Replica-Symmetric Solution

Compute the RS free energy and find the critical temperature.

In [None]:
# Compute RS free energy at different temperatures
betas = np.linspace(0.1, 2.0, 20)
f_rs = [solver.rs_free_energy(beta) for beta in betas]

# Find critical temperature
beta_c = solver.find_critical_temperature()

# Plot
plt.figure(figsize=(10, 6))
plt.plot(betas, f_rs, 'o-', linewidth=2, markersize=6)
plt.axvline(beta_c, color='red', linestyle='--', label=f'β_c = {beta_c:.3f}')
plt.xlabel('Inverse Temperature β', fontsize=12)
plt.ylabel('Free Energy per Spin', fontsize=12)
plt.title('Replica-Symmetric Free Energy', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 3. MCMC Simulation

Run Monte Carlo simulations at different temperatures.

In [None]:
# Run MCMC at low temperature
beta = 2.0
key, subkey = jax.random.split(key)

final_spins, energy_trajectory = mcmc.run_mcmc(
    subkey, J, beta, h=0.0,
    n_steps=500,
    n_samples=10,
    method='metropolis'
)

# Plot energy evolution
animation.plot_energy_trajectory(energy_trajectory, n_show=10)
plt.show()

## 4. Overlap Distribution

Compute the overlap distribution P(q) at different temperatures.

In [None]:
# High temperature (paramagnetic)
key, subkey = jax.random.split(key)
overlaps_high = overlap.sample_overlap_distribution(
    subkey, J, beta=0.5, n_samples=200, n_steps=1000, burnin=200
)

# Low temperature (spin glass)
key, subkey = jax.random.split(key)
overlaps_low = overlap.sample_overlap_distribution(
    subkey, J, beta=2.0, n_samples=200, n_steps=1000, burnin=200
)

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
animation.plot_overlap_distribution(overlaps_high, beta=0.5, ax=ax1)
animation.plot_overlap_distribution(overlaps_low, beta=2.0, ax=ax2)
plt.tight_layout()
plt.show()

print(f"Edwards-Anderson parameter:")
print(f"  High T (β=0.5): q_EA = {overlap.compute_edwards_anderson_parameter(overlaps_high):.4f}")
print(f"  Low T (β=2.0): q_EA = {overlap.compute_edwards_anderson_parameter(overlaps_low):.4f}")

## 5. Parisi k-RSB Solution (Optional)

Optimize the full Parisi free energy. Note: This is computationally intensive.

In [None]:
# Optimize Parisi solution (use small k for demo)
beta = 1.5
k = 3  # Number of RSB levels

q_opt, m_opt, f_opt, history = solver.optimize_parisi(
    beta, h=0.0, k=k, n_steps=200, learning_rate=0.01, n_quad=16
)

print(f"\nOptimal Parisi solution at β = {beta}:")
print(f"  q = {q_opt}")
print(f"  m = {m_opt}")
print(f"  f = {f_opt:.4f}")

# Plot Parisi function
animation.plot_parisi_function(q_opt, m_opt)
plt.show()

## 6. Summary

This demo showed:
- Basic SK Hamiltonian operations
- RS free energy and phase transition
- MCMC simulations on GPU
- Overlap distribution analysis
- Parisi RSB solution (optional)

For more examples, see the full documentation!