## Topological Visualizations

### Basic 3D cube visualization

In [None]:
#Import libraries
import networkx as nx
import plotly.graph_objects as go
import numpy as np
from lxml import etree

#Create a directed Graph object
G = nx.DiGraph()

#Define dimensions of the Hypercube
length, width, depth = 2,2,2

for i in range(length):
    for j in range(width):
        for k in range(depth):
            new_edge = (i,j,k)
            G.add_node(new_edge)

            #connect horizontally
            if (i + 1 < length):
                G.add_edge(new_edge, (i+1,j,k)) #send
                G.add_edge((i+1,j,k), new_edge) #receive
            
            #connect vertically
            if (j + 1 < width):
                G.add_edge((i,j+1,k), new_edge) #send
                G.add_edge(new_edge, (i,j+1,k)) #receive
                
            #connect depth wise
            if (k + 1 < depth):
                G.add_edge((i,j,k+1), new_edge) #send
                G.add_edge(new_edge, (i,j,k+1)) #receive

#Steps for plotting onto Plotly
"""
    Nodes Logic
"""
nodes = np.array(G.nodes())
scatter = go.Scatter3d(
    x=nodes[:, 0], y=nodes[:, 1], z=nodes[:, 2],
    mode='markers',
    marker=dict(size=10, color='blue')
)

"""
    Edge Logic
"""
edges = []
node1, node2 = None, None
for edge in G.edges():
    for node in enumerate(G.nodes()):
        if (node[1] == edge[0]):
            node1 = node[0]
        elif (node[1] == edge[1]):
            node2 = node[0]
    edges.append((node1, node2))

"""
    Arrow Logic
"""
arrows = []
for edge in edges:
    start_node = nodes[edge[0]]
    end_node = nodes[edge[1]]

    # Line for the arrow shaft
    line = go.Scatter3d(
        x=[start_node[0], end_node[0]],
        y=[start_node[1], end_node[1]],
        z=[start_node[2], end_node[2]],
        mode='lines',
        line=dict(color='black', width=1)
    )
    
    # Add an arrowhead as a cone at the end of each edge
    arrow = go.Cone(
        x=[end_node[0]], y=[end_node[1]], z=[end_node[2]], 
        u=[end_node[0] - start_node[0]],  
        v=[end_node[1] - start_node[1]],  
        w=[end_node[2] - start_node[2]],  
        sizemode="absolute",
        sizeref=0.2,  
        anchor="tip", 
        showscale=False,
        colorbar=None,   
        colorscale=None
    )
    
    # Add both the shaft and the arrowhead to the list
    arrows.append(line)
    arrows.append(arrow)

# Set up the layout for 3D view
layout = go.Layout(
    scene=dict(
        xaxis=dict(range=[-2, 2], title='X'),
        yaxis=dict(range=[-2, 2], title='Y'),
        zaxis=dict(range=[-2, 2], title='Z'),
        aspectratio=dict(x=1, y=1, z=1)
    )
)

# Create the figure with initial frames
fig = go.Figure(data=[scatter] + arrows, layout=layout)

fig.show()

### Parsing the XML for representing them per timestep

In [26]:
from lxml import etree
import networkx as nx
import plotly.graph_objects as go
import numpy as np
import dash
from dash import dcc, html

timestep_data = {}

tree = etree.parse('gpu_3_link_500_bw_50_chunk_1024_chunk_coll_2.xml')

root = tree.getroot()

for gpu in root.findall("gpu"):
    gpu_id = gpu.get("id")
    for tb in gpu.findall("tb"):
        tb_id = tb.get("id") #thread block ID
        tb_gpuid_send = tb.get("send") #send threadblock ID
        tb_gpuid_recv = tb.get("recv") #Receive threadblock ID
        for step in tb.findall("step"):
            step_id = step.get("s") #step ID number
            pkt_type = step.get("type") #send or receive packet?
            has_dep = step.get("has_dep") #has dependency? A boolean
            depid = step.get("depid") #dependent thread block ID
            deps = step.get("deps") #dependent step for the corresponding thread block
            chunk_id = step.get("srcoff") #chunk id

            if step_id not in timestep_data:
                timestep_data[step_id] = {gpu_id : {tb_id : {"pkt_type" : pkt_type, "tb_gpuid_send" : tb_gpuid_send, "tb_gpuid_recv" :tb_gpuid_recv, "chunk_id" : chunk_id}}}
            else:
                if gpu_id not in timestep_data[step_id]:
                    timestep_data[step_id][gpu_id] = {tb_id : {"pkt_type" : pkt_type, "tb_gpuid_send" : tb_gpuid_send, "tb_gpuid_recv" :tb_gpuid_recv, "chunk_id" : chunk_id}}
                else:
                    if tb_id not in timestep_data[step_id][gpu_id]:
                        timestep_data[step_id][gpu_id][tb_id] = {"pkt_type" : pkt_type, "tb_gpuid_send" : tb_gpuid_send, "tb_gpuid_recv" :tb_gpuid_recv, "chunk_id" : chunk_id}

