# Rayleigh-Gauss-Newton optimization: implementing the necessary elements

In this notebook, we look to implement using NetKet the Rayleigh-Gauss-Newton optimization presented in the paper of Robert J. Webber and Michael Lindsey.

Recall the following equation:
$$
\Delta \vec{\theta} = P^{-1} \vec{F}.
$$
For SR, we went up to first order in our expansion. If we go to second order, we retrieve the Rayleigh-Gauss-Newton method:
$$
\begin{array}{|c|c|c|}
\hline
\mathrm{Method} & \mathrm{Preconditionner \,} P\\
\hline 
\mathrm{Gradient \, descent} & \epsilon^{-1} I \\
\mathrm{Natural \, gradient \, descent} & \epsilon^{-1}(S + \eta I) \\
\mathrm{Rayleigh-Gauss-Newton} & H + \epsilon^{-1}(S + \eta I) \\
\hline
\end{array}
$$

$H$ is the Hessian. However, computing the Hessian can be expensive, and thus we approximate it by: 
$$
\hat{H_{ij}} = \mathrm{Cov}_{\hat{\rho}}[\nu_i(\sigma), E_{L,j}(\sigma)] - \hat{g}_i \mathbb{E}_{\hat{\rho}}[\nu_j (\sigma)] - \hat{\mathcal{E}} \hat{S_{ij}}
$$

where:
$$
\hat{\mathcal{E}} = \mathbb{E}_{\hat{\rho}}[E_L(\sigma)]
$$
$$
\vec{F} = \hat{g}_i = \mathrm{Cov}_{\hat{\rho}}[\nu_i(\sigma), E_L(\sigma)] = <(\nu_i(\sigma) - <\nu_i(\sigma)>) E_L(\sigma)> = < \left( \frac{\partial_{\theta_i} \braket{\sigma|\psi}}{\braket{\sigma|\psi}} - <\frac{\partial_{\theta_i} \braket{\sigma|\psi}}{\braket{\sigma|\psi}}> \right) \frac{H \braket{\sigma|\psi}}{\braket{\sigma|\psi}}> = < \frac{\partial_{\theta_i} \braket{\sigma|\psi}}{\braket{\sigma|\psi}} \left( \frac{H \braket{\sigma|\psi}}{\braket{\sigma|\psi}} - <\frac{H \braket{\sigma|\psi}}{\braket{\sigma|\psi}}> \right)>
$$
$$
\hat{S}_{ij} = \mathrm{Cov}_{\hat{\rho}}[\nu_i(\sigma), \nu_j(\sigma)]
$$

Let us start implementing everything from scratch. We first import the necessary libraries, define the system and the Hamiltonian (that we will take to be $H = \sum_{i = 1}^L S_i^z S_{i + 1}^z + \Delta \sum_{i = 1}^L \left( S_i^x S_{i + 1}^x + S_i^y S_{i + 1}^y\right)$ with $\Delta = -1$).

In [1]:
# Import netket library
import netket as nk

# Import Json, this will be needed to load log files
import json

import os

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import time

# jax and jax.numpy
import jax 
import jax.numpy as jnp

# Flax is a framework to define models using jaxxx
import flax
# we refer to `flax.linen` as `nn`. It's a repository of 
# layers, initializers and nonlinear functions.
import flax.linen as nn

from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


We first define the system, the model and the Hamiltonian, which we will take to be $H = \sum_{i = 1}^L S_i^z S_{i + 1}^z + \Delta \sum_{i = 1}^L \left( S_i^x S_{i + 1}^x + S_i^y S_{i + 1}^y\right)$. We will take $\Delta = -1$.

In [2]:
L = 10

# Define a 1d chain
g = nk.graph.Hypercube(length=L, n_dim=1, pbc=True)
    
# Define the Hilbert space based on this graph
# We impose to have a fixed total magnetization of zero 
hi = nk.hilbert.Spin(s=1/2, N=g.n_nodes)
# Initialize the model
model_RBM = nk.models.RBM(param_dtype = complex)

# Define the local sampler on the Hilbert space
sampler = nk.sampler.MetropolisLocal(hi)

