In [None]:
from collections import defaultdict

class Graph:
    def __init__(self):
        self.graph = defaultdict(list)

    def addEdge(self, u, v):
        self.graph[u].append(v)

    def dfs(self, start, end, specified):
        stack = [(start, [start])]
        paths = []
        
        while stack:
            node, path = stack.pop()
            if node == end:
                if any(spec_node in path for spec_node in specified):
                    paths.append(path)
                continue
            
            for neighbor in self.graph[node]:
                if neighbor not in path:
                    stack.append((neighbor, path + [neighbor]))
                    
        return paths
    
    def find_sources_and_sinks(self):
        in_degree = defaultdict(int)
        out_degree = defaultdict(int)
        
        for node, neighbors in self.graph.items():
            out_degree[node] += len(neighbors)
            for neighbor in neighbors:
                in_degree[neighbor] += 1
                
        sources = [node for node, degree in in_degree.items() if degree == 0]
        sinks = [node for node, degree in out_degree.items() if degree == 0]
        
        return sources, sinks
    
    def trim_and_merge(self, specified):
        sources, sinks = self.find_sources_and_sinks()
        
        merged_graph = Graph()
        
        for start in sources:
            for end in sinks:
                paths = self.dfs(start, end, specified)
                for path in paths:
                    merge_node = None
                    for node in path:
                        if node in specified:
                            if merge_node:
                                merged_graph.addEdge(merge_node, node)
                            merge_node = node
                        elif merge_node is None:
                            merge_node = node
                        
                    if merge_node and merge_node != end:
                        merged_graph.addEdge(merge_node, end)
        
        return merged_graph


# Example
specified = {2, 3, 6}  # Specified nodes
g = Graph()
g.addEdge(1, 2)
g.addEdge(2, 3)
g.addEdge(3, 4)
g.addEdge(4, 5)
g.addEdge(5, 6)

new_g = g.trim_and_merge(specified)

# Now, new_g is the new graph where nodes have been merged and trimmed
# according to the specified nodes and the rule described.
print("Trimmed and Merged Graph:")
for node, neighbors in new_g.graph.items():
    print(f"{node} -> {neighbors}")

In [None]:
from collections import defaultdict, deque

class Graph:
    def __init__(self):
        self.graph = defaultdict(set)

    def add_edge(self, u, v):
        self.graph[u].add(v)

    def trim_and_merge(self, specified_nodes):
        ancestors = set(specified_nodes)
        to_merge = set()

        # Find ancestors of specified nodes using BFS
        queue = deque(specified_nodes)
        while queue:
            node = queue.popleft()
            for parent in self.graph:
                if node in self.graph[parent]:
                    if parent not in ancestors:
                        ancestors.add(parent)
                        queue.append(parent)
                        if parent not in specified_nodes:
                            to_merge.add(parent)

        # Create a new graph with trimmed and merged nodes
        trimmed_graph = defaultdict(set)
        for node in ancestors:
            if node not in to_merge:
                for child in self.graph[node]:
                    if child in to_merge:
                        trimmed_graph[node].add('merged_node')
                    else:
                        trimmed_graph[node].add(child)

        # Add the merged node
        if to_merge:
            for node in to_merge:
                for child in self.graph[node]:
                    if child not in to_merge:
                        trimmed_graph['merged_node'].add(child)

        return trimmed_graph

# Usage
g = Graph()
g.add_edge('a', 'b')
g.add_edge('b', 'c')
g.add_edge('c', 'd')
g.add_edge('d', 'e')

specified_nodes = ['c', 'e']
trimmed_graph = g.trim_and_merge(specified_nodes)
print(dict(trimmed_graph))  

In [None]:
from collections import defaultdict, deque


class Graph:
    def __init__(self):
        self.graph = defaultdict(set)

    def add_edge(self, u, v):
        self.graph[u].add(v)

    def trim_and_merge(self, specified_nodes):
        trimmed_graph = defaultdict(set)

        specified_nodes = set(specified_nodes)

        for i, src in enumerate(specified_nodes):
            other_specified_nodes = specified_nodes - {src}

            visited = set()
            queue = deque([src])
            while queue:
                node = queue.popleft()
                if node in other_specified_nodes:
                    trimmed_graph[src].add(node)
                visited.add(node)
                for neighbor in self.graph[node]:
                    if neighbor not in visited and neighbor not in queue:
                        queue.append(neighbor)
                            
        # Ensure all specified nodes are in the graph even if they are isolated
        for node in specified_nodes:
            if node not in trimmed_graph:
                trimmed_graph[node] = set()

        return dict(trimmed_graph)  # convert defaultdict to dict for cleaner representation

