## Introduction
Some networks have many skip connections between nodes on the same horizontal coordinate. The lines representing these connections often overlap, creating confusion.

The purpose of this notebook is to develop an algorithm to determine appropriate x-coordinates for the vertical part of every skip connection to maximize clarity

## Design
* Each x-coordinate containing nodes will be treated separately, as skip connections on different x-coordinates don't overlap
* Skip connection lines will be placed between offsets of -0.4 and 0.4 from the x-coordinate of the nodes, because each integer x-coordinate is expected to contain its own nodes
* At each x-coordinate, skip connections will first be divided into two sets, namely those drawn on the left side of the nodes and those drawn on the right side. After this, the two sides will be considered separately

### Step 1: Create potential overlap matrix
* Square matrix `A` with number of rows and columns equal to number of skip connections
* Each row of `A` represents a skip connection already placed
* Each column of `A` represents a skip connection about to be placed
* `A[i,j]` is the number of intersections when connection `i` is placed **inside** connection `j`
  * `A` will contain 0's, 1's and 2's. The 1's will be symmetric w.r.t. transpose operation
  * Connections with many 2's in **rows** should typically be placed far from the nodes. 
  * Connections with low total row counts should typically be placed close to the nodes
  * The sum over a row in `A` is the cost (number of intersections) of placing the connection closest to the nodes 
  * The sum over a column in `A` is the cost (number of intersections) of placing the connection furthest from the nodes 

### Step 2: Assign skip connections to sides
* Define 3 matrices of the same size: `A` from the previous step and `AL` and `AR` for the left and right sides, separately
  * `A` is kept constant, where as `AL` and `AR` sre both initialized to zeros and populated from `A` as connections are added to each side
  * Once a connection is added to a side, the corresponding row is copied from `A` to the relevant side matrix
  * 2's in the matrices are ignored (set to zero) for this step.
* The immediate cost of adding a connection to a specific side is the sum over the corresponding column of that side's matrix 
  * The **difference** between a column sum in `AL` and `AR` is the **difference in cost** between adding a new connection to the left vs right
* Algorithm:
  * Start by adding the connection with the highest column sum in `A`, to `AL`
  * Proceed greedily by checking which connection has the largest **difference in cost** between `AL` and `AR`. Assign that connection to the lowest-cost side
  * If multiple connection are tied in cost difference, assign the one with the **smallest** total column sum in `A` first

### Step 3: Determine x-coordinates of skip connections
* Treat the left and right sides separately. Their problems are now independent
  * For each size, start by copying the relevant rows and columns from `A` to `B`, which starts as a copy and becomes zero as this step progresses
* In a separate matrix `C`, keep track of where lines are drawn to avoid using a segment twice
  * `C[i,j]` is non-zero if there is a line segment that skips a the `i'th` level at the `j'th` x-offset grid line
  * `C` starts as a single column and grows as grid lines are added in the horizontal dimension
  * The spacing of the grid lines are calculated afterwards, when `C` is fully populated
* Start placing connections from the inside out
* When placing a connection:
  * Check `C` to determine the closest possible position to the nodes without using any cell twice
  * Update `C` to reflect the new placement
  * Set both the row and column corresponding to the index of the placed connection to zero in `B`
* In a loop, place connections in the following order. If a node is placed, restart the loop without completing it.
  * Any unplaced connections with a zero row sum in `B`
  * Any connections with one or more 2's in their columns but no 2's in their rows
  * The connection with the lowest row sum

### Additional ideas
* Test step 2 using graphs with nodes arranged vertically and connections represented with semi-circle arcs. This will ensure that longer connections are drawn wider

## Imports

In [16]:
import sys, importlib
from pathlib import Path
nb_dir = Path.cwd()
project_root = nb_dir if nb_dir.name == "idlmav" else nb_dir.parent
sys.path.append(str(project_root))

In [228]:
from typing import List, Tuple, Set, Type
import numpy as np
from numpy.typing import NDArray
import random

from idlmav.mavtypes import MavNode, MavGraph, MavConnection
from idlmav.static_viewers import SemiCircleViewer
from idlmav.renderers.figure_renderer import FigureRenderer
from idlmav.mavutils import plotly_renderer

In [131]:
def reload_imports():
    # NB update low-level dependencies before high-level dependencies
    importlib.reload(sys.modules['idlmav.mavtypes'])
    importlib.reload(sys.modules['idlmav.mavutils'])
    importlib.reload(sys.modules['idlmav.layout'])
    importlib.reload(sys.modules['idlmav.static_viewers'])
    importlib.reload(sys.modules['idlmav.renderers.renderer_utils'])
    importlib.reload(sys.modules['idlmav.renderers.figure_renderer'])
    global MavNode, MavGraph, MavConnection, SemiCircleViewer, FigureRenderer
    from idlmav.mavtypes import MavNode, MavGraph, MavConnection
    from idlmav.static_viewers import SemiCircleViewer
    from idlmav.renderers.figure_renderer import FigureRenderer

