In [None]:
import numpy as np
from utils import *

import networkx as nx
from factor import Factor
from typing import List

Okay, second try. This time I'm actually implementing belief propagation. Three steps now.

1. Belief prop for my Factor class
2. Belief prop for tensor networks
3. (This is in `tn_simple_update.ipynb`!) Use the Simple- or Full-Update algorithm from `quimb` to do the belief prop (which is probably what we will use in the end)
4. (Here again) Implement belief prop as shown in the paper `Tensor networks contraction and the belief propagation algorithm` by Alkabetz and Arad

In [None]:
# Here is a network to work with

def construct_graph_from_edges(edges):
    """ Constructs a graph from a list of edges. Every edge is a tuple of the form (node1, node2, size). The tensor at each node is initialized to a random tensor of the corresponding shape. """

    def get_edge_data(graph, node, sort_by=0):
        """ Finds all edges connected to the given node and returns their corresponding data. """
        edges = []
        for n in graph.neighbors(node):
            for e, e_data in graph.get_edge_data(node, n).items():
                edges.append(list(e_data.values()))
        return np.unique(edges, axis=sort_by).T

    graph = nx.MultiGraph()
    for i, edge in enumerate(edges):
        var_label = chr(ord('i') + i)
        graph.add_edge(edge[0], edge[1], var=var_label, size=edge[2])

    for n in list(graph.nodes):
        variables, sizes = get_edge_data(graph, n)
        variables, sizes = list(variables), [int(s) for s in sizes]
        tensor = normalize(np.random.rand(*sizes), p=1, axis=0)
        graph.add_node(n, factor=Factor(variables, tensor))

    return graph

def draw_graph(graph):
    def get_edge_labels(graph):
        edge_labels = {}
        for i, j, d in graph.edges(data=True):
            if (i,j) in edge_labels:
                edge_labels[(i,j)] += f", {d['var']}"
            else:
                edge_labels[(i,j)] = str(d['var'])
        for (i, j), s in edge_labels.items():
            if i == j:
                edge_labels[(i,j)] = "______" + s
        return edge_labels

    pos = nx.spring_layout(graph)
    nx.draw(graph, pos, with_labels=True, width=[3**(graph.number_of_edges(i,j)-1) for i, j, d in graph.edges(data=True)])
    nx.draw_networkx_edge_labels(graph, pos, edge_labels=get_edge_labels(graph))

edges = [('a', 'a', 2), ('a', 'c', 2), ('b', 'c', 3), ('c', 'd', 2), ('c', 'd', 3), ('d', 'd', 2)] # each edge corresponds to a variable / index (networkx doesn't support hyperedges)
net = construct_graph_from_edges(edges)
draw_graph(net)
print("Network:", {n :d['factor'].variables for n, d in net.nodes(data=True)})

None

In [None]:
def construct_graph_from_factors(factors: List[Factor]):
    """ Constructs a graph from a list of factors. Every factor is a tuple of the form (variables, tensor). """
    graph = nx.MultiGraph()
    for i, factor in enumerate(factors):
        # node_label = chr(ord('a') + i)
        node_label = ''.join(factor.variables)
        graph.add_node(node_label, factor=factor)

    for i, factor in enumerate(factors):
        # Find the node that contains the factor
        node = None
        for n, d in graph.nodes(data=True):
            if d['factor'] == factor:
                node = n
                break
        if node is None:
                raise Exception("Could not find node for factor:", factor)
        for v in factor.variables:
            for other_factor in factors:
                if v in other_factor.variables and other_factor != factor:
                    # print(f"{''.join(factor.variables)}: Found {v} in {''.join(other_factor.variables)}")
                    # Find the node that contains the other factor
                    other_node = None
                    for n, d in graph.nodes(data=True):
                        if d['factor'] == other_factor:
                            other_node = n
                            break
                    if  other_node is None:
                        raise Exception("Could not find node for factor:", other_factor)
                    # Add an edge between the two nodes if there is not already the opposite edge
                    if not (graph.has_edge(node, other_node) or graph.has_edge(other_node, node)):
                        graph.add_edge(node, other_node, var=v, size=factor.data.shape[factor.variables.index(v)])
            # If there is no other, add a self-loop
            if not any([v in other_factor.variables for other_factor in factors if other_factor != factor]):
                # print(f"{''.join(factor.variables)}: Found {v} in no other factor")
                graph.add_edge(node, node, var=v, size=factor.data.shape[factor.variables.index(v)])

    return graph

