In [None]:
from graph import Graph
from part_wrapper import PartWrapper

# Util libraries
import pickle
from typing import List, Set, Dict, Tuple

# ML libraries
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

In [None]:
with open('data/graphs.dat', 'rb') as file:
    all_graphs: List[Graph] = pickle.load(file)
    X_train, X_temp, y_train, y_temp = train_test_split(list(map(lambda g: g.get_parts(), all_graphs)), all_graphs, test_size=0.3, random_state=0)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=0)


print(X_train[0])
print(y_train[0])
print(len(y_train))
print(y_train[0].get_edges())

In [None]:
# Prepare the data for embedding:
# 1. Map parts to indices which will determine one-hot encoding representation
# 2. Create a training set for the embedding. Each training will consist of one hot encoding of a part and the one hot encoding of one of its neighbors

# returns a dictionary of all parts in the training set and the size of the dictionary
# the value determines one-hot encoding
def map_parts_to_index(X_train: List[Set[PartWrapper]]) -> (Dict[PartWrapper, int], int):
    parts_list = [part for parts in X_train for part in parts]
    parts_dict = dict.fromkeys(parts_list)
    counter = 0
    for i in parts_dict:
        parts_dict[i] = counter
        counter += 1
    return parts_dict, len(parts_dict)

# returns a list of tuples. Each tuple contains a one-hot encoding of a part and a one-hot encoding of one of its neighbors
def create_embedding_training_set(X_train: List[Set[PartWrapper]], graphs: List[Graph]) -> List[Tuple[List[int], List[int]]]:
    training_set = []
    mapped_parts = map_parts_to_index(X_train)[0]
    mapped_parts_size = map_parts_to_index(X_train)[1]
    base_vector = [0] * mapped_parts_size
    # iterate through each graph in the training set
    for graph in graphs:
        # iterate through all edges in a graph. One part is the key, its neighbors are the values
        for node in graph.get_edges():
            part = node.get_part()
            part_one_hot_encoded = base_vector.copy()
            part_one_hot_encoded[mapped_parts[part]] = 1
            # iterate through all neighbors af a part. 
            for neighbor_node in graph.get_edges()[node]:
                neighbor_part = neighbor_node.get_part()
                neighbor_part_one_hot_encoded = base_vector.copy()
                neighbor_part_one_hot_encoded[mapped_parts[neighbor_part]] = 1
                training_set.append((part_one_hot_encoded, neighbor_part_one_hot_encoded))
            
    return training_set

print(len(create_embedding_training_set(X_train, y_train)))

# Create a dataset with the right format for the netowrk
class EmbeddingDataset(Dataset):
    def __init__(self, data):
        self.data = data  # list of tuples

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        part, neighbor = self.data[idx]
        return torch.tensor(part, dtype=torch.float), torch.tensor(neighbor, dtype=torch.float)
    
embedding_training_set = EmbeddingDataset(create_embedding_training_set(X_train, y_train))



In [None]:
# Create neural network to learn the embeddings

# Train on an efficient device if possible
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")


# Define the network architecture
class EmbeddingNetwork(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super(EmbeddingNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size) 
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        # Note: No softmax here when using nn.CrossEntropyLoss
        return out
    

# Define the model
input_size = map_parts_to_index(X_train)[1] # Number of different parts in the training set
hidden_size = 10 # Number of embedding dimension
output_size = map_parts_to_index(X_train)[1] # Number of different parts in the training set
model = EmbeddingNetwork(input_size, hidden_size, output_size)
model.to(device)

print(model)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss() # Cross-entropy loss
optimizer = torch.optim.Adam(model.parameters()) # Example optimizer

# Train the model
train_dataset = create_embedding_training_set(X_train, y_train)
train_loader = DataLoader(dataset=embedding_training_set, batch_size=64, shuffle=True)

num_epochs = 100
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}")
    for i, (parts, neighbors) in enumerate(train_loader):
        optimizer.zero_grad()
        parts = parts.to(device)
        neighbors = neighbors.to(device)
        output = model(parts)
        loss = criterion(output, neighbors)
        loss.backward()
        optimizer.step()

# Extract the embeddings from the model and save them in pytorch format
embeddings = model.fc1.weight.data
detached_embeddings = embeddings.detach().cpu()
print(len(embeddings))
print(embeddings)
torch.save(detached_embeddings, './models/embeddings.pt')



In [None]:
# Example code to load the embeddings in a new environment
loaded_weights = torch.load('./models/embeddings.pt')

# If you're using a specific device in your new model/environment, move the weights to that device
# For example, if you're using MPS in your new setup
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
loaded_weights = loaded_weights.to(device)
print(loaded_weights)

In [None]:
# check equivalence implementation
part1 = PartWrapper(202, 2)
part2 = PartWrapper(202, 2)
part3 = PartWrapper(202, 3)
print(f"Part1: id={id(part1)}, hash={hash(part1)}")
print(f"Part2: id={id(part2)}, hash={hash(part2)}")

print(part1.equivalent(part2))

print(part1.get_part_id() == part2.get_part_id())  
print(part1.get_family_id() == part2.get_family_id()) 
print(PartWrapper(202, 2) == PartWrapper(202, 2)) 
print(PartWrapper(202, 2).__eq__(PartWrapper(202, 2)))#
print(part1 == part2)
print(part1 == part3)