In [157]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
import netket as nk
import numpy as np
import matplotlib.pyplot as plt
import netket.nn as nknn
import flax.linen as nn
import jax.numpy as jnp

class FFNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=2*x.shape[-1], 
                     use_bias=True, 
                     param_dtype=np.complex128, 
                     kernel_init=nn.initializers.normal(stddev=0.01), 
                     bias_init=nn.initializers.normal(stddev=0.01)
                    )(x)
        x = nknn.log_cosh(x)
        x = jnp.sum(x, axis=-1)
        return x
    

In [168]:
class FNN(nn.Module):
    def __init__(self, features):
        super().__init__()
        # Inicializar a camada Dense no construtor
        self.dense_layer = nn.Dense(features=2 * features,
                                     use_bias=True,
                                     param_dtype=np.complex128,
                                     kernel_init=nn.initializers.normal(stddev=0.01),
                                     bias_init=nn.initializers.normal(stddev=0.01))

    def __call__(self, x):
        # Acessar a camada Dense diretamente dentro do método __call__
        x = self.dense_layer(x)
        x = nknn.log_cosh(x)
        x = jnp.sum(x, axis=-1)
        return x

    def get_kernel(self):
        # Método para acessar o kernel da camada Dense
        return self.dense_layer.kernel

    def get_bias(self):
        # Método para acessar o bias da camada Dense
        return self.dense_layer.bias


In [158]:
import warnings
warnings.filterwarnings("ignore")

In [159]:
def info(e):
    head   = list(e.parameters.keys())[0]
    body   = list(e.parameters[head].keys())
    bias   = e.parameters[head][body[0]]
    kernel = e.parameters[head][body[1]]
    return  head, body, list(bias), list(kernel)
def real(c):
    return float(np.real(c))  
def img(c):
    return float(np.imag(c))    
def r_i(c):
    return real(c),img(c)   

In [169]:
def run(it):
    print("------------",it,"-------------")
    J = [1, 0.2];L = 8;edge_colors = []
    for i in range(L):
        edge_colors.append([i, (i+1)%L, 1])
        edge_colors.append([i, (i+2)%L, 2])
    g = nk.graph.Graph(edges=edge_colors)
    sigmaz = [[1, 0], [0, -1]];  mszsz = (np.kron(sigmaz, sigmaz))
    exchange = np.asarray([[0, 0, 0, 0], [0, 0, 2, 0], [0, 2, 0, 0], [0, 0, 0, 0]])
    bond_operator = [
        (J[0] * mszsz).tolist(),    (J[1] * mszsz).tolist(),
        (-J[0] * exchange).tolist(),(J[1] * exchange).tolist(),
        ]
    bond_color = [1, 2, 1, 2];   sigmaz = [[1, 0], [0, -1]]
    mszsz = (np.kron(sigmaz, sigmaz));
    exchange = np.asarray([[0, 0, 0, 0], [0, 0, 2, 0], [0, 2, 0, 0], [0, 0, 0, 0]])
    bond_operator = [
        (J[0] * mszsz).tolist(),    (J[1] * mszsz).tolist(),
        (-J[0] * exchange).tolist(),(J[1] * exchange).tolist(),
    ]
    bond_color = [1, 2, 1, 2]
    hi = nk.hilbert.Spin(s=0.5, total_sz=0.0, N=g.n_nodes)
    op = nk.operator.GraphOperator(hi, graph=g, bond_ops=bond_operator, bond_ops_colors=bond_color)
    
    #model = FFNN()
    model = FNN(2)
    
    sa = nk.sampler.MetropolisExchange(hilbert=hi, graph=g, d_max = 2)
    vs = nk.vqs.MCState(sa, model, n_samples=1008)
    vs_i = vs;
    opt = nk.optimizer.Sgd(learning_rate=0.01)
    sr = nk.optimizer.SR(diag_shift=0.01)
    gs = nk.VMC(hamiltonian=op, optimizer=opt, variational_state=vs, preconditioner=sr)
    # We need to specify the local operators as a matrix acting on a local Hilbert space 
    sf = []
    sites = []
    structure_factor = nk.operator.LocalOperator(hi, dtype=complex)
    for i in range(0, L):
        for j in range(0, L):
            structure_factor += (nk.operator.spin.sigmaz(hi, i)*nk.operator.spin.sigmaz(hi, j))*((-1)**(i-j))/L
    gs.run(out='test', n_iter=it, obs={'Structure Factor': structure_factor}) 
    print("------------  -------------")
    vs_f = vs
    
    return vs_i, vs_f, gs

