In [None]:
import sys
import os
# os.environ['JAX_PLATFORM_NAME'] = 'cpu'

import numpy as np
import time

# Add the parent directory (project root) to sys.path
script_dir = os.path.dirname(os.curdir)
parent_dir = os.path.abspath(os.path.join(script_dir, '..'))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# ------------------------------------------------------------------
#! General
import QES.general_python.common.binary as Binary
from QES.general_python.common.timer import Timer, timeit
from QES.general_python.common.binary import JAX_AVAILABLE, get_backend, get_global_logger
from QES.general_python.common.plot import Plotter, MatrixPrinter

# ------------------------------------------------------------------
#! Lattice
from QES.general_python.lattices import choose_lattice, plot_bonds

# ------------------------------------------------------------------
#! Sampler
import QES.Solver.MonteCarlo.sampler as Sampler
import QES.NQS.nqs as NQS
import QES.NQS.nqs_train as NQST
import QES.NQS.tdvp as TDVP
from QES.general_python.ml.schedulers import EarlyStopping, ConstantScheduler, ExponentialDecayScheduler

# ------------------------------------------------------------------
#! ODE solver
from QES.general_python.algebra.ode import choose_ode

# ------------------------------------------------------------------
#! Networks
from QES.general_python.ml.net_impl.networks.net_rbm import RBM
from QES.general_python.ml.net_impl.networks.net_cnn import CNN
from QES.general_python.ml.net_impl.activation_functions import relu_jnp, tanh_jnp, sigmoid_jnp, leaky_relu_jnp, elu_jnp, poly6_jnp, softplus_jnp

# ------------------------------------------------------------------
#! Hamiltonians
from QES.general_python.algebra.linalg import act, overlap
from QES.Algebra.Model.dummy import DummyHamiltonian
from QES.Algebra.Model.Interacting.Spin.heisenberg_kitaev import HeisenbergKitaev
from QES.Algebra.Model.Interacting.Spin.transverse_ising import TransverseFieldIsing

# ------------------------------------------------------------------
#! Linear algebra
import QES.general_python.algebra.solvers.stochastic_rcnfg as SR
import QES.general_python.algebra.solvers as solvers
import QES.general_python.algebra.preconditioners as preconditioners

# ------------------------------------------------------------------

#! Spin operators
import QES.Algebra.Operator.operators_spin as op_spin
#! Fermionic operators
import QES.Algebra.Operator.operators_spinless_fermions as op_sferm

#! Backends
if JAX_AVAILABLE:
    import jax
    import jax.numpy as jnp
else:
    jax = None
    jnp = None
    
# ------------------------------------------------------------------
lattice_type        = 'honeycomb'
# lattice_type        = 'square'
lx, ly, lz          = 8, 4, 1
# lx, ly, lz          = 10, 1, 1
# ------------------------------------------------------------------
scheduler           = 'ExponentialDecay'
ode_solver_type     = 'Euler'
# ode_solver_type   = 'Heun'
# ode_solver_type   = 'AdaptiveHeun'
# ode_solver_type   = 'RK4'
lr                  = 0.1 # is also dt
# ------------------------------------------------------------------
network_type        = 'RBM'
# network_type        = 'CNN'
alpha               = 3
# ------------------------------------------------------------------
ham_type            = 'HeisenbergKitaev'
# ham_type            = 'TransverseFieldIsing'
# ham_type            = 'DummyHamiltonian'
ham_dtype           = jnp.float32
# ------------------------------------------------------------------

logger              = get_global_logger()
backend             = 'jax'
seed                = 0
dtypex              = jnp.complex128
be_modules          = get_backend(backend, random=True, seed=seed, scipy=True)
backend_np, (rng, rng_k), backend_sp = be_modules if isinstance(be_modules, tuple) else (be_modules, (None, None), None)

# ------------------------------------------------------------------


### Lattice interface

In [None]:
lattice = choose_lattice(
    typek = lattice_type,
    lx    = lx,
    ly    = ly,
    lz    = lz,
    bc    = 'mbc'
)
ns, mult, st_shape = lattice.ns, lattice.ns // (lx * ly * lz), (lattice.ns, )
lattice.print_forward(logger=logger)