# This is the tree in Barber Figure 14.1 and 14.2
net_barber = construct_graph_from_factors([
    Factor('A', np.array([0.01, 0.99])),
    Factor('AB', np.array([[0.1, 0.9], [0.001, 0.999]])),
    Factor('C', np.array([0.001, 0.999])),
    Factor('BCD', np.array([[[0.99, 0.01], [0.9, 0.1]], [[0.95, 0.05], [0.01, 0.99]]])),
    Factor('DE', np.array([[0.9, 0.1], [0.3, 0.7]])),
    Factor('DF', np.array([[0.2, 0.8], [0.1, 0.9]])),
])
# remove edge between 'DE' and 'DF'
net_barber.remove_edge('DE', 'DF')
draw_graph(net_barber)
print("Network:", {n :d['factor'].variables for n, d in net_barber.nodes(data=True)})

In [None]:
# Here is an attempt to add tensors to both nodes and edges
# tree = nx.Graph()

# nodes = [(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 2)]
# edges = [(0, 1), (0, 2), (1, 3), (1, 4), (2, 5)]
# # add a size attribute to each node
# for node, size in nodes:
#     tree.add_node(node, size=size, f=np.random.rand(size))
# # add a tensor to each edge as an attribute with the size of the respective nodes
# for i,j in edges:
#     f = np.random.rand(tree.nodes[i]['size'], tree.nodes[j]['size'])
#     # tensor = normalize(tensor, 1, axis=1)
#     tree.add_edge(i, j, f=f)

# 1. Belief propagation

In [None]:
def belief_prop(graph, query, max_iterations=100, eps=1e-6):
    """
        Computes the marginal distribution of the given query variables using the belief propagation algorithm.    
    
        Parameters
            graph (nx.Graph): The graph for which the marginals should be computed.
            Each node should have the attribute 'factor' containing a Factor object, and each edge should have 
            an attribute 'var' containing the name of the variable present the in factors connected by the edge.

            query (list[str]): A list of variables for which the marginals should be computed

        Returns
            np.ndarray: An array containing the marginals for the given nodes
    """

    # Big table of messages
    messages = {}

    # Initialize messages
    for node in graph.nodes:
        f = graph.nodes[node]['factor']
        for neighbor in graph.neighbors(node):
            # incoming messages to the node
            messages[(neighbor, node)] = Factor(
                potentials=np.ones(f.data.shape),
                variables=f.variables.copy()
            )

    # Iterate until convergence
    for i in range(max_iterations):
        # print("Iteration", i)
        old_messages = messages.copy()
        # Compute messages
        for node in graph.nodes:
            for neighbor in graph.neighbors(node):
                # print(f"Computing message {node} -> {neighbor}")
                # Aggregate messages from other neighbors
                aggregate = Factor()
                for other_neighbor in graph.neighbors(node):
                    if other_neighbor != neighbor:
                        aggregate = aggregate * messages[(other_neighbor, node)]
                # Multiply with the marginal of the node
                total = graph.nodes[node]['factor'] * aggregate
                # Sum out all variables that are not in the neighbor
                for v in total.variables:
                    if v not in graph.nodes[neighbor]['factor'].variables:
                        total = total.marginalize(v)
                # print("total:",total.variables, "neighbor:", graph.nodes[neighbor]['factor'].variables)
                # Normalize the message
                message = total.normalize()
                # print(f"New message {node} -> {neighbor}:", message)
                # Update the message
                if i > 0 and (message.variables != old_messages[(node, neighbor)].variables):
                    print(f"Message variables {message.variables} do not match old message variables {old_messages[(node, neighbor)].variables}")
                messages[(node, neighbor)] = message
        # Check for convergence
        if i == 0:
            continue
        error = 0
        for key in messages:
            error += np.sum(np.abs(messages[key].data - old_messages[key].data))
        # print("Error:", error)
        if error < eps:
            print("Converged after", i, "iterations")
            break

    # Compute marginals for the query variables
    marginal = Factor()
    for node in graph.nodes:
        if any([v in graph.nodes[node]['factor'].variables for v in query]):
            marginal = marginal * graph.nodes[node]['factor']
    for v in marginal.variables:
        if v not in query:
            marginal = marginal.marginalize(v)

    marginal.transpose(query, inplace=True)
    return marginal.normalize()

belief_prop(net, 'ij'), belief_prop(net_barber, 'C')

# 2. Belief propagation for tensor networks

In [None]:
import quimb as qu
import quimb.tensor as qtn

