In [1]:
%matplotlib inline
import os
import torch
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
cmap = plt.get_cmap("viridis")
from tqdm.notebook import tqdm
from tqdm.contrib.concurrent import process_map
import multiprocessing
from sklearn.metrics import confusion_matrix
import seaborn as sns

import config
from graph_models import FullGraphModel
from complete_training_data_processing import CompleteModelsDataProcessor
from model_inspection_funcs import sample_images
from no_training import process_image_without_deciding

num_test_pairs = 1000
device = torch.device("cpu")
dtype = torch.float32

In [2]:
data_processor = CompleteModelsDataProcessor(
    neurons=config.neurons,
    voronoi_criteria=config.voronoi_criteria,
    random_synapses=config.random_synapses,
    log_transform_weights=config.log_transform_weights,
)

model = FullGraphModel(
    input_shape=data_processor.number_of_synapses,
    num_connectome_passes=config.NUM_CONNECTOME_PASSES,
    decision_making_vector=data_processor.decision_making_vector,
    batch_size=config.batch_size,
    dtype=config.dtype,
    edge_weights=data_processor.synaptic_matrix.data,
    device=config.DEVICE,
    num_classes=len(config.CLASSES),
).to(device)

In [3]:
# horrible data stuff
checkpoint = torch.load(
    "models/n_all_v_R7_r_False_lr_0.003_p_4_2024-05-27 12:45.pth", map_location="cpu"
)
model.load_state_dict(checkpoint["model"])
model.eval()
connections = (
    pd.read_csv(
        "adult_data/connections.csv",
        dtype={
            "pre_root_id": "string",
            "post_root_id": "string",
            "syn_count": np.int32,
        },
    )
    .groupby(["pre_root_id", "post_root_id"])
    .sum("syn_count")
    .reset_index()
)

connections["weight"] = model.connectome.edge_weight_multiplier.detach()
right_root_ids = data_processor.right_root_ids
all_neurons = (
    pd.read_csv("adult_data/classification_clean.csv")
    .merge(right_root_ids, on="root_id")
    .fillna("Unknown")
)
neuron_data = pd.read_csv(
    "adult_data/right_visual_positions_selected_neurons.csv",
    dtype={"root_id": "string"},
).drop(columns=["x", "y", "z", "PC1", "PC2"])
data_cols = ["x_axis", "y_axis"]
all_coords = pd.read_csv("adult_data/all_coords_clean.csv", dtype={"root_id": "string"})
rational_cell_types = pd.read_csv("adult_data/rational_cell_types.csv")
all_neurons["decision_making"] = np.where(
    all_neurons["cell_type"].isin(rational_cell_types["cell_type"].values.tolist()),
    1,
    0,
)
all_neurons["root_id"] = all_neurons["root_id"].astype("string")

neurons_in_coords = all_neurons.merge(all_coords, on="root_id", how="right")[
    ["root_id", "cell_type"]
].fillna("Unknown")

# Set all cell_types with less than "n" samples to "others"
n = 1

counts = neurons_in_coords["cell_type"].value_counts()

small_categories = counts[counts < n].index
neurons_in_coords["cell_type"] = neurons_in_coords["cell_type"].apply(
    lambda x: "others" if x in small_categories else x
)

In [4]:
num_passes = 4
base_dir = "images/zero_to_ten/train"
sub_dirs = ["yellow", "blue"]

sampled_images = sample_images(base_dir, sub_dirs, num_test_pairs)

tasks = [
    (img, neuron_data, connections, all_coords, num_passes)
    for img in sampled_images
]
result_tuples = process_map(
    process_image_without_deciding, tasks, max_workers=multiprocessing.cpu_count() - 2, chunksize=1
)
dms = dict(result_tuples)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [None]:
data = pd.DataFrame(dms)
df = data.T
df["num_points"] = [int(a.split("_")[1]) + int(a.split("_")[2]) for a in df.index]

In [None]:
model_response = {0: 1, 1: .5, 2: .4, 3: .2, 4: 0}
def get_normalized_response(n, total):
    acts = []
    for i in range(1, total + 1):
        dist = abs(n - i)
        if dist in model_response.keys():
            acts.append(model_response[dist])
        else:
            acts.append(0)
    
    return acts

[1, 0.5, 0.4, 0.2, 0, 0, 0, 0, 0, 0]

In [None]:
means = df.groupby("num_points").mean()
# normalize all columns to 0-1
means = (means - means.min()) / (means.max() - means.min())
# remove columns with missing data
means = means.dropna(axis=1)
for i in range(1, 11):
    means[f"tuning_curve_{i}"] = get_normalized_response(i, 10)
temp = means.T

In [None]:
def top_neurons_tuned_to_i(df, i):
    col_name = f"tuning_curve_{i}"
    
    tuned = (df[df[i] == df.max(axis=1)]).T
    if len(tuned) > 10000:
        return "Too many neurons tuned to this curve"
    correlations = tuned.corr()[col_name].sort_values(ascending=False)
    # remove "one_tuning_curve" from the list
    correlations = correlations.drop(col_name)
    top_correlations = correlations[correlations > 0.9]

    return neurons_in_coords.merge(
        top_correlations, left_index=True, right_index=True
    ).sort_values(by=col_name, ascending=False)

In [32]:
tns = {}
for i in range(1, 11):
    tns[f"num_{i}"] = top_neurons_tuned_to_i(temp, i)

: 

In [17]:
tns

{}

In [68]:
[a.shape for a in tns.values()]

[(2614, 3), (55, 3), (1443, 3), (618, 3)]

In [73]:
vc1 = tns["num_1"]["cell_type"].value_counts()
vc2 = tns["num_2"]["cell_type"].value_counts()
vc3 = tns["num_3"]["cell_type"].value_counts()
vc4 = tns["num_4"]["cell_type"].value_counts()

In [74]:
vc1

Unknown      779
Mi15         116
T2a           70
Tm1           48
Tm20          43
            ... 
SMP404b        1
DNge044        1
CB2171         1
CB.FB5D10      1
Lawf1          1
Name: cell_type, Length: 552, dtype: int64

In [82]:
tns["num_1"]

Unnamed: 0,root_id,cell_type,tuning_curve_1
43113,720575940620321728,Unknown,0.998762
60712,720575940623381956,Unknown,0.998254
68576,720575940624814333,Tm5c,0.997962
21669,720575940614778535,Tm5c,0.997962
49962,720575940621529435,Unknown,0.997559
...,...,...,...
4025,720575940606915714,Tm20,0.900063
105763,720575940631711563,CB3851,0.900051
18284,720575940613686698,Unknown,0.900048
117907,720575940635248223,Unknown,0.900038
