<a href="https://colab.research.google.com/github/engelberger/neural-VMC/blob/main/Neural_Variational_Monte_Carlo_A_Mathematical_Walkthrough.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Neural Variational Monte Carlo: A Mathematical Walkthrough

The goal of this tutorial is to build an intuitive and technical understanding of the mathematics presented in amazing blog post "Neural Variational Monte Carlo" by Teddy Koker [link](https://teddykoker.com/2024/11/neural-vmc-jax/). We'll break down the core equations and concepts, interleaving explanations with relevant Python/JAX code snippets. Please note that the code is heavily based on Koker's blog post and its references. The original blog post motivated me to deepen my understanding of the ideas around the PsiFormer from Google Deepmind and related ideas so I prepared this google colab notebook to play around more easily.

**Assumed Background:** Familiarity with basic calculus, linear algebra (eigenvalues/vectors), probability (expectation values), Python, JAX, and some intuition about quantum mechanics (wavefunctions, energy levels).

Felipe Engelberger, April 2025.

## Setup: Imports

First, let's import the necessary libraries. We'll primarily use JAX for numerical computation and automatic differentiation, Equinox for building neural networks in JAX, and Optax for optimization.


In [None]:
# Install necessary Python libraries quietly (-q)
%pip install jax jaxlib optax equinox tqdm matplotlib pyscf -q
print("Libraries installed successfully!")

In [None]:
# Necessary imports (as used in the blog post)
import jax
import jax.numpy as jnp
import numpy as np
import math
import equinox as eqx # For building neural networks in JAX
from functools import partial
from collections.abc import Callable
import optax # For optimization algorithms
from tqdm.notebook import tqdm # For progress bars in notebooks

# Set a default random key for reproducibility in examples
key = jax.random.PRNGKey(42) # Changed seed for minor variation

## 1. The Goal: Solving the Schrödinger Equation

At the heart of quantum chemistry and physics lies the time-independent Schrödinger equation. For the electronic structure of a molecule or material (within the Born-Oppenheimer approximation where nuclei are fixed), it's written as:

$$
\hat{H}\psi(X) = E\psi(X)
$$

Let's break this down:

* **$\hat{H}$ (The Hamiltonian Operator):** Represents the total energy of the system's electrons. It's an operator containing terms for the kinetic energy of electrons and the potential energy from electron-electron repulsions and electron-nucleus attractions.
* **$\psi(X)$ (The Wavefunction):** The eigenfunction. Contains all information about the state of the N electrons ($X = (x_1, ..., x_N)$, where $x_i$ includes position $\mathbf{r}_i$ and spin $\sigma_i$). The square of the wavefunction, $|\psi(X)|^2$, gives the probability density of finding the electrons in the specific configuration $X$.
* **$E$ (The Energy Eigenvalue):** A scalar representing the total energy of the system when it's in the state described by $\psi(X)$.

**The Problem:** This is an eigenvalue problem. We are looking for specific functions $\psi$ (eigenfunctions) and corresponding scalar energies $E$ (eigenvalues) that satisfy the equation. There are typically many solutions ($\psi_n, E_n$).

**The Specific Goal (Ground State):** We are usually most interested in the lowest possible energy $E_0$ (the ground state energy) and its corresponding wavefunction $\psi_0$ (the ground state wavefunction). This represents the most stable configuration of the electrons. Finding this ground state is computationally very challenging for systems with many electrons.


## 2. The Strategy: The Variational Principle

Solving the Schrödinger equation directly is often impossible for complex systems. The Variational Principle gives us a powerful alternative strategy. It states that for any reasonably well-behaved "trial" wavefunction $\tilde{\psi}$ that we might guess, the expectation value of the energy calculated with this trial wavefunction will always be greater than or equal to the true ground state energy $E_0$:

$$
E_{var} = \frac{\langle\tilde{\psi}|\hat{H}|\tilde{\psi}\rangle}{\langle\tilde{\psi}|\tilde{\psi}\rangle} \ge E_0
$$

Let's unpack the notation:

* **$\tilde{\psi}$ (Trial Wavefunction):** Any function that could potentially represent the system's state. It doesn't have to be the true ground state wavefunction $\psi_0$.
* **$\langle\phi|\psi\rangle$ (Dirac Bra-Ket Notation):** Represents the inner product. For wavefunctions depending on configuration $X$, it means integrating their product over all possible configurations: $\int \phi^*(X) \psi(X) dX$. The `*` denotes the complex conjugate (though often wavefunctions can be chosen to be real).
* **$\langle\tilde{\psi}|\hat{H}|\tilde{\psi}\rangle$:** The expectation value of the Hamiltonian (average energy) for the trial state: $\int \tilde{\psi}^*(X) \hat{H} \tilde{\psi}(X) dX$.
* **$\langle\tilde{\psi}|\tilde{\psi}\rangle$:** The normalization factor: $\int |\tilde{\psi}(X)|^2 dX$.

**The Significance:** The Variational Principle turns the problem of finding the ground state into an **optimization problem**.

1.  The calculated energy $E_{var}$ provides an **upper bound** to the true ground state energy $E_0$.
2.  The equality $E_{var} = E_0$ holds if and only if the trial wavefunction $\tilde{\psi}$ is exactly the true ground state wavefunction $\psi_0$.

Therefore, if we can propose a family of trial wavefunctions parameterized by some variables $\theta$ (let's call them $\psi_\theta$), we can systematically adjust $\theta$ to minimize $E_{var}$. The lower the energy we achieve, the closer our $\psi_\theta$ is (in some sense) to the true $\psi_0$, and the closer our calculated energy is to $E_0$.


## 3. The Method: Variational Monte Carlo (VMC)

VMC uses the Variational Principle combined with Monte Carlo methods (using random sampling to estimate quantities).

### 3.1. Parameterized Wavefunction and Loss Function

We introduce a trial wavefunction $\psi_\theta(X)$ that depends on a set of parameters $\theta$. In Neural VMC, $\theta$ represents the weights and biases of a neural network. Our goal is to find the optimal $\theta$ that minimizes the variational energy. We define a loss function $L(\theta)$ which is exactly the variational energy for $\psi_\theta$:

$$
L(\theta) = \frac{\langle\psi_\theta|\hat{H}|\psi_\theta\rangle}{\langle\psi_\theta|\psi_\theta\rangle} = \frac{\int \psi_\theta^*(X) \hat{H} \psi_\theta(X) dX}{\int |\psi_\theta(X)|^2 dX}
$$

Minimizing $L(\theta)$ with respect to $\theta$ brings us closer to the ground state energy $E_0$.


### 3.2. Rewriting the Loss for Monte Carlo Sampling

The integrals in the loss function are usually high-dimensional and impossible to compute analytically. Monte Carlo methods are ideal for estimating such integrals. To make this suitable for sampling, we rewrite the loss function. Assuming $\psi_\theta$ is real (common in VMC):

$$
L(\theta) = \frac{\int \psi_\theta(X) \hat{H} \psi_\theta(X) dX}{\int \psi_\theta(X)^2 dX}
$$

Now, multiply the numerator and denominator inside the top integral by $\psi_\theta(X)$ (this is valid as long as $\psi_\theta(X) \neq 0$ where the integrand is non-zero):

$$
L(\theta) = \frac{\int \psi_\theta(X)^2 \left( \frac{\hat{H} \psi_\theta(X)}{\psi_\theta(X)} \right) dX}{\int \psi_\theta(X)^2 dX}
$$

Notice that $\psi_\theta(X)^2 = |\psi_\theta(X)|^2$. This term represents the unnormalized probability density of finding the electrons in configuration $X$. Let's define the **Local Energy**:

$$
E_{local}(X; \theta) = \frac{\hat{H} \psi_\theta(X)}{\psi_\theta(X)}
$$

The local energy represents the energy of the system at a specific electron configuration $X$, given the current wavefunction parameters $\theta$. If $\psi_\theta$ were the exact eigenfunction $\psi_0$, then $\hat{H}\psi_0 = E_0\psi_0$, and $E_{local}(X)$ would be equal to the constant $E_0$ for all $X$. For an approximate trial wavefunction, $E_{local}(X)$ will fluctuate depending on the configuration $X$.

Substituting back into the loss expression, we see that $L(\theta)$ is the average of the local energy, weighted by the probability density $|\psi_\theta(X)|^2$:

$$
L(\theta) = \frac{\int |\psi_\theta(X)|^2 E_{local}(X; \theta) dX}{\int |\psi_\theta(X)|^2 dX}
$$

**The Significance:** This expresses the loss (average variational energy) as the expectation value of the local energy $E_{local}(X; \theta)$ when the configurations $X$ are sampled from the probability distribution $P_\theta(X) = \frac{|\psi_\theta(X)|^2}{\int |\psi_\theta(X')|^2 dX'}$:

$$
L(\theta) = \mathbb{E}_{X \sim |\psi_\theta|^2} [E_{local}(X; \theta)]
$$

This is perfect for Monte Carlo estimation:

1.  Generate many electron configurations (samples) $X_1, X_2, ..., X_M$ drawn from the distribution $|\psi_\theta(X)|^2$ (using Metropolis-Hastings).
2.  For each sample $X_i$, calculate its local energy $E_{local}(X_i; \theta)$.
3.  Estimate the loss (average energy) by averaging the local energies: $L(\theta) \approx \frac{1}{M} \sum_{i=1}^M E_{local}(X_i; \theta)$.

### 3.3. Calculating the Local Energy in Code

To calculate $E_{local}$, we need the Hamiltonian $\hat{H}$. As we'll see in Section 4, $\hat{H}$ consists of kinetic and potential energy terms. The following functions implement the calculation based on the Born-Oppenheimer Hamiltonian. The `kinetic_energy` function uses `jax.hessian` for automatic differentiation to compute the Laplacian part ($\nabla^2 \psi$), while `potential_energy` computes the multiplicative terms.

In [None]:
def kinetic_energy(wavefunction: Callable, pos: jax.Array) -> jax.Array:
    """
    Calculates the kinetic energy part of the local energy: -0.5 * Laplacian(psi) / psi.
    Uses JAX's automatic differentiation (`jax.hessian`) to compute the Laplacian.
    The Laplacian is the trace of the Hessian matrix of the wavefunction w.r.t position.

    Args:
        wavefunction: A function (e.g., a neural network) that takes electron positions
                      (flattened array) and returns the wavefunction value.
        pos: A flattened array [3N] of electron positions for a single configuration.

    Returns:
        The kinetic energy value for the given configuration.
    """
    # jax.hessian(wavefunction) returns a function that computes the Hessian matrix.
    # Applying it to (pos) computes the Hessian at the given position.
    # The Hessian is a matrix of second derivatives: H_ij = d^2 psi / (dr_i dr_j)
    # where r_i, r_j are components of the position vector (x1, y1, z1, x2, ...).
    hessian_matrix = jax.hessian(wavefunction)(pos)

    # The Laplacian is the sum of the diagonal elements of the Hessian:
    # Sum(d^2 psi / dr_i^2) where i runs over all 3N coordinates.
    laplacian = jnp.trace(hessian_matrix)

    # Kinetic energy formula
    # Need to handle potential division by zero if wavefunction is zero,
    # although in practice sampling avoids regions where psi is exactly zero.
    psi_val = wavefunction(pos)
    # Add a small epsilon for numerical stability if needed, though sampling density
    # usually keeps psi_val reasonably far from zero.
    epsilon = 1e-8
    # Use sign of psi_val to add epsilon safely
    psi_val_stable = psi_val + epsilon * jnp.sign(psi_val)
    # Handle case where psi_val was exactly zero
    psi_val_stable = jnp.where(psi_val == 0, epsilon, psi_val_stable)

    return -0.5 * laplacian / psi_val_stable


def potential_energy(atoms: jax.Array, charges: jax.Array, pos: jax.Array) -> jax.Array:
    """
    Calculates the potential energy part of the local energy.
    Includes electron-electron repulsion, electron-nucleus attraction,
    and nucleus-nucleus repulsion (constant term).

    Args:
        atoms: [M, 3] array of M atomic positions.
        charges: [M] array of atomic charges (atomic numbers).
        pos: [3N] flattened array of N electron positions for a single configuration.

    Returns:
        The potential energy value for the given configuration.
    """
    # Reshape electron positions into [N, 3] array
    n_electrons = pos.shape[0] // 3
    pos_reshaped = pos.reshape(n_electrons, 3)
    n_atoms = atoms.shape[0]
    epsilon = 1e-8 # Small value to prevent division by zero

    # --- Electron-Nucleus Attraction (VeN) ---
    # Calculate pairwise distances between N electrons and M atoms.
    # Use broadcasting: pos_reshaped[:, None, :] -> [N, 1, 3]
    #                  atoms[None, :, :]      -> [1, M, 3]
    # Resulting shape: [N, M, 3] difference vectors.
    delta_r_ea = pos_reshaped[:, None, :] - atoms[None, :, :]
    # Calculate norms (distances) along the last axis (axis=-1). Shape: [N, M]
    r_ea = jnp.linalg.norm(delta_r_ea, axis=-1)
    # Calculate potential: sum over electrons i and nuclei I of -Z_I / |r_i - R_I|
    # charges[None, :] -> [1, M]
    # r_ea             -> [N, M]
    v_ea = -jnp.sum(charges[None, :] / (r_ea + epsilon)) # Sum over all N*M pairs

    # --- Electron-Electron Repulsion (Vee) ---
    # Calculate pairwise distances between N electrons.
    if n_electrons > 1:
      # Get upper triangle indices (i < j) to avoid double counting and i=j terms.
      i_indices, j_indices = jnp.triu_indices(n_electrons, k=1)
      # Calculate difference vectors for pairs. Shape: [num_pairs, 3]
      delta_r_ee = pos_reshaped[i_indices] - pos_reshaped[j_indices]
      # Calculate norms (distances). Shape: [num_pairs]
      r_ee = jnp.linalg.norm(delta_r_ee, axis=-1)
      # Calculate potential: sum over pairs (i<j) of 1 / |r_i - r_j|
      v_ee = jnp.sum(1.0 / (r_ee + epsilon)) # Sum over all pairs
    else:
      v_ee = 0.0 # No electron-electron repulsion for single electron

    # --- Nucleus-Nucleus Repulsion (VNN) ---
    # This is constant for fixed nuclei, could be pre-calculated.
    v_aa = 0.0
    if n_atoms > 1:
        i_indices_a, j_indices_a = jnp.triu_indices(n_atoms, k=1)
        delta_r_aa = atoms[i_indices_a] - atoms[j_indices_a]
        r_aa = jnp.linalg.norm(delta_r_aa, axis=-1)
        # Product of charges for pairs Z_I * Z_J
        z_aa = charges[i_indices_a] * charges[j_indices_a]
        v_aa = jnp.sum(z_aa / (r_aa + epsilon)) # Sum over all pairs

    return v_ee + v_ea + v_aa


def local_energy(wavefunction: Callable, atoms: jax.Array, charges: jax.Array, pos: jax.Array) -> jax.Array:
    """
    Computes the total local energy for a given electron configuration.
    E_local = Kinetic + Potential = (H psi / psi)

    Args:
        wavefunction: The wavefunction evaluator function.
        atoms: [M, 3] atomic positions.
        charges: [M] atomic charges.
        pos: [3N] flattened electron positions.

    Returns:
        The total local energy E_local(pos).
    """
    # Note: The wavefunction itself is needed only for the kinetic energy term.
    ke = kinetic_energy(wavefunction, pos)
    pe = potential_energy(atoms, charges, pos)
    return ke + pe

### 3.4. Example: Hydrogen Atom

Let's test the local energy calculation for a simple case: the Hydrogen atom (1 proton at the origin, 1 electron). The exact ground state wavefunction is $\psi(r) = \exp(-|r|) / \sqrt{\pi}$. We can use a function proportional to this, $\psi(r) = \exp(-|r|)$, as the proportionality constant cancels out in $E_{local} = \hat{H}\psi / \psi$. The exact ground state energy is -0.5 Hartree (in atomic units).

For the exact wavefunction, $E_{local}$ should be constant and equal to the eigenvalue (-0.5) for *any* electron position $r$.

In [None]:
# Define the (unnormalized) ground state wavefunction for Hydrogen
def wavefunction_h(pos: jax.Array) -> jax.Array:
    """ Calculates exp(-|r|) for a single electron position vector r=(x,y,z)."""
    # pos is expected to be [x, y, z] for a single electron (N=1)
    # jnp.linalg.norm calculates the Euclidean distance |r| = sqrt(x^2+y^2+z^2)
    return jnp.exp(-jnp.linalg.norm(pos))

# Define the system: 1 proton (charge +1) at the origin (0,0,0)
atoms_h = jnp.array([[0.0, 0.0, 0.0]])
charges_h = jnp.array([1.0])

# Generate a random position for the electron
key, subkey = jax.random.split(key)
# pos_h should be shape (3,) for a single electron (N=1 => 3N=3)
pos_h = jax.random.normal(subkey, shape=(3,)) * 0.5 # Start near the nucleus
print(f"Hydrogen Atom:")
print(f"Electron position 1: {pos_h}")

# Calculate the local energy at this random position
e_local_h = local_energy(wavefunction_h, atoms_h, charges_h, pos_h)

print(f"Wavefunction value psi(pos1): {wavefunction_h(pos_h):.4f}")
# Note: For the exact wavefunction, E_local should be constant (-0.5) everywhere.
# Due to numerical differentiation, there might be tiny variations.
print(f"Local Energy E_local(pos1): {e_local_h:.6f}")
print(f"Expected Energy E_exact:  -0.5")

# Verify E_local is constant (within numerical precision) for different positions
key, subkey = jax.random.split(key)
pos_h_2 = jax.random.normal(subkey, shape=(3,)) * 2.0 # Try a different position
print(f"\nElectron position 2: {pos_h_2}")
e_local_h_2 = local_energy(wavefunction_h, atoms_h, charges_h, pos_h_2)
print(f"Wavefunction value psi(pos2): {wavefunction_h(pos_h_2):.4f}")
print(f"Local Energy E_local(pos2): {e_local_h_2:.6f}")

# Assertion to check correctness
atol = 1e-5 # Tolerance for numerical errors
assert jnp.allclose(e_local_h, -0.5, atol=atol), f"E_local ({e_local_h}) should be -0.5 for Hydrogen ground state"
assert jnp.allclose(e_local_h_2, -0.5, atol=atol), f"E_local ({e_local_h_2}) should be constant and equal to -0.5"

print("\nAssertions passed!")

## 4. The Specifics: The Born-Oppenheimer Hamiltonian

Let's revisit the Hamiltonian $\hat{H}$ used in the `local_energy` calculation. It's the standard electronic Hamiltonian within the Born-Oppenheimer approximation (nuclei assumed fixed). In atomic units (where $\hbar = m_e = e = 4\pi\epsilon_0 = 1$):

$$
\hat{H} = \underbrace{-\frac{1}{2}\sum_{i=1}^N \nabla_i^2}_{\text{Kinetic Energy (KE)}} + \underbrace{\sum_{i=1}^N \sum_{j=i+1}^N \frac{1}{|\mathbf{r}_i - \mathbf{r}_j|}}_{\text{e-e Repulsion (Vee)}} \underbrace{- \sum_{i=1}^N \sum_{I=1}^M \frac{Z_I}{|\mathbf{r}_i - \mathbf{R}_I|}}_{\text{e-N Attraction (VeN)}} + \underbrace{\sum_{I=1}^M \sum_{J=I+1}^M \frac{Z_I Z_J}{|\mathbf{R}_I - \mathbf{R}_J|}}_{\text{N-N Repulsion (VNN)}}
$$

Where:
* $N$ is the number of electrons, $M$ is the number of nuclei.
* $\mathbf{r}_i$ is the position of the $i$-th electron.
* $\mathbf{R}_I$ is the position of the $I$-th nucleus.
* $Z_I$ is the charge (atomic number) of the $I$-th nucleus.
* $\nabla_i^2$ is the Laplacian operator acting on the coordinates of the $i$-th electron ($\frac{\partial^2}{\partial x_i^2} + \frac{\partial^2}{\partial y_i^2} + \frac{\partial^2}{\partial z_i^2}$).

**Mapping to Code:**
* **KE:** Handled by the `kinetic_energy` function using `jax.hessian` and the definition $KE \psi / \psi = (-0.5 \sum_i \nabla_i^2 \psi) / \psi$.
* **Vee, VeN, VNN:** These potential energy terms depend only on the positions of electrons ($\mathbf{r}_i$ / `pos`) and nuclei ($\mathbf{R}_I$ / `atoms`). They are calculated directly within the `potential_energy` function using distances. VNN is constant for a fixed geometry and doesn't depend on $\psi$.

The local energy $E_{local} = \hat{H}\psi / \psi$ is thus calculated as:
$E_{local} = (KE \psi / \psi) + Vee + VeN + VNN$.


## 5. The Sampling: Metropolis-Hastings Algorithm

We need to generate samples $X$ from the probability distribution $P_\theta(X) \propto |\psi_\theta(X)|^2$. For complex, high-dimensional wavefunctions $\psi_\theta$ (especially from neural networks), we usually can't sample directly.

The Metropolis-Hastings algorithm is a Markov Chain Monte Carlo (MCMC) method used to generate samples from a probability distribution that we can evaluate (up to a normalization constant), even if we can't sample from it directly. It generates a sequence of samples where each sample depends only on the previous one (Markov property), and the distribution of these samples converges to the target distribution.

**Algorithm Steps (One Step):**

1.  **Start** with a configuration $X_t$.
2.  **Propose a Move:** Generate a candidate next configuration $X'$ by making a small random change to $X_t$. A common choice is a Gaussian move: add a small random vector (drawn from a Gaussian/Normal distribution with standard deviation `step_size`) to the position of all electrons:
$$
X' = X_t + \mathcal{N}(0, \text{step_size}^2 \mathbf{I})
$$
    This defines the proposal distribution $Q(X'|X_t)$. For this symmetric Gaussian proposal, $Q(X'|X_t) = Q(X_t|X')$.
3.  **Calculate Acceptance Ratio:** Compute how much more or less probable the proposed configuration $X'$ is compared to the current configuration $X_t$ under the target distribution $P_\theta \propto |\psi_\theta|^2$. This is the acceptance ratio $A$:
$$
A = \frac{P_\theta(X')}{P_\theta(X_t)} = \frac{|\psi_\theta(X')|^2 / \text{Norm}}{|\psi_\theta(X_t)|^2 / \text{Norm}} = \frac{|\psi_\theta(X')|^2}{|\psi_\theta(X_t)|^2}
$$
    *(Note: If Q wasn't symmetric, A would also include the Hastings ratio $Q(X_t|X') / Q(X'|X_t)$).*
4.  **Accept or Reject:** Generate a uniform random number $u$ between 0 and 1 ($u \sim U[0, 1]$).
    * If $u \le \min(1, A)$, **accept** the proposal: $X_{t+1} = X'$.
    * If $u > \min(1, A)$, **reject** the proposal: $X_{t+1} = X_t$ (The chain stays in the same state).

**Why it Works (Intuition):**
* If the proposed move $X'$ is more probable than $X_t$ (i.e., $A \ge 1$), the move is always accepted ($\min(1, A) = 1$). The algorithm preferentially moves towards regions of higher probability density $|\psi_\theta|^2$.
* If the proposed move $X'$ is less probable than $X_t$ (i.e., $A < 1$), the move is accepted only sometimes, with a probability equal to $A$. This allows the algorithm to explore lower probability regions occasionally, preventing it from getting stuck.

By repeating these steps many times, the sequence of configurations $X_1, X_2, ...$ generated (after an initial "warm-up" or "burn-in" period to forget the starting point) forms a representative sample from the target distribution $|\psi_\theta|^2$.

**Implementation:**
The code below uses `jax.vmap` to run multiple MCMC chains in parallel (one for each sample in a batch, often called "walkers") and `jax.lax.fori_loop` to efficiently perform multiple MCMC steps.

In [None]:
@partial(jax.vmap, in_axes=(None, 0, 0, None, 0)) # Vectorize over batch of positions, probs & keys
@eqx.filter_jit # Just-in-Time compile this function for speed
def metropolis_step(
    wavefunction: Callable,
    pos: jax.Array,
    prob: jax.Array, # Current probability |psi(pos)|^2
    step_size: float,
    key: jax.Array,
) -> tuple[jax.Array, jax.Array, jax.Array]:
    """Performs a single Metropolis-Hastings step for one configuration (walker)."""
    key_proposal, key_accept = jax.random.split(key)

    # 1. Propose move: Add Gaussian noise to all electron coordinates
    pos_proposal = pos + step_size * jax.random.normal(key_proposal, shape=pos.shape)

    # 2. Calculate acceptance ratio A = |psi(pos')|^2 / |psi(pos)|^2
    # We pass prob = |psi(pos)|^2 to avoid recomputing it.
    prob_proposal = wavefunction(pos_proposal) ** 2
    # Avoid division by zero if prob is zero (might happen with poor initialization/step size)
    acceptance_ratio = prob_proposal / jnp.where(prob == 0, 1e-8, prob)

    # 3. Accept or reject
    accept_condition = jax.random.uniform(key_accept) < acceptance_ratio # accept if u < min(1, A) implicitly
    accept = jnp.array(accept_condition, dtype=jnp.float32) # Convert bool to float for metrics

    # Update position and probability based on acceptance
    new_pos = jnp.where(accept_condition, pos_proposal, pos)
    new_prob = jnp.where(accept_condition, prob_proposal, prob)

    return new_pos, new_prob, accept


@eqx.filter_jit
def metropolis_mcmc(
    wavefunction: Callable,
    initial_pos: jax.Array, # Batch of initial positions [batch_size, 3N]
    step_size: float,
    mcmc_steps: int,
    key: jax.Array,
) -> tuple[jax.Array, jax.Array]:
    """
    Runs the Metropolis-Hastings MCMC sampling for multiple steps and a batch of walkers.

    Args:
        wavefunction: The wavefunction evaluator.
        initial_pos: Starting positions for the walkers [batch_size, 3N].
        step_size: Standard deviation for the Gaussian proposal distribution.
        mcmc_steps: Number of MCMC steps to perform.
        key: JAX random key.

    Returns:
        final_pos: Positions after mcmc_steps [batch_size, 3N].
        acceptance_rate: Average acceptance rate over the batch and steps.
    """
    batch_size = initial_pos.shape[0]

    # Define the loop body for jax.lax.fori_loop
    def step_fn(i, carry):
        pos, prob, total_accepts, key = carry
        step_keys = jax.random.split(key, batch_size + 1)
        key = step_keys[0]
        walker_keys = step_keys[1:] # One key per walker in the batch

        # Perform one Metropolis step for all walkers in parallel (thanks to vmap in metropolis_step)
        new_pos, new_prob, accepts = metropolis_step(wavefunction, pos, prob, step_size, walker_keys)

        # Update total number of accepted moves (sum over batch)
        total_accepts += jnp.sum(accepts)
        return new_pos, new_prob, total_accepts, key

    # Initial state for the loop
    # Calculate initial probabilities |psi(pos_initial)|^2 for the batch
    initial_prob = jax.vmap(lambda p: wavefunction(p)**2)(initial_pos)
    initial_carry = (initial_pos, initial_prob, 0.0, key)

    # Run the MCMC loop
    final_pos, _, total_accepts, _ = jax.lax.fori_loop(0, mcmc_steps, step_fn, initial_carry)

    # Calculate average acceptance rate
    acceptance_rate = total_accepts / (batch_size * mcmc_steps)

    return final_pos, acceptance_rate

# --- Example usage (conceptual - needs a wavefunction model) ---
# print("\nRunning Conceptual MCMC Example...")
# key, subkey1, subkey2 = jax.random.split(key, 3)
# N_walkers = 1024
# N_dims = 3 # For H atom (N=1)
# mcmc_steps_example = 100
# step_size_example = 0.5

# # Need initial positions for the walkers
# initial_positions = jax.random.normal(subkey1, shape=(N_walkers, N_dims)) # Batch for H atom

# # Use the exact H wavefunction for this example
# final_positions, accept_rate = metropolis_mcmc(
#     wavefunction_h, initial_positions, step_size_example, mcmc_steps_example, subkey2
# )

# print(f"MCMC finished.")
# print(f"Shape of final positions: {final_positions.shape}") # Should be (N_walkers, N_dims)
# print(f"Average acceptance rate: {accept_rate:.3f}")
# # Acceptance rate typically aimed for 0.3-0.6 by tuning step_size
# # final_positions now contains samples approximately from |wavefunction_h|^2
# print("-" * 20)

## 6. The Optimization: Gradient Descent

We want to minimize the loss $L(\theta) = \mathbb{E}_{X \sim |\psi_\theta|^2} [E_{local}(X; \theta)]$ with respect to the parameters $\theta$ (neural network weights/biases). We typically use gradient descent methods (like Adam). This requires calculating the gradient $\nabla_\theta L(\theta)$.

**The Challenge:** We cannot simply take the gradient inside the expectation and average the gradient of the local energy: $\mathbb{E}_{X \sim |\psi_\theta|^2} [\nabla_\theta E_{local}(X; \theta)]$ is **incorrect / biased**. This is because the sampling distribution $|\psi_\theta|^2$ *itself* depends on the parameters $\theta$. Changing $\theta$ changes where we sample $X$, which also affects the average energy. This is related to the REINFORCE rule (score function estimator) in reinforcement learning.

**The Solution (Unbiased Gradient Estimator):** Through calculus (using the log-derivative trick, derived in the Appendix of the original blog post), one can derive an unbiased estimator for the gradient:

$$
\nabla_\theta L(\theta) = 2 \mathbb{E}_{X \sim |\psi_\theta|^2} \left[ \nabla_\theta \log |\psi_\theta(X)| \left( E_{local}(X; \theta) - L(\theta) \right) \right]
$$

Let's break this down:

* **$\mathbb{E}_{X \sim |\psi_\theta|^2}[...]$:** We estimate this as an average over samples $X_i$ drawn using Metropolis-Hastings from $|\psi_\theta|^2$.
* **$\nabla_\theta \log |\psi_\theta(X)|$:** This is the gradient of the logarithm of the (absolute value of the) wavefunction with respect to the parameters $\theta$ (often called the "score function"). For neural networks, this can be computed efficiently using automatic differentiation (backpropagation).
* **$E_{local}(X; \theta)$:** The local energy of the configuration $X$.
* **$L(\theta)$:** The current estimate of the average energy (loss), approximated by the sample mean $\frac{1}{M} \sum_i E_{local}(X_i; \theta)$.
* **$(E_{local}(X; \theta) - L(\theta))$:** The difference between the local energy of a specific sample and the current average energy (often called the "centered" local energy or baseline subtraction). This variance reduction technique is crucial for stabilizing training.

**Intuition:** The gradient pushes the parameters $\theta$ in a direction that:
* Increases the probability (increases $\log |\psi_\theta(X)|$) of configurations $X$ where the local energy $E_{local}(X)$ is *lower* than the current average $L(\theta)$.
* Decreases the probability of configurations where the local energy is *higher* than the average.

This systematically adjusts the wavefunction $\psi_\theta$ to concentrate probability density in low-energy regions, thus lowering the overall average energy $L(\theta)$.

**Implementation using Custom JVP:**
The original blog post uses `equinox.filter_custom_jvp` to define a function (`forward_loss_calculation` below) that calculates the loss $L(\theta)$ but whose gradient calculation rule (Jacobian-vector product, JVP) is explicitly defined to compute the unbiased estimator above. This leverages JAX's automatic differentiation framework while ensuring the correct gradient calculation specifically for VMC.

In [None]:
def make_loss_and_grad(atoms: jax.Array, charges: jax.Array):
    """
    Creates functions to compute the VMC loss and its gradient estimator.

    Args:
        atoms: [M, 3] atomic positions.
        charges: [M] atomic charges.

    Returns:
        A function `calculate_loss_and_grad(wavefunction, pos)` that computes
        the loss, auxiliary local energies, and the VMC gradient estimator.
    """
    # Pre-define a vectorized version of local_energy calculation
    # It applies local_energy over the batch dimension (axis 0) of pos
    batch_local_energy = jax.vmap(local_energy, in_axes=(None, None, None, 0))

    # Define the forward loss function L(theta) = E[E_local]
    # We use a custom JVP rule to define its gradient correctly for VMC.
    @eqx.filter_custom_jvp
    def forward_loss_calculation(wavefunction: eqx.Module, pos: jax.Array):
        """
        Calculates the VMC loss L(theta) = mean(E_local) over the batch 'pos'.
        Also returns the E_local values as auxiliary output (needed for gradient).
        Args:
            wavefunction: The parameterized wavefunction model (an eqx.Module).
            pos: Batch of electron configurations [batch_size, 3N].
        Returns:
            loss: The estimated energy E = mean(E_local).
            local_energies: The local energy for each configuration in the batch.
        """
        # Calculate local energy for each configuration in the batch
        e_l = batch_local_energy(wavefunction, atoms, charges, pos)
        # The loss is the mean of the local energies
        loss = jnp.mean(e_l)
        # Return loss and the individual local energies (needed for gradient)
        return loss, e_l

    # Define the custom Jacobian-Vector Product (JVP) for the loss function.
    # The JVP defines how the output changes when the input primals (wavefunction params)
    # are perturbed along the direction of the tangents. For gradient calculation,
    # JAX uses the transpose of the JVP (the VJP - vector-Jacobian product).
    # Defining the JVP correctly here ensures `jax.grad` computes the right VMC gradient.
    @forward_loss_calculation.def_jvp
    def loss_jvp(primals, tangents):
        """
        Defines the custom gradient calculation (via JVP) for VMC loss.
        Implements: dL = 2 * E[ d(log|psi|) * (E_local - L) ] in JVP form.
        """
        # Unpack primals (inputs) and tangents (perturbations to inputs)
        wavefunction, pos = primals
        # Tangents correspond to perturbations: dwavefunction contains perturbation
        # direction for wavefunction parameters. We don't need tangent for pos.
        wavefunction_tangent, _ = tangents

        # --- Step 1: Calculate forward pass results ---
        # Calculate the loss and local energies using the original forward function
        loss, local_energy = forward_loss_calculation(wavefunction, pos)
        primals_out = loss, local_energy # Output of the forward pass

        # --- Step 2: Calculate the gradient term for the JVP ---
        # We need d(log|psi|) = (d psi / psi), representing the change in log|psi|
        # due to the parameter perturbation wavefunction_tangent.
        # We compute the JVP of log|psi| directly.

        # Define log|psi| function (add epsilon for numerical stability near psi=0)
        # Use vmap to apply log_wavefunction over the batch of positions.
        log_abs_wavefunction = lambda wf, p: jnp.log(jnp.abs(wf(p)) + 1e-8)
        batch_log_abs_wavefunction = jax.vmap(log_abs_wavefunction, in_axes=(None, 0))

        # Compute JVP for log|psi|:
        # filter_jvp calculates:
        #   primals_out_logpsi = log|psi|(wavefunction, pos)
        #   tangents_out_logpsi = d(log|psi|) = (grad_theta log|psi|) . wavefunction_tangent
        # We only need tangents_out_logpsi, which is the directional derivative.
        _, log_psi_tangent = eqx.filter_jvp(
            batch_log_abs_wavefunction,
            (wavefunction, pos), # Primals for log_wavefunction
            (wavefunction_tangent, eqx.filter(pos, eqx.is_inexact_array)) # Tangents for log_wavefunction
            # eqx.filter(pos, ...) provides zero tangents for 'pos' as it's not differentiated w.r.t.
        )
        # log_psi_tangent now holds the batch of values for d(log|psi|) [shape: batch_size]

        # --- Step 3: Assemble the JVP output tangent for the loss ---
        # Calculate the VMC gradient contribution: 2 * E[ d(log|psi|) * (E_local - L) ]
        # In JVP terms, the output tangent (dL) is the dot product of the gradient
        # with the input tangent vector (wavefunction_tangent).
        # Here, log_psi_tangent *is* the dot product part related to log|psi|.
        batch_size = jnp.shape(local_energy)[0]
        centered_el = local_energy - loss
        # grad_term = 2.0 * jnp.mean(log_psi_tangent * centered_el)
        # Alternatively, using dot product as in the original blog:
        grad_term = 2.0 * jnp.dot(log_psi_tangent, centered_el) / batch_size

        # The tangent output corresponding to the loss is the computed grad_term.
        # The tangent output corresponding to local_energy is just local_energy itself
        # (as it's auxiliary data, its 'tangent' isn't used for the loss gradient).
        # Often, for auxiliary data, the tangent is set to None or zero. Let's pass it through.
        tangents_out = (grad_term, local_energy) # Or perhaps (grad_term, None)? Check Equinox docs if needed.

        return primals_out, tangents_out

    # Return a function that can be used with jax.value_and_grad or eqx.filter_value_and_grad
    @eqx.filter_jit
    def calculate_loss_and_grad(wavefunction: eqx.Module, pos: jax.Array):
      """
      Calculates VMC loss and gradient using the custom JVP rule.

      Args:
          wavefunction: The eqx.Module representing the wavefunction.
          pos: Batch of electron configurations [batch_size, 3N].

      Returns:
          value: A tuple (loss, local_energies).
          grad: The gradient of the loss w.r.t. the wavefunction parameters (pytree).
      """
      # has_aux=True because forward_loss_calculation returns (loss, e_l)
      # filter_value_and_grad works like jax.value_and_grad but handles eqx modules.
      value, grad = eqx.filter_value_and_grad(forward_loss_calculation, has_aux=True)(wavefunction, pos)
      # value = (loss, e_l)
      # grad = pytree matching wavefunction parameters
      return value, grad

    return calculate_loss_and_grad

# --- Example usage (conceptual - needs a wavefunction model and samples) ---
# print("\nConceptual Loss/Grad Example:")
# # Assume we have:
# # my_nn_wavefunction = PsiMLPJastrow(...) # Initialized model
# # sampled_positions = final_positions # From MCMC step

# # Create the loss/grad function for our system (e.g., Hydrogen)
# loss_and_grad_fn = make_loss_and_grad(atoms_h, charges_h)

# # Calculate loss and gradients
# (loss, e_locals), grads = loss_and_grad_fn(my_nn_wavefunction, sampled_positions)

# print(f"Calculated Loss (Energy): {loss:.6f}")
# print("Gradients computed (structure shown representatively):")
# print(jax.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, grads)) # Show shapes of gradients
# # 'grads' can now be used by an optimizer like optax.adam to update my_nn_wavefunction
# print("-" * 20)

## 7. The Model: Neural Wavefunction Properties

How do we design the neural network $\psi_\theta$? It needs to satisfy some fundamental physical principles, particularly for systems with multiple electrons (fermions).

### 7.1. Antisymmetry (Pauli Exclusion Principle)

Electrons are fermions. The total wavefunction must be **antisymmetric** with respect to the exchange of the coordinates (position $\mathbf{r}$ and spin $\sigma$) of any two identical fermions (e.g., two electrons with the same spin). Let $x_i = (\mathbf{r}_i, \sigma_i)$.

$$
\psi_\theta(x_1, ..., x_i, ..., x_j, ..., x_N) = - \psi_\theta(x_1, ..., x_j, ..., x_i, ..., x_N) \quad \text{if } \sigma_i = \sigma_j
$$

If $x_i = x_j$, then $\psi = -\psi$, which implies $\psi = 0$. This means two electrons with the same spin cannot occupy the same position (more generally, the same quantum state) – the **Pauli Exclusion Principle**.

**Implementation using Determinants:**
A standard way to enforce antisymmetry is using **Slater determinants**.
1.  Define single-particle orbitals $\phi_k(x_j)$, which are functions (outputs of a neural network component) evaluated for electron $j$. Typically, separate orbitals are used for spin-up and spin-down electrons.
2.  Construct matrices (one for spin-up, one for spin-down) where the element $(k, j)$ is $\phi_k(x_j)$ (for electrons $j$ of the respective spin).
3.  The determinant of such a matrix $\det[\Phi]$ (where $\Phi_{kj} = \phi_k(x_j)$) is inherently antisymmetric with respect to swapping the coordinates $x_i$ and $x_j$ (which corresponds to swapping columns $i$ and $j$ of the matrix).
4.  The overall wavefunction is often represented as a product of spin-up and spin-down determinants: $\psi = \det[\Phi^\uparrow] \times \det[\Phi^\downarrow]$.
5.  To increase expressiveness, a sum of multiple determinants ($\sum_k c_k \det[\Phi^\uparrow_k] \times \det[\Phi^\downarrow_k]$) can be used, where each $\Phi_k$ uses a different set of orbitals derived from the NN. The blog post uses a simpler sum-of-determinants form over *all* electrons for a single spin configuration (implicitly assuming fixed spin assignments).

The `PsiMLP` model in the blog post implements this general idea, using an MLP to generate features that are then combined into determinants.

In [None]:
class Linear(eqx.Module):
    """Standard Linear layer."""
    weights: jax.Array
    bias: jax.Array
    in_features: int = eqx.field(static=True)
    out_features: int = eqx.field(static=True)

    def __init__(self, in_size: int, out_size: int, key: jax.Array):
        # Initialize weights using a standard method (e.g., Kaiming/He uniform)
        # The blog used a simple uniform initialization based on input/output size.
        # Using Glorot/Xavier uniform initialization:
        lim = math.sqrt(6.0 / (in_size + out_size))
        self.weights = jax.random.uniform(key, (in_size, out_size), minval=-lim, maxval=lim)
        self.bias = jnp.zeros(out_size)
        self.in_features = in_size
        self.out_features = out_size

    def __call__(self, x: jax.Array) -> jax.Array:
        return jnp.dot(x, self.weights) + self.bias

class PsiMLP(eqx.Module):
    """
    MLP-based wavefunction model using a sum of Slater determinants.
    Based on the implementation described in the blog post.
    Assumes a single atom at the origin for simplicity in feature calculation.
    Uses fixed spin assignments (first n_up are spin up, rest are spin down).
    """
    spins: tuple[int, int] = eqx.field(static=True) # (n_up, n_down)
    n_electrons: int = eqx.field(static=True)
    n_determinants: int = eqx.field(static=True)
    num_atoms: int = eqx.field(static=True) # Store num_atoms statically

    # Network layers and parameters
    linears: list[Linear]
    orbitals: Linear # Final layer producing orbital values
    sigma: jax.Array # Parameters for Slater-like decay envelope
    pi: jax.Array    # Parameters for Slater-like decay envelope

    def __init__(
        self,
        hidden_sizes: list[int], # List of hidden layer dimensions
        spins: tuple[int, int],  # (n_up, n_down) electrons
        determinants: int,       # Number of determinants to sum
        num_atoms: int,          # Number of atoms (fixed geometry assumed)
        key: jax.Array,
    ):
        self.spins = spins
        n_up, n_down = spins
        self.n_electrons = n_up + n_down
        self.n_determinants = determinants
        self.num_atoms = num_atoms # Store number of atoms

        # Input features per electron:
        # From the blog's code: ae (vector r_i - R_I), r_ae (distance |r_i - R_I|), spin
        # For simplicity here, matching the blog, we'll compute features relative to
        # ONLY the first atom (assumed origin if num_atoms=1).
        # A more general implementation (like FermiNet) uses features from all atoms.
        # Features: displacement (3), distance (1), spin (1) = 5 features
        feature_size = 3 + 1 + 1
        sizes = [feature_size] + hidden_sizes
        n_layers = len(sizes) - 1

        # Create MLP layers
        layer_keys = jax.random.split(key, n_layers + 3) # Keys for layers, orbitals, pi, sigma
        key = layer_keys[0]
        mlp_keys = layer_keys[1:n_layers+1]
        orbital_key = layer_keys[n_layers+1]
        pi_key = layer_keys[n_layers+2]
        sigma_key = layer_keys[n_layers+3]

        self.linears = []
        for i in range(n_layers):
            self.linears.append(Linear(sizes[i], sizes[i+1], mlp_keys[i]))

        # Final layer outputs values used to construct orbitals for all determinants
        # Output size = n_electrons * n_determinants (one value per electron per determinant)
        final_layer_out_size = self.n_electrons * determinants
        self.orbitals = Linear(sizes[-1], final_layer_out_size, orbital_key)

        # Parameters for the Slater-like exponential decay envelope
        # Shape: [num_atoms, final_layer_out_size]. Initialize carefully.
        # Blog uses ones, which might be okay. Let's use small values.
        self.pi = jax.random.normal(pi_key, (num_atoms, final_layer_out_size)) * 0.01
        # Sigma should be positive for decay, initialize near zero or small positive.
        self.sigma = jnp.abs(jax.random.normal(sigma_key, (num_atoms, final_layer_out_size)) * 0.01) + 0.01


    def __call__(self, pos: jax.Array) -> jax.Array:
        """ Evaluates the wavefunction psi(pos) for a single configuration."""
        # Reshape flat position array [3N] -> [N, 3]
        pos_reshaped = pos.reshape(self.n_electrons, 3)

        # --- Calculate Input Features ---
        # Simplified: Features relative to the first atom only (like blog code)
        # A better model would use features relative to all atoms.
        atom_pos_origin = jnp.zeros((1, 3)) # Assume first atom is at origin
        ae = pos_reshaped[:, None, :] - atom_pos_origin[None, :, :] # Shape [N, 1, 3]
        r_ae = jnp.linalg.norm(ae, axis=2, keepdims=True) # Shape [N, 1, 1]

        # Feature for spins: +1 for up, -1 for down [N, 1]
        n_up, n_down = self.spins
        spin_feat = jnp.concatenate([
            jnp.ones(n_up),
            jnp.ones(n_down) * -1
        ])[:, None] # Reshape to [N, 1]

        # Combine features: [N, 5]
        h = jnp.concatenate([ae.squeeze(axis=1), r_ae.squeeze(axis=1), spin_feat], axis=1)
        assert h.shape == (self.n_electrons, 5) # 3 + 1 + 1 = 5

        # --- Pass through MLP ---
        # Multi-layer perceptron with tanh activations
        for layer in self.linears:
            h = jnp.tanh(layer(h)) # Shape: [N, hidden_sizes[-1]]

        # --- Construct Orbitals & Determinants ---
        # Final linear layer output: [N, n_electrons * n_determinants]
        final_out = self.orbitals(h)

        # Apply Slater-like exponential decay envelope (matching blog structure)
        # Shapes: final_out: [N, N*D], pi: [A, N*D], sigma: [A, N*D], r_ae: [N, 1, 1]
        # where N=n_electrons, D=n_determinants, A=num_atoms
        # Broadcasting: sigma * r_ae -> [A, N*D] * [N, 1, 1]? JAX handles this as [N, A, N*D].
        # Check JAX broadcasting rules carefully if implementing a multi-atom version.
        # For num_atoms=1: pi/sigma shape [1, N*D]. r_ae shape [N, 1, 1].
        # sigma * r_ae -> [N, 1, N*D]. exp(...) -> [N, 1, N*D]. pi * exp(...) -> [N, 1, N*D]
        # sum(..., axis=1) sums over the atom dimension (size 1) -> [N, N*D]
        envelope = jnp.sum(self.pi * jnp.exp(-self.sigma * r_ae), axis=1) # Sum over atom dimension
        phi = final_out * envelope # Element-wise: [N, N*D] * [N, N*D] -> [N, N*D]

        # Reshape phi to form determinant matrices:
        # phi: [N_electrons, N_electrons * N_determinants]
        # Reshape -> [N_determinants, N_electrons, N_electrons]
        # Blog code: reshape(N, D, N).transpose(1, 0, 2) -> [D, N, N]
        phi_reshaped = phi.reshape(self.n_electrons, self.n_determinants, self.n_electrons)
        det_matrices = phi_reshaped.transpose(1, 0, 2) # Shape [n_determinants, n_electrons, n_electrons]

        # Calculate determinants for each matrix in the batch (of determinants)
        # jnp.linalg.det expects [..., N, N], returns [...]
        # Use vmap if you were batching configurations, but here we batch over determinants.
        # So, directly apply det to the leading dimensions.
        sign, logabsdet = jnp.linalg.slogdet(det_matrices) # More stable: gives sign and log(|det|)
        # determinants = sign * jnp.exp(logabsdet) # Shape [n_determinants]
        # For simplicity matching blog, assume standard det is stable enough (might not be!)
        determinants = jnp.linalg.det(det_matrices) # Shape [n_determinants]


        # The wavefunction value is the sum of the determinants
        psi_value = jnp.sum(determinants)

        return psi_value

### 7.2. Finite Integral & Asymptotic Behavior

Since $|\psi_\theta(X)|^2$ represents a probability density, its integral over all space must be finite (so it can be normalized to 1). Physically, we expect the probability of finding an electron infinitely far from the nuclei to be zero.

**Implementation:** This is typically ensured by making $\psi_\theta(X) \to 0$ as any electron position $|\mathbf{r}_i| \to \infty$. The `PsiMLP` model attempts this by multiplying the MLP output by an exponential decay factor (the `envelope`) related to the electron-nucleus distances $r_{ae}$, inspired by Slater-type orbitals (STOs). The `exp(-sigma * r_ae)` term in the `envelope` calculation within the `PsiMLP.__call__` method is designed to enforce this decay. The learnable parameters `pi` and `sigma` control the shape of this decay.

## 8. The Improvement: Jastrow Factor

The basic Slater determinant construction (even with an MLP and envelope) often struggles to capture the intricate correlations between electrons caused by their mutual repulsion ($1/|\mathbf{r}_i - \mathbf{r}_j|$ term in $\hat{H}$). Electrons tend to avoid each other (Coulomb hole), and this affects the wavefunction's shape, particularly when two electrons get close.

The **Jastrow factor**, $e^{J(X)}$, is a symmetric function of electron coordinates multiplied onto the determinantal part to explicitly introduce these correlations:

$$
\psi(X) = e^{J(X)} \times \psi_{\text{determinant}}(X)
$$

* It's **symmetric** ($J$ doesn't change if $x_i$ and $x_j$ are swapped), so it doesn't break the required antisymmetry of the determinant part.
* It typically depends on inter-particle distances ($|\mathbf{r}_i - \mathbf{r}_j|$, $|\mathbf{r}_i - \mathbf{R}_I|$).
* It helps enforce **cusp conditions**: the wavefunction should have specific non-analytic behavior (a "cusp") when two charged particles meet, related to the divergence of the Coulomb potential. Satisfying these conditions significantly improves accuracy.

The blog uses the **Padé-Jastrow** form which depends only on electron-electron distances $r_{ij} = |\mathbf{r}_i - \mathbf{r}_j|$ and incorporates the electron-electron cusp conditions:

$$
J(X) = \sum_{i<j} \frac{\alpha_{ij} r_{ij}}{1 + \beta r_{ij}}
$$

* $\alpha_{ij}$ depends on whether electrons $i$ and $j$ have the same spin ($\alpha = 1/4$ for parallel spins) or opposite spins ($\alpha = 1/2$ for antiparallel spins). These values ensure the wavefunction satisfies the correct cusp condition when $r_{ij} \to 0$.
* $\beta$ is a single additional learnable parameter that controls the range of the correlation effect.

**Implementation:** We extend the `PsiMLP` class to include this factor.


In [None]:
class PsiMLPJastrow(PsiMLP):
    """ PsiMLP model extended with a Padé-Jastrow factor for e-e correlations."""

    beta: jax.Array # Single learnable parameter for the Jastrow factor

    def __init__(self, *args, **kwargs):
        # Initialize the base PsiMLP part
        # Need a key for the base class initialization
        base_key, beta_key = jax.random.split(args[-1], 2) # Assume last arg is key
        new_args = args[:-1] + (base_key,)
        super().__init__(*new_args, **kwargs)

        # Initialize the Jastrow parameter beta (e.g., to 1.0 or small value)
        # Should be positive.
        self.beta = jnp.array(1.0) # As used in blog post example
        # self.beta = jnp.abs(jax.random.normal(beta_key)) # Alternative random init

    def __call__(self, pos: jax.Array) -> jax.Array:
        """Evaluates psi(pos) = exp(J(pos)) * PsiMLP(pos)."""
        # 1. Calculate the determinant part using the parent PsiMLP's __call__
        # Note: super().__call__(pos) calls the PsiMLP.__call__ method
        det_sum = super().__call__(pos)

        # 2. Calculate the Jastrow factor exponent J(X)
        pos_reshaped = pos.reshape(self.n_electrons, 3)
        n_up = self.spins[0]

        if self.n_electrons > 1:
            i_indices, j_indices = jnp.triu_indices(self.n_electrons, k=1) # Pairs (i<j)

            # Electron-electron distances |r_i - r_j| for all pairs
            r_ee = jnp.linalg.norm(pos_reshaped[i_indices] - pos_reshaped[j_indices], axis=1)

            # Determine alpha based on spins of electron pairs
            # alpha = 1/4 for same spin, 1/2 for opposite spin
            # Check if both i and j are spin up OR both are spin down
            same_spin_condition = ((i_indices < n_up) & (j_indices < n_up)) | \
                                  ((i_indices >= n_up) & (j_indices >= n_up))
            alpha = jnp.where(same_spin_condition, 0.25, 0.5)

            # Calculate the sum in the Jastrow exponent
            # J = sum_{i<j} [ alpha * r_ee / (1 + beta * r_ee) ]
            # Add epsilon to denominator for stability if beta*r_ee can be near -1 (beta should be positive)
            epsilon_jastrow = 1e-8
            jastrow_exponent = jnp.sum(alpha * r_ee / (1.0 + self.beta * r_ee + epsilon_jastrow))
        else:
            jastrow_exponent = 0.0 # No Jastrow term for single electron

        # The Jastrow factor is exp(J)
        jastrow_factor = jnp.exp(jastrow_exponent)

        # 3. Multiply determinant sum by Jastrow factor
        psi_value = det_sum * jastrow_factor

        return psi_value

Adding this relatively simple Jastrow factor significantly improves accuracy by explicitly modeling electron correlation and satisfying cusp conditions, allowing the model to reach energies much closer to the true ground state, as demonstrated for the Lithium atom in the original blog post.


## 9. Training Overview

The complete VMC process involves these steps, typically run iteratively:

1.  **Initialization:**
    * Initialize the neural network wavefunction parameters $\theta$ (e.g., `PsiMLPJastrow` model).
    * Initialize an optimizer (e.g., `optax.adam`).
    * Initialize a batch of random electron configurations $X_0$ (walkers), often sampled from a simple distribution (like Gaussian clouds around atoms).

2.  **Warm-up MCMC:** Run the Metropolis-Hastings algorithm (`metropolis_mcmc`) for a number of steps (`warmup_steps`) starting from $X_0$ using the initial $\psi_\theta$ to let the walkers "thermalize" and reach regions where $|\psi_\theta|^2$ is significant. Discard these initial steps. Let the positions after warmup be $X_{warm}$.

3.  **Training Loop** (for `num_iterations`):
    * **a. MCMC Sampling:** Run Metropolis-Hastings (`metropolis_mcmc`) starting from the previous step's positions $X_{t-1}$ for `mcmc_steps` using the *current* wavefunction $\psi_\theta$ to generate a batch of configurations $X_t$ sampled approximately from $|\psi_\theta|^2$. Keep track of the acceptance rate.
    * **b. Loss & Gradient Calculation:** Compute the loss $L(\theta)$ (average energy) and its unbiased gradient $\nabla_\theta L(\theta)$ using the sampled configurations $X_t$ and the `make_loss_and_grad` function (which uses `local_energy` and the custom VMC gradient rule). This involves calculating $E_{local}(X_t; \theta)$ for each walker and applying the gradient formula from Section 6.
    * **c. Parameter Update:** Update the wavefunction parameters $\theta$ using the optimizer (e.g., Adam) and the calculated gradient: $\theta_{new} \leftarrow \text{optimizer_update}(\nabla_\theta L(\theta), \theta, \text{optimizer_state})$. Update the optimizer state.
    * **d. Logging:** Record the energy (loss), acceptance rate, and potentially other metrics.
    * **e. Repeat:** Go back to step (a) with the updated $\theta_{new}$ and the latest walker positions $X_t$ as the starting point for the next MCMC run.

4.  **Convergence:** Continue the training loop until the energy $L(\theta)$ converges to a stable value (hopefully close to the true ground state energy $E_0$).

The `vmc` function in the original `blog.py` script orchestrates this entire process, including setting up the Adam optimizer (`optax.adam`) and managing the training state (model parameters, optimizer state, walker positions, random key).

## Conclusion

This notebook-style walkthrough covered the core mathematical and computational components of the Neural Variational Monte Carlo method as presented in Teddy Koker's blog post:

* **Goal:** Solve the Schrödinger equation's ground state using the **Variational Principle**.
* **Method:** Minimize the expectation value of the **Local Energy** $E_{local} = \hat{H}\psi / \psi$, which provides an upper bound to the true energy.
* **Sampling:** Use **Metropolis-Hastings MCMC** to sample electron configurations $X$ from the probability distribution $|\psi_\theta(X)|^2$ defined by the trial wavefunction.
* **Model:** Parameterize the trial wavefunction $\psi_\theta$ using a **Neural Network** (e.g., `PsiMLP` + `Jastrow`), ensuring **antisymmetry** via Slater determinants and incorporating physics like **cusp conditions** and **asymptotic decay**.
* **Optimization:** Use **gradient descent** with a specific **unbiased gradient estimator** for the VMC loss, calculated efficiently using automatic differentiation with a custom gradient rule (`filter_custom_jvp`).

By combining these elements – variational optimization, Monte Carlo sampling, expressive neural network ansätze respecting physics, and correct gradient estimation – Neural VMC provides a powerful and increasingly popular framework for finding accurate solutions to quantum many-body problems in chemistry and physics.
