# Topology Analysis 

In [8]:
import plotly.express as px
from skimage import io
import numpy as np


In [9]:
'''
This function mocks color segmentation. 
It only works under the hypothesis that each distinct shape has a unique color 
(which is respected in the input dataset)
'''

def flood_fill_iterative(image, x, y, current_color, shape_id, shape_image):
    rows, cols = image.shape[:2]
    stack = [(x, y)]

    while stack:
        x, y = stack.pop()
        # Check if current position is valid
        if (x < 0 or x >= rows or 
            y < 0 or y >= cols or 
            shape_image[x, y] != -1 or 
            not np.array_equal(image[x, y], current_color)):
            continue
        
        # Mark the current pixel with the shape_id
        shape_image[x, y] = shape_id

        # Add neighbors to the stack (4-way)
        stack.append((x+1, y))
        stack.append((x-1, y))
        stack.append((x, y+1))
        stack.append((x, y-1))

def color_segment(image):
    """Finds color shapes in the image and marks them with increasing numbers."""
    rows, cols = image.shape[:2]
    shape_image = np.tile(-1, (rows, cols))
    shape_to_color_map = {} # Keeps track of what the original color of each shape was
    shape_id = 0

    for x in range(rows):
        for y in range(cols):
            if shape_image[x, y] == -1:  # If the pixel hasn't been visited
                flood_fill_iterative(image, x, y, image[x, y], shape_id, shape_image)
                shape_to_color_map[shape_id] = image[x,y]
                shape_id += 1  # Increment the shape_id for the next shape
    num_shapes = shape_id

    
    return shape_image, num_shapes, shape_to_color_map

In [10]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import networkx as nx

def topo_pos(G):
    """Display in topological order, with simple offsetting for legibility"""
    pos_dict = {}
    for i, node_list in enumerate(nx.topological_generations(G)):
        x_offset = len(node_list) / 2
        y_offset = 0.1
        for j, name in enumerate(node_list):
            pos_dict[name] = (j - x_offset, -i + j * y_offset)

    return pos_dict

def plot_touch_graph(image, touch_graph, levels, tree, shape_to_color_map):

    num_shapes = len(touch_graph)

    ############ TOUCH GRAPH
    G = nx.Graph(touch_graph)
    pos = nx.spring_layout(G)

    # Edge trace
    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
    edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines')

    # Node trace
    node_x, node_y, node_labels = [], [], []
    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_labels.append(levels[node])

    colors = [f'rgb({col[0]},{col[1]},{col[2]})' for col in [shape_to_color_map[r] for r in range(num_shapes)]]
    node_trace = go.Scatter(x=node_x, y=node_y, text=node_labels,
                            mode='markers+text',
                            textposition="top center",
                            hoverinfo="none",
                            marker=dict(showscale=False, color=colors, size=10, line_width=2))
    

    ############ TREE
    T = nx.DiGraph(tree.adj_list)
    pos = topo_pos(T)
    # pos = nx.spring_layout(T, seed=42)  # Use a fixed seed for consistent layout

    # Edge trace for hierarchy tree
    edge_x, edge_y = [], []
    for edge in T.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
    tree_edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines')

    # Node trace for hierarchy tree
    node_x, node_y, node_labels = [], [], []
    for node in T.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_labels.append(levels[node])
    tree_node_trace = go.Scatter(x=node_x, y=node_y, text=node_labels, mode='markers+text', textposition="top center",
                                 hoverinfo="none", marker=dict(showscale=False, color=colors, size=10, line_width=2))
    
    ############ DRAW
    fig = make_subplots(rows=1, cols=3)

    fig.add_trace(tree_edge_trace, row=1, col=3)
    fig.add_trace(tree_node_trace, row=1, col=3)

    fig.add_trace(edge_trace, row=1, col=2)
    fig.add_trace(node_trace, row=1, col=2)
    fig.add_trace(px.imshow(image).data[0], row=1, col=1)

    # Hide legend, hide axes for all plots
    fig.update_layout(showlegend=False)
    fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False)
    fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False)

    fig.show()


In [11]:
def build_touch_graph(image_shapes, num_classes):
    """
    Builds the touch graph.
    Diagonal adjacency is not considered.

    Args:
    - image: An image array.
    - color_to_class_map: A dictionary mapping colors to class identifiers.
    """
    adjacency_list = {class_id: set() for class_id in range(num_classes)}

    rows, cols = image_shapes.shape[:2]
    for row in range(rows):
        for col in range(cols):
            current_class = image_shapes[row, col]

            # List of directions to check for adjacency: Right, Bottom, Bottom-Right, Bottom-Left
            directions = [(0, 1), (1, 0), (1, 1), (1, -1)]

            for dr, dc in directions:
                new_row, new_col = row + dr, col + dc
                if 0 <= new_row < rows and 0 <= new_col < cols:
                    adjacent_class = image_shapes[new_row, new_col]
                    if current_class != adjacent_class:
                        adjacency_list[current_class].add(adjacent_class)

    return adjacency_list

In [12]:
def find_levels_and_parents(touch_graph, border_class):
    '''
    Given an adjacency-list graph, find all levels 
    '''
    # define output
    levels = [None for _ in range(len(touch_graph))]
    parents = [None for _ in range(len(touch_graph))]
    levels[border_class] = 0
    parents[border_class] = None

    # define recursive function
    def find_levels_and_parents_(s0):
        assert(levels[s0] is not None)
        unexplored_nodes_touching_s0 = [n for n in touch_graph[s0] 
                                          if levels[n] is None ] # Condition needed to exclude siblings/parent of s0
        if unexplored_nodes_touching_s0 == []:
            return # found a leaf

        nodes_on_level = unexplored_nodes_touching_s0
        while True:
            # Note, this can be optimised by keeping `touch_count` and adding the newly discovered nodes to it
            touch_count = [0 for _ in range(len(touch_graph))]
            for n in nodes_on_level:
                for nn in touch_graph[n]:
                    touch_count[nn] +=1

            new_nodes = [n for n, touches in enumerate(touch_count) if touches >= 2 and 
                                                                       n not in nodes_on_level and 
                                                                       n != s0]
            if new_nodes == []: break 
            nodes_on_level += new_nodes
        
        # Assign level to all siblings
        for n in nodes_on_level:
            levels[n] = levels[s0]+1
            parents[n] = s0

        # Recursively explore nodes. Must be done after assigning level.
        for n in nodes_on_level:
            find_levels_and_parents_(n)

    find_levels_and_parents_(border_class)

    # All levels must be assigned
    assert [n for n in levels if n is None] == []

    return levels, parents

In [13]:
class Tree:
    def __init__(self, root, nodes):
        self.root = root
        self.adj_list = {class_id: set() for class_id in nodes}


def build_tree(parents, s0):
    assert parents[s0] is None

    t = Tree(s0, range(len(parents)))
    for n,p in enumerate(parents):
        if n != s0:
            t.adj_list[p].add(n)

    return t


In [14]:
import os
import glob

for image_path in glob.glob(os.path.abspath("topologies/") + '/*.png'):
    image = io.imread(image_path)
    image_shapes, num_classes, shape_to_color_map = color_segment(image)
    touch_graph = build_touch_graph(image_shapes, num_classes)
    s0 = image_shapes[0][0] # Get most outer class
    levels, parents = find_levels_and_parents(touch_graph, s0) 
    tree = build_tree(parents, s0)
    plot_touch_graph(image, touch_graph, levels, tree, shape_to_color_map)