In [None]:
import netket as nk
from jax import numpy as jnp
from jax.nn.initializers import truncated_normal, zeros
from netket import experimental as nkx
from netket.jax import dtype_real
from netket.nn import log_cosh

from ham import ColoredJ1J2, Hubbard, tVModel, J1J2OneD
import sym_sinekan
import mlp


def get_ham():
    L = args.L
    L2 = args.L2

    if args.boundary == "peri":
        pbc = True
    elif args.boundary == "open":
        pbc = False
    else:
        raise ValueError(f"Unknown boundary: {args.boundary}")

    if args.ham.endswith("tri"):
        assert args.ham_dim == 2
        graph = ColoredJ1J2((L, L2), pbc, back_diag=False)
    elif args.ham.endswith("kag"):
        assert args.ham_dim == 2
        # Nikita's kagome_sqrt18: extent = [3, 2]
        graph = nk.graph.Lattice(
            basis_vectors=[[1, 0], [-0.5, 0.75**0.5]],
            extent=[L, L2],
            pbc=pbc,
            site_offsets=[[0.5, 0], [0.25, 0.75**0.5 / 2], [0.75, 0.75**0.5 / 2]],
        )
    elif args.ham == "j1j2":
        assert args.ham_dim == 2
        graph = ColoredJ1J2((L, L2), pbc)
    else:
        if args.ham_dim == 2:
            extent = [L, L2]
        else:
            extent = [L] * args.ham_dim
        graph = nk.graph.Grid(extent=extent, pbc=pbc)

    if args.ham == "hubb":
        assert not args.zero_mag
        hilbert = nkx.hilbert.SpinOrbitalFermions(
            n_orbitals=graph.n_nodes, s=1 / 2, n_fermions=(args.Nf,) * 2
        )
    elif args.ham == "tv":
        assert not args.zero_mag
        hilbert = nkx.hilbert.SpinOrbitalFermions(
            n_orbitals=graph.n_nodes, n_fermions=args.Nf
        )
    else:
        assert not args.Nf
        if args.zero_mag:
            hilbert = nk.hilbert.Spin(s=1 / 2, N=graph.n_nodes, total_sz=0)
        else:
            hilbert = nk.hilbert.Spin(s=1 / 2, N=graph.n_nodes)

    J = 1
    sign = args.sign == "mars"

    if args.ham == "ising":
        assert args.sign == "none"
        assert not args.J2
        assert not args.U
        assert not args.V
        H = nk.operator.IsingJax(hilbert=hilbert, graph=graph, J=-J, h=args.h)
    elif args.ham.startswith("heis"):
        assert not args.J2
        assert not args.U
        assert not args.V
        assert not args.h
        if args.ham.endswith("tri"):
            H = nk.operator.Heisenberg(
                hilbert=hilbert, graph=graph, J=[J, J], sign_rule=[sign, False]
            )
        else:
            H = nk.operator.Heisenberg(
                hilbert=hilbert, graph=graph, J=J, sign_rule=sign
            )
    elif args.ham == "j1j2":
        assert not args.U
        assert not args.V
        assert not args.h
        H = nk.operator.Heisenberg(
            hilbert=hilbert, graph=graph, J=[J, args.J2], sign_rule=[sign, False],
        )
    elif args.ham == "hubb":
        assert args.sign == "none"
        assert not args.J2
        assert not args.V
        assert not args.h
        H = Hubbard(hilbert=hilbert, graph=graph, U=args.U)
    elif args.ham == "tv":
        assert args.sign == "none"
        assert not args.J2
        assert not args.U
        assert not args.h
        H = tVModel(hilbert=hilbert, graph=graph, V=args.V)
    elif args.ham == "j1j2_1d":
        H, graph, hilbert = J1J2OneD(L, args.J2, pbc, use_marshall=args.sign=="mars",
                                     total_sz=0 if args.zero_mag else None)
    else:
        raise ValueError(f"Unknown ham: {args.ham}")

    return graph, hilbert, H


