In [5]:
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt


def visualize_graph_from_files_with_metrics_and_coordinates(coordinates_file, metrics_file, metrics_to_plot,
                                                            statistics_to_plot):
    """
    Visualizes an interactive 3D graph using MNI node coordinates and metrics/statistics.

    :param coordinates_file: Path to the CSV file with the MNI coordinates of the nodes.
    :param metrics_file: Path to the CSV file with the metrics of the nodes.
    :param metrics_to_plot: List of metrics to visualize (e.g., ['closeness', 'degree']).
    :param statistics_to_plot: List of statistics to visualize (e.g., ['Mean', 'Median']).
    """
    # Load data from both files
    coordinates_data = pd.read_csv(coordinates_file)
    metrics_data = pd.read_csv(metrics_file)

    # Remove extra spaces in column names
    coordinates_data.columns = coordinates_data.columns.str.strip()
    metrics_data.columns = metrics_data.columns.str.strip()

    # Check if 'Node' column and other necessary columns are present
    if 'Node' not in coordinates_data.columns or 'Node' not in metrics_data.columns:
        raise KeyError("'Node' column not found in one of the CSV files.")

    if 'Metric' not in metrics_data.columns:
        raise KeyError("'Metric' column not found in the metrics CSV file.")

    # Merge the coordinates and metrics data based on the 'Node' column
    merged_data = pd.merge(coordinates_data, metrics_data, on='Node', how='inner')

    # Function to create a graph with MNI coordinates and metrics
    def create_graph_with_metrics_and_coordinates(data):
        G = nx.Graph()

        # Add nodes with their MNI coordinates and metrics
        for _, row in data.iterrows():
            node = row['Node']
            # Add coordinate and metric data for each node
            G.add_node(node, x=row['x.mni'], y=row['y.mni'], z=row['z.mni'],
                       metric=row['Metric'], mean=row['Mean'], median=row['Median'], stddev=row['Standard Deviation'])

        return G

    # Create the graph with coordinates and metrics
    G = create_graph_with_metrics_and_coordinates(merged_data)

    # Function to visualize the graph interactively in 3D with Plotly
    def visualize_graph_by_statistic_interactive(G, metric, statistic, title="Graph Visualization"):
        """
        Visualizes a 3D graph with real positions and colors nodes based on the selected statistic.
        """
        # Get the positions (coordinates) of the nodes
        pos = {node: (G.nodes[node]['x'], G.nodes[node]['y'], G.nodes[node]['z']) for node in G.nodes()}

        # Extract the values of the selected statistic to color the nodes
        values = []
        for node in G.nodes():
            # Find the value corresponding to the node and selected metric
            node_data = merged_data[(merged_data['Node'] == node) & (merged_data['Metric'] == metric)]

            # Access the correct value for the selected statistic
            if not node_data.empty:
                value = node_data[statistic].values[0]  # Directly access the statistic value
            else:
                value = 0  # Default value in case no data is available
            values.append(value)

        # Use a vibrant colormap for coloring the nodes
        colormap = plt.cm.plasma
        norm = mcolors.Normalize(vmin=min(values), vmax=max(values))

        # Extract the MNI coordinates of the nodes
        x = [pos[i][0] for i in G.nodes()]
        y = [pos[i][1] for i in G.nodes()]
        z = [pos[i][2] for i in G.nodes()]

        # Create the interactive 3D plot for the nodes
        trace_nodes = go.Scatter3d(
            x=x, y=y, z=z,
            mode='markers+text',
            text=[str(node) for node in G.nodes()],
            textposition="top center",
            marker=dict(
                size=10,
                color=values,
                colorscale='plasma',
                colorbar=dict(title=f'{metric} ({statistic})')
            ),
            name=f'Nodes ({metric} - {statistic})'
        )

        # Create the edges of the graph
        edge_x, edge_y, edge_z = [], [], []
        for edge in G.edges():
            x0, y0, z0 = pos[edge[0]]
            x1, y1, z1 = pos[edge[1]]
            edge_x.extend([x0, x1, None])
            edge_y.extend([y0, y1, None])
            edge_z.extend([z0, z1, None])

        trace_edges = go.Scatter3d(
            x=edge_x, y=edge_y, z=edge_z,
            mode='lines',
            line=dict(color='black', width=2),
            name='Edges'
        )

        # Set up the layout of the graph
        layout = go.Layout(
            title=title,
            scene=dict(
                xaxis=dict(title='X'),
                yaxis=dict(title='Y'),
                zaxis=dict(title='Z')
            ),
            showlegend=True
        )

        # Show the interactive graph
        fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
        fig.show()

    # Visualize the graph for each combination of metric and statistic selected
    for metric in metrics_to_plot:
        for statistic in statistics_to_plot:
            visualize_graph_by_statistic_interactive(G, metric=metric, statistic=statistic,
                                                     title=f"Graph Colored by {metric.capitalize()} - {statistic}")


In [6]:
# Define the path to the coordinates data file
coordinates_file = 'aal116.csv'  # Path to the file with node coordinates
metrics_file = '../computing/analysis/ppmi/60_70/comparison/pd/node_differences.csv'  # Path to the file with node metrics
metrics_to_plot = ['closeness', 'degree', 'clustering']  # Metrics you want to visualize (e.g., 'closeness', 'degree', 'clustering')
statistics_to_plot = ['Mean']  # Statistics you want to visualize (e.g., 'Mean')

# Visualize the graph with the selected parameters
visualize_graph_from_files_with_metrics_and_coordinates(coordinates_file, metrics_file, metrics_to_plot, statistics_to_plot)