In [1]:
from collections import defaultdict

class Graph:
    def __init__(self):
        self.graph = defaultdict(list)

    def addEdge(self, u, v):
        self.graph[u].append(v)

    def dfs(self, start, end, specified):
        stack = [(start, [start])]
        paths = []
        
        while stack:
            node, path = stack.pop()
            if node == end:
                if any(spec_node in path for spec_node in specified):
                    paths.append(path)
                continue
            
            for neighbor in self.graph[node]:
                if neighbor not in path:
                    stack.append((neighbor, path + [neighbor]))
                    
        return paths
    
    def find_sources_and_sinks(self):
        in_degree = defaultdict(int)
        out_degree = defaultdict(int)
        
        for node, neighbors in self.graph.items():
            out_degree[node] += len(neighbors)
            for neighbor in neighbors:
                in_degree[neighbor] += 1
                
        sources = [node for node, degree in in_degree.items() if degree == 0]
        sinks = [node for node, degree in out_degree.items() if degree == 0]
        
        return sources, sinks
    
    def trim_and_merge(self, specified):
        sources, sinks = self.find_sources_and_sinks()
        
        merged_graph = Graph()
        
        for start in sources:
            for end in sinks:
                paths = self.dfs(start, end, specified)
                for path in paths:
                    merge_node = None
                    for node in path:
                        if node in specified:
                            if merge_node:
                                merged_graph.addEdge(merge_node, node)
                            merge_node = node
                        elif merge_node is None:
                            merge_node = node
                        
                    if merge_node and merge_node != end:
                        merged_graph.addEdge(merge_node, end)
        
        return merged_graph


# Example
specified = {2, 3, 6}  # Specified nodes
g = Graph()
g.addEdge(1, 2)
g.addEdge(2, 3)
g.addEdge(3, 4)
g.addEdge(4, 5)
g.addEdge(5, 6)

new_g = g.trim_and_merge(specified)

# Now, new_g is the new graph where nodes have been merged and trimmed
# according to the specified nodes and the rule described.
print("Trimmed and Merged Graph:")
for node, neighbors in new_g.graph.items():
    print(f"{node} -> {neighbors}")

Trimmed and Merged Graph:


In [6]:
from collections import defaultdict, deque

class Graph:
    def __init__(self):
        self.graph = defaultdict(set)

    def add_edge(self, u, v):
        self.graph[u].add(v)

    def trim_and_merge(self, specified_nodes):
        ancestors = set(specified_nodes)
        to_merge = set()

        # Find ancestors of specified nodes using BFS
        queue = deque(specified_nodes)
        while queue:
            node = queue.popleft()
            for parent in self.graph:
                if node in self.graph[parent]:
                    if parent not in ancestors:
                        ancestors.add(parent)
                        queue.append(parent)
                        if parent not in specified_nodes:
                            to_merge.add(parent)

        # Create a new graph with trimmed and merged nodes
        trimmed_graph = defaultdict(set)
        for node in ancestors:
            if node not in to_merge:
                for child in self.graph[node]:
                    if child in to_merge:
                        trimmed_graph[node].add('merged_node')
                    else:
                        trimmed_graph[node].add(child)

        # Add the merged node
        if to_merge:
            for node in to_merge:
                for child in self.graph[node]:
                    if child not in to_merge:
                        trimmed_graph['merged_node'].add(child)

        return trimmed_graph

# Usage
g = Graph()
g.add_edge('a', 'b')
g.add_edge('b', 'c')
g.add_edge('c', 'd')
g.add_edge('d', 'e')

specified_nodes = ['c', 'e']
trimmed_graph = g.trim_and_merge(specified_nodes)
print(dict(trimmed_graph))  

{'c': {'merged_node'}, 'merged_node': {'c', 'e'}}


In [40]:
from collections import defaultdict, deque


class Graph:
    def __init__(self):
        self.graph = defaultdict(set)

    def add_edge(self, u, v):
        self.graph[u].add(v)

    def trim_and_merge(self, specified_nodes):
        trimmed_graph = defaultdict(set)

        specified_nodes = set(specified_nodes)

        for i, src in enumerate(specified_nodes):
            other_specified_nodes = specified_nodes - {src}

            visited = set()
            queue = deque([src])
            while queue:
                node = queue.popleft()
                if node in other_specified_nodes:
                    trimmed_graph[src].add(node)
                visited.add(node)
                for neighbor in self.graph[node]:
                    if neighbor not in visited and neighbor not in queue:
                        queue.append(neighbor)
                            
        # Ensure all specified nodes are in the graph even if they are isolated
        for node in specified_nodes:
            if node not in trimmed_graph:
                trimmed_graph[node] = set()

        return dict(trimmed_graph)  # convert defaultdict to dict for cleaner representation

