In [None]:
%matplotlib inline
import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go

cmap = plt.get_cmap("viridis")

from configs import config
from connectome.core.graph_models import FullGraphModel
from connectome.core.data_processing import CompleteModelsDataProcessor
from utils.model_inspection_utils import get_neuron_type
from utils.model_inspection_funcs import (
    activation_cols_and_colours,
    neuron_data_from_image,
    propagate_neuron_data,
    sample_images,
)

device = torch.device("cpu")
dtype = torch.float32

In [None]:
data_processor = CompleteModelsDataProcessor(
    neurons=config.neurons,
    voronoi_criteria=config.voronoi_criteria,
    random_synapses=config.random_synapses,
    log_transform_weights=config.log_transform_weights,
)

model = FullGraphModel(
    input_shape=data_processor.number_of_synapses,
    num_connectome_passes=config.NUM_CONNECTOME_PASSES,
    decision_making_vector=data_processor.decision_making_vector,
    batch_size=config.batch_size,
    dtype=config.dtype,
    edge_weights=data_processor.synaptic_matrix.data,
    device=config.DEVICE,
    num_classes=len(config.CLASSES),
).to(device)

In [None]:
# horrible data stuff
checkpoint = torch.load("models/model_2024-05-20 03:41:43.pth", map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.eval()
connections = (
    pd.read_csv(
        "adult_data/connections.csv",
        dtype={
            "pre_root_id": "string",
            "post_root_id": "string",
            "syn_count": np.int32,
        },
    )
    .groupby(["pre_root_id", "post_root_id"])
    .sum("syn_count")
    .reset_index()
)

connections["weight"] = model.connectome.edge_weight_multiplier.detach()
right_root_ids = data_processor.right_root_ids
all_neurons = (
    pd.read_csv("../../adult_data/classification_clean.csv")
    .merge(right_root_ids, on="root_id")
    .fillna("Unknown")
)
right_visual_neurons = data_processor.voronoi_cells.get_tesselated_neurons().merge(
    right_root_ids, on="root_id"
)
neuron_data = pd.read_csv(
    "../../adult_data/right_visual_positions_selected_neurons.csv",
    dtype={"root_id": "string"},
).drop(columns=["x", "y", "z", "PC1", "PC2"])
data_cols = ["x_axis", "y_axis"]
all_coords = pd.read_csv("../../adult_data/all_coords_clean.csv", dtype={"root_id": "string"})
rational_cell_types = pd.read_csv("../../adult_data/rational_cell_types.csv")
all_neurons["decision_making"] = np.where(
    all_neurons["cell_type"].isin(rational_cell_types["cell_type"].values.tolist()),
    1,
    0,
)
all_neurons["root_id"] = all_neurons["root_id"].astype("string")

In [None]:
def prepare_data_for_plot(propagation, activation_cols):
    df = propagation.copy()
    # Normalize activations

    df["color"] = df[activation_cols].apply(
        lambda x: get_neuron_type(x, activation_cols), axis=1
    )
    df["plot_activation"] = df.apply(
        lambda x: 0 if x["color"] == "no_activation" else x[x["color"]], axis=1
    )
    # df[activation_cols] = df[activation_cols].apply(lambda x: normalize_non_zero(x, total_min, total_max))
    df["size"] = np.log1p(np.abs(df["plot_activation"])) + 2
    return df


In [None]:
def plot_brain(df, image_path, colours):
    fig = go.Figure()
    for category, colour in colours.items():
        data = df[df["color"] == category]
        fig.add_trace(
            go.Scatter3d(
                x=data["x"],
                y=data["y"],
                z=data["z"],
                mode="markers",
                marker=dict(
                    size=data["size"], color=colour, line=dict(width=0), opacity=0.5
                ),
                name=category,
                visible=True,
            )
        )

    # Update layout for a clean background
    fig.update_layout(
        title=os.path.basename(image_path),
        scene=dict(
            xaxis=dict(title="X", showbackground=False, showgrid=False, zeroline=False),
            yaxis=dict(title="Y", showbackground=False, showgrid=False, zeroline=False),
            zaxis=dict(title="Z", showbackground=False, showgrid=False, zeroline=False),
            xaxis_title="X",
            yaxis_title="Y",
            zaxis_title="Z",
        ),
        scene_aspectmode="auto",
        margin=dict(l=0, r=0, b=0, t=30),
    )

    fig.show()

In [None]:
def process_and_plot(img_path, neuron_data, coords, neurons, num_passes):
    neuron_data = neuron_data_from_image(img_path, neuron_data)
    propagation = propagate_neuron_data(neuron_data, connections, coords, neurons, num_passes)
    activation_cols, colours = activation_cols_and_colours(num_passes)
    df = prepare_data_for_plot(propagation, activation_cols)
    plot_brain(df, img_path, colours)

In [None]:
base_dir = "images/five_to_fifteen/train"
sub_dirs = ["yellow", "blue"]

sampled_images = sample_images(base_dir, sub_dirs, 10)
process_and_plot(sampled_images[0], neuron_data, all_coords, all_neurons, num_passes=4)

In [None]:
process_and_plot(sampled_images[1], neuron_data, all_coords, all_neurons, num_passes=4)