In [None]:
import quimb.tensor as qtn
import quimb as qu
import netket as nk

import numpy as np
import torch
import matplotlib.pyplot as plt

The task in this notebook is to find the ground state energy of a given ising model hamiltonian using the update algorithm on tensor networks.
For the update we use the update algorithm implemented in the `quimb` library. But first let's quickly review the basics of the update algorithm.

The imaginary time evolution (ITE) operator is defined as follows:
$$
U(\tau) = e^{-\tau H}
$$
For $\tau \rightarrow \infty$ the ITE operator maps any state to the ground state of the hamiltonian $H$. The ITE operator is not unitary (so, it has actually nothing to do with time evolution), but it works like a softmax (with $\tau = -\beta$) to pick out the ground state.
$$
\ket{\psi_0} \propto \lim_{\tau \rightarrow \infty} U(\tau) \ket{\psi_\text{init}}
$$
To avoid huge numbers in the exponential, we iterate the ITE operator in small steps $\delta \tau$ until convergence.
$$
\psi_{\tau + \delta \tau} = \frac{U(\delta \tau) \psi_{\tau}}{\| U(\delta \tau) \psi_{\tau} \|_2}
$$
Furthermore, the update algorithm uses the fact that a matrix exponential can be decomposed into a product of exponentials of the sum of the matrices if the matrices commute: $e^{A + B} = e^A e^B \iff [A, B] = 0$.
So, if we assume that the hamiltonian can be decomposed into two commuting parts $H = \sum_i H_i$ (e.g. that operate on different parts of the system), we can decompose the ITE operator as follows:
$$
U(\delta \tau) = e^{-\delta \tau H} = e^{-\delta \tau \sum_i H_i} = \prod_i e^{-\delta \tau H_i}
$$

The update algorithm is implemented in the `quimb` library. We can use it to find the ground state energy of a given hamiltonian. Let's try it out on a simple hamiltonian:
$$
H = \sigma^z_1 \sigma^z_2 + \sigma^z_2 \sigma^z_3 + \sigma^z_3 \sigma^z_4
$$
The ground state energy is $E_0 = -3$ and the ground state is $\ket{\psi_0} = \ket{0101}$.

## Helper Functions

In [None]:
import scipy.sparse as sp
from functools import reduce

matmap_np, matmap_sp = None, None

def parse_hamiltonian(hamiltonian, sparse=False, scaling=1, buffer=None, max_buffer_n=0, dtype=float): # I'd usually default to complex, but because we're only dealing with Ising models here, float is more handy
    """Parse a string representation of a Hamiltonian into a matrix representation. The result is guaranteed to be Hermitian.

    Parameters:
        hamiltonian (str): The Hamiltonian to parse.
        sparse (bool): Whether to use sparse matrices (csr_matrix) or dense matrices (numpy.array).
        scaling (float): A constant factor to scale the Hamiltonian by.
        buffer (dict): A dictionary to store calculated chunks in. If `None`, it defaults to the global `matmap_np` (or `matmap_sp` if `sparse == True`). Give `buffer={}` and leave `max_buffer_n == 0` (default) to disable the buffer.
        max_buffer_n (int): The maximum length (number of qubits) for new chunks to store in the buffer (default: 0). If `0`, no new chunks will be stored in the buffer.

    Returns:
        numpy.ndarray | scipy.sparse.csr_matrix: The matrix representation of the Hamiltonian.

    Example:
    >>> parse_hamiltonian('0.5*(XX + YY + ZZ + II)') # SWAP
    array([[ 1.+0.j  0.+0.j  0.+0.j  0.+0.j]
           [ 0.+0.j  0.+0.j  1.+0.j  0.+0.j]
           [ 0.+0.j  1.+0.j  0.+0.j  0.+0.j]
           [ 0.+0.j  0.+0.j  0.+0.j  1.+0.j]])
    >>> parse_hamiltonian('-(XX + YY + .5*ZZ) + 1.5')
    array([[ 1.+0.j  0.+0.j  0.+0.j  0.+0.j]
           [ 0.+0.j  2.+0.j -2.+0.j  0.+0.j]
           [ 0.+0.j -2.+0.j  2.+0.j  0.+0.j]
           [ 0.+0.j  0.+0.j  0.+0.j  1.+0.j]])
    >>> parse_hamiltonian('0.5*(II + ZI - ZX + IX)') # CNOT

    """
    kron = sp.kron if sparse else np.kron

    # Initialize the matrix map
    global matmap_np, matmap_sp
    if matmap_np is None or matmap_sp is None or matmap_np["I"].dtype != dtype:
        # numpy versions
        matmap_np = {
            "H": np.array([[1, 1], [1, -1]], dtype=dtype) / np.sqrt(2),
            "X": np.array([[0, 1], [1, 0]], dtype=dtype),
            "Z": np.array([[1, 0], [0, -1]], dtype=dtype),
            "I": np.array([[1, 0], [0, 1]], dtype=dtype),
        }
        # composites
        matmap_np.update({
            "ZZ": np.kron(matmap_np['Z'], matmap_np['Z']),
            "IX": np.kron(matmap_np['I'], matmap_np['X']),
            "XI": np.kron(matmap_np['X'], matmap_np['I']),
            "YY": np.array([[ 0,  0,  0, -1],  # to avoid complex numbers
                            [ 0,  0,  1,  0],
                            [ 0,  1,  0,  0],
                            [-1,  0,  0,  0]], dtype=dtype)
        })
        for i in range(2, 11):
            matmap_np["I"*i] = np.eye(2**i, dtype=dtype)
        # add 'Y' only if dtype supports imaginary numbers
        if np.issubdtype(dtype, np.complexfloating):
            matmap_np["Y"] = np.array([[0, -1j], [1j, 0]], dtype=dtype)

        # sparse versions
        matmap_sp = {k: sp.csr_array(v) for k, v in matmap_np.items()}
    
    if not np.issubdtype(dtype, np.complexfloating) and "Y" in hamiltonian:
        raise ValueError(f"The Pauli matrix Y is not supported for dtype {dtype.__name__}.")

    matmap = matmap_sp if sparse else matmap_np

    # only use buffer if pre-computed chunks are available or if new chunks are allowed to be stored
    use_buffer = buffer is None or len(buffer) > 0 or max_buffer_n > 0
    if use_buffer and buffer is None:
        buffer = matmap

    def calculate_chunk_matrix(chunk, sparse=False, scaling=1):
        # if scaling != 1:  # only relevant for int dtype
            # scaling = np.array(scaling, dtype=dtype)
        if use_buffer:
            if chunk in buffer:
                return buffer[chunk] if scaling == 1 else scaling * buffer[chunk]
            if len(chunk) == 1:
                return matmap[chunk[0]] if scaling == 1 else scaling * matmap[chunk[0]]
            # Check if a part of the chunk has already been calculated
            for i in range(len(chunk)-1, 1, -1):
                for j in range(len(chunk)-i+1):
                    subchunk = chunk[j:j+i]
                    if subchunk in buffer:
                        # If so, calculate the rest of the chunk recursively
                        parts = [chunk[:j], subchunk, chunk[j+i:]]
                        # remove empty chunks
                        parts = [c for c in parts if c != ""]
                        # See where to apply the scaling
                        shortest = min(parts, key=len)
                        # Calculate each part recursively
                        for i, c in enumerate(parts):
                            if c == subchunk:
                                if c == shortest:
                                    parts[i] = scaling * buffer[c]
                                    shortest = ""
                                else:
                                    parts[i] = buffer[c]
                            else:
                                if c == shortest:
                                    parts[i] = calculate_chunk_matrix(c, sparse=sparse, scaling=scaling)
                                    shortest = ""
                                else:
                                    parts[i] = calculate_chunk_matrix(c, sparse=sparse, scaling=1)
                        return reduce(kron, parts)

        # Calculate the chunk matrix gate by gate
        if use_buffer and len(chunk) <= max_buffer_n:
            gates = [matmap[gate] for gate in chunk]
            chunk_matrix = reduce(kron, gates)
            buffer[chunk] = chunk_matrix
            if scaling != 1:
                chunk_matrix = scaling * chunk_matrix
        else:
            gates = [scaling * matmap[chunk[0]]] + [matmap[gate] for gate in chunk[1:]]
            chunk_matrix = reduce(kron, gates)

        return chunk_matrix

    # Remove whitespace
    hamiltonian = hamiltonian.replace(" ", "")
    # replace - with +-, except before e
    hamiltonian = hamiltonian \
                    .replace("-", "+-") \
                    .replace("e+-", "e-") \
                    .replace("(+-", "(-")

    # print("parse_hamiltonian: Pre-processed Hamiltonian:", hamiltonian)

    # Find parts in parentheses
    part = ""
    parts = []
    depth = 0
    current_part_weight = ""
    for i, c in enumerate(hamiltonian):
        if c == "(":
            if depth == 0:
                # for top-level parts search backwards for the weight
                weight = ""
                for j in range(i-1, -1, -1):
                    if hamiltonian[j] in ["("]:
                        break
                    weight += hamiltonian[j]
                    if hamiltonian[j] in ["+", "-"]:
                        break
                weight = weight[::-1]
                if weight != "":
                    current_part_weight = weight
            depth += 1
        elif c == ")":
            depth -= 1
            if depth == 0:
                part += c
                parts.append((current_part_weight, part))
                part = ""
                current_part_weight = ""
        if depth > 0: 
            part += c

    # print("Parts found:", parts)

    # Replace parts in parentheses with a placeholder
    for i, (weight, part) in enumerate(parts):
        hamiltonian = hamiltonian.replace(weight+part, f"+part{i}", 1)
        # remove * at the end of the weight
        if weight != "" and weight[-1] == "*":
            weight = weight[:-1]
        if weight in ["", "+", "-"]:
            weight += "1"
        # Calculate the part recursively
        part = part[1:-1] # remove parentheses
        parts[i] = parse_hamiltonian(part, sparse=sparse, scaling=float(weight), buffer=buffer, max_buffer_n=max_buffer_n, dtype=dtype)

    # print("Parts replaced:", parts)

    # Parse the rest of the Hamiltonian
    chunks = hamiltonian.split("+")
    # Remove empty chunks
    chunks = [c for c in chunks if c != ""]
    # If parts are present, use them to determine the number of qubits
    if parts:
        n = int(np.log2(parts[0].shape[0]))
    else: # Use chunks to determine the number of qubits
        n = 0
        for c in chunks:
            if c[0] in ["-", "+"]:
                c = c[1:]
            if "*" in c:
                c = c.split("*")[1]
            if c.startswith("part"):
                continue
            try:
                float(c)
                continue
            except ValueError:
                n = len(c)
                break
        if n == 0:
            print("Warning: Hamiltonian is a scalar!")

    if not sparse and n > 10:
        # check if we would blow up the memory
        mem_required = 2**(2*n) * np.array(1, dtype=dtype).nbytes
        mem_available = psutil.virtual_memory().available
        if mem_required > mem_available:
            raise MemoryError(f"This would blow up you memory ({duh(mem_required)} required)! Try using `sparse=True`.")

    if sparse:
        H = sp.csr_array((2**n, 2**n), dtype=dtype)
    else:
        if n > 10:
            print(f"Warning: Using a dense matrix for a {n}-qubit Hamiltonian is not recommended. Use sparse=True.")
        H = np.zeros((2**n, 2**n), dtype=dtype)

    for chunk in chunks:
        # print("Processing chunk:", chunk)
        chunk_matrix = None
        if chunk == "":
            continue
        # Parse the weight of the chunk
        
        if chunk.startswith("part"):
            weight = 1  # parts are already scaled
            chunk_matrix = parts[int(chunk.split("part")[1])]
        elif "*" in chunk:
            weight = float(chunk.split("*")[0])
            chunk = chunk.split("*")[1]
        elif len(chunk) == n+1 and chunk[0] in ["-", "+"] and n >= 1 and chunk[1] in matmap:
            weight = float(chunk[0] + "1")
            chunk = chunk[1:]
        elif (chunk[0] in ["-", "+", "."] or chunk[0].isdigit()) and all([c not in matmap for c in chunk[1:]]):
            if len(chunk) == 1 and chunk[0] in ["-", "."]:
                chunk = 0
            weight = complex(chunk)
            if np.iscomplex(weight):
                raise ValueError("Complex scalars would make the Hamiltonian non-Hermitian!")
            weight = weight.real
            # weight = np.array(weight, dtype=dtype)  # only relevant for int dtype
            chunk_matrix = np.eye(2**n, dtype=dtype)
        elif len(chunk) != n:
            raise ValueError(f"Gate count must be {n} but was {len(chunk)} for chunk \"{chunk}\"")
        else:
            weight = 1

        if chunk_matrix is None:
            chunk_matrix = calculate_chunk_matrix(chunk, sparse=sparse, scaling = scaling * weight)
        elif scaling * weight != 1:
            chunk_matrix = scaling * weight * chunk_matrix

        # Add the chunk to the Hamiltonian
        # print("Adding chunk", weight, chunk, "for hamiltonian", scaling, hamiltonian)
        # print(type(H), H.dtype, type(chunk_matrix), chunk_matrix.dtype)
        if len(chunks) == 1:
            H = chunk_matrix
        else:
            H += chunk_matrix

    if sparse:
        assert np.allclose(H.data, H.conj().T.data), f"The given Hamiltonian {hamiltonian} is not Hermitian: {H.data}"
    else:
        assert np.allclose(H, H.conj().T), f"The given Hamiltonian {hamiltonian} is not Hermitian: {H}"

    return H