def get_net(graph, hilbert):
    N = hilbert.size
    if args.net == "jas":
        assert args.layers == 1
        assert args.features == 1
        return nk.models.Jastrow(
            param_dtype=args.dtype, kernel_init=truncated_normal(stddev=1 / N)
        )
    elif args.net == "rbm":
        assert args.layers == 1
        alpha = args.features
        if jnp.issubdtype(args.dtype, jnp.floating):
            kernel_init = truncated_normal(stddev=1 / (alpha**0.5 * N))
        else:
            kernel_init = truncated_normal(stddev=1 / (alpha**0.25 * N**0.75))

        return nk.models.RBM(
            alpha=alpha,
            param_dtype=args.dtype,
            activation=log_cosh,
            kernel_init=kernel_init,
            hidden_bias_init=zeros,
            visible_bias_init=zeros,
        )
    elif args.net == "gcnn":
        return nk.models.GCNN(
            symmetries=graph,
            layers=args.layers,
            features=args.features,
            param_dtype=args.dtype,
        )
    elif args.net == "rnn_lstm":
        return nkx.models.FastLSTMNet(
            hilbert=hilbert,
            layers=args.layers,
            features=args.features,
            graph=graph,
            param_dtype=args.dtype,
        )
    elif args.net == "mlp":
        return mlp.MLP(
            layers_hidden=args.layers_hidden,
        )
    elif args.net == "symmlp":
        return mlp.SymmetricMLP(
            layers_hidden=args.layers_hidden,
        )
    elif args.net == "sinekan":
        return sym_sinekan.SineKAN(
            layers_hidden=args.layers_hidden,
            grid_size=args.grid_size,
        )
    elif args.net == "sym_sinekan":
        return sym_sinekan.SymmetricSineKAN1D(
            layers_hidden=args.layers_hidden,
            grid_size=args.grid_size,
        )
    else:
        raise ValueError(f"Unknown net: {args.net}")


def get_sampler(graph, hilbert):
    if args.ham in ["hubb", "tv"] or args.zero_mag:
        if args.net.startswith("rnn"):
            raise NotImplementedError
        elif args.net.endswith("sinekan"):
            return nk.sampler.MetropolisLocal(
                hilbert, n_chains=args.batch_size, dtype=dtype_real(args.dtype)
            )
        else:
            if args.ham == "hubb":
                graph = nk.graph.disjoint_union(graph, graph)
            return nk.sampler.MetropolisExchange(
                hilbert,
                graph=graph,
                n_chains=args.batch_size,
                dtype=dtype_real(args.dtype),
            )
    else:
        if args.net.startswith("rnn"):
            return nk.sampler.ARDirectSampler(hilbert, dtype=dtype_real(args.dtype))
        else:
            return nk.sampler.MetropolisLocal(
                hilbert, n_chains=args.batch_size, dtype=dtype_real(args.dtype)
            )


def get_vstate(sampler, model):
    return nk.vqs.MCState(
        sampler,
        model,
        n_samples=args.batch_size,
        n_discard_per_chain=0,
        chunk_size=args.chunk_size,
        seed=args.seed,
    )

In [None]:
import argparse
import os
from datetime import datetime
import numpy as np
import netket as nk
import flax

# --- Provided Functions (unchanged) --- #

def parse_tuple(arg_str):
    try:
        return tuple(int(x.strip()) for x in arg_str.split(","))
    except ValueError:
        raise argparse.ArgumentTypeError(
            f"Invalid tuple format: '{arg_str}'. Expected format: '64,64'."
        )