In [170]:
it = 0; bia_i_list  = [];bia_f_list = [] ; kernel_i_list = []; kernel_f_list = []
#for i in range(40,50):
for i in range(0,10):
    #it = i*(10) + 10
    it = i + 10
    vs_i, vs_f, g  = run(it)
    head_i, body_i, bias_i, kernel_i   = info(vs_i); 
    head_f, body_f, bias_f, kernel_f  = info(vs_f)
    bia_i_list.append(bias_i)
    bia_f_list.append(bias_f)
    kernel_i_list.append(kernel_i)
    kernel_f_list.append(kernel_f)    
    break;

------------ 10 -------------


IncorrectPostInitOverrideError: Overrode `.__post_init__()` without calling `super().__post_init__()` (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.IncorrectPostInitOverrideError)

In [123]:
g

Vmc(
  step_count = 10,
  state = MCState(hilbert = Spin(s=1/2, total_sz=0.0, N=8), sampler = MetropolisSampler(rule = ExchangeRule(# of clusters: 28), n_chains = 16, n_sweeps = 8, reset_chains = False, machine_power = 2, dtype = <class 'float'>), n_samples = 1008))

In [124]:
g.state

MCState(
  hilbert = Spin(s=1/2, total_sz=0.0, N=8),
  sampler = MetropolisSampler(rule = ExchangeRule(# of clusters: 28), n_chains = 16, n_sweeps = 8, reset_chains = False, machine_power = 2, dtype = <class 'float'>),
  n_samples = 1008,
  n_discard_per_chain = 100,
  sampler_state = MetropolisSamplerState(# accepted = 16472/20864 (78.94938650306749%), rng state=[255598534 894310797]),
  n_parameters = 144)

In [125]:
g.state.parameters.keys()

dict_keys(['Dense_0'])

In [126]:
g.state.parameters['Dense_0']

{'bias': Array([ 0.00289549+0.00071373j,  0.00268039-0.0077054j ,
         0.00611114+0.00492813j,  0.01275886-0.00354826j,
         0.01353529-0.01003271j,  0.00196755+0.00234869j,
        -0.00673705-0.00397639j,  0.00738052-0.00537823j,
        -0.00291331-0.00053627j,  0.01008471+0.00221863j,
        -0.00738753-0.01855009j,  0.00660092-0.00456029j,
        -0.00504508-0.00513531j,  0.00332576+0.0027621j ,
         0.0020937 +0.00265095j,  0.00083739+0.00417381j],      dtype=complex128),
 'kernel': Array([[-0.06053235-0.00549636j, -0.01179664-0.0422946j ,
          0.03270989-0.0407181j , -0.0461898 -0.01874433j,
         -0.08880351+0.05412178j,  0.02386505-0.04582802j,
         -0.12397818-0.06794977j,  0.06107529-0.17246938j,
         -0.01910214-0.0684051j , -0.03983793+0.1009705j ,
          0.02697425+0.01636251j,  0.18881455+0.0143407j ,
          0.04328187+0.01596088j,  0.06963702+0.05557598j,
          0.01614129-0.01297855j, -0.02164263+0.09245701j],
        [ 0.11023386

In [127]:
vs_i.parameters['Dense_0']

{'bias': Array([ 0.00289549+0.00071373j,  0.00268039-0.0077054j ,
         0.00611114+0.00492813j,  0.01275886-0.00354826j,
         0.01353529-0.01003271j,  0.00196755+0.00234869j,
        -0.00673705-0.00397639j,  0.00738052-0.00537823j,
        -0.00291331-0.00053627j,  0.01008471+0.00221863j,
        -0.00738753-0.01855009j,  0.00660092-0.00456029j,
        -0.00504508-0.00513531j,  0.00332576+0.0027621j ,
         0.0020937 +0.00265095j,  0.00083739+0.00417381j],      dtype=complex128),
 'kernel': Array([[-0.06053235-0.00549636j, -0.01179664-0.0422946j ,
          0.03270989-0.0407181j , -0.0461898 -0.01874433j,
         -0.08880351+0.05412178j,  0.02386505-0.04582802j,
         -0.12397818-0.06794977j,  0.06107529-0.17246938j,
         -0.01910214-0.0684051j , -0.03983793+0.1009705j ,
          0.02697425+0.01636251j,  0.18881455+0.0143407j ,
          0.04328187+0.01596088j,  0.06963702+0.05557598j,
          0.01614129-0.01297855j, -0.02164263+0.09245701j],
        [ 0.11023386

In [83]:
g.state.parameters['Dense_0']

{'bias': Array([ 0.00604849-0.01302215j, -0.00518966-0.0038263j ,
         0.001783  -0.00801423j, -0.00597949+0.00020621j,
         0.00345921-0.00441961j,  0.00640831+0.00180515j,
        -0.00165057+0.0123789j ,  0.00469725-0.01190759j,
         0.00593717-0.00024499j,  0.0071412 -0.0008751j ,
         0.00118892+0.01014857j, -0.00190333+0.00536017j,
         0.01084671+0.00443603j, -0.00639983+0.0081103j ,
        -0.0083662 +0.0003642j ,  0.00944342-0.00475401j],      dtype=complex128),
 'kernel': Array([[ 7.38653532e-02+0.04805963j, -1.43404073e-01-0.0982466j ,
         -7.27544184e-02-0.01769517j, -3.85064769e-03-0.0237708j ,
         -5.63476092e-02+0.1137615j , -1.32094516e-02-0.00662134j,
         -6.03774043e-02+0.10430256j, -1.44080605e-02-0.02324688j,
          2.63994204e-02-0.05071438j, -5.44174025e-02-0.02511771j,
          8.47932542e-02-0.12819425j, -4.01239474e-02+0.06479916j,
         -7.76071100e-02-0.00452977j,  8.45778513e-02-0.03518249j,
          9.86825199e-02

In [None]:
vs.parameters[head]

In [81]:
g

Vmc(
  step_count = 10,
  state = MCState(hilbert = Spin(s=1/2, total_sz=0.0, N=8), sampler = MetropolisSampler(rule = ExchangeRule(# of clusters: 28), n_chains = 16, n_sweeps = 8, reset_chains = False, machine_power = 2, dtype = <class 'float'>), n_samples = 1008))

In [56]:
vs_i

MCState(
  hilbert = Spin(s=1/2, total_sz=0.0, N=8),
  sampler = MetropolisSampler(rule = ExchangeRule(# of clusters: 28), n_chains = 16, n_sweeps = 8, reset_chains = False, machine_power = 2, dtype = <class 'float'>),
  n_samples = 1008,
  n_discard_per_chain = 100,
  sampler_state = MetropolisSamplerState(# accepted = 16652/20864 (79.81211656441718%), rng state=[3398082654 3858810580]),
  n_parameters = 144)

In [57]:
g.state

MCState(
  hilbert = Spin(s=1/2, total_sz=0.0, N=8),
  sampler = MetropolisSampler(rule = ExchangeRule(# of clusters: 28), n_chains = 16, n_sweeps = 8, reset_chains = False, machine_power = 2, dtype = <class 'float'>),
  n_samples = 1008,
  n_discard_per_chain = 100,
  sampler_state = MetropolisSamplerState(# accepted = 16652/20864 (79.81211656441718%), rng state=[3398082654 3858810580]),
  n_parameters = 144)

In [64]:
head_g, body_g, bias_g, kernel_g   = info(g.state)
head_i, body_i, bias_i, kernel_i   = info(vs_i)

In [71]:
for l in kernel_i:
        for m in l:
            print(r_i(m)) 
            break;

(0.07386535320872238, 0.04805963479200247)
(-0.06783379714057702, 0.011686767102873415)
(0.06429412057769147, -0.06364318843887515)
(-0.0867604555440013, -0.09438920621183337)
(0.08281528643286239, -0.04460515073052329)
(-0.08568492551333044, 0.019274965657488347)
(0.07543171291539558, 0.035675979321252466)
(-0.08775834679399584, 0.04558047725822235)


In [72]:
for l in kernel_g:
        for m in l:
            print(r_i(m)) 
            break;

(0.07386535320872238, 0.04805963479200247)
(-0.06783379714057702, 0.011686767102873415)
(0.06429412057769147, -0.06364318843887515)
(-0.0867604555440013, -0.09438920621183337)
(0.08281528643286239, -0.04460515073052329)
(-0.08568492551333044, 0.019274965657488347)
(0.07543171291539558, 0.035675979321252466)
(-0.08775834679399584, 0.04558047725822235)


In [70]:
for b in bias_g:
    print(r_i(b))
    break;

(0.006048493641296908, -0.01302215213909139)


In [None]:
for b in bias_i:
    print(r_i(b))
    break;

In [None]:
for b in bias_i:
    print(r_i(b))
    break;

In [54]:
g

Vmc(
  step_count = 10,
  state = MCState(hilbert = Spin(s=1/2, total_sz=0.0, N=8), sampler = MetropolisSampler(rule = ExchangeRule(# of clusters: 28), n_chains = 16, n_sweeps = 8, reset_chains = False, machine_power = 2, dtype = <class 'float'>), n_samples = 1008))

In [62]:
vs_i

MCState(
  hilbert = Spin(s=1/2, total_sz=0.0, N=8),
  sampler = MetropolisSampler(rule = ExchangeRule(# of clusters: 28), n_chains = 16, n_sweeps = 8, reset_chains = False, machine_power = 2, dtype = <class 'float'>),
  n_samples = 1008,
  n_discard_per_chain = 100,
  sampler_state = MetropolisSamplerState(# accepted = 16652/20864 (79.81211656441718%), rng state=[3398082654 3858810580]),
  n_parameters = 144)

In [48]:
for bia in bia_i_list:
    for b in bia:
        print(r_i(b))

(0.0012089183343395712, 0.010439892721967579)
(0.011614629010124156, -0.006845203808681925)
(-0.0025105833600316397, 0.005030369890678689)
(-0.00486317502361419, -0.004519582937478232)
(3.3757970504233817e-05, -0.00263628985926806)
(0.011468057335219363, -0.004594882938407451)
(-0.0022727840978556514, -0.0014105674535117721)
(-0.015417960174287978, -0.005411387289561282)
(0.004217750859182926, -0.012272186627498193)
(0.011031790553683645, -0.007603066595460171)
(0.0020719623056127184, 0.007872558967794688)
(-0.004201621892680013, 0.006984906035982861)
(-0.008867641487892314, 6.720291500030104e-06)
(-0.004678142890638788, 0.004850019635618519)
(-0.010780293815930044, -0.0011236152901618248)
(-0.008554623208940331, 0.00014468448832657767)
(-0.005176146919097537, 0.009737032311977638)
(-0.0018918275234901516, 0.007406472998204376)
(0.012286691002789271, 0.007939149821143177)
(-0.0075727461307336106, -0.0017283608224576133)
(-0.006291208095292822, -0.00783944268820518)
(0.00056332344628001

In [49]:
for bia in bia_f_list:
    for b in bia:
        print(r_i(b))

(0.0012089183343395712, 0.010439892721967579)
(0.011614629010124156, -0.006845203808681925)
(-0.0025105833600316397, 0.005030369890678689)
(-0.00486317502361419, -0.004519582937478232)
(3.3757970504233817e-05, -0.00263628985926806)
(0.011468057335219363, -0.004594882938407451)
(-0.0022727840978556514, -0.0014105674535117721)
(-0.015417960174287978, -0.005411387289561282)
(0.004217750859182926, -0.012272186627498193)
(0.011031790553683645, -0.007603066595460171)
(0.0020719623056127184, 0.007872558967794688)
(-0.004201621892680013, 0.006984906035982861)
(-0.008867641487892314, 6.720291500030104e-06)
(-0.004678142890638788, 0.004850019635618519)
(-0.010780293815930044, -0.0011236152901618248)
(-0.008554623208940331, 0.00014468448832657767)
(-0.005176146919097537, 0.009737032311977638)
(-0.0018918275234901516, 0.007406472998204376)
(0.012286691002789271, 0.007939149821143177)
(-0.0075727461307336106, -0.0017283608224576133)
(-0.006291208095292822, -0.00783944268820518)
(0.00056332344628001

In [50]:
for kernel in kernel_i_list:
    for l in kernel:
        for m in l:
            print(r_i(m))    

(-0.08315867050330772, -0.05893580988611658)
(0.03900183390662743, 0.006263200027144738)
(0.04022863506300141, 0.004045480307960666)
(0.008874870673260602, 0.1628776072011607)
(-0.01934942989001642, -0.09901489027225925)
(0.03588481196187401, 0.0012019566918465552)
(0.14992732995517627, -0.07132474109730393)
(0.0021220158527113396, -0.0820365530092662)
(-0.040706748851998854, 0.014243345495966392)
(0.01600646804922593, 0.01841333716360235)
(-0.04685332429115798, -0.029224653911940738)
(0.02415232632570817, 0.07725467830303592)
(0.0704468503649793, 0.05275964552876016)
(0.10550040579515152, 0.09701877657250478)
(-0.12615803254137986, 0.02754442531475753)
(0.04732794963343678, -0.046917892321991124)
(-0.009506305698808823, -0.003364067282619499)
(-0.05941242095968713, -0.016538200631886386)
(-0.06499658569623355, -0.03777132205080851)
(-0.037387521911037155, 0.14242180337268856)
(0.04232432247782635, -0.036432991249748935)
(-0.08854258759049885, 0.11082393648881415)
(-0.15873628559109834

In [51]:
for kernel in kernel_f_list:
    for l in kernel:
        for m in l:
            print(r_i(m))      

(-0.08315867050330772, -0.05893580988611658)
(0.03900183390662743, 0.006263200027144738)
(0.04022863506300141, 0.004045480307960666)
(0.008874870673260602, 0.1628776072011607)
(-0.01934942989001642, -0.09901489027225925)
(0.03588481196187401, 0.0012019566918465552)
(0.14992732995517627, -0.07132474109730393)
(0.0021220158527113396, -0.0820365530092662)
(-0.040706748851998854, 0.014243345495966392)
(0.01600646804922593, 0.01841333716360235)
(-0.04685332429115798, -0.029224653911940738)
(0.02415232632570817, 0.07725467830303592)
(0.0704468503649793, 0.05275964552876016)
(0.10550040579515152, 0.09701877657250478)
(-0.12615803254137986, 0.02754442531475753)
(0.04732794963343678, -0.046917892321991124)
(-0.009506305698808823, -0.003364067282619499)
(-0.05941242095968713, -0.016538200631886386)
(-0.06499658569623355, -0.03777132205080851)
(-0.037387521911037155, 0.14242180337268856)
(0.04232432247782635, -0.036432991249748935)
(-0.08854258759049885, 0.11082393648881415)
(-0.15873628559109834