In [1]:
import sys
import os
import numbers
import numpy as np
from IPython.display import Math, display
from typing import Union, List, Sequence, Optional

# Add the parent directory (project root) to sys.path
script_dir              = os.path.dirname(os.curdir)
parent_dir              = os.path.abspath(os.path.join(script_dir, '..'))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from QES.general_python.lattices.honeycomb import HoneycombLattice
from QES.general_python.lattices.square import SquareLattice
from QES.general_python.lattices.lattice import LatticeBC, Lattice
import QES.general_python.common.binary as _bin_mod
from QES.general_python.common.timer import Timer
from QES.general_python.common.display import (
    display_state,
    display_operator_action,
)

from QES.Algebra.Operator.operator import OperatorTypeActing, Operator, JAX_AVAILABLE
#! Spin operators
import QES.Algebra.Operator.operators_spin as op_spin
#! Fermionic operators
import QES.Algebra.Operator.operators_spinless_fermions as op_sferm


#! Backends
if JAX_AVAILABLE:
    import jax
    import jax.numpy as jnp
else:
    jax = None
    jnp = None
    
backend = 'np'
lat     = SquareLattice(dim     = 1,
                        lx      = 5,
                        ly      = 1,
                        lz      = 1,
                        bc      = LatticeBC.PBC)

#! Create functions for testing

def test_operator_on_state(op           : Union[Operator, Sequence[Operator]],
                        lat             : Lattice,
                        state           : Union[int, np.ndarray, jnp.ndarray],
                        *,
                        ns              : Optional[int] = None,
                        op_acting       : "OperatorTypeActing" = OperatorTypeActing.Local,
                        op_label        = None,
                        to_bin          = None,
                        just_time       = False
                        ) -> None:
    r"""
    Pretty-print the action of *one or several* lattice operators
    on a basis state or wave-function.

    Parameters
    ----------
    op : Operator or sequence[Operator]
        The operator(s) Ô acting on 0, 1, or 2 sites.
    lat : Lattice
        Provides the number of sites ``lat.ns`` = :math:`N_s`.
    ns : int, optional
        Number of sites.  If *None*, uses ``lat.ns``.
    state : int | np.ndarray | jax.numpy.ndarray
        *Basis state* (integer encoding) or *wave-function* :math:`|\psi\rangle`.
    op_acting : OperatorTypeActing, default = ``Local``
        How Ô acts: Local (Ôᵢ), Correlation (Ôᵢⱼ), or Global (Ô).
    op_label : str | sequence[str], optional
        LaTeX label(s).  If *None*, uses ``op.name`` for every operator.
    to_bin : Callable[[int, int], str], optional
        Integer → binary-string formatter.  Defaults to
        ``lambda k,L: format(k, f'0{L}b')``.
    just_time : bool, default = False
        If True, only measure the time taken for the operation without
        displaying the results.  This is useful for benchmarking.

    Notes
    -----
    * For **integer states** we reproduce the coefficient table you had before.
    * For **array states** (NumPy / JAX) we show only the *first* non-zero
        coefficient returned by the operator.  Adjust if you need more detail.
    """
    # ------------------------------------------------------------------
    ops      = (op,) if not isinstance(op, Sequence) else tuple(op)
    is_int   = isinstance(state, (numbers.Integral, int, np.integer)) 
    ns       = lat.ns if ns is None else ns
    if not just_time:
        labels   = _prepare_labels(ops, op_label)
    else:
        labels   = [''] * len(ops)
    
    def display_state_in(*args, **kwargs):
        if not just_time:
            display_state(*args, **kwargs)

    def display_in(*args, **kwargs):
        if not just_time:
            display(*args, **kwargs)
    
    # ------------------------------------------------------------------
    display_state_in(state,
                ns,
                label  = f"Initial integer state (Ns={ns})",
                to_bin = to_bin)

    # ------------------------------------------------------------------
    with Timer(verbose=True, precision='us', name="Operator action"):
        for cur_op, lab in zip(ops, labels):
            display_in(Math(fr"\text{{Operator: }} {lab}, \text{{typeacting}}: {op_acting}"))
            
            if op_acting == OperatorTypeActing.Local:
                for i in range(ns):
                    display_in(Math(fr"\quad \text{{Site index: }} {i}"))
                    s, c = _dispatch(cur_op, state, lat, is_int, to_bin, lab, i=i, just_time=just_time)

            elif op_acting == OperatorTypeActing.Correlation:
                for i in range(ns):
                    for j in range(ns):
                        display_in(Math(fr"\text{{Site indices: }} {i}, {j}"))
                        s, c = _dispatch(cur_op, state, lat, is_int, to_bin, lab, i=i, j=j, just_time=just_time)

            elif op_acting == OperatorTypeActing.Global:
                s, c = _dispatch(cur_op, state, lat, is_int, to_bin, lab, just_time=just_time)

            else:
                raise ValueError(f"Operator acting type {op_acting!r} not supported.")