# Create the variational state
vstate = nk.vqs.MCState(sampler, model_RBM, n_samples=300)

# Apply function
psi_apply_fun = vstate._apply_fun
# Log function
logpsi = vstate.log_value

# Parameters
parameters = vstate.parameters
# parameters = jax.tree_util.tree_map(lambda x: x.real, parameters)

parameters_concatenated, unconcatenate_function = nk.jax.tree_ravel(parameters)

variables = vstate.variables

# Get the samples
samples = vstate.samples
samples_2d = samples.reshape(-1, samples.shape[-1])

def generate_hamiltonian(Delta): 
    # Initialization of the Hamiltonian
    hamiltonian = nk.operator.LocalOperator(hi)
    
    # Add interaction terms with periodic boundary conditions
    for i in range(L):
        # Apply periodic boundary conditions
        j = (i + 1) % L  # Wraps around to the first site when i = L-1
    
        # Add the S^z_i S^z_j term
        hamiltonian = hamiltonian + nk.operator.spin.sigmaz(hi, i) @ nk.operator.spin.sigmaz(hi, j)
    
        # Add the Delta * (S^x_i S^x_j + S^y_i S^y_j) terms
        hamiltonian = hamiltonian + Delta * (
            nk.operator.spin.sigmax(hi, i) @ nk.operator.spin.sigmax(hi, j)
            + nk.operator.spin.sigmay(hi, i) @ nk.operator.spin.sigmay(hi, j)
        )
    return hamiltonian

delta = -1
hamiltonian = generate_hamiltonian(delta)

# Convert Hamiltonian to JAX operator
hamiltonian_jax = hamiltonian.to_pauli_strings().to_jax_operator()

def generate_hamiltonian_tfi(h_var):
    # Initialization of the Hamiltonian
    hamiltonian = nk.operator.LocalOperator(hi)
    
    # Add interaction terms with periodic boundary conditions
    for i in range(L):
        # Apply periodic boundary conditions
        j = (i + 1) % L  # Wraps around to the first site when i = L-1
    
        # Add the S^z_i S^z_j term
        hamiltonian = hamiltonian - nk.operator.spin.sigmaz(hi, i) @ nk.operator.spin.sigmaz(hi, j)
    
        # Add the Delta * (S^x_i S^x_j + S^y_i S^y_j) terms
        hamiltonian = hamiltonian - h_var * nk.operator.spin.sigmax(hi, i)
    return hamiltonian

  self.n_samples = n_samples


Let us compute $S$ and $\vec{F}$ and $\hat{\mathcal{E}} = \mathbb{E}_{\hat{\rho}}[E_L(\sigma)] = \mathbb{E}_{\hat{\rho}}[\frac{\hat{H} \psi(\sigma)}{\psi(\sigma)}] \approx \frac{1}{N_S} \sum_{\sigma | \hat{H}_{\sigma \eta} \neq 0} \hat{H}_{\sigma \eta} \frac{\psi(\eta)}{\psi(\sigma)}$.

In [3]:
# Compute the Jacobian function
jacobian = nk.jax.jacobian(psi_apply_fun, parameters, samples_2d, mode='holomorphic', dense=True)

############################################################################### Compute S

def compute_S(jacobian_var, samples_2d_var):
    O_mean = np.sum(jacobian_var, axis=0)/ samples_2d_var.shape[0]
    delta_J = (jacobian_var - O_mean)

    S = delta_J.conj().T@delta_J
    S /= samples_2d_var.shape[0]

    return S

S = compute_S(jacobian, samples_2d) # Correct, compared to qgt = nk.optimizer.qgt.QGTJacobianDense(vstate).to_dense()

############################################################################### Compute E_loc

def compute_E_loc(vstate_log_value, hamiltonian_jax_var, samples_2d_var):
    eta, H_sigmaeta = jax.vmap(hamiltonian_jax_var.get_conn_padded)(samples_2d_var)

    logpsi_sigma = vstate_log_value(samples_2d_var) # A vector of applying logpsi to each sample in samples_2d
    logpsi_eta = vstate.log_value(eta) # A vector of applying logpsi to each component of eta
    
    # Reshape log_sigma to match log_eta for broadcasting
    log_sigma_expanded = logpsi_sigma[:, None]  # Shape: (304, 1)
    
    # Compute the exponentials
    exp_values = jnp.exp(logpsi_eta - log_sigma_expanded)  # Now shape (304, 11)
    
    # Compute E_components as a sum over the weighted coefficients
    E_loc = jnp.sum(H_sigmaeta * exp_values, axis=1)  # Final shape: (304,)

    return E_loc

