In [1]:
import pickle
from graph import Graph, Part
from typing import List, Callable, Set
from sklearn.model_selection import train_test_split

with open('data/graphs.dat', 'rb') as file:
    all_graphs: List[Graph] = pickle.load(file)
    X_train, X_temp, Y_train, Y_temp = train_test_split(list(map(lambda g: g.get_parts(), all_graphs)), all_graphs, test_size=0.3, random_state=0)
    X_val, X_test, Y_val, Y_test = train_test_split(X_temp, Y_temp, test_size=0.5, random_state=0)

In [2]:
from node import Node

class Ordering:
    def __init__(self, y: List[Graph], degreeAggregate: Callable[[List[int]], int]):
        degrees = {}
        for graph in y:
            for (node, edges) in graph.get_edges().items():
                if node.get_part().get_part_id() not in degrees:
                    degrees[node.get_part().get_part_id()] = []
                degrees[node.get_part().get_part_id()].append(len(edges))
        self.keys = {part: degreeAggregate(degs) for (part, degs) in degrees.items()}

    def sort(self, x: Set[Part]) -> List[Part]:
        sorted_parts = sorted(x, key=lambda n: self.keys[n.get_part_id()])
        sorted_parts.reverse()
        return sorted_parts

In [3]:
# Create all mandatory documents
ordering = Ordering(Y_train, lambda n: sum(n)/len(n))

current_part_id = 0
current_graph_id = 0
gwp_a = []
gwp_graph_indicator = []
gwp_graph_labels = []
gwp_node_labels = []
for x, y in zip(X_train, Y_train):
    ordered = ordering.sort(x)
    adjacency_matrix = y.get_adjacency_matrix(ordered)

    for i in range(adjacency_matrix.shape[0]):
        for j in range(adjacency_matrix.shape[1]):
            if j < i:
                continue
            if adjacency_matrix[i][j] == 1:
                gwp_a.append((current_part_id+j,current_part_id+i)) # Adjacency of i and j
        
        gwp_graph_indicator.append(current_graph_id)
        gwp_node_labels.append((ordered[i].get_family_id(), ordered[i].get_part_id()))
        current_part_id += 1

    gwp_graph_labels.append(1)
    current_graph_id += 1

with open('GRAN/transformed_data/GWP_A.txt', 'a') as f:
    for con in gwp_a:
        f.write(str(con[0])+","+str(con[1])+"\n")
with open('GRAN/transformed_data/GWP_graph_indicator.txt', 'a') as f:
    for indicator in gwp_graph_indicator:
        f.write(str(indicator)+"\n")
with open('GRAN/transformed_data/GWP_graph_labels.txt', 'a') as f:
    for label in gwp_graph_labels:
        f.write(str(label)+"\n")
with open('GRAN/transformed_data/GWP_node_labels.txt', 'a') as f:
    for labels in gwp_node_labels:
        f.write(str(labels[0])+","+str(labels[1])+"\n")