#------------------------------------------------------------------------
#! Helper functions
#------------------------------------------------------------------------

def _prepare_labels(ops, op_label):
    """Return a tuple of LaTeX labels matching *ops*."""
    if op_label is None:
        return tuple(getattr(op, "name", f"op_{k}") for k, op in enumerate(ops))
    if isinstance(op_label, Sequence):
        print(f"op_label: {op_label}", len(op_label))
        if len(op_label) != len(ops):
            raise ValueError("Length of op_label must match number of operators")
        return tuple(op_label)
    return tuple(op_label for _ in ops)

def _dispatch(op, state, lat, is_int, to_bin, lab, *, i=None, j=None, just_time=False):
    """
    Call *op* with the right signature and send its first output to
    `display_operator_action`.
    """
    
    if is_int:
        state_act = state
    else:
        state_act = state.copy()
    
    # call signature depends on OperatorTypeActing
    if i is None:                        # Global operator
        st_out, coeff = op(state_act)
    elif j is None:                      # Local operator
        if is_int:
            st_out, coeff = op(state_act, i)
        elif isinstance(state_act, (np.ndarray)):
            sites         = np.array([i])
            st_out, coeff = op(state_act, *sites)
        elif isinstance(state_act, (jnp.ndarray)):
            sites         = jnp.array([i])
            st_out, coeff = op(state_act, sites[0])
    else:                                # Correlation operator
        if is_int:
            st_out, coeff = op(state_act, i, j)
        elif isinstance(state_act, (np.ndarray)):
            sites         = np.array([i, j])
            st_out, coeff = op(state_act, *sites)
        elif isinstance(state_act, (jnp.ndarray)):
            sites         = jnp.array([i, j])
            st_out, coeff = op(state_act, sites[0], sites[1])
            
    #! choose what to show depending on state representation
    new_state = st_out
    new_coeff = coeff
    if not just_time:
        display_operator_action(f"\\quad \\quad {lab}",
                                i if j is None else (i, j),
                                state,
                                lat.ns,
                                new_state,
                                new_coeff,
                                to_bin = to_bin)
    return new_state, new_coeff

# ------------------------------------------------------------------

def initial_states(ns: int, display: bool = False) -> tuple:
    int_state = np.random.randint(0, 2**(ns%64), dtype=np.int32)
    if ns >= 64:
        np_state  = np.random.randint(0, 2, size=lat.ns).astype(np.float32)
        jnp_state = jnp.array(np_state)
    else:
        np_state  = _bin_mod.int2base_np(int_state, size = ns, value_true=1, value_false=0).astype(np.float32)
        jnp_state = jnp.array(np_state)
    
    if display:
        display_state(int_state, lat.ns,    label = "Integer state")
        display_state(np_state, lat.ns,     label = "NumPy state")
        display_state(jnp_state, lat.ns,    label = "JAX state")
    
    return int_state, np_state, jnp_state

# ------------------------------------------------------------------

