In [1]:
# Import necessary libraries
import netket.experimental as nkx
import functools
from functools import partial
import flax.linen as nn
import numpy as np
import jax.numpy as jnp
import flax
import optax
import csv
import numpy as np
import os
import netket.experimental as nkx
import sys 
from math import pi
import json 

os.environ["JAX_PLATFORM_NAME"] = "cuda"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
# Check NetKet installation and print version
import netket as nk
print(f"NetKet version: {nk.__version__}")
# Print available JAX devices for the current process
print(jax.devices())


L = 4
# Build square lattice with nearest and next-nearest neighbor edges
lattice = nk.graph.Square(L, max_neighbor_order=2)
hi = nk.hilbert.Spin(s=1 / 2, N=lattice.n_nodes, inverted_ordering=False)
# Heisenberg with coupling J=1.0 for nearest neighbors
# and J=0.5 for next-nearest neighbors
H = nk.operator.Heisenberg(hilbert=hi, graph=lattice, J=[1.0, 0.51])

sparse_ham = H.to_sparse()
sparse_ham.shape

from scipy.sparse.linalg import eigsh

eig_vals, eig_vecs = eigsh(sparse_ham, k=2, which="SA")

print("eigenvalues with scipy sparse:", eig_vals)

E_gs = eig_vals[0]
print("Ground state energy from Exact Diagonalization:", E_gs)

class FFN(nn.Module):
    alpha : int = 1
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.alpha * x.shape[-1], 
                     use_bias=True, 
                     param_dtype=np.complex128, 
                     kernel_init=nn.initializers.normal(stddev=0.01), 
                     bias_init=nn.initializers.normal(stddev=0.01)
                    )(x)
        x = nk.nn.log_cosh(x)
        x = jnp.sum(x, axis=-1)
        return x

model = FFN(alpha=2)
sampler = nk.sampler.MetropolisLocal(hi)
vstate = nk.vqs.MCState(sampler, model, n_samples=1024)

optimizer = nk.optimizer.Sgd(learning_rate=0.01)

sr = nk.optimizer.SR(diag_shift=1e-6, holomorphic=False)
# Notice the use, again of Stochastic Reconfiguration, which considerably improves the optimisation
gs = nk.driver.VMC(H, optimizer, variational_state=vstate,preconditioner=sr)

log=nk.logging.RuntimeLog()
gs.run(n_iter=100,out=log)

ffn_energy=vstate.expect(H)
error=abs((ffn_energy.mean-E_gs)/E_gs)
print("Optimized energy and relative error: ",ffn_energy,error)

  from .autonotebook import tqdm as notebook_tqdm


NetKet version: 3.15.0
[CudaDevice(id=0)]
eigenvalues with scipy sparse: [-33.73952093 -32.60188042]
Ground state energy from Exact Diagonalization: -33.73952093296431


100%|██████████| 100/100 [00:12<00:00,  8.04it/s, Energy=-23.00-0.32j ± 0.13 [σ²=10.04, R̂=1.0939]]    


Optimized energy and relative error:  -22.72+0.21j ± 0.31 [σ²=95.00, R̂=1.0127] 0.3267585996206149


In [None]:
import netket as nk
import numpy as np
import json
from math import pi
import jax
import jax.numpy as jnp

L = 4
# Build square lattice with nearest and next-nearest neighbor edges
lattice = nk.graph.Square(L, max_neighbor_order=2)
hi = nk.hilbert.Spin(s=1 / 2, total_sz=0, N=lattice.n_nodes, inverted_ordering=False)
# Heisenberg with coupling J=1.0 for nearest neighbors
# and J=0.5 for next-nearest neighbors
H = nk.operator.Heisenberg(hilbert=hi, graph=lattice, J=[1.0, 0.5])



input = hi.random_state(size=64, key=jax.random.PRNGKey(1))
input = jnp.array(input)
print(input)
print(input.shape)


# Print model architecture and parameter count
def print_model_info(model):
    """Prints the model architecture and parameter count."""
    key = jax.random.PRNGKey(0)
    params = model.init(key, input)
    print("Model Architecture:")
    print(model)
    total_params = sum(jnp.prod(jnp.array(param.shape)) for param in jax.tree_util.tree_leaves(params))
    print(f"Total Trainable Parameters: {total_params}")

# Find an approximate ground state
machine = nk.models.GCNN(
    symmetries=lattice,
    parity=1,
    layers=4,
    features=4,
    param_dtype=complex,
)

print_model_info(machine)

n_chains = 128
n_samples = 4096
n_discard_per_chain = 10  # Number of samples to discard per chain
chunk_size = 4096

sampler = nk.sampler.MetropolisExchange(hi, n_chains=1024, graph=lattice, d_max=L, dtype=jnp.int8)
vstate = nk.vqs.MCState(sampler=sampler,model=machine,n_samples=n_samples,n_discard_per_chain=n_discard_per_chain,chunk_size=chunk_size)

opt = nk.optimizer.Sgd(learning_rate=0.01)
sr = nk.optimizer.SR(diag_shift=1e-4, holomorphic=False)

gs = nk.driver.VMC(H, opt, variational_state=vstate, preconditioner=sr)
gs.run(n_iter=100, out="ground_state")

data = json.load(open("ground_state.log"))
print("Energy averaged over last ten steps:", np.mean(data["Energy"]["Mean"]["real"][-10:]))
print("Energy per site averaged over last ten steps:", np.mean(data["Energy"]["Mean"]["real"][-10:]) / (lattice.n_nodes))
print("Energy std over last ten steps:", np.std(data["Energy"]["Mean"]["real"][-10:]) / (lattice.n_nodes))

from scipy.sparse.linalg import eigsh

eig_vals, eig_vecs = eigsh(H.to_sparse(), k=2, which="SA")

print("eigenvalues with scipy sparse:", eig_vals)

E_gs = eig_vals[0]

[[-1. -1. -1. ...  1.  1.  1.]
 [-1. -1.  1. ...  1.  1.  1.]
 [-1. -1. -1. ... -1.  1.  1.]
 ...
 [-1.  1.  1. ...  1. -1. -1.]
 [-1.  1.  1. ...  1.  1.  1.]
 [ 1. -1. -1. ... -1.  1. -1.]]
(64, 16)
Model Architecture:
GCNN_Parity_FFT(
    # attributes
    symmetries = HashableArray([[ 0  1  2 ... 13 14 15]
     [ 0  3  2 ... 15 14 13]
     [ 0  1  2 ...  5  6  7]
     ...
     [ 7 11 15 ...  8 12  0]
     [13  9  5 ...  8  4  0]
     [15 11  7 ...  8  4  0]],
     shape=(128, 16), dtype=int64, hash=-7783991654347323897)
    product_table = HashableArray([[  0   1   2 ... 125 126 127]
     [  1   0   3 ... 111 108 109]
     [  2   3   0 ...  60  63  62]
     ...
     [ 62  63  60 ...   0   3   2]
     [109 108 111 ...   3   0   1]
     [127 126 125 ...   2   1   0]],
     shape=(128, 128), dtype=int64, hash=-6398037290674299010)
    shape = (np.int64(4), np.int64(4))
    layers = 4
    features = (4, 4, 4, 4)
    characters = HashableArray([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1

  0%|          | 0/100 [00:00<?, ?it/s]