In [None]:
import PIL
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import json
import math
from utils_helpers import *

In [None]:
###################################################
### Define int encodings for string variables #####
###################################################

#nodes type encodings
nodes_type_int_encodings = {
    "nuclei":0,
    "golgi":1
}

reversed_nodes_type_int_encodings = {v: k for k, v in nodes_type_int_encodings.items()}

#edges type encodings
edges_type_int_encodings = {
    "nuclei-nuclei":0,
    "golgi-golgi":1,
    "golgi-nuclei":2,
    "nuclei-golgi":3
}

reversed_edges_type_int_encodings = {v: k for k, v in edges_type_int_encodings.items()}

def nx_convert_edges_df_to_list(edges_df):
    """
    Converts a Pandas DataFrame of edges to a list of edges.

    Args:
    edges_df: A Pandas DataFrame of edges.

    Returns:
    A list of edges.
    """

    edges_list = []
    for i in range(len(edges_df)):
        edge = [edges_df.loc[i, "source"], edges_df.loc[i, "target"]]
        edge_attrs = {}
        for col in edges_df.columns:
            if col not in ["source", "target"]:
                edge_attrs[col] = edges_df.loc[i, col] 
        edge.append(edge_attrs)
        edges_list.append(edge)
    return edges_list

def nx_convert_nodes_df_to_list(nodes_df):
    nodes_list = []
    for index, row in nodes_df.iterrows():
        node_info = (row["ID"],{
            "Y": row["Y"],
            "X": row["X"],
            "Z": row["Z"],
            "node_type": row["node_type"]
        })
        nodes_list.append(node_info)
    return nodes_list
    
def load_df_from_csv(gt_vectors_csv: str, column_names=['YN', 'XN', 'YG', 'XG', 'ZN', 'ZG'],
                     nuclei_columns=['YN', 'XN', 'ZN'],
                     golgi_columns=['YG', 'XG', 'ZG'],
                     nodes_type_int_encodings=nodes_type_int_encodings,
                    edges_type_int_encodings = edges_type_int_encodings):
    if("automatic" in gt_vectors_csv and "results" not in gt_vectors_csv):
        df = pd.read_csv(gt_vectors_csv, delimiter=",")
        df["node_type"] = df["node_type"].apply(lambda x : nodes_type_int_encodings[x])
        nodes_df = df.copy()
        edges_df = pd.DataFrame(columns=['source', 'target', "edge_type"])#empty
    else:
        df = pd.read_csv(gt_vectors_csv, delimiter=",", header=None)
        df.columns = column_names

        # Split the DataFrame into two
        nuclei_df = df[nuclei_columns].copy().reset_index(drop=True)
        nuclei_df["node_type"] = nodes_type_int_encodings['nuclei']
        nuclei_df.rename(columns={'YN': 'Y', 'XN': 'X', 'ZN': 'Z'}, inplace=True)  # Rename nuclei columns

        golgi_df = df[golgi_columns].copy().reset_index(drop=True)
        golgi_df["node_type"] = nodes_type_int_encodings['golgi']
        golgi_df.rename(columns={'YG': 'Y', 'XG': 'X', 'ZG': 'Z'}, inplace=True)  # Rename golgi columns

        # Set the 'ID'
        golgi_df["ID"] = range(len(nuclei_df), len(nuclei_df) + len(golgi_df))
        nuclei_df["ID"] = range(len(nuclei_df))

        # Concatenate the two DataFrames into 'nodes_df'
        nodes_df = pd.concat([nuclei_df, golgi_df]).reset_index(drop=True)
        nodes_df["X"] = nodes_df["X"].apply(lambda x: x - 1)
        nodes_df["Y"] = nodes_df["Y"].apply(lambda y: y - 1)
        nodes_df["Z"] = nodes_df["Z"].apply(lambda z: z - 1)

        # Create edges between nuclei and golgi based on nodes in the same line
        edges_df = pd.DataFrame({'source': nuclei_df['ID'], 'target': golgi_df['ID']})        
        
    return df, nodes_df, edges_df