07_05_2025_19-31_54 [INFO] Log file created: ./log/QES_07_05_2025_19-31_54.log
07_05_2025_19-31_54 [INFO] Log level set to: info
07_05_2025_19-31_54 [INFO] ############Global logger initialized.############
07_05_2025_19-31_54 [INFO] JAX backend available and successfully imported
07_05_2025_19-31_54 [INFO] 	JAX 64-bit precision enabled.
07_05_2025_19-31_54 [INFO] Setting JAX as the active backend.
07_05_2025_19-31_54 [INFO] **************************************************
07_05_2025_19-31_54 [INFO] Backend Configuration:
07_05_2025_19-31_54 [INFO] 		NumPy Version: 2.1.3
07_05_2025_19-31_54 [INFO] 		SciPy Version: 1.15.2
07_05_2025_19-31_54 [INFO] 		JAX Version: 0.5.3
07_05_2025_19-31_54 [INFO] 		Active Backend: jax
07_05_2025_19-31_54 [INFO] 			JAX Available: True
07_05_2025_19-31_54 [INFO] 			Default Seed: 42
07_05_2025_19-31_54 [INFO] 		JAX Backend Details:
07_05_2025_19-31_54 [INFO] 				Main Module: jax.numpy
07_05_2025_19-31_54 [INFO] 				Random Module: jax.random (+ PRNGKey)
07

## Spin operators $\sigma^x$, $\sigma^y$, $\sigma^z$

The Pauli matrices are defined as:
$$
\sigma^x = \begin{pmatrix}
0 & 1 \\
1 & 0
\end{pmatrix}, \quad
\sigma^y = \begin{pmatrix}
0 & -i \\
-i & 0
\end{pmatrix}, \quad
\sigma^z = \begin{pmatrix}
1 & 0 \\
0 & -1
\end{pmatrix}
$$

### a) Local operators

In [2]:
type_act = op_spin.OperatorTypeActing.Local

sig_x_l = op_spin.sig_x(
    lattice  = lat,
    type_act = type_act
)
sig_y_l = op_spin.sig_y(
    lattice  = lat,
    type_act = type_act
)
sig_z_l = op_spin.sig_z(
    lattice  = lat,
    type_act = type_act
)
sig_p_l = op_spin.sig_p(
    lattice  = lat,
    type_act = type_act
)
sig_m_l = op_spin.sig_m(
    lattice  = lat,
    type_act = type_act
)
sig_pm_l = op_spin.sig_pm(
    lattice  = lat,
    type_act = type_act
)
sig_mp_l = op_spin.sig_mp(
    lattice  = lat,
    type_act = type_act
)
operators_l  = [sig_x_l, sig_y_l, sig_z_l]#, sig_p_l]#, sig_m_l]# sig_pm_l, sig_mp_l]
labels_l     = ['\\sigma^x', '\\sigma^y', '\\sigma^z', 
                '\\sigma^+', '\\sigma^-', 
                '[\\sigma^+]\\sigma^{\\pm}', '[\\sigma^-]\\sigma^{\\mp}'][:len(operators_l)]
# create a set of states and test it
int_state, np_state, jnp_state = initial_states(lat.ns, display=True)

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

#### i) Integer states

In [3]:
test_operator_on_state(
    op          = operators_l,
    lat         = lat,
    state       = int_state,
    op_acting   = OperatorTypeActing.Local,
    op_label    = labels_l,
    to_bin      = None,
    just_time   = False
)


op_label: ['\\sigma^x', '\\sigma^y', '\\sigma^z', '\\sigma^+', '\\sigma^-', '[\\sigma^+]\\sigma^{\\pm}', '[\\sigma^-]\\sigma^{\\mp}'] 7


<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

Operator action: 1230875.750000 us


#### ii) Numpy states

In [4]:
test_operator_on_state(
    op          = operators_l,
    lat         = lat,
    state       = np_state,
    op_acting   = OperatorTypeActing.Local,
    op_label    = labels_l,
    to_bin      = None,
    just_time   = True
)


Operator action: 399.375000 us


TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1m[1m[1m[1mFailed in nopython mode pipeline (step: nopython frontend)
[1m[1m[1m[1mNo implementation of function Function(<class 'float'>) found for signature:
 
 >>> float(bool)
 
There are 4 candidate implementations:
[1m  - Of which 2 did not match due to:
  Overload in function 'Float.generic': File: numba/core/typing/old_builtins.py: Line 992.
    With argument(s): '(bool)':[0m
[1m   Rejected as the implementation raised a specific error:
     NumbaTypeError: [1mfloat() only support for numbers[0m[0m
  raised from /Users/makskliczkowski/miniconda3/lib/python3.12/site-packages/numba/core/typing/old_builtins.py:1005
[1m  - Of which 2 did not match due to:
  Overload of function 'float': File: numba/experimental/jitclass/overloads.py: Line 137.
    With argument(s): '(bool)':[0m
[1m   No match.[0m
[0m
[0m[1mDuring: resolving callee type: Function(<class 'float'>)[0m
[0m[1mDuring: typing of call at /Users/makskliczkowski/Codes/QuantumEigenSolver/Python/QES/general_python/common/binary.py (343)
[0m
[1m
File "../QES/general_python/common/binary.py", line 343:[0m
[1mdef flip_array_np_nspin(n           : Array,
    <source elided>
    """
[1m    n[k] = float(not (n[k] == spin_value)) * spin_value
[0m    [1m^[0m[0m

[0m[1mDuring: Pass nopython_type_inference[0m
[0m[1mDuring: resolving callee type: type(CPUDispatcher(<function sigma_y_np at 0x3131b67a0>))[0m
[0m[1mDuring: typing of call at /Users/makskliczkowski/Codes/QuantumEigenSolver/Python/QES/Algebra/Operator/operator.py (1684)
[0m
[0m[1mDuring: resolving callee type: type(CPUDispatcher(<function sigma_y_np at 0x3131b67a0>))[0m
[0m[1mDuring: typing of call at /Users/makskliczkowski/Codes/QuantumEigenSolver/Python/QES/Algebra/Operator/operator.py (1684)
[0m
[1m
File "../QES/Algebra/Operator/operator.py", line 1684:[0m
[1m        def fun_np(state, i):
            <source elided>
            sites_1 = np.array([i], dtype=np.int32)
[1m            return op_func_np(state, sites_1, *extra_args)
[0m            [1m^[0m[0m

[0m[1mDuring: Pass nopython_type_inference[0m

#### iii) JAX states

In [57]:
test_operator_on_state(
    op          = [sig_x_l, sig_z_l],
    lat         = lat,
    state       = jnp_state,
    op_acting   = OperatorTypeActing.Local,
    op_label    = ['\\sigma^x', '\\sigma^z'],
    to_bin      = None,
    just_time   = False
)

op_label: ['\\sigma^x', '\\sigma^z'] 2


<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

Operator action: 31927.042000 us


### b) Correlation functions

In [64]:
sig_x_c = op_spin.sig_x(
    lattice  = lat,
    type_act = op_spin.OperatorTypeActing.Correlation
)
sig_y_c = op_spin.sig_y(
    lattice  = lat,
    type_act = op_spin.OperatorTypeActing.Correlation
)
sig_z_c = op_spin.sig_z(
    lattice  = lat,
    type_act = op_spin.OperatorTypeActing.Correlation
)

int_state, np_state, jnp_state = initial_states(lat.ns, display = True)

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

#### i) Integer states

In [67]:
test_operator_on_state(
    op          = [sig_x_c, sig_y_c, sig_z_c],
    lat         = lat,
    state       = int_state,
    op_acting   = OperatorTypeActing.Correlation,
    op_label    = ['\\sigma^x', '\\sigma^y', '\\sigma^z'],
    to_bin      = None,
    just_time   = False
)

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

Operator action: 123653.333000 us


TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1mUntyped global name 'op_func_int':[0m [1m[1mCannot determine Numba type of <class 'function'>[0m
[1m
File "../QES/Algebra/Operator/operator.py", line 1711:[0m
[1m        def fun_int(state, i, j):
            <source elided>
            sites_2 = np.array([i, j], dtype=np.int32)
[1m            return op_func_int(state, ns, sites_2, *extra_args)
[0m            [1m^[0m[0m
[0m
[0m[1mDuring: Pass nopython_type_inference[0m