In [None]:
def tn_mul(t1, t2):
        """ Multiply two tensors element-wise, adding new indices if necessary. New indices are initialized to 1. """
        def add_indices(target, like):
            """ Add all indices of `like` to `target` if they are not already present. """
            for i in reversed(like.inds):
                if i not in target.inds:
                    target.new_ind(i, size=0, axis=-1) # size 0 is a placeholder and will be updated later automatically

        t1 = t1.copy()
        t2 = t2.copy()
        # add all extra indices of t2 to t1 and vice versa
        add_indices(target=t1, like=t2)
        add_indices(target=t2, like=t1)
        # ensure t2 indices have the same order as t1 indices
        t2.transpose(*t1.inds, inplace=True)
        # elementwise multiplication
        return t1 * t2

# Equivalent to `belief_prop` above, but using `quimb.tensor.TensorNetwork` instead of `networkx.Graph`
def belief_prop_tn(tn, query, max_iterations=100, eps=1e-6):
    """ Computes the marginals for given indices in `query`. """
    # Initialize messages
    messages = {}
    for n in tn.tensors:
        # incoming messages to n
        for m in tn.select_neighbors(n.tags):
            messages[(list(n.tags)[0], list(m.tags)[0])] = qtn.Tensor(
                data=np.ones(n.data.shape),
                inds=n.inds,
                tags=n.tags
            )

    # Iterate until convergence
    for i in range(max_iterations):
        # print("Iteration", i)
        old_messages = messages.copy()
        # Compute messages
        for n in tn.tensors:
            for m in tn.select_neighbors(n.tags):
                # print(f"Computing message {n} -> {m}")
                # Aggregate messages from other neighbors
                aggregate = qtn.Tensor(data=1, inds=[], tags=[])
                for other_neighbor in tn.select_neighbors(n.tags):
                    if other_neighbor != m:
                        aggregate = tn_mul(aggregate, messages[(list(other_neighbor.tags)[0], list(n.tags)[0])])
                # Multiply with the marginal of the node
                total = tn_mul(n, aggregate)
                # Sum out all variables that are not in the neighbor
                for v in total.inds:
                    if v not in m.inds:
                        total = total.sum_reduce(v)
                # print("total:",total.inds, "neighbor:", tn[m].inds)
                # Normalize the message
                message = total / total.data.sum()
                # print(f"New message {n} -> {m}:", message)
                # Update the message
                if i > 0 and (message.inds != old_messages[(list(n.tags)[0], list(m.tags)[0])].inds):
                    print(f"Message indices {message.inds} do not match old message indices {old_messages[(n.tags, m.tags)].inds}")
                messages[(list(n.tags)[0], list(m.tags)[0])] = message
        # Check for convergence
        if i == 0:
            continue
        error = 0
        for key in messages:
            error += np.sum(np.abs(messages[key].data - old_messages[key].data))
        # print("Error:", error)
        if error < eps:
            print("Converged after", i, "iterations")
            break

    # Compute marginals for the query variables
    marginal = qtn.Tensor(data=1, inds=[], tags=[])
    for n in tn.tensors:
        if any([v in n.inds for v in query]):
            marginal = tn_mul(marginal, n)
    for v in marginal.inds:
        if v not in query:
            marginal = marginal.sum_reduce(v)

    marginal.transpose(*query, inplace=True)
    return marginal / marginal.data.sum()

def get_tn(graph):
    """ Returns a tensor network from the given graph. """
    T = qtn.TensorNetwork()
    for n, n_data in graph.nodes(data=True):
        tensor = qtn.Tensor(n_data['factor'].data, inds=n_data['factor'].variables, tags=n)
        T |= tensor
    return T

In [None]:
# Define a tensor network from the graph `net`
T = get_tn(net)
# T.draw()
query = 'ij'

assert np.allclose(belief_prop(net, query).data, belief_prop_tn(T, query).data)

belief_prop_tn(T, query)

In [None]:
T_barber = get_tn(net_barber)
query = 'C'
assert np.allclose(belief_prop(net_barber, query).data, belief_prop_tn(T_barber, query).data)
belief_prop_tn(T_barber, query)

In [None]:
t1 = qtn.Tensor(np.random.rand(2, 3, 4), inds='abc', tags='t1')
t2 = qtn.Tensor(np.random.rand(3, 4, 5), inds='bcd', tags='t2')
tn = t1 | t2
list(t1.tags)[0]

4. Belief propagation for tensor networks 2

In [None]:
import quimb as qu
import quimb.tensor as qtn