# lattice
# MatrixPrinter.print_matrix(A)

### Operators to test later on

In [None]:
sig_z = op_spin.sig_z(
    lattice = lattice,
    ns      = lattice.ns,
    type_act= op_spin.OperatorTypeActing.Global,
    sites   = [0]
)

sig_x = op_spin.sig_x(
    lattice = lattice,
    ns      = lattice.ns,
    type_act= op_spin.OperatorTypeActing.Global,
    sites   = [0]
)

sig_z_c = op_spin.sig_z(
    lattice = lattice,
    ns      = lattice.ns,
    type_act= op_spin.OperatorTypeActing.Global,
    sites   = [0, 1]
)

sig_x_c = op_spin.sig_x(
    lattice = lattice,
    ns      = lattice.ns,
    type_act= op_spin.OperatorTypeActing.Global,
    sites   = [0, 1]
)

if lattice.ns == 2:
    matrix_test = np.kron(op_spin._SIG_X, op_spin._SIG_0) * (0.5)
    print('Matrix test:', matrix_test)

### Hamiltonian - Hamiltonian operator $H$ and its expectation value $\langle H \rangle$.

In [None]:
if ham_type == 'HeisenbergKitaev':
    hamil = HeisenbergKitaev(lattice    = lattice,
                        hilbert_space   = None,
                        hx              = 0.0,
                        hz              = 0.0,
                        kx              = 1.0,
                        ky              = 1.0,
                        kz              = 1.0,
                        j               = 0.0,
                        dlt             = 1.0,
                        dtype           = ham_dtype,
                        use_forward     = False,
                        backend         = backend)

elif ham_type == 'TransverseFieldIsing':
    hamil = TransverseFieldIsing(
                        lattice        = lattice,
                        hilbert_space  = None,
                        hz             = 2.5,
                        hx             = -1.4,
                        # hx             = -0.7,    # same, but we use spin = 1/2
                        j              = 4.0,          
                        # j              = -1.0,    # same, but we use spin = 1/2          
                        dtype          = ham_dtype,
                        backend        = backend
                )
elif ham_type == 'DummyHamiltonian':
    hamil = DummyHamiltonian(
                        lattice        = lattice,
                        hilbert_space  = None,
                        dtype          = ham_dtype,
                        backend        = backend
                )
else:
    raise ValueError(f"Unknown Hamiltonian type: {ham_type}")

logger.title('Hamiltonian', desired_size=150, fill='#', color='red')
hamil

In [None]:
if hamil.hilbert_size <= 2**20:
    time0   = time.time()
    hamil.build(use_numpy=True)
    time1   = time.time()
    logger.info(f"Time to build Hamiltonian: {time1 - time0:.2f} seconds", color='green')
    
    if hamil.hilbert_size <= 2**12:
        hamil.diagonalize()
    else:
        hamil.diagonalize(method = 'lanczos', k = 50)
    time2   = time.time()
    logger.info(f"Time to diagonalize Hamiltonian: {time2 - time1:.2f} seconds", color='blue')
    eigv    = hamil.get_eigval()
    
    #! Test the operator expectation in the ground state
    if hamil.hilbert_size <= 2**12:
        gs              = hamil.get_eigvec(0)
        # test energy 
        hamil_mat       = hamil.hamil
        energy_0        = overlap(gs, hamil_mat, backend = np)   
        logger.info(f"Energy of the ground state: {energy_0:.4f}", color='green')
        
        sig_x_op_mat    = sig_x.matrix(dim = hamil.hilbert_size, use_numpy = True)
        sig_z_op_mat    = sig_z.matrix(dim = hamil.hilbert_size, use_numpy = True)
        sig_x_op_mat_c  = sig_x_c.matrix(dim = hamil.hilbert_size, use_numpy = True)
        sig_z_op_mat_c  = sig_z_c.matrix(dim = hamil.hilbert_size, use_numpy = True)

        ed_sig_x_exp    = overlap(gs, sig_x_op_mat, backend = np)
        ed_sig_z_exp    = overlap(gs, sig_z_op_mat, backend = np)
        ed_sig_x_exp_c  = overlap(gs, sig_x_op_mat_c, backend = np)
        ed_sig_z_exp_c  = overlap(gs, sig_z_op_mat_c, backend = np)
        logger.info(f"sig_x expectation value: {ed_sig_x_exp:.4f}", color='green')
        logger.info(f"sig_z expectation value: {ed_sig_z_exp:.4f}", color='green')
        logger.info(f"sig_x expectation value (c): {ed_sig_x_exp_c:.4f}", color='green')
        logger.info(f"sig_z expectation value (c): {ed_sig_z_exp_c:.4f}", color='green')
    else:
        ed_sig_x_exp    = None
        ed_sig_z_exp    = None
        ed_sig_x_exp_c  = None
        ed_sig_z_exp_c  = None
        logger.info(f"(TODO) Cannot compute expectation values for Hamiltonian: {hamil.hilbert_size} > 2^12", color='red')