## Graph generation and viewer

In [24]:
def add_fictional_metadata(nodes:List[MavNode]):
    for n in nodes:
        n.operation = 'sample'
        n.activations = np.random.randint(low=1, high=100, size=(3,)) * 10
        n.params = np.random.randint(low=100, high=1000) * 10
        n.flops = np.random.randint(low=100, high=1000) * 10

In [145]:
def create_sample_graph():
    nodes = [MavNode(str(ni), x=0, y=ni) for ni in range(7)]  # Straight vertical line
    add_fictional_metadata(nodes)
    connections = [
        MavConnection(nodes[0], nodes[1]),
        MavConnection(nodes[0], nodes[2]),
        MavConnection(nodes[1], nodes[3]),
        MavConnection(nodes[1], nodes[4]),
        MavConnection(nodes[2], nodes[5]),
        MavConnection(nodes[2], nodes[6]),
        MavConnection(nodes[3], nodes[5]),
        MavConnection(nodes[4], nodes[5]),
        MavConnection(nodes[5], nodes[6]),
    ]
    return MavGraph(nodes=nodes, connections=connections)

In [87]:
def create_random_sample_graph(num_nodes, num_connections):
    nodes = [MavNode(str(ni), x=0, y=ni) for ni in range(num_nodes)]  # Straight vertical line
    add_fictional_metadata(nodes)
    connection_tuples: List[Tuple[int]] = []
    connections: List[MavConnection] = []
    num_attempts = 0
    while len(connections) < num_connections and num_attempts < num_connections*10:
        n1 = random.randint(0, num_nodes-1)
        n2 = random.randint(0, num_nodes-1)
        if n1>n2: n_temp = n1; n1 = n2; n2 = n_temp
        reject = n1==n2 or (n1,n2) in connection_tuples
        if not reject:
            connection_tuples.append((n1,n2))
            connections.append(MavConnection(nodes[n1], nodes[n2]))
        num_attempts += 1
    connections = sorted(connections, key=lambda c: c.from_node.y*100 + c.to_node.y)
    return MavGraph(nodes=nodes, connections=connections)

In [63]:
reload_imports()

In [65]:
g1 = create_sample_graph()
SemiCircleViewer(g1).draw()

In [88]:
random.seed(2)
g2 = create_random_sample_graph(10,11)
SemiCircleViewer(g2).draw()

## Step 1: Create potential overlap matrix
* Square matrix `A` with number of rows and columns equal to number of skip connections
* Each row of `A` represents a skip connection already placed
* Each column of `A` represents a skip connection about to be placed
* `A[i,j]` is the number of intersections when connection `i` is placed **inside** connection `j`
  * `A` will contain 0's, 1's and 2's. The 1's will be symmetric w.r.t. transpose operation
  * Connections with many 2's in **rows** should typically be placed far from the nodes. 
  * Connections with low total row counts should typically be placed close to the nodes
  * The sum over a row in `A` is the cost (number of intersections) of placing the connection closest to the nodes 
  * The sum over a column in `A` is the cost (number of intersections) of placing the connection furthest from the nodes 

In [50]:
def is_skip_connection(c:MavConnection, g:MavGraph):
    n0, n1 = c.from_node, c.to_node
    x0, y0, x1, y1 = n0.x, n0.y, n1.x, n1.y
    if x0 != x1: return False
    nodes_on_line = [n for n in g.nodes if n.x == x0]  # Perform 1st check on all nodes
    nodes_on_segment = [n for n in nodes_on_line if n.y > y0 and n.y < y1]  # Perform 2nd and 3rd checks on subset of nodes
    return True if nodes_on_segment else False

In [89]:
g = g2
targ_x = 0  # Value at which to process the vertical connections
skip_connections = [c for c in g.connections if c.from_node.x == targ_x and c.to_node.x == targ_x and is_skip_connection(c, g)]
[(c.from_node.y, c.to_node.y) for c in skip_connections]

[(0, 4), (0, 9), (1, 5), (2, 4), (2, 6), (3, 9), (4, 9), (5, 8), (6, 8)]

In [90]:
num_skip_connections = len(skip_connections)
A = np.zeros((num_skip_connections, num_skip_connections))

for i0,c0 in enumerate(skip_connections):
    y0_from, y0_to = c0.from_node.y, c0.to_node.y
    for i1,c1 in enumerate(skip_connections):
        if i1==i0: continue
        y1_from, y1_to = c1.from_node.y, c1.to_node.y
        if y1_from > y0_from and y1_from < y0_to: A[i0,i1] += 1
        if y1_to > y0_from and y1_to < y0_to: A[i0,i1] += 1
A


