In [6]:
import sys
import os
import numbers
import numpy as np
from IPython.display import Math, display
from typing import Union, List, Sequence, Optional

# Add the parent directory (project root) to sys.path
script_dir              = os.path.dirname(os.curdir)
parent_dir              = os.path.abspath(os.path.join(script_dir, '..'))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from QES.general_python.algebra.utils import JAX_AVAILABLE
from QES.general_python.lattices.honeycomb import HoneycombLattice
from QES.general_python.lattices.square import SquareLattice
from QES.general_python.lattices.lattice import LatticeBC, Lattice
import QES.general_python.common.binary as _bin_mod
from QES.general_python.common.timer import Timer
from QES.general_python.common.display import (
    display_state,
    display_operator_action,
)

from QES.Algebra.Operator.operator import OperatorTypeActing, Operator
#! Spin operators
import QES.Algebra.Operator.operators_spin as op_spin
#! Fermionic operators
import QES.Algebra.Operator.operators_spinless_fermions as op_sferm


#! Backends
if JAX_AVAILABLE:
    import jax
    import jax.numpy as jnp
else:
    jax = None
    jnp = None
    
backend = 'np'
lat     = SquareLattice(dim     = 1,
                        lx      = 5,
                        ly      = 1,
                        lz      = 1,
                        bc      = LatticeBC.PBC)

#! Create functions for testing

def test_operator_on_state(op           : Union[Operator, Sequence[Operator]],
                        lat             : Lattice,
                        state           : Union[int, np.ndarray, jnp.ndarray],
                        *,
                        ns              : Optional[int] = None,
                        op_acting       : "OperatorTypeActing" = OperatorTypeActing.Local,
                        op_label        = None,
                        to_bin          = None) -> None:
    r"""
    Pretty-print the action of *one or several* lattice operators
    on a basis state or wave-function.

    Parameters
    ----------
    op : Operator or sequence[Operator]
        The operator(s) Ô acting on 0, 1, or 2 sites.
    lat : Lattice
        Provides the number of sites ``lat.ns`` = :math:`N_s`.
    state : int | np.ndarray | jax.numpy.ndarray
        *Basis state* (integer encoding) or *wave-function* :math:`|\psi\rangle`.
    op_acting : OperatorTypeActing, default = ``Local``
        How Ô acts: Local (Ôᵢ), Correlation (Ôᵢⱼ), or Global (Ô).
    op_label : str | sequence[str], optional
        LaTeX label(s).  If *None*, uses ``op.name`` for every operator.
    to_bin : Callable[[int, int], str], optional
        Integer → binary-string formatter.  Defaults to
        ``lambda k,L: format(k, f'0{L}b')``.

    Notes
    -----
    * For **integer states** we reproduce the coefficient table you had before.
    * For **array states** (NumPy / JAX) we show only the *first* non-zero
        coefficient returned by the operator.  Adjust if you need more detail.
    """
    # ------------------------------------------------------------------
    ops      = (op,) if not isinstance(op, Sequence) else tuple(op)
    labels   = _prepare_labels(ops, op_label)
    is_int   = isinstance(state, (numbers.Integral, int, np.integer)) 
    ns       = lat.ns if ns is None else ns
    
    # ------------------------------------------------------------------
    if is_int:
        display_state(state,
                    ns,
                    label  = f"Initial integer state (Ns={ns})",
                    to_bin = to_bin)
    else:
        print(f"Input state is a {state}")

    # ------------------------------------------------------------------
    with Timer(verbose=True):
        for cur_op, lab in zip(ops, labels):
            display(Math(fr"\text{{Operator: }} {lab}, \text{{typeacting}}: {op_acting}"))
            
            if op_acting == OperatorTypeActing.Local:
                for i in range(ns):
                    display(Math(fr"\quad \text{{Site index: }} {i}"))
                    _dispatch(cur_op, state, lat, is_int, to_bin, lab, i=i)

            elif op_acting == OperatorTypeActing.Correlation:
                for i in range(ns):
                    for j in range(ns):
                        display(Math(fr"\text{{Site indices: }} {i}, {j}"))
                        _dispatch(cur_op, state, lat, is_int, to_bin, lab, i=i, j=j)

            elif op_acting == OperatorTypeActing.Global:
                _dispatch(cur_op, state, lat, is_int, to_bin, lab)

            else:
                raise ValueError(f"Operator acting type {op_acting!r} not supported.")