def tn_mul(t1, t2):
        """ Multiply two tensors element-wise, adding new indices if necessary. New indices are initialized to 1. """
        def add_indices(target, like):
            """ Add all indices of `like` to `target` if they are not already present. """
            for i in reversed(like.inds):
                if i not in target.inds:
                    target.new_ind(i, size=0, axis=-1) # size 0 is a placeholder and will be updated later automatically

        t1 = t1.copy()
        t2 = t2.copy()
        # add all extra indices of t2 to t1 and vice versa
        add_indices(target=t1, like=t2)
        add_indices(target=t2, like=t1)
        # ensure t2 indices have the same order as t1 indices
        t2.transpose(*t1.inds, inplace=True)
        # elementwise multiplication
        return t1 * t2

# Equivalent to `belief_prop` above, but using `quimb.tensor.TensorNetwork` instead of `networkx.Graph`
def belief_prop_tn(tn, query, max_iterations=100, eps=1e-6):
    """ Computes the marginals for given indices in `query`. """
    # Initialize messages
    messages = {}
    for n in tn.tensors:
        # incoming messages to n
        for m in tn.select_neighbors(n.tags):
            messages[(list(n.tags)[0], list(m.tags)[0])] = qtn.Tensor(
                data=np.ones(n.data.shape),
                inds=n.inds,
                tags=n.tags
            )

    # Iterate until convergence
    for i in range(max_iterations):
        # print("Iteration", i)
        old_messages = messages.copy()
        # Compute messages
        for n in tn.tensors:
            for m in tn.select_neighbors(n.tags):
                # print(f"Computing message {n} -> {m}")
                # Aggregate messages from other neighbors
                aggregate = qtn.Tensor(data=1, inds=[], tags=[])
                for other_neighbor in tn.select_neighbors(n.tags):
                    if other_neighbor != m:
                        aggregate = tn_mul(aggregate, messages[(list(other_neighbor.tags)[0], list(n.tags)[0])])
                # Multiply with the marginal of the node
                total = tn_mul(n, aggregate)
                # Sum out all variables that are not in the neighbor
                for v in total.inds:
                    if v not in m.inds:
                        total = total.sum_reduce(v)
                # print("total:",total.inds, "neighbor:", tn[m].inds)
                # Normalize the message
                message = total / total.data.sum()
                # print(f"New message {n} -> {m}:", message)
                # Update the message
                if i > 0 and (message.inds != old_messages[(list(n.tags)[0], list(m.tags)[0])].inds):
                    print(f"Message indices {message.inds} do not match old message indices {old_messages[(n.tags, m.tags)].inds}")
                messages[(list(n.tags)[0], list(m.tags)[0])] = message
        # Check for convergence
        if i == 0:
            continue
        error = 0
        for key in messages:
            error += np.sum(np.abs(messages[key].data - old_messages[key].data))
        # print("Error:", error)
        if error < eps:
            print("Converged after", i, "iterations")
            break

    # Compute marginals for the query variables
    marginal = qtn.Tensor(data=1, inds=[], tags=[])
    for n in tn.tensors:
        if any([v in n.inds for v in query]):
            marginal = tn_mul(marginal, n)
    for v in marginal.inds:
        if v not in query:
            marginal = marginal.sum_reduce(v)

    marginal.transpose(*query, inplace=True)
    return marginal / marginal.data.sum()

def get_tn(graph):
    """ Returns a tensor network from the given graph. """
    T = qtn.TensorNetwork()
    for n, n_data in graph.nodes(data=True):
        tensor = qtn.Tensor(n_data['factor'].data, inds=n_data['factor'].variables, tags=n)
        T |= tensor
    return T

# 5. Simple Update Algorithm with tnsu package for Hamiltonians

## Load Dataset

In [None]:
import sys
import os
import torch
import numpy as np
sys.path.insert(1, '..\\..\\dataset\\ising')


from isingModel import IsingModelDataset

data_file = "..\\..\\dataset\\ising\\data\\nk2_2000_12_True.pt"
data_file = "..\\..\\dataset\\ising\\data\\nk_2000_12_True.pt"
# Load the dataset
dataset = IsingModelDataset.load(data_file)

In [None]:
rand_idx = int(torch.randint(len(dataset), (1,)))
data_point = dataset[rand_idx]

idx = -1

for _idx, point in enumerate(dataset):
    if len(point.x_nodes) == 4:
        idx = _idx
assert idx != -1
data_point = dataset[idx]
print(data_point)