def add_legend_plot(ax, plot_styles):
    # Create custom legend handles and labels based on plot_styles
    legend_handles = []
    legend_labels = []

    #Add Edge Legends
    for label, style in plot_styles.items():
        if isinstance(style, dict) and "dashed" in style:
            color = style["color"]
            dashed = style.get("dashed", False)
            alpha = style["alpha"]
            linestyle = "--" if dashed else "-"
            legend_handles.append(matplotlib.lines.Line2D([0], [0], color=color, linewidth=2, linestyle=linestyle, label=label, alpha = alpha))
            legend_labels.append(label.upper() + " Edge")

    # Add node legends
    legend_handles.append(matplotlib.lines.Line2D([0], [0], marker=plot_styles["nuclei"]["marker"], color="w", alpha = plot_styles["nuclei"]["alpha"],
                                    label="Nuclei", markerfacecolor=plot_styles["nuclei"]["color"], markersize=10))
    legend_labels.append("Nuclei")
    legend_handles.append(matplotlib.lines.Line2D([0], [0], marker=plot_styles["golgi"]["marker"], color="w", alpha = plot_styles["golgi"]["alpha"],
                                    label="Golgi", markerfacecolor=plot_styles["golgi"]["color"], markersize=10))
    legend_labels.append("Golgi")

    legend = ax.legend(legend_handles, legend_labels, loc="upper right", fontsize="small", ncol=2)
    
    return legend

def df_make_plot(ax, nodes_df, edges_df, edge_labels, title, plot_styles = {
                "nuclei":{"marker":"o","color":"red", "alpha":0.3},
                "golgi":{"marker":"o","color":"green", "alpha":0.3},
                "": {"color": "black",  "dashed": False, "alpha":1},
                #"tp": {"color": "black",  "dashed": False, "alpha":1},
                #"fp": {"color": "yellow", "dashed": False, "alpha":1},
                #"tn": None,
                #"fn": {"color": "blue",  "dashed": True, "alpha":1}
            }, add_legends=True):
    node_list = nx_convert_nodes_df_to_list(nodes_df)
    edge_list = nx_convert_edges_df_to_list(edges_df)

    GraphInfo.plot_graph_nx_matplotlib(node_list, edge_list, edge_labels, ax, dims = 3, plot_styles = plot_styles, reversed_nodes_type_int_encodings = reversed_nodes_type_int_encodings)
    if(title):
        plt.title(title, fontsize = 15)
    
    #Legends
    if(add_legends):
        add_legend_plot(ax, plot_styles)
    
    return ax

