In [14]:
import os
import random
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GAE
from torch.utils.data import Dataset
from graph_loader import load_graphs

In [15]:
SEED = 0
random.seed(SEED)

In [None]:
"""

    graph.get_edges()

    Edge:
        node: Aktuelle Node
        connected_nodes: List an Nodes, von node

        E.g.: edges.get_items() liefer alle edges:
        node: Node
        connected_nodes: [Dict(Nodes)]
        Verbindung Node 2 zu Node 0:
            Node(NodeID=2, Part=Part(PartID=58, FamilyID=31))
            [Node(NodeID=0, Part=Part(PartID=1621, FamilyID=0))]

        Verbindung Node 0 zu Nodes 1, 2, 3, 4, 5:
            Node(NodeID=0, Part=Part(PartID=1621, FamilyID=0)),
            [Node(NodeID=1, Part=Part(PartID=58, FamilyID=31)), Node(NodeID=2, Part=Part(PartID=58, FamilyID=31)), Node(NodeID=3, Part=Part(PartID=58, FamilyID=31)), Node(NodeID=4, Part=Part(PartID=58, FamilyID=31))]




"""

In [84]:
def create_edge_list(graph):
    edge_set = set()
    edges = graph.get_edges()
    for node, connected_nodes in edges.items():
        for connected_node in connected_nodes:

            # Store edges by node_ID and part_ID + node_ID and part_ID or source and target
            # Make sure each edge is only stored once (unidirectionally)
            edge = tuple(sorted((
                (node.get_id(), node.get_part().get_part_id()),
                (connected_node.get_id(), connected_node.get_part().get_part_id())
            )))
            edge_set.add(edge)

    return list(edge_set)


In [85]:
def create_part_list(graph):
    part_list = []
    nodes = graph.get_nodes()
    for node in nodes:
        part_list.append((node.get_id(), node.get_part().get_part_id()))
    return part_list

In [102]:
class GraphDataset(Dataset):
    def __init__(self, file_path: str, train=False, validation=False, test=False, seed=42):
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Dataset file not found at {file_path}")

        self.graphs = load_graphs(file_path)

        if sum([train, validation, test]) != 1:
            raise ValueError("Exactly one of 'train', 'validation', or 'test' must be True.")

    def __getitem__(self, idx):
        graph = self.graphs[idx]
        return graph

In [103]:
def prepare_graph_data(graph_dataset):
    edge_list_dict = {}
    part_list_dict = {}

    for index, graph in enumerate(graph_dataset):
        edge_list_dict[index] =  create_edge_list(graph)
        part_list_dict[index] = create_part_list(graph)

    return edge_list_dict, part_list_dict


In [104]:
#training_set = GraphDataset("data/graphs.dat", train = True, seed=SEED)
#validation_set = GraphDataset("data/graphs.dat", validation = True, seed=SEED)
testing_set = GraphDataset("data/graphs.dat", test = True, seed=SEED)

In [107]:
edge_list, parts_list = prepare_graph_data(testing_set)
print(len(edge_list), len(parts_list))
print(edge_list[0])
print(parts_list[0])

11159 11159
[((0, '1621'), (4, '58')), ((0, '1621'), (3, '58')), ((0, '1621'), (2, '58')), ((0, '1621'), (1, '58'))]
[(2, '58'), (3, '58'), (1, '58'), (4, '58'), (0, '1621')]


# Model: Graph Neural Network

In [None]:
class GCNEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, embedding_dim):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, embedding_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

# Initialize GCN