In [1]:
import numpy as np

file = np.load('../../data/npz_all/npz/layout/xla/random/train/alexnet_train_batch_32.npz')
edge_index = file['edge_index']
node_config_ids = file['node_config_ids']

In [2]:
import numpy as np
from collections import defaultdict, deque


class Graph:
    def __init__(self):
        self.graph = defaultdict(set)

    def add_edge(self, u, v):
        self.graph[u].add(v)

    def trim_and_merge(self, specified_nodes: set, return_distance: bool):
        trimmed_graph = defaultdict(set)
        visited_global = set()  # to keep track of globally visited nodes
        if return_distance:
            distance_between_nodes = defaultdict(lambda: defaultdict(int))

        for src in specified_nodes:
            if src in visited_global:  # skip already visited nodes
                continue

            visited = set([src])

            if return_distance:
                queue = deque([(src, 1)])
            else:
                queue = deque([src])

            while queue:
                if return_distance:
                    node, distance = queue.popleft()
                else:
                    node = queue.popleft()
                visited_global.add(node)
                for neighbor in self.graph[node]:
                    if neighbor in specified_nodes:
                        trimmed_graph[src].add(neighbor)
                        if return_distance:
                            distance_between_nodes[src][neighbor] = (
                                distance + 1
                            )
                    elif neighbor not in visited:
                        visited.add(neighbor)
                        if return_distance:
                            queue.append((neighbor, distance + 1))
                        else:
                            queue.append(neighbor)

        if return_distance:
            return trimmed_graph, distance_between_nodes
        else:
            return trimmed_graph


def get_config_graph(origin_edges, config_node_ids, return_distance=False):
    g = Graph()

    for src, tgt in origin_edges:
        g.add_edge(src, tgt)

    trimmed_graph = g.trim_and_merge(config_node_ids.tolist(), return_distance)
    if return_distance:
        trimmed_graph, distances = trimmed_graph

    trimmed_edges = []

    for src, tgts in trimmed_graph.items():
        if not tgts:
            continue
        for tgt in tgts:
            trimmed_edges.append([src, tgt])

    trimmed_edges = np.array(trimmed_edges)
    weights = [distances[src][tgt] for src, tgt in trimmed_edges]
    weights = np.array(weights)
    weights = weights.max() / weights

    return trimmed_edges, weights if return_distance else trimmed_edges

In [3]:
config_edge_index, config_edge_weight = get_config_graph(
    edge_index, node_config_ids, return_distance=True
)

In [4]:
import torch
from torch_geometric.nn import GCNConv, GraphConv

config_edge_index = torch.from_numpy(config_edge_index).long().transpose(0, 1)
config_edge_weight = torch.from_numpy(config_edge_weight).float()
config_node_feat = torch.randn(config_edge_index.max() + 1, 64)

gcn = GCNConv(64, 64)
gcn(config_node_feat, config_edge_index, edge_weight=config_edge_weight)


tensor([[ 0.1426, -1.7442,  1.5470,  ...,  2.1013,  0.6848,  1.3438],
        [-1.2576, -2.0359,  1.1427,  ...,  1.7395,  1.2310,  0.2977],
        [ 0.8821,  0.2583, -0.6701,  ..., -0.9313, -0.0867, -0.8881],
        ...,
        [ 0.3846,  0.9855,  0.0725,  ..., -1.0511, -0.3521, -1.5276],
        [ 0.9387, -0.3078, -0.3238,  ..., -0.1412, -1.0043, -0.8093],
        [ 0.5443, -0.0848,  1.7449,  ...,  1.7213,  0.0217,  1.5148]],
       grad_fn=<AddBackward0>)