In [None]:
# this cell shows how qu.ham_heis works, including the convention 
# to divide interaction terms by 4 and magnetic field terms by -2

H1 = qu.ham_heis(n=3, j=(1, 1, 1), sparse=False).astype(complex)
H2 = parse_hamiltonian('0.25*(ZZI + IZZ + XXI + IXX + YYI + IYY)', dtype=complex)
assert np.allclose(H1, H2)
plt.figure(figsize=(2,2))
plt.imshow(H2.real, cmap='bwr')

H1 = qu.ham_heis(n=3, j=(0, 0, 1), b=(-1, 0, 0), sparse=False).astype(complex)
H2 = parse_hamiltonian('0.25*(ZZI + IZZ) + 0.5*(XII + IXI + IIX)', dtype=complex)
assert np.allclose(H1, H2)
plt.figure(figsize=(2,2))
plt.imshow(H2.real, cmap='bwr')
None

In [None]:
def ising_model_graph(graph, J=(-1,1), h=(-1,1), g=(-1,1)):
    """ Takes a graph and generates a Hamiltonian string for it that is compatible with `parse_hamiltonian`. """
    if not isinstance(graph, nk.graph.Graph):
        raise ValueError(f"graph must be a nk.graph.Graph, but is {type(graph)}")
    
    # get the number of qubits
    n_qubits = graph.n_nodes
    # get the edges
    edges = graph.edges()
    # get the coupling matrix
    J = np.array(J)
    if J.shape == ():
        # triangular matrix with all couplings set to J
        J = np.triu(np.ones((n_qubits, n_qubits)), k=1) * J
    elif J.shape == (2,):
        # triangular matrix with all couplings set to a random value in this range
        J = np.triu(np.random.uniform(J[0], J[1], (n_qubits, n_qubits)), k=1)
    elif J.shape == (n_qubits, n_qubits):
        # use the given matrix
        pass
    else:
        raise ValueError(f"J must be a scalar, 2-element vector, or matrix of shape {(n_qubits, n_qubits)}, but is {J.shape}")
    
    # get the longitudinal fields
    if h is not None:
        h = np.array(h)
        if h.shape == ():
            h = np.ones(n_qubits) * h
        elif h.shape == (2,):
            h = np.random.uniform(h[0], h[1], n_qubits)
        elif h.shape == (n_qubits,):
            pass
        else:
            raise ValueError(f"h must be a scalar, 2-element vector, or vector of shape {(n_qubits,)}, but is {h.shape}")
        
    # get the transverse fields
    if g is not None:
        g = np.array(g)
        if g.shape == ():
            g = np.ones(n_qubits) * g
        elif g.shape == (2,):
            g = np.random.uniform(g[0], g[1], n_qubits)
        elif g.shape == (n_qubits,):
            pass
        else:
            raise ValueError(f"g must be a scalar, 2-element vector, or vector of shape {(n_qubits,)}, but is {g.shape}")
        
    # generate the Hamiltonian
    H_str = ''
    # pairwise interactions
    for i, j in edges:
        assert i < j, f"edges must be sorted, but ({i}, {j}) is not"
        if J[i,j] != 0:
            H_str += str(J[i,j]) + '*' + 'I'*i + 'Z' + 'I'*(j-i-1) + 'Z' + 'I'*(n_qubits-j-1) + ' + '
    # local longitudinal fields
    if np.any(h):
        H_str += ' + '.join([str(h[i]) + '*' + 'I'*i + 'Z' + 'I'*(n_qubits-i-1) for i in range(n_qubits) if h[i] != 0]) + ' + '
    # local transverse fields
    if np.any(g):
        H_str += ' + '.join([str(g[i]) + '*' + 'I'*i + 'X' + 'I'*(n_qubits-i-1) for i in range(n_qubits) if g[i] != 0]) + ' + '

    # remove trailing ' + '
    H_str = H_str[:-3]

    return H_str

