In [None]:
!pip -q install --upgrade pip
!pip -q install "netket" "flax" "optax" "einops" "tqdm"

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import netket as nk
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from flax import linen as nn
from tqdm import tqdm

jax.config.update("jax_enable_x64", True)
print("JAX devices:", jax.devices())

def make_j1j2_chain(L, J2, total_sz=0.0):
    J1 = 1.0
    edges = []
    for i in range(L):
        edges.append([i, (i+1)%L, 1])
        edges.append([i, (i+2)%L, 2])
    g = nk.graph.Graph(edges=edges)
    hi = nk.hilbert.Spin(s=0.5, N=L, total_sz=total_sz)
    sigmaz = np.array([[1,0],[0,-1]], dtype=np.float64)
    mszsz = np.kron(sigmaz, sigmaz)
    exchange = np.array(
        [[0,0,0,0],
         [0,0,2,0],
         [0,2,0,0],
         [0,0,0,0]], dtype=np.float64
    )
    bond_ops = [
        (J1*mszsz).tolist(),
        (J2*mszsz).tolist(),
        (-J1*exchange).tolist(),
        (J2*exchange).tolist(),
    ]
    bond_colors = [1,2,1,2]
    H = nk.operator.GraphOperator(hi, g, bond_ops=bond_ops, bond_ops_colors=bond_colors)
    return g, hi, H

In [None]:
class TransformerLogPsi(nn.Module):
    L: int
    d_model: int = 96
    n_heads: int = 4
    n_layers: int = 6
    mlp_mult: int = 4

    @nn.compact
    def __call__(self, sigma):
        x = (sigma > 0).astype(jnp.int32)
        tok = nn.Embed(num_embeddings=2, features=self.d_model)(x)
        pos = self.param("pos_embedding",
                         nn.initializers.normal(0.02),
                         (1, self.L, self.d_model))
        h = tok + pos
        for _ in range(self.n_layers):
            h_norm = nn.LayerNorm()(h)
            attn = nn.SelfAttention(
                num_heads=self.n_heads,
                qkv_features=self.d_model,
                out_features=self.d_model,
            )(h_norm)
            h = h + attn
            h2 = nn.LayerNorm()(h)
            ff = nn.Dense(self.mlp_mult*self.d_model)(h2)
            ff = nn.gelu(ff)
            ff = nn.Dense(self.d_model)(ff)
            h = h + ff
        h = nn.LayerNorm()(h)
        pooled = jnp.mean(h, axis=1)
        out = nn.Dense(2)(pooled)
        return out[...,0] + 1j*out[...,1]

In [None]:
def structure_factor(vs, L):
    samples = vs.samples
    spins = samples.reshape(-1, L)
    corr = np.zeros(L)
    for r in range(L):
        corr[r] = np.mean(spins[:,0] * spins[:,r])
    q = np.arange(L) * 2*np.pi/L
    Sq = np.abs(np.fft.fft(corr))
    return q, Sq

def exact_energy(L, J2):
    _, hi, H = make_j1j2_chain(L, J2, total_sz=0.0)
    return nk.exact.lanczos_ed(H, k=1, compute_eigenvectors=False)[0]

def run_vmc(L, J2, n_iter=250):
    g, hi, H = make_j1j2_chain(L, J2, total_sz=0.0)
    model = TransformerLogPsi(L=L)
    sampler = nk.sampler.MetropolisExchange(
        hilbert=hi,
        graph=g,
        n_chains_per_rank=64
    )
    vs = nk.vqs.MCState(
        sampler,
        model,
        n_samples=4096,
        n_discard_per_chain=128
    )
    opt = nk.optimizer.Adam(learning_rate=2e-3)
    sr = nk.optimizer.SR(diag_shift=1e-2)
    vmc = nk.driver.VMC(H, opt, variational_state=vs, preconditioner=sr)
    log = vmc.run(n_iter=n_iter, out=None)
    energy = np.array(log["Energy"]["Mean"])
    var = np.array(log["Energy"]["Variance"])
    return vs, energy, var

In [None]:
L = 24
J2_values = np.linspace(0.0, 0.7, 6)

energies = []
structure_peaks = []

for J2 in tqdm(J2_values):
    vs, e, var = run_vmc(L, J2)
    energies.append(e[-1])
    q, Sq = structure_factor(vs, L)
    structure_peaks.append(np.max(Sq))

In [None]:
L_ed = 14
J2_test = 0.5
E_ed = exact_energy(L_ed, J2_test)

vs_small, e_small, _ = run_vmc(L_ed, J2_test, n_iter=200)
E_vmc = e_small[-1]

print("ED Energy (L=14):", E_ed)
print("VMC Energy:", E_vmc)
print("Abs gap:", abs(E_vmc - E_ed))

plt.figure(figsize=(12,4))

plt.subplot(1,3,1)
plt.plot(e_small)
plt.title("Energy Convergence")

plt.subplot(1,3,2)
plt.plot(J2_values, energies, 'o-')
plt.title("Energy vs J2")

plt.subplot(1,3,3)
plt.plot(J2_values, structure_peaks, 'o-')
plt.title("Structure Factor Peak")

plt.tight_layout()
plt.show()

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/1.8 MB[0m [31m7.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.8/1.8 MB[0m [31m23.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[?25hJAX devices: [CpuDevice(id=0)]


The number of samples (4096) is less than or equal to the number of parameters (673922).

In this regime, the standard QGT-based SR formulation is inefficient and potentially unstable.
You should switch to the kernel/minSR formulation by using:

    nk.driver.VMC_SR(hamiltonian, optimizer, diag_shift=0.01, variational_state=your_state)

instead of:

    nk.driver.VMC(hamiltonian, optimizer, preconditioner=nk.optimizer.SR())

VMC_SR automatically chooses the optimal implementation and is always recommended when using SR.
This provides the same mathematical result but with much better performance and numerical stability.


-------------------------------------------------------
For more detailed informations, visit the following link:
	 https://netket.readthedocs.io/en/latest/api/errors.html
-------------------------------------------------------

  vmc = nk.driver.VMC(H, opt, variational_state=vs, preconditioner=sr)


  0%|          | 0/250 [00:00<?, ?it/s]