In [None]:
import networkx as nx
import matplotlib.pyplot as plt

def initialize_grid(l_x, l_y):
    G = nx.grid_2d_graph(l_y, l_x)
    labels = {node: i for i, node in enumerate(G.nodes())} # maps a node to a given label, so normal ordering
    # Update positions to be used for plotting
    pos = {(x, y): (y, -x) for x, y in G.nodes()} # check
    return G, labels, pos

def draw_grid_with_labels(G, labels, pos): # removed shortest path highlight for now
    """Draw the grid with integer labels and highlight the paths."""
    plt.figure(figsize=(8, 8))
    # Draw nodes using the integer labels directly
    nx.draw(G, pos, labels=labels, with_labels=True, node_size=2000,
             font_size=15)
    plt.show()


def find_all_shortest_paths(G, labels, start_label, end_label):
    """Find and visualize all shortest paths from start_label to end_label."""
    # Mapping from labels back to nodes for searching
    label_to_node = {label: node for node, label in labels.items()}
    start_node = label_to_node[start_label]
    end_node = label_to_node[end_label]

    paths = list(nx.all_shortest_paths(G, source=start_node, target=end_node))
    path_labels = [[labels[node] for node in path] for path in paths]
    print(f"All shortest paths from {start_label} to {end_label}: {path_labels}")
    return paths


# takes integer swap
def swap_labels(G, labels, pos, label1, label2):
    """Swap the labels of two nodes identified by their labels and redraw the graph."""
    node1 = next(node for node, label in labels.items() if label == label1)
    node2 = next(node for node, label in labels.items() if label == label2)

    # Swap the labels in the dictionary
    labels[node1], labels[node2] = labels[node2], labels[node1]

    # Redraw the graph with updated labels
    draw_grid_with_labels(G, labels, pos)

# def swap_labels(G, node1, node2):
#     """
#     Swap labels of two nodes.
#     """
#     # Get current labels
#     label1, label2 = G.nodes[node1]['label'], G.nodes[node2]['label']
    
#     # Swap labels
#     G.nodes[node1]['label'], G.nodes[node2]['label'] = label2, label1
    
#     # Update the drawing
#     labels = nx.get_node_attributes(G, 'label')
#     draw_grid_with_labels(G, pos, labels)

def find_common_subpaths(paths, min_length=2):
    """Finds common consecutive subpaths in a list of paths with a minimum length."""
    common_subpaths = {}
    # Convert each path into tuple for easier manipulation and hashing
    paths = [tuple(path) for path in paths]

    # Check each pair of paths
    for i, path1 in enumerate(paths):
        for j in range(i + 1, len(paths)):
            path2 = paths[j]
            # Compare subpaths of the specified minimum length
            for length in range(min_length, min(len(path1), len(path2)) + 1):
                # Collect subpaths of current length
                subpaths1 = {path1[k:k + length] for k in range(len(path1) - length + 1)}
                subpaths2 = {path2[k:k + length] for k in range(len(path2) - length + 1)}
                # Find intersection
                common = subpaths1.intersection(subpaths2)
                for subpath in common:
                    if subpath in common_subpaths:
                        common_subpaths[subpath].add((i, j))
                    else:
                        common_subpaths[subpath] = {(i, j)}

    # Format results
    result = {subpath: len(index_pairs) for subpath, index_pairs in common_subpaths.items()}
    return result