In [None]:
data_point = torch.load("test_point.pt")

In [None]:
import scipy.sparse as sp
from functools import reduce

matmap_np, matmap_sp = None, None

def parse_hamiltonian(hamiltonian, sparse=False, scaling=1, buffer=None, max_buffer_n=0, dtype=float): # I'd usually default to complex, but because we're only dealing with Ising models here, float is more handy
    """Parse a string representation of a Hamiltonian into a matrix representation. The result is guaranteed to be Hermitian.

    Parameters:
        hamiltonian (str): The Hamiltonian to parse.
        sparse (bool): Whether to use sparse matrices (csr_matrix) or dense matrices (numpy.array).
        scaling (float): A constant factor to scale the Hamiltonian by.
        buffer (dict): A dictionary to store calculated chunks in. If `None`, it defaults to the global `matmap_np` (or `matmap_sp` if `sparse == True`). Give `buffer={}` and leave `max_buffer_n == 0` (default) to disable the buffer.
        max_buffer_n (int): The maximum length (number of qubits) for new chunks to store in the buffer (default: 0). If `0`, no new chunks will be stored in the buffer.

    Returns:
        numpy.ndarray | scipy.sparse.csr_matrix: The matrix representation of the Hamiltonian.

    Example:
    >>> parse_hamiltonian('0.5*(XX + YY + ZZ + II)') # SWAP
    array([[ 1.+0.j  0.+0.j  0.+0.j  0.+0.j]
           [ 0.+0.j  0.+0.j  1.+0.j  0.+0.j]
           [ 0.+0.j  1.+0.j  0.+0.j  0.+0.j]
           [ 0.+0.j  0.+0.j  0.+0.j  1.+0.j]])
    >>> parse_hamiltonian('-(XX + YY + .5*ZZ) + 1.5')
    array([[ 1.+0.j  0.+0.j  0.+0.j  0.+0.j]
           [ 0.+0.j  2.+0.j -2.+0.j  0.+0.j]
           [ 0.+0.j -2.+0.j  2.+0.j  0.+0.j]
           [ 0.+0.j  0.+0.j  0.+0.j  1.+0.j]])
    >>> parse_hamiltonian('0.5*(II + ZI - ZX + IX)') # CNOT

    """
    kron = sp.kron if sparse else np.kron

    # Initialize the matrix map
    global matmap_np, matmap_sp
    if matmap_np is None or matmap_sp is None or matmap_np["I"].dtype != dtype:
        # numpy versions
        matmap_np = {
            "H": np.array([[1, 1], [1, -1]], dtype=dtype) / np.sqrt(2),
            "X": np.array([[0, 1], [1, 0]], dtype=dtype),
            "Z": np.array([[1, 0], [0, -1]], dtype=dtype),
            "I": np.array([[1, 0], [0, 1]], dtype=dtype),
        }
        # composites
        matmap_np.update({
            "ZZ": np.kron(matmap_np['Z'], matmap_np['Z']),
            "IX": np.kron(matmap_np['I'], matmap_np['X']),
            "XI": np.kron(matmap_np['X'], matmap_np['I']),
            "YY": np.array([[ 0,  0,  0, -1],  # to avoid complex numbers
                            [ 0,  0,  1,  0],
                            [ 0,  1,  0,  0],
                            [-1,  0,  0,  0]], dtype=dtype)
        })
        for i in range(2, 11):
            matmap_np["I"*i] = np.eye(2**i, dtype=dtype)
        # add 'Y' only if dtype supports imaginary numbers
        if np.issubdtype(dtype, np.complexfloating):
            matmap_np["Y"] = np.array([[0, -1j], [1j, 0]], dtype=dtype)

        # sparse versions
        matmap_sp = {k: sp.csr_array(v) for k, v in matmap_np.items()}
    
    if not np.issubdtype(dtype, np.complexfloating) and "Y" in hamiltonian:
        raise ValueError(f"The Pauli matrix Y is not supported for dtype {dtype.__name__}.")

    matmap = matmap_sp if sparse else matmap_np

    # only use buffer if pre-computed chunks are available or if new chunks are allowed to be stored
    use_buffer = buffer is None or len(buffer) > 0 or max_buffer_n > 0
    if use_buffer and buffer is None:
        buffer = matmap

    def calculate_chunk_matrix(chunk, sparse=False, scaling=1):
        # if scaling != 1:  # only relevant for int dtype
            # scaling = np.array(scaling, dtype=dtype)
        if use_buffer:
            if chunk in buffer:
                return buffer[chunk] if scaling == 1 else scaling * buffer[chunk]
            if len(chunk) == 1:
                return matmap[chunk[0]] if scaling == 1 else scaling * matmap[chunk[0]]
            # Check if a part of the chunk has already been calculated
            for i in range(len(chunk)-1, 1, -1):
                for j in range(len(chunk)-i+1):
                    subchunk = chunk[j:j+i]
                    if subchunk in buffer:
                        # If so, calculate the rest of the chunk recursively
                        parts = [chunk[:j], subchunk, chunk[j+i:]]
                        # remove empty chunks
                        parts = [c for c in parts if c != ""]
                        # See where to apply the scaling
                        shortest = min(parts, key=len)
                        # Calculate each part recursively
                        for i, c in enumerate(parts):
                            if c == subchunk:
                                if c == shortest:
                                    parts[i] = scaling * buffer[c]
                                    shortest = ""
                                else:
                                    parts[i] = buffer[c]
                            else:
                                if c == shortest:
                                    parts[i] = calculate_chunk_matrix(c, sparse=sparse, scaling=scaling)
                                    shortest = ""
                                else:
                                    parts[i] = calculate_chunk_matrix(c, sparse=sparse, scaling=1)
                        return reduce(kron, parts)

        # Calculate the chunk matrix gate by gate
        if use_buffer and len(chunk) <= max_buffer_n:
            gates = [matmap[gate] for gate in chunk]
            chunk_matrix = reduce(kron, gates)
            buffer[chunk] = chunk_matrix
            if scaling != 1:
                chunk_matrix = scaling * chunk_matrix
        else:
            gates = [scaling * matmap[chunk[0]]] + [matmap[gate] for gate in chunk[1:]]
            chunk_matrix = reduce(kron, gates)

        return chunk_matrix

    # Remove whitespace
    hamiltonian = hamiltonian.replace(" ", "")
    # replace - with +-, except before e
    hamiltonian = hamiltonian \
                    .replace("-", "+-") \
                    .replace("e+-", "e-") \
                    .replace("(+-", "(-")

    # print("parse_hamiltonian: Pre-processed Hamiltonian:", hamiltonian)

    # Find parts in parentheses
    part = ""
    parts = []
    depth = 0
    current_part_weight = ""
    for i, c in enumerate(hamiltonian):
        if c == "(":
            if depth == 0:
                # for top-level parts search backwards for the weight
                weight = ""
                for j in range(i-1, -1, -1):
                    if hamiltonian[j] in ["("]:
                        break
                    weight += hamiltonian[j]
                    if hamiltonian[j] in ["+", "-"]:
                        break
                weight = weight[::-1]
                if weight != "":
                    current_part_weight = weight
            depth += 1
        elif c == ")":
            depth -= 1
            if depth == 0:
                part += c
                parts.append((current_part_weight, part))
                part = ""
                current_part_weight = ""
        if depth > 0: 
            part += c

    # print("Parts found:", parts)

    # Replace parts in parentheses with a placeholder
    for i, (weight, part) in enumerate(parts):
        hamiltonian = hamiltonian.replace(weight+part, f"+part{i}", 1)
        # remove * at the end of the weight
        if weight != "" and weight[-1] == "*":
            weight = weight[:-1]
        if weight in ["", "+", "-"]:
            weight += "1"
        # Calculate the part recursively
        part = part[1:-1] # remove parentheses
        parts[i] = parse_hamiltonian(part, sparse=sparse, scaling=float(weight), buffer=buffer, max_buffer_n=max_buffer_n, dtype=dtype)

    # print("Parts replaced:", parts)

    # Parse the rest of the Hamiltonian
    chunks = hamiltonian.split("+")
    # Remove empty chunks
    chunks = [c for c in chunks if c != ""]
    # If parts are present, use them to determine the number of qubits
    if parts:
        n = int(np.log2(parts[0].shape[0]))
    else: # Use chunks to determine the number of qubits
        n = 0
        for c in chunks:
            if c[0] in ["-", "+"]:
                c = c[1:]
            if "*" in c:
                c = c.split("*")[1]
            if c.startswith("part"):
                continue
            try:
                float(c)
                continue
            except ValueError:
                n = len(c)
                break
        if n == 0:
            print("Warning: Hamiltonian is a scalar!")

    if not sparse and n > 10:
        # check if we would blow up the memory
        mem_required = 2**(2*n) * np.array(1, dtype=dtype).nbytes
        mem_available = psutil.virtual_memory().available
        if mem_required > mem_available:
            raise MemoryError(f"This would blow up you memory ({duh(mem_required)} required)! Try using `sparse=True`.")

    if sparse:
        H = sp.csr_array((2**n, 2**n), dtype=dtype)
    else:
        if n > 10:
            print(f"Warning: Using a dense matrix for a {n}-qubit Hamiltonian is not recommended. Use sparse=True.")
        H = np.zeros((2**n, 2**n), dtype=dtype)

    for chunk in chunks:
        # print("Processing chunk:", chunk)
        chunk_matrix = None
        if chunk == "":
            continue
        # Parse the weight of the chunk
        
        if chunk.startswith("part"):
            weight = 1  # parts are already scaled
            chunk_matrix = parts[int(chunk.split("part")[1])]
        elif "*" in chunk:
            weight = float(chunk.split("*")[0])
            chunk = chunk.split("*")[1]
        elif len(chunk) == n+1 and chunk[0] in ["-", "+"] and n >= 1 and chunk[1] in matmap:
            weight = float(chunk[0] + "1")
            chunk = chunk[1:]
        elif (chunk[0] in ["-", "+", "."] or chunk[0].isdigit()) and all([c not in matmap for c in chunk[1:]]):
            if len(chunk) == 1 and chunk[0] in ["-", "."]:
                chunk = 0
            weight = complex(chunk)
            if np.iscomplex(weight):
                raise ValueError("Complex scalars would make the Hamiltonian non-Hermitian!")
            weight = weight.real
            # weight = np.array(weight, dtype=dtype)  # only relevant for int dtype
            chunk_matrix = np.eye(2**n, dtype=dtype)
        elif len(chunk) != n:
            raise ValueError(f"Gate count must be {n} but was {len(chunk)} for chunk \"{chunk}\"")
        else:
            weight = 1

        if chunk_matrix is None:
            chunk_matrix = calculate_chunk_matrix(chunk, sparse=sparse, scaling = scaling * weight)
        elif scaling * weight != 1:
            chunk_matrix = scaling * weight * chunk_matrix

        # Add the chunk to the Hamiltonian
        # print("Adding chunk", weight, chunk, "for hamiltonian", scaling, hamiltonian)
        # print(type(H), H.dtype, type(chunk_matrix), chunk_matrix.dtype)
        if len(chunks) == 1:
            H = chunk_matrix
        else:
            H += chunk_matrix

    if sparse:
        assert np.allclose(H.data, H.conj().T.data), f"The given Hamiltonian {hamiltonian} is not Hermitian: {H.data}"
    else:
        assert np.allclose(H, H.conj().T), f"The given Hamiltonian {hamiltonian} is not Hermitian: {H}"

    return H

