In [33]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pickle
from typing import Dict, List, Set, Tuple

from graph import Graph
from part import Part
from node import Node


torch.manual_seed(42)


<torch._C.Generator at 0x75d925702790>

In [74]:
# Example: two small directed graphs

# Graph 1
nodes_g1 = [
    {'node_id': 0, 'part_id': 2, 'family_id': 0},
    {'node_id': 1, 'part_id': 2, 'family_id': 1},
    {'node_id': 2, 'part_id': 5, 'family_id': 0},
]

edges_g1 = [
    (0, 1),  # from node 0 -> node 1
    (1, 2),  # from node 1 -> node 2
]

# Graph 2
nodes_g2 = [
    {'node_id': 0, 'part_id': 2, 'family_id': 1},
    {'node_id': 1, 'part_id': 3, 'family_id': 1},
    {'node_id': 2, 'part_id': 5, 'family_id': 2},
    {'node_id': 3, 'part_id': 2, 'family_id': 2},
]

edges_g2 = [
    (0, 1),
    (0, 2),
    (2, 3),
]

train_graphs_old = [
    (nodes_g1, edges_g1),
    (nodes_g2, edges_g2),
]


In [21]:

from typing import List

with open('./data/graphs.dat', 'rb') as file:
    train_graphs_list: List[Graph] = pickle.load(file)
x = 1


KeyboardInterrupt: 

In [102]:
%load_ext autoreload
%autoreload 2
from importlib import reload

# Reload the module to reflect changes

import part
import graph
from graph import *
from node import *
from part import *
reload(graph)

graph1 = Graph(construction_id=123)  # any ID you want

# Create the Part objects
part0_g1 = Part(part_id=2, family_id=0)
part1_g1 = Part(part_id=2, family_id=1)
part2_g1 = Part(part_id=5, family_id=0)


graph1.add_edge(part0_g1, part1_g1)
graph1.add_edge(part1_g1, part2_g1)

graph2 = Graph(construction_id=456)

part0_g2 = Part(part_id=2, family_id=1)
part1_g2 = Part(part_id=3, family_id=1)
part2_g2 = Part(part_id=5, family_id=2)
part3_g2 = Part(part_id=2, family_id=2)

graph2.add_edge(part0_g2, part1_g2)
graph2.add_edge(part0_g2, part2_g2)
graph2.add_edge(part2_g2, part3_g2)

train_graphs = [graph1, graph2]



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


KeyboardInterrupt: 

In [77]:
from torch.utils.data import Dataset
import torch