def edges_from_graph(graph, undirected=False):
    edges = graph.edges()
    edges = np.array(edges).T
    if undirected:
        edges = np.concatenate([edges, edges[:,::-1]], axis=0)
    edges = np.unique(edges, axis=0)  # sorts the edges
    return edges[0], edges[1]

def random_ising_own(N: int, graph: nk.graph.Grid):
    """ Generates N random ising models on the given graph. """
    n = graph.n_nodes
    J = np.random.uniform(-1, 1, size=(N, n, n))
    # # make sure each J is symmetric
    # for i in range(N):
    #     J[i] = (J[i] + J[i].T)/2
    #     # make sure the diagonal is zero
    #     J[i] -= np.diag(np.diag(J[i]))
    h = np.random.uniform(-1, 1, size=(N, n))
    g = np.random.uniform(-1, 1, size=(N, n))

    # get the edges for the coupling matrix
    edges = edges_from_graph(graph)

    hamiltonians = []
    for i in range(N):
        H_ising_str = ising_model_graph(graph, J[i], h[i], g[i])

        # create the coupling matrix
        J_i = np.zeros((n, n))
        J_i[edges] = J[i][edges]
        # convert to triu list
        # J_i = J_i[np.triu_indices(n, k=1)]
    
        hamiltonians.append((H_ising_str, graph.extent, {"J": J_i, "h": h[i], "g": g[i]}))

    return hamiltonians

def random_ising_nk(N: int, graph: nk.graph.Grid):
    J = np.random.uniform(-1, 1, size=N) # Coupling constant
    g = np.random.uniform(-1, 1, size=N) # Transverse field
    # J = 10*np.ones(N) # Coupling constant
    # g = -5*np.ones(N) # Transverse field
    n = graph.n_nodes
    edges = edges_from_graph(graph)
    hilbert = nk.hilbert.Spin(s=0.5, N=n)
    h = np.zeros(n)
    n_ones = np.ones(n)

    hamiltonians = []
    for i in range(N):
        ising = nk.operator.Ising(
            hilbert=hilbert,
            graph=graph,
            J=J[i], h=-g[i]
        )

        # Convert hyperparameters to the right format
        J_i = np.zeros((n, n))
        J_i[edges] = J[i]
        # convert to triu list
        # J_i = J_i[np.triu_indices(n, k=1)]
    
        hamiltonians.append((ising, graph.extent, {"J": J_i, "h": h, "g": g[i]*n_ones}))
    return hamiltonians

## MPS with DMRG (QUIMB)

In [None]:
data = torch.load('../../dataset/ising/data/MPS_6000_N200_16.pt')
data[800].grid_extent

In [None]:
import quimb as qu
import quimb.tensor as qtn

# Example setup: define these based on your data structure
idx = 600
n = tuple(data[idx].grid_extent)  # Number of sites, e.g., (10,)
pbc = data[idx].pbc  # Whether periodic boundary conditions are applied
assert not pbc, "Periodic boundary conditions are not supported yet."

# Initialize the Hamiltonian builder for a 1D chain with spin-1/2 (s=1/2)
ham_builder = qtn.SpinHam1D(S=1/2)

# Add single-site terms from data
for i, (h, g) in enumerate(data[idx].x_nodes):
    h, g = float(h), float(g)  # Ensure these are floats
    ham_builder[i] += 2 * h, 'Z'  # Factor of 2 for the field strength, adjust as necessary
    ham_builder[i] += 2 * g, 'X'  # Factor of 2 for the field strength, adjust as necessary

# Add interaction terms between sites considering PBC
for (a, b), J_ab in zip(data[idx].edge_index.T, data[idx].x_edges):
    J_ab = float(J_ab)  # Ensure this is a float
    if J_ab != 0 and (abs(a - b) == 1 or (pbc and (abs(a - b) == n[0] - 1))):
        ham_builder[a, b] += 4 * J_ab, 'Z', 'Z'  # Add interaction terms, factor of 4 for scaling
        ham_builder[b, a] += 4 * J_ab, 'Z', 'Z'  # Add interaction terms, factor of 4 for scaling

# Build the MPO for the Hamiltonian
H_mpo = ham_builder.build_mpo(n[0])

# Set up DMRG using the MPO
dmrg = qtn.DMRG2(H_mpo)
dmrg.opts['max_bond'] = 30  # Adjust the maximum bond dimension as needed
dmrg.opts['cutoff'] = 1e-10  # Set the truncation cutoff
dmrg.opts['tol' ] = 1e-6  # Set the tolerance for convergence
dmrg.solve(tol = 1e-6, verbosity=1)  # Perform optimization with a given number of sweeps


In [None]:
print("The ground state energy is", dmrg.energy)
print("The ground state is", dmrg.state)

## PEPS with Simple/Full Update (QUIMB)

In [None]:
data = torch.load('../../dataset/ising/data/own_ham_2000_12_True.pt')
data[800].grid_extent

In [None]:
idx = 1400

if idx is None:
    n = (5,1)
    graph = nk.graph.Grid(n, pbc=False)

    def pos(n):
        x,y = graph.positions[n]
        return int(x), int(y)

    H, _, hyp = random_ising_nk(1, graph)[0]
    # construct local hamiltonians using qu.ham_heis(2, j=(0, 0, J_ab)) and qu.ham_heis(1, b=(h_a, 0, g_a))
    local_two_site_hamiltonians = {}  # dict for qtn.LocalHam2D H2
    for a,b in graph.edges():
        J_ab = hyp['J'][a,b]
        local_two_site_hamiltonians[pos(a),pos(b)] = qu.ham_heis(2, j=(0, 0, 4*J_ab))

    local_one_site_hamiltonians = {}  # dict for qtn.LocalHam2D H1
    for a in graph.nodes():
        h_a = hyp['h'][a]
        g_a = hyp['g'][a]
        local_one_site_hamiltonians[pos(a)] = qu.ham_heis(1, b=(-2*g_a, 0, -2*h_a))
else:
    H = data[idx].hamiltonian
    n = tuple(data[idx].grid_extent)
    print("grid extent:", n)

    def pos(node):
        # calculate x,y coordinates from node index
        y = node % n[1]
        x = (node - y) // n[1]
        return x,y

    # data[idx].x_nodes contains [n_nodes, 2] local field values for each node
    local_one_site_hamiltonians = {}  # dict for qtn.LocalHam2D H1
    for i, (h,g) in enumerate(data[idx].x_nodes):
        h, g = float(h), float(g)  # convert from torch to float
        local_one_site_hamiltonians[pos(i)] = qu.spin_operator('Z') * h * 2
        local_one_site_hamiltonians[pos(i)] += qu.spin_operator('X') * g * 2

    # data[idx].x_edges contains [n_edges, 1] coupling values for each edge
    # and data[idx].edge_index contains [2, n_edges] indices of the nodes that the edge connects
    local_two_site_hamiltonians = {}  # dict for qtn.LocalHam2D H2
    for (a,b), J_ab in zip(data[idx].edge_index.T, data[idx].x_edges):
        a, b, J_ab = int(a), int(b), float(J_ab)  # convert from torch to int/float
        if J_ab != 0:
            local_two_site_hamiltonians[pos(a),pos(b)] = qu.ham_heis(2, j=(0, 0, 4*J_ab))  # factor of 4 because of different convention

    print("Data label ground energy", float(data[idx].y_energy))