In [None]:
class GraphInfo:
    def __init__(self, 
                 raw_df, nodes_df_original,
                 nodes_df, edges_df, edges_df_knn,
                 graph_id = None, k_inter = None, k_intra = None, node_feats = "all",
                edge_feats = "all"):
        
        self.graph_id = graph_id
        self.k_inter = k_inter
        self.k_intra = k_intra
        
        self.raw_df = raw_df
        
        self.nodes_df_original = nodes_df_original
        self.nodes_df = nodes_df
        
        self.edges_df = edges_df
        self.edges_df_knn = edges_df_knn
        
        self.concat_nodes_edges_df = apply_concat_nodes_edges_df(self.nodes_df, self.edges_df_knn)
        
        node_feats_cols_to_remove = [col for col in list(self.nodes_df.columns) if col not in node_feats]
        node_feats_cols_to_remove = list(set(["ID"]+node_feats_cols_to_remove))
        
        edge_feats_cols_to_remove = [col for col in list(self.edges_df.columns) if col not in edge_feats]
        edge_feats_cols_to_remove = list(set(["source","target","edge_label"]+edge_feats_cols_to_remove))
        
        self.pyg_graph = pyg_load_graph_from_df(self.nodes_df, self.edges_df_knn, to_undirected = False, graph_type = "homo",
                             node_feats_cols_to_remove = node_feats_cols_to_remove, edge_feats_cols_to_remove = edge_feats_cols_to_remove,
                            encoder = pyg_IdentityEncoder())
            
        self.pyg_graph_edge_list, self.pyg_graph_true_labels, self.edge_list, self.edge_list_knn = [], np.array([]), [], []
        self.pyg_graph_edge_list = self.edge_index_to_edge_list(self.pyg_graph.edge_index)
        self.pyg_graph_true_labels = self.pyg_graph.edge_label.detach().cpu().numpy()
        self.edge_list = nx_convert_edges_df_to_list(self.edges_df)
        self.edge_list_knn = nx_convert_edges_df_to_list(self.edges_df_knn)
        
        self.node_list = nx_convert_nodes_df_to_list(self.nodes_df)
        
        #Commented because of empty edges_df for automatic graphs
        self.sklearn_graph = {}
        self.sklearn_graph["X_TRUE"], self.sklearn_graph["Y_TRUE"], self.sklearn_graph["X_KNN"], self.sklearn_graph["Y_KNN"] = sklearn_dataset_to_pytorch(self, shuffle = True)
        
        if node_feats != "all":
            self.nodes_df = self.nodes_df[node_feats]
        if edge_feats != "all":
            self.edges_df, self.edges_df_knn = self.edges_df[edge_feats] , self.edges_df_knn[edge_feats]
            
    @staticmethod
    def edge_index_to_edge_list(edge_index):
        """
        Convert PyTorch Geometric edge_index tensor to NetworkX edge_list format.

        Args:
            edge_index (torch.Tensor): Edge index tensor (2 x num_edges) in PyTorch Geometric format.
            num_nodes (int): Number of nodes in the graph.

        Returns:
            list: List of edges in NetworkX edge_list format.
        """
        edge_list = edge_index.t().tolist()  # Transpose edge_index and convert to list
        return edge_list
    
    @staticmethod
    def edge_list_to_edge_df(edge_list):
        data_list = []
        for edge in edge_list:
            data_list.append({"source":edge[0], "target":edge[1], "edge_label":1})
        edges_df = pd.DataFrame(data_list) 
        return edges_df
    
    @staticmethod
    def convert_edge_pred_to_label(true, pred):
        label_mapping = {
            (1, 1): "tp",
            (1, 0): "fn",
            (0, 1): "fp",
            (0, 0): "tn"
        }
        return label_mapping[(true, pred)]

    @staticmethod
    def convert_edge_preds_to_labels(true_labels, pred_labels):
        """
        Converts edge predicted output (0 or 1) to tp,fp,tn,fn

        Args:
            pred_labels (list): List of predicted edge labels (0 or 1) corresponding to each edge.
            true_labels (list): List of true edge labels (0 or 1) corresponding to each edge.

        Returns:
            pred_labels_list (list) : List of strings with "tp", "fp", "tn", "fn"
        """
        pred_labels_list = []
        
        for i, (true, pred) in enumerate(zip(true_labels, pred_labels)):
            
            label = GraphInfo.convert_edge_pred_to_label(true, pred)
            pred_labels_list.append(label)
        
        return pred_labels_list
    
    @staticmethod
    def plot_graph_nx_matplotlib(node_list, edge_list, edge_labels, ax, dims = 2, 
        plot_edges = True, 
        plot_styles = {
            "nuclei":{"marker":"o","color":"red", "alpha":0.3},
            "golgi":{"marker":"o","color":"green", "alpha":0.3},
            "edge": {"color": "black",  "dashed": False, "alpha":1},
            #"tp": {"color": "black",  "dashed": False, "alpha":1},
            #"fp": {"color": "yellow", "dashed": False, "alpha":1},
            #"tn": None,
            #"fn": {"color": "blue",  "dashed": True, "alpha":1}
        }, reversed_nodes_type_int_encodings = reversed_nodes_type_int_encodings):
        
        assert (len(edge_list)==len(edge_labels))
        dims_allowed_values = [2,3]
        if dims not in dims_allowed_values:
            raise ValueError("Wrong dims! Allowed values", str(dims_allowed_values))
        
        node_pos = {}
        # Draw nodes
        for node in node_list:
            node_id = node[0]
            node_info = node[1]
            node_type = reversed_nodes_type_int_encodings[node_info["node_type"]]
            node_color = plot_styles[node_type]["color"]#node_info["color"]
            node_alpha = plot_styles[node_type]["alpha"]
            marker = plot_styles[node_type]["marker"]
            node_size = 30#node_info["size"]
            pos = (node_info["X"], node_info["Y"], node_info["Z"])#node_info["coordinates"]#node_pos_list[node_id]
            node_pos[node_id] = pos
            
            if(dims ==2):
                ax.scatter(pos[0], pos[1], s=node_size, c=node_color, marker = marker, alpha = node_alpha)
            else:
                ax.scatter3D(pos[0], pos[1], pos[2], s=node_size, c=node_color, marker = marker, alpha = node_alpha)
        
        if plot_edges:
            linewidth = 1
            alpha = 1
            for i in range(len(edge_list)):
                u = edge_list[i][0]
                v = edge_list[i][1]
                edge_label = edge_labels[i]
                edge_style = plot_styles[edge_label]


                if(edge_style!=None):
                    color = edge_style["color"]
                    dashed = edge_style["dashed"]
                    alpha = edge_style["alpha"]

                    # Draw arrows
                    if dashed:
                        line_style = "--"
                    else:
                        line_style = "-"

                    if dims == 2:
                        line_kwargs = {
                            "linewidth": linewidth,
                            "color": color,
                            "alpha": alpha,
                            "linestyle": line_style                    
                        }
                        ax.plot([node_pos[u][0], node_pos[v][0]],
                                [node_pos[u][1], node_pos[v][1]], **line_kwargs)
                    else:
                        line_kwargs = {
                            "linewidth": linewidth,
                            "color": color,
                            "alpha": alpha,
                            "linestyle": line_style
                        }
                        ax.plot([node_pos[u][0], node_pos[v][0]],
                                [node_pos[u][1], node_pos[v][1]],
                                [node_pos[u][2], node_pos[v][2]], **line_kwargs)


        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        if(dims==3):
            ax.set_zlabel("Z", labelpad=-2)  # Adjust this value as needed, default is 5
        return ax
    
    @staticmethod
    def make_graph_plot(node_list, edge_list, edge_labels, plot_styles, dims = 2,
                        title = "", figax = None, figsize = (6,4), plot_edges = True):

        #True Graph
        graph_fig = figax
        if(not graph_fig):
            graph_fig = plt.figure(figsize=figsize)
        plt.title(title)
        
        #G_nx = build_graph_nx(node_list, edge_list)
        
        graph_fig = GraphInfo.plot_graph_nx_matplotlib(node_list, edge_list, edge_labels,  
                                                       graph_fig, dims = dims, plot_styles = plot_styles, 
                                                       plot_edges = plot_edges)
        
        return graph_fig
    
    @staticmethod
    def visualize_nodes_df(df, plot_styles = {
            "nuclei":{"marker":"o","color":"red", "alpha":0.3},
            "golgi":{"marker":"o","color":"green", "alpha":0.3},
            "tp": {"color": "black",  "dashed": False, "alpha":1},
            "fp": {"color": "yellow", "dashed": False, "alpha":1},
            "tn": None,
            "fn": {"color": "blue",  "dashed": True, "alpha":1}
        }, figsize = (6,4)):
        
        # Plot
        fig = plt.figure(figsize=figsize,dpi=250)
        ax = fig.add_subplot(111, projection="3d")
        ax.grid(False)
        ax.xaxis.pane.fill = ax.yaxis.pane.fill = ax.zaxis.pane.fill = False
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_zlabel("Z")
    
        for idx, row in df.iterrows():
            x, y, z = row['X'], row['Y'], row['Z']
            node_type = reversed_nodes_type_int_encodings[row["node_type"]]
            node_color = plot_styles[node_type]["color"]
            node_alpha = plot_styles[node_type]["alpha"]
            marker = plot_styles[node_type]["marker"]
            ax.scatter3D(x, y, z, s=30, c=node_color, marker = marker, alpha = node_alpha)
        
        return ax 

