In [4]:
import os
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import pandas as pd
from collections import defaultdict

In [2]:
def analyze_visual_pathways(connections_df, neuron_types_df, muted_types=None):
    """
    Analyze paths from photoreceptors to Kenyon cells

    Parameters:
    -----------
    connections_df: DataFrame with columns ['pre_root_id', 'post_root_id', 'syn_count']
    neuron_types_df: DataFrame with columns ['root_id', 'cell_type']
    muted_types: list of cell types to remove from analysis

    Returns:
    --------
    paths: dict of path statistics
    G: networkx DiGraph for visualization
    """
    # Filter out muted cell types
    if muted_types:
        active_neurons = neuron_types_df[~neuron_types_df['cell_type'].isin(muted_types)]
        connections = connections_df[
            connections_df['pre_root_id'].isin(active_neurons['root_id']) &
            connections_df['post_root_id'].isin(active_neurons['root_id'])
        ]
    else:
        connections = connections_df.copy()
        active_neurons = neuron_types_df.copy()

    # Create directed graph
    G = nx.DiGraph()
    for _, row in connections.iterrows():
        G.add_edge(row['pre_root_id'], row['post_root_id'], weight=row['syn_count'])

    # Identify source (R1-8) and target (KC) nodes
    sources = active_neurons[active_neurons['cell_type'].isin(['R1-6', 'R7', 'R8'])]['root_id']
    targets = active_neurons[active_neurons['cell_type'].str.contains('KC')]['root_id']

    # Analyze paths
    paths = defaultdict(list)
    for source in sources:
        for target in targets:
            try:
                # Get all simple paths up to some max length
                simple_paths = list(nx.all_simple_paths(G, source, target, cutoff=10))
                for path in simple_paths:
                    # Get cell types along path
                    cell_types = [
                        neuron_types_df[neuron_types_df['root_id'] == node]['cell_type'].iloc[0]
                        for node in path
                    ]
                    paths['source'].append(cell_types[0])
                    paths['target'].append(cell_types[-1])
                    paths['path_types'].append('->'.join(cell_types))
                    paths['path_length'].append(len(path))
            except nx.NetworkXNoPath:
                continue

    return pd.DataFrame(paths), G

In [3]:
def visualize_pathways(paths_df):
    """
    Create visualizations of pathway analysis
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Path length distribution
    sns.histplot(data=paths_df, x='path_length', ax=ax1)
    ax1.set_title('Distribution of Path Lengths')

    # Most common paths
    path_counts = paths_df['path_types'].value_counts().head(10)
    sns.barplot(x=path_counts.values, y=path_counts.index, ax=ax2)
    ax2.set_title('Most Common Pathways')

    return fig

def plot_cell_type_graph(G, neuron_types_df, layout='spring'):
    """
    Plot graph with nodes colored by cell type
    """
    cell_types = pd.Series(
        neuron_types_df['cell_type'].values,
        index=neuron_types_df['root_id']
    ).to_dict()

    pos = getattr(nx, f'{layout}_layout')(G)

    plt.figure(figsize=(12, 12))
    nx.draw(G, pos,
            node_color=[cell_types.get(node, 'grey') for node in G.nodes()],
            node_size=20,
            with_labels=False,
            edge_color='grey',
            alpha=0.6)

    return plt.gcf()

In [5]:
connections = pd.read_csv("new_data/connections.csv")
neuron_types = pd.read_csv(
            os.path.join("new_data", "classification.csv"),
            usecols=["root_id", "cell_type", "side"],
            dtype={"root_id": "string"},
        ).fillna("Unknown")

FileNotFoundError: [Errno 2] No such file or directory: 'new_data\\classification.csv'