In [None]:
g = Graph()
g.add_edge('a', 'b')
g.add_edge('a', 'c')
g.add_edge('b', 'c')
g.add_edge('c', 'd')
g.add_edge('d', 'e')
g.add_edge('c', 'e')
g.add_edge('e', 'f')

In [None]:
g.trim_and_merge(['c', 'f'])


In [1]:
import numpy as np

file = np.load(r'H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout\xla\default\train\alexnet_train_batch_32.npz')
edge_index = file['edge_index']
node_config_ids = file['node_config_ids']

In [None]:
from gfos.data.graph import Graph

In [None]:
g = Graph()

for src, tgt in edge_index:
    g.add_edge(src, tgt)
    
trimmed_graph, paths = g.trim_return_path(node_config_ids.tolist(), True)

trimmed_edges = []

for src, tgts in trimmed_graph.items():
    if not tgts:
        continue
    for tgt in tgts:
        trimmed_edges.append([src, tgt])
        
trimmed_edges = np.array(trimmed_edges)

In [None]:
import matplotlib.pyplot as plt
import networkx as nx

def draw_graph(graph_dict, title="Graph"):
    G = nx.DiGraph()
    for node, children in graph_dict.items():
        for child in children:
            G.add_edge(node, child)
            
    pos = nx.spring_layout(G, seed=42)
    plt.figure(figsize=(8, 6))
    nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, edge_color='black', linewidths=1, font_size=15, arrowsize=20, connectionstyle='arc3,rad=0.1')
    plt.title(title)
    plt.show()

In [None]:
import numpy as np
from collections import defaultdict, deque


class Graph:
    def __init__(self):
        self.graph = defaultdict(set)

    def add_edge(self, u, v):
        self.graph[u].add(v)

    def trim_and_merge(self, specified_nodes: set, return_distance: bool):
        trimmed_graph = defaultdict(set)
        visited_global = set()  # to keep track of globally visited nodes
        if return_distance:
            distance_between_nodes = defaultdict(lambda: defaultdict(int))

        for src in specified_nodes:
            if src in visited_global:  # skip already visited nodes
                continue

            visited = set([src])

            if return_distance:
                queue = deque([(src, 1)])
            else:
                queue = deque([src])

            while queue:
                if return_distance:
                    node, distance = queue.popleft()
                else:
                    node = queue.popleft()
                visited_global.add(node)
                for neighbor in self.graph[node]:
                    if neighbor in specified_nodes:
                        trimmed_graph[src].add(neighbor)
                        if return_distance:
                            distance_between_nodes[src][neighbor] = (
                                distance + 1
                            )
                    elif neighbor not in visited:
                        visited.add(neighbor)
                        if return_distance:
                            queue.append((neighbor, distance + 1))
                        else:
                            queue.append(neighbor)

        if return_distance:
            return trimmed_graph, distance_between_nodes
        else:
            return trimmed_graph


def get_config_graph(origin_edges, config_node_ids, return_distance=False):
    g = Graph()

    for src, tgt in origin_edges:
        g.add_edge(src, tgt)

    trimmed_graph = g.trim_and_merge(config_node_ids.tolist(), return_distance)
    if return_distance:
        trimmed_graph, distances = trimmed_graph

    trimmed_edges = []

    for src, tgts in trimmed_graph.items():
        if not tgts:
            continue
        for tgt in tgts:
            trimmed_edges.append([src, tgt])

    trimmed_edges = np.array(trimmed_edges)
    weights = [distances[src][tgt] for src, tgt in trimmed_edges]
    weights = np.array(weights)
    weights = weights.max() / weights

    return trimmed_edges, weights if return_distance else trimmed_edges

In [None]:
g = Graph()

for src, tgt in edge_index:
    g.add_edge(src, tgt)
    
trimmed_graph, distances = g.trim_and_merge_with_distance(node_config_ids.tolist())

trimmed_edges = []

for src, tgts in trimmed_graph.items():
    if not tgts:
        continue
    for tgt in tgts:
        trimmed_edges.append([src, tgt])
        
trimmed_edges = np.array(trimmed_edges)