# Performance Metrics

In [None]:
"""
Group results across different trials
"""

def plot_results_dl(results_list_pytorch):
    plot_df_pytorch = plot_table(results_list_pytorch, metrics_dict_entries = [["@best","metrics"],["@best","@constraints","metrics"],
                                                                               ["@constraints","metrics"], ["@constraints_opt","metrics"]])
    plot_df_pytorch = plot_df_pytorch.sort_values(by=["Algorithm", "Normalize", "K Inter", 'Data Train', 'Data Test','Constraints'])
    #plot_df_pytorch = plot_df_pytorch.drop(["Data Train", "Data Test"], axis=1)
    return plot_df_pytorch

def convert_to_final_format(output_df):
    output_df = output_df.drop(["Node Feat.", "Scale", "Normalize", "K Intra", "K Inter", "TP Percent","TP Total Count","TP","FP","TN","FN"], axis = 1)
    output_df = output_df.rename(columns={"Edge Feat.":"Angles"})
    output_df["Angles"] = output_df["Angles"].apply(lambda x: any("angle" in item for item in x))
    return output_df

# Load Data
def load_results_data(results_folder):
    folders = [
            os.path.join(results_folder,"trial1"), 
            os.path.join(results_folder,"trial2"), 
            os.path.join(results_folder,"trial3")
    ]
    subfolders = [el for el in os.listdir(folders[0]) if "constraints" not in el]


    folders_dict = {}
    for subfolder in subfolders:
        if subfolder not in folders_dict:
            folders_dict[subfolder] = {}
        for folder in folders:
            folders_dict[subfolder][folder] = {}
            subfolder_path = os.path.join(folder,subfolder)

            with open(os.path.join(subfolder_path, "params.json"), 'r') as file:
                params = json.load(file)

            files_info = {"graphs":[os.path.join(folder, el) for el in os.listdir(subfolder_path) if el!="params.json"], "params": params} 
            folders_dict[subfolder][folder] = files_info

    df_rows = []

    for subfolder in folders_dict:
        for folder in folders_dict[subfolder]:
            el = folders_dict[subfolder][folder]
            job_parameters = el["params"]["job_parameters"]
            graphs = el["graphs"]
            metrics = el["params"]
            df_rows.append(metrics)

    output_df = plot_results_dl(df_rows)

    output_df = convert_to_final_format(output_df)
    numeric_cols = ["ROC AUC Score",	"Accuracy",	"Precision",	"TPR", "FPR"	,"F1-Score"]
    output_df[numeric_cols] = output_df[numeric_cols].astype(float)
    output_df = output_df.reset_index(drop=True)
    return output_df