### Integrating the basic 3D topological visualization with the data extracted per timestep, outputs timestep number of hypercubes.

In [None]:
#Dimensions for the topology
length, width, depth = 3,3,3

"""
    Create a hypercube for the given topological dimensions.
"""
def create_hypercube():
    G = nx.DiGraph()
    node_number = 0
    for k in range(depth):  # First, iterate over depth
        for i in range(length):  # Then iterate over rows
            for j in range(width):  # Finally, iterate over columns (within a row)
                new_edge = (i, j, k)
                G.add_node(new_edge, label = f"gpu : {node_number}")
                node_number += 1
                
                # Connect horizontally (within the row)
                if (j + 1 < width):
                    G.add_edge(new_edge, (i, j+1, k))  # send to the right
                    G.add_edge((i, j+1, k), new_edge)  # receive from the right

            # Connect vertically (to the next row in the same depth)
            if (i + 1 < length):
                for j in range(width):  # For each column, connect rows
                    G.add_edge((i, j, k), (i+1, j, k))  # send to the next row
                    G.add_edge((i+1, j, k), (i, j, k))  # receive from the next row

        # After covering all rows and columns for this depth, move to the next depth layer.
        if (k + 1 < depth):
            for i in range(length):
                for j in range(width):
                    G.add_edge((i, j, k), (i, j, k+1))  # send to the next depth
                    G.add_edge((i, j, k+1), (i, j, k))  # receive from the next depth

    return G

"""
    Iterate through each timestep, and create a hypercube per timestep.
"""
for timestep, time_data in timestep_data.items():
    G = create_hypercube()

    #define nodes
    nodes = np.array(G.nodes())
    
    #Plot all the nodes using plotly
    scatter = go.Scatter3d(
        x=nodes[:, 0], y=nodes[:, 1], z=nodes[:, 2],
        mode='markers',
        marker=dict(size=10, color='blue'),
        hoverinfo='text',
        hovertext=[f"gpu : {gpu_id}" for gpu_id in range(length * width * depth)]
    )

    #Edge logic
    edges = []
    for gpu_id, gpu_data in time_data.items():
        for tb_id, tb_data in gpu_data.items():
            if (tb_data['pkt_type'] == 's'):
                edges.append((int(gpu_id), int(tb_data['tb_gpuid_send']), int(tb_data['chunk_id'])))
            elif (tb_data['pkt_type'] == 'r'):
                edges.append((int(tb_data['tb_gpuid_recv']), int(gpu_id), int(tb_data['chunk_id'])))

    arrows = []
    for edge in edges:
        start_node = nodes[edge[0]]
        end_node = nodes[edge[1]]
        chunk_id = edge[2]

        # Line for the arrow shaft
        line = go.Scatter3d(
            x=[start_node[0], end_node[0]],
            y=[start_node[1], end_node[1]],
            z=[start_node[2], end_node[2]],
            mode='lines',
            line=dict(color='red', width=2),
            hoverinfo='text',
            hovertext=f"Chunk_id : {chunk_id}"
        )
        
        # Add an arrowhead as a cone at the end of each edge
        arrow = go.Cone(
            x=[end_node[0]], y=[end_node[1]], z=[end_node[2]], 
            u=[end_node[0] - start_node[0]],  
            v=[end_node[1] - start_node[1]],  
            w=[end_node[2] - start_node[2]],  
            sizemode="absolute",
            sizeref=0.2,  
            anchor="tip", 
            showscale=False,
            colorbar=None,   
            colorscale=None
        )
        
        # Add both the shaft and the arrowhead to the list
        arrows.append(line)
        arrows.append(arrow)

    # Set up the layout for 3D view
    layout = go.Layout(
        scene=dict(
            xaxis=dict(range=[-3, 3], title='X'),
            yaxis=dict(range=[-3, 3], title='Y'),
            zaxis=dict(range=[-3, 3], title='Z'),
            aspectratio=dict(x=1, y=1, z=1)
        )
    )

    # Create the figure with initial frames
    fig = go.Figure(data=[scatter] + arrows, layout=layout)

    fig.show()