if type(H) == str:
    ham_full = parse_hamiltonian(H, sparse=True, dtype=float)
else:
    ham_full = H.to_sparse()

energy_exact = qu.groundenergy(ham_full)
ground_state_exact = qu.groundstate(ham_full)

print(f'Exact ground state energy: {energy_exact}')

ham_local = qtn.LocalHam2D(*n, H2=local_two_site_hamiltonians, H1=local_one_site_hamiltonians)
# set ordering
#ham_local.ordering = 'raster'

psi0 = qtn.PEPS.rand(*n, bond_dim=4)
psi0.show()
su = qtn.SimpleUpdate(
    psi0 = psi0,
    ham = ham_local,
    chi = 15,
    compute_energy_every = None,
    compute_energy_per_site = True,
    keep_best = True,
    progbar = True
)
for tau in [0.1, 0.01, 0.001]:
    su.evolve(100, tau=tau)

print(f'Approximated ground state energy: {(su.best["energy"] * np.prod(n)):.6f}')

# Continue with Full Update -> this doesn't improve the result at all, but it's really slow
# fu = qtn.FullUpdate(
#     psi0 = su.best['state'].copy(),
#     ham = ham_local,
#     chi = 8,
#     compute_energy_every = None,
#     compute_energy_per_site = True,
#     keep_best = True,
#     progbar = True
# )
# for tau in [0.3, 0.1, 0.03]:
#     fu.evolve(50, tau=tau)
#     print(f'Approximated ground state energy: {fu.best["energy"]:.6f}')

# plt.plot(su.its, su.energies, color='green')
# plt.axhline(energy_exact, color='black')
# plt.title('Simple Update Convergence')
# plt.ylabel('Energy')
# plt.xlabel('Iteration')

su.best['energy'] *= np.prod(n)

In [None]:
idx = 800

H = data[idx].hamiltonian
n = tuple(data[idx].grid_extent)
if len(n) == 1:
    n = (n[0], 1)
print("grid extent:", n)

def pos(node):
    # calculate x,y coordinates from node index
    y = node % n[1]
    x = (node - y) // n[1]
    return x,y

# def pos(node):
#     return node


# data[idx].x_nodes contains [n_nodes, 2] local field values for each node
local_one_site_hamiltonians = {}  # dict for qtn.LocalHam2D H1
for i, (h,g) in enumerate(data[idx].x_nodes):
    print(pos(i))
    h, g = float(h), float(g)  # convert from torch to float
    local_one_site_hamiltonians[pos(i)] = qu.spin_operator('Z') * h * 2
    local_one_site_hamiltonians[pos(i)] += qu.spin_operator('X') * g * 2

# data[idx].x_edges contains [n_edges, 1] coupling values for each edge
# and data[idx].edge_index contains [2, n_edges] indices of the nodes that the edge connects
edges = []
local_two_site_hamiltonians = {}  # dict for qtn.LocalHam2D H2
for (a,b), J_ab in zip(data[idx].edge_index.T, data[idx].x_edges):
    a, b, J_ab = int(a), int(b), float(J_ab)  # convert from torch to int/float

    # if abs(a - b) == 1 or (a == 0 and b == n[0]-1) or (b == 0 and a == n[0]-1):
    edge = (pos(a),pos(b))
    if J_ab != 0:
        edges.append(edge)
        local_two_site_hamiltonians[pos(a),pos(b)] = qu.ham_heis(2, j=(0, 0, 4*J_ab))  # factor of 4 because of different convention

print(edges)

print("Data label ground energy", float(data[idx].y_energy))

print(local_one_site_hamiltonians)
print(local_two_site_hamiltonians)

if type(H) == str:
    ham_full = parse_hamiltonian(H, sparse=True, dtype=float)
else:
    ham_full = H.to_sparse()

energy_exact = qu.groundenergy(ham_full)
ground_state_exact = qu.groundstate(ham_full)

print(f'Exact ground state energy: {energy_exact}')

ham_local = qtn.LocalHamGen(H2=local_two_site_hamiltonians, H1=local_one_site_hamiltonians)
# set ordering
ham_local.ordering = 'raster'


psi0 = qtn.TN_from_edges_rand(edges, D=15, phys_dim=2)

su = qtn.SimpleUpdateGen(
    psi0 = psi0,
    ham = ham_local,
    compute_energy_every = None,
    compute_energy_per_site = True,
    keep_best = True,
    progbar = True
)
for tau in [0.1, 0.01, 0.001]:
    su.evolve(100, tau=tau)

print(f'Approximated ground state energy: {(su.best["energy"] * np.prod(n)):.6f}')

# Continue with Full Update -> this doesn't improve the result at all, but it's really slow
# fu = qtn.FullUpdate(
#     psi0 = su.best['state'].copy(),
#     ham = ham_local,
#     chi = 8,
#     compute_energy_every = None,
#     compute_energy_per_site = True,
#     keep_best = True,
#     progbar = True
# )
# for tau in [0.3, 0.1, 0.03]:
#     fu.evolve(50, tau=tau)
#     print(f'Approximated ground state energy: {fu.best["energy"]:.6f}')

# plt.plot(su.its, su.energies, color='green')
# plt.axhline(energy_exact, color='black')
# plt.title('Simple Update Convergence')
# plt.ylabel('Energy')
# plt.xlabel('Iteration')

su.best['energy'] *= np.prod(n)

In [None]:
L = 20

# define any geometry here
edges = [
    (i, (i + 1) % L)
    for i in range(L - 1)
]

print(edges)


two = {
    edge: qu.ham_heis(2).real
    for edge in edges
}

one = {
    i: (qu.spin_operator('Z')* -0.5 + qu.spin_operator('X') * -0.5)
    for i in range(L)
}


ham = qtn.LocalHamGen(H2=two, H1=one)

print(one)
print(two)

psi = qtn.TN_from_edges_rand(edges, D=15, phys_dim=2)

su = qtn.SimpleUpdateGen(psi, ham, compute_energy_per_site=True, keep_best=True, progbar=True)
su.evolve(30, tau=0.3)
su.evolve(30, tau=0.1)
su.evolve(30, tau=0.001)

print(su.best['energy'] * L)
su.state.compute_local_expectation_exact(ham)

### Calculate RDMs

In [None]:
psi = su.best['state']

if len(data[idx].grid_extent) == 1:
    m, n = (data[idx].grid_extent[0], 1)
else:
    m, n = tuple(data[idx].grid_extent)

dims = [[2] * n] * m

print("dims:", dims)

def pos(node):
        # calculate x,y coordinates from node index
        y = node % n
        x = (node - y) // n
        return x,y

def compute_rdm(peps, sites, dims):
    """
    Compute the RDM for a list of sites in a PEPS.

    Parameters:
    peps (PEPS): The PEPS representing the quantum state.
    sites (List of tuples): The coordinates of the sites.
    dims (list): The dimensions of the Hilbert space at each site.

    Returns:
    numpy.ndarray: The 2-RDM of the specified sites.
    """          
    return qu.normalize(qu.partial_trace(peps, dims=dims, keep=sites))

site1 = pos(0)
site2 = pos(1)

print(site1, site2)

psi_dense = psi.to_dense()


for i in range(len(data[idx].y_node_rdms)):
    one_rdm = compute_rdm(psi_dense, [pos(i)], dims)
    error = np.linalg.norm(one_rdm - data[idx].y_node_rdms[i])
    print(f"1RDM {i}\t{error:.4f}")

for x in range(len(data[idx].y_edge_rdms)):
    i, j = data[idx].edge_index[:, x]
    two_rdm = compute_rdm(psi_dense, [pos(i), pos(j)], dims)
    error = np.linalg.norm(two_rdm - data[idx].y_edge_rdms[x])
    if error > 1e-6:
        print(f"2RDM {int(i), int(j)}\terror: {error:.4f}")
    else:
        print(f"\t2RDM {int(i), int(j)}\t*check*")