#With Edge Feats and With Node Feats
#results_folder="../results/deep_learning_with_edge_feats_with_node_feats/results_real_annotated_normalized"
results_folder="../results/deep_learning_with_edge_feats_with_node_feats/results_synthetic_not_normalized"

output_df = load_results_data(results_folder)
#Get the data name
output_df["Data Train"] = output_df["Data Train"].apply(lambda x: os.path.basename(x))
output_df["Data Test"] = output_df["Data Test"].apply(lambda x: os.path.basename(x))
output_df = output_df.drop(["Data Test"], axis=1)
output_df

In [None]:
grouped = output_df.groupby(["Data Train", "Algorithm", "Angles", "Constraints"]).agg({'ROC AUC Score': ['mean', 'std', 'min', 'max'],
                                                                  'Accuracy': ['mean', 'std', 'min', 'max'],
                                                                  'TPR': ['mean', 'std', 'min', 'max'],
                                                                  'FPR': ['mean', 'std', 'min', 'max'],
                                                                  'Precision': ['mean', 'std', 'min', 'max'],
                                                                  'F1-Score': ['mean', 'std', 'min', 'max']})

grouped = grouped.reset_index(level=['Data Train', 'Algorithm', "Data Train", 'Angles', 'Constraints'])

def format_column(_mean, _min, _max, _std):
    reference = _min if abs(_min) > abs(_max) else _max
    difference = abs(_mean-reference)
    return str(round(_mean, 3)) + "±" + str(round(difference, 3))

numeric_cols = ["ROC AUC Score","Accuracy","TPR", "FPR",	"Precision","F1-Score"]
for col in numeric_cols:  # Iterate over the first level of multi-level columns
    cols = [(col, 'mean'), (col, 'min'), (col, 'max'), (col, 'std')]
    grouped[(col,"")] = grouped.apply(lambda row: format_column(row[cols[0]], row[cols[1]], row[cols[2]], row[cols[3]]), axis=1)
    grouped = grouped.drop(columns= cols)

grouped.columns = grouped.columns.map(''.join)
grouped = grouped[["Data Train","Algorithm", "Constraints","Angles",	"ROC AUC Score","Accuracy","TPR", "FPR","Precision",	"F1-Score"]]
grouped

In [None]:
# Transform Algorithm column
grouped = grouped.sort_values(by=['Data Train', 'Angles', 'Algorithm'], ascending=[True, True, False]).reset_index(drop=True)

grouped['Algorithm'] = grouped['Algorithm'].replace({
    'GNN_Classifier_NonRecurrent': 'MPNN Non-Recurrent',
    'GNN_Classifier_Recurrent': 'MPNN Recurrent'
})

# Add Angular Features information to Algorithm column
grouped['Algorithm'] = grouped.apply(lambda row: f"{row['Algorithm']} {'w/ Angular Features' if row['Angles'] else 'w/o Angular Features'}", axis=1)

# Drop unnecessary columns
#grouped = grouped.drop(columns=['Data Train', 'Angles'])
grouped = grouped.drop(columns=['Angles'])
grouped

In [None]:
print(grouped.to_latex())

# Plots

In [None]:
#Plot Input Dataset
#input_folder = "../data/vectors"#real manually annotated
#input_folder = "../data/vectors_automatic_csv"#real automatic segmentation
input_folder = "../data/synthetic_algo_200_points/"#synthetic