#------------------------------------------------------------------------
#! Helper functions
#------------------------------------------------------------------------

def _prepare_labels(ops, op_label):
    """Return a tuple of LaTeX labels matching *ops*."""
    if op_label is None:
        return tuple(getattr(op, "name", f"op_{k}") for k, op in enumerate(ops))
    if isinstance(op_label, Sequence):
        if len(op_label) != len(ops):
            raise ValueError("Length of op_label must match number of operators")
        return tuple(op_label)
    return tuple(op_label for _ in ops)

def _dispatch(op, state, lat, is_int, to_bin, lab, *, i=None, j=None):
    """
    Call *op* with the right signature and send its first output to
    `display_operator_action`.
    """
    
    if is_int:
        state_act = state
    else:
        state_act = state.copy()
    
    # call signature depends on OperatorTypeActing
    if i is None:                        # Global operator
        st_out, coeff = op(state_act)
    elif j is None:                      # Local operator
        if is_int:
            st_out, coeff = op(state_act, i)
        elif isinstance(state_act, (np.ndarray)):
            sites         = np.array([i])
            st_out, coeff = op(state_act, sites)
        elif isinstance(state_act, (jnp.ndarray)):
            sites         = jnp.array([i])
            st_out, coeff = op(state_act, sites[0])
    else:                                # Correlation operator
        if is_int:
            st_out, coeff = op(state_act, i, j)
        elif isinstance(state_act, (np.ndarray)):
            sites         = np.array([i, j])
            st_out, coeff = op(state_act, sites)
        elif isinstance(state_act, (jnp.ndarray)):
            sites         = jnp.array([i, j])
            st_out, coeff = op(state_act, sites[0], sites[1])
            
    # --------------- choose what to show depending on state representation ---
    if is_int:
        new_state = None if st_out is None else st_out[0]
        new_coeff = 0    if st_out is None else coeff[0]
    else:
        # For array-based states: show the first coefficient, state display N/A
        new_state = None
        new_coeff = coeff[0] if np.size(coeff) else coeff

    display_operator_action(f"\\quad \\quad {lab}",
                            i if j is None else (i, j),
                            state,
                            lat.ns,
                            new_state,
                            new_coeff,
                            to_bin = to_bin)

## Spin operators $\sigma^x$, $\sigma^y$, $\sigma^z$

The Pauli matrices are defined as:
$$
\sigma^x = \begin{pmatrix}
0 & 1 \\
1 & 0
\end{pmatrix}, \quad
\sigma^y = \begin{pmatrix}
0 & -i \\
-i & 0
\end{pmatrix}, \quad
\sigma^z = \begin{pmatrix}
1 & 0 \\
0 & -1
\end{pmatrix}
$$

### a) Local operators

In [7]:
sig_x = op_spin.sig_x(
    lattice  = lat,
    type_act = op_spin.OperatorTypeActing.Local
)
sig_y = op_spin.sig_y(
    lattice  = lat,
    type_act = op_spin.OperatorTypeActing.Local
)
sig_z = op_spin.sig_z(
    lattice  = lat,
    type_act = op_spin.OperatorTypeActing.Local
)

# create a set of states and test it
int_state = np.random.randint(0, 2**lat.ns, dtype=np.int32)
np_state  = np.ones((lat.ns), dtype = np.float64)
jnp_state = jnp.ones((lat.ns), dtype = jnp.float64)

int_state, np_state, jnp_state

(19, array([1., 1., 1., 1., 1.]), Array([1., 1., 1., 1., 1.], dtype=float64))

#### a) Integer states

In [None]:
test_operator_on_state(
    op          = [sig_x, sig_z],
    lat         = lat,
    state       = int_state,
    op_acting   = OperatorTypeActing.Local,
    op_label    = ['\\sigma^x', '\\sigma^z'],
    to_bin      = lambda k, L: format(k, f'0{L}b'),
)


<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

None: 0.1108 seconds
