In [1]:
import os

import numpy as np
import pandas as pd
import torch
from torch import autocast, device, cuda
from torch.cuda.amp import GradScaler
from torch.nn import CrossEntropyLoss
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from adult_models_helpers import get_synapse_df
from flyvis.examples.flyvision_ans import DECODING_CELLS
from graph_models import GNNModel
from retina_to_connectome import get_activation_tensor, get_batch_voronoi_averages, voronoi_averages_to_df

device_type = "cuda" if cuda.is_available() else "cpu"
DEVICE = device(device_type)
last_good_frame = 8

  _C._set_default_tensor_type(t)


In [2]:
# get data
activations_dir = "flyvis/parsed_objects"
activations = np.load(os.path.join(activations_dir, "decoding_activations.npy"), allow_pickle=True)
# labels = np.load(os.path.join(activations_dir, "decoding_labels.npy"), allow_pickle=True)
# toy labels as a tensor with all 1
labels = torch.ones(activations.shape[0], dtype=torch.long, device=DEVICE)
classification = pd.read_csv("adult_data/classification.csv")

# remove duplicated root_ids
classification = classification.drop_duplicates(subset='root_id')

In [3]:
avgs_dict = {}
for cell_type in tqdm(DECODING_CELLS):
    number_of_cells = len(classification[classification["cell_type"] == cell_type])
    if number_of_cells > 0:
        activation_tensor = get_activation_tensor(activations, cell_type, last_frame=last_good_frame) / 255
        avgs_dict[cell_type] = get_batch_voronoi_averages(activation_tensor, n_centers=number_of_cells)

100%|██████████| 34/34 [00:11<00:00,  3.00it/s]


In [4]:
result_df = voronoi_averages_to_df(avgs_dict)

In [5]:
# Extract cell types and activations
cell_types = result_df.iloc[:, -1]  # Last column for cell type
activations = result_df.iloc[:, :-1]  # Exclude the last column

# Create a dictionary to hold shuffled root_ids for each cell type
root_id_mapping = {}

# Populate the dictionary with shuffled root_ids for each cell type
for cell_type, group in classification.groupby("cell_type"):
    # Shuffle the root_ids within each group
    shuffled_root_ids = group['root_id'].sample(frac=1).values
    root_id_mapping[cell_type] = shuffled_root_ids

# Function to assign root_ids to each row in result_df based on cell type and available root_ids
def assign_root_ids(row):
    cell_type = row.iloc[-1]  # Get cell type from the last column
    # Get the list of shuffled root_ids for this cell type
    root_ids = root_id_mapping[cell_type]
    # Assign a root_id from the list, ensuring we don't exceed the list's length
    # The index in the list is the count of occurrences of this cell type so far, modulo the number of available root_ids
    root_id_index = row.name % len(root_ids)  # row.name is the index of the row in the dataframe
    return root_ids[root_id_index]

# Apply the function to result_df, creating a new 'root_id' column
result_df['root_id'] = result_df.apply(assign_root_ids, axis=1)

# Remove duplicated root_ids
result_df = result_df.drop_duplicates(subset='root_id')


In [6]:
activation_df = pd.merge(
    classification.drop(
        columns=["flow", "super_class", "class", "sub_class", 
                 "hemibrain_type", "hemilineage", "side", "nerve"]), 
    result_df.drop(columns=[result_df.columns[-2]]), on='root_id', how='left').fillna(0)

In [7]:
synapse_df = get_synapse_df()

In [8]:
# Step 1: Identify Common Neurons
# Unique root_ids in merged_df
neurons_merged = pd.unique(activation_df['root_id'])

# Unique root_ids in synapse_df (both pre and post)
neurons_synapse_pre = pd.unique(synapse_df['pre_root_id'])
neurons_synapse_post = pd.unique(synapse_df['post_root_id'])
neurons_synapse = np.unique(np.concatenate([neurons_synapse_pre, neurons_synapse_post]))

# Common neurons
common_neurons = np.intersect1d(neurons_merged, neurons_synapse)

# Step 2: Filter synapse_df
# Keep only rows with both pre and post root_ids in common_neurons
from scipy.sparse import coo_matrix