### Adding Next and Previous options to display the hypercube for the respective timestep.

In [None]:
import plotly.graph_objects as go
import numpy as np
import networkx as nx
import ipywidgets as widgets
from IPython.display import display, clear_output

length, width, depth = 3, 3, 3

# Function to create the hypercube graph
def create_hypercube():
    G = nx.DiGraph()
    node_number = 0
    for k in range(depth):  
        for i in range(length):  
            for j in range(width):  
                new_edge = (i, j, k)
                G.add_node(new_edge, label=f"gpu : {node_number}")
                node_number += 1

                if (j + 1 < width):
                    G.add_edge(new_edge, (i, j + 1, k))  
                    G.add_edge((i, j + 1, k), new_edge)  

            if (i + 1 < length):
                for j in range(width):
                    G.add_edge((i, j, k), (i + 1, j, k))  
                    G.add_edge((i + 1, j, k), (i, j, k))  

        if (k + 1 < depth):
            for i in range(length):
                for j in range(width):
                    G.add_edge((i, j, k), (i, j, k + 1))  
                    G.add_edge((i, j, k + 1), (i, j, k))  

    return G

# Function to create scatter plot for nodes and edges per timestep
def update_plot(timestep):
    clear_output(wait=True)
    display(prev_button, next_button)

    # Initialize the graph data for timestep changes
    G = create_hypercube()
    nodes = np.array(G.nodes())

    time_data = timestep_data[f'{timestep}']
    
    # Nodes scatter plot
    scatter = go.Scatter3d(
        x=nodes[:, 0], y=nodes[:, 1], z=nodes[:, 2],
        mode='markers',
        marker=dict(size=10, color='blue'),
        hoverinfo='text',
        hovertext=[f"gpu : {gpu_id}" for gpu_id in range(length * width * depth)]
    )
    
    # Edges and arrows
    edges = []
    for gpu_id, gpu_data in time_data.items():
        for tb_id, tb_data in gpu_data.items():
            if (tb_data['pkt_type'] == 's'):
                edges.append((int(gpu_id), int(tb_data['tb_gpuid_send']), int(tb_data['chunk_id'])))
            elif (tb_data['pkt_type'] == 'r'):
                edges.append((int(tb_data['tb_gpuid_recv']), int(gpu_id), int(tb_data['chunk_id'])))
    
    arrows = []
    for edge in edges:
        start_node = nodes[edge[0]]
        end_node = nodes[edge[1]]
        chunk_id = edge[2]

        line = go.Scatter3d(
            x=[start_node[0], end_node[0]],
            y=[start_node[1], end_node[1]],
            z=[start_node[2], end_node[2]],
            mode='lines',
            line=dict(color='black', width=2),
            hoverinfo='text',
            hovertext=f"Chunk_id : {chunk_id}"
        )
        
        arrow = go.Cone(
            x=[end_node[0]], y=[end_node[1]], z=[end_node[2]],
            u=[end_node[0] - start_node[0]],
            v=[end_node[1] - start_node[1]],
            w=[end_node[2] - start_node[2]],
            sizemode="absolute",
            sizeref=0.2,
            anchor="tip",
            showscale=False,
            colorbar=None,
            colorscale=None
        )
        arrows.append(line)
        arrows.append(arrow)

    # 3D Plot Layout
    layout = go.Layout(
        title=f"Time step = {timestep}"
    )

    fig = go.Figure(data=[scatter] + arrows, layout=layout)
    fig.show()

# Timestep navigation
current_timestep = 0
max_timestep = len(timestep_data) - 1

def on_next_button_clicked(b):
    global current_timestep
    if current_timestep < max_timestep:
        current_timestep += 1
        update_plot(current_timestep)

def on_prev_button_clicked(b):
    global current_timestep
    if current_timestep > 0:
        current_timestep -= 1
        update_plot(current_timestep)

# Create Next and Previous buttons
next_button = widgets.Button(description="Next")
prev_button = widgets.Button(description="Previous")

# Bind button clicks to functions
next_button.on_click(on_next_button_clicked)
prev_button.on_click(on_prev_button_clicked)

# Display buttons and initial plot
display(prev_button, next_button)
update_plot(current_timestep)

### Replacing the straight line edges with curved edges (Hypercube)

In [None]:
import plotly.graph_objects as go
import numpy as np
import networkx as nx
import ipywidgets as widgets
from IPython.display import display, clear_output

length, width, depth = 6, 6, 6

