# Defining Coupling Graph
---
### Basic Information
**Description:** This script defines the hardware architectures, represented by coupling graphs, for the numerical experiments.

In [19]:
import rustworkx as rx
from rustworkx.visualization import mpl_draw

In [20]:
def set_grid_graph(width, height, draw=False):
    # construct a square grid graph, all nodes are labelled with their position

    # width = 10
    # height = 10
    types = [0]

    G = rx.PyGraph()

    node_dict = {} # mapping from qubit index to position
    node_dict_reverse = {} # mapping from qubit position to index

    # add nodes to the graph
    count = 0
    for index_y in range(height):
        for index_x in range(width):
            label = (index_x, index_y, types[0])
            G.add_node(label)
            node_dict.update({count: label})
            node_dict_reverse.update({label: count})
            count += 1

    # add edges in the horizontal direction
    for index_y in range(height):
        for index_x in range(width - 1):
            node_1 = node_dict_reverse[(index_x, index_y, types[0])]
            node_2 = node_dict_reverse[(index_x + 1, index_y, types[0])]
            G.add_edge(node_1, node_2, None)

    # add edges in the vertical direction
    for index_y in range(height - 1):
        for index_x in range(width):
            node_1 = node_dict_reverse[(index_x, index_y, types[0])]
            node_2 = node_dict_reverse[(index_x, index_y + 1, types[0])]
            G.add_edge(node_1, node_2, None)

    # draw the graph
    if draw:
        mpl_draw(G, with_labels=True)
    return [G, node_dict, node_dict_reverse]

In [21]:
def set_octagonal_graph(width, height, draw=False):
    # construct an octagonal graph, all nodes are labelled with their position

    # width = 3
    # height = 3
    num_types = 8
    types = range(num_types)

    G = rx.PyGraph()

    node_dict = {}
    node_dict_reverse = {}

    # add nodes to the graph
    count = 0
    for index_y in range(height):
        for index_x in range(width):
            for type_node in types:
                label = (index_x, index_y, type_node)
                G.add_node(label)
                node_dict.update({count: label})
                node_dict_reverse.update({label: count})
                count += 1

    # add in-circle edges to the graph
    for index_y in range(height):
        for index_x in range(width):
            for type_node in types:
                node_1 = node_dict_reverse[(index_x, index_y, type_node)]
                node_2 = node_dict_reverse[(index_x, index_y, (type_node + 1) % num_types)]
                G.add_edge(node_1, node_2, None)

    # add inter-circle edges in the horizontal direction
    for index_y in range(height):
        for index_x in range(width - 1):
            node_1 = node_dict_reverse[(index_x, index_y, types[1])]
            node_2 = node_dict_reverse[(index_x, index_y, types[2])]
            node_3 = node_dict_reverse[(index_x + 1, index_y, types[5])]
            node_4 = node_dict_reverse[(index_x + 1, index_y, types[6])]
            G.add_edge(node_1, node_4, None)
            G.add_edge(node_2, node_3, None)

    # add inter-circle edges in the vertical direction
    for index_y in range(height - 1):
        for index_x in range(width):
            node_1 = node_dict_reverse[(index_x, index_y, types[0])]
            node_2 = node_dict_reverse[(index_x, index_y, types[7])]
            node_3 = node_dict_reverse[(index_x, index_y + 1, types[3])]
            node_4 = node_dict_reverse[(index_x, index_y + 1, types[4])]
            G.add_edge(node_1, node_3, None)
            G.add_edge(node_2, node_4, None)

    # draw the graph
    if draw:
        mpl_draw(G, with_labels=True)
    return [G, node_dict, node_dict_reverse]