array([[0., 0., 1., 1., 1., 1., 0., 0., 0.],
       [1., 0., 2., 2., 2., 1., 1., 2., 2.],
       [1., 0., 0., 2., 1., 1., 1., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [1., 0., 1., 1., 0., 1., 1., 1., 0.],
       [1., 0., 1., 1., 1., 0., 1., 2., 2.],
       [0., 0., 1., 0., 1., 0., 0., 2., 2.],
       [0., 0., 0., 0., 1., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [91]:
reload_imports()

In [93]:
SemiCircleViewer(g).draw(300)

## Step 2: Assign sides
* Define 3 matrices of the same size: `A` from the previous step and `AL` and `AR` for the left and right sides, separately
  * `A` is kept constant, where as `AL` and `AR` sre both initialized to zeros and populated from `A` as connections are assigned to each side
  * Once a connection is assigned to a side, the corresponding row is copied from `A` to the relevant side matrix
  * 2's in the matrices are ignored (set to zero) for this step.
* The immediate cost of assigning a connection to a specific side is the sum over the corresponding column of that side's matrix 
  * The **difference** between a column sum in `AL` and `AR` is the **difference in cost** between assigning a new connection to the left vs right
* Algorithm
  * Start by assigning the connection with the highest column sum in `A`, to `AL`
  * Proceed greedily by checking which connection has the largest **difference in cost** between `AL` and `AR`. Assign that connection to the lowest-cost side
  * If multiple connections are tied in cost difference, assign the one with the **smallest** total column sum in `A` first

In [None]:
A1 = A.copy()
A1[A1==2] = 0  # Consider only 1's
AL = np.zeros(A.shape)
AR = np.zeros(A.shape)
left_idxs:List[int] = []
right_idxs:List[int] = []
unassigned_idxs = set([i for i in range(num_skip_connections)])

def step2_cost(A:NDArray, idx:int):
    return A[:,idx].sum()

def step2_costdiff(AL:NDArray, AR:NDArray, idx:int):
    return step2_cost(AL,idx) - step2_cost(AR,idx)

def step2_assign(idx:int, to_right:bool, A1:NDArray, AL:NDArray, AR:NDArray, left_idxs:List[int], right_idxs:List[int], unassigned_idxs:Set[int]):
    if to_right:
        AR[idx,:] = A1[idx,:]
        left_idxs.append(idx)
    else:
        AL[idx,:] = A1[idx,:]
        right_idxs.append(idx)
    unassigned_idxs.remove(idx)

col_sums = A1.sum(axis=0)
idx = col_sums.argmax()
step2_assign(idx, False, A1, AL, AR, left_idxs, right_idxs, unassigned_idxs)
while unassigned_idxs:
    cost_diffs = AR.sum(axis=0) - AL.sum(axis=0)
    abs_cost_diffs:NDArray = np.abs(cost_diffs)
    abs_cost_diffs[left_idxs + right_idxs] = -1
    idx = abs_cost_diffs.argmax()  # TODO: tie-break on smallest total column sum in A??
    to_right = cost_diffs[idx] <= 0
    step2_assign(idx, to_right, A1, AL, AR, left_idxs, right_idxs, unassigned_idxs)

for idx in left_idxs: skip_connections[idx].offset = -1
for idx in right_idxs: skip_connections[idx].offset = 1
SemiCircleViewer(g).draw(300)


## Step 3: Determine x-coordinates of skip connections
* Treat the left and right sides separately. Their problems are now independent
  * For each size, start by copying the relevant rows and columns from `A` to `B`, which starts as a copy and becomes zero as this step progresses
* In a separate matrix `C`, keep track of where lines are drawn to avoid using a segment twice
  * `C[i,j]` is non-zero if there is a line segment that skips a the `i'th` level at the `j'th` x-offset grid line
  * `C` starts as a single column and grows as grid lines are added in the horizontal dimension
  * The spacing of the grid lines are calculated afterwards, when `C` is fully populated
* Start placing connections from the inside out
* When placing a connection:
  * Check `C` to determine the closest possible position to the nodes without using any cell twice
  * Update `C` to reflect the new placement
  * Set both the row and column corresponding to the index of the placed connection to zero in `B`
* In a loop, place connections in the following order. If a node is placed, restart the loop without completing it.
  * Any unplaced connections with a zero row sum in `B`
  * Any connections with one or more 2's in their columns but no 2's in their rows
  * The connection with the lowest row sum

In [227]:
def step3_ensure_width(C:NDArray, w:int):
    if C.shape[1] >= w: return C
    return np.append(C, np.zeros((C.shape[0], w-C.shape[1])), axis=1)

def step3_get_idx_to_place(B:NDArray, placed_idxs:List[int], unplaced_idxs:Set[int]):
    # Check for any unplaced index with a zero row sum
    row_sums = B.sum(axis=1)
    row_sums[placed_idxs] = 999999
    zero_row_sum_idxs = (row_sums==0).nonzero()[0]
    if len(zero_row_sum_idxs) > 0: return zero_row_sum_idxs[0]

    # Check for any index with 2's in columns, but not in rows
    # * Placed indices are already zeroed out
    col_2_counts = (B==2).sum(axis=0)
    row_2_counts = (B==2).sum(axis=0)
    diff_2_counts = col_2_counts - row_2_counts
    diff_2_counts[row_2_counts > 0] = 0
    diff_2_counts[placed_idxs] = 0
    if (diff_2_counts>0).any(): return diff_2_counts.argmax()

    # If this point is reached, place the unplaced connection with the lowest row sum
    return row_sums.argmin()

def step3_place(B:NDArray, C:NDArray, placed_idxs:List[int], unplaced_idxs:Set[int], idx:int, c:MavConnection, side_factor):
    y0, y1 = c.from_node.y+1, c.to_node.y-1
    x = 0  # Offset at which to place connection
    C = step3_ensure_width(C, x+1)
    occupied = C[y0:y1+1,:].any()
    while occupied:
        x += 1
        C = step3_ensure_width(C, x+1)
        occupied = C[y0:y1+1,x].any()
    B[idx,:] = 0
    B[:,idx] = 0
    C[y0:y1+1,:] = 1
    c.offset = side_factor * (x+1)  # side_factor if -1 for left or +1 for right hand side
    placed_idxs.append(idx)
    unplaced_idxs.remove(idx)
    return C  # Return the only object that may have been copied. Others were modified in-place

for side in range(2):
    if side==0:
        side_connections = [skip_connections[i] for i in left_idxs]
        B = A.copy()[left_idxs,:][:,left_idxs]
        side_factor = -1
    else:
        side_connections = [skip_connections[i] for i in right_idxs]
        B = A.copy()[right_idxs,:][:,right_idxs]
        side_factor = 1
    num_side_connections = len(side_connections)
    max_y = max([c.to_node.y for c in side_connections])
    C = np.zeros((max_y+1,1))

    placed_idxs:List[int] = []
    unplaced_idxs = set([i for i in range(num_side_connections)])
    while unplaced_idxs:
        idx = step3_get_idx_to_place(B, placed_idxs, unplaced_idxs)
        C = step3_place(B, C, placed_idxs, unplaced_idxs, idx, side_connections[idx], side_factor)

# Normalize
max_abs_offset = max([abs(c.offset) for c in skip_connections])
scale_factor = 0.4 / max_abs_offset
for c in skip_connections: c.offset = c.offset*scale_factor

In [230]:
def step3_calc_offsets(A:NDArray, skip_connections:List[MavConnection], left_idxs:List[int], right_idxs:List[int]):
    for side in range(2):
        if side==0:
            side_connections = [skip_connections[i] for i in left_idxs]
            B = A.copy()[left_idxs,:][:,left_idxs]
            side_factor = -1
        else:
            side_connections = [skip_connections[i] for i in right_idxs]
            B = A.copy()[right_idxs,:][:,right_idxs]
            side_factor = 1
        num_side_connections = len(side_connections)
        max_y = max([c.to_node.y for c in side_connections])
        C = np.zeros((max_y+1,1))

        placed_idxs:List[int] = []
        unplaced_idxs = set([i for i in range(num_side_connections)])
        while unplaced_idxs:
            idx = step3_get_idx_to_place(B, placed_idxs, unplaced_idxs)
            C = step3_place(B, C, placed_idxs, unplaced_idxs, idx, side_connections[idx], side_factor)

    # Normalize
    max_abs_offset = max([abs(c.offset) for c in skip_connections])
    scale_factor = 0.4 / max_abs_offset
    for c in skip_connections: c.offset = c.offset*scale_factor

In [231]:
%timeit step3_calc_offsets(A, skip_connections, left_idxs, right_idxs)

351 μs ± 18.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [232]:
reload_imports()

In [233]:
fig = FigureRenderer(g).render(add_slider=True, num_levels_displayed=len(g.nodes))
fig.update_xaxes(range=[-0.5, 0.5])
fig.show(renderer='notebook_connected')

### Exploratory code

In [206]:
M = np.random.randint(0,3,(5,5))
M

array([[1, 0, 0, 1, 1],
       [1, 2, 0, 1, 2],
       [1, 2, 0, 2, 0],
       [2, 0, 1, 2, 1],
       [2, 1, 0, 0, 2]])

In [208]:
col_2_counts = (M==2).sum(axis=0)
row_2_counts = (M==2).sum(axis=1)
diff_2_counts = col_2_counts - row_2_counts
diff_2_counts[row_2_counts > 0] = 0
diff_2_counts

array([2, 0, 0, 0, 0])

In [209]:
diff_2_counts.argmax()

0