In [None]:
print(data_point.hamiltonian)
hamiltonian = parse_hamiltonian(data_point.hamiltonian)
print(hamiltonian.shape)

In [None]:
def get_dims(data_point):
    dims = data_point.grid_extent

    assert len(dims) <= 2
    
    if len(data_point.grid_extent) == 1:
        if dims[0] == 0:
            dims[0] = len(data_point.x_nodes)
        dims = np.append(dims, 1)

    assert np.prod(dims) == len(data_point.x_nodes)
    return dims

dims = get_dims(data_point)

In [None]:
import numpy as np

def rectangular_peps_pbc(height: int, width: int):
    """
    Creates a structure matrix of a rectangular lattice tensor network with periodic boundary
    conditions (pbc) of shape (height x width). The total number of tensors in the network would be height x width.
    :param height: The height of the tensor network.
    :param width: The width of the tensor network.
    :return: a structure matrix
    """

    # create tuples of tensor indices
    edge_list = []
    for i in range(height):
        for j in range(width):
            if height > 1:
                i_down = (i + 1) % height
                edge_list.append((i, j, 4, i_down, j, 2))
            
            if width > 1:
                j_right = (j + 1) % width
                edge_list.append((i, j, 3, i, j_right, 1))
    
    structure_matrix = np.zeros(shape=[height * width, len(edge_list)], dtype=int)
    # fill in the structure matrix
    for edge_idx, edge in enumerate(edge_list):
        node_a_idx = np.ravel_multi_index([edge[0], edge[1]], (height, width))
        node_b_idx = np.ravel_multi_index([edge[3], edge[4]], (height, width))

        structure_matrix[node_a_idx, edge_idx] = edge[2]
        structure_matrix[node_b_idx, edge_idx] = edge[5]

    # reorder dimension according to a constant order
    for i in range(structure_matrix.shape[0]):
        row = structure_matrix[i, np.nonzero(structure_matrix[i, :])[0]]
        new_row = np.array(range(1, len(row) + 1))
        order = np.argsort(row)
        new_row = new_row[order]
        structure_matrix[i, np.nonzero(structure_matrix[i, :])[0]] = new_row

    return structure_matrix



