# 0. Setup

## 0.1. Imports

In [19]:
import os
import random
import numpy as np

import torch
from torch.nn import functional as F

from part import Part
from graph import Graph

from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

from graph_loader import load_graphs
from gensim.models import Word2Vec
from typing import Set, List, Tuple

## 0.2. Hyperparameters

In [20]:
# General: ----------------------------
SEED = 3
random.seed(SEED)
example_graph = 0

# Random Walks: -----------------------
num_walks = 16
walk_length = 8

# Embedding model: --------------------
embedding_vector_size=16    # Size of the embedding vector
window=5                    # Context window size --> Wie viele Wörter außenherum werden beachtet? --> ca. 5, da Durchscnittliche größe der Graphen (Entscheidung von 2 auf 5 machte Unterschied)
min_count=1                 # Minimum occurrences of a node in the walks to include it in the vocabulary
sg=1                        # Use Skip-Gram (sg=1) instead of CBOW (sg=0)
workers=4                   # Number of CPU threads to use
embedding_model_epochs=10   # Number of training epochs

# GCN model: --------------------------
gcn_input_dim = embedding_vector_size
gcn_hidden_dim = 32
gcn_output_dim = 16
gcn_learning_rate = 0.025
gcn_epochs = 5

# GAE model: --------------------------
gae_input_dim = embedding_vector_size
gae_hidden_dim = 32
gae_latent_dim = 16
gae_learning_rate = 0.01
gae_epochs = 200
batch_size = 32

# 1 Process Training Data

## 1.1 Helper Functions

In [4]:
def create_part_list(graph_tuple):
    part_list = []
    nodes = graph_tuple[1].get_nodes()
    for node in nodes:
        part_list.append((node.get_id(), node.get_part().get_part_id()))
    return part_list

In [5]:
def create_edge_list(graph_tuple):
    edge_set = set()
    edges = graph_tuple[1].get_edges()
    for node, connected_nodes in edges.items():
        for connected_node in connected_nodes:

            # Store edges by node_ID and part_ID + node_ID and part_ID or source and target
            # Make sure each edge is only stored once (unidirectionally)
            edge = tuple(sorted((
                (node.get_id(), int(node.get_part().get_part_id())),
                (connected_node.get_id(), int(connected_node.get_part().get_part_id()))
            )))
            edge_set.add(edge)

    return list(edge_set)


In [6]:
def prepare_graph_data(graph_dataset):
    part_list_dict = {}
    edge_list_dict = {}

    for index, graph in enumerate(graph_dataset):
        part_list_dict[index] = create_part_list(graph)
        edge_list_dict[index] = create_edge_list(graph)

    # Sort the lists within the dictionaries
    for key in part_list_dict.keys():
        part_list_dict[key] = sorted(part_list_dict[key], key=lambda x: x[0])  # Sort by NodeID

    for key in edge_list_dict.keys():
        edge_list_dict[key] = sorted(edge_list_dict[key], key=lambda x: (x[0][0], x[1][0]))  # Sort edges by source and target

    return part_list_dict, edge_list_dict



## 1.2 Prepare Datasets

### 1.2.1 Graph Dataset