# Function to create the hypercube graph
def create_hypercube():
    G = nx.DiGraph()
    node_number = 0
    for k in range(depth):  
        for i in range(length):  
            for j in range(width):  
                new_node = (i, j, k)
                G.add_node(new_node, label=f"gpu : {node_number}")
                node_number += 1

                if (j + 1 < width):
                    G.add_edge(new_node, (i, j + 1, k))  
                    G.add_edge((i, j + 1, k), new_node)  

            if (i + 1 < length):
                for j in range(width):
                    G.add_edge(new_node, (i + 1, j, k))  
                    G.add_edge((i + 1, j, k), new_node)  

        if (k + 1 < depth):
            for i in range(length):
                for j in range(width):
                    G.add_edge(new_node, (i, j, k + 1))  
                    G.add_edge((i, j, k + 1), new_node)  

    return G

def create_curve(start, end, height_factor=0.2):
    """Create a curved path between two points."""
    mid = (start + end) / 2
  
    # Add some height to the midpoint
    if (start[2] == end[2]):
        mid[2] += np.linalg.norm(end - start) * height_factor
    else:
        mid[0] += np.linalg.norm(end - start) * height_factor
    
    t = np.linspace(0, 1, 50)

    curve = np.outer(1-t, start) + np.outer(t, end) + np.outer(4*t*(1-t), mid-0.5*(start+end))
    return curve

# Function to create scatter plot for nodes and edges per timestep
def update_plot(timestep):
    clear_output(wait=True)
    display(prev_button, next_button)

    # Initialize the graph data for timestep changes
    G = create_hypercube()
    nodes = np.array(G.nodes())

    time_data = timestep_data[f'{timestep}']
    
    # Nodes scatter plot
    scatter = go.Scatter3d(
        x=nodes[:, 0], y=nodes[:, 1], z=nodes[:, 2],
        mode='markers',
        marker=dict(size=12, color='blue'),
        hoverinfo='text',
        hovertext=[f"gpu : {gpu_id}" for gpu_id in range(length * width * depth)]
    )
    
    # Edges and arrows
    edges = []
    for gpu_id, gpu_data in time_data.items():
        for tb_id, tb_data in gpu_data.items():
            if (tb_data['pkt_type'] == 's'):
                edges.append((int(gpu_id), int(tb_data['tb_gpuid_send']), int(tb_data['chunk_id'])))
            elif (tb_data['pkt_type'] == 'r'):
                edges.append((int(tb_data['tb_gpuid_recv']), int(gpu_id), int(tb_data['chunk_id'])))
    
    arrows = []
    x = []
    edges = list(set(edges))

    for edge in edges:
        # print(edge)
        start_node = nodes[edge[0]]
        end_node = nodes[edge[1]]
        chunk_id = edge[2]

        # Convert start_node and end_node to tuples for comparison
        start_node_tuple = tuple(start_node)
        end_node_tuple = tuple(end_node)
        
        """
            for example if an edge is ((0,0,0), (0,0,1)), it doesn't exist in x, so append it to x, and draw a curved line and 
            arrow from edge 0 to edge 1 with arrow pointing towards edge 1
            
            if ((0,0,1),(0,0,0)) appears, check if ((0,0,0),(0,0,1)) i.e., reverse of this exists in the list. If so remove this 
            entry from the list and draw a curved line and arrow from edge 1 to edge 0, with arrow pointing towards edge 0
        """

        if (end_node_tuple, start_node_tuple) in x:
            curve = create_curve(start_node, end_node, -0.05)
            x.remove((end_node_tuple, start_node_tuple))
        else:
            x.append((start_node_tuple, end_node_tuple))
            curve = create_curve(start_node, end_node, 0.05)

        line = go.Scatter3d(
            x=curve[:, 0], y=curve[:, 1], z=curve[:, 2],
            mode='lines',
            line=dict(color='black', width=2),
            hoverinfo='text',
            hovertext=f"Chunk_id : {chunk_id}"
        )
        
        arrows.append(line)
        tangent = curve[-1] - curve[-7]

        arrow = go.Cone(
            x=[end_node[0]], y=[end_node[1]], z=[end_node[2]],
            u=[tangent[0]], 
            v=[tangent[1]], 
            w=[tangent[2]],
            sizemode="absolute",
            sizeref=0.1,  # Adjust the size reference to control arrow size
            anchor="tip",
            showscale=False,
            colorscale="Viridis"  # You can change the color as needed
        )

        arrows.append(arrow)
        
    # 3D Plot Layout
    layout = go.Layout(
        title=f"Time step = {timestep}"
    )

    fig = go.Figure(data=[scatter] + arrows, layout=layout)
    fig.show()