def rectangular_peps_obc(height: int, width: int):
    """
    Creates a structure matrix of a rectangular lattice tensor network with open (non-periodic) boundary
    conditions (obc) of shape (height x width). The total number of tensors in the network would be height x width.
    :param height: The height of the tensor network.
    :param width: The width of the tensor network.
    :return: a structure matrix
    """
    # edge = (node_a i, node_a j, node_a l, node_b i, node_b j, node_b l)

    # create tuples of tensor indices
    edge_list = []
    for i in range(height):
        for j in range(width):
            if i < height - 1:
                edge_list.append((i, j, 4, i + 1, j, 2))
            if j < width - 1:
                edge_list.append((i, j, 3, i, j + 1, 1))
    structure_matrix = np.zeros(shape=[height * width, len(edge_list)], dtype=int)

    # fill in the structure matrix
    for edge_idx, edge in enumerate(edge_list):
        node_a_idx = np.ravel_multi_index([edge[0], edge[1]], (height, width))
        node_b_idx = np.ravel_multi_index([edge[3], edge[4]], (height, width))

        structure_matrix[node_a_idx, edge_idx] = edge[2]
        structure_matrix[node_b_idx, edge_idx] = edge[5]

    # reorder dimension according to a constant order
    for i in range(structure_matrix.shape[0]):
        row = structure_matrix[i, np.nonzero(structure_matrix[i, :])[0]]
        new_row = np.array(range(1, len(row) + 1))
        order = np.argsort(row)
        new_row = new_row[order]
        structure_matrix[i, np.nonzero(structure_matrix[i, :])[0]] = new_row
    return structure_matrix