else:
    eigv                = [None]
    ed_sig_x_exp        = None
    ed_sig_z_exp        = None
    ed_sig_x_exp_c      = None
    ed_sig_z_exp_c      = None
    logger.info(f"Cannot diagonalize Hamiltonian: {hamil.hilbert_size} > 2^20", color='red')
    

#### Plot if needed

In [None]:
if eigv is not None and len(eigv) > 1:
    fig, ax = Plotter.get_subplots(
        nrows       = 1,
        ncols       = 1,
        figsize     = (4, 3),
        dpi         = 100,
    )
    x   = np.arange(0, len(eigv))
    y   = eigv
    ax[0].plot(x, y / lattice.ns, 'o', markersize=2)
    ax[0].set_xlabel(r'$\mathcal{e}$')
    ax[0].set_ylabel(r'$E/N_s$')
    ax[0].axhline(eigv[0] / lattice.ns, color='r', linestyle='--', label=f'Ground state {eigv[0] / lattice.ns :.3e}')
    ax[0].legend()

### Network - variational ansatz body $\psi _\theta (s)$ and its gradient $\nabla \psi _\theta (s)$.

In [None]:
if network_type == 'RBM':
    net = RBM(
        input_shape         = st_shape, 
        n_hidden            = int(alpha * ns),
        dtype               = dtypex,
        param_dtype         = dtypex,
        seed                = seed,
        visible_bias        = True,
        bias                = True,
    )
elif network_type == 'CNN':
    net     = CNN(
        input_shape         = st_shape,
        reshape_dims        = (lx, ly * mult),
        features            = (8,) * alpha,
        strides             = [(1, 1)] * alpha,
        kernel_sizes        = [(2, 2)] * alpha,
        activations         = [elu_jnp] * alpha,
        dtype               = dtypex,
        param_dtype         = dtypex,
        final_activation    = elu_jnp,
        seed                = seed,
        output_shape        = (1,)
    )
else:
    raise ValueError(f"Unknown network type: {network_type}")
net

### Sampler - sampling from the distribution $p_\theta (s)$.

In [None]:
n_chains        = 5
n_samples       = 200
n_therm_steps   = 25
sampler         = Sampler.MCSampler(
                    net             = net,
                    shape           = st_shape,
                    rng             = rng,
                    rng_k           = rng_k,
                    numchains       = n_chains,
                    numsamples      = n_samples,
                    sweep_steps     = min(ns, 28),
                    backend         = backend_np,
                    therm_steps     = n_therm_steps,
                    mu              = 2.0,
                    seed            = seed,
                    dtype           = dtypex,
                    statetype       = np.float64,
                    makediffer      = True
                )
do_tests    = False
sampler_fun = sampler.get_sampler_jax()
sampler

#### Optional tests

In [None]:
%%timeit -r 5 -n 5
if do_tests:
    sampler.sample()

In [None]:
%%timeit -r 5 -n 5
if do_tests:
    sampler_fun(sampler.states, sampler.rng_k, net.get_params())

In [None]:
if do_tests:
    samples = sampler.sample()
    samples[1][0].shape

In [None]:
import time 

# seems is 5-10x faster than vmc_jax
def multiple_samples(n):
    samples = []
    for i in range(n):
        start       = time.time()
        samples     = sampler.sample()
        samples[0][0].block_until_ready()
        end         = time.time()
        print(f"Time taken for iteration {i}: {end - start:.4f} seconds")
    return samples