# Timestep navigation
current_timestep = 0
max_timestep = len(timestep_data) - 1

def on_next_button_clicked(b):
    global current_timestep
    if current_timestep < max_timestep:
        print("current_timestep : ", current_timestep)
        current_timestep += 1
        update_plot(current_timestep)

def on_prev_button_clicked(b):
    global current_timestep
    if current_timestep > 0:
        print("current_timestep : ", current_timestep)
        current_timestep -= 1
        update_plot(current_timestep)

# Create Next and Previous buttons
next_button = widgets.Button(description="Next")
prev_button = widgets.Button(description="Previous")

# Bind button clicks to functions
next_button.on_click(on_next_button_clicked)
prev_button.on_click(on_prev_button_clicked)

# Display buttons and initial plot
display(prev_button, next_button)
update_plot(current_timestep)

### Torus3D visualization

In [None]:
import plotly.graph_objects as go
import numpy as np
import networkx as nx
import ipywidgets as widgets
from IPython.display import display, clear_output

length, width, depth = 3, 3, 3

# Function to create the hypercube graph
def create_torus3d():
    G = nx.DiGraph()
    node_number = 0
    for k in range(depth):  
        for i in range(length):  
            for j in range(width):  
                new_node = (i, j, k)
                G.add_node(new_node, label=f"gpu : {node_number}")
                node_number += 1

                next_j = (j+1) % width
                G.add_edge(new_node, (i, next_j, k))  
                G.add_edge((i, next_j, k), new_node)

            next_i = (i+1) % length
            for j in range(width):
                G.add_edge(new_node, (next_i, j, k))  
                G.add_edge((next_i, j, k), new_node) 

        next_k = (k+1) % depth
        for i in range(length):
            for j in range(width):
                G.add_edge(new_node, (i, j, next_k))  
                G.add_edge((i, j, next_k), new_node)  

    return G

def create_curve(start, end, height_factor=0.2):
    """Create a curved path between two points."""
    mid = (start + end) / 2
  
    # Add some height to the midpoint
    if (start[2] == end[2]):
        if (start[1] == end[1]):
            mid[2] += np.linalg.norm(end - start) * height_factor * abs(end[0] - start[0])
        else:
            mid[2] += np.linalg.norm(end - start) * height_factor
    else:
        mid[0] += np.linalg.norm(end - start) * height_factor * abs(end[2] - start[2])

    t = np.linspace(0, 1, 50)

    curve = np.outer(1-t, start) + np.outer(t, end) + np.outer(4*t*(1-t), mid-0.5*(start+end))
    return curve

# Function to create scatter plot for nodes and edges per timestep
def update_plot(timestep):
    clear_output(wait=True)
    display(prev_button, next_button)

    # Initialize the graph data for timestep changes
    G = create_torus3d()
    nodes = np.array(G.nodes())

    time_data = timestep_data[f'{timestep}']
    
    # Nodes scatter plot
    scatter = go.Scatter3d(
        x=nodes[:, 0], y=nodes[:, 1], z=nodes[:, 2],
        mode='markers',
        marker=dict(size=12, color='blue'),
        hoverinfo='text',
        hovertext=[f"gpu : {gpu_id}" for gpu_id in range(length * width * depth)]
    )
    
    # Edges and arrows
    edges = []
    for gpu_id, gpu_data in time_data.items():
        for tb_id, tb_data in gpu_data.items():
            if (tb_data['pkt_type'] == 's'):
                edges.append((int(gpu_id), int(tb_data['tb_gpuid_send']), int(tb_data['chunk_id'])))
            elif (tb_data['pkt_type'] == 'r'):
                edges.append((int(tb_data['tb_gpuid_recv']), int(gpu_id), int(tb_data['chunk_id'])))
    
    arrows = []
    x = []
    edges = list(set(edges))

    for edge in edges:
        # print(edge)
        start_node = nodes[edge[0]]
        end_node = nodes[edge[1]]
        chunk_id = edge[2]

        # Convert start_node and end_node to tuples for comparison
        start_node_tuple = tuple(start_node)
        end_node_tuple = tuple(end_node)
        
        """
            for example if an edge is ((0,0,0), (0,0,1)), it doesn't exist in x, so append it to x, and draw a curved line and 
            arrow from edge 0 to edge 1 with arrow pointing towards edge 1
            
            if ((0,0,1),(0,0,0)) appears, check if ((0,0,0),(0,0,1)) i.e., reverse of this exists in the list. If so remove this 
            entry from the list and draw a curved line and arrow from edge 1 to edge 0, with arrow pointing towards edge 0
        """

        if (end_node_tuple, start_node_tuple) in x:
            curve = create_curve(start_node, end_node, -0.05)
            x.remove((end_node_tuple, start_node_tuple))
        else:
            x.append((start_node_tuple, end_node_tuple))
            curve = create_curve(start_node, end_node, 0.05)

        line = go.Scatter3d(
            x=curve[:, 0], y=curve[:, 1], z=curve[:, 2],
            mode='lines',
            line=dict(color='black', width=2),
            hoverinfo='text',
            hovertext=f"Chunk_id : {chunk_id}"
        )
        
        arrows.append(line)
        tangent = curve[-1] - curve[-7]

        arrow = go.Cone(
            x=[end_node[0]], y=[end_node[1]], z=[end_node[2]],
            u=[tangent[0]], 
            v=[tangent[1]], 
            w=[tangent[2]],
            sizemode="absolute",
            sizeref=0.1,  # Adjust the size reference to control arrow size
            anchor="tip",
            showscale=False,
            colorscale="Viridis"  # You can change the color as needed
        )

        arrows.append(arrow)
        
    # 3D Plot Layout
    layout = go.Layout(
        title=f"Time step = {timestep}"
    )

    fig = go.Figure(data=[scatter] + arrows, layout=layout)
    fig.show()