In [33]:
g = Graph()
g.add_edge('a', 'b')
g.add_edge('a', 'c')
g.add_edge('b', 'c')
g.add_edge('c', 'd')
g.add_edge('d', 'e')
g.add_edge('c', 'e')
g.add_edge('e', 'f')

In [34]:
g.trim_and_merge(['c', 'f'])


Specified nodes: {'f'}
deque(['c'])
deque(['e', 'd'])
deque(['d', 'f'])
deque(['f'])
Specified nodes: {'c'}
deque(['f'])


{'c': {'f'}, 'f': set()}

In [35]:
import numpy as np

file = np.load(r'H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout\xla\default\train\alexnet_train_batch_32.npz')
edge_index = file['edge_index']
node_config_ids = file['node_config_ids']

In [41]:
g = Graph()

for src, tgt in edge_index:
    g.add_edge(src, tgt)

In [47]:
trimmed_graph = g.trim_and_merge(node_config_ids.tolist())

In [48]:
trimmed_edges = []

for src, tgts in trimmed_graph.items():
    if not tgts:
        continue
    for tgt in tgts:
        trimmed_edges.append([src, tgt])
        
trimmed_edges = np.array(trimmed_edges)

In [44]:
import matplotlib.pyplot as plt
import networkx as nx

def draw_graph(graph_dict, title="Graph"):
    G = nx.DiGraph()
    for node, children in graph_dict.items():
        for child in children:
            G.add_edge(node, child)
            
    pos = nx.spring_layout(G, seed=42)
    plt.figure(figsize=(8, 6))
    nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, edge_color='black', linewidths=1, font_size=15, arrowsize=20, connectionstyle='arc3,rad=0.1')
    plt.title(title)
    plt.show()

In [46]:
import graphviz

dot = graphviz.Digraph("original_graph")

# Add nodes and edges to the graph
for edge in edge_index:
    dot.edge(str(edge[0]), str(edge[1]))

# Set configurable nodes to red
for node in node_config_ids:
    dot.node(str(node), fillcolor="red", style="filled")

# # Set operation names as labels
# for node, node_idx in enumerate(model["node_opcode"]):
#     dot.node(str(node), label=node_idx2name[node_idx])

dot.render(f"../../output/original_graph.gv")

'..\\..\\output\\original_graph.gv.pdf'

In [50]:
dot = graphviz.Digraph("trimmed_graph")

# Add nodes and edges to the graph
for edge in trimmed_edges:
    dot.edge(str(edge[0]), str(edge[1]))

# Set configurable nodes to red
for node in node_config_ids:
    dot.node(str(node), fillcolor="red", style="filled")

# # Set operation names as labels
# for node, node_idx in enumerate(model["node_opcode"]):
#     dot.node(str(node), label=node_idx2name[node_idx])

dot.render(f"../../output/trimmed_graph.gv")

'..\\..\\output\\trimmed_graph.gv.pdf'

In [51]:
def get_config_graph(origin_edges, config_node_ids):
    g = Graph()

    for src, tgt in origin_edges:
        g.add_edge(src, tgt)

    trimmed_graph = g.trim_and_merge(config_node_ids.tolist())

    trimmed_edges = []

    for src, tgts in trimmed_graph.items():
        if not tgts:
            continue
        for tgt in tgts:
            trimmed_edges.append([src, tgt])

    trimmed_edges = np.array(trimmed_edges)
    
    return trimmed_edges

In [52]:
get_config_graph(edge_index, node_config_ids)

array([[262, 227],
       [262, 103],
       [262, 231],
       [262, 136],
       [262, 235],
       [262, 236],
       [262, 110],
       [262, 239],
       [262, 144],
       [262, 241],
       [262, 243],
       [262, 116],
       [262,  86],
       [262, 123],
       [262, 125],
       [262,  95],
       [136, 103],
       [136, 110],
       [136, 116],
       [136,  86],
       [136, 123],
       [136, 125],
       [136,  95],
       [270, 227],
       [270, 103],
       [270, 231],
       [270, 136],
       [270, 235],
       [270, 236],
       [270, 110],
       [270, 239],
       [270, 144],
       [270, 241],
       [270, 116],
       [270,  86],
       [270, 123],
       [270, 125],
       [270,  95],
       [144, 103],
       [144, 136],
       [144, 110],
       [144, 116],
       [144,  86],
       [144, 123],
       [144, 125],
       [144,  95],
       [277, 227],
       [277, 103],
       [277, 231],
       [277, 136],
       [277, 235],
       [277, 236],
       [277,