In [1]:
%matplotlib inline
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from imageio.v3 import imread

cmap = plt.cm.binary

import connectome
config = connectome.get_config()

from connectome.core.data_processing import DataProcessor
from connectome.core.graph_models import FullGraphModel
from connectome.core.train_funcs import get_activation_from_cell_type, assign_cell_type
from utils.model_inspection_utils import process_image, process_and_plot_data

device = torch.device("cpu")
dtype = torch.float32

In [None]:
data_processor = DataProcessor(config)
model = FullGraphModel(data_processor, config).to(device)

In [4]:
connections = data_processor._get_connections()

In [13]:
import os
from paths import PROJECT_ROOT
from utils.model_inspection_funcs import (
    neuron_data_from_image,
    propagate_neuron_data,
    sample_images,
)
from utils.model_inspection_utils import propagate_data_with_steps
neuron_data = pd.read_csv(
    os.path.join(PROJECT_ROOT, "new_data", "right_visual_positions_selected_neurons.csv"),
    dtype={"root_id": "string"},
).drop(columns=["x", "y", "z", "PC1", "PC2"])

num_passes = 4
base_dir = os.path.join(PROJECT_ROOT, "images", "one_to_ten", "train")
sub_dirs = ["yellow", "blue"]

sampled_images = sample_images(base_dir, sub_dirs, 1)
img = sampled_images[0]
activated_data = neuron_data_from_image(img, neuron_data)

propagation = (
    activated_data[["root_id", "activation"]]
    .fillna(0)
    .rename(columns={"activation": "input"})
)
activation = activated_data[["root_id", "activation"]]
connections["weight"] = 1

for i in range(num_passes):
    activation = propagate_data_with_steps(activation.copy(), connections, i)
    propagation = propagation.merge(activation, on="root_id", how="left").fillna(0)

In [15]:
# find percentage of non-zero values in activation_4
propagation["activation_4"].astype(bool).sum() / len(propagation)

0.5563860903477413

In [18]:
# reshuffle column post_rood_id of the dataframe connections
shuffled_connections = connections.copy()
shuffled_connections["post_root_id"] = np.random.permutation(
    connections["post_root_id"]
)
propagation = (
    activated_data[["root_id", "activation"]]
    .fillna(0)
    .rename(columns={"activation": "input"})
)
for i in range(num_passes):
    activation = propagate_data_with_steps(activation.copy(), shuffled_connections, i)
    propagation = propagation.merge(activation, on="root_id", how="left").fillna(0)

# find percentage of non-zero values in activation_4
propagation["activation_4"].astype(bool).sum() / len(propagation)

0.7491062723431914

In [19]:
# reshuffle column post_rood_id of the dataframe connections
random_equalized_connections = pd.read_csv(os.path.join(PROJECT_ROOT, "new_data", "connections_random.csv"))
propagation = (
    activated_data[["root_id", "activation"]]
    .fillna(0)
    .rename(columns={"activation": "input"})
)
for i in range(num_passes):
    activation = propagate_data_with_steps(activation.copy(), shuffled_connections, i)
    propagation = propagation.merge(activation, on="root_id", how="left").fillna(0)

# find percentage of non-zero values in activation_4
propagation["activation_4"].astype(bool).sum() / len(propagation)

0.7491062723431914

In [None]:
# horrible data stuff
checkpoint = torch.load("models/model_2024-05-20 03:41:43.pth", map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.eval()
connections = (
    pd.read_csv(
        "adult_data/connections.csv",
        dtype={
            "pre_root_id": "string",
            "post_root_id": "string",
            "syn_count": np.int32,
        },
    )
    .groupby(["pre_root_id", "post_root_id"])
    .sum("syn_count")
    .reset_index()
)

connections["weight"] = model.connectome.edge_weight_multiplier.detach()
right_root_ids = data_processor.right_root_ids
all_neurons = (
    pd.read_csv("adult_data/classification_clean.csv")
    .merge(right_root_ids, on="root_id")
    .fillna("Unknown")
)
neuron_data = pd.read_csv(
    "adult_data/right_visual_positions_selected_neurons.csv",
    dtype={"root_id": "string"},
).drop(columns=["x", "y", "z", "PC1", "PC2"])
data_cols = ["x_axis", "y_axis"]
all_coords = pd.read_csv("adult_data/all_coords_clean.csv", dtype={"root_id": "string"})
rational_cell_types = pd.read_csv("adult_data/rational_cell_types.csv")
all_neurons["decision_making"] = np.where(
    all_neurons["cell_type"].isin(rational_cell_types["cell_type"].values.tolist()),
    1,
    0,
)
all_neurons["root_id"] = all_neurons["root_id"].astype("string")

In [None]:
data = pd.DataFrame(dms)
means = pd.DataFrame(data.mean(axis=0))
means = means.rename(columns={0: "mean"})
means["yellow"] = [int(a.split("_")[1]) for a in means.index]
means["blue"] = [int(a.split("_")[2]) for a in means.index]
means["color"] = means[["yellow", "blue"]].idxmax(axis=1)

In [None]:
sns.boxplot(x="color", y="mean", data=means)

In [None]:
means["pred"] = np.where(means["mean"] > means["mean"].mean(), "yellow", "blue")
# confusion matrix between color and pred

confusion_matrix(means["color"], means["pred"]) / len(means)

In [None]:
print(f"accuracy = {np.mean(means['color'] == means['pred'])}")

# With reshuffled weights

In [None]:
# reshuffle column post_rood_id of the dataframe connections
shuffled_connections = connections.copy()
shuffled_connections["post_root_id"] = np.random.permutation(connections["post_root_id"])

In [None]:
num_passes = 4
base_dir = "images/five_to_fifteen/train"
sub_dirs = ["yellow", "blue"]

sampled_images = sample_images(base_dir, sub_dirs, num_test_pairs)

dms = {}
for img in tqdm(sampled_images):
    activated_data = neuron_data_from_image(img, neuron_data)
    propagation = propagate_neuron_data(
        activated_data, shuffled_connections, all_coords, all_neurons, num_passes
    )
    dms[os.path.basename(img)] = propagation["decision_making"][
        all_neurons["decision_making"] == 1
    ]

In [None]:
data = pd.DataFrame(dms)
means = pd.DataFrame(data.mean(axis=0))
means = means.rename(columns={0: "mean"})
means["yellow"] = [int(a.split("_")[1]) for a in means.index]
means["blue"] = [int(a.split("_")[2]) for a in means.index]
means["color"] = means[["yellow", "blue"]].idxmax(axis=1)

In [None]:
%matplotlib inline
sns.boxplot(x="color", y="mean", data=means)
plt.show()

In [None]:
means["pred"] = np.where(means["mean"] > means["mean"].mean(), "yellow", "blue")
# confusion matrix between color and pred

confusion_matrix(means["color"], means["pred"]) / len(means)

In [None]:
print(f"accuracy = {np.mean(means['color'] == means['pred'])}")