# Timestep navigation
current_timestep = 0
max_timestep = len(timestep_data) - 1

def on_next_button_clicked(b):
    global current_timestep
    if current_timestep < max_timestep:
        print("current_timestep : ", current_timestep)
        current_timestep += 1
        update_plot(current_timestep)

def on_prev_button_clicked(b):
    global current_timestep
    if current_timestep > 0:
        print("current_timestep : ", current_timestep)
        current_timestep -= 1
        update_plot(current_timestep)

# Create Next and Previous buttons
next_button = widgets.Button(description="Next")
prev_button = widgets.Button(description="Previous")

# Bind button clicks to functions
next_button.on_click(on_next_button_clicked)
prev_button.on_click(on_prev_button_clicked)

# Display buttons and initial plot
display(prev_button, next_button)
update_plot(current_timestep)

### Mesh2D visualization

In [None]:
import plotly.graph_objects as go
import numpy as np
import networkx as nx
import ipywidgets as widgets
from IPython.display import display, clear_output

length, width= 3, 3

# Function to create the hypercube graph
def create_mesh2D():
    G = nx.DiGraph()
    node_number = 0
    for i in range(length):  
        for j in range(width):  
            new_node = (i, j, 0)
            G.add_node(new_node, label=f"gpu : {node_number}")
            node_number += 1

            if (j + 1 < width):
                G.add_edge(new_node, (i, j + 1, 0))  
                G.add_edge((i, j + 1, 0), new_node)  

        if (i + 1 < length):
            for j in range(width):
                G.add_edge(new_node, (i + 1, j, 0))  
                G.add_edge((i + 1, j, 0), new_node)  

    return G

def create_curve(start, end, height_factor=0.2):
    """Create a curved path between two points."""
    mid = (start + end) / 2
  
    # Add some height to the midpoint
    if (start[1] == end[1]):
        mid[1] += np.linalg.norm(end - start) * height_factor
    else:
        mid[0] += np.linalg.norm(end - start) * height_factor
    
    t = np.linspace(0, 1, 50)

    curve = np.outer(1-t, start) + np.outer(t, end) + np.outer(4*t*(1-t), mid-0.5*(start+end))
    return curve