In [None]:
def obtainRDMs(data_point, ground_state):
    one_rdms = []
    two_rdms = []

    m, n = tuple(data[idx].grid_extent)
    dims = [[2] * n] * m

    ground_state = ground_state.to_dense()

    for node_idx in range(data_point.x_nodes.shape[0]):
        one_rdms.append(compute_rdm(ground_state, [pos(node_idx)], dims))

    for edge_idx in range(data_point.edge_index.shape[1]):
        edge = data_point.edge_index[:, edge_idx]
        two_rdms.append(compute_rdm(ground_state, [pos(edge[0]), pos(edge[1])], dims))
    
    return one_rdms, two_rdms

one_rdms, two_rdms = obtainRDMs(data[idx], psi)    

## MPS with Two-Site DMRG (TeNPy)

In [None]:
data = torch.load('../../dataset/ising/data/own_ham_2000_12_True.pt')
idx = 11
print("grid extent:", data[idx].grid_extent)
print("pbc:", data[idx].pbc)

In [None]:
# data = torch.load('../../dataset/ising/data/own_ham_mpsobc_10000_12_True.pt')
# idx = 60
# print(data[idx].grid_extent)
# print("pbc:", data[idx].pbc)

# print(data[60].grid_extent)
# print(data[3000].grid_extent)
# print(data[4000].grid_extent)
# print(data[6000].grid_extent)
# print(data[8000].grid_extent)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tenpy
from tenpy.networks.site import SpinHalfSite
from tenpy.models.tf_ising import TFIChain
from tenpy.models.model import CouplingMPOModel
from tenpy.models.lattice import Chain
from tenpy.networks.mps import MPS
from tenpy.algorithms import dmrg

class Custom(CouplingMPOModel):
    default_lattice = "Chain"
    force_default_lattice = False
    

    def init_sites(self, model_param):
        site = SpinHalfSite(conserve=None)
        return site

    def init_lattice(self, model_param):
        sites = self.init_sites(model_param)
        bc = "open" if model_param["bc_MPS"] == "open" else "periodic"
        lat = Chain(model_param["L"], sites, bc=bc, bc_MPS=model_param["bc_MPS"])
        self.L = lat.N_sites
        return lat

    def init_terms(self, model_params):
        # Add local field terms
        for i, (h, g) in local_fields:
            self.add_onsite_term(-h, i, 'Sigmaz')
            self.add_onsite_term(-g, i, 'Sigmax')

        # Add coupling terms
        for (i, j), J in couplings:
            if i > j:
                t = i
                i = j
                j = t
    
            if J != 0:
                self.add_coupling_term(float(J), int(i), int(j),  'Sigmaz', 'Sigmaz')



print(data[idx].grid_extent)

# Extract data from your dataset
local_fields = [(i, (float(h), float(g))) for i, (h, g) in enumerate(data[idx].x_nodes)]
couplings = [((int(a), int(b)), float(J_ab)) for (a, b), J_ab in zip(data[idx].edge_index.T, data[idx].x_edges)]

# Define model parameters
model_params = {
    'L': int(data[idx].grid_extent[0]),
    'local_fields': local_fields,
    'couplings': couplings,
    'conserve': None,
    'bc_MPS': 'infinite' if data[idx].pbc else 'finite'
    }

bc = 'periodic' if data[idx].pbc else 'open'

# Now create the model using the lattice and model_params
model = Custom(model_params)

# Initialize MPS
psi0 = MPS.from_product_state(model.lat.mps_sites(), ["up"] * model.lat.N_sites, bc=model_params["bc_MPS"])

# DMRG parameters
dmrg_params = {
    'mixer': True,
    'max_E_err': 1.e-10,
    'trunc_params': {
        'chi_max': 30,
        'svd_min': 1.e-10
    },
    'combine': True,
}

# Run DMRG
eng = dmrg.TwoSiteDMRGEngine(psi0, model, dmrg_params)
E, psi = eng.run()


print(f"Data label ground energy {(float(data[idx].y_energy)):.6f}")
if data[idx].pbc:
    print(f'Approximated ground state energy: {(E * model.lat.N_sites):.6f}')
else:
    print(f'Approximated ground state energy: {E:.6f}')


### Finding 1 and 2 RDMs

In [None]:
def find_indices(data, idx, i, j):
    """
    Find indices in data[idx].edge_index where the first row equals i and the second row equals j.

    Args:
    data: The data containing edge_index tensors.
    idx (int): The index of the specific data item.
    i, j (int): The values to match in the first and second rows of edge_index.

    Returns:
    int: The indices where the first row equals i and the second row equals j.
    """
    # Ensure i and j are within the bounds of the tensor
    if i >= data[idx].edge_index.shape[0] or j >= data[idx].edge_index.shape[1]:
        raise ValueError("i or j is out of bounds.")

    matching_indices = ((data[idx].edge_index[0] == i) & (data[idx].edge_index[1] == j)).nonzero(as_tuple=True)

    print(i, j, matching_indices)
    return matching_indices[0].item()  


def compute_RDM_TeNPy(psi, indices):
    
    one_rdm = len(indices) == 1
    two_rdm = len(indices) == 2

    if not one_rdm and not two_rdm:
        raise ValueError("The 'indices' argument must be a tuple of one or two integers.")
    
    if one_rdm:
        rdm = psi.get_rho_segment(indices)
        rdm = rdm.to_ndarray()
        rdm = rdm[::-1,::-1]
        off = np.eye(rdm.shape[0]) != 1
        rdm[off] *= -1
    elif two_rdm:
        rdm = psi.get_rho_segment(indices)
        rdm = rdm.to_ndarray()
        rdm = rdm.reshape(4, 4)
        rdm = rdm[::-1, ::-1]
        off = np.logical_and(np.eye(4) != 1, np.eye(4)[::-1] != 1)
        rdm[off] *= -1
        # if indices are not adjacent, we need to transform the 2-RDM
        def swap(a, i, j):
            a[i], a[j] = a[j], a[i]
        if abs(indices[0] - indices[1]) != 1:
            # swap indices to match the label
            swap(rdm, (0,2), (1,0))
            swap(rdm, (0,3), (1,1))
            swap(rdm, (2,2), (3,0))
            swap(rdm, (2,3), (3,1))

    return rdm

#Assert all RDMs are close to the label for each particle running compute rdm for all indices
for i in range(len(data[idx].y_node_rdms)):
    one_rdm = compute_RDM_TeNPy(psi, [i])
    assert np.allclose(one_rdm, data[idx].y_node_rdms[i]), f"1RDM {i} is not close to the label"
print("All 1RDMs are close to the label")

for x in range(len(data[idx].y_edge_rdms)):
    i, j = data[idx].edge_index[:, x]
    two_rdm = compute_RDM_TeNPy(psi, [i, j])
    error = np.linalg.norm(two_rdm - data[idx].y_edge_rdms[x])
    assert np.allclose(two_rdm, data[idx].y_edge_rdms[x]), f"2RDM {int(i), int(j)} is not close to the label"
print("All 2RDMs are close to the label")

### Calculate all 1 and 2 RDMs

In [None]:
def obtainRDMs(data_point, ground_state):
    one_rdms = []
    two_rdms = []


    for node_idx in range(data_point.x_nodes.shape[0]):
        one_rdms.append(compute_RDM_TeNPy(ground_state, [node_idx]))

    for edge_idx in range(data_point.edge_index.shape[1]):
        edge = data_point.edge_index[:, edge_idx]
        two_rdms.append(compute_RDM_TeNPy(ground_state, [edge[0], edge[1]]))
    
    return one_rdms, two_rdms

one_rdms, two_rdms = obtainRDMs(data[idx], psi)
print("1RDMs")
for i in range(len(one_rdms)):
    print(f"{i}: {np.allclose(one_rdms[i], data[idx].y_node_rdms[i])}")
print("2RDMs")
for i in range(len(two_rdms)):
    print(f"{i}: {np.allclose(two_rdms[i], data[idx].y_edge_rdms[i])}")


for i in range(len(data[idx].y_node_rdms)):
    error = np.linalg.norm(one_rdms[i] - data[idx].y_node_rdms[i])
    print(f"1RDM {i}\t{error:.4f}")
for x in range(len(data[idx].y_edge_rdms)):
    error = np.linalg.norm(two_rdms[x] - data[idx].y_edge_rdms[x])
    if error > 1e-6:
        print(f"2RDM {int(i), int(j)}\terror: {error:.4f}")
    else:
        print(f"\t2RDM {int(i), int(j)}\t*check*")
        

