In [1]:
import networkx as nx
import numpy as np
import os
import pickle as pkl
import pandas as pd

In [2]:
def _remove_cycles(candidate_avg_dag: nx.DiGraph, probability_matrix: np.array) -> nx.DiGraph:
    print("\tstarted removing cycles...")
    while not nx.is_directed_acyclic_graph(candidate_avg_dag):
        cycle = nx.find_cycle(candidate_avg_dag)
        
        min_score = float("inf")
        edge_to_remove = None

        # find and remove the edge with the lowest score in the cycle
        for u, v in cycle:
            if probability_matrix[u, v] < min_score:
                min_score = probability_matrix[u, v]
                edge_to_remove = (u, v)
        
        print(f"\t\tremoving edge with score {min_score}")
        candidate_avg_dag.remove_edge(*edge_to_remove)
    
    return candidate_avg_dag

def _connect_subgraphs(candidate_avg_dag: nx.DiGraph, probability_matrix: np.array)  -> nx.DiGraph:
    # NOTE: here we are risking creating edges which never appear in any graph
    # if max_score remains 0, a random edge is generated to connect the disconnected subgraphs
    print("\tstarted connecting subgraphs...")
    while not nx.is_weakly_connected(candidate_avg_dag):
        subgraphs = [
            g 
            for g in list(nx.weakly_connected_components(candidate_avg_dag)) 
            if len(g) > 1
        ]
        print(f"\t\tsubgraphs found: {len(subgraphs)}")
        
        if len(subgraphs) == 1:
            break

        # sort subgraphs by size, largest first
        subgraphs = sorted(subgraphs, key=len, reverse=True) 

        largest_subgraph = subgraphs[0]
        second_largest_subgraph = subgraphs[1]

        best_edge = None
        max_score = -float('inf')

        # find the edge with the highest score between the two subgraphs
        for u in largest_subgraph:
            for v in second_largest_subgraph:
                if probability_matrix[u, v] > max_score:
                    max_score = probability_matrix[u, v]
                    best_edge = (u, v)

        # add the best edge to the graph
        print(f"\t\tadding edge with score {max_score}")
        candidate_avg_dag.add_edge(*best_edge)

    return candidate_avg_dag

In [3]:
data = {
    "visual": {
        "ges": [16,19,23,25,26,28,30,34,37,39,42,45,46,56,64,65], # missing 21, 41
        "pruned_fci": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
        "pruned_pc": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
    },
    "audio": {
        "ges": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
        "pruned_fci": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
        "pruned_pc": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
    },
    "physio": {
        "ges": [16,19,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65], # missing 21
        "pruned_fci": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
        "pruned_pc": [16,19,21,23,25,26,28,30,34,37,39,41,42,45,46,56,64,65],
    },
}

DAG_PATH = "../results_dag"

for modality, v in data.items():
    # compute the max number of nodes in a tree
    # max nodes = n principal components + 2 features
    modality_df = pd.read_csv(f"../data/subset_{modality}_pca.csv")
    num_nodes = len(
        [
            col_name
            for col_name in modality_df.columns
            if col_name != "Participant"
        ]
    )
    
    for method, dag_participant_ids in v.items():
        output_path = os.path.join(
            DAG_PATH,
            modality,
            f"avg_{method}_dag.pkl"
        )

        # initialise edge counter
        edge_tally_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
        # initialise average graph edge matrix
        avg_graph_matrix = edge_tally_matrix.copy()

        print(f"Generating average graph for {modality} {method}...")

        for dag_id in dag_participant_ids:
            dag_file_path = os.path.join(
                DAG_PATH,
                modality,
                f"{method}_dag_participant_{dag_id}.pkl"
            )

            with open(dag_file_path, "rb") as file:
                dag_data = pkl.load(file)

                graph_matrix = (
                    dag_data["G"].graph if method == "ges" 
                    else dag_data[0].graph if method in ["fci", "pruned_fci"]
                    else dag_data.G.graph if method in ["pc", "pruned_pc"]
                    else None
                )

                # ensure matrix contains all nodes
                assert graph_matrix.shape == (num_nodes, num_nodes), f"Graph matrix does not contain expected nodes: {dag_file_path}"

                # iterate through the graph and update the tally matrix
                for i in range(num_nodes):
                    for j in range(num_nodes):
                        if graph_matrix[i, j] == -1 and graph_matrix[j, i] == 1:  # i --> j
                            edge_tally_matrix[i, j] += 1
                        # elif graph_matrix[i, j] == graph_matrix[j, i] == 1:  # i <-> j bidirectional
                        #     tally_matrix[i, j] += 1
                        #     tally_matrix[j, i] += 1
                        # elif graph_matrix[i, j] == 2 and graph_matrix[j, i] == 1:  # i o-> j partially directed (ignore?)
                        #     partially_directed_edges_count += 1

        probability_matrix = edge_tally_matrix / len(dag_participant_ids)

        # here using a fixed value of .5 risks having no graph
        # computing mean as that of existing edge probabilities
        threshold = np.mean(probability_matrix[probability_matrix != 0])
        adjacency_matrix = (probability_matrix >= threshold).astype(int)
        candidate_avg_dag = nx.DiGraph(adjacency_matrix)

        candidate_avg_dag = _remove_cycles(candidate_avg_dag, probability_matrix)
        candidate_avg_dag = _connect_subgraphs(candidate_avg_dag, probability_matrix)

        for edge in candidate_avg_dag.edges:
            avg_graph_matrix[edge[0], edge[1]] = -1
            avg_graph_matrix[edge[1], edge[0]] = 1

        with open(output_path, "wb") as f:
            pkl.dump(avg_graph_matrix, f)

Generating average graph for visual ges...
	started removing cycles...
		removing edge with score 0.3125
		removing edge with score 0.25
		removing edge with score 0.25
		removing edge with score 0.3125
		removing edge with score 0.3125
		removing edge with score 0.25
		removing edge with score 0.25
		removing edge with score 0.3125
		removing edge with score 0.25
		removing edge with score 0.375
		removing edge with score 0.25
		removing edge with score 0.25
		removing edge with score 0.25
		removing edge with score 0.25
		removing edge with score 0.25
		removing edge with score 0.25
		removing edge with score 0.25
		removing edge with score 0.25
		removing edge with score 0.3125
		removing edge with score 0.25
		removing edge with score 0.3125
		removing edge with score 0.3125
		removing edge with score 0.25
		removing edge with score 0.3125
		removing edge with score 0.25
		removing edge with score 0.25
		removing edge with score 0.25
		removing edge with score 0.3125
		removing edg