In [1]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
import netket as nk
import numpy as np
import matplotlib.pyplot as plt

In [2]:
import netket.nn as nknn
import flax.linen as nn
import jax.numpy as jnp

class FFNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=2*x.shape[-1], 
                     use_bias=True, 
                     param_dtype=np.complex128, 
                     kernel_init=nn.initializers.normal(stddev=0.01), 
                     bias_init=nn.initializers.normal(stddev=0.01)
                    )(x)
        x = nknn.log_cosh(x)
        x = jnp.sum(x, axis=-1)
        return x

In [3]:
#Couplings J1 and J2
J = [1, 0.2]
L = 8
# Define custom graph
edge_colors = []
for i in range(L):
    edge_colors.append([i, (i+1)%L, 1])
    edge_colors.append([i, (i+2)%L, 2])
# Define the netket graph object
g = nk.graph.Graph(edges=edge_colors)
#Sigma^z*Sigma^z interactions
sigmaz = [[1, 0], [0, -1]]
mszsz = (np.kron(sigmaz, sigmaz))
#Exchange interactions
exchange = np.asarray([[0, 0, 0, 0], [0, 0, 2, 0], [0, 2, 0, 0], [0, 0, 0, 0]])
bond_operator = [
    (J[0] * mszsz).tolist(),
    (J[1] * mszsz).tolist(),
    (-J[0] * exchange).tolist(),  
    (J[1] * exchange).tolist(),
]
bond_color = [1, 2, 1, 2]
#Sigma^z*Sigma^z interactions
sigmaz = [[1, 0], [0, -1]]
mszsz = (np.kron(sigmaz, sigmaz))
#Exchange interactions
exchange = np.asarray([[0, 0, 0, 0], [0, 0, 2, 0], [0, 2, 0, 0], [0, 0, 0, 0]])
bond_operator = [
    (J[0] * mszsz).tolist(),
    (J[1] * mszsz).tolist(),
    (-J[0] * exchange).tolist(),  
    (J[1] * exchange).tolist(),
]
bond_color = [1, 2, 1, 2]

# Spin based Hilbert Space
hi = nk.hilbert.Spin(s=0.5, total_sz=0.0, N=g.n_nodes)
# Custom Hamiltonian operator
op = nk.operator.GraphOperator(hi, graph=g, bond_ops=bond_operator, bond_ops_colors=bond_color)

In [5]:
model = FFNN()

In [6]:
# We shall use an exchange Sampler which preserves the global magnetization (as this is a conserved quantity in the model)
sa = nk.sampler.MetropolisExchange(hilbert=hi, graph=g, d_max = 2)
# Construct the variational state
vs = nk.vqs.MCState(sa, model, n_samples=1008)

In [53]:
def bias(vs):
    head = list(vs.parameters.keys())[0]
    body = list(vs.parameters[head].keys())
    bias = vs.parameters[head][body[0]]
    return list(bias)

In [81]:
def real(c):
    return float(np.real(c))  
def img(c):
    return float(np.imag(c))    
def r_i(c):
    return real(c),img(c)    

In [75]:
print(bias(vs)[1])
c = bias(vs)[1]
d = bias(vs)[1]

(0.006194963183593213-0.009955006159631892j)


In [82]:
print(r_i(c))

(0.006194963183593213, -0.009955006159631892)


In [70]:
real(c)

0.006194963183593213

In [76]:
img(d)

-0.009955006159631892

In [7]:
vs.parameters.keys()

dict_keys(['Dense_0'])

In [16]:
print('n:',len(list(vs.parameters.keys())),';','list:',(list(vs.parameters.keys())))

n: 1 ; list: ['Dense_0']


In [20]:
head = list(vs.parameters.keys())[0]

In [21]:
vs.parameters[head].keys()

dict_keys(['bias', 'kernel'])

In [26]:
body = list(vs.parameters[head].keys())

In [27]:
body

['bias', 'kernel']

In [28]:
for l in body:
    print(l)

bias
kernel


In [34]:
body[0], vs.parameters[head][body[0]]

('bias',
 Array([ 0.00328815-0.00179911j,  0.00619496-0.00995501j,
        -0.00693332+0.01156536j, -0.00104578-0.00348196j,
        -0.00694921-0.00155494j, -0.01096324+0.00516608j,
        -0.00709209-0.00829739j, -0.0009933 -0.00349228j,
        -0.0008565 +0.01301766j, -0.00733595-0.01038364j,
         0.0003041 -0.00179853j,  0.0011776 -0.00491084j,
        -0.00203536+0.01145279j, -0.00695663-0.00143488j,
        -0.00823136-0.0004503j ,  0.00394141-0.00866509j],      dtype=complex128))

In [43]:
len(list(vs.parameters[head][body[0]]))

16

In [46]:
list(vs.parameters[head][body[0]])[0]                       

Array(0.00328815-0.00179911j, dtype=complex128)

In [48]:
np.real(list(vs.parameters[head][body[0]])[0])

Array(0.00328815, dtype=float64)

In [49]:
float(np.real(list(vs.parameters[head][body[0]])[0]))

0.0032881458733503284

In [50]:
float(np.imag(list(vs.parameters[head][body[0]])[0]))

-0.001799107774068141

In [35]:
body[1], vs.parameters[head][body[1]]

('kernel',
 Array([[-1.22312651e-03+6.40885627e-03j, -2.36817418e-03-6.63413610e-03j,
         -2.52069953e-03+5.60791472e-03j, -7.36043016e-03-5.06539920e-03j,
         -8.91450170e-04+4.20144580e-03j, -1.93334328e-03+7.62155089e-04j,
         -9.44795450e-04-1.00851804e-02j, -1.25521314e-03-5.22405766e-03j,
          3.18137380e-03-1.24095499e-02j, -1.01066333e-02+9.64963355e-03j,
         -5.52574583e-03+8.78693693e-05j,  2.64726795e-03+1.50777527e-02j,
         -2.69146599e-03-2.32198276e-03j,  5.10107247e-03-1.52501419e-03j,
         -3.72063804e-03+5.02269010e-04j,  5.34610281e-03-1.91641763e-03j],
        [-1.00775701e-03+2.44660648e-04j,  1.16975892e-02-4.13453459e-03j,
         -1.96104595e-03+5.35713400e-04j,  8.27713098e-03-1.31668416e-02j,
          1.66392822e-03+1.12595901e-03j,  2.43026953e-03+9.34693749e-04j,
          4.91904362e-03-7.99409063e-03j,  1.94016898e-03-2.59468500e-03j,
         -1.44375184e-02-7.73208782e-03j,  1.07481238e-02-7.10312167e-04j,
          2.4

In [36]:
np.shape(vs.parameters[head][body[0]])

(16,)

In [37]:
np.shape(vs.parameters[head][body[1]])

(8, 16)