class LinkPredictionDatasetOld(Dataset):
    def __init__(self, graphs):
        self.samples = []
        for (nodes, edges) in graphs:
            node_features = {
                n['node_id']: (n['part_id'], n['family_id']) for n in nodes
            }
            edge_set = set(edges)
            node_ids = [n['node_id'] for n in nodes]
            for i in node_ids:
                for j in node_ids:
                    if i == j:
                        continue
                    part_i, fam_i = node_features[i]
                    part_j, fam_j = node_features[j]
                    label = 1 if (i, j) in edge_set else 0
                    self.samples.append((part_i, fam_i, part_j, fam_j, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


In [72]:
from torch.utils.data import Dataset
import torch

class LinkPredictionDataset(Dataset):
    def __init__(self, graphs: List[Graph]):
        """
        graphs: list of Graph objects
        """
        self.samples = []

        for g in graphs:
            # Extract the graph's nodes and edges
            node_list = g.get_nodes()              # List[Node]
            edge_list = g.get_edges()              # List of (Node, Node)
            edge_set = set(edge_list)        # for quick membership checks

            # Build a map: node_id -> (part_id, family_id)
            # so we can create the features
            node_id_to_features = {}
            for node in node_list:
                node_id_to_features[node.get_id()] = (
                    node.get_part().get_part_id(),
                    node.get_part().get_family_id()
                )

            # Create all pairs (i, j) of node_ids
            # This enumerates node_id from the actual Node objects
            node_ids = [n.get_id() for n in node_list]
            # We also need a quick map from id -> Node object for membership checks
            id_to_node = {n.get_id(): n for n in node_list}

            for i in node_ids:
                for j in node_ids:
                    if i == j:
                        continue
                    part_i, fam_i = node_id_to_features[i]
                    part_j, fam_j = node_id_to_features[j]

                    # Check if there's an edge in the undirected set
                    # Remember the edges are stored as (min_node, max_node)
                    # so let's do that same ordering:
                    pair = tuple(sorted([id_to_node[i], id_to_node[j]], key=lambda x: x))
                    label = 1 if pair in edge_set else 0

                    self.samples.append((part_i, fam_i, part_j, fam_j, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Return the 5-tuple:
        # (part_i, fam_i, part_j, fam_j, edge_label)
        return self.samples[idx]


In [94]:
# Find maximum IDs to size embeddings
all_part_ids = []
all_family_ids = []
for graph in train_graphs:
    for n in graph.get_nodes():
        all_part_ids.append(n.get_part().get_part_id())
        all_family_ids.append(n.get_part().get_family_id())

part_vocab_size = max(all_part_ids) + 1
family_vocab_size = max(all_family_ids) + 1

dataset = LinkPredictionDataset(train_graphs)
dataset_2 = LinkPredictionDatasetOld(train_graphs_old)
print("Number of training pairs:", len(dataset))



Number of training pairs: 18


In [95]:
class EdgePredictor(nn.Module):
    def __init__(self, part_vocab_size, family_vocab_size,
                 embed_dim=16, hidden_dim=32):
        super().__init__()
        self.part_embedding = nn.Embedding(part_vocab_size, embed_dim)
        self.family_embedding = nn.Embedding(family_vocab_size, embed_dim)
        self.fc1 = nn.Linear(embed_dim * 4, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()

    def forward(self, part_i, fam_i, part_j, fam_j):
        pi = self.part_embedding(part_i)
        fi = self.family_embedding(fam_i)
        pj = self.part_embedding(part_j)
        fj = self.family_embedding(fam_j)
        x = torch.cat([pi, fi, pj, fj], dim=1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x.squeeze(1)



In [96]:
BATCH_SIZE = 64
EPOCHS = 500
LR = 0.001

dataloader = DataLoader(dataset_2, batch_size=BATCH_SIZE, shuffle=True)

model = EdgePredictor(part_vocab_size, family_vocab_size, 16, 32)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

model.train()
for epoch in range(EPOCHS):
    total_loss = 0.0
    for batch in dataloader:
        part_i, fam_i, part_j, fam_j, label = batch
        part_i = part_i.long()
        fam_i  = fam_i.long()
        part_j = part_j.long()
        fam_j  = fam_j.long()
        label  = label.float()

        optimizer.zero_grad()
        logits = model(part_i, fam_i, part_j, fam_j)
        loss = criterion(logits, label)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {avg_loss:.4f}")



Epoch 1/500 - Loss: 0.7597
Epoch 2/500 - Loss: 0.7394
Epoch 3/500 - Loss: 0.7211
Epoch 4/500 - Loss: 0.7042
Epoch 5/500 - Loss: 0.6882
Epoch 6/500 - Loss: 0.6738
Epoch 7/500 - Loss: 0.6609
Epoch 8/500 - Loss: 0.6489
Epoch 9/500 - Loss: 0.6378
Epoch 10/500 - Loss: 0.6277
Epoch 11/500 - Loss: 0.6182
Epoch 12/500 - Loss: 0.6092
Epoch 13/500 - Loss: 0.6004
Epoch 14/500 - Loss: 0.5919
Epoch 15/500 - Loss: 0.5841
Epoch 16/500 - Loss: 0.5767
Epoch 17/500 - Loss: 0.5702
Epoch 18/500 - Loss: 0.5638
Epoch 19/500 - Loss: 0.5575
Epoch 20/500 - Loss: 0.5514
Epoch 21/500 - Loss: 0.5455
Epoch 22/500 - Loss: 0.5399
Epoch 23/500 - Loss: 0.5345
Epoch 24/500 - Loss: 0.5292
Epoch 25/500 - Loss: 0.5239
Epoch 26/500 - Loss: 0.5189
Epoch 27/500 - Loss: 0.5139
Epoch 28/500 - Loss: 0.5090
Epoch 29/500 - Loss: 0.5040
Epoch 30/500 - Loss: 0.4991
Epoch 31/500 - Loss: 0.4943
Epoch 32/500 - Loss: 0.4895
Epoch 33/500 - Loss: 0.4847
Epoch 34/500 - Loss: 0.4798
Epoch 35/500 - Loss: 0.4749
Epoch 36/500 - Loss: 0.4699
E

In [97]:
def predict_edges(model, new_nodes, threshold=0.5):
    model.eval()
    predicted_edges = []
    node_features = {
        n['node_id']: (n['part_id'], n['family_id']) for n in new_nodes
    }
    node_ids = list(node_features.keys())

    for i in node_ids:
        for j in node_ids:
            if i == j:
                continue
            part_i, fam_i = node_features[i]
            part_j, fam_j = node_features[j]

            with torch.no_grad():
                logit = model(
                    torch.tensor([part_i]),
                    torch.tensor([fam_i]),
                    torch.tensor([part_j]),
                    torch.tensor([fam_j])
                )
                prob = torch.sigmoid(logit)
                if prob.item() > threshold:
                    predicted_edges.append((i, j))

    return predicted_edges


In [100]:
test_nodes = [
    {'node_id': 0, 'part_id': 2, 'family_id': 1},
    {'node_id': 1, 'part_id': 5, 'family_id': 2},
    {'node_id': 2, 'part_id': 2, 'family_id': 2},
]
test_nodes_2 = [
    {'node_id': 0, 'part_id': 2, 'family_id': 1},
    {'node_id': 1, 'part_id': 3, 'family_id': 1},
    {'node_id': 2, 'part_id': 5, 'family_id': 2},
    {'node_id': 3, 'part_id': 2, 'family_id': 2},
]

pred_edges = predict_edges(model, test_nodes_2, threshold=0.5)
print("Predicted edges:", pred_edges)


Predicted edges: [(0, 1), (0, 2), (2, 3)]
