In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from typing import List, Dict, Tuple
import io
import sys
import pickle
import itertools
import datetime
import copy
from tqdm import tqdm
import random
import csv
import json
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math

import torch

import torch.optim.lr_scheduler as lr_scheduler
import torch_geometric
import torch_geometric.transforms as T
import torch_geometric.transforms
import torch_geometric.datasets
import torch_geometric.nn
from torch_geometric.utils import to_networkx

from IPython.display import display, HTML
import warnings
warnings.filterwarnings('ignore', module='sklearn')

In [None]:
from env_variables import *
from utils_helpers import *

In [None]:
results_output_dir = "./results"
if not os.path.exists(results_output_dir):
    os.makedirs(results_output_dir)

In [None]:
device = torch.device('cpu')

# Deep Learning Pytorch
Packages versions:
- torchmetrics==0.5.0
- pytorch-lightning==1.5.10

In [None]:
def apply_constraints_func(probabilities, data, threshold=0.5, debug = False, constraints_type="greedy"):
    # Apply constraints to the predicted probabilities
    edge_types = data.edge_types.detach().cpu().numpy()
    source_nodes = data.edge_index[0].detach().cpu().numpy()
    target_nodes = data.edge_index[1].detach().cpu().numpy()
    allowed_edge_types = set([edges_type_int_encodings['nuclei-golgi'], edges_type_int_encodings['golgi-nuclei']])

    if constraints_type=="greedy":
        """
        Greedy Approach to apply constraints to the model output.
        """
        # Sort edge_list_info and probabilities in descending order of probabilities
        sorted_indices = sorted(range(len(probabilities)), key=lambda k: probabilities[k], reverse=True)
        
        edge_types_sorted = [edge_types[i] for i in sorted_indices]
        source_nodes_sorted = [source_nodes[i] for i in sorted_indices]
        target_nodes_sorted = [target_nodes[i] for i in sorted_indices]
        
        probabilities_sorted = [probabilities[i] for i in sorted_indices]

        # Create a set to keep track of assigned node ids
        assigned_nodes = set()

        # Create a new list to store the predicted labels
        pred_labels = [0] * len(probabilities)
        
        # Assign 1 to the links with the highest probabilities for each nuclei and golgi   
        for i in range(len(sorted_indices)):
            src = source_nodes_sorted[i]
            tgt = target_nodes_sorted[i]
            edge_type = edge_types_sorted[i]
            prob = probabilities_sorted[i]
            
            # If the edge is nuclei-golgi or golgi-nuclei and the nodes are not already assigned and probability > threshold
            if (edge_type in allowed_edge_types) and (src not in assigned_nodes) and (tgt not in assigned_nodes):

                if prob > threshold:
                    pred_labels[sorted_indices[i]] = 1
                    assigned_nodes.add(src)
                    assigned_nodes.add(tgt)
    elif constraints_type=="optimization":
        """
        Apply modified version of the Jonker-Volgenant algorithm to get global labels.
        """
        # Build a NetworkX bipartite graph
        G = nx.Graph()
        
        G.add_nodes_from(source_nodes, bipartite=0)
        G.add_nodes_from(target_nodes, bipartite=1)

        # Invert probabilities
        MAX_VALUE = 1e9  # A very large number
        inverse_probabilities = [1 / prob if prob != 0 else MAX_VALUE for prob in probabilities]


        # Add weighted edges with inverse probabilities
        for i in range(len(source_nodes)):
            src = source_nodes[i]
            tgt = target_nodes[i]
            edge_type = edge_types[i]
            weight = inverse_probabilities[i]

            # Add edge only if it's nuclei-golgi or golgi-nuclei
            if edge_type in allowed_edge_types:
                G.add_edge(src, tgt, weight=weight)

        # Compute minimum weight full matching
        matching = nx.bipartite.minimum_weight_full_matching(G, weight='weight')

        # Create a dictionary to store predicted assignments
        preds_dict = {}

        # Assign labels based on matching
        for src, tgt in matching.items():
            preds_dict[src] = tgt
            preds_dict[tgt] = src

        # Initialize predicted labels
        pred_labels = [0] * len(probabilities)
        
        # Assign labels based on matching dictionary
        for i in range(len(source_nodes)):
            src = source_nodes[i]
            tgt = target_nodes[i]
            
            if preds_dict.get(src) == tgt or preds_dict.get(tgt) == src:
                pred_labels[i] = 1

    else:
        raise ValueError("Wrong constraints type!")

    return pred_labels

In [None]:
"""
MPNN Classifier built from scratch
"""

from torch_scatter import scatter
from typing import List, Any, Iterable

def dims_from_multipliers(output_dim: int, multipliers: Iterable[int]) -> Tuple[int, ...]:
    return tuple(int(output_dim * mult) for mult in multipliers)

from torch.nn import LeakyReLU

class MLP(torch.nn.Module):
    def __init__(self, input_dim: int, fc_dims: Iterable[int], nonlinearity: torch.nn.Module,
                 dropout_p: float = 0, use_batchnorm: bool = False, last_output_free: bool = False):
        super().__init__()
        assert isinstance(fc_dims, (list, tuple)
                          ), f"fc_dims must be a list or a tuple, but got {type(fc_dims)}"

        self.input_dim = input_dim
        self.fc_dims = fc_dims
        self.nonlinearity = nonlinearity
        # if dropout_p is None:
        #     dropout_p = 0
        # if use_batchnorm is None:
        # use_batchnorm = False
        self.dropout_p = dropout_p
        self.use_batchnorm = use_batchnorm

        layers: List[torch.nn.Module] = []
        for layer_i, dim in enumerate(fc_dims):
            layers.append(torch.nn.Linear(input_dim, dim))
            if last_output_free and layer_i == len(fc_dims) - 1:
                continue

            layers.append(nonlinearity)
            if dim != 1:
                if use_batchnorm:
                    layers.append(torch.nn.BatchNorm1d(dim))
                if dropout_p > 0:
                    layers.append(torch.nn.Dropout(p=dropout_p))
            input_dim = dim
        self.fc_layers = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.fc_layers(x)

    @property
    def output_dim(self) -> int:
        return self.fc_dims[-1]
    
class BasicEdgeModel(torch.nn.Module):
    """ Class used to peform an edge update during neural message passing """

    def __init__(self, edge_mlp):
        super(BasicEdgeModel, self).__init__()
        self.edge_mlp = edge_mlp

    def forward(self, x, edge_index, edge_attr):
        source_nodes, target_nodes = edge_index
        # assert len(source_nodes) == len(target_nodes) == len(
            # edge_attr), f"Different lengths {len(source_nodes)}, {len(target_nodes)}, {len(edge_attr)} "
        merged_features = torch.cat([x[source_nodes], x[target_nodes], edge_attr], dim=1)
        # print(f"merged_features, {merged_features.shape}")
        assert len(merged_features) == len(source_nodes), f"Merged input has wrong length {merged_features.shape} != {edge_attr.shape}"
        return self.edge_mlp(merged_features)
    
