In [None]:
import torch
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
cmap = plt.get_cmap("viridis")
from tqdm.notebook import tqdm
from tqdm.contrib.concurrent import process_map
import multiprocessing

from configs import config
from connectome.core.data_processing import DataProcessor
from connectome.core.graph_models import FullGraphModel
from utils.model_inspection_funcs import sample_images
from scripts.no_training import process_image_without_deciding

num_test_pairs = 500
device = torch.device("cpu")
dtype = torch.float32

In [None]:
data_processor = DataProcessor(
    neurons=config.neurons,
    voronoi_criteria=config.voronoi_criteria,
    random_synapses=config.random_synapses,
    log_transform_weights=config.log_transform_weights,
)

model = FullGraphModel(
    input_shape=data_processor.number_of_synapses,
    num_connectome_passes=config.NUM_CONNECTOME_PASSES,
    decision_making_vector=data_processor.decision_making_vector,
    batch_size=config.batch_size,
    dtype=config.dtype,
    edge_weights=data_processor.synaptic_matrix.data,
    device=config.DEVICE,
    num_classes=len(config.CLASSES),
).to(device)

In [None]:
# horrible data stuff
checkpoint = torch.load(
    "models/n_all_v_R7_r_False_lr_0.003_p_4_2024-05-27 21:45.pth", map_location="cpu"
)
model.load_state_dict(checkpoint["model"])
model.eval()
connections = (
    pd.read_csv(
        "adult_data/connections.csv",
        dtype={
            "pre_root_id": "string",
            "post_root_id": "string",
            "syn_count": np.int32,
        },
    )
    .groupby(["pre_root_id", "post_root_id"])
    .sum("syn_count")
    .reset_index()
)

connections["weight"] = model.connectome.edge_weight_multiplier.detach()
right_root_ids = data_processor.root_ids
all_neurons = (
    pd.read_csv("../../adult_data/classification_clean.csv")
    .merge(right_root_ids, on="root_id")
    .fillna("Unknown")
)
neuron_data = pd.read_csv(
    "../../adult_data/right_visual_positions_selected_neurons.csv",
    dtype={"root_id": "string"},
).drop(columns=["x", "y", "z", "PC1", "PC2"])
data_cols = ["x_axis", "y_axis"]
all_coords = pd.read_csv("../../adult_data/all_coords_clean.csv", dtype={"root_id": "string"})
rational_cell_types = pd.read_csv("../../adult_data/rational_cell_types.csv")
all_neurons["decision_making"] = np.where(
    all_neurons["cell_type"].isin(rational_cell_types["cell_type"].values.tolist()),
    1,
    0,
)
all_neurons["root_id"] = all_neurons["root_id"].astype("string")

neurons_in_coords = all_neurons.merge(all_coords, on="root_id", how="right")[
    ["root_id", "cell_type"]
].fillna("Unknown")

# Set all cell_types with less than "n" samples to "others"
n = 1

counts = neurons_in_coords["cell_type"].value_counts()

small_categories = counts[counts < n].index
neurons_in_coords["cell_type"] = neurons_in_coords["cell_type"].apply(
    lambda x: "others" if x in small_categories else x
)

In [None]:
num_passes = 4
base_dir = "images/one_colour"
sub_dirs = [str(i) for i in range(1, 10)]

sampled_images = sample_images(base_dir, sub_dirs, num_test_pairs)

tasks = [
    (img, neuron_data, connections, all_coords, num_passes)
    for img in sampled_images
]
result_tuples = process_map(
    process_image_without_deciding, tasks, max_workers=multiprocessing.cpu_count() - 2, chunksize=1
)
dms = dict(result_tuples)

In [None]:
data = pd.DataFrame(dms)
df = data.T
df["num_points"] = [int(a.split("_")[1]) for a in df.index]

In [None]:
model_response = {0: 1, 1: .5, 2: .4, 3: .2, 4: 0}
def get_normalized_response(n, total):
    acts = []
    for i in range(1, total):
        dist = abs(n - i)
        if dist in model_response.keys():
            acts.append(model_response[dist])
        else:
            acts.append(0)
    
    return acts

In [None]:
means = df.groupby("num_points").mean()
# normalize all columns to 0-1
means = (means - means.min()) / (means.max() - means.min())
# remove columns with missing data
means = means.dropna(axis=1)
for i in range(1, 10):
    means[f"tuning_curve_{i}"] = get_normalized_response(i, 10)
temp = means.T

In [None]:
temp.to_csv("neuron_responses.csv")

In [None]:
import gc

del tasks, result_tuples, dms, data, df
gc.collect()

In [None]:
(temp[temp[1] == temp.max(axis=1)]).T.shape

In [None]:
def top_neurons_tuned_to_i(df, i, method="pearson"):
    col_name = f"tuning_curve_{i}"
    
    tuned = (df[df[i] == df.max(axis=1)]).T
    print(tuned.shape[1])
    if tuned.shape[1] > 12000:
        return "Too many neurons tuned to this curve"
    correlations = tuned.corr(method)[col_name].sort_values(ascending=False)
    # remove "tune_curve_i" from the list
    correlations = correlations.drop(col_name)
    top_correlations = correlations[correlations > 0.9]

    return neurons_in_coords.merge(
        top_correlations, left_index=True, right_index=True
    ).sort_values(by=col_name, ascending=False)

In [None]:
tns = {}
for i in range(1, 10):
    tns[f"num_{i}"] = top_neurons_tuned_to_i(temp, i, method="spearman")

In [None]:
vcs = {}
for name, tn in tns.items():
    if type(tn) != str:
        vcs[name] = tn["cell_type"].value_counts()    

In [None]:
# check if some cell types are in all the top neurons
common = set(vcs["num_3"].index)
for name, vc in vcs.items():
    if type(vc) != str:
        print(name)
        common = common.intersection(set(vc.index))
        print(common)


In [None]:
vcs["num_2"]

In [None]:
vcs["num_3"]

In [None]:
tns["num_3"]

In [None]:
tns["num_4"]

In [None]:
tns["num_5"]

In [None]:
tns["num_6"]