## Joining Everything

In [None]:
from abc import ABC, abstractmethod
import random
import platform

import numpy as np
import matplotlib.pyplot as plt
import torch

import quimb.tensor as qtn
import quimb as qu

import tenpy
from tenpy.networks.site import SpinHalfSite
from tenpy.models.tf_ising import TFIChain
from tenpy.models.model import CouplingMPOModel
from tenpy.models.lattice import Chain
from tenpy.networks.mps import MPS
from tenpy.algorithms import dmrg

USE_FUNC = True

class TensorNetworkAlgorithm(ABC):
    @abstractmethod
    def __init__(self, params):
        pass
    @abstractmethod
    def set_datapoint(self, data_point):
        pass
    @abstractmethod
    def create_model(self):
        pass
    @abstractmethod
    def run(self):
        pass
    @abstractmethod
    def getEnergy(self):
        pass
    @abstractmethod
    def getRDMs(self):
        pass
    @abstractmethod
    def getGroundState(self):
        pass
    @staticmethod
    def isMPS(data_point):
        return len(data_point.grid_extent) == 1 or data_point.grid_extent[1] == 1

class SimpleUpdate(TensorNetworkAlgorithm):
    def __init__(self, params):
        self.params = params
        self.chi = params['chi']
        self.bond_dim = params['bond_dim']
        self.num_iters = params['num_iters']
        self.tau = params['tau']
        self.psi = None
        self.hamiltonian = None
        self.psi = None
        self.energy = None
        self.data_point = None
        self.psi0 = None
        self.edges = []

    @staticmethod
    def pos(node, n):
        # calculate x,y coordinates from node index
        y = node % n[1]
        x = (node - y) // n[1]
        return x,y

    def set_datapoint(self, data_point):
        self.n = tuple(data_point.grid_extent)
        self.data_point = data_point
        self.psi0 = qtn.PEPS.rand(*self.n, bond_dim=self.bond_dim)
                
    def create_model(self):

        local_one_site_hamiltonians = {}  # dict for qtn.LocalHam2D H1
        for i, (h,g) in enumerate(self.data_point.x_nodes):
            h, g = float(h), float(g)  # convert from torch to float
            local_one_site_hamiltonians[SimpleUpdate.pos(i, self.n)] = qu.spin_operator('Z') * h * 2
            local_one_site_hamiltonians[SimpleUpdate.pos(i, self.n)] += qu.spin_operator('X') * g * 2

        local_two_site_hamiltonians = {}  # dict for qtn.LocalHam2D H2
        for (a,b), J_ab in zip(self.data_point.edge_index.T, self.data_point.x_edges):
            a, b, J_ab = int(a), int(b), float(J_ab)  # convert from torch to int/float
            if J_ab != 0:
                local_two_site_hamiltonians[SimpleUpdate.pos(a, self.n), SimpleUpdate.pos(b, self.n)] = qu.ham_heis(2, j=(0, 0, 4*J_ab))  # factor of 4 because of different convention

        ham_local = qtn.LocalHam2D(*self.n, H2=local_two_site_hamiltonians, H1=local_one_site_hamiltonians)
        self.hamiltonian = ham_local
    
    def run(self):
        su = qtn.SimpleUpdate(
            psi0 = self.psi0,
            ham = self.hamiltonian,
            chi = self.chi,
            compute_energy_every = None,
            compute_energy_per_site = True,
            keep_best = True,
            progbar = True
        )
        for tau in self.tau:
            su.evolve(self.num_iters, tau=tau)
        self.psi = su.best['state']
        self.energy = su.best['energy'] * np.prod(self.n)
    
    def getEnergy(self):
        return self.energy
    
    def getRDMs(self):
        one_rdms = []
        two_rdms = []

        m, n = self.n
        dims = [[2] * n] * m
            

        def compute_rdm(peps, sites, dims):
            """
            Compute the RDM for a list of sites in a PEPS.

            Parameters:
            peps (PEPS): The PEPS representing the quantum state.
            sites (List of tuples): The coordinates of the sites.
            dims (list): The dimensions of the Hilbert space at each site.

            Returns:
            numpy.ndarray: The 2-RDM of the specified sites.
            """          
            return qu.normalize(qu.partial_trace(peps, dims=dims, keep=sites))

        ground_state = self.psi.to_dense()

        for node_idx in range(self.data_point.x_nodes.shape[0]):
            one_rdms.append(compute_rdm(ground_state, [SimpleUpdate.pos(node_idx, self.n)], dims))

        for edge_idx in range(self.data_point.edge_index.shape[1]):
            edge = self.data_point.edge_index[:, edge_idx]
            two_rdms.append(compute_rdm(ground_state, [SimpleUpdate.pos(edge[0], self.n), SimpleUpdate.pos(edge[1], self.n)], dims))
        
        return one_rdms, two_rdms
    
    def getGroundState(self):
        return self.psi
    
class SimpleUpdateGen(SimpleUpdate):
    def __init__(self, params):
        super().__init__(params)
        self.edges = []
        self.max_bond = params['max_bond']


    def set_datapoint(self, data_point):
        self.n = tuple(data_point.grid_extent)
        if len(self.n) == 1:
            self.n = (self.n[0], 1)
        self.data_point = data_point
        self.edges = []
        for (a,b), J_ab in zip(self.data_point.edge_index.T, self.data_point.x_edges):
            a, b, J_ab = int(a), int(b), float(J_ab)
            if J_ab != 0:
                self.edges.append((a, b))
        self.psi0 = qtn.TN_from_edges_rand(self.edges, D=self.bond_dim, phys_dim=2)

    def run(self):
        su = qtn.SimpleUpdateGen(
            psi0 = self.psi0,
            ham = self.hamiltonian,
            compute_energy_every = None,
            compute_energy_per_site = True,
            keep_best = True,
            progbar = True
        )

        for tau in self.tau:
            su.evolve(self.num_iters, tau=tau)

        self.psi = su.best['state']
        self.energy = su.best['energy'] * np.prod(self.n)

    def create_model(self):

        local_one_site_hamiltonians = {}  # dict for qtn.LocalHam2D H1
        for i, (h,g) in enumerate(self.data_point.x_nodes):
            h, g = float(h), float(g)  # convert from torch to float
            local_one_site_hamiltonians[i] = qu.spin_operator('Z') * h * 2
            local_one_site_hamiltonians[i] += qu.spin_operator('X') * g * 2

        local_two_site_hamiltonians = {}  # dict for qtn.LocalHam2D H2
        for (a,b), J_ab in zip(self.data_point.edge_index.T, self.data_point.x_edges):
            a, b, J_ab = int(a), int(b), float(J_ab)  # convert from torch to int/float
            if J_ab != 0:
                local_two_site_hamiltonians[a, b] = qu.ham_heis(2, j=(0, 0, 4*J_ab))  # factor of 4 because of different convention

        ham_local = qtn.LocalHamGen(H2=local_two_site_hamiltonians, H1=local_one_site_hamiltonians)
        self.hamiltonian = ham_local

    def getRDMs(self):
        one_rdms = []
        two_rdms = []

        m, n = self.n
        dims = [[2] * n] * m
        
        def pos(node, n):
            # calculate x,y coordinates from node index
            y = node % n[1]
            x = (node - y) // n[1]
            # return x, y
            return node


        for node_idx in range(self.data_point.x_nodes.shape[0]):
            rdm = self.psi.partial_trace([pos(node_idx, self.n)], self.max_bond, "auto", normalized=True)
            rdm.draw()
            one_rdms.append(rdm)

        for edge_idx in range(self.data_point.edge_index.shape[1]):
            edge = self.data_point.edge_index[:, edge_idx]
            rdm = self.psi.partial_trace([pos(edge[0], self.n), pos(edge[1], self.n)], self.max_bond, "auto", normalized=True)
            rdm.draw()
            two_rdms.append(rdm)
        
        return one_rdms, two_rdms