E_loc = compute_E_loc(logpsi, hamiltonian_jax, samples_2d) # Correct, compared with E_loc_exact = vstate.local_estimators(hamiltonian_jax).reshape(304, )

############################################################################### Compute F

def compute_F(jacobian, E_loc):
    O_mean = np.sum(jacobian, axis=0)/ samples_2d.shape[0]
    delta_J = (jacobian - O_mean)
    
    F_w = delta_J.conj().T @ E_loc  # Ensure correct conjugate transpose multiplication
    F = (F_w * 2)/(samples_2d.shape[0])
    return F

F = compute_F(jacobian, E_loc) # Correct,compared it to E_exact , F_exact = vstate.expect_and_grad(hamiltonian_jax)

############################################################################### Compute mathcal_E

def compute_mathcal_E(E_loc_var):
    return jnp.mean(E_loc_var, 0)  # Final output as a scalar

mathcal_E = compute_mathcal_E(E_loc) # Correct, compared it to E_exact , F_exact = vstate.expect_and_grad(hamiltonian_jax)

Now, let us compute $\mathbb{E}\left[ \frac{\partial_{\theta_i} \psi(\sigma)}{\psi(\sigma)} \right] \approx \frac{1}{N_S} \sum_{\sigma} \frac{\partial \log \psi(\sigma)}{\partial \theta_i}$. We will create a vector of size ($N_p$,) where the i-th component is exactly $\mathbb{E}\left[ \frac{\partial_{\theta_i} \psi(\sigma)}{\psi(\sigma)} \right]$.

In [4]:
def compute_expectation_of_derivative(jacobian_var):
     return jnp.mean(jacobian_var, axis=0)

def compute_expectation_of_derivative_full_sum(jacobian_times_p_var):
     return jnp.sum(jacobian_times_p_var, axis=0)

exp_logpsi_vector = compute_expectation_of_derivative(jacobian)

