In [None]:

import os
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch

from graph_models import GNNModel
from retina_to_connectome import get_activation_tensor, get_batch_voronoi_averages
from adult_models_helpers import get_synapse_df
from utils import flush_cuda_memory
from flyvis.examples.flyvision_ans import DECODING_CELLS

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
last_good_frame = 8

In [None]:
# get data
activations_dir = "flyvis/parsed_objects"
activations = np.load(os.path.join(activations_dir, "decoding_activations.npy"), allow_pickle=True)
classification = pd.read_csv("adult_data/classification.csv")

# remove duplicated root_ids
classification = classification.drop_duplicates(subset='root_id')

In [None]:
classification.head(20)

In [None]:
avgs_dict = {}
for cell_type in tqdm(DECODING_CELLS):
    number_of_cells = len(classification[classification["cell_type"] == cell_type])
    if number_of_cells > 0:
        activation_tensor = get_activation_tensor(activations, cell_type, last_frame=last_good_frame) / 255
        avgs_dict[cell_type] = get_batch_voronoi_averages(activation_tensor, n_centers=number_of_cells)

In [None]:
def voronoi_averages_to_df(dict_with_voronoi_averages):
    dfs = []
    for key, matrix in dict_with_voronoi_averages.items():
        df = pd.DataFrame(matrix.transpose())
        df['index_name'] = key
        dfs.append(df)

    # Concatenate all the DataFrames into one
    return pd.concat(dfs, axis=0, ignore_index=True)

In [None]:
result_df = voronoi_averages_to_df(avgs_dict)

In [None]:
# Extract cell types and activations
cell_types = result_df.iloc[:, -1]  # Last column for cell type
activations = result_df.iloc[:, :-1]  # Exclude the last column

# Create a dictionary to hold shuffled root_ids for each cell type
root_id_mapping = {}

# Populate the dictionary with shuffled root_ids for each cell type
for cell_type, group in classification.groupby("cell_type"):
    # Shuffle the root_ids within each group
    shuffled_root_ids = group['root_id'].sample(frac=1).values
    root_id_mapping[cell_type] = shuffled_root_ids

# Function to assign root_ids to each row in result_df based on cell type and available root_ids
def assign_root_ids(row):
    cell_type = row.iloc[-1]  # Get cell type from the last column
    # Get the list of shuffled root_ids for this cell type
    root_ids = root_id_mapping[cell_type]
    # Assign a root_id from the list, ensuring we don't exceed the list's length
    # The index in the list is the count of occurrences of this cell type so far, modulo the number of available root_ids
    root_id_index = row.name % len(root_ids)  # row.name is the index of the row in the dataframe
    return root_ids[root_id_index]

# Apply the function to result_df, creating a new 'root_id' column
result_df['root_id'] = result_df.apply(assign_root_ids, axis=1)

# Remove duplicated root_ids
result_df = result_df.drop_duplicates(subset='root_id')


In [None]:
activation_df = pd.merge(
    classification.drop(
        columns=["flow", "super_class", "class", "sub_class", 
                 "hemibrain_type", "hemilineage", "side", "nerve"]), 
    result_df.drop(columns=[result_df.columns[-2]]), on='root_id', how='left').fillna(0)

In [None]:
synapse_df = get_synapse_df()

In [None]:
# Step 1: Identify Common Neurons
# Unique root_ids in merged_df
neurons_merged = pd.unique(activation_df['root_id'])

# Unique root_ids in synapse_df (both pre and post)
neurons_synapse_pre = pd.unique(synapse_df['pre_root_id'])
neurons_synapse_post = pd.unique(synapse_df['post_root_id'])
neurons_synapse = np.unique(np.concatenate([neurons_synapse_pre, neurons_synapse_post]))

# Common neurons
common_neurons = np.intersect1d(neurons_merged, neurons_synapse)

# Step 2: Filter synapse_df
# Keep only rows with both pre and post root_ids in common_neurons
from scipy.sparse import coo_matrix

# Filter synapse_df to include only rows with both pre and post root_ids in common_neurons
filtered_synapse_df = synapse_df[
    synapse_df['pre_root_id'].isin(common_neurons) & synapse_df['post_root_id'].isin(common_neurons)
]

# Map neuron root_ids to matrix indices
root_id_to_index = {root_id: index for index, root_id in enumerate(common_neurons)}

# Convert root_ids in filtered_synapse_df to matrix indices
pre_indices = filtered_synapse_df['pre_root_id'].map(root_id_to_index).values
post_indices = filtered_synapse_df['post_root_id'].map(root_id_to_index).values

# Use syn_count as the data for the non-zero elements of the matrix
data = filtered_synapse_df['syn_count'].values

# Create a sparse matrix in COO format
synaptic_matrix_sparse = coo_matrix(
    (data, (pre_indices, post_indices)),
    shape=(len(common_neurons), len(common_neurons)),
    dtype=np.int64  # or np.float32/np.float64 if memory issue persists
)

In [None]:
activation_df = activation_df[activation_df['root_id'].isin(list(root_id_to_index.keys()))]
activation_data = activation_df.drop(columns=["root_id", "cell_type"])

In [None]:
from torch.cuda.amp import GradScaler
from torch import autocast
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch

device_type = "cuda" if torch.cuda.is_available() else "cpu"
#device_type = "cpu" # for debugging
DEVICE = torch.device(device_type)

batch_size = 10

edges = torch.tensor([synaptic_matrix_sparse.row, synaptic_matrix_sparse.col], dtype=torch.long, device=DEVICE)
activation_tensor = torch.tensor(activation_data.values, dtype=torch.float16, device=DEVICE)

# Correctly set node features for each graph
graph_list = []
for i in range(activation_tensor.shape[1]):  # Iterate over samples, ensuring the second dimension is the sample dimension
    node_features = activation_tensor[:, i].unsqueeze(1)  # Shape [num_nodes, 1], one feature per node
    graph = Data(x=node_features.to(DEVICE), edge_index=edges)  # Create a graph for each sample
    graph_list.append(graph)

# DataLoader to handle batches of graphs
loader = DataLoader(graph_list, batch_size=10, shuffle=False)

# Random class labels for each graph
# FIXME!!!! 
class_labels = torch.tensor(np.round(np.random.rand(100)).astype("int"), dtype=torch.long, device=DEVICE)

# Initialize the model
model = GNNModel(num_node_features=1, num_classes=2).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scaler = GradScaler()

from torch.nn import CrossEntropyLoss

# Initialize the loss function
criterion = CrossEntropyLoss()

model.train()
for batch_idx, batch in tqdm(enumerate(loader)):
    batch = batch.to(DEVICE)

    # Get the labels for the current batch
    # Assuming your DataLoader does not automatically handle this
    batch_labels = class_labels[batch_idx * batch_size : (batch_idx + 1) * batch_size].to(DEVICE)

    optimizer.zero_grad()
    
    with autocast(device_type):
        out = model(batch)
        loss = criterion(out, batch_labels)    
    # Backward pass and optimize
    loss.backward()
    optimizer.step()


# TODO
1. Literature review to identify the "thinking" neurons
2. Identify these neurons in the classification dataframe and create a class_labels tensor
3. Train the model with the class_labels tensor
4. Check model accuracy and weber ratio
5. Try with other model architectures, specially with a one-hot encoding for each neuron type to simulate different neurons