#This generates Rectangular PEPS with OBC or PBC. Structure Matrix defined in https://arxiv.org/pdf/1808.00680.pdf
def get_structure_matrix(data_point, circular = False):
    """ Returns the structure matrix from a given data point. """
    dims = get_dims(data_point)
    if circular:
        smat = rectangular_peps_pbc(dims[0], dims[1])
    else:
        smat = rectangular_peps_obc(dims[0], dims[1])
    return smat


smat = get_structure_matrix(data_point, circular=True)
print(smat)

In [None]:
import numpy as np
import tnsu.tensor_network as tn
from simple_update import SimpleUpdate


tensornet = tn.TensorNetwork(structure_matrix=smat, virtual_dim=2, spin_dim=2)


# pauli matrices
pauli_x = np.array([[0, 1],
                    [1, 0]])
pauli_z = np.array([[1., 0],
                    [0, -1]])

# ITE time constants
dts = [0.1, 0.01, 0.001, 0.0001, 0.00001]

# Local spin operators
s_i = [pauli_x]
s_j = [pauli_z]

# The field-spin operators 
s_k = [pauli_z]

# The maximal virtual bond dimension (used for SU truncation)
d_max = 5

# The Hamiltonian's 2-body interaction constants 
j_ij = data_point.x_edges.squeeze().numpy()

# The Hamiltonian's 1-body field constant
h_k = data_point.x_nodes.squeeze().numpy()

USE_HAMILTONIAN = False



su = SimpleUpdate(tensor_network=tensornet, 
                          dts=dts, 
                          j_ij=j_ij, 
                          h_k=h_k, 
                          s_i=s_i, 
                          s_j=s_j, 
                          s_k=s_k,
                          hamiltonian=hamiltonian if USE_HAMILTONIAN else None, 
                          d_max=d_max, 
                          max_iterations=200, 
                          convergence_error=1e-6, 
                          log_energy=True,
                          print_process=True)


su.run()

## Check Values

In [None]:
#Obtain Labels
y_energy = data_point.y_energy
one_rdms = data_point.y_node_rdms
two_rdms = data_point.y_edge_rdms

#Obtain Predictions
pred_energy = su.energy_per_site()
print("Calculated Energy: ", pred_energy)
print("True Energy: ", y_energy.item())


#Compare the first one-body RDM
pred_one_rdm = su.tensor_rdm(tensor_index=0)
print("Calculated One RDM: ", pred_one_rdm)
print("True One RDM: ", one_rdms[0])