if do_tests:
    samples = multiple_samples(50)

In [None]:
if do_tests:
    x, y, z     = samples
    y_st, y_an  = y
    logger.info(f"y_st: {y_st}, shape: {y_st.shape}")

### Stepper - TDVP stepper for the time evolution of the state $\psi _\theta (s)$.

In [None]:
ode_solver = choose_ode(ode_type = ode_solver_type, backend = backend_np, dt = lr, rhs_prefactor = -1.0)
ode_solver

### NQS - neural network quantum state $\psi_\theta (s)$ and its gradient $\nabla \psi_\theta (s)$.

In [None]:
n_epo           = 5000
n_sweep_steps   = ns
n_batch         = 128

# Other
reg             = 5e-2
maxiter         = 1000
tolerance       = 1e-8
use_min_sr      = False

# Solver
solver_id       = solvers.SolverType.SCIPY_CG
precond_id      = preconditioners.PreconditionersTypeSym.JACOBI
precond         = preconditioners.choose_precond(precond_id=precond_id, backend=backend_np)
# precond_id      = None

tdvp = TDVP.TDVP(
    use_sr          = True,
    use_minsr       = False,
    rhs_prefact     = 1.0,
    sr_lin_solver   = solver_id,
    sr_precond      = precond,
    sr_pinv_tol     = tolerance,
    sr_pinv_cutoff  = 1e-8,
    sr_snr_tol      = tolerance,
    sr_diag_shift   = reg,
    sr_lin_solver_t = solvers.SolverForm.GRAM,
    sr_maxiter      = maxiter,
    backend         = backend_np
)
tdvp

In [None]:
nqs = NQS.NQS(
            net             = net,
            sampler         = sampler,
            hamiltonian     = hamil,
            lower_betas     = None,
            lower_states    = None,
            seed            = seed,
            beta            = 1.0,
            mu              = sampler.get_mu(),
            shape           = st_shape,
            backend         = backend_np,
            batch_size      = n_batch,
            dtype           = dtypex,  
        )
nqs.reset()

nqs_train = NQST.NQSTrainer(
    nqs             = nqs,
    ode_solver      = ode_solver,
    tdvp            = tdvp,
    n_batch         = n_batch,
    lr_scheduler    = ExponentialDecayScheduler(lr, n_epo, lr_decay=3e-4, logger=logger, lr_clamp=3e-2),
    early_stopper   = EarlyStopping(patience = 500, min_delta=1e-4, logger=logger),
    reg_scheduler   = ConstantScheduler(reg, max_epochs=n_epo, lr_clamp=1e-2, logger=logger),
    logger          = logger,
)
nqs

In [None]:
nqs_train

### Test the training of the NQS with a simple Hamiltonian.

In [None]:
history, history_std, timings = nqs_train.train(n_epochs=n_epo, 
                        reset=False, use_lr_scheduler=(ode_solver_type=='Euler'), use_reg_scheduler=False)

Training:  10%|▉         | 493/5000 [1:55:08<17:44:22, 14.17s/it, E/N=-2.6935e-01+6.7591e-05j ± 1.0928e-02, lr=8.6e-02, sig=5.0e-02, t_sample=2.60e+00s, t_step=1.14e+01s, t_update=5.36e-04s, t_gradient=1.93e-02s, t_prepare=5.86e-01s, t_solve=1.08e+01s, t_total=2.54e+01s]

In [None]:
energy_plot = eigv[0] if eigv[0] is not None else -18.583590144097595
fig, ax = nqs_train.report_gs(eigv  =   energy_plot,
                            last_n  =   0.05,
                            savedir =   f'./data/nqs_train/{ham_type}/',
                            plot_kw =   {
                                'ylim_0'        : (-0.5, 3e-2),
                                'ylim_1'        : (5e-2, 1e0),
                                'inset_axes'    : (0.18, 0.02, 0.4, 0.4),
                                'annotate_x'    : 0.93,
                                'annotate_y'    : 0.05
                            })