class UniformAggNodeModel(torch.nn.Module):
    """ Class used to peform a node update during neural message passing """

    def __init__(self, flow_model, node_mlp, node_agg_mode: str):
        super().__init__()
        assert (flow_model.output_dim == node_mlp.input_dim), f"Flow models have incompatible output/input dims"
        self.flow_model = flow_model
        self.node_mlp = node_mlp
        self.node_agg_mode = node_agg_mode

    def forward(self, x, edge_index, edge_attr, **kwargs):
        start_nodes, end_nodes = edge_index

        """
        x[start_nodes]  # features of nodes emitting messages, past -> future
        edge_attr       # emitted messages
        x[end_nodes] # features of nodes receiving messages, future nodes receiving from earlier ones
        """
        # input features order does not matter as long as it is symmetric between two flow inputs
        #                               nodes receiving, nodes sending, edges
        flow_forward_input = torch.hstack((x[end_nodes], x[start_nodes], edge_attr))
        assert len(flow_forward_input) == len(edge_attr)
        # [n_edges x edge_feature_count]
        flow_backward_input = torch.hstack((x[start_nodes], x[end_nodes], edge_attr))
        assert flow_forward_input.shape == flow_backward_input.shape, f"{flow_forward_input.shape} != {flow_backward_input.shape}"

        # [2*n_edges x edge_feature_count]
        flow_total_input = torch.vstack((flow_forward_input, flow_backward_input))
        flow_processed = self.flow_model(flow_total_input)

        # aggregate features for each node based on features taken over each node
        # the index has to account for both incoming and outgoing edges - so that each edge is considered by both of its nodes
        flow_total = scatter(src=flow_processed, index=torch.cat((end_nodes, start_nodes)),
                             reduce=self.node_agg_mode, dim=0, dim_size=len(x))
        
        return self.node_mlp(flow_total)
    
class InitialUniformAggNodeModel(torch.nn.Module):
    """ Class used to peform an initial update for empty nodes at the beginning of message passing
    The initial features are simply a transformation of edge features and therefore depend largely on the graph structure
    The aggregation logic is the same as for `UniformAggNodeModel` but without node features.

    The abscense of node features means that transforming features before aggregation is the same as transforming the input edge features before the node update. Therefore, the transform is only applied after aggregation and instead the initial edge model is more powerful in comparison.
    """

    def __init__(self, node_mlp, node_agg_mode: str):
        super().__init__()
        self.node_mlp = node_mlp
        self.node_agg_mode = node_agg_mode

    def forward(self, node_attr, edge_index, edge_attr, num_nodes: int, **kwargs):
        start_nodes, end_nodes = edge_index
        flow_total = scatter(src=torch.vstack((edge_attr, edge_attr)),
                             index=torch.cat((end_nodes, start_nodes)),
                             reduce=self.node_agg_mode, dim=0, dim_size=num_nodes)
        return self.node_mlp(flow_total)
    
class InitialUniformAggNodeModelWithNodeFeats(torch.nn.Module):
    """ Class used to peform an initial update for empty nodes at the beginning of message passing
    The initial features are simply a transformation of edge features and therefore depend largely on the graph structure
    The aggregation logic is the same as for `UniformAggNodeModel` but without node features.

    The abscense of node features means that transforming features before aggregation is the same as transforming the input edge features before the node update. Therefore, the transform is only applied after aggregation and instead the initial edge model is more powerful in comparison.
    """

    def __init__(self, node_mlp, node_agg_mode: str):
        super().__init__()
        self.node_mlp = node_mlp
        self.node_agg_mode = node_agg_mode

    def forward(self, node_attr, edge_index, edge_attr, num_nodes: int, **kwargs):
        start_nodes, end_nodes = edge_index
        flow_total = scatter(src=torch.vstack((edge_attr, edge_attr)),
                             index=torch.cat((end_nodes, start_nodes)),
                             reduce=self.node_agg_mode, dim=0, dim_size=num_nodes)
        return self.node_mlp(torch.cat((node_attr,flow_total), dim=1))


class MessagePassingNetworkRecurrent(torch.nn.Module):
    def __init__(self, edge_model: torch.nn.Module, node_model: torch.nn.Module, steps: int):
        """
        Args:
            edge_model: an Edge Update model
            node_model: an Node Update model
        """
        super().__init__()
        self.edge_model = edge_model
        self.node_model = node_model
        self.steps = steps

    def forward(self, x, edge_index, edge_attr, num_nodes: int):
        """
        Does a single node and edge feature vectors update.
        Args:
            x: node features matrix
            edge_index: tensor with shape [2, M], with M being the number of edges, indicating nonzero entries in the
            graph adjacency (i.e. edges)
            edge_attr: edge features matrix (ordered by edge_index)
        Returns: Updated Node and Edge Feature matrices
        """
        for step in range(self.steps):
            # Edge Update
            edge_attr_mpn = self.edge_model(x, edge_index, edge_attr)

            if step == self.steps - 1:
                continue  # do not process nodes in the last step - only edge features are used for classification
            # Node Update
            x = self.node_model(x, edge_index, edge_attr_mpn)
        return x, edge_attr_mpn

class MessagePassingNetworkNonRecurrent(torch.nn.Module):
    def __init__(self, edge_models: List[torch.nn.Module], node_models: List[torch.nn.Module], steps: int):
        """
        Args:
            edge_models: a list/tuple of callable Edge Update models
            node_models: a list/tuple of callable Node Update models
        """
        super().__init__()
        assert len(edge_models) == steps, f"steps={steps} not equal edge models {len(edge_models)}"
        assert len(node_models) == steps - 1, f"steps={steps} -1 not equal node models {len(node_models)}"
        self.edge_models = torch.nn.ModuleList(edge_models)
        self.node_models = torch.nn.ModuleList(node_models)
        self.steps = steps

    def forward(self, x, edge_index, edge_attr, num_nodes: int):
        """
        Does a single node and edge feature vectors update.
        Args:
            x: node features matrix
            edge_index: tensor with shape [2, M], with M being the number of edges, indicating nonzero entries in the
            graph adjacency (i.e. edges)
            edge_attr: edge features matrix (ordered by edge_index)
        Returns: Updated Node and Edge Feature matrices
        """
        edge_embeddings = []
        for step, (edge_model, node_model) in enumerate(zip(self.edge_models, self.node_models.append(None))):
            # Edge Update
            edge_attr_mpn = edge_model(x, edge_index, edge_attr)
            edge_embeddings.append(edge_attr_mpn)

            if step == self.steps - 1:
                continue  # do not process nodes in the last step - only edge features are used for classification
            # Node Update
            x = node_model(x, edge_index, edge_attr_mpn)
        assert len(edge_embeddings) == self.steps, f"Collected {len(edge_embeddings)} edge embeddings for {self.steps} steps"
        return x, edge_embeddings[-1]