In [7]:
class GraphDataset(Dataset):
    def __init__(self, file_path: str, train=False, validation=False, test=False, seed=42):
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Dataset file not found at {file_path}")

        self.graphs = load_graphs(file_path)

        if sum([train, validation, test]) != 1:
            raise ValueError("Exactly one of 'train', 'validation', or 'test' must be True.")

        # Create global mapping for unique parts
        self.family_part_dict = {}

        unique_parts = set()
        for graph in self.graphs:
            parts = graph.get_parts()
            for part in parts:
                unique_parts.add(int(part.get_part_id()))
                self.family_part_dict[int(part.get_part_id())] = int(part.get_family_id())

        # unique parts and mapping across all graphs (not just within a certain split)
        unique_parts = sorted(list(unique_parts))
        self.total_global_part_to_idx = {part: idx for idx, part in enumerate(unique_parts)} # mapping part_id to index
        self.idx_to_part_id = {idx: part for part, idx in self.total_global_part_to_idx.items()}  # Reverse mapping
        self.total_num_unique_parts = len(unique_parts)

        # Split: 70% training, 15% validation, 15% test
        train_graphs, test_graphs = train_test_split(self.graphs, test_size=0.3, random_state=seed)
        validation_graphs, test_graphs = train_test_split(test_graphs, test_size=0.5, random_state=seed)

        if train:
            self.graphs = train_graphs
        elif validation:
            self.graphs = validation_graphs
        elif test:
            self.graphs = test_graphs

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        # return parts und graphen
        graph = self.graphs[idx]

        # Initialize a count vector for parts
        part_frequency_vector = np.zeros(self.total_num_unique_parts, dtype=np.int32)

        # Count occurrences of each part
        parts = graph.get_parts()
        for part in parts:
            part_id = int(part.get_part_id())
            mapped_id = self.total_global_part_to_idx[part_id]
            part_frequency_vector[mapped_id] += 1  # Increment the count

        return self.graphs[idx].get_parts(), self.graphs[idx]

### 1.2.2 Read Datasets

In [8]:
training_set = GraphDataset("data/graphs.dat", train = True, seed=SEED)
validation_set = GraphDataset("data/graphs.dat", validation = True, seed=SEED)
testing_set = GraphDataset("data/graphs.dat", test = True, seed=SEED)

### 1.2.3 Prepare Part and Edge Data for all graphs

In [21]:
graph_parts_dict, graph_edge_dict = prepare_graph_data(training_set)

## 1.3. Create Embeddings

Embeddings represent features of a node e.g. 16 features of Part 1 are 16 Embedding values

### 1.3.1. Random Walk Embeddings

In [22]:
def generate_random_walks_single_graph(edges):

    walks = []
    graph = {}

    # Build adjacency list
    for edge in edges:
        node1, node2 = edge[0][1], edge[1][1]  # Extract PartIDs
        graph.setdefault(node1, []).append(node2)
        graph.setdefault(node2, []).append(node1)

    # Perform random walks
    for _ in range(num_walks):
        for node in graph.keys():
            walk = [node]  # Start the walk with the current node
            while len(walk) < walk_length:
                cur = walk[-1]  # Get the last node in the walk
                if cur in graph:
                    walk.append(random.choice(graph[cur]))  # Add a random neighbor
                else:
                    break
            walks.append(walk)  # Add the walk to the list of walks

    return walks

In [23]:
# Generate Random Walks for all graphs:
random_walks = {}
for index, graph in enumerate(training_set):
    random_walks[index] = generate_random_walks_single_graph(graph_edge_dict[index])

