In [1]:
import torch
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
cmap = plt.get_cmap("viridis")
from tqdm.notebook import tqdm
from tqdm.contrib.concurrent import process_map
import multiprocessing

import config
from graph_models import FullGraphModel
from data_processing import DataProcessor
from model_inspection_funcs import sample_images
from no_training import process_image_without_deciding

num_test_pairs = 500
device = torch.device("cpu")
dtype = torch.float32

In [2]:
data_processor = DataProcessor(
    neurons=config.neurons,
    voronoi_criteria=config.voronoi_criteria,
    random_synapses=config.random_synapses,
    log_transform_weights=config.log_transform_weights,
)

model = FullGraphModel(
    input_shape=data_processor.number_of_synapses,
    num_connectome_passes=config.NUM_CONNECTOME_PASSES,
    decision_making_vector=data_processor.decision_making_vector,
    batch_size=config.batch_size,
    dtype=config.dtype,
    edge_weights=data_processor.synaptic_matrix.data,
    device=config.DEVICE,
    num_classes=len(config.CLASSES),
).to(device)

In [3]:
# horrible data stuff
checkpoint = torch.load(
    "models/n_all_v_R7_r_False_lr_0.003_p_4_2024-05-27 21:45.pth", map_location="cpu"
)
model.load_state_dict(checkpoint["model"])
model.eval()
connections = (
    pd.read_csv(
        "adult_data/connections.csv",
        dtype={
            "pre_root_id": "string",
            "post_root_id": "string",
            "syn_count": np.int32,
        },
    )
    .groupby(["pre_root_id", "post_root_id"])
    .sum("syn_count")
    .reset_index()
)

connections["weight"] = model.connectome.edge_weight_multiplier.detach()
right_root_ids = data_processor.root_ids
all_neurons = (
    pd.read_csv("adult_data/classification_clean.csv")
    .merge(right_root_ids, on="root_id")
    .fillna("Unknown")
)
neuron_data = pd.read_csv(
    "adult_data/right_visual_positions_selected_neurons.csv",
    dtype={"root_id": "string"},
).drop(columns=["x", "y", "z", "PC1", "PC2"])
data_cols = ["x_axis", "y_axis"]
all_coords = pd.read_csv("adult_data/all_coords_clean.csv", dtype={"root_id": "string"})
rational_cell_types = pd.read_csv("adult_data/rational_cell_types.csv")
all_neurons["decision_making"] = np.where(
    all_neurons["cell_type"].isin(rational_cell_types["cell_type"].values.tolist()),
    1,
    0,
)
all_neurons["root_id"] = all_neurons["root_id"].astype("string")

neurons_in_coords = all_neurons.merge(all_coords, on="root_id", how="right")[
    ["root_id", "cell_type"]
].fillna("Unknown")

# Set all cell_types with less than "n" samples to "others"
n = 1

counts = neurons_in_coords["cell_type"].value_counts()

small_categories = counts[counts < n].index
neurons_in_coords["cell_type"] = neurons_in_coords["cell_type"].apply(
    lambda x: "others" if x in small_categories else x
)

In [4]:
num_passes = 4
base_dir = "images/one_colour"
sub_dirs = [str(i) for i in range(1, 10)]

sampled_images = sample_images(base_dir, sub_dirs, num_test_pairs)

tasks = [
    (img, neuron_data, connections, all_coords, num_passes)
    for img in sampled_images
]
result_tuples = process_map(
    process_image_without_deciding, tasks, max_workers=multiprocessing.cpu_count() - 2, chunksize=1
)
dms = dict(result_tuples)

  0%|          | 0/3600 [00:00<?, ?it/s]

In [5]:
data = pd.DataFrame(dms)
df = data.T
df["num_points"] = [int(a.split("_")[1]) for a in df.index]

In [9]:
model_response = {0: 1, 1: .5, 2: .4, 3: .2, 4: 0}
def get_normalized_response(n, total):
    acts = []
    for i in range(1, total):
        dist = abs(n - i)
        if dist in model_response.keys():
            acts.append(model_response[dist])
        else:
            acts.append(0)
    
    return acts

In [10]:
means = df.groupby("num_points").mean()
# normalize all columns to 0-1
means = (means - means.min()) / (means.max() - means.min())
# remove columns with missing data
means = means.dropna(axis=1)
for i in range(1, 10):
    means[f"tuning_curve_{i}"] = get_normalized_response(i, 10)
temp = means.T

In [15]:
temp.to_csv("neuron_responses.csv")

In [18]:
import gc

