In [None]:
%matplotlib inline
import os
import torch
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
cmap = plt.get_cmap("viridis")
from tqdm.contrib.concurrent import process_map
import multiprocessing

from configs import config
from connectome.core.data_processing import DataProcessor
from connectome.core.graph_models import FullGraphModel
from utils.model_inspection_funcs import neuron_data_from_image, propagate_data_without_deciding, sample_images

num_test_pairs = 20
device = torch.device("cpu")
dtype = torch.float32

In [None]:
data_processor = DataProcessor(
    neurons=config.neurons,
    voronoi_criteria=config.voronoi_criteria,
    random_synapses=config.random_synapses,
    log_transform_weights=config.log_transform_weights,
)

model = FullGraphModel(
    input_shape=data_processor.number_of_synapses,
    num_connectome_passes=config.NUM_CONNECTOME_PASSES,
    decision_making_vector=data_processor.decision_making_vector,
    batch_size=config.batch_size,
    dtype=config.dtype,
    edge_weights=data_processor.synaptic_matrix.data,
    device=config.DEVICE,
    num_classes=len(config.CLASSES),
).to(device)

In [None]:
# horrible data stuff
checkpoint = torch.load("models/model_2024-05-20 03:41:43.pth", map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.eval()
connections = (
    pd.read_csv(
        "adult_data/connections.csv",
        dtype={
            "pre_root_id": "string",
            "post_root_id": "string",
            "syn_count": np.int32,
        },
    )
    .groupby(["pre_root_id", "post_root_id"])
    .sum("syn_count")
    .reset_index()
)

connections["weight"] = model.connectome.edge_weight_multiplier.detach()
right_root_ids = data_processor.right_root_ids
all_neurons = (
    pd.read_csv("../../adult_data/classification_clean.csv")
    .merge(right_root_ids, on="root_id")
    .fillna("Unknown")
)
neuron_data = pd.read_csv(
    "../../adult_data/right_visual_positions_selected_neurons.csv",
    dtype={"root_id": "string"},
).drop(columns=["x", "y", "z", "PC1", "PC2"])
data_cols = ["x_axis", "y_axis"]
all_coords = pd.read_csv("../../adult_data/all_coords_clean.csv", dtype={"root_id": "string"})
rational_cell_types = pd.read_csv("../../adult_data/rational_cell_types.csv")
all_neurons["decision_making"] = np.where(
    all_neurons["cell_type"].isin(rational_cell_types["cell_type"].values.tolist()),
    1,
    0,
)
all_neurons["root_id"] = all_neurons["root_id"].astype("string")

In [None]:
def process_image(args):
    img, neuron_data, connections, all_coords, num_passes = args
    activated_data = neuron_data_from_image(img, neuron_data)
    propagation = propagate_data_without_deciding(
        activated_data, connections, all_coords, num_passes
    )
    return os.path.basename(img), propagation

In [None]:
num_passes = 4
base_dir = "images/five_to_fifteen/train"
sub_dirs = ["yellow", "blue"]

sampled_images = sample_images(base_dir, sub_dirs, num_test_pairs)

tasks = [
    (img, neuron_data, connections, all_coords, num_passes)
    for img in sampled_images
]
result_tuples = process_map(
    process_image, tasks, max_workers=multiprocessing.cpu_count() - 2, chunksize=1
)
dms = dict(result_tuples)

In [None]:
data = pd.DataFrame(dms)
means = pd.DataFrame(data.mean(axis=0))

In [None]:
neurons_in_coords = all_neurons.merge(
    all_coords, on="root_id", how="right"
    )[[ "root_id", "cell_type"]].fillna("Unknown")

# Set all cell_types with less than "20" samples to "others"
n = 100

counts = neurons_in_coords["cell_type"].value_counts()

small_categories = counts[counts < n].index
neurons_in_coords["cell_type"] = neurons_in_coords["cell_type"].apply(
    lambda x: "others" if x in small_categories else x
)
data["cell_type"] = neurons_in_coords["cell_type"]

In [None]:
# group by cell_type and take the mean of all the columns
means = pd.DataFrame(data.groupby("cell_type").mean()).T
means["yellow"] = [int(a.split("_")[1]) for a in means.index]
means["blue"] = [int(a.split("_")[2]) for a in means.index]
means["color"] = means[["yellow", "blue"]].idxmax(axis=1)
means = means.drop(columns=["yellow", "blue"])

In [None]:
col_means = means.groupby("color").mean().T
col_means["diff"] = (col_means["yellow"] - col_means["blue"]) / np.where(col_means["yellow"] != 0, col_means["yellow"], 1)
col_means["abs_diff"] = np.abs(col_means["diff"])
col_means.sort_values("abs_diff", ascending=False)

In [None]:
ct_counts = all_neurons["cell_type"].value_counts()
ct_counts

In [None]:
def accuracy_per_celltype(ct):
    temp = means[[ct, "color"]]
    temp["pred"] = np.where(temp[ct] > temp[ct].median(), "yellow", "blue")
    acc = (temp["color"] == temp["pred"]).mean()
    return acc if acc > 0.5 else 1 - acc

In [None]:
accs = {}
for ct in list(col_means.index):
    accs[ct] = accuracy_per_celltype(ct)

In [None]:
df = pd.DataFrame(accs, index=[0]).T
df.sort_values(0, ascending=False)