class FullUpdate(SimpleUpdate):
    def __init__(self, params):
        super().__init__(params)
        self.is_gpu = platform.system() != 'Windows'

    def run(self):
        super().run()

        self.psi0 = self.psi

        if self.is_gpu:
            def to_backend(x):
                import cupy as cp
                return cp.asarray(x).astype('float32')
            self.hamiltonian.apply_to_arrays(to_backend)
            self.psi0.apply_to_arrays(to_backend)

        fu = qtn.FullUpdate(
            psi0 = self.psi0,
            ham = self.hamiltonian,
            chi = self.chi,
            compute_energy_every = None,
            compute_energy_per_site = True,
            keep_best = True,
            progbar = True
        )
        for tau in self.tau:
            fu.evolve(self.num_iters, tau=tau)
        self.psi = fu.best['state']
        self.energy = fu.best['energy'] * np.prod(self.n)

class CustomIsingMPOModel(CouplingMPOModel):
    default_lattice = "Chain"
    force_default_lattice = False
    

    def init_sites(self, model_param):
        site = SpinHalfSite(conserve=None)
        return site

    def init_lattice(self, model_param):
        sites = self.init_sites(model_param)
        bc = "open" if model_param["bc_MPS"] == "open" else "periodic"
        lat = Chain(model_param["L"], sites, bc=bc, bc_MPS=model_param["bc_MPS"])
        self.L = lat.N_sites
        return lat

    def init_terms(self, model_params):
        # Add local field terms
        for i, (h, g) in model_params["local_fields"]:
            self.add_onsite_term(-h, i, 'Sigmaz')
            self.add_onsite_term(-g, i, 'Sigmax')

        # Add coupling terms
        for (i, j), J in model_params["couplings"]:
            if i > j:
                t = i
                i = j
                j = t
    
            if J != 0:
                self.add_coupling_term(float(J), int(i), int(j),  'Sigmaz', 'Sigmaz')

class DMRG(TensorNetworkAlgorithm):
    def __init__(self, params):
        self.params = params
        self.max_E_err = params['max_E_err']
        self.chi_max = params['chi_max']
        self.svd_min = params['svd_min']
        self.psi = None
        self.energy = None
        self.data_point = None
        self.psi0 = None
        self.model = None
    
    def set_datapoint(self, data_point):
        self.data_point = data_point
        self.L = int(self.data_point.grid_extent[0])
        self.bc_MPS = 'infinite' if self.data_point.pbc else 'finite'
        self.bc = 'periodic' if self.data_point.pbc else 'open'
        self.psi = None
        self.energy = None
        self.psi0 = None
        self.model = None
        
        
    def create_model(self):
        local_fields = [(i, (float(h), float(g))) for i, (h, g) in enumerate(self.data_point.x_nodes)]
        couplings = [((int(a), int(b)), float(J_ab)) for (a, b), J_ab in zip(self.data_point.edge_index.T, self.data_point.x_edges)]
        
        # Define model parameters
        model_params = {
            'L': self.L,
            'local_fields': local_fields,
            'couplings': couplings,
            'conserve': None,
            'bc_MPS': self.bc_MPS
        }

        self.model = CustomIsingMPOModel(model_params)
        seq = [random.choice(["up", "down"]) for _ in range(self.model.lat.N_sites)]
        self.psi0 = MPS.from_product_state(self.model.lat.mps_sites(), seq, bc=self.bc_MPS)
        #self.psi0 = MPS.from_product_state(self.model.lat.mps_sites(), ["up"] * self.model.lat.N_sites, bc=self.bc_MPS)

    def run(self):
        dmrg_params = {
            'mixer': True,
            'max_E_err': self.max_E_err,
            'trunc_params': {
                'chi_max': self.chi_max,
                'svd_min': self.svd_min
            },
            'combine': True,
        }

        if self.L > 2:
            eng = dmrg.TwoSiteDMRGEngine(self.psi0, self.model, dmrg_params)
        else:
            eng = dmrg.SingleSiteDMRGEngine(self.psi0, self.model, dmrg_params)
            
        self.energy, self.psi = eng.run()
        # self.energy *= self.model.lat.N_sites


    def getEnergy(self):
        return self.energy
    
    def getRDMs(self):
        #Use TeNPy function
        one_rdms = []
        two_rdms = []

        def compute_RDM_TeNPy(psi, indices):

            one_rdm = len(indices) == 1
            two_rdm = len(indices) == 2

            if not one_rdm and not two_rdm:
                raise ValueError("The 'indices' argument must be a tuple of one or two integers.")
            
            if one_rdm:
                rdm = psi.get_rho_segment(indices)
                rdm = rdm.to_ndarray()
                rdm = rdm[::-1,::-1]
                off = np.eye(rdm.shape[0]) != 1
                rdm[off] *= -1
            elif two_rdm:
                rdm = psi.get_rho_segment(indices)
                rdm = rdm.to_ndarray()
                rdm = rdm.reshape(4, 4)
                rdm = rdm[::-1, ::-1]
                off = np.logical_and(np.eye(4) != 1, np.eye(4)[::-1] != 1)
                rdm[off] *= -1
                # if indices are not adjacent, we need to transform the 2-RDM
                def swap(a, i, j):
                    a[i], a[j] = a[j], a[i]
                if abs(indices[0] - indices[1]) != 1:
                    # swap indices to match the label
                    swap(rdm, (0,2), (1,0))
                    swap(rdm, (0,3), (1,1))
                    swap(rdm, (2,2), (3,0))
                    swap(rdm, (2,3), (3,1))

            return rdm

        for node_idx in range(self.data_point.x_nodes.shape[0]):
            one_rdms.append(compute_RDM_TeNPy(self.psi, [node_idx]))

        for edge_idx in range(self.data_point.edge_index.shape[1]):
            edge = self.data_point.edge_index[:, edge_idx]
            two_rdms.append(compute_RDM_TeNPy(self.psi, [edge[0], edge[1]]))
    
        return one_rdms, two_rdms        

    def getGroundState(self):
        return self.psi
    
class DMRG_QUIMB(TensorNetworkAlgorithm):
    def __init__(self, params):
        self.params = params
        self.max_bond = params.get('max_bond', 30)  # Maximum bond dimension
        self.cutoff = params.get('cutoff', 1e-10)  # Truncation cutoff for SVD
        self.tolerance = params.get('tol', 1e-6)  # Tolerance for convergence
        self.verbosity = params.get('verbosity', 0)  # Verbosity level
        self.psi = None  # Ground state MPS
        self.energy = None
        self.data_point = None
        self.pbc = False

    def set_datapoint(self, data_point):
        if not TensorNetworkAlgorithm.isMPS(data_point):
            raise ValueError("DMRG_QUIMB only supports MPS formatted data points.")
        self.data_point = data_point
        self.L = data_point.grid_extent[0]  # Assuming a 1D system
        self.pbc = data_point.pbc  # Periodic boundary conditions flag
        assert not self.pbc, "Periodic boundary conditions are not supported."

    def create_model(self):
        assert self.data_point is not None
        # Initialize the Hamiltonian builder
        ham_builder = qtn.SpinHam1D(S=1/2)  # Spin-1/2 Hamiltonian

        # Add single-site terms
        for i, (h, g) in enumerate(self.data_point.x_nodes):
            h, g = float(h), float(g)
            ham_builder[i] += 2*h, 'Z'
            ham_builder[i] += 2*g, 'X'

        # Add interaction terms, respecting PBC if necessary
        for (a, b), J_ab in zip(self.data_point.edge_index.T, self.data_point.x_edges):
            J_ab = float(J_ab)
            if J_ab != 0 or (abs(a - b) == 1 or (self.pbc and (abs(a - b) == n[0] - 1))):
                ham_builder[int(a), int(b)] +=  4*J_ab, 'Z', 'Z'
                ham_builder[int(b), int(a)] +=  4*J_ab, 'Z', 'Z'

        # Build the Hamiltonian MPO
        self.H_mpo = ham_builder.build_mpo(self.L)

    def run(self):
        # Setup and run DMRG
        dmrg = qtn.DMRG2(self.H_mpo)
        dmrg.opts['max_bond'] = self.max_bond
        dmrg.opts['cutoff'] = self.cutoff
        dmrg.solve(tol = self.tolerance, verbosity=self.verbosity)
        self.psi = dmrg.state
        self.energy = dmrg.energy

    def getEnergy(self):
        return self.energy.real if self.energy else None

    def getRDMs2(self):
        one_rdms = []
        two_rdms = []

        for node_idx in range(self.data_point.x_nodes.shape[0]):
            rdm = self.psi.partial_trace(keep=[node_idx])
            assert len(rdm.tensors) == 1
            one_rdms.append(qu.normalize(rdm.tensors[0].data))

        for edge_idx in range(self.data_point.edge_index.shape[1]):
            edge = self.data_point.edge_index[:, edge_idx]
            two_rdm = self.psi.partial_trace(keep=[edge[0], edge[1]])
            keys = list(two_rdm.ind_map.keys())
            for ind in keys:
                if len(two_rdm.ind_map[ind]) > 1:
                    two_rdm.contract_ind(ind)
            #Contract any closed legs
            assert len(two_rdm.tensors) == 1 
            dim = int(np.sqrt(np.prod(two_rdm.tensors[0].data.shape)))
            two_rdms.append(qu.normalize(two_rdm.tensors[0].data.reshape(dim, dim)))
        
        return one_rdms, two_rdms
    
    def getRDMs(self):
        one_rdms = []
        two_rdms = []

        dims = [2] * self.L
            

        def compute_rdm(peps, sites, dims):
            """
            Compute the RDM for a list of sites in a PEPS.

            Parameters:
            peps (PEPS): The PEPS representing the quantum state.
            sites (List of tuples): The coordinates of the sites.
            dims (list): The dimensions of the Hilbert space at each site.

            Returns:
            numpy.ndarray: The 2-RDM of the specified sites.
            """          
            return qu.normalize(qu.partial_trace(peps, dims=dims, keep=sites))

        ground_state = self.psi.to_dense()

        for node_idx in range(self.data_point.x_nodes.shape[0]):
            one_rdms.append(compute_rdm(ground_state, [node_idx], dims))

        for edge_idx in range(self.data_point.edge_index.shape[1]):
            edge = self.data_point.edge_index[:, edge_idx]
            two_rdms.append(compute_rdm(ground_state, [edge[0], edge[1]], dims))
        
        return one_rdms, two_rdms

    def getGroundState(self):
        return self.psi

