In [28]:
import sys
import os
# os.environ['JAX_PLATFORM_NAME'] = 'cpu'

import numpy as np
import time

# Add the parent directory (project root) to sys.path
script_dir = os.path.dirname(os.curdir)
parent_dir = os.path.abspath(os.path.join(script_dir, '..'))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# ------------------------------------------------------------------
#! General
from QES.general_python.common.timer import Timer, timeit
from QES.general_python.common.plot import Plotter, MatrixPrinter
from QES.general_python.common.binary import JAX_AVAILABLE, get_backend, get_global_logger

# ------------------------------------------------------------------
#! Sampler
import QES.Solver.MonteCarlo.sampler as Sampler

# ------------------------------------------------------------------
#! Networks
from QES.general_python.ml.net_impl.networks.net_rbm import RBM
from QES.general_python.ml.net_impl.networks.net_cnn import CNN
from QES.general_python.ml.net_impl.activation_functions import relu_jnp, tanh_jnp, sigmoid_jnp, leaky_relu_jnp, elu_jnp, poly6_jnp, softplus_jnp

#! Backends
if JAX_AVAILABLE:
    import jax
    import jax.numpy as jnp
else:
    jax = None
    jnp = np

# ------------------------------------------------------------------
network_type    = 'rbm'  # 'cnn' or 'rbm'
lx              = 5
ly              = 2
ns              = 10
mult            = (ns) // (lx * ly)
st_shape        = (ns, )
alpha           = 2
dtypex          = jnp.complex64
seed            = 1234

logger          = get_global_logger()
backend         = 'jax'
be_modules      = get_backend(backend, random=True, seed=seed, scipy=True)
backend_np, (rng, rng_k), backend_sp = be_modules if isinstance(be_modules, tuple) else (be_modules, (None, None), None)

In [29]:
if network_type == 'rbm':
    net = RBM(
        input_shape         = st_shape, 
        n_hidden            = int(alpha * ns),
        dtype               = dtypex,
        param_dtype         = dtypex,
        seed                = seed,
        visible_bias        = True,
        bias                = True,
    )
elif network_type == 'cnn':
    net     = CNN(
        input_shape         = st_shape,
        reshape_dims        = (lx, ly * mult),
        features            = (8,) * alpha,
        strides             = [(1, 1)] * alpha,
        kernel_sizes        = [(2, 2)] * alpha,
        activations         = [elu_jnp] * alpha,
        dtype               = dtypex,
        param_dtype         = dtypex,
        final_activation    = elu_jnp,
        seed                = seed,
        output_shape        = (1,)
    )
else:
    raise ValueError(f"Unknown network type: {network_type}")
net

18_05_2025_20-37_24 [INFO] 	->[34m[GeneralNet] Holomorphic check result (||∇Re[f] - i*∇Im[f]|| / ||∇Re[f]|| ≈ 0): True[0m
18_05_2025_20-37_24 [INFO] 	->[34m[GeneralNet] FlaxInterface initialized: dtype=complex64, is_complex=True, nparams=230, is_holomorphic=True[0m


ComplexRBM(shape=(10,), hidden=20, bias=on, visible_bias=on, dtype=complex64, params=230, analytic_grad=False, initialized)

In [30]:
n_chains        = 5
n_samples       = 200
n_therm_steps   = 25
sampler         = Sampler.MCSampler(
                    net             = net,
                    shape           = st_shape,
                    rng             = rng,
                    rng_k           = rng_k,
                    numchains       = n_chains,
                    numsamples      = n_samples,
                    sweep_steps     = min(ns, 28),
                    backend         = backend_np,
                    therm_steps     = n_therm_steps,
                    mu              = 2.0,
                    seed            = seed,
                    dtype           = dtypex,
                    statetype       = np.float64,
                    makediffer      = True
                )
do_tests    = False
sampler_fun = sampler.get_sampler_jax()
sampler

MCSampler(shape=(10,), mu=2.0, beta=1.0, therm_steps=25, sweep_steps=10, numsamples=200, numchains=5, backend=jax)

### Test many samples through sample() method 

In [31]:
%%timeit -r 5 -n 5
if do_tests:
    sampler.sample()

58.3 ns ± 50.3 ns per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Test many samples through obtained sampler function 

In [32]:
%%timeit -r 5 -n 5
if do_tests:
    sampler_fun(sampler.states, sampler.rng_k, net.get_params())

46.7 ns ± 41 ns per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [49]:
def multiple_samples(n):
    samples = []
    times   = []
    
    def shall_print(i):
        if n < 50:
            return True
        if i % (n // 10) == 0:
            return True
        return False
    
    logger.info(f"Sampling {n} times", color='green')
    for i in range(n):
        s, t        = timeit(sampler.sample)
        samples.append(s)
        times.append(t)
        if shall_print(i):
            logger.info(f"Iteration {i}: {t:.4e} seconds", color='blue', lvl=1)
    # statistics
    times   = np.array(times)
    logger.info(f"Mean time: {np.mean(times):.4e} seconds", color='white', lvl=0)
    logger.info(f"Max time: {np.max(times):.4e} seconds", color='red', lvl=0)
    logger.info(f"Min time: {np.min(times):.4e} seconds", color='green', lvl=0)
    logger.info(f"Std time: {np.std(times):.4e} seconds", color='yellow', lvl=0)
    
    return samples, times
samples, times = multiple_samples(100)

18_05_2025_20-40_55 [INFO] [32mSampling 100 times[0m
18_05_2025_20-40_55 [INFO] 	->[34mIteration 0: 5.2204e-02 seconds[0m
18_05_2025_20-40_56 [INFO] 	->[34mIteration 10: 3.7067e-02 seconds[0m
18_05_2025_20-40_56 [INFO] 	->[34mIteration 20: 3.7114e-02 seconds[0m
18_05_2025_20-40_57 [INFO] 	->[34mIteration 30: 3.6813e-02 seconds[0m
18_05_2025_20-40_57 [INFO] 	->[34mIteration 40: 3.6877e-02 seconds[0m
18_05_2025_20-40_57 [INFO] 	->[34mIteration 50: 3.6599e-02 seconds[0m
18_05_2025_20-40_58 [INFO] 	->[34mIteration 60: 3.7736e-02 seconds[0m
18_05_2025_20-40_58 [INFO] 	->[34mIteration 70: 3.7103e-02 seconds[0m
18_05_2025_20-40_58 [INFO] 	->[34mIteration 80: 3.6729e-02 seconds[0m
18_05_2025_20-40_59 [INFO] 	->[34mIteration 90: 3.6184e-02 seconds[0m
18_05_2025_20-40_59 [INFO] [0mMean time: 3.7490e-02 seconds[0m
18_05_2025_20-40_59 [INFO] [31mMax time: 5.2204e-02 seconds[0m
18_05_2025_20-40_59 [INFO] [32mMin time: 3.6057e-02 seconds[0m
18_05_2025_20-40_59 [INFO] [33