In [1]:
import os
os.environ['NETKET_EXPERIMENTAL_SHARDING'] = '1'

import netket as nk

from model import LogSlaterDeterminant, LogFullNeuralBackflow, CombinedModel
from hamiltonian.hubbard import Hubbard, Hubbard_extend

from drivers.VMC_infinity import VMCInfinity
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
t = 1.0
U = 1.0
N_f = 8

Lx = 4
Ly = 4

In [3]:
hi, H, graph = Hubbard(t, U, [Lx, Ly], [False, True], (N_f, N_f))
hi_help, H_full, graph_full = Hubbard_extend(t, U, [Lx, Ly], [False, True], (N_f, N_f))
H = H.to_jax_operator()
H_full = H_full.to_jax_operator()

In [4]:
# Define the sampler
model_base = LogFullNeuralBackflow.LogFullNeuralBackflow(hi, 
                                                         #param_dtype=complex, 
                                                         hidden_units=32)
#model_base = LogFullNeuralBackflow(hi, param_dtype=complex, hidden_units=16)
sys_backflow = copy.deepcopy(model_base)
env_backflow = copy.deepcopy(model_base)
model = CombinedModel.CombinedNeuralBackflow(hi_help, 
                                             sys_backflow=sys_backflow, 
                                             env_backflow=env_backflow, 
                                             #param_dtype=complex
                                             )

In [5]:
sa =  nk.sampler.MetropolisExchange(hi_help, graph=graph_full, n_chains=16, sweep_size=96)
op = nk.optimizer.Sgd(learning_rate=0.01)

  sa =  nk.sampler.MetropolisExchange(hi_help, graph=graph_full, n_chains=16, sweep_size=96)


In [6]:
vstate = nk.vqs.MCState(sa, model, n_samples=2**12, n_discard_per_chain=16)

In [7]:
vstate.expect(H_full)

12.11+0.00j ± 0.18 [σ²=123.13, R̂=1.0036]

In [73]:
vstate.samples[0][0].shape
print(vstate.samples[4][0].reshape(6,4,4))

[[[1 0 0 1]
  [1 0 1 0]
  [0 1 0 1]
  [0 1 1 0]]

 [[0 1 1 0]
  [0 0 0 1]
  [1 0 1 1]
  [1 0 0 1]]

 [[0 1 1 0]
  [0 0 1 1]
  [0 0 1 1]
  [1 0 0 1]]

 [[1 0 0 1]
  [1 1 0 0]
  [1 1 0 0]
  [0 1 1 0]]

 [[0 0 1 0]
  [0 1 0 1]
  [1 0 1 0]
  [1 0 1 1]]

 [[1 1 0 1]
  [0 0 1 1]
  [0 1 0 0]
  [0 1 0 1]]]


In [57]:
vstate.parameters.keys()

dict_keys(['env_backflow', 'sys_backflow'])

In [34]:
import flax

pars = flax.core.copy(vstate.parameters, {})

In [56]:
# pars

In [46]:
sa_base = nk.sampler.MetropolisFermionHop(
    hi, graph=graph
    #, dtype=np.int8, n_chains=16, sweep_size=64
)
vstate_base = nk.vqs.MCState(sa_base, model_base, 
                             n_samples=2**12, 
                             n_discard_per_chain=16)

In [47]:
print(vstate_base.parameters['Dense_0']['kernel'].shape)
print(vstate_base.parameters['Dense_0']['bias'].shape)
print(vstate_base.parameters['Dense_1']['kernel'].shape)
print(vstate_base.parameters['Dense_1']['bias'].shape)

(32, 32)
(32,)
(32, 512)
(512,)


In [48]:
vstate_base.expect(H)

30.15+0.00j ± 0.28 [σ²=144.18, R̂=1.0120]

In [49]:
# load
with open("../test/4_4_U8_NNS/4_4_U8_N8_pbc_test.mpack", 'rb') as file:
  vstate_base.variables = flax.serialization.from_bytes(vstate_base.variables, file.read())

In [50]:
print(vstate_base.parameters['Dense_0']['kernel'].shape)
print(vstate_base.parameters['Dense_0']['bias'].shape)
print(vstate_base.parameters['Dense_1']['kernel'].shape)
print(vstate_base.parameters['Dense_1']['bias'].shape)

(32, 32)
(32,)
(32, 512)
(512,)


In [55]:
vstate_base.expect(H)

-6.62-0.00j ± 0.12 [σ²=6.78, R̂=1.0620]

In [58]:
pars['sys_backflow']

{'Dense_0': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float64),
  'kernel': Array([[-0.22161072, -0.39535937, -0.04723201, ...,  0.26079883,
           0.23031004,  0.00903337],
         [-0.22988689,  0.00165779, -0.02688494, ..., -0.02885858,
           0.08574698,  0.11739345],
         [ 0.08500679,  0.0007416 ,  0.22085993, ...,  0.10737033,
          -0.21011714, -0.0796543 ],
         ...,
         [ 0.051961  , -0.3449762 , -0.1146706 , ..., -0.10057967,
          -0.05921012,  0.19593621],
         [-0.22706153,  0.16084708, -0.21472319, ..., -0.14618996,
           0.18109311,  0.02442786],
         [ 0.15172325, -0.0654453 , -0.14680221, ..., -0.24764085,
           0.24570636,  0.01876069]], dtype=float64)},
 'Dense_1': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [60]:
pars['sys_backflow'] = copy.deepcopy(vstate_base.parameters)
pars['env_backflow'] = copy.deepcopy(vstate_base.parameters)

In [61]:
vstate.parameters = pars

In [62]:
vstate.expect(H_full)

520.4+0.0j ± 6.4 [σ²=92925.5, R̂=1.0059]