# Function to create scatter plot for nodes and edges per timestep
def update_plot(timestep):
    clear_output(wait=True)
    display(prev_button, next_button)

    # Initialize the graph data for timestep changes
    G = create_mesh2D()
    nodes = np.array(G.nodes())

    time_data = timestep_data[f'{timestep}']
    
    # Nodes scatter plot
    scatter = go.Scatter3d(
        x=nodes[:, 0], y=nodes[:, 1], z=nodes[:, 2],
        mode='markers',
        marker=dict(size=12, color='blue'),
        hoverinfo='text',
        hovertext=[f"gpu : {gpu_id}" for gpu_id in range(length * width * depth)]
    )
    
    # Edges and arrows
    edges = []
    for gpu_id, gpu_data in time_data.items():
        for tb_id, tb_data in gpu_data.items():
            if (tb_data['pkt_type'] == 's'):
                edges.append((int(gpu_id), int(tb_data['tb_gpuid_send']), int(tb_data['chunk_id'])))
            elif (tb_data['pkt_type'] == 'r'):
                edges.append((int(tb_data['tb_gpuid_recv']), int(gpu_id), int(tb_data['chunk_id'])))
    
    arrows = []
    x = []
    edges = list(set(edges))

    for edge in edges:
        # print(edge)
        start_node = nodes[edge[0]]
        end_node = nodes[edge[1]]
        chunk_id = edge[2]

        # Convert start_node and end_node to tuples for comparison
        start_node_tuple = tuple(start_node)
        end_node_tuple = tuple(end_node)
        
        """
            for example if an edge is ((0,0,0), (0,0,1)), it doesn't exist in x, so append it to x, and draw a curved line and 
            arrow from edge 0 to edge 1 with arrow pointing towards edge 1
            
            if ((0,0,1),(0,0,0)) appears, check if ((0,0,0),(0,0,1)) i.e., reverse of this exists in the list. If so remove this 
            entry from the list and draw a curved line and arrow from edge 1 to edge 0, with arrow pointing towards edge 0
        """

        if (end_node_tuple, start_node_tuple) in x:
            curve = create_curve(start_node, end_node, -0.1)
            x.remove((end_node_tuple, start_node_tuple))
        else:
            x.append((start_node_tuple, end_node_tuple))
            curve = create_curve(start_node, end_node, 0.1)

        line = go.Scatter3d(
            x=curve[:, 0], y=curve[:, 1], z=curve[:, 2],
            mode='lines',
            line=dict(color='black', width=2),
            hoverinfo='text',
            hovertext=f"Chunk_id : {chunk_id}"
        )
        
        arrows.append(line)
        tangent = curve[-1] - curve[-7]

        arrow = go.Cone(
            x=[end_node[0]], y=[end_node[1]], z=[end_node[2]],
            u=[tangent[0]], 
            v=[tangent[1]], 
            w=[tangent[2]],
            sizemode="absolute",
            sizeref=0.1,  # Adjust the size reference to control arrow size
            anchor="tip",
            showscale=False,
            colorscale="Viridis"  # You can change the color as needed
        )

        arrows.append(arrow)
        
    # 3D Plot Layout
    layout = go.Layout(
        title=f"Time step = {timestep}"
    )

    fig = go.Figure(data=[scatter] + arrows, layout=layout)
    fig.show()

# Timestep navigation
current_timestep = 0
max_timestep = len(timestep_data) - 1

def on_next_button_clicked(b):
    global current_timestep
    if current_timestep < max_timestep:
        print("current_timestep : ", current_timestep)
        current_timestep += 1
        update_plot(current_timestep)

def on_prev_button_clicked(b):
    global current_timestep
    if current_timestep > 0:
        print("current_timestep : ", current_timestep)
        current_timestep -= 1
        update_plot(current_timestep)

# Create Next and Previous buttons
next_button = widgets.Button(description="Next")
prev_button = widgets.Button(description="Previous")

# Bind button clicks to functions
next_button.on_click(on_next_button_clicked)
prev_button.on_click(on_prev_button_clicked)

# Display buttons and initial plot
display(prev_button, next_button)
update_plot(current_timestep)

### Torus2D

In [None]:
import plotly.graph_objects as go
import numpy as np
import networkx as nx
import ipywidgets as widgets
from IPython.display import display, clear_output

length, width= 3, 3

# Function to create the hypercube graph
def create_torus2D():
    G = nx.DiGraph()
    node_number = 0
    for i in range(length):  
        for j in range(width):  
            new_node = (i, j, 0)
            G.add_node(new_node, label=f"gpu : {node_number}")
            node_number += 1

            next_j = (j+1) % width
            G.add_edge(new_node, (i, next_j, 0))  
            G.add_edge((i, next_j, 0), new_node)  

        next_i = (i+1) % length
        for j in range(width):
            G.add_edge(new_node, (next_i, j, 0))  
            G.add_edge((next_i, j, 0), new_node)  

    return G

def create_curve(start, end, height_factor):
    """Create a curved path between two points."""
    mid = (start + end) / 2
  
    # Add some height to the midpoint
    if (start[1] == end[1]):
        if (abs(start[0] - end[0]) > 1):
            mid[2] += np.linalg.norm(end - start) * height_factor
        else:    
            mid[1] += np.linalg.norm(end - start) * height_factor
    else:
        if (abs(start[1] - end[1]) > 1):
            mid[2] += np.linalg.norm(end - start) * height_factor
        else:    
            mid[0] += np.linalg.norm(end - start) * height_factor
    
    t = np.linspace(0, 1, 50)

    curve = np.outer(1-t, start) + np.outer(t, end) + np.outer(4*t*(1-t), mid-0.5*(start+end))
    return curve