In [22]:
def set_heavy_hex_graph(width, height, draw=False):
    # construct a heavy-hex graph, all nodes are labelled with their position

    # width = 3
    # height = 3
    num_types = 5
    types = range(num_types)

    G = rx.PyGraph()

    node_dict = {}
    node_dict_reverse = {}

    # add nodes to the graph
    row_width = width * (num_types - 1) + 3
    count = 0
    # add nodes in horizontal links
    for index_y in range(height + 1):
        for index_x in range(row_width):
            type_node = (index_x + 2 * index_y) % 4
            label = (int((index_x - 2 * index_y) // 4), index_y, type_node)
            G.add_node(label) 
            node_dict.update({count: label})
            node_dict_reverse.update({label: count})
            count += 1
    # add nodes in vertical links
    for index_y in range(height):
        for index_x in range(width + 1):
            type_node = types[-1]
            label = (index_x - (index_y // 2), index_y, type_node)
            G.add_node(label)
            node_dict.update({count: label})
            node_dict_reverse.update({label: count})
            count += 1

    # add edges to the graph
    # add horizontal edges
    for index_y in range(height + 1):
        for index_x in range(row_width - 1):
            node_1 = index_x + index_y * row_width
            node_2 = index_x + 1 + index_y * row_width
            G.add_edge(node_1, node_2, None)

    # add vertical edges
    for index_y in range(height):
        for index_x in range(width + 1):
            node_1 = node_dict_reverse[(index_x - (index_y // 2), index_y, types[0])]
            node_2 = node_dict_reverse[(index_x - (index_y // 2), index_y, types[-1])]
            node_3 = node_dict_reverse[(index_x - (index_y // 2) - 1), index_y + 1, types[2]]
            G.add_edge(node_1, node_2, None)
            G.add_edge(node_2, node_3, None)

    # draw the graph
    if draw:
        mpl_draw(G, with_labels=True)
    return [G, node_dict, node_dict_reverse]

In [23]:
def generating_set(size, graph_type, node_dict_reverse):
    # retrieve a generating set of the specified graph that locates in this center
    
    center_generating_set = []

    if graph_type == 'grid':
        index_x = size[0] // 2
        index_y = size[1] // 2
        center_generating_set.append(node_dict_reverse[(index_x, index_y, 0)])
        return center_generating_set
    
    if graph_type == 'octagonal':
        num_node_types = 8
        index_x = size[0] // 2
        index_y = size[1] // 2
        for node_type in range(num_node_types):
            center_generating_set.append(node_dict_reverse[(index_x, index_y, node_type)])
        return center_generating_set

    if graph_type == 'heavy_hex':
        num_node_types = 5
        index_y = size[1] // 2
        index_x = size[0] // 2 - (index_y // 2)
        for node_type in range(num_node_types):
            center_generating_set.append(node_dict_reverse[(index_x, index_y, node_type)])
        return center_generating_set

    raise ValueError("Unsupported graph type.")

In [24]:
import numpy as np

def retrieve_node_indices(node_pos_array, node_dict_reverse, graph_type):
    # create a hash table that enables the batch retrieval of node indices provided their positions and types
    # input: a numpy array of node positions and types, graph_type, and node_dict_reverse
    # output: a numpy array of node indices
    # note: the input array should be of three rows, representing the x-coord, y-coord, and the type of the nodes, respectively

    node_dict_array = np.transpose(np.array(list(node_dict_reverse.keys())))

    if graph_type == 'grid':
        width = max(node_dict_array[0, :]) + 1
        node_pos_array = np.transpose(node_pos_array)
        node_indices_array = width * node_pos_array[1, :] + node_pos_array[0, :]

    elif graph_type == 'octagonal':
        num_node_types = 8
        width = max(node_dict_array[0, :]) + 1
        node_pos_array = np.transpose(node_pos_array)
        node_indices_array = node_pos_array[0, :] * num_node_types + width * node_pos_array[1, :] * num_node_types + node_pos_array[2, :]

    elif graph_type == 'heavy_hex':
        num_node_types = 5
        target_graph_first_row = node_dict_array[:, node_dict_array[1, :] == 0]
        width = max(target_graph_first_row[0, :])
        height = max(node_dict_array[1, :]) + 1
        first_node_in_vertical = height * (width * (num_node_types - 1) + 3)
        indices_x = node_pos_array[:, 0]
        indices_y = node_pos_array[:, 1]
        indices_types = node_pos_array[:, 2]
        flag_node_horizontal = (indices_types < (num_node_types - 1))
        flag_node_vertical = (indices_types == (num_node_types - 1))
        node_indices_array = flag_node_horizontal * (indices_y * (width * (num_node_types - 1) + 3) + ((indices_x + indices_y // 2)) * (num_node_types - 1) + (indices_y % 2) * 2 + indices_types)
        node_indices_array += flag_node_vertical * (first_node_in_vertical + indices_y * (width + 1) + indices_x + indices_y // 2)

    else:
        node_indices_array = np.array([node_dict_reverse[tuple(node_pos)] for node_pos in node_pos_array])

    return node_indices_array