In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import pandas as pd
from utils.model_inspection_funcs import propagate_data_with_steps
from scripts.no_training import get_data

In [None]:
def propagate_neuron_data(neuron_data, connections, neurons, num_passes):
    propagation = (neuron_data[["root_id", "activation"]]
                   .fillna(0)
                   .rename(columns={"activation": "input"})
    )
    activation = neuron_data[["root_id", "activation"]]

    for i in range(num_passes):
        activation = propagate_data_with_steps(activation.copy(), connections, i)
        propagation = propagation.merge(activation, on="root_id", how="left").fillna(0)

    cols = propagation.columns.tolist()
    propagation = propagation.merge(
        neurons[["root_id", "decision_making"]], on="root_id"
    )
    propagation["decision_making"] = (
        propagation["decision_making"] * propagation[cols[-1]]
    )
    return propagation.drop(columns=[cols[-1]])


def analyze_detour_pathways(
    neuron_data, connections, neurons, ablated_types, num_passes=4
):
    """
    Analyze how information flows when specific cell types are ablated

    Parameters:
    -----------
    neuron_data: DataFrame with neuron properties including cell types
    connections: DataFrame with synaptic connections
    ablated_types: list of cell types to remove
    """
    # First get baseline propagation
    baseline = propagate_neuron_data(
        neuron_data, connections, neurons, num_passes
    )

    # Remove ablated cell types
    mask = ~neuron_data["cell_type"].isin(ablated_types)
    ablated_neurons = neuron_data[mask].copy()
    ablated_connections = connections[
        connections["pre_root_id"].isin(ablated_neurons["root_id"])
        & connections["post_root_id"].isin(ablated_neurons["root_id"])
    ]

    # Get propagation with ablated cells
    ablated = propagate_neuron_data(
        ablated_neurons, ablated_connections, neurons, num_passes
    )

    return baseline, ablated

In [None]:
def visualize_pathways(paths_df):
    """
    Create visualizations of pathway analysis
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Path length distribution
    sns.histplot(data=paths_df, x='path_length', ax=ax1)
    ax1.set_title('Distribution of Path Lengths')

    # Most common paths
    path_counts = paths_df['path_types'].value_counts().head(10)
    sns.barplot(x=path_counts.values, y=path_counts.index, ax=ax2)
    ax2.set_title('Most Common Pathways')

    return fig

def plot_cell_type_graph(G, neuron_types_df, layout='spring'):
    """
    Plot graph with nodes colored by cell type
    """
    cell_types = pd.Series(
        neuron_types_df['cell_type'].values,
        index=neuron_types_df['root_id']
    ).to_dict()

    pos = getattr(nx, f'{layout}_layout')(G)

    plt.figure(figsize=(12, 12))
    nx.draw(G, pos,
            node_color=[cell_types.get(node, 'grey') for node in G.nodes()],
            node_size=20,
            with_labels=False,
            edge_color='grey',
            alpha=0.6)

    return plt.gcf()

In [None]:
connections, _, all_neurons, neuron_data, _ = get_data("new_data")

# No muting

In [None]:
ablated_types = []
connectivity = analyze_detour_pathways(
    neuron_data, connections, all_neurons, ablated_types
)

In [None]:
visualize_pathways(paths)

In [None]:
muted_types = ["L1", "L2", "L3", "L4", "L5"]