In [130]:
import torch
from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        return x

# Assuming your graph has 10 features per node and you want to output 2 features per node
net = GraphSAGE(10, 2)

# Dummy data
x = torch.randn((100, 10))  # 100 nodes with 10 features each
edge_index = torch.randint(0, 100, (2, 500))  # 500 edges

# Forward pass
out = net(x, edge_index)


In [131]:
# open data in pkl
import pickle

with open('../../Dominators/graph/cat_combined_graph.pkl', 'rb') as f:
    graph = pickle.load(f)

In [132]:
# function that converts intger into binary list with fixed length
def int2bin(num, max_length):
    return [int(x) for x in bin(num)[2:].zfill(max_length)]


In [133]:
# iterate through the graph and get a mapping of node to label
node2label = {}
for node in graph.nodes():
    if node.split("__")[0] not in node2label:
        node2label[node.split("__")[0]] = len(node2label)

max_len = len(bin(len(node2label))[2:])

# change the label to list in node2label
for node in node2label:
    node2label[node.split("__")[0]] = int2bin(node2label[node.split("__")[0]], max_len)

In [134]:
# create a mapping of node to idx
node2idx = {}
for idx, node in enumerate(graph.nodes()):
    node2idx[node] = idx

In [135]:
# create an edge list
edge_list_start = []
edge_list_end = []

for edge in graph.edges():
    edge_list_start.append(node2idx[edge[0]])
    edge_list_end.append(node2idx[edge[1]])

edge_list = [edge_list_start, edge_list_end]
edge_list = torch.tensor(edge_list, dtype=torch.long)

# create a feature matrix
feature_matrix = []
for node in graph.nodes():
    feature_matrix.append(node2label[node.split("__")[0]])

feature_matrix = torch.tensor(feature_matrix, dtype=torch.float)

In [136]:
# experiment with the model

# Assuming your graph has 10 features per node and you want to output 2 features per node
net = GraphSAGE(max_len, 128)

# Forward pass
out = net(feature_matrix, edge_list)