# Function to create scatter plot for nodes and edges per timestep
def update_plot(timestep):
    clear_output(wait=True)
    display(prev_button, next_button)

    # Initialize the graph data for timestep changes
    G = create_torus2D()
    nodes = np.array(G.nodes())

    time_data = timestep_data[f'{timestep}']
    
    # Nodes scatter plot
    scatter = go.Scatter3d(
        x=nodes[:, 0], y=nodes[:, 1], z=nodes[:, 2],
        mode='markers',
        marker=dict(size=12, color='blue'),
        hoverinfo='text',
        hovertext=[f"gpu : {gpu_id}" for gpu_id in range(length * width * depth)]
    )
    
    # Edges and arrows
    edges = []
    for gpu_id, gpu_data in time_data.items():
        for tb_id, tb_data in gpu_data.items():
            if (tb_data['pkt_type'] == 's'):
                edges.append((int(gpu_id), int(tb_data['tb_gpuid_send']), int(tb_data['chunk_id'])))
            elif (tb_data['pkt_type'] == 'r'):
                edges.append((int(tb_data['tb_gpuid_recv']), int(gpu_id), int(tb_data['chunk_id'])))
    
    arrows = []
    x = []
    edges = list(set(edges))

    for edge in edges:
        # print(edge)
        start_node = nodes[edge[0]]
        end_node = nodes[edge[1]]
        chunk_id = edge[2]

        # Convert start_node and end_node to tuples for comparison
        start_node_tuple = tuple(start_node)
        end_node_tuple = tuple(end_node)
        
        """
            for example if an edge is ((0,0,0), (0,0,1)), it doesn't exist in x, so append it to x, and draw a curved line and 
            arrow from edge 0 to edge 1 with arrow pointing towards edge 1
            
            if ((0,0,1),(0,0,0)) appears, check if ((0,0,0),(0,0,1)) i.e., reverse of this exists in the list. If so remove this 
            entry from the list and draw a curved line and arrow from edge 1 to edge 0, with arrow pointing towards edge 0
        """

        if (end_node_tuple, start_node_tuple) in x:
            curve = create_curve(start_node, end_node, -0.05)
            x.remove((end_node_tuple, start_node_tuple))
        else:
            x.append((start_node_tuple, end_node_tuple))
            curve = create_curve(start_node, end_node, 0.05)

        line = go.Scatter3d(
            x=curve[:, 0], y=curve[:, 1], z=curve[:, 2],
            mode='lines',
            line=dict(color='black', width=2),
            hoverinfo='text',
            hovertext=f"Chunk_id : {chunk_id}"
        )
        
        arrows.append(line)
        tangent = curve[-1] - curve[-7]

        arrow = go.Cone(
            x=[end_node[0]], y=[end_node[1]], z=[end_node[2]],
            u=[tangent[0]], 
            v=[tangent[1]], 
            w=[tangent[2]],
            sizemode="absolute",
            sizeref=0.1,  # Adjust the size reference to control arrow size
            anchor="tip",
            showscale=False,
            colorscale="Viridis"  # You can change the color as needed
        )

        arrows.append(arrow)
        
    # 3D Plot Layout
    layout = go.Layout(
        title=f"Time step = {timestep}"
    )

    fig = go.Figure(data=[scatter] + arrows, layout=layout)
    fig.show()

# Timestep navigation
current_timestep = 0
max_timestep = len(timestep_data) - 1

def on_next_button_clicked(b):
    global current_timestep
    if current_timestep < max_timestep:
        print("current_timestep : ", current_timestep)
        current_timestep += 1
        update_plot(current_timestep)

def on_prev_button_clicked(b):
    global current_timestep
    if current_timestep > 0:
        print("current_timestep : ", current_timestep)
        current_timestep -= 1
        update_plot(current_timestep)

# Create Next and Previous buttons
next_button = widgets.Button(description="Next")
prev_button = widgets.Button(description="Previous")

# Bind button clicks to functions
next_button.on_click(on_next_button_clicked)
prev_button.on_click(on_prev_button_clicked)

# Display buttons and initial plot
display(prev_button, next_button)
update_plot(current_timestep)