### Test on all the Dataset

In [None]:
import warnings

#data = torch.load('..\\..\\dataset\\ising\\data\\PEPS_4400_N200_16.pt')
data = torch.load('..\\..\\dataset\\ising\\data\\MPS_6000_N200_16.pt')
params_SU = {
    'chi': 15,
    'bond_dim': 4,
    'num_iters': 100,
    'tau': [0.1, 0.01, 0.001]
}

params_DMRG = {
    'max_E_err': 1.e-10,
    'chi_max': 30,
    'svd_min': 1.e-10
}

params_DMRG_QUIMB = {
    'max_bond': 60,
    'cutoff': 1.e-10,
    'tol': 1.e-6,
    'verbosity': 0
}

params_SU_gen = {
    'chi': 30,
    'max_bond':30,
    'bond_dim': 2,
    'num_iters': 100,
    'tau': [0.1, 0.01, 0.001]
}

def print_and_save(output):
    file = "results_aux.txt"
    print(output)
    with open(file, "a") as file_object:
        file_object.write(output + '\n')    

def matrix_sqrt(A):
    """Compute the square root of a positive semi-definite matrix, assuming A is Hermitian."""
    vals, vecs = torch.linalg.eigh(A)
    sqrt_vals = torch.sqrt(torch.clamp(vals, min=0))  # Ensure eigenvalues are non-negative
    sqrt_vals_complex = torch.diag_embed(sqrt_vals).to(dtype=torch.complex128)  # Use complex type
    return vecs @ sqrt_vals_complex @ vecs.conj().transpose(-2, -1)

def fidelity_torch(rho, sigma):
    """Calculate the fidelity between two density matrices rho and sigma."""
    # Compute the square root of rho
    sqrt_rho = matrix_sqrt(rho)
    
    # Compute the product sqrt_rho * sigma * sqrt_rho
    middle_product = sqrt_rho @ sigma @ sqrt_rho
    
    # Compute the square root of the middle product
    sqrt_middle_product = matrix_sqrt(middle_product)
    
    # Compute the trace of the sqrt_middle_product and then square the result for fidelity
    trace_value = torch.trace(sqrt_middle_product)
    
    # Return the square of the trace, since fidelity is the square of the trace of the square root of the middle product
    return trace_value.real**2


def isMPS(data_point):
    return len(data_point.grid_extent) == 1 or data_point.grid_extent[1] == 1

def checkEnergy(energy, label):
    return np.allclose(energy, label, atol=1e-2)

def checkRDMs(rdms, labels):
    errors = []
    for i in range(len(rdms)):
        errors.append(fidelity_torch(torch.tensor(rdms[i].copy(), dtype=torch.cdouble), torch.tensor(labels[i], dtype=torch.cdouble)).numpy())
    return np.allclose(errors, 0.1), errors

#Ignore UserWarnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for idx in [200,600]:
        data_point = data[idx]
        print(data_point.grid_extent)
        if isMPS(data_point):
            print_and_save(f"{idx}, MPS: {data_point.grid_extent}, PBC: {data_point.pbc}")
            # alg = DMRG(params_DMRG)
            alg = DMRG_QUIMB(params_DMRG_QUIMB)
        else:
            print_and_save(f"{idx}, PEPS: {data_point.grid_extent}, PBC: {data_point.pbc}")
            alg = SimpleUpdate(params_SU)
        # alg = SimpleUpdateGen(params_SU_gen)
        alg.set_datapoint(data_point)
        alg.create_model()
        alg.run()
        correct_energy = checkEnergy(alg.getEnergy(), data_point.y_energy)

        print("Predicted energy:", alg.getEnergy())
        print("Label energy:", data_point.y_energy)
        print("PBC: ", data_point.pbc)


        one_rdms, two_rdms = alg.getRDMs()
        correct_one_rdms , one_rdm_errors = checkRDMs(one_rdms, data_point.y_node_rdms)
        correct_two_rdms , two_rdm_errors = checkRDMs(two_rdms, data_point.y_edge_rdms)

        #print total distance between calculated and label RDMs
        # print_and_save(f"Total distance between calculated and label 1RDMs: {np.array(one_rdm_errors)}")
        # print_and_save(f"Total distance between calculated and label 2RDMs: {np.array(two_rdm_errors)}")
        #Average distance between calculated and label RDMs
        print_and_save(f"Average distance between calculated and label 1RDMs: {np.mean(one_rdm_errors)}")
        print_and_save(f"Average distance between calculated and label 2RDMs: {np.mean(two_rdm_errors)}")
        

        print_and_save(f'Correct Energy: {correct_energy}, Correct 1RDMs: {correct_one_rdms}, Correct 2RDMs: {correct_two_rdms}')

        # if not correct_energy or not correct_one_rdms or not correct_two_rdms:
        #     if not correct_energy:
        #         print_and_save(f"Energy Error: {np.linalg.norm(alg.getEnergy() - data_point.y_energy):.6f}")

        #     if not correct_one_rdms:
        #         pr = False
        #         one_rdm_errors = []
        #         for i in range(len(one_rdms)):
        #             error = np.linalg.norm(one_rdms[i] - data_point.y_node_rdms[i])
        #             if error > 0.05:
        #                 pr = True
        #                 one_rdm_errors.append(f"[({i}):{error:.4f}]")
        #             print(f"1RDM calc{i}\t{one_rdms[i]}")
        #             print(f"1RDM labl{i}\t{data_point.y_node_rdms[i]}")
        #         if pr:
        #             print_and_save("One RDM Errors:")
        #             print_and_save(",".join(one_rdm_errors))

        #     if not correct_two_rdms:
        #         pr = False
        #         two_rdm_errors = []
        #         for x in range(len(two_rdms)):
        #             i, j = data_point.edge_index[:, x]
        #             error = np.linalg.norm(two_rdms[x] - data_point.y_edge_rdms[x])
        #             if error > 0.05:
        #                 pr = True
        #                 two_rdm_errors.append(f"[{int(i), int(j)}:{error:.4f}]")
        #             print(f"2RDM calc{i,j}\t{two_rdms[x]}")
        #             print(f"2RDM labl{i,j}\t{data_point.y_edge_rdms[x]}")
        #         if pr:
        #             print_and_save("Two RDM Errors:")
        #             print_and_save(",".join(two_rdm_errors))