In [None]:
edge_weights = [
    distances[src][tgt]
    for src, tgt in trimmed_edges
]

In [9]:
import dgl
import torch
from dgl.nn import SAGEConv

In [12]:
g

Graph(num_nodes=6, num_edges=12,
      ndata_schemes={}
      edata_schemes={})

In [11]:
# Case 1: Homogeneous graph
g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
g = dgl.add_self_loop(g)
feat = torch.ones(6, 10)
conv = SAGEConv(10, 2, 'pool')
res = conv(g, feat, torch.ones(6))
res

AssertionError: 

In [None]:
trimmed_edges, edge_weights = get_config_graph(edge_index, node_config_ids, return_distance=True)

In [2]:
from gfos.data.graph import get_config_graph

In [3]:
trimmed_edges, edge_weights = get_config_graph(edge_index, node_config_ids)

In [4]:
edge_weights

array([2.25 , 1.8  , 1.5  , 1.5  , 3.   , 1.5  , 1.8  , 1.5  , 2.25 ,
       1.5  , 3.   , 1.5  , 2.25 , 2.25 , 4.5  , 1.5  , 1.8  , 1.5  ,
       1.8  , 1.5  , 1.5  , 1.125, 1.   , 1.8  , 1.125, 1.8  , 1.5  ,
       1.8  , 1.5  , 1.8  , 4.5  , 2.25 , 1.5  , 2.25 , 1.5  , 3.   ,
       3.   , 1.5  , 2.25 , 2.25 , 1.5  , 1.8  , 4.5  , 1.5  , 2.25 ])

In [None]:
import graphviz

dot = graphviz.Digraph("original_graph")

# Add nodes and edges to the graph
for edge in edge_index:
    dot.edge(str(edge[0]), str(edge[1]))

# Set configurable nodes to red
for node in node_config_ids:
    dot.node(str(node), fillcolor="red", style="filled")

# # Set operation names as labels
# for node, node_idx in enumerate(model["node_opcode"]):
#     dot.node(str(node), label=node_idx2name[node_idx])

dot.render(f"../../output/original_graph.gv")

In [None]:
dot = graphviz.Digraph("trimmed_graph")

# Add nodes and edges to the graph
for edge in trimmed_edges:
    dot.edge(str(edge[0]), str(edge[1]))

# Set configurable nodes to red
for node in node_config_ids:
    dot.node(str(node), fillcolor="red", style="filled")

# # Set operation names as labels
# for node, node_idx in enumerate(model["node_opcode"]):
#     dot.node(str(node), label=node_idx2name[node_idx])

dot.render(f"../../output/trimmed_graph.gv")

In [None]:
import graphviz


dot = graphviz.Digraph("trimmed_graph_2")

# Add nodes and edges to the graph
for edge, distance in zip(trimmed_edges, edge_weights):
    dot.edge(str(edge[0]), str(edge[1]), label=str(distance))

# Set configurable nodes to red
for node in node_config_ids:
    dot.node(str(node), fillcolor="red", style="filled")

# # Set operation names as labels
# for node, node_idx in enumerate(model["node_opcode"]):
#     dot.node(str(node), label=node_idx2name[node_idx])

dot.render(f"../../output/trimmed_graph_2_weights_src.gv")

In [None]:
import graphviz


dot = graphviz.Digraph("trimmed_graph_2")

# Add nodes and edges to the graph
for edge in trimmed_edges:
    dot.edge(str(edge[0]), str(edge[1]))

# Set configurable nodes to red
for node in node_config_ids:
    dot.node(str(node), fillcolor="red", style="filled")

# # Set operation names as labels
# for node, node_idx in enumerate(model["node_opcode"]):
#     dot.node(str(node), label=node_idx2name[node_idx])

dot.render(f"../../output/trimmed_graph_2_src.gv")

In [None]:
import graphviz


dot = graphviz.Digraph("trimmed_graph_2")

# Add nodes and edges to the graph
for edge in trimmed_edges:
    dot.edge(str(edge[0]), str(edge[1]), label=",".join([str(p) for p in paths[edge[0]][edge[1]]]))

# Set configurable nodes to red
for node in node_config_ids:
    dot.node(str(node), fillcolor="red", style="filled")

# # Set operation names as labels
# for node, node_idx in enumerate(model["node_opcode"]):
#     dot.node(str(node), label=node_idx2name[node_idx])

dot.render(f"../../output/trimmed_graph_2_src_paths.gv")