In [1]:
%matplotlib inline
import os
import torch
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
cmap = plt.get_cmap("viridis")
from tqdm.notebook import tqdm
from tqdm.contrib.concurrent import process_map
import multiprocessing
import seaborn as sns

import config
from graph_models import FullGraphModel
from data_processing import DataProcessor
from model_inspection_funcs import neuron_data_from_image, propagate_data_without_deciding, sample_images

num_test_pairs = 20
device = torch.device("cpu")
dtype = torch.float32

In [2]:
data_processor = DataProcessor(config)
model = FullGraphModel(data_processor, config).to(device)

In [None]:
# horrible data stuff
checkpoint = torch.load("models/m_2025-02-05 11:46_qr7soxee.pth", map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.eval()


In [None]:
connections = (
    pd.read_csv(
        "new_data/connections_refined.csv",
        dtype={
            "pre_root_id": "string",
            "post_root_id": "string",
            "syn_count": np.int32,
        },
    )
    .groupby(["pre_root_id", "post_root_id"])
    .sum("syn_count")
    .reset_index()
)

connections["weight"] = model.connectome.edge_weight_multiplier.detach()
right_root_ids = data_processor.right_root_ids

neuron_data = pd.read_csv(
    "new_data/right_visual_positions_selected_neurons.csv",
    dtype={"root_id": "string"},
).drop(columns=["x", "y", "z", "PC1", "PC2"])

ValueError: Length of values (14996423) does not match length of index (15091983)

In [None]:
connections = data_processor._get_connections(True)
neuron_classification = data_processor._get_neurons()

In [27]:
all_neurons = neuron_classification.merge(data_processor.root_ids, on="root_id").fillna("Unknown")
ix_conns = connections.merge(
    all_neurons[["root_id", "index_id"]], left_on="pre_root_id", right_on="root_id"
).merge(
    all_neurons[["root_id", "index_id"]],
    left_on="post_root_id",
    right_on="root_id",
    suffixes=("_pre", "_post"),
)

# Sort and group
ix_conns = ix_conns.sort_values(["index_id_pre", "index_id_post"])
grouped = (
    ix_conns.groupby(["index_id_pre", "index_id_post"], sort=True)["syn_count"]
    .sum()
    .reset_index()
)

# Round any values very close to zero to exactly zero
grouped["syn_count"] = np.where(
    np.abs(grouped["syn_count"]) < 1e-10, 0, grouped["syn_count"]
)

In [28]:
grouped = grouped[grouped["syn_count"] != 0]

In [26]:
synapses = model.connectome.edge_weight.detach().numpy()

In [24]:
connections["weight"] = model.connectome.edge_weight_multiplier.detach()

ValueError: Length of values (14996423) does not match length of index (15053597)

In [4]:
def process_image(args):
    img, neuron_data, connections, all_coords, num_passes = args
    activated_data = neuron_data_from_image(img, neuron_data)
    propagation = propagate_data_without_deciding(
        activated_data, connections, all_coords, num_passes
    )
    return os.path.basename(img), propagation

In [None]:
num_passes = 4
base_dir = "images/one_to_ten/train"
sub_dirs = ["yellow", "blue"]

sampled_images = sample_images(base_dir, sub_dirs, num_test_pairs)

tasks = [
    (img, neuron_data, connections, all_coords, num_passes)
    for img in sampled_images
]
result_tuples = process_map(
    process_image, tasks, max_workers=multiprocessing.cpu_count() - 2, chunksize=1
)
dms = dict(result_tuples)

  0%|          | 0/40 [00:00<?, ?it/s]

In [13]:
data = pd.DataFrame(dms)
means = pd.DataFrame(data.mean(axis=0))

In [58]:
neurons_in_coords = all_neurons.merge(
    all_coords, on="root_id", how="right"
    )[[ "root_id", "cell_type"]].fillna("Unknown")

# Set all cell_types with less than "20" samples to "others"
n = 100

counts = neurons_in_coords["cell_type"].value_counts()

small_categories = counts[counts < n].index
neurons_in_coords["cell_type"] = neurons_in_coords["cell_type"].apply(
    lambda x: "others" if x in small_categories else x
)
data["cell_type"] = neurons_in_coords["cell_type"]

In [61]:
# group by cell_type and take the mean of all the columns
means = pd.DataFrame(data.groupby("cell_type").mean()).T
means["yellow"] = [int(a.split("_")[1]) for a in means.index]
means["blue"] = [int(a.split("_")[2]) for a in means.index]
means["color"] = means[["yellow", "blue"]].idxmax(axis=1)
means = means.drop(columns=["yellow", "blue"])

In [63]:
col_means = means.groupby("color").mean().T
col_means["diff"] = (col_means["yellow"] - col_means["blue"]) / np.where(col_means["yellow"] != 0, col_means["yellow"], 1)
col_means["abs_diff"] = np.abs(col_means["diff"])
col_means.sort_values("abs_diff", ascending=False)

color,blue,yellow,diff,abs_diff
cell_type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
LC28,28.714743,-10.778294,3.664127,3.664127
TmY3,53.932394,15.232143,-2.540696,2.540696
CB3816,7.197084,2.308986,-2.116989,2.116989
Dm11,-0.159142,0.702482,1.226543,1.226543
Tm20,0.519000,-7.775932,1.066744,1.066744
...,...,...,...,...
Tm5a,55.354335,54.923887,-0.007837,0.007837
MTe01,598.510789,599.744590,0.002057,0.002057
R1-6,0.102870,0.102780,-0.000874,0.000874
LB3,0.000000,0.000000,0.000000,0.000000


In [55]:
ct_counts = all_neurons["cell_type"].value_counts()
ct_counts

In [95]:
def accuracy_per_celltype(ct):
    temp = means[[ct, "color"]]
    temp["pred"] = np.where(temp[ct] > temp[ct].median(), "yellow", "blue")
    acc = (temp["color"] == temp["pred"]).mean()
    return acc if acc > 0.5 else 1 - acc

In [96]:
accs = {}
for ct in list(col_means.index):
    accs[ct] = accuracy_per_celltype(ct)

In [104]:
df = pd.DataFrame(accs, index=[0]).T
df.sort_values(0, ascending=False)