# Filter synapse_df to include only rows with both pre and post root_ids in common_neurons
filtered_synapse_df = synapse_df[
    synapse_df['pre_root_id'].isin(common_neurons) & synapse_df['post_root_id'].isin(common_neurons)
]

# Map neuron root_ids to matrix indices
root_id_to_index = {root_id: index for index, root_id in enumerate(common_neurons)}

# Convert root_ids in filtered_synapse_df to matrix indices
pre_indices = filtered_synapse_df['pre_root_id'].map(root_id_to_index).values
post_indices = filtered_synapse_df['post_root_id'].map(root_id_to_index).values

# Use syn_count as the data for the non-zero elements of the matrix
data = filtered_synapse_df['syn_count'].values

# Create a sparse matrix in COO format
synaptic_matrix_sparse = coo_matrix(
    (data, (pre_indices, post_indices)),
    shape=(len(common_neurons), len(common_neurons)),
    dtype=np.int64  # or np.float32/np.float64 if memory issue persists
)

In [9]:
activation_df = activation_df[activation_df['root_id'].isin(list(root_id_to_index.keys()))]
activation_data = activation_df.drop(columns=["root_id", "cell_type"])

In [10]:
# get the info from data/cell_type_rational.csv
cell_type_rational = pd.read_csv("data/cell_type_rational.csv")
# get the cell types with rational = 1
rational_cell_types = cell_type_rational[cell_type_rational["rational"] == 1]["cell_type"]
# find, using the root to index dictionary, the indices of the rational cells
rational_indices = [root_id_to_index[root_id] for root_id in activation_df[activation_df["cell_type"].isin(rational_cell_types)]["root_id"]]
decision_making_vector = np.zeros(activation_df.shape[0], dtype=int)
decision_making_vector[rational_indices] = 1

In [11]:
from torch.nn import BCEWithLogitsLoss

batch_size = 10

edges = torch.tensor(np.array([synaptic_matrix_sparse.row, synaptic_matrix_sparse.col]), dtype=torch.long, device=DEVICE)
activation_tensor = torch.tensor(activation_data.values, dtype=torch.float16, device=DEVICE)

# move the decision-making vector to the device
decision_making_vector = torch.tensor(decision_making_vector, dtype=torch.float16, device=DEVICE).detach()

# Correctly set node features for each graph
graph_list = []
for i in range(activation_tensor.shape[1]):  # Iterate over samples, ensuring the second dimension is the sample dimension
    node_features = activation_tensor[:, i].unsqueeze(1)  # Shape [num_nodes, 1], one feature per node
    graph = Data(x=node_features.to(DEVICE), edge_index=edges, y=labels[i])  # Create a graph for each sample
    graph_list.append(graph)

# DataLoader to handle batches of graphs
loader = DataLoader(graph_list, batch_size=batch_size, shuffle=False)

# Initialize the model
model = GNNModel(num_node_features=1, decision_making_vector=decision_making_vector).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scaler = GradScaler()

# Initialize the loss function
criterion = BCEWithLogitsLoss()

model.train()
for batch_idx, batch in tqdm(enumerate(loader)):
    batch = batch.to(DEVICE)

    optimizer.zero_grad()
    
    with autocast(device_type):
        out = model(batch)
        loss = criterion(out, batch.y.unsqueeze(-1).float())    
    # Backward pass and optimize
    loss.backward()
    optimizer.step()


10it [00:00, 10.64it/s]


# TODO
1. Literature review to identify the "thinking" neurons
2. Identify these neurons in the classification dataframe and create a class_labels tensor
3. Train the model with the class_labels tensor
4. Check model accuracy and weber ratio
5. Try with other model architectures, specially with a one-hot encoding for each neuron type to simulate different neurons

Kenyon Cells (KC): KCab, KCapbp-m, KCapbp-ap1, KCapbp-ap2
T4/T5 Neurons: These are involved in motion detection and possibly could be implicated in processing visual information related to numerosity. The neurons you've listed include T4a, T4b, T4c, T4d, T5a, T5b, T5c, T5d.
Central Complex Neurons: These are involved in a variety of integrative brain functions which could include decision-making processes. The neurons from your list include C2, C3.
torch.unique(batch.y)

In [13]:
len(out)

1

In [15]:
batch.y

tensor([1])