def get_parser():
    parser = argparse.ArgumentParser(allow_abbrev=False)

    group = parser.add_argument_group("physics parameters")
    group.add_argument(
        "--ham",
        type=str,
        default="ising",
        choices=["ising", "heis", "heis_tri", "heis_kag", "j1j2", "hubb", "tv", "j1j2_1d"],
        help="Hamiltonian type",
    )
    group.add_argument(
        "--boundary",
        type=str,
        default="open",
        choices=["open", "peri"],
        help="boundary conditions",
    )
    group.add_argument(
        "--sign",
        type=str,
        default="none",
        choices=["none", "mars"],
        help="sign rule",
    )
    group.add_argument(
        "--ham_dim",
        type=int,
        default=1,
        choices=[1, 2],
        help="dimension of the lattice",
    )
    group.add_argument(
        "--L",
        type=int,
        default=4,
        help="edge length of the lattice",
    )
    group.add_argument(
        "--L2",
        type=int,
        default=0,
        help="another edge length of the lattice",
    )
    group.add_argument(
        "--J2",
        type=float,
        default=0,
        help="2nd nearest neighbor interaction",
    )
    group.add_argument(
        "--U",
        type=float,
        default=0,
        help="on-site interaction",
    )
    group.add_argument(
        "--V",
        type=float,
        default=0,
        help="repulsive interaction",
    )
    group.add_argument(
        "--h",
        type=float,
        default=0,
        help="external field",
    )
    group.add_argument(
        "--zero_mag",
        action="store_true",
        help="use zero magnetization constraint",
    )
    group.add_argument(
        "--Nf",
        type=int,
        default=0,
        help="number of fermions",
    )

    group = parser.add_argument_group("network parameters")
    group.add_argument(
        "--net",
        type=str,
        default="jas",
        choices=["jas", "rbm", "gcnn", "rnn_lstm", "sym_sinekan", "sinekan", "mlp", "symmlp"],
        help="network type",
    )
    group.add_argument(
        "--layers",
        type=int,
        default=1,
        help="number of layers",
    )
    group.add_argument(
        "--layers_hidden",
          type=parse_tuple,
          help="Comma-separated tuple for hidden layers, e.g., '64,64'."
    )
    group.add_argument(
        "--grid_size",
        type=int,
        default=8,
        help="grid size of a sinekan model"
    )
    group.add_argument(
        "--features",
        type=int,
        default=1,
        help="number of features",
    )
    group.add_argument(
        "--dtype",
        type=str,
        default="float32",
        choices=["float32", "float64", "complex64", "complex128"],
        help="data type",
    )
    group = parser.add_argument_group("optimizer parameters")
    group.add_argument(
        "--seed",
        type=int,
        default=0,
        help="random seed, 0 for randomized",
    )
    group.add_argument(
        "--optimizer",
        type=str,
        default="sr",
        choices=["adam", "sgd", "sr", "custom"],
        help="optimizer type",
    )
    group.add_argument(
        "--split_real",
        action="store_true",
        help="split real and imaginary parts of parameters in the optimizer",
    )
    group.add_argument(
        "--batch_size",
        type=int,
        default=1024,
        help="batch size",
    )
    group.add_argument(
        "--lr",
        type=float,
        default=1e-3,
        help="learning rate",
    )
    group.add_argument(
        "--decay_time",
        type=float,
        default=1000,
        help="controls decay time for learning rate",
    )
    group.add_argument(
        "--diag_shift",
        type=float,
        default=1e-2,
        help="diagonal shift of SR",
    )
    group.add_argument(
        "--max_step",
        type=int,
        default=10**4,
        help="number of training/sampling steps",
    )
    group.add_argument(
        "--drop_step",
        type=int,
        default=10**4,
        help="steps after which to decrease learning rate near converging point"
    )
    group.add_argument(
        "--grad_clip",
        type=float,
        default=0,
        help="global norm to clip gradients, 0 for disabled",
    )
    group.add_argument(
        "--chunk_size",
        type=int,
        default=1024,
        help="chunk size, 0 for disabled",
    )
    group.add_argument(
        "--estim_size",
        type=int,
        default=1024**2,
        help="batch size to estimate the Hamiltonian, 0 for matching 'batch_size'",
    )
    group = parser.add_argument_group("system parameters")
    group.add_argument(
        "--show_progress",
        action="store_true",
        help="show progress",
    )
    group.add_argument(
        "--cuda",
        type=str,
        default="auto",
        help="GPU ID, empty string for disabling GPU, multi-GPU is not supported yet",
    )
    group.add_argument(
        "--run_name",
        type=str,
        default="",
        help="output subdirectory to keep repeated runs, empty string for disabled",
    )
    group.add_argument(
        "-o",
        "--out_dir",
        type=str,
        default="./out",
        help="output directory, empty string for disabled",
    )

    return parser

def get_ham_net_name(args):
    ham_name = "{ham}_{boundary}"
    if args.sign != "none":
        ham_name += "_{sign}"
    ham_name += "_{ham_dim}d_L{L}"
    if args.L2 and args.L2 != args.L:
        ham_name += ",{L2}"
    if args.J2:
        ham_name += "_J2={J2:g}"
    if args.U:
        ham_name += "_U={U:g}"
    if args.V:
        ham_name += "_V={V:g}"
    if args.h:
        ham_name += "_h={h:g}"
    if args.zero_mag:
        ham_name += "_zm"
    if args.Nf:
        ham_name += "_Nf{Nf}"
    ham_name = ham_name.format(**vars(args))

    net_name = "{net}"
    if args.net == "rbm":
        net_name += "_a{features}"
    elif args.net != "jas":
        net_name += "_l{layers}_f{features}"

    net_name += "_{optimizer}"
    if args.split_real:
        net_name += "_sp"
    if args.grad_clip:
        net_name += "_gc{grad_clip:g}"
    net_name = net_name.format(**vars(args))

    return ham_name, net_name

def post_init_args(args):
    if args.ham_dim == 1:
        assert args.L2 == 0
    else:
        if args.L2 == 0:
            args.L2 = args.L

    if args.seed == 0:
        # The seed depends on the time and the PID
        args.seed = hash((datetime.now(), os.getpid())) & (2**32 - 1)

    if args.optimizer == "sr" and args.diag_shift == 0:
        args.diag_shift = args.lr

    if args.chunk_size == 0:
        args.chunk_size = None

    if args.estim_size == 0:
        args.estim_size = args.batch_size

    args.ham_name, args.net_name = get_ham_net_name(args)

    if args.dtype in ["float32", np.float32]:
        args.dtype = np.float32
    elif args.dtype in ["float64", np.float64]:
        args.dtype = np.float64
    elif args.dtype in ["complex64", np.complex64]:
        args.dtype = np.complex64
    elif args.dtype in ["complex128", np.complex128]:
        args.dtype = np.complex128
    else:
        raise ValueError(f"Unknown dtype: {args.dtype}")

    if args.out_dir:
        args.full_out_dir = "{out_dir}/{ham_name}/{net_name}/".format(**vars(args))
        if args.run_name:
            args.full_out_dir = "{full_out_dir}{run_name}/".format(**vars(args))
        args.log_filename = args.full_out_dir + "out"
    else:
        args.full_out_dir = None
        args.log_filename = None