We move on to the next term. We first have to compute $E_{\sigma, i} = \sum_{\sigma'} \braket{\sigma | H | \sigma'} \exp{(\log{\psi(\sigma')} - \log{\psi(\sigma)})} \tilde{J}_{\sigma', i}$.

In [5]:
def compute_E_sigma_i(psi_apply_fun_var, hamiltonian_jax_var, samples_2d_var, variables_var):
    # Get connectivity and matrix elements for all samples
    eta, H_sigmaeta = hamiltonian_jax_var.get_conn_padded(samples_2d_var)

    num_samples, num_connections, n_sites = eta.shape

    # Compute log values for eta and sigma
    logpsi_eta = psi_apply_fun(variables_var, eta)  # Shape: (num_samples, num_connections)
    logpsi_sigma = psi_apply_fun(variables_var, samples_2d_var)[:, None]  # Shape: (num_samples, 1)

    # Compute exp(logpsi_eta - logpsi_sigma)
    exp_term = jnp.exp(logpsi_eta - logpsi_sigma)  # Shape: (num_samples, num_connections)

    tilde_jacobian = nk.jax.jacobian(psi_apply_fun_var, variables_var["params"], eta.reshape(-1, n_sites), mode='holomorphic', dense=True).reshape(num_samples, num_connections, -1)
    
    # Compute E_sigma, i
    E_sigma_i = jnp.sum(H_sigmaeta[:, :, None] * exp_term[:, :, None] * tilde_jacobian, axis=1)  # Shape: (num_samples, n_parameters)
    
    return E_sigma_i

# Compute E_sigma_i
E_sigma_i = compute_E_sigma_i(psi_apply_fun, hamiltonian_jax, samples_2d, variables)

Now we compute $\tilde{H} = \frac{1}{N_S} \sum_{\sigma} \Delta J_{\sigma, i} E_{\sigma, j}$.

In [6]:
def compute_tilde_H(jacobian_var, E_sigma_i_var, samples_2d_var):
    # Compute expectation value of the Jacobian
    O_mean = jnp.mean(jacobian_var, axis=0)  # Shape: (num_parameters,)
    
    # Compute delta_J = (Jacobian - mean)
    delta_J = jacobian_var - O_mean  # Shape: (num_samples, num_parameters)
    
    # Compute tilde_H using dot product
    tilde_H = jnp.mean(delta_J[:,:, None]*E_sigma_i_var[:, None, :], axis = 0)

    return tilde_H

def compute_tilde_H_full_sum(jacobian_times_p_var, E_sigma_i_var, samples_2d_var):
    # Compute expectation value of the Jacobian
    O_mean = jnp.sum(jacobian_times_p_var, axis=0)  # Shape: (num_parameters,)
    
    # Compute delta_J = (Jacobian - mean)
    delta_J = jacobian_times_p_var - O_mean  # Shape: (num_samples, num_parameters)
    
    # Compute tilde_H using dot product
    tilde_H = jnp.sum(delta_J[:,:, None]*E_sigma_i_var[:, None, :], axis = 0)

    return tilde_H

# Compute H̃
tilde_H = compute_tilde_H(jacobian, E_sigma_i, samples_2d)

We now put everything together:
$$
H_{ij} = \tilde{H}_{\sigma, i} - \vec{F}_i \mathbb{E}_{\hat{\rho}}[\nu_j(\sigma)] - \hat{\mathcal{E}} \hat{S}_{ij}.
$$

In [7]:
def compute_H_matrix(tilde_H_var, F_var, exp_logpsi_vector_var, mathcal_E_var, S_var):
    tmp = tilde_H_var - mathcal_E_var * S_var
    tmp -= F_var[:, None]*exp_logpsi_vector_var[None, :]
    return tmp

H_matrix = compute_H_matrix(tilde_H, F, exp_logpsi_vector, mathcal_E, S)

In [8]:
H_matrix.shape

(120, 120)

# Optimization using the Rayleigh-Gauss-Newton preconditionner
Let us now create a function that builds for us the preconditionner $P = \eta(\alpha H + (S + \lambda I))$. We fix $\eta$ snd the diag_shift to values that work well for SR. We then see the effect of alpha.

In [9]:
def compute_preconditionner(H_matrix_var, S_var, eta_var, alpha_var, diag_shift):
    return (alpha_var*H_matrix_var + (S + diag_shift * jnp.eye(S_var.shape[0])))/eta_var

We need to solve the following system:
$$
\Delta \vec{\theta} = P^{-1} \vec{F}
$$

where $\Delta \vec{\theta}$ is the unkown vector.

In [10]:
P = compute_preconditionner(H_matrix, S, 1, 0, 0)
# delta_theta = np.linalg.solve(P, F)

In [11]:
print(nk.exact.lanczos_ed(hamiltonian_jax)[0])

-18.06178541796815


In [12]:
mathcal_E.item().real

-9.755156010020093

# Comparison for different values of $\epsilon$

Let us now run the simulation for different values of epsilon and compare them to the SR optimization of NetKet.

In [13]:
def update_parameters(parameters_var, delta_theta_var):
    return parameters_var - delta_theta_var

In [None]:
delta_list = [5]

eta = 5e-03
alpha_list = [1/eta]
diag_shift_1 = 1e-03

lr_2 = 5e-02
diag_shift_2 = 1e-03

lr_3 = 5e-02

iterations = 400
sample_number_def = 5000


for delta in delta_list:
    hamiltonian = generate_hamiltonian(delta)
    hamiltonian_jax = hamiltonian.to_pauli_strings().to_jax_operator()
    
    # Store energy histories for each epsilon
    energy_histories = {}
    
    # Run for different epsilon values
    for alpha in alpha_list:
        print(f"Running for alpha = {alpha}")
    
        energy_history = []  # Reset before each run
    
        # Reinitialize the vstate before running the loop
        vstate = nk.vqs.MCState(sampler, model_RBM, n_samples = sample_number_def, seed = jax.random.PRNGKey(0))  # Ensure vstate is fresh
        
        # vstate = nk.vqs.FullSumState(hi, model_RBM, seed = jax.random.PRNGKey(0))
    
        for i in tqdm(range(iterations)): 
            # Apply function
            psi_apply_fun = vstate._apply_fun
            logpsi = vstate.log_value
    
            # Parameters
            parameters = vstate.parameters
            # parameters = jax.tree_util.tree_map(lambda x: x.real, parameters)
    
            parameters_concatenated, unconcatenate_function = nk.jax.tree_ravel(parameters)

            variables = vstate.variables
    
            # Get the samples
            samples = vstate.samples
            samples_2d = samples.reshape(-1, samples.shape[-1])

            # samples_2d = hi.all_states()
    
            # Compute the Jacobian function
            jacobian = nk.jax.jacobian(psi_apply_fun, parameters, samples_2d, mode='holomorphic', dense=True)
    
            # Compute various matrices
            S = compute_S(jacobian, samples_2d)
            E_loc = compute_E_loc(logpsi, hamiltonian_jax, samples_2d)
            F = compute_F(jacobian, E_loc)
            mathcal_E = compute_mathcal_E(E_loc)
    
            exp_logpsi_vector = compute_expectation_of_derivative(jacobian)
            E_sigma_i = compute_E_sigma_i(psi_apply_fun, hamiltonian_jax, samples_2d, variables)
            tilde_H = compute_tilde_H(jacobian, E_sigma_i, samples_2d)
    
            # Compute the preconditioner
            P = compute_preconditionner(H_matrix, S, eta, alpha, diag_shift_1)

            """
            # In case P is singular, continue otherwise
            if np.linalg.cond(P) > 1e10:  # Condition number is too high (ill-conditioned matrix)
                print(f"Warning: Matrix P is nearly singular at alpha={alpha}, step={i}. Adjusting...")
                P = P + 1e-06  * np.eye(P.shape[0])  # Add regularization
                alpha = 0
            """
            
            # Solve for parameter update
            # delta_theta = np.linalg.solve(P, F)
            # delta_theta, _, _, _ = lstsq(P, F)
            delta_theta = np.linalg.pinv(P) @ F
    
            # Update parameters
            new_pars = update_parameters(parameters_concatenated, delta_theta)
            vstate.parameters = unconcatenate_function(new_pars)
    
            # Store energy value
            energy_history.append(mathcal_E.item().real)
    
        # Store energy history for this epsilon
        energy_histories[alpha] = energy_history
        
    #####################################################################################################
    # Define folder and filename
    folder_path = "RGN_energy_log"  # Change this to your preferred folder name
    file_name = f"energy_histories_delta={delta}_alpha={alpha}_diag_shift_1={diag_shift_1}_sampling=True.json"
    
    # Create the folder if it doesn’t exist
    os.makedirs(folder_path, exist_ok=True)
    
    # Full path to the JSON file
    file_path = os.path.join(folder_path, file_name)
    
    # Save the dictionary as JSON
    with open(file_path, "w") as json_file:
        json.dump(energy_histories, json_file, indent=4)
    
    print(f"Data saved to {file_path}")
    
    #####################################################################################################

    # Initialize VMC optimization with SGD and SR preconditioner
    vstate = nk.vqs.MCState(sampler, model_RBM, n_samples = sample_number_def, seed = jax.random.PRNGKey(0))
    # vstate = nk.vqs.FullSumState(hi, model_RBM, seed = jax.random.PRNGKey(0))
    optimizer = nk.optimizer.Sgd(learning_rate=lr_2)
    
    gs = nk.driver.VMC(
        hamiltonian, optimizer, variational_state=vstate, preconditioner=nk.optimizer.SR(diag_shift=diag_shift_2, holomorphic=True)
    )
    
    # Construct the Json logger
    log_file = f"RGN_energy_log/NetKet_SR_log_delta={delta}_lr_2={lr_2}_diag_shift_2={diag_shift_2}_sampling=True.json"
    log = nk.logging.JsonLog(log_file)
    
    gs.run(n_iter=iterations, out=log)
    #####################################################################################################
    
    # Initialize VMC optimization with SGD and SR preconditioner
    vstate = nk.vqs.MCState(sampler, model_RBM, n_samples = sample_number_def, seed = jax.random.PRNGKey(0))
    # vstate = nk.vqs.FullSumState(hi, model_RBM, seed = jax.random.PRNGKey(0))
    optimizer = nk.optimizer.Sgd(learning_rate=lr_3)
    
    gs = nk.driver.VMC(
        hamiltonian, optimizer, variational_state=vstate
    )
    
    # Construct the Json logger
    log_file = f"RGN_energy_log/NetKet_GD_log_delta={delta}_lr_2={lr_2}_diag_shift_2={diag_shift_2}_sampling=True.json"
    log = nk.logging.JsonLog(log_file)
    
    gs.run(n_iter=iterations, out=log)

  self.n_samples = n_samples


Running for alpha = 200.0


 94%|██████████████████████████████████████▎  | 374/400 [01:17<00:05,  5.01it/s]

In [None]:
# Define folder and filenames
folder_path = "RGN_energy_log"
energy_file = f"energy_histories_delta={delta}_alpha={alpha}_diag_shift_1={diag_shift_1}_sampling=True.json"
sr_file = f"NetKet_SR_log_delta={delta}_lr_2={lr_2}_diag_shift_2={diag_shift_2}_sampling=True.json.log"
gd_file = f"NetKet_GD_log_delta={delta}_lr_2={lr_2}_diag_shift_2={diag_shift_2}_sampling=True.json.log"
energy_path = os.path.join(folder_path, energy_file)
sr_path = os.path.join(folder_path, sr_file)
gd_path = os.path.join(folder_path, gd_file)

# Load the energy history JSON file
with open(energy_path, "r") as json_file:
    energy_histories = json.load(json_file)

# Load the SR optimization JSON file
with open(sr_path, "r") as json_file:
    sr_data = json.load(json_file)
    sr_iters = sr_data["Energy"]["iters"]
    sr_energy = sr_data["Energy"]["Mean"]["real"]

# Load the GD optimization JSON file
with open(gd_path, "r") as json_file:
    gd_data = json.load(json_file)
    gd_iters = gd_data["Energy"]["iters"]    
    gd_energy = gd_data["Energy"]["Mean"]["real"]

# Define a color map for different epsilon values
colors = plt.cm.viridis(np.linspace(0, 1, len(energy_histories)))

# Compute the exact ground state energy
exact_energy = nk.exact.lanczos_ed(hamiltonian_jax)[0]

# Plot energy curves
plt.figure(figsize=(10, 6))
for (epsilon, energy_history), color in zip(energy_histories.items(), colors):
    plt.plot(range(len(energy_history)), np.abs(np.array(energy_history)-exact_energy), label=f"alpha = {alpha}", color=color)

# Plot SR optimization curve
plt.plot(sr_iters, np.abs(np.array(sr_energy)-exact_energy), label="SR Optimization (NetKet)", color="red", linestyle="--", linewidth=2)

# Plot SR optimization curve
plt.plot(gd_iters, np.abs(np.array(gd_energy)-exact_energy), label="GD Optimization (NetKet)", color="blue", linestyle="--", linewidth=2)

# Labels and title
plt.xlabel("Iteration Steps", fontsize=20)
plt.ylabel("Energy", fontsize=20)
plt.yscale("log")
plt.title(f"Relative Error for Different Epsilon and NetKet SR (Delta = {delta}, alpha = {alpha}, eta = {eta}, lr_2 = {lr_2}, diag_shift_1 = {diag_shift_1}, diag_shift_2 = {diag_shift_2})")
plt.legend()
plt.grid()

# plt.xlim(0, 300)

# Show the plot
plt.show()

# Full summation

Same as before, but we sum over all the elements of the Hilbert space. A subtelty we have to take care of is that we have now that:
$$
\mathbb{E}[E_{\text{loc}}(\sigma)] = \sum_{\sigma} p(\sigma) E_{\text{loc}} = \sum_{\sigma} \frac{|\braket{\sigma | \psi}|^2}{\braket{\psi | \psi}} E_{\text{loc}}
$$

In [None]:
delta_list = [10]

eta = 5e-03
alpha_list = [eta]
diag_shift_1 = 1e-06

lr_2 = 5e-03
diag_shift_2 = 1e-02

lr_3 = 5e-03

iterations = 1000

for delta in delta_list:
    hamiltonian = generate_hamiltonian(delta)
    hamiltonian_jax = hamiltonian.to_pauli_strings().to_jax_operator()
    
    # Store energy histories for each epsilon
    energy_histories = {}
    
    # Run for different epsilon values
    for alpha in alpha_list:
        print(f"Running for alpha = {alpha}")
    
        energy_history = []  # Reset before each run
    
        # Reinitialize the vstate before running the loop
        # vstate = nk.vqs.MCState(sampler, model_RBM, n_samples = sample_number_def, seed = jax.random.PRNGKey(0))  # Ensure vstate is fresh
        
        vstate = nk.vqs.FullSumState(hi, model_RBM, seed = jax.random.PRNGKey(0))
    
        for i in tqdm(range(iterations)): 
            # Apply function
            psi_apply_fun = vstate._apply_fun
            logpsi = vstate.log_value
    
            # Parameters
            parameters = vstate.parameters
            # parameters = jax.tree_util.tree_map(lambda x: x.real, parameters)
    
            parameters_concatenated, unconcatenate_function = nk.jax.tree_ravel(parameters)

            variables = vstate.variables
    
            # Get the samples
            # samples = vstate.samples
            # samples_2d = samples.reshape(-1, samples.shape[-1])

            samples_2d = hi.all_states()
    
            # Compute the Jacobian function
            jacobian = nk.jax.jacobian(psi_apply_fun, parameters, samples_2d, mode='holomorphic', dense=True)
            p = vstate.probability_distribution().reshape(-1, 1)  # Shape becomes (1024, 1)
            jacobian_times_p = p * jacobian  # Now shapes are (1024, 1) * (1024, 120)

    
            # Compute various matrices
            S = nk.optimizer.qgt.QGTJacobianDense(vstate, holomorphic = True).to_dense()
            E_loc = compute_E_loc(logpsi, hamiltonian_jax, samples_2d)
            E_exact , F = vstate.expect_and_grad(hamiltonian_jax)
            F, _ = nk.jax.tree_ravel(F)
            mathcal_E = jnp.sum(vstate.probability_distribution()*E_loc)
    
            exp_logpsi_vector = compute_expectation_of_derivative_full_sum(jacobian_times_p)
            E_sigma_i = compute_E_sigma_i(psi_apply_fun, hamiltonian_jax, samples_2d, variables)
            tilde_H = compute_tilde_H_full_sum(jacobian_times_p, E_sigma_i, samples_2d)
            H_matrix = compute_H_matrix(tilde_H, F, exp_logpsi_vector, mathcal_E, S)
    
            # Compute the preconditioner
            P = compute_preconditionner(H_matrix, S, eta, alpha, diag_shift_1)
    
            delta_theta = np.linalg.pinv(P) @ F
    
            # Update parameters
            new_pars = update_parameters(parameters_concatenated, delta_theta)
            vstate.parameters = unconcatenate_function(new_pars)
    
            # Store energy value
            energy_history.append(mathcal_E.item().real)
    
        # Store energy history for this epsilon
        energy_histories[alpha] = energy_history

    #####################################################################################################
    # Define folder and filename
    folder_path = "RGN_energy_log"  # Change this to your preferred folder name
    file_name = f"energy_histories_delta={delta}_alpha={alpha}_diag_shift_1={diag_shift_1}.json"
    
    # Create the folder if it doesn’t exist
    os.makedirs(folder_path, exist_ok=True)
    
    # Full path to the JSON file
    file_path = os.path.join(folder_path, file_name)
    
    # Save the dictionary as JSON
    with open(file_path, "w") as json_file:
        json.dump(energy_histories, json_file, indent=4)
    
    print(f"Data saved to {file_path}")

    #####################################################################################################
    # Initialize VMC optimization with SGD and SR preconditioner
    # vstate = nk.vqs.MCState(sampler, model_RBM, n_samples = sample_number_def, seed = jax.random.PRNGKey(0))
    vstate = nk.vqs.FullSumState(hi, model_RBM, seed = jax.random.PRNGKey(0))
    optimizer = nk.optimizer.Sgd(learning_rate=lr_2)
    
    gs = nk.driver.VMC(
        hamiltonian, optimizer, variational_state=vstate, preconditioner=nk.optimizer.SR(diag_shift=diag_shift_2, holomorphic=True)
    )
    
    # Construct the Json logger
    log_file = f"RGN_energy_log/NetKet_SR_log_delta={delta}_lr_2={lr_2}_diag_shift_2={diag_shift_2}.json"
    log = nk.logging.JsonLog(log_file)
    
    gs.run(n_iter=iterations, out=log)
    #####################################################################################################
    
    # Initialize VMC optimization with SGD and SR preconditioner
    # vstate = nk.vqs.MCState(sampler, model_RBM, n_samples = sample_number_def, seed = jax.random.PRNGKey(0))
    vstate = nk.vqs.FullSumState(hi, model_RBM, seed = jax.random.PRNGKey(0))
    optimizer = nk.optimizer.Sgd(learning_rate=lr_3)
    
    gs = nk.driver.VMC(
        hamiltonian, optimizer, variational_state=vstate
    )
    
    # Construct the Json logger
    log_file = f"RGN_energy_log/NetKet_GD_log_delta={delta}_lr_2={lr_2}_diag_shift_2={diag_shift_2}.json"
    log = nk.logging.JsonLog(log_file)
    
    gs.run(n_iter=iterations, out=log)

In [None]:
# Define folder and filenames
folder_path = "RGN_energy_log"
energy_file = f"energy_histories_delta={delta}_alpha={alpha}_diag_shift_1={diag_shift_1}.json"
sr_file = f"NetKet_SR_log_delta={delta}_lr_2={lr_2}_diag_shift_2={diag_shift_2}.json.log"
gd_file = f"NetKet_GD_log_delta={delta}_lr_2={lr_2}_diag_shift_2={diag_shift_2}.json.log"
energy_path = os.path.join(folder_path, energy_file)
sr_path = os.path.join(folder_path, sr_file)
gd_path = os.path.join(folder_path, gd_file)

# Load the energy history JSON file
with open(energy_path, "r") as json_file:
    energy_histories = json.load(json_file)

# Load the SR optimization JSON file
with open(sr_path, "r") as json_file:
    sr_data = json.load(json_file)
    sr_iters = sr_data["Energy"]["iters"]
    sr_energy = sr_data["Energy"]["Mean"]["real"]

# Load the GD optimization JSON file
with open(gd_path, "r") as json_file:
    gd_data = json.load(json_file)
    gd_iters = gd_data["Energy"]["iters"]
    gd_energy = gd_data["Energy"]["Mean"]["real"]

# Define a color map for different epsilon values
colors = plt.cm.viridis(np.linspace(0, 1, len(energy_histories)))

# Compute the exact ground state energy
exact_energy = nk.exact.lanczos_ed(hamiltonian_jax)[0]

# Plot energy curves
plt.figure(figsize=(10, 6))
for (epsilon, energy_history), color in zip(energy_histories.items(), colors):
    plt.plot(range(len(energy_history)), np.abs(np.array(energy_history)-exact_energy), label=f"RGN", color=color)

# Plot SR optimization cu
plt.plot(sr_iters, np.abs(np.array(sr_energy)-exact_energy), label="SR Optimization (NetKet)", color="red", linestyle="--", linewidth=2)

# Plot SR optimization curve
plt.plot(gd_iters, np.abs(np.array(gd_energy)-exact_energy), label="GD Optimization (NetKet)", color="blue", linestyle="--", linewidth=2)

# Labels and title
plt.xlabel("Iteration Steps", fontsize=20)
plt.ylabel("Relative Error", fontsize=20)
plt.yscale("log")
plt.title(f"Plot of the Relative Error for different optimization techniques (Full Sum), Delta={delta}", fontsize=13)
plt.legend()
plt.grid()

# Show the plot
plt.show()