In [None]:
# Load and split data

import pickle
from graph import Graph, Part
from typing import List, Set, Tuple, Dict
from sklearn.model_selection import train_test_split

with open("data/graphs.dat", "rb") as file:
    all_graphs: List[Graph] = pickle.load(file)
    X_train, X_temp, y_train, y_temp = train_test_split(
        list(map(lambda g: g.get_parts(), all_graphs)),
        all_graphs,
        test_size=0.3,
        random_state=0,
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.5, random_state=0
    )

In [None]:
# Train model and store results (in-)correct train/validation graphs

from model import InstanceBased

ib = InstanceBased(y_train, edge_pickle_store=True)
print(f"Order Validation: {ib.evaluate_order(y_val) / len(y_val)}% accuracy\n")

y_train_prev_right = []
y_train_prev_wrong = []
y_val_prev_right = []
y_val_prev_wrong = []

correct_train = 0
for i in range(len(y_train)):
    if y_train[i] == ib.createGraph(X_train[i]):
        correct_train += 1
        y_train_prev_right.append(y_train[i])
    else:
        y_train_prev_wrong.append(y_train[i])
print(f"train_err: {correct_train / len(y_train)}")
correct_val = 0
for i in range(len(y_val)):
    if y_val[i] == ib.createGraph(X_val[i]):
        correct_val += 1
        y_val_prev_right.append(y_val[i])
    else:
        y_val_prev_wrong.append(y_val[i])
print(f"val_err: {correct_val / len(y_val)}")

In [None]:
# Load model and store it for evaluation.py
# Calculate graph accuracy for training, validation, and test set

ib = InstanceBased(y_train, edge_pickle_load=True)
with open(f"data/model.dat", "wb") as file:
    pickle.dump(ib, file)

def checkAccuracy(name, graphs):
    correct_count = 0
    for i in range(len(graphs)):
        if graphs[i] == ib.createGraph(graphs[i].get_parts()):
            correct_count += 1
    print(f"{name}: {100 * correct_count / len(graphs)}%")

print()
checkAccuracy("Training Set", y_train)
checkAccuracy("Validation Set", y_val)
checkAccuracy("Test Set", y_test)

In [None]:
# Calculating Edge Accuracy for validation and test set

from itertools import permutations
import numpy as np

def evaluate( data_set: List[Tuple[Set[Part], Graph]]) -> float:
    """
    Evaluates a given prediction model on a given data set.
    :param model: prediction model
    :param data_set: data set
    :return: evaluation score (for now, edge accuracy in percent)
    """
    sum_correct_edges = 0
    edges_counter = 0

    for input_parts, target_graph in data_set:
        predicted_graph = ib.createGraph(input_parts)

        edges_counter += len(input_parts) * len(input_parts)
        sum_correct_edges += edge_accuracy(predicted_graph, target_graph)

    return sum_correct_edges / edges_counter * 100


def edge_accuracy(predicted_graph: Graph, target_graph: Graph) -> int:
    """
    Returns the number of correct predicted edges.
    :param predicted_graph:
    :param target_graph:
    :return:
    """
    assert len(predicted_graph.get_nodes()) == len(target_graph.get_nodes()), 'Mismatch in number of nodes.'
    assert predicted_graph.get_parts() == target_graph.get_parts(), 'Mismatch in expected and given parts.'

    best_score = 0

    # Determine all permutations for the predicted graph and choose the best one in evaluation
    perms: List[Tuple[Part]] = __generate_part_list_permutations(predicted_graph.get_parts())

    # Determine one part order for the target graph
    target_parts_order = perms[0]
    target_adj_matrix = target_graph.get_adjacency_matrix(target_parts_order)

    for perm in perms:
        predicted_adj_matrix = predicted_graph.get_adjacency_matrix(perm)
        score = np.sum(predicted_adj_matrix == target_adj_matrix)
        best_score = max(best_score, score)

    return best_score


def __generate_part_list_permutations(parts: Set[Part]) -> List[Tuple[Part]]:
    """
    Different instances of the same part type may be interchanged in the graph. This method computes all permutations
    of parts while taking this into account. This reduced the number of permutations.
    :param parts: Set of parts to compute permutations
    :return: List of part permutations
    """
    # split parts into sets of same part type
    equal_parts_sets: Dict[Part, Set[Part]] = {}
    for part in parts:
        for seen_part in equal_parts_sets.keys():
            if part.equivalent(seen_part):
                equal_parts_sets[seen_part].add(part)
                break
        else:
            equal_parts_sets[part] = {part}

    multi_occurrence_parts: List[Set[Part]] = [pset for pset in equal_parts_sets.values() if len(pset) > 1]
    single_occurrence_parts: List[Part] = [next(iter(pset)) for pset in equal_parts_sets.values() if len(pset) == 1]

    full_perms: List[Tuple[Part]] = [()]
    for mo_parts in multi_occurrence_parts:
        perms = list(permutations(mo_parts))
        full_perms = list(perms) if full_perms == [()] else [t1 + t2 for t1 in full_perms for t2 in perms]

    # Add single occurrence parts
    full_perms = [fp + tuple(single_occurrence_parts) for fp in full_perms]
    assert all([len(perm) == len(parts) for perm in full_perms]), 'Mismatching number of elements in permutation(s).'
    return full_perms

print(evaluate(list(zip(X_val, y_val))))
print(evaluate(list(zip(X_test, y_test))))