def set_env(args):
    if args.cuda != "auto":
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
    np.random.seed(args.seed)

# --- End Provided Functions --- #

# Here is our helper function to build the variational state from the parsed arguments.
# (This function uses your already defined get_ham, get_net, and get_sampler functions.)
def get_vstate_from_args(args, ham_kwargs=None, net_kwargs=None, sampler_kwargs=None):
    """
    Constructs and returns the variational state (vstate) using the provided arguments.

    Parameters:
      args : An argparse.Namespace object containing all necessary parameters.
      ham_kwargs, net_kwargs, sampler_kwargs : Optional dicts for passing additional keyword
                                                arguments to get_ham, get_net, and get_sampler.

    Returns:
      vstate : The NetKet variational state.
      H      : The Hamiltonian (useful for computing expectation values).
    """
    ham_kwargs = ham_kwargs or {}
    net_kwargs = net_kwargs or {}
    sampler_kwargs = sampler_kwargs or {}

    # Build the Hamiltonian, Hilbert space, and lattice graph.
    graph, hilbert, H = get_ham(**ham_kwargs)

    # Build the neural network model.
    model = get_net(graph, hilbert, **net_kwargs)

    # Build the sampler.
    sampler = get_sampler(graph, hilbert, **sampler_kwargs)

    # Construct the variational state.
    vstate = nk.vqs.MCState(
        sampler,
        model,
        n_samples=args.batch_size,
        n_discard_per_chain=0,
        chunk_size=args.chunk_size,
        seed=args.seed,
    )
    return vstate, H, graph, hilbert

# ---- In a Jupyter Notebook, simulate command-line arguments ---- #
# Here is the command-line argument string you mentioned:
#   --ham j1j2_1d --boundary peri --sign mars --J2 0.4 --ham_dim 1 --L 100 --zero_mag
#   --net sym_sinekan --layers_hidden 64,64,1 --grid_size 8 --seed 123 --optimizer custom
#   --drop_step 50_000 --decay_time 1_000 --max_step 60_000 --show_progress --lr 1e-4

# We simulate this by constructing a list of strings:
args_list = [
    "--ham", "j1j2_1d",
    "--boundary", "peri",
    "--sign", "mars",
    "--J2", "0.4",
    "--ham_dim", "1",
    "--L", "100",
    "--zero_mag",
    "--net", "symmlp",
    "--layers_hidden", "256,256,1",
    # "--grid_size", "8",
    # "--net", "rbm",
    # "--features", "128",
    "--seed", "123",
    "--optimizer", "custom",
    "--drop_step", "30000",
    "--decay_time", "1000",
    "--max_step", "34000",
    "--show_progress",
    "--lr", "1e-3"
]

# Parse the simulated arguments.
parser = get_parser()
args = parser.parse_args(args_list)

# Post-process the arguments and set environment variables.
post_init_args(args)
set_env(args)

# ---- Build the variational state using our helper function ---- #
initial_vstate, H, graph, hilbert = get_vstate_from_args(args)

# Now 'vstate' is the variational state built according to your parameters,
# and you can proceed to use it (for example, passing it to a VMC driver,
# computing expectation values, or loading it later via flax.serialization).

print("Variational state built with", initial_vstate.n_parameters, "parameters.")
print("Hamiltonian:", H)


Variational state built with 91905 parameters.
Hamiltonian: LocalOperator(dim=100, #acting_on=200 locations, constant=0.0, dtype=float64)


In [None]:
fp = "/content/drive/MyDrive/Spin_Lattice/out/j1j2_1d_peri_mars_1d_L100_J2=0.4_zm/symmlp_l1_f1_custom/"

In [None]:
# with open("/content/final_vstate (3).mpack", 'rb') as file:
with open(fp + "final_vstate.mpack", 'rb') as file:
  vstate = flax.serialization.from_bytes(initial_vstate, file.read())

In [None]:
# print(vstate.expect(structure_factor).mean.real, "+/-", vstate.expect(structure_factor).error_of_mean.real, ", variance:", vstate.expect(structure_factor).variance.real)
print(vstate.expect(H).mean.real, "+/-", vstate.expect(H).error_of_mean.real, ", variance:", vstate.expect(H).variance.real)

-37.907705218400494 +/- 0.018664025171303465 , variance: 0.3567061356493305