class GraphClassifierMPNN(torch.nn.Module):
    def __init__(self, dataset_num_node_features, dataset_num_edge_features, is_recurrent=True, use_node_feats= False):
        """ Top level model class holding all components necessary to perform classification on a graph
        :param dataset_num_edge_features: number of edge features
        """
        super().__init__()

        params_gnn = {'initial_edge_model_input_dim': dataset_num_edge_features,
        'edge_dim': 16,
        'fc_dims_initial_edge_model_multipliers': (1, 1),
        'nonlinearity_initial_edge': LeakyReLU(negative_slope=0.2, inplace=True),
        'fc_dims_initial_node_model_multipliers': (2, 4, 1),
        'nonlinearity_initial_node': LeakyReLU(negative_slope=0.2, inplace=True),
        'directed_flow_agg': 'max',
        'fc_dims_edge_model_multipliers': (4, 1),
        'nonlinearity_edge': LeakyReLU(negative_slope=0.2, inplace=True),
        'fc_dims_directed_flow_model_multipliers': (2, 1),
        'nonlinearity_directed_flow': LeakyReLU(negative_slope=0.2, inplace=True),
        'fc_dims_total_flow_model_multipliers': (4, 2, 1),
        'nonlinearity_total_flow': LeakyReLU(negative_slope=0.2, inplace=True),
        'fc_dims_edge_classification_model_multipliers': (4, 2, 1),
        'nonlinearity_edge_classification': LeakyReLU(negative_slope=0.2, inplace=True),
        'use_batchnorm': False,
        'mpn_steps': 4,
        'is_recurrent': is_recurrent,
        'use_node_feats':use_node_feats,
        'node_dim_multiplier': 2}
        
        edge_dim = params_gnn["edge_dim"]
        node_dim = edge_dim * params_gnn["node_dim_multiplier"]  # Have nodes hold 2x info of edges
        use_batchnorm = params_gnn["use_batchnorm"]

        #build_edge_model
        fc_dims_initial_edge = dims_from_multipliers(edge_dim, params_gnn["fc_dims_initial_edge_model_multipliers"])
        self.initial_edge_model = MLP(params_gnn["initial_edge_model_input_dim"], fc_dims_initial_edge,
                            params_gnn["nonlinearity_initial_edge"], use_batchnorm=use_batchnorm)
        
        # Initial node model    
        initial_node_agg_mode = params_gnn["directed_flow_agg"]
        fc_dims_initial_node = dims_from_multipliers(node_dim, params_gnn["fc_dims_initial_node_model_multipliers"])
        
        if(not use_node_feats):#Do not use features to compute initial node embeddings
            self.initial_node_model = InitialUniformAggNodeModel(MLP(edge_dim, fc_dims_initial_node,
                                                                    params_gnn["nonlinearity_initial_node"], use_batchnorm=use_batchnorm),
                                                                initial_node_agg_mode)
        else:
            self.initial_node_model = InitialUniformAggNodeModelWithNodeFeats(MLP(dataset_num_node_features+edge_dim, fc_dims_initial_node,
                                                                    params_gnn["nonlinearity_initial_node"], use_batchnorm=use_batchnorm),
                                                                initial_node_agg_mode)
        # Edge classification model    
        fc_dims_edge_classification_model = dims_from_multipliers(edge_dim, params_gnn["fc_dims_edge_classification_model_multipliers"]) + (1,)
        self.edge_classifier = MLP(edge_dim, fc_dims_edge_classification_model,
                            params_gnn["nonlinearity_edge_classification"], last_output_free=True)
        
        # Define models in MPN
        edge_models, node_models = [], []
        steps = params_gnn["mpn_steps"]
        assert steps > 1, "Fewer than 2 MPN steps does not make sense as in that case nodes do not get a chance to update"
        is_recurrent = params_gnn["is_recurrent"]
        for step in range(steps):
            # Edge model
            edge_model_input = node_dim * 2 + edge_dim  # edge_dim * 5
            fc_dims_edge = dims_from_multipliers(edge_dim, params_gnn["fc_dims_edge_model_multipliers"])
            edge_models.append(BasicEdgeModel(MLP(edge_model_input, fc_dims_edge, params_gnn["nonlinearity_edge"], use_batchnorm=use_batchnorm)))

            if step == steps - 1: # don't need a node update at the last step
                continue

            # Node model
            flow_model_input = node_dim * 2 + edge_dim  # two nodes and their edge
            fc_dims_directed_flow = dims_from_multipliers(node_dim, params_gnn["fc_dims_directed_flow_model_multipliers"])
            fc_dims_aggregated_flow = dims_from_multipliers(node_dim, params_gnn["fc_dims_total_flow_model_multipliers"])
            
            node_agg_mode = params_gnn["directed_flow_agg"]

            individual_flow_model = MLP(flow_model_input, fc_dims_directed_flow,params_gnn["nonlinearity_directed_flow"], use_batchnorm=use_batchnorm)
            aggregated_flow_model = MLP(node_dim, fc_dims_aggregated_flow,params_gnn["nonlinearity_total_flow"], use_batchnorm=use_batchnorm)
            node_models.append(UniformAggNodeModel(individual_flow_model,aggregated_flow_model, node_agg_mode))

            if is_recurrent:  # only one model to use at each step
                break

        if is_recurrent:
            assert len(edge_models) == len(node_models) == 1
            self.mpn_model = MessagePassingNetworkRecurrent(edge_models[0], node_models[0], steps)
        else:
            self.mpn_model = MessagePassingNetworkNonRecurrent(edge_models, node_models, steps)

        self.device = torch.device('cpu')
    
    def forward(self, data):
        node_attr, edge_index, edge_attr, num_nodes = data.x, data.edge_index.long(), data.edge_attr, data.num_nodes

        # Initial Edge embeddings with Null node embeddings
        edge_attr = self.initial_edge_model(edge_attr)
        
        # Initial Node embeddings with Null original embeddings
        x = self.initial_node_model(node_attr, edge_index, edge_attr, num_nodes, device=self.device)
        assert len(x) == num_nodes
        
        x, final_edge_embeddings = self.mpn_model(x, edge_index, edge_attr, num_nodes)
        return self.edge_classifier(final_edge_embeddings)
    
    def forward_graph(self, graph, criterion = None):
        out = self.forward(graph.pyg_graph).view(-1)
        loss = None
        true = graph.pyg_graph.edge_label
        if(criterion):
            loss = criterion(out, true)
        return out, loss, true