subplot_params = {"projection":"3d"}
fig_all = plt.figure(figsize=(18, 9), dpi= 300)#12,9
plt.subplots_adjust(wspace=0.4, hspace=0.15)#wspace-> horizontal hspace-> vertical
rows = 2
columns = 4
plot_styles = {
                "nuclei":{"marker":"o","color":"red", "alpha":0.3},
                "golgi":{"marker":"o","color":"green", "alpha":0.3},
                "": {"color": "black",  "dashed": False, "alpha":1},
                #"tp": {"color": "black",  "dashed": False, "alpha":1},
                #"fp": {"color": "yellow", "dashed": False, "alpha":1},
                #"tn": None,
                #"fn": {"color": "blue",  "dashed": True, "alpha":1}
}

for (i,file) in enumerate(os.listdir(input_folder)):
    file_path_input = os.path.join(input_folder, file)
    title = os.path.basename(file_path_input)
    title = os.path.splitext(title)[0]
    
    #True Graph
    figax = plt.subplot(rows, columns, i+1, **subplot_params)
    df, nodes_df, edges_df = load_df_from_csv(file_path_input)
    edge_labels = [""]*len(edges_df)
    df_figure = df_make_plot(figax, nodes_df, edges_df, edge_labels, title, add_legends=False)

    
add_legend_plot(fig_all, plot_styles)
plt.show()

In [None]:
#Plot predicted with labels
def annote_figure(figax, annotations, fontsize = 7):
    
    start_point_y = 0.9
    start_point_x = 1.1
        
    for i, annotation in enumerate(annotations):
        figax.annotate(annotation, xy=(start_point_x, start_point_y - i * 0.1), 
                       xycoords="axes fraction", fontsize=fontsize, color="black")
    return

def get_edge_label(edges_df_true, edges_df_pred):
    edges_dict_src_tgt = dict(zip(edges_df_true.source, edges_df_true.target))
    edges_dict_tgt_src = dict(zip(edges_df_true.target, edges_df_true.source))
    edges_dict_true = {**edges_dict_src_tgt, **edges_dict_tgt_src} #merge two dicts

    #Get TP and FP
    edge_labels = []
    for idx, row in edges_df_pred.iterrows():
        if(row["source"] in edges_dict_true and edges_dict_true[row["source"]]==row["target"]):
            edge_labels.append("tp")
        else:
            edge_labels.append("fp")
    edges_df_pred["edge_label"] = edge_labels

    #Get FN
    edges_dict_src_tgt = dict(zip(edges_df_pred.source, edges_df_pred.target))
    edges_dict_tgt_src = dict(zip(edges_df_pred.target, edges_df_pred.source))
    edges_dict_pred = {**edges_dict_src_tgt, **edges_dict_tgt_src}#merge two dicts
    
    rows = []
    for idx, row in edges_df_true.iterrows():
        if(row["source"] not in edges_dict_pred or edges_dict_pred[row["source"]]!=row["target"]):
            rows.append({"source":row["source"], "target":row["target"], "edge_label":"fn"})
    fn_edges_df = pd.DataFrame(rows)

    edges_df_pred = pd.concat([edges_df_pred, fn_edges_df], ignore_index=True, sort=False).reset_index(drop=True)
    return edges_df_pred

input_folder = "../data/vectors"

#Manually annotated data - Classical Bipartite Matching algorithms
#results_folder = "../results/HopcroftKarp_RealDataManuallyAnnotated_k7"
#results_folder = "../results/JonkerVolgenant_RealDataManuallyAnnotated_k7"

#Manually annotated data - MPNN Recurrent With Angular Features
results_folder = "../results/deep_learning_with_edge_feats_with_node_feats/results_real_annotated_normalized/trial1/Results_0"#without constraints
#results_folder = "../results/deep_learning_with_edge_feats_with_node_feats/results_real_annotated_normalized/trial1/Results_0_constraints"#MPNN Recurrent Greedy W/o Threshold
#results_folder = "../results/deep_learning_with_edge_feats_with_node_feats/results_real_annotated_normalized/trial1/Results_0_constraints_threshold"#MPNN Recurrent Greedy W/ Threshold

subplot_params = {"projection":"3d"}
fig_all = plt.figure(figsize=(18, 9), dpi= 300)
plt.subplots_adjust(wspace=0.4, hspace=0.15)#wspace-> horizontal hspace-> vertical
rows = 2
columns = 4
plot_styles = {
                "nuclei":{"marker":"o","color":"red", "alpha":0.3},
                "golgi":{"marker":"o","color":"green", "alpha":0.3},
                #"": {"color": "black",  "dashed": False, "alpha":1},
                "tp": {"color": "black",  "dashed": False, "alpha":1},
                "fp": {"color": "blue", "dashed": True, "alpha":1},
                "tn": None,
                "fn": {"color": "yellow",  "dashed": True, "alpha":1}
}