del tasks, result_tuples, dms, data, df
gc.collect()

0

In [28]:
(temp[temp[1] == temp.max(axis=1)]).T.shape

(9, 61359)

In [41]:
def top_neurons_tuned_to_i(df, i, method="pearson"):
    col_name = f"tuning_curve_{i}"
    
    tuned = (df[df[i] == df.max(axis=1)]).T
    print(tuned.shape[1])
    if tuned.shape[1] > 12000:
        return "Too many neurons tuned to this curve"
    correlations = tuned.corr(method)[col_name].sort_values(ascending=False)
    # remove "tune_curve_i" from the list
    correlations = correlations.drop(col_name)
    top_correlations = correlations[correlations > 0.9]

    return neurons_in_coords.merge(
        top_correlations, left_index=True, right_index=True
    ).sort_values(by=col_name, ascending=False)

In [42]:
tns = {}
for i in range(1, 10):
    tns[f"num_{i}"] = top_neurons_tuned_to_i(temp, i, method="spearman")

61359
27569
3359
664
413
501
903
3009
4538


: 

In [39]:
vcs = {}
for name, tn in tns.items():
    if type(tn) != str:
        vcs[name] = tn["cell_type"].value_counts()    

In [40]:
# check if some cell types are in all the top neurons
common = set(vcs["num_3"].index)
for name, vc in vcs.items():
    if type(vc) != str:
        print(name)
        common = common.intersection(set(vc.index))
        print(common)


num_3
{'CB1379', 'CB2974', 'CB1294', 'CB2612', 'CB0559', 'T3', 'CB1955', 'CB1917', 'DNge094', 'CB2771', 'Mi10', 'CB0927', 'CB2376', 'CB1569', 'SLP162b', 'CB1380', 'CB2987', 'L5', 'CB0693', 'Li13', 'Tm16', 'AOTU008d', 'CB3425', 'CB1356', 'LTe63', 'cL13', 'CB2779', 'Dm1', 'CB2298', 'CB2250', 'CB2007', 'Mi1', 'CB2041', 'CB3754', 'CB1896', 'DNge133', 'LAL030c', 'MTe53', 'SMP022a', 'CB2956', 'CB2271', 'LLPC2', 'CB1774', 'DNge081', 'CB1698', 'CB1016', 'DNg17', 'CB2228', 'CB1660', 'LC10a', 'CB0800', 'LTe09', 'CB0955', 'DNge229', 'Tm5c', 'LLPC4', 'CB3439', 'CB3531', 'CB1518', 'CB2550', 'CB1586', 'AVLP312a', 'SMP469c', 'CB2485', 'CB2785', 'CB1505', 'lLN2X12', 'MeTu2', 'PLP191,PLP192b', 'CB1682', 'CB1691', 'CB1618', 'CB0224', 'CB1964', 'CB3550', 'PS004b', 'CB3088', 'DNge044', 'DNge012', 'CB2763', 'CB3011', 'CB3858', 'CB2455', 'CB2665', 'T4a', 'CB3382', 'CB2840', 'CB2582', 'CB1049', 'CB3816', 'CB1783', 'TmY16', 'CB1919', 'CB0496', 'CB1730', 'AOTU050b', 'CB2000', 'CB2505', 'CB2151', 'LTe49b', 'CB1

In [26]:
vcs["num_2"]

Unknown    9
T2a        3
Pm2        2
Tm3        2
CB2846     1
CB3379     1
L5         1
Tm2        1
Dm2        1
L3         1
CB3064     1
CB1415     1
L1         1
Mi9        1
CB0989     1
CB2592     1
CB2995     1
CB1301     1
CB2809     1
Name: cell_type, dtype: int64

In [27]:
vcs["num_3"]

LCe03      1
L1         1
Unknown    1
CB3092     1
CB2453     1
Name: cell_type, dtype: int64

In [28]:
tns["num_3"]

Unnamed: 0,root_id,cell_type,tuning_curve_3
133598,720575940644282528,LCe03,0.920465
68448,720575940624792051,L1,0.919709
50908,720575940621678942,Unknown,0.907401
128216,720575940639686973,CB3092,0.906043
42416,720575940620173733,CB2453,0.901155


In [29]:
tns["num_4"]

Unnamed: 0,root_id,cell_type,tuning_curve_4


In [30]:
tns["num_5"]

Unnamed: 0,root_id,cell_type,tuning_curve_5


In [31]:
tns["num_6"]

Unnamed: 0,root_id,cell_type,tuning_curve_6
55556,720575940622478341,CB3857,0.904525