In [None]:
results, energy, timings = nqs.eval_observables(
    operators      = [sig_z, sig_x, sig_z_c, sig_x_c],
    true_values    = [ed_sig_z_exp, ed_sig_x_exp, ed_sig_z_exp_c, ed_sig_x_exp_c],
    n_chains       = n_chains,
    n_samples      = n_samples,
    batch_size     = 100,
    logger         = logger,
    plot           = True,
    get_energy     = True,
    bins           = 200,
    true_en        = eigv[0],
    plot_kwargs    = {
        'xlim_0'    : (-0.5, 0.1),
        'xlim_1'    : (-0.5, 0.1),
        'inset_axes': (0.6, 0.2, 0.35, 0.4),
    }
)

### Test modifier

In [None]:
from QES.Algebra.Operator.operator import initial_states, create_operator

sites_all       = []
operators_inv   = []
# locals
sites           = [0]

operators_inv.append(
    create_operator(
        type_act        =   op_spin.OperatorTypeActing.Global,
        op_func_int     =   None,
        op_func_np      =   None,
        op_func_jnp     =   op_spin.sigma_z_inv_jnp,
        ns              =   lattice.ns,
        sites           =   sites,
        name            =   'sig_z_0',
        extra_args      =   (True, 1.0),
        modifies        =   False,
        
    )
)

operators_inv.append(
    create_operator(
        type_act        =   op_spin.OperatorTypeActing.Global,
        op_func_int     =   None,
        op_func_np      =   None,
        op_func_jnp     =   op_spin.sigma_x_inv_jnp,
        ns              =   lattice.ns,
        sites           =   sites,
        name            =   'sig_x_0',
        extra_args      =   (True, 1.0),
        modifies        =   False,
        
    )
)

# correlations
sites_corr = jnp.array([0, 3], dtype=jnp.int32)
def sig_z_c_modifier(x):
    return op_spin.sigma_z_inv_jnp(x, sites=sites_corr)

def sig_x_c_modifier(x):
    return op_spin.sigma_x_inv_jnp(x, sites=sites_corr)

# k-space
k       = 1.0
sites_k = jnp.arange(0, lattice.ns, dtype=jnp.int32)
def sig_k_modifier(x):
    return op_spin.sigma_k_inv_jnp(x, k=k, sites=sites_k)


names = ['sig_z', 'sig_x', 'sig_z_c', 'sig_x_c', 'sig_k']

sig_z_modifier_jit      = operators_inv[0].jax
sig_x_modifier_jit      = operators_inv[1].jax
sig_z_c_modifier_jit    = jax.jit(sig_z_c_modifier)
sig_x_c_modifier_jit    = jax.jit(sig_x_c_modifier)
sig_k_modifier_jit      = jax.jit(sig_k_modifier)

In [None]:
# test modifiers
_, _, st_jax = initial_states(lattice.ns, display=True)
st_jax       = (st_jax - 0.5) * 2

In [None]:
for modifier in [sig_z_modifier_jit, sig_x_modifier_jit]:
    st_out, st_val = modifier(st_jax)
    logger.title('Modifier', desired_size=150, fill='#', color='red')
    logger.info(f"Modified state: {st_out[0]}")
    logger.info(f"Original state: {st_jax}")
    logger.info(f"Modified state value: {st_val}")
    logger.info(f"Modified state shape: {st_out.shape}")

#### $\sigma_z$ - Pauli Z operator at site $i=0$.

In [None]:
nqs.unset_modifier()
two_states      = jnp.array([st_jax, st_jax], dtype=jnp.float32)
nqs_ansatz_old  = nqs.ansatz
nqs_eval_before = nqs_ansatz_old(nqs.get_params(), two_states)
nqs_eval_before

In [None]:
nqs.set_modifier(modifier=sig_z_modifier_jit, name=operators_inv[0].name)

Try to evaluate

In [None]:
nqs_ansatz_new  = nqs.ansatz
nqs_eval_new    = nqs_ansatz_new(nqs.get_params(), two_states)
nqs_eval_new

In [None]:
results, energy, timings = nqs.eval_observables(
    operators      = [sig_z],
    true_values    = [ed_sig_z_exp],
    n_chains       = n_chains,
    n_samples      = n_samples,
    batch_size     = 100,
    logger         = logger,
    plot           = True,
    get_energy     = True,
    bins           = 200,
    true_en        = energy_0,
)