[[1334, 198, 1334, 83, 1334, 58, 1334, 58], [168, 1334, 83, 1334, 58, 1334, 198, 1334], [198, 1334, 83, 1334, 198, 1334, 58, 1334], [83, 1334, 168, 1334, 58, 1334, 83, 1334], [58, 1334, 83, 1334, 83, 1334, 58, 1334], [1334, 198, 1334, 168, 1334, 198, 1334, 198], [168, 1334, 83, 1334, 83, 1334, 58, 1334], [198, 1334, 58, 1334, 83, 1334, 83, 1334], [83, 1334, 58, 1334, 198, 1334, 83, 1334], [58, 1334, 83, 1334, 168, 1334, 168, 1334], [1334, 198, 1334, 83, 1334, 83, 1334, 168], [168, 1334, 83, 1334, 58, 1334, 58, 1334], [198, 1334, 83, 1334, 168, 1334, 58, 1334], [83, 1334, 83, 1334, 58, 1334, 198, 1334], [58, 1334, 83, 1334, 198, 1334, 83, 1334], [1334, 58, 1334, 58, 1334, 58, 1334, 83], [168, 1334, 83, 1334, 83, 1334, 58, 1334], [198, 1334, 83, 1334, 168, 1334, 58, 1334], [83, 1334, 83, 1334, 83, 1334, 58, 1334], [58, 1334, 168, 1334, 168, 1334, 83, 1334], [1334, 83, 1334, 198, 1334, 198, 1334, 83], [168, 1334, 83, 1334, 168, 1334, 58, 1334], [198, 1334, 58, 1334, 83, 1334, 83, 1334], [

### 1.3.2. Train Word2Vec Model with Random Walk Embeddings

In [24]:
# Flattening the random walks to be able to train Embedding model on them:
flat_random_walks = [walk for walks in random_walks.values() for walk in walks]

# Training embedding model:
word2vec_model = Word2Vec(
    sentences=flat_random_walks,
    vector_size=embedding_vector_size,
    window=window,
    min_count=min_count,
    sg=sg,
    workers=workers,
    epochs=embedding_model_epochs
)
word2vec_model.save("parts_embeddings.model")
print("Embedding model saved \n")

Embedding model saved 



## 1.4. Prepare Training Data

In [17]:
# Create one parts-torch, one edge-torch and one label-torch for each graph:
all_graph_data = []

for graph_id in graph_parts_dict.keys():

    # Retrieve parts and edges for the current graph
    graph_parts = graph_parts_dict[graph_id]
    graph_edges = graph_edge_dict[graph_id]

    # 1. Extract Parts Features:
    parts_list = []
    for part in graph_parts:
        embedding = word2vec_model.wv[int(part[1])]
        parts_list.append(embedding)
    parts = torch.tensor(parts_list, dtype=torch.float)

    # 2. Extract Positive Edges:
    edge_index_list = []
    for edge in graph_edges:
        source_node = edge[0][0]
        target_node = edge[1][0]
        edge_index_list.append((source_node, target_node))
    edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()

    # 3. Extract Negative Edges:
    num_nodes = len(graph_parts)
    all_edges = torch.combinations(torch.arange(num_nodes), r=2).T  # All possible edges
    all_edges = all_edges.to(edge_index.device)  # Ensure device compatibility

    # Identify negative edges
    neg_edge_mask = ~torch.any(
        (all_edges.unsqueeze(-1) == edge_index.unsqueeze(1)).all(dim=0), dim=1
    )
    neg_edge_index = all_edges[:, neg_edge_mask]

    # 4. Create Edge Labels:
    pos_edge_label = torch.ones(edge_index.size(1))  # Label 1 for positive edges
    neg_edge_label = torch.zeros(neg_edge_index.size(1))  # Label 0 for negative edges

    # Combine positive and negative edges:
    edge_index = torch.cat([edge_index, neg_edge_index], dim=1)
    edge_label = torch.cat([pos_edge_label, neg_edge_label], dim=0)

    # 5. Create Data Object for the Current Graph:
    single_graph_data = Data(
        x=parts,
        edge_index=edge_index,
        edge_label=edge_label  # Store labels for training
    )
    all_graph_data.append(single_graph_data)

  parts = torch.tensor(parts_list, dtype=torch.float)


# 3. Setup GCN

## 3.1 Define GCN


In [25]:
class GCN_Graph_Predictor(torch.nn.Module):
    def __init__(self):
        super(GCN_Graph_Predictor, self).__init__()
        self.conv1 = GCNConv(gcn_input_dim, gcn_hidden_dim)
        self.conv2 = GCNConv(gcn_hidden_dim, gcn_output_dim)

    def forward(self, x, edge_index):
        # print("Input to Conv1:", x)
        x = self.conv1(x, edge_index)
        # print("Output of Conv1:", x)
        x = F.relu(x)
        # print("After ReLU:", x)
        x = self.conv2(x, edge_index)
        # print("Output of Conv2:", x)
        return x

    def initialize_weights(self):
        """
        Initializes weights of the GCNConv layers using Xavier initialization
        and biases to zero.
        """
        for layer in self.modules():
            if isinstance(layer, GCNConv):
                torch.nn.init.xavier_uniform_(layer.lin.weight)  # Xavier initialization for weights
                if layer.lin.bias is not None:
                    torch.nn.init.zeros_(layer.lin.bias)  # Zero initialization for biases

    def train_model(self, data, optimizer):

        self.train()
        all_losses = []

        for epoch in range(gcn_epochs):
            total_loss = 0

            for i, graph_data in enumerate(data):
                optimizer.zero_grad()

                # Feed Forward:
                predictions = self(graph_data.x, graph_data.edge_index)                 # Using self(...) in pytorch always triggers the forward pass

                # Compute probabilistic adjacency matrix
                adjacency_logits = torch.mm(predictions, predictions.t())
                adjacency_probs = torch.sigmoid(adjacency_logits)
                adjacency_probs = adjacency_probs.clamp(min=1e-7, max=1 - 1e-7)             # To prevent invalid log inputs, clip the values of adjacency_probs to a small positive range away from 0 and 1
#                 if i == 1: print("Adjacency------------------ \n", adjacency_probs)

                # Loss computation using precomputed edge labels
                edge_probs = adjacency_probs[graph_data.edge_index[0], graph_data.edge_index[1]]
                loss = F.binary_cross_entropy(edge_probs, graph_data.edge_label)

                # Backpropagation
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            # Avg Loss for the epoche:
            avg_loss = total_loss / len(data)
            all_losses.append(avg_loss)
            print(f"Epoch {epoch + 1}/{gcn_epochs}, Loss: {avg_loss:.4f}")

        return all_losses

    def predict_graph(self, parts: Set[Part]) -> Graph:
        """
        Predicts a graph from the given set of parts.
        :param parts: Set of Part objects.
        :return: A predicted graph by the GCN
        """

        family_id_mapping = {}
        for part in parts:
            family_id_mapping[part.get_part_id()] = part.get_family_id()

        # Step 1: Sort and process parts
        parts = list(sorted(parts, key=lambda p: int(p.get_part_id())))  # Sort by Part ID
        part_ids = [int(p.get_part_id()) for p in parts]  # Extract Part IDs
        print("Part IDs:", part_ids)

        # Step 2: Convert Part IDs to embeddings
        part_embeddings = [word2vec_model.wv[part_id] for part_id in part_ids]
        x = torch.tensor(part_embeddings, dtype=torch.float)

        # Step 3: Create a Data object for querying
        edge_index = torch.empty((2, 0), dtype=torch.long)  # No edges for initial query
        query_data = Data(x=x, edge_index=edge_index)

        # Step 4: Pass data through the GCN model
        node_embeddings = self(query_data.x, query_data.edge_index)  # Forward pass
        adjacency_logits = torch.mm(node_embeddings, node_embeddings.t())
        adjacency_probs = torch.sigmoid(adjacency_logits)

        # Step 5: Generate adjacency matrix and edges
        threshold = 0.5  # Define threshold for edge prediction
        adjacency_matrix = (adjacency_probs > threshold).float()

        # Extract predicted edges
        edges = torch.nonzero(adjacency_matrix, as_tuple=False).t()
        edge_list = [(int(src), int(dst)) for src, dst in zip(edges[0], edges[1])]

        return self.create_mst(family_id_mapping, edge_list, parts)

    def create_mst(
        edge_list: List[Tuple[int, int]],
        parts: List[Part],
    ) -> Graph:
        """
        Creates a minimum spanning tree (MST) using the adjacency matrix and part IDs.
        """
        # Step 1: Create a NetworkX graph from adjacency matrix
        building_graph = Graph()

        building_parts_list = {}
        for part in parts:
            print(part)

        for edge in edge_list:
            source_index = edge[0]
            target_index = edge[1]
            building_graph.add_undirected_edge(parts[source_index], parts[target_index])

        building_graph.draw()

        return building_graph


## 3.2 Initialize GCN

In [26]:
gcn_model = GCN_Graph_Predictor()
gcn_model.initialize_weights()
optimizer = torch.optim.Adam(gcn_model.parameters(), lr=gcn_learning_rate)

### 3.3 Train GCN

In [27]:
gcn_model.train_model(all_graph_data, optimizer)

KeyboardInterrupt: 