In [1]:
from itertools import combinations, product
import networkx as nx
import pandas as pd
import pickle as pkl

In [None]:
# generating a pandas df to store similarity scores
all_p_ids = [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65]

layer_1 = ["visual", "audio", "physio"]
layer_2 = ["ges:pruned_fci", "ges:pruned_pc", "pruned_fci:pruned_pc"]

multi_columns = pd.MultiIndex.from_product([layer_1, layer_2])

results_df = pd.DataFrame(index=all_p_ids, columns=multi_columns)

Unnamed: 0_level_0,visual,visual,visual,audio,audio,audio,physio,physio,physio
Unnamed: 0_level_1,ges:pruned_fci,ges:pruned_pc,pruned_fci:pruned_pc,ges:pruned_fci,ges:pruned_pc,pruned_fci:pruned_pc,ges:pruned_fci,ges:pruned_pc,pruned_fci:pruned_pc
16,,,,,,,,,
19,,,,,,,,,
21,,,,,,,,,
23,,,,,,,,,
25,,,,,,,,,
26,,,,,,,,,
28,,,,,,,,,
30,,,,,,,,,
34,,,,,,,,,
37,,,,,,,,,


In [8]:
def get_edges(dag_path, method:str):
    with open(dag_path, "rb") as file:
        dag_data = pkl.load(file)

    graph_matrix = (
        dag_data["G"].graph if method == "ges" 
        else dag_data[0].graph if method in ["fci", "pruned_fci"]
        else dag_data.G.graph if method in ["pc", "pruned_pc"]
        else dag_data if method == "avg"
        else None
    )

    # extract edges from the learned DAG
    num_nodes = graph_matrix.shape[0]
    edges = []

    # iterate through all node pairs to determine edge types
    for i, j in product(range(num_nodes), range(num_nodes)):
        if graph_matrix[i, j] == -1 and graph_matrix[j, i] == 1:  # fully directed edge i --> j
            edges.append((i, j))

    return edges


def jaccard_similarity(graph1_edges, graph2_edges):
    edges1 = set(graph1_edges)
    edges2 = set(graph2_edges)
    intersection = len(edges1 & edges2)
    union = len(edges1 | edges2)
    return intersection / union if union != 0 else 0

In [4]:
data = {
    "visual": {
        "ges": [16,19,23,25,26,28,30,34,37,39,42,45,46,56,64,65], # missing 21, 41
        "pruned_fci": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
        "pruned_pc": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
    },
    "audio": {
        "ges": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
        "pruned_fci": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
        "pruned_pc": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
    },
    "physio": {
        "ges": [16,19,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65], # missing 21
        "pruned_fci": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
        "pruned_pc": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
    },
}

for modality, methods in data.items():
    method_pairs = list(combinations(methods, 2))
    for pair in method_pairs:
        participants = set(methods[pair[0]]) & set(methods[pair[1]])
        for participant_id in sorted(participants):
            print(f"Processing: {modality}, {participant_id}, {pair[0]}:{pair[1]}")
            dag_path_0 = f"../results_dag/{modality}/{pair[0]}_dag_participant_{participant_id}.pkl"
            dag_path_1 = f"../results_dag/{modality}/{pair[1]}_dag_participant_{participant_id}.pkl"

            edges_0 = get_edges(dag_path_0, pair[0])
            edges_1 = get_edges(dag_path_1, pair[1])

            dag_0 = nx.DiGraph(edges_0)
            dag_1 = nx.DiGraph(edges_1)

            # compute the graph edit distance
            ged = nx.optimize_graph_edit_distance(dag_0, dag_1)

            results_df.loc[participant_id, (modality, f"{pair[0]}:{pair[1]}")] = next(ged)

results_df.to_csv("ged_approximations.csv")

Processing: visual, 16, ges:pruned_fci
Processing: visual, 19, ges:pruned_fci
Processing: visual, 23, ges:pruned_fci
Processing: visual, 25, ges:pruned_fci
Processing: visual, 26, ges:pruned_fci
Processing: visual, 28, ges:pruned_fci
Processing: visual, 30, ges:pruned_fci
Processing: visual, 34, ges:pruned_fci
Processing: visual, 37, ges:pruned_fci
Processing: visual, 39, ges:pruned_fci
Processing: visual, 42, ges:pruned_fci
Processing: visual, 45, ges:pruned_fci
Processing: visual, 46, ges:pruned_fci
Processing: visual, 56, ges:pruned_fci
Processing: visual, 64, ges:pruned_fci
Processing: visual, 65, ges:pruned_fci
Processing: visual, 16, ges:pruned_pc
Processing: visual, 19, ges:pruned_pc
Processing: visual, 23, ges:pruned_pc
Processing: visual, 25, ges:pruned_pc
Processing: visual, 26, ges:pruned_pc
Processing: visual, 28, ges:pruned_pc
Processing: visual, 30, ges:pruned_pc
Processing: visual, 34, ges:pruned_pc
Processing: visual, 37, ges:pruned_pc
Processing: visual, 39, ges:pruned

In [9]:
for modality, methods in data.items():
    method_pairs = list(combinations(methods, 2))
    for pair in method_pairs:
        participants = set(methods[pair[0]]) & set(methods[pair[1]])
        for participant_id in sorted(participants):
            print(f"Processing: {modality}, {participant_id}, {pair[0]}:{pair[1]}")
            dag_path_0 = f"../results_dag/{modality}/{pair[0]}_dag_participant_{participant_id}.pkl"
            dag_path_1 = f"../results_dag/{modality}/{pair[1]}_dag_participant_{participant_id}.pkl"

            edges_0 = get_edges(dag_path_0, pair[0])
            edges_1 = get_edges(dag_path_1, pair[1])

            jaccard = jaccard_similarity(edges_0, edges_1)

            results_df.loc[participant_id, (modality, f"{pair[0]}:{pair[1]}")] = jaccard

Processing: visual, 16, ges:pruned_fci
Processing: visual, 19, ges:pruned_fci
Processing: visual, 23, ges:pruned_fci
Processing: visual, 25, ges:pruned_fci
Processing: visual, 26, ges:pruned_fci
Processing: visual, 28, ges:pruned_fci
Processing: visual, 30, ges:pruned_fci
Processing: visual, 34, ges:pruned_fci
Processing: visual, 37, ges:pruned_fci
Processing: visual, 39, ges:pruned_fci
Processing: visual, 42, ges:pruned_fci
Processing: visual, 45, ges:pruned_fci
Processing: visual, 46, ges:pruned_fci
Processing: visual, 56, ges:pruned_fci
Processing: visual, 64, ges:pruned_fci
Processing: visual, 65, ges:pruned_fci
Processing: visual, 16, ges:pruned_pc
Processing: visual, 19, ges:pruned_pc
Processing: visual, 23, ges:pruned_pc
Processing: visual, 25, ges:pruned_pc
Processing: visual, 26, ges:pruned_pc
Processing: visual, 28, ges:pruned_pc
Processing: visual, 30, ges:pruned_pc
Processing: visual, 34, ges:pruned_pc
Processing: visual, 37, ges:pruned_pc
Processing: visual, 39, ges:pruned

In [11]:
results_df.to_csv("inter_graph_jaccard.csv")