tp_total=0
fp_total=0
fn_total=0

for (index_fig,file) in enumerate(os.listdir(input_folder)):
    file_path_input = os.path.join(input_folder, file)
    file_path_results = os.path.join(results_folder, file)
    title = os.path.basename(file_path_input)
    title = os.path.splitext(title)[0]

    df_true, nodes_df_true, edges_df_true = load_df_from_csv(file_path_input)
    df_pred, nodes_df_pred, edges_df_pred = load_df_from_csv(file_path_results)
    nodes_df_pred[["X", "Y", "Z"]] += 1#correct the transformation that is applied when reading file in load_df_from_csv

    #map to same reference of IDs for nodes
    nodes_df = pd.concat([nodes_df_true, nodes_df_pred], ignore_index=True, sort=False).reset_index(drop=True)
    nodes_df["ID"] = nodes_df.index
    nodes_dict = dict(zip(zip(nodes_df['X'], nodes_df['Y'], nodes_df['Z']), nodes_df['ID']))

    nodes_df_true["NEW_ID"] = nodes_df_true.apply(lambda x: nodes_dict[tuple([x["X"], x["Y"], x["Z"]])], axis=1)#get new global ID
    nodes_df_true_dict = dict(zip(nodes_df_true.ID, nodes_df_true.NEW_ID))
    edges_df_true["source"] = edges_df_true["source"].apply(lambda x : nodes_df_true_dict[x])
    edges_df_true["target"] = edges_df_true["target"].apply(lambda x : nodes_df_true_dict[x])
    
    nodes_df_pred["NEW_ID"] = nodes_df_pred.apply(lambda x: nodes_dict[tuple([x["X"], x["Y"], x["Z"]])], axis=1)#get new global ID
    nodes_df_pred_dict = dict(zip(nodes_df_pred.ID, nodes_df_pred.NEW_ID))
    edges_df_pred["source"] = edges_df_pred["source"].apply(lambda x : nodes_df_pred_dict[x])
    edges_df_pred["target"] = edges_df_pred["target"].apply(lambda x : nodes_df_pred_dict[x])

    edges_df_pred= get_edge_label(edges_df_true, edges_df_pred)

    #Plot Graph
    figax = plt.subplot(rows, columns, index_fig+1, **subplot_params)
    edge_labels = edges_df_pred["edge_label"].tolist()
    df_figure = df_make_plot(figax, nodes_df, edges_df_pred, edge_labels, title, add_legends=False, plot_styles=plot_styles)
    sample_metrics = dict(edges_df_pred["edge_label"].value_counts())
    tp_total+=sample_metrics['tp'] if 'tp' in sample_metrics else 0
    fp_total+=sample_metrics['fp'] if 'fp' in sample_metrics else 0
    fn_total+=sample_metrics['fn'] if 'fn' in sample_metrics else 0
    sample_metrics = [
        f"TP: {sample_metrics['tp'] if 'tp' in sample_metrics else 0}",
        f"FP: {sample_metrics['fp'] if 'fp' in sample_metrics else 0}",
        f"FN: {sample_metrics['fn'] if 'fn' in sample_metrics else 0}"
    ]

    annote_figure(figax, sample_metrics)

fig_all_legend = add_legend_plot(fig_all, plot_styles)


params_file_path = os.path.join(results_folder, "params.json")
if os.path.isfile(params_file_path):
    with open(params_file_path, "r") as f:
        params_file = json.load(f)
    rouc_auc_score = params_file["aggregated_metrics"]["rouc_auc_score"]

    if "constraints" in results_folder:
        if "threshold" in results_folder:
            params_dict = params_file["aggregated_metrics"]["@best"]["@constraints"]["metrics"]
        else:
            params_dict = params_file["aggregated_metrics"]["@constraints"]["metrics"]
    else:
        params_dict = params_file["aggregated_metrics"]["@best"]["metrics"]

    count_annotations = [
            "Overall Metrics:",
            f"Acc.: {round(params_dict['acc'],3)}",
            f"TPR: {round(params_dict['TPR'],3)}",
            f"FPR: {round(params_dict['FPR'],3)}",
            f"Prec.: {round(params_dict['precision'],3)}",
            f"TP: {params_dict['tp']}",
            f"FP: {params_dict['fp']}",
            f"TN: {params_dict['tn']}",
            f"FN: {params_dict['fn']}"
    ]
    count_annotations = "\n".join(count_annotations)
