In [6]:
import jax
import jax.numpy as jnp
from jax.numpy.linalg import eigh
from scipy.sparse import csr_matrix, kron, identity

# Define Pauli matrices and identity matrix using scipy sparse matrices
X = csr_matrix([[0, 1], [1, 0]], dtype=jnp.float32)
Z = csr_matrix([[1, 0], [0, -1]], dtype=jnp.float32)
I = identity(2, dtype=jnp.float32)
n_op = csr_matrix([[0, 0], [0, 1]], dtype=jnp.float32)

class Rydberg:
    """
    Hamiltonian for the Rydberg Model.

    Attributes
    ----------
    system_size : int
        Size of the physical system.
    positions : array_like
        Positions of the atoms.
    periodic : bool
        Whether the system has periodic boundary conditions.
    """

    def __init__(self, system_size, positions, periodic=True):
        self.system_size = system_size
        self.n = system_size
        self.param_dim = 2
        self.param_range = jnp.array([[0, -2], [5, 2]])
        self.Omega = 1
        self.delta = 1
        self.positions = jnp.array(positions)
        self.periodic = periodic

    def update_param(self, param):
        assert len(param) == 2
        self.Omega = param[0]
        self.delta = param[1]

    def full_H(self, param=None):
        if param is None:
            Omega, delta = self.Omega, self.delta
        else:
            Omega, delta = param

        # Initialize the Hamiltonian
        Hamiltonian = csr_matrix((2 ** self.n, 2 ** self.n), dtype=jnp.float32)

        # Add the - Omega/2 * sum(sigma_x) term
        for i in range(self.n):
            H_x = 1
            for j in range(self.n):
                if i == j:
                    H_x = kron(H_x, X, format='csr')
                else:
                    H_x = kron(H_x, I, format='csr')
            Hamiltonian -= (Omega / 2) * H_x

        # Add the - delta * sum(n_i) term
        for i in range(self.n):
            H_n = 1
            for j in range(self.n):
                if i == j:
                    H_n = kron(H_n, n_op, format='csr')
                else:
                    H_n = kron(H_n, I, format='csr')
            Hamiltonian -= delta * H_n

        # Add the V_ij * n_i * n_j term
        for i in range(self.n):
            for j in range(i + 1, self.n):
                V_ij = 7 / jnp.linalg.norm(self.positions[i] - self.positions[j]) ** 6
                H_nn = 1
                for k in range(self.n):
                    if k == i or k == j:
                        H_nn = kron(H_nn, n_op, format='csr')
                    else:
                        H_nn = kron(H_nn, I, format='csr')
                Hamiltonian += V_ij * H_nn

        return Hamiltonian

    def DMRG(self, param=None, verbose=False):
        if param is None:
            param = [self.Omega, self.delta]

        full_Hamiltonian = self.full_H(param).toarray()

        # Compute the ground state energy using jax.numpy.linalg.eigh
        evals, evecs = eigh(full_Hamiltonian)

        E_ground = evals[0]
        psi_ground = evecs[:, 0]

        if verbose:
            print(f"Ground state energy: {E_ground:.13f}")

        return E_ground, psi_ground

# Example usage
positions = [[i, 0] for i in range(10)]
rydberg = Rydberg(system_size=10, positions=positions)
rydberg.update_param([2, -1])
H = rydberg.full_H()
E_ground, psi_ground = rydberg.DMRG(verbose=True)


ValueError: could not interpret dimensions