# First approach, simple neural networks

In [None]:
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple
import graph
from graph import *
from node import *
from part import *
from sklearn.model_selection import train_test_split
from evaluation import MyPredictionModel, evaluate, load_model, EdgePredictor, OmarPredictionModel


In [None]:
with open('./data/graphs.dat', 'rb') as file:
    train_graphs_list: List[Graph] = pickle.load(file)
    train_graphs_list, test_graphs = train_test_split(train_graphs_list, test_size=0.2, random_state=42)
all_part_ids = []
all_family_ids = []
for graph in train_graphs_list:
    for n in graph.get_nodes():
        all_part_ids.append(int(n.get_part().get_part_id()))
        all_family_ids.append(int(n.get_part().get_family_id()))

part_vocab_size = max(all_part_ids) + 1
family_vocab_size = max(all_family_ids) + 1
print(f"Part Vocab Size: {part_vocab_size}")
print(f"Family Vocab Size: {family_vocab_size}")

In [None]:
############################################################
# 2) Build a PyTorch Dataset
############################################################
class LinkPredictionDataset(Dataset):

    """
    Creates positive/negative samples from each Graph.
    For each Graph:
      - Collect all nodes
      - For every pair (i, j), check if it's an edge (label=1) or not (label=0)
    """
    def __init__(self, graphs: List[Graph]):
        super().__init__()
        self.samples = []

        for g in graphs:
            # Get the nodes and edges
            node_list = g.get_nodes()          # List[Node]
            edge_list = g.get_edges()          # List of (Node, Node)
            edge_set = self.get_edge_list(edge_list)        # for quick membership checks

            # Map node ID -> (part_id, family_id)
            node_id_to_features = {}
            for node in node_list:
                node_id_to_features[node.get_id()] = (
                    node.get_part().get_part_id(),
                    node.get_part().get_family_id()
                )

            # We'll gather all node IDs from the node list
            node_ids = [n.get_id() for n in node_list]
            id_to_node = {n.get_id(): n for n in node_list}

            # Create all (i, j) pairs
            for i in node_ids:
                for j in node_ids:
                    if i == j:
                        continue
                    part_i, fam_i = node_id_to_features[i]
                    part_j, fam_j = node_id_to_features[j]

                    # Sort the pair for an undirected edge check
                    pair = tuple(sorted([id_to_node[i], id_to_node[j]],
                                        key=lambda x: x.get_id()))
                    label = 1 if pair in edge_set else 0

                    self.samples.append((int(part_i), int(fam_i), int(part_j), int(fam_j), int(label)))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]  # (part_i, fam_i, part_j, fam_j, label)
    def get_edge_list(self, __edges: Dict[Node, List[Node]]):
        edge_pairs = set()  # use a set to avoid duplicates

        for src, neighbors in __edges.items():
            for dst in neighbors:
                # Sort the pair so that (NodeA, NodeB) == (NodeB, NodeA)
                sorted_pair = tuple(sorted([src, dst], key=lambda n: n.get_id()))
                edge_pairs.add(sorted_pair)

        return edge_pairs  # Now we have a list of (Node, Node) pairs


def train_edge_predictor(model, train_graphs_list, optimizer, criterion, epochs=100):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for batch in dataloader:
            part_i, fam_i, part_j, fam_j, label = batch
            # Convert to Long / Float for embeddings + BCE
            part_i = part_i.long()
            fam_i  = fam_i.long()
            part_j = part_j.long()
            fam_j  = fam_j.long()
            label  = label.float()

            optimizer.zero_grad()
            logits = model(part_i, fam_i, part_j, fam_j)
            loss = criterion(logits, label)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")



### Now train it and write it to disk

In [None]:
# Create the dataset and dataloader
dataset = LinkPredictionDataset(train_graphs_list)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Create the model, criterion, and optimizer
model_EdgePredictor = EdgePredictor(part_vocab_size, family_vocab_size, embed_dim=16, hidden_dim=32)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model_EdgePredictor.parameters(), lr=0.001)

# train the model
train_edge_predictor(model_EdgePredictor, train_graphs_list, optimizer, criterion, epochs=50)
torch.save(model_EdgePredictor.state_dict(), "model_EdgePredictor.pth")


Now evaluate it

In [None]:
model_file_path = 'model_EdgePredictor.pth'
prediction_model: MyPredictionModel = load_model(model_file_path)

# For illustration, we compute the eval score on a portion of the training data
instances = [(graph.get_parts(), graph) for graph in test_graphs[:500]]
eval_score = evaluate(prediction_model, instances)
print(eval_score)


# Second Method: GNN

In [None]:
def train_graph_predictor(model, train_graphs_list, optimizer, criterion, epochs=100):
    model.train()
    for epoch in range(epochs):
        print("EPOCH:", epoch)
        total_loss = 0.0

        for graph in train_graphs_list:
            optimizer.zero_grad()

            # Sort nodes
            nodes = sorted(
                graph.get_nodes(),
                key=lambda node: (node.get_part().get_part_id(), node.get_part().get_family_id())
            )

            # Prepare part/family IDs
            part_ids = torch.tensor(
                [int(node.get_part().get_part_id()) for node in nodes],
                dtype=torch.long
            )
            family_ids = torch.tensor(
                [int(node.get_part().get_family_id()) for node in nodes],
                dtype=torch.long
            )

            # Build adjacency on the same device
            part_order = tuple(node.get_part() for node in nodes)
            adjacency_matrix = torch.tensor(
                graph.get_adjacency_matrix(part_order),
                dtype=torch.float32
            )

            # Forward pass
            logits = model(part_ids, family_ids)

            # Flatten for loss
            target = adjacency_matrix.flatten()
            loss = criterion(logits.flatten(), target)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_graphs_list)
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")


### Now train it and write it to disk

In [None]:
model = OmarPredictionModel(part_vocab_size, family_vocab_size, embed_dim=1, gnn_hidden_dim=32)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()

train_graph_predictor(model, train_graphs_list, optimizer, criterion, epochs=100)
torch.save(model.state_dict(), "graph_predictor_model.pth")



### Now evaluate it

In [None]:
model_file_path = 'graph_predictor_model.pth'
prediction_model: MyPredictionModel = load_model(model_file_path)

# For illustration, we compute the eval score on a portion of the training data
instances = [(graph.get_parts(), graph) for graph in test_graphs[:500]]
eval_score = evaluate(prediction_model, instances)
print(eval_score)