else:
    count_annotations = "\n".join([
            "Overall Metrics:",
            #f"Acc.: {params_dict['acc']}",
            #f"Prec.: {round((tp_total/(tp_total+fp_total)),3)}",
            #f"Recall: {round((tp_total/(tp_total+fp_total)),3)}",
            f"TP: {tp_total}",
            f"FP: {fp_total}",
            #f"TN: {tn_total}",
            f"FN: {fn_total}"
    ])

# Add the annotation to the figure
fig_all.text(0.98, 0.9, count_annotations, ha='right', va='top', fontsize="small")
plt.show()

### True vs Pred graph by graph

In [None]:
#Plot input dataset vs predicted
input_folder = "../data/vectors"

results_folder_list = [
    {"title":"(b) Jonker-Volgenant", "path": r"../results/JonkerVolgenant_RealDataCNNAutomatic_k10/"},
    {"title":"(c) MPNN Non-Recurrent w/o Constraints", "path":r"../results/deep_learning_with_edge_feats_with_node_feats/results_real_automatic_normalized/trial1/Results_3"},
    {"title":"(d) MPNN Non-Recurrent w/ Constraints \"Greedy w/ Threshold\"", "path": r"../results/deep_learning_with_edge_feats_with_node_feats/results_real_automatic_normalized/trial1/Results_3_constraints_threshold"},
]


figures = []
number_of_cols = len(results_folder_list)+1
for file in os.listdir(input_folder):
    file_path_input = os.path.join(input_folder, file)
    graph_name = os.path.splitext(os.path.basename(file_path_input))[0]
    # Check if the graph_name ends with "_BC"
    if graph_name.endswith("_BC"):
        # Remove the "_BC" suffix
        graph_name = graph_name[:-3]
        
    title = "Predictions for "+ graph_name+" without Angular Features"

    subplot_params = {"projection":"3d"}
    fig_all = plt.figure(figsize=(26, 9), dpi= 300, layout='constrained')#18,9
    #plt.subplots_adjust(wspace=0.4, hspace=0.15)#wspace-> horizontal hspace-> vertical

    #True Graph
    figax = plt.subplot(1, number_of_cols, 1, **subplot_params)
    df, nodes_df, edges_df = load_df_from_csv(file_path_input)
    edge_labels = [""]*len(edges_df)
    df_figure = df_make_plot(figax, nodes_df, edges_df, edge_labels, "(a) Ground-truth")

    for i in range(2, number_of_cols+1):#start at 2, end at number_of_cols
        results_folder = results_folder_list[i-2]
        file_path_results = os.path.join(results_folder["path"], file)
        figax = plt.subplot(1, number_of_cols, i, **subplot_params)
        df, nodes_df, edges_df = load_df_from_csv(file_path_results)
        edge_labels = [""]*len(edges_df)
        df_figure = df_make_plot(figax, nodes_df, edges_df, edge_labels, results_folder["title"])

    #Append figure to figures List
    fig_all.suptitle(title, y=0.95, fontsize=20)#y=0.85
    plt.show()
    figures.append(fig_all)

### Merge images

In [7]:
from PIL import Image

def merge_images_vertically(image_paths, output_path):
    images = [Image.open(image_path) for image_path in image_paths]
    
    # Find the maximum width
    max_width = max(image.size[0] for image in images)
    
    # Resize all images to the maximum width
    resized_images = [image.resize((max_width, int(image.size[1] * (max_width / image.size[0]))), Image.LANCZOS) for image in images]
    
    total_height = sum(image.size[1] for image in resized_images)

    new_image = Image.new('RGB', (max_width, total_height))

    y_offset = 0
    for image in resized_images:
        new_image.paste(image, (0, y_offset))
        y_offset += image.size[1]

    new_image.save(output_path)

# Example usage:
image_paths = [
    #r"../figures/real_automatic_normalized/crop1.png",
      #r"../figures/real_automatic_normalized/crop2.png",
      #r"../figures/real_automatic_normalized/crop3.png",
      r"../figures/real_automatic_normalized/crop4.png",
      r"../figures/real_automatic_normalized/crop5.png",
      r"../figures/real_automatic_normalized/crop6.png",
      r"../figures/real_automatic_normalized/crop7.png",
      #r"../figures/real_automatic_normalized/crop8.png",
]  # Replace with your image paths
output_path = "merged-crops-4-5-7.jpg"  # Specify the output path

merge_images_vertically(image_paths, output_path)