In [1]:
import netket as nk
import netket.experimental as nkx

import numpy as np
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from netket.experimental.operator.fermion import destroy as c
from netket.experimental.operator.fermion import create as cdag
from netket.experimental.operator.fermion import number as nc

In [3]:
L = 4  # Side of the square
graph = nk.graph.Square(L)
N = graph.n_nodes

t = 1.0
N_f = 1

hi_help = nkx.hilbert.SpinOrbitalFermions(N, s=5/2, n_fermions_per_spin=(N_f,N_f,N_f,N_f,N_f,N_f))
H_help = 0.0
for (i, j) in graph.edges():
    for sz in [-5, -3, -1, 1, 3, 5]:
        H_help -= t * (cdag(hi_help,i,sz) * c(hi_help,j,sz) + cdag(hi_help,j,sz) * c(hi_help,i,sz))

In [6]:
hi_help.n_spin_subsectors

6

In [5]:
import flax.linen as nn
from netket.utils.types import NNInitFunc
from netket.nn.masked_linear import default_kernel_init
from netket import jax as nkjax
from typing import Any, Callable, Sequence
from functools import partial

import jax
import jax.numpy as jnp
DType = Any

In [9]:
class LogSlaterDeterminant(nn.Module):
    hilbert: nkx.hilbert.SpinOrbitalFermions
    kernel_init: NNInitFunc = default_kernel_init
    param_dtype: DType = float

    def setup(self):
        self.M = self.param('M', self.kernel_init, 
                   (self.hilbert.n_orbitals, self.hilbert.n_fermions_per_spin[0]), 
                   self.param_dtype)  

    @nn.compact
    def __call__(self, n):
        @partial(jnp.vectorize, signature='(n)->()')
        def log_sd(n):
            #Find the positions of the occupied orbitals 
            R = n.nonzero(size=self.hilbert.n_fermions)[0]
            
            log_psi = 0.0
            for sz in range(self.hilbert.n_spin_subsectors):
                Rsz = R[self.hilbert.n_fermions_per_spin[sz] * sz : self.hilbert.n_fermions_per_spin[sz] * (sz+1)]
                Msz = self.M[Rsz]
                log_psi += nkjax.logdet_cmplx(Msz)
            return log_psi

        return log_sd(n)

In [10]:
model = LogSlaterDeterminant(hi_help, param_dtype=complex)
sa = nkx.sampler.MetropolisParticleExchange(hi_help, graph=graph, n_chains=16, exchange_spins=False, sweep_size=64)
op = nk.optimizer.Sgd(learning_rate=0.01)
vstate = nk.vqs.MCState(sa, model, n_samples=512, n_discard_per_chain=16)
preconditioner = nk.optimizer.SR(diag_shift=0.05, holomorphic=True)
gs = nk.VMC(H_help, op, variational_state=vstate, preconditioner=preconditioner)
bfsd_log=nk.logging.RuntimeLog()
gs.run(n_iter=300, out=bfsd_log)

100%|██████████| 300/300 [02:26<00:00,  2.05it/s, Energy=-23.9999959-0.0000032j ± 0.0000075 [σ²=0.0000000, R̂=1.0121]]


(RuntimeLog():
  keys = ['acceptance', 'Energy'],)

In [1]:
import os
print("当前工作目录:", os.getcwd())

当前工作目录: /root/netket_test/notebooks