In [None]:
"""
MLP Classifier built from scratch
"""

class GraphClassifierMLP(torch.nn.Module):
    def __init__(self, dimensions):
        super().__init__()
        layers = []
        for i in range(len(dimensions) - 1):
            layers.append(torch.nn.Linear(dimensions[i], dimensions[i+1]))
            layers.append(torch.nn.ReLU())  # You can use other activation functions here

        # Remove the last ReLU layer
        layers.pop()

        self.mlp = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)
    
    def forward_graph(self, graph, criterion = None):
        out = self.forward(torch.from_numpy(graph.edge_x).to(torch.float)).view(-1)
        true = torch.from_numpy(graph.edge_y)
        loss = None
        if(criterion):
            loss = criterion(out, true)
        return out, loss, true

In [None]:
class ModelWrapper:
    def __init__(self, model, params):
        self.model = model
        self.params = params
    
    def save(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump({'model': self.model, 'params': self.params}, f)
        print("Model saved successfully.")

    @classmethod
    def load(cls, filename):
        with open(filename, 'rb') as f:
            data = pickle.load(f)
        model = data['model']
        params = data['params']
        return cls(model, params)

    # Example usage:
    # Instantiate your model
    # model = YourModel()
    # params = {'param1': value1, 'param2': value2, ...}

    # Wrap the model
    # model_wrapper = ModelWrapper(model, params)

    # Save the model
    # model_wrapper.save('model.pkl')

    # Load the model
    # loaded_model_wrapper = ModelWrapper.load('model.pkl')
    # loaded_model = loaded_model_wrapper.model
    # loaded_params = loaded_model_wrapper.params


In [None]:
##################################################################################################
## Functions to  train and evaluate neural network 
#################################################################################################
import sklearn.metrics
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import negative_sampling

def train_link_predictor(
    model, train_data, val_data, optimizer, criterion, n_epochs=100, debug = False,
    early_stopper = None, scheduler = None, apply_constraints = True
):
    early_stopper = early_stopper
    for epoch in range(1, n_epochs + 1):
        model.train()
        random.shuffle(train_data)
        for graph in train_data:
            optimizer.zero_grad()
            out, loss, true = model.forward_graph(graph, criterion = criterion)
            loss.backward()
            optimizer.step()
        
        if(debug):
            if epoch % 10 == 0:
                # Eval the model at the end of each Epoch
                metrics = eval_link_predictor(model, val_data, criterion = criterion, apply_constraints = apply_constraints)
                print(f"Epoch: {epoch:03d}, Train Loss: {loss:.3f}, Metrics:",metrics)
                

        if early_stopper:
            if early_stopper.early_stop(metrics["loss"]):             
                break

        if scheduler:
            scheduler.step()

    return model

@torch.no_grad()
def aggregate_metrics_all(metrics_list, loss_criterion=None):
    aggregated_metrics = {}
    
    aggregated_metrics["rouc_auc_score"] = statistics.mean([metric["rouc_auc_score"] for metric in metrics_list])
    aggregated_metrics["rouc_auc_curve"] = metrics_list[0]["rouc_auc_curve"]

    # Aggregate loss if provided
    if loss_criterion:
        aggregated_metrics["loss"] = torch.mean(torch.stack([metric_["loss"] for metric_ in metrics_list]), dim=0)

    
    # Aggregate other metrics
    metric_keys = ["@best", "@0.5", "@constraints", "@constraints_opt"]

    for key in metric_keys:
        aggregated_metrics[key] = {}    
        aggregated_metrics[key]["metrics"] = aggregate_metrics([metric[key]["metrics"] for metric in metrics_list])
        
        sample_metric = metrics_list[0][key]
        if("@constraints" in sample_metric):
            aggregated_metrics[key]["@constraints"] = {}
            aggregated_metrics[key]["@constraints"]["metrics"] = aggregate_metrics([metric[key]["@constraints"]["metrics"] for metric in metrics_list])

    return aggregated_metrics

@torch.no_grad()
def eval_link_predictor(model, train_data, test_data, criterion = None, plot_roc_curve=False, debug = False, 
                                 apply_constraints=True):
    model.eval()
    
    #computed metrics -> "acc", "precision", "recall", "tp", fp", "tn", "fn"
    metrics_dict = {}
    metrics_dict["individual_metrics"] = {}#the metrics for each graph, key=graph_id->value=graph_metrics
    metrics_dict["aggregated_metrics"] = {}
    
    #Find optimal threshold from training data here:
    optimal_thresholds_list = []

    for i in range(len(train_data)):
        graph = train_data[i]
        graph_id = train_data[i].graph_id
        tp_total_count = len(train_data[i].edge_list)

        out, loss, true = model.forward_graph(graph, criterion = None)

        out = out.sigmoid()
        pred = out.cpu().numpy()

        fpr, tpr, thresholds = sklearn.metrics.roc_curve(true, pred)
        sensitivity = tpr
        specificity = 1 - fpr
        optimal_idx = np.argmax(sensitivity + specificity - 1)
        optimal_threshold = thresholds[optimal_idx]
        optimal_thresholds_list.append({"size":tp_total_count, "thresh":optimal_threshold})
    
    #-> Uncomment Here to compute optimal_threshold based on weighted average
    #_total_size = sum([_v["size"] for _v in optimal_thresholds_list])
    #optimal_threshold = sum([_v["size"]*_v["thresh"] for _v in optimal_thresholds_list])/len(optimal_thresholds_list)/_total_size#weighted average of the threshold for each graph

    #-> Uncomment Here to compute optimal_threshold based on average
    optimal_threshold = statistics.mean(_v["thresh"] for _v in optimal_thresholds_list)

    for i in range(len(test_data)):
        graph = test_data[i]
        graph_id = test_data[i].graph_id
        tp_total_count = len(test_data[i].edge_list)
        
        metrics = {}
        out, loss, true = model.forward_graph(graph, criterion = criterion)
        if(criterion!=None):
            metrics["loss"] = loss
        
        out = out.sigmoid()
        pred = out.cpu().numpy()
        
        if len(np.unique(pred))==1 or len(np.unique(true)) == 1:
            rouc_auc_score = 0
        else:
            rouc_auc_score = round(roc_auc_score(true, pred), 3)
        

        metrics["rouc_auc_score"] = rouc_auc_score
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(true, pred)
        metrics["rouc_auc_curve"] = {"fpr":fpr.tolist(), "tpr":tpr.tolist(), "thresholds":thresholds.tolist()}
        
        # 0.5 threshold
        pred_labels_05 = (pred > 0.5).astype(int)
        metrics["@0.5"] = {}
        metrics["@0.5"]["pred_labels"] = pred_labels_05#save pred labels to make plot of predicted graph
        metrics["@0.5"]["metrics"] = eval_metrics(true, pred_labels_05, tp_total_count)


        #Not needed anymore -> roc auc score is computed from training dataset above and not from test dataset!
        """
        sensitivity = tpr
        specificity = 1 - fpr
        optimal_idx = np.argmax(sensitivity + specificity - 1)
        optimal_threshold = thresholds[optimal_idx]
        """

        # Calculate pred_labels_best with constraints if required
        pred_labels_best = (pred > optimal_threshold).astype(int)
        metrics["@best"] = {}
        metrics["@best"]["metrics"] = eval_metrics(true, pred_labels_best, tp_total_count)
        metrics["@best"]["pred_labels"] = pred_labels_best
        metrics["@best"]["optimal_threshold"] = optimal_threshold
        metrics["figures"] = {}

        if(apply_constraints):
            # Apply constraints
            pred_labels_constraints_05 = apply_constraints_func(pred, graph.pyg_graph, threshold=0.5, constraints_type="greedy")  
            metrics["@0.5"]["@constraints"] = {}
            metrics["@0.5"]["@constraints"]["pred_labels"] = pred_labels_constraints_05
            metrics["@0.5"]["@constraints"]["metrics"] = eval_metrics(true, pred_labels_constraints_05, tp_total_count)

            # Apply constraints with threshold
            pred_labels_constraints_best = apply_constraints_func(pred, graph.pyg_graph, threshold=optimal_threshold, constraints_type="greedy") 
            metrics["@best"]["@constraints"] = {}
            metrics["@best"]["@constraints"]["metrics"] = eval_metrics(true, pred_labels_constraints_best, tp_total_count)
            metrics["@best"]["@constraints"]["pred_labels"] = pred_labels_constraints_best
            
            # Apply constraints without threshold
            pred_labels_constraints = apply_constraints_func(pred, graph.pyg_graph, threshold=0, constraints_type="greedy")  
            metrics["@constraints"] = {}
            metrics["@constraints"]["metrics"] = eval_metrics(true, pred_labels_constraints, tp_total_count)
            metrics["@constraints"]["pred_labels"]  = pred_labels_constraints

            # Apply constraints without jonker-volgenant
            pred_labels_constraints_opt = apply_constraints_func(pred, graph.pyg_graph, constraints_type="optimization")  
            metrics["@constraints_opt"] = {}
            metrics["@constraints_opt"]["metrics"] = eval_metrics(true, pred_labels_constraints_opt, tp_total_count)
            metrics["@constraints_opt"]["pred_labels"]  = pred_labels_constraints_opt
        
        metrics["pred_edge_probabilities"] = pred
        metrics_dict["individual_metrics"][graph_id]= metrics
        test_data[i].metrics = metrics

    metrics_dict["aggregated_metrics"] = aggregate_metrics_all(list(metrics_dict["individual_metrics"].values()), 
                                                               loss_criterion = criterion)
    
    return metrics_dict

In [None]:
def build_model(model_type, dataset_num_node_features, dataset_num_edge_features, dataset_num_total_features,
               dataset_num_classes):
    model = None

    if(model_type=="GNN_Classifier_Recurrent_WithoutNodeFeats"):
        model = GraphClassifierMPNN(dataset_num_node_features, dataset_num_edge_features, is_recurrent=True, use_node_feats=False)
    elif(model_type=="GNN_Classifier_NonRecurrent_WithoutNodeFeats"):
        model = GraphClassifierMPNN(dataset_num_node_features, dataset_num_edge_features, is_recurrent=False, use_node_feats=False)
    elif(model_type=="GNN_Classifier_Recurrent_WithNodeFeats"):
        model = GraphClassifierMPNN(dataset_num_node_features, dataset_num_edge_features, is_recurrent=True, use_node_feats=True)
    elif(model_type=="GNN_Classifier_NonRecurrent_WithNodeFeats"):
        model = GraphClassifierMPNN(dataset_num_node_features, dataset_num_edge_features, is_recurrent=False, use_node_feats=True)
    elif(model_type =="MLP"):
        model_dims = (dataset_num_total_features, 100, 100, dataset_num_classes)
        model = GraphClassifierMLP(model_dims)
    else:
        raise ValueError("Wrong type of classifier!")
        
    return model

def schedule_training_GNN(job_parameters, model, graph_list_train, graph_list_val , debug = False):   
    
    dataset_num_classes = job_parameters["num_classes"]
    dataset_num_node_features = job_parameters["num_node_features"]
    dataset_num_edge_features = job_parameters["num_edge_features"]
    k_inter = job_parameters["knn_inter_nodes"]
    k_intra = job_parameters["knn_intra_nodes"]
    
    lr = job_parameters["lr"]
    n_epochs = job_parameters["n_epochs"]
    device = job_parameters["device"]
    
    criterion_switch = {"BCEWithLogitsLoss":torch.nn.BCEWithLogitsLoss}
    criterion_function = criterion_switch[job_parameters["criterion"]]
    
    if job_parameters["pos_weight"]:
        k_inter_mean = statistics.mean([g.k_inter for g in graph_list_train])
        pos_weight = torch.tensor([max((k_inter_mean+k_intra*2-1)/1,1)])
        criterion = criterion_function(pos_weight = pos_weight)
    else:
        criterion = criterion_function()
    
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    early_stopper = job_parameters["early_stopper"]
    scheduler = job_parameters["scheduler"]
    
    model = train_link_predictor(model, graph_list_train, graph_list_val, optimizer, criterion,
                                          n_epochs=n_epochs, debug = debug, early_stopper = early_stopper, scheduler = scheduler)
            
    return model         

In [None]:
def get_cv_groups(data_type):
    if data_type in ["Real", "Real_automatic"]:
        cross_validation_groups = [
            ["Crop1.csv", "Crop2.csv", "Crop3.csv", "Crop4.csv"],
            ["Crop5_BC.csv", "Crop6_BC.csv"],
            ["Crop7_BC.csv", "Crop8_BC.csv"]
        ]
    else:
        cross_validation_groups = "even"
    return cross_validation_groups

def get_job_params_dl(debug = False):
    combinations = {
        #Select type of training data here, "Real" for manually annotated data, and the remaining options are for synthetic data.
        #Select single option or multiple options for multiple combinations of training and testing data.
         "data_type_train":[
                   "Real",
        #"../data/synthetic_algo_100_points",
        #"../data/synthetic_algo_200_points",
        #"../data/synthetic_algo_300_points",
        #"../data/synthetic_algo_400_points",
        #"../data/synthetic_algo_500_points",
        #"../data/synthetic_algo_600_points",
        #"../data/synthetic_algo_700_points",
        #"../data/synthetic_algo_800_points",
        #"../data/synthetic_algo_900_points",
        #"../data/synthetic_algo_1000_points"
        ],
        #Select type of testing data here, "Real" for manually annotated data, "Real_automatic" for data with centroids detected automatically by CNN model and the remaining options are for synthetic data.
        #Select single option or multiple options for multiple combinations of training and testing data.
        "data_type_test":[
            #"Real",
            "Real_automatic",
       #"../data/synthetic_algo_100_points",
        #"../data/synthetic_algo_200_points",
        #"../data/synthetic_algo_300_points",
        #"../data/synthetic_algo_400_points",
        #"../data/synthetic_algo_500_points",
        #"../data/synthetic_algo_600_points",
        #"../data/synthetic_algo_700_points",
        #"../data/synthetic_algo_800_points",
         #"../data/synthetic_algo_900_points",
        #"../data/synthetic_algo_1000_points"
        ],
        "model_type":[
                #"GNN_Classifier_Recurrent_WithoutNodeFeats",
                #"GNN_Classifier_NonRecurrent_WithoutNodeFeats",
                "GNN_Classifier_Recurrent_WithNodeFeats",
                "GNN_Classifier_NonRecurrent_WithNodeFeats",
                    "MLP",
                    ],
        "knn_inter_nodes":[
                           #7,
                            10,
                            #"min"
                        ],
        "knn_inter_nodes_max": [7],
       "knn_intra_nodes":[0],
        "normalize":[True],#[False]
        "node_feats":[
        #[
        #    'Y', 
        #    'X', 
        #    'Z', 
        #    'node_type', 
            #'ID'
        #],
        [
            'Y', 
            'X', 
            'Z', 
            'node_type',
        ]
        ],

        "edge_feats":[
            [
     'delta_x',
     'delta_y',
     'delta_z',
     'weight',
     'angle_orientation_theta',
     'angle_orientation_phi'],
           [
     'delta_x',
     'delta_y',
     'delta_z',
     'weight',
     ],
    # [
    #'delta_x',
    # 'delta_y',
    # 'delta_z',
    # 'weight',
    # 'x1',
    # 'y1',
    # 'z1',
    # 'node_type1',
    # 'x2',
    # 'y2',
    # 'z2',
    # 'node_type2'
    # ],
    # [
    #'delta_x',
    # 'delta_y',
    # 'delta_z',
    # 'weight',
    # 'x1',
    # 'y1',
    # 'z1',
    # 'node_type1',
    # 'x2',
    # 'y2',
    # 'z2',
    # 'node_type2',
    # 'angle_orientation_theta',
    # 'angle_orientation_phi'
    #],
      #[
     #'angle_orientation_theta',
     #'angle_orientation_phi']
     ],

        "to_undirected":[False],
       "lr":[1e-3],
       "n_epochs":[100],
        "early_stopper": [None],
        "scheduler" : [None],
        "pos_weight" : [True],
        "criterion" : ["BCEWithLogitsLoss"],
        "device" : ["cpu"]
    }
    
    jobs = []
    
    # Generate all possible combinations of the dictionary values
    for values in itertools.product(*combinations.values()):
        # Generate a dictionary for the combination of values
        job_dict = dict(zip(combinations.keys(), values))
        job_dict["scale_features"] = True if "Real" in job_dict["data_type_train"] else False

        index_train = "all"
        index_test = "all"
        cross_validation_groups_train = get_cv_groups(job_dict["data_type_train"])
        cross_validation_groups_test = get_cv_groups(job_dict["data_type_test"])

        job_dict["index_train"] = index_train
        job_dict["index_test"] = index_test
        job_dict["cross_validation_groups_train"] = cross_validation_groups_train
        job_dict["cross_validation_groups_test"] = cross_validation_groups_test

        jobs.append(job_dict)
    
    if(debug):
        print("Total Number of jobs is:",len(jobs))
        print(json.dumps(jobs))
    return jobs

In [None]:
def get_graph_list_dl(jobs, debug = False):
    #build dataframes
    graph_list_dict_deep_learning = {}

    for params in tqdm(jobs):

        params_list_train = [params["data_type_train"], params["knn_inter_nodes"], params["knn_intra_nodes"], 
                                        params["knn_inter_nodes_max"], params["normalize"],
                                        params["scale_features"], str(params["node_feats"]), str(params["edge_feats"])]
        params_list_train = [str(param_) for param_ in params_list_train]
        graph_key = "_".join(params_list_train)

        if graph_key not in graph_list_dict_deep_learning:
            graph_list = get_graph_list(params["data_type_train"], params["knn_inter_nodes"], params["knn_intra_nodes"], 
                                            params["knn_inter_nodes_max"],  normalize = params["normalize"],
                                            scale_feats = params["scale_features"],
                                            node_feats = params["node_feats"], edge_feats = params["edge_feats"],
                                            shuffle = False)
            graph_list_dict_deep_learning[graph_key] = graph_list

        params_list_test = [params["data_type_test"], params["knn_inter_nodes"], params["knn_intra_nodes"], 
                                        params["knn_inter_nodes_max"], params["normalize"],
                                        params["scale_features"], str(params["node_feats"]), str(params["edge_feats"])]
        params_list_test = [str(param_) for param_ in params_list_test]
        graph_key = "_".join(params_list_test)

        if graph_key not in graph_list_dict_deep_learning:
            graph_list = get_graph_list(params["data_type_test"], params["knn_inter_nodes"], params["knn_intra_nodes"], 
                                            params["knn_inter_nodes_max"], normalize = params["normalize"],
                                            scale_feats = params["scale_features"],
                                            node_feats = params["node_feats"], edge_feats = params["edge_feats"],
                                            shuffle = False)
            graph_list_dict_deep_learning[graph_key] = graph_list
    return graph_list_dict_deep_learning

In [None]:
def train_dl(graph_list_dict_deep_learning, jobs, debug = False):
    results_list_pytorch = []
    models_list = []

    for i, job_parameters in tqdm(enumerate(jobs), total=len(jobs)):

        k_inter = job_parameters["knn_inter_nodes"]
        k_inter_max = job_parameters["knn_inter_nodes_max"]
        k_intra = job_parameters["knn_intra_nodes"]

        scale_feats = job_parameters["scale_features"]

        #get data
        params_list_train = [job_parameters["data_type_train"], job_parameters["knn_inter_nodes"], job_parameters["knn_intra_nodes"], 
                                        job_parameters["knn_inter_nodes_max"], job_parameters["normalize"],
                                        job_parameters["scale_features"],
                                         str(job_parameters["node_feats"]), str(job_parameters["edge_feats"])]
        params_list_train = [str(param_) for param_ in params_list_train]
        graph_key = "_".join(params_list_train)
        graph_list_train = graph_list_dict_deep_learning[graph_key]

        params_list_test = [job_parameters["data_type_test"], job_parameters["knn_inter_nodes"], job_parameters["knn_intra_nodes"], 
                                        job_parameters["knn_inter_nodes_max"], job_parameters["normalize"],
                                        job_parameters["scale_features"],
                                         str(job_parameters["node_feats"]), str(job_parameters["edge_feats"])]
        params_list_test = [str(param_) for param_ in params_list_test]
        graph_key = "_".join(params_list_test)
        graph_list_test = graph_list_dict_deep_learning[graph_key]

        job_parameters["num_classes"] = 1
        job_parameters["num_node_features"] = graph_list_train[0].pyg_graph.x.shape[1] if graph_list_train[0].pyg_graph.x.shape[0] >0 else 0
        job_parameters["num_edge_features"] = graph_list_train[0].pyg_graph.edge_attr.shape[1]
        job_parameters["num_total_features"] = graph_list_train[0].edge_x.shape[1]

        indexes_train = job_parameters["index_train"]
        indexes_test = job_parameters["index_test"]
        cross_validation_groups_train = job_parameters.get("cross_validation_groups_train",[])
        cross_validation_groups_test = job_parameters.get("cross_validation_groups_test",[])

        if(indexes_train=="all"):
            indexes_train = [g.graph_id for g in graph_list_train]
        if(indexes_test=="all"):
            indexes_test = [g.graph_id for g in graph_list_test]
        if(cross_validation_groups_train=="even"):
            number_cross_validation_groups = 3
            cross_validation_groups_train = distribute_elements_to_lists(indexes_train, number_cross_validation_groups)
        if(cross_validation_groups_test=="even"):
            number_cross_validation_groups = 3
            cross_validation_groups_test = distribute_elements_to_lists(indexes_test, number_cross_validation_groups)

        indexes_train = set(indexes_train)
        indexes_test = set(indexes_test)

        cv_dataset_list = []
        if(not cross_validation_groups_train):#without cross-validation
            graph_list_train = [el for el in graph_list_train if el.graph_id in indexes_train]
            graph_list_test = [el for el in graph_list_test if el.graph_id in indexes_test]
            cv_dataset_list.append({"train":graph_list_train, "test":graph_list_test})

        else:#with cross-valudation
            for i in range(len(cross_validation_groups_test)):
                graph_list_test_cv = [el for el in graph_list_test if el.graph_id in set(cross_validation_groups_test[i])]
                graph_list_train_cv = []
                for j in range(len(cross_validation_groups_train)):
                    if(j!=i):
                        graph_list_train_cv.extend([el for el in graph_list_train if el.graph_id in set(cross_validation_groups_train[j])])

                cv_dataset_list.append({"train": graph_list_train_cv, "test": graph_list_test_cv})    

        #Train Model
        results = {"cv_results":[], "job_parameters":job_parameters, "aggregated_metrics" : None}
        models_list_cv = []
        for dataset in cv_dataset_list:

            result = {}
            graph_list_train, graph_list_test = dataset["train"], dataset["test"]

            model_type = job_parameters["model_type"]
            model = build_model(model_type, job_parameters["num_node_features"], job_parameters["num_edge_features"], 
                                job_parameters["num_total_features"], 1)

            model = schedule_training_GNN(job_parameters, model, 
                                               graph_list_train, graph_list_test, debug = debug)
            
            result["graphs"] = {}
            result["graphs"]["train"] = graph_list_train
            result["graphs"]["test"] = graph_list_test

            #Eval Model
            result["eval"] = eval_link_predictor(model, graph_list_train, graph_list_test, 
                                                     criterion = None,  apply_constraints = True,
                                                    plot_roc_curve = False, debug = False)
            results["cv_results"].append(result)
            models_list_cv.append(ModelWrapper(model, job_parameters))
        models_list.append(models_list_cv)

        #aggregate all metrics

        all_metrics = {}
        for item in results["cv_results"]:
            individual_metrics = item["eval"]["individual_metrics"]
            for individual_graph in individual_metrics:
                all_metrics[individual_graph] = individual_metrics[individual_graph]

        results["aggregated_metrics"]  = aggregate_metrics_all(list(all_metrics.values()))

        results_list_pytorch.append(results)
    return results_list_pytorch, models_list

In [None]:
def plot_results_dl(results_list_pytorch):
    plot_df_pytorch = plot_table(results_list_pytorch, metrics_dict_entries = [["@best","metrics"],["@best","@constraints","metrics"], 
                                                                               ["@constraints","metrics"], ["@constraints_opt","metrics"]])
    plot_df_pytorch = plot_df_pytorch.sort_values(by=["Algorithm", "Normalize", "K Inter", 'Data Train', 'Data Test','Constraints'])
    display(plot_df_pytorch)
    plot_df_pytorch = plot_df_pytorch.drop(["Data Train", "Data Test"], axis=1)
    display(plot_df_pytorch)
    print(plot_df_to_latex(plot_df_pytorch))
    return plot_df_pytorch

In [None]:
jobs_dl = get_job_params_dl()

#Apply custom processing, make sure data_type_train==data_type_test, useful for some cases namely when using synthetic data, to only allow to train and test on the same type of data.
#jobs_dl = list(filter(lambda x: x["data_type_train"]==x["data_type_test"], jobs_dl))

assert len(jobs_dl)>0

len(jobs_dl)

In [None]:
graph_list_dl = get_graph_list_dl(jobs_dl, debug = False)

In [None]:
results_list_dl, models_list = train_dl(graph_list_dl, jobs_dl, debug = False)

In [None]:
pd.set_option('display.max_rows', 500)
output_df = plot_results_dl(results_list_dl)

In [None]:
def convert_to_final_format(output_df):
    output_df = output_df.drop(["Node Feat.", "Scale", "Normalize", "K Intra",  "TP Percent","TP Total Count","TP","FP","TN","FN"], axis = 1)
    output_df = output_df.rename(columns={"K Inter": "K", "Edge Feat.":"Angles"})
    output_df["Angles"] = output_df["Angles"].apply(lambda x: any("angle" in item for item in x))
    output_df = output_df[["Algorithm","Constraints","Angles","K","ROC AUC Score","Accuracy","TPR","FPR","Precision","F1-Score"]]
    return output_df

final_output_df = convert_to_final_format(output_df)
print(plot_df_to_latex(final_output_df))
final_output_df

In [None]:
def array_to_csv(array : np.array, csv_path : str, separator = ",", columns_order=['YN', 'XN', 'YG', 'XG', 'ZN', 'ZG']):
    with open(csv_path, 'w') as fp:
        for i in range(len(array)):
            row_array = array[i]
            array_dict = {}
            array_dict["XN"], array_dict["YN"], array_dict["ZN"], array_dict["XG"], array_dict["YG"], array_dict["ZG"] = row_array[0], row_array[1], row_array[2], row_array[3], row_array[4], row_array[5]
            new_array = [array_dict[col] for col in columns_order]
            row = separator.join(str(v) for v in new_array)
            if(i!=(len(array)-1)):#only write \n up to the line before the last line
                row+="\n"
            fp.write(row)
    return

In [None]:
def save_results(output_folder, results_list_pytorch):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    results_count = 0
    #Save results to file
    for results_entry in results_list_pytorch:

        results_entry_output_folder = os.path.join(output_folder, "Results_"+str(results_count))
        if not os.path.exists(results_entry_output_folder):
            os.makedirs(results_entry_output_folder)
        results_info = {k:results_entry[k] for k in results_entry if k!="cv_results"}
        print(results_info["aggregated_metrics"].keys())
        desc_file_path = os.path.join(results_entry_output_folder, "params.json")
        with open(desc_file_path, 'w') as f:
            json.dump(results_info, f, indent = 2)

        results_entry_output_folder_constraints = os.path.join(output_folder, "Results_"+str(results_count)+"_constraints")
        if not os.path.exists(results_entry_output_folder_constraints):
            os.makedirs(results_entry_output_folder_constraints)
        desc_file_path = os.path.join(results_entry_output_folder_constraints, "params.json")
        with open(desc_file_path, 'w') as f:
            json.dump(results_info, f, indent = 2)

        results_entry_output_folder_constraints_thresh = os.path.join(output_folder, "Results_"+str(results_count)+"_constraints_threshold")
        if not os.path.exists(results_entry_output_folder_constraints_thresh):
            os.makedirs(results_entry_output_folder_constraints_thresh)
        desc_file_path = os.path.join(results_entry_output_folder_constraints_thresh, "params.json")
        with open(desc_file_path, 'w') as f:
            json.dump(results_info, f, indent = 2)

        results_entry_output_folder_constraints_opt = os.path.join(output_folder, "Results_"+str(results_count)+"_constraints_opt")
        if not os.path.exists(results_entry_output_folder_constraints_opt):
            os.makedirs(results_entry_output_folder_constraints_opt)
        desc_file_path = os.path.join(results_entry_output_folder_constraints_opt, "params.json")
        with open(desc_file_path, 'w') as f:
            json.dump(results_info, f, indent = 2)
        
        data_type_train = results_entry["job_parameters"]["data_type_train"]
        data_type_test = results_entry["job_parameters"]["data_type_test"]
        model_type = results_entry["job_parameters"]["model_type"]
        k_intra = results_entry["job_parameters"]["knn_intra_nodes"]
        k_inter = results_entry["job_parameters"]["knn_inter_nodes"]
        node_feats_str = str(results_entry["job_parameters"]["node_feats"]).replace("[","").replace(",","_").replace("]","").replace("\'","")
        edge_feats_str = str(results_entry["job_parameters"]["edge_feats"]).replace("[","").replace(",","_").replace("]","").replace("\'","")

        print("K_Intra", k_intra, "K_Inter", k_inter)

        print("Aggregated Metrics", json.dumps(results_entry["aggregated_metrics"], indent = 1, cls = CustomEncoder))

        for cv_crop_index, cv_crop in enumerate(results_entry["cv_results"]):
            graphs_list_test = cv_crop["graphs"]["test"]
            graphs_list_train = cv_crop["graphs"]["train"]
            graph_individual_metrics = cv_crop["eval"]["individual_metrics"]
            graphs_indexes = {"train":[g.graph_id for g in graphs_list_train], "test":[g.graph_id for g in graphs_list_test]}

            print("##############################################################\n")
            print("CV Crop Index",cv_crop_index, "", graphs_indexes)

            for graph_id in graphs_indexes["test"]:
                graph_matches = [g for g in graphs_list_test if g.graph_id ==graph_id]
                if len(graph_matches)!=1:
                    raise ValueError("Multiple graphs with same ID!")
                graph_test = graph_matches[0]
                graph_metrics = graph_individual_metrics[graph_id]

                edge_list = GraphInfo.edge_index_to_edge_list(graph_test.pyg_graph.edge_index)
                edge_df = GraphInfo.edge_list_to_edge_df(edge_list)
                edge_df["edge_label"] =  graph_metrics["@best"]["pred_labels"]
                edges_array =  pred_df_to_csv(edge_df, graph_test.nodes_df_original)
                output_file_path = os.path.join(results_entry_output_folder, graph_test.graph_id)
                array_to_csv(edges_array, output_file_path)
                #nodes_df = graph_test.nodes_df

                edge_df_constraints = edge_df.copy()
                edge_df_constraints["edge_label"] = graph_metrics["@constraints"]["pred_labels"]            
                constraints_array = pred_df_to_csv(edge_df_constraints, graph_test.nodes_df_original)
                output_file_path = os.path.join(results_entry_output_folder_constraints, graph_test.graph_id)
                array_to_csv(constraints_array, output_file_path)

                edge_df_constraints = edge_df.copy()
                edge_df_constraints["edge_label"] = graph_metrics["@best"]["@constraints"]["pred_labels"]            
                constraints_array = pred_df_to_csv(edge_df_constraints, graph_test.nodes_df_original)
                output_file_path = os.path.join(results_entry_output_folder_constraints_thresh, graph_test.graph_id)
                array_to_csv(constraints_array, output_file_path)

                edge_df_constraints = edge_df.copy()
                edge_df_constraints["edge_label"] = graph_metrics["@constraints_opt"]["pred_labels"]            
                constraints_array = pred_df_to_csv(edge_df_constraints, graph_test.nodes_df_original)
                output_file_path = os.path.join(results_entry_output_folder_constraints_opt, graph_test.graph_id)
                array_to_csv(constraints_array, output_file_path)

                
        results_count+=1

In [None]:
save_results(os.path.join("./results","results_real_test"), results_list_dl)

In [None]:
#Uncomment this to save models
"""
def save_models(models_list, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    for i, models_cv_list in enumerate(models_list):
        print("####")
        for cv_crop, model in enumerate(models_cv_list):
            print(model.params)
            model_name = str(i)+"_cv-"+str(cv_crop)+"_"+model.params["model_type"]+"_K-"+str(model.params["knn_inter_nodes"])+"_Angles-"+str(any('angle' in element for element in model.params))
            model_path = os.path.join(output_folder, model_name) 
            model.save(model_path)

save_models(models_list, os.path.join("./models","models_test"))
"""