In [3]:
import os

import numpy as np
import pandas as pd
import torch
from torch import autocast, device, cuda
from torch.cuda.amp import GradScaler
from torch.nn import BCEWithLogitsLoss
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from graph_models import GNNModel


device_type = "cuda" if cuda.is_available() else "cpu"
DEVICE = device(device_type)
batch_size = 10
last_good_frame = 8

In [4]:
# get data
activations_dir = "flyvis/parsed_objects"
activations = np.load(os.path.join(activations_dir, "decoding_activations.npy"), allow_pickle=True)
# labels = np.load(os.path.join(activations_dir, "decoding_labels.npy"), allow_pickle=True)
# toy labels as a tensor with all 1
labels = torch.ones(activations.shape[0], dtype=torch.long, device=DEVICE)
classification = pd.read_csv("adult_data/classification.csv").drop_duplicates(subset='root_id')

In [5]:
from retina_to_connectome_funcs import create_root_id_mapping
from flyvis.examples.flyvision_ans import DECODING_CELLS
from connectome_model_preparation import prepare_connectome_input, compute_voronoi_averages, LAST_GOOD_FRAME, \
    get_synaptic_matrix, get_rational_neurons

result_df = compute_voronoi_averages(
    activations, classification, DECODING_CELLS, last_good_frame=LAST_GOOD_FRAME
)

# Create a dictionary to hold shuffled root_ids for each cell type
root_id_mapping = create_root_id_mapping(classification)

def assign_root_ids(row):
    cell_type = row.iloc[-1]
    root_ids = root_id_mapping[cell_type]
    return root_ids[row.name % len(root_ids)]

# Apply the function to result_df, creating a new 'root_id' column
result_df["root_id"] = result_df.apply(assign_root_ids, axis=1)

# Remove duplicated root_ids
result_df = result_df.drop_duplicates(subset="root_id")

activation_df = pd.merge(
    classification.drop(
        columns=[
            "flow",
            "super_class",
            "class",
            "sub_class",
            "hemibrain_type",
            "hemilineage",
            "side",
            "nerve",
        ]
    ),
    result_df.drop(columns=[result_df.columns[-2]]),
    on="root_id",
    how="left",
).fillna(0)
synaptic_matrix_sparse, root_id_to_index = get_synaptic_matrix(activation_df)
activation_df = activation_df[
    activation_df["root_id"].isin(list(root_id_to_index.keys()))
]
activation_data = activation_df.drop(columns=["root_id", "cell_type"])

# fixme: this should not be here
decision_making_vector = get_rational_neurons(root_id_to_index, activation_df)

  _C._set_default_tensor_type(t)
100%|██████████| 34/34 [00:10<00:00,  3.21it/s]


In [10]:
from scipy.sparse import load_npz

temp = load_npz("adult_data/synaptic_matrix_sparse.npz")

In [ ]:
edges = torch.tensor(
    np.array([synaptic_matrix_sparse.row, synaptic_matrix_sparse.col]),
    dtype=torch.long,
    device=DEVICE,
)
activation_tensor = torch.tensor(
    activation_data.values, dtype=torch.float16, device=DEVICE
)

# move the decision-making vector to the device
decision_making_vector = torch.tensor(
    decision_making_vector, dtype=torch.float16, device=DEVICE
).detach()

# Correctly set node features for each graph
graph_list = []
for i in range(
    activation_tensor.shape[1]
):  # Iterate over samples, ensuring the second dimension is the sample dimension
    node_features = activation_tensor[:, i].unsqueeze(
        1
    )  # Shape [num_nodes, 1], one feature per node
    graph = Data(
        x=node_features.to(device), edge_index=edges, y=labels[i]
    )  # Create a graph for each sample
    graph_list.append(graph)

# DataLoader to handle batches of graphs
loader = DataLoader(graph_list, batch_size=batch_size, shuffle=False)



In [4]:
batch_size = 10

edges = torch.tensor(np.array([synaptic_matrix_sparse.row, synaptic_matrix_sparse.col]), dtype=torch.long, device=DEVICE)
activation_tensor = torch.tensor(activation_data.values, dtype=torch.float16, device=DEVICE)

# move the decision-making vector to the device
decision_making_vector = torch.tensor(decision_making_vector, dtype=torch.float16, device=DEVICE).detach()

# Correctly set node features for each graph
graph_list = []
for i in range(activation_tensor.shape[1]):  # Iterate over samples, ensuring the second dimension is the sample dimension
    node_features = activation_tensor[:, i].unsqueeze(1)  # Shape [num_nodes, 1], one feature per node
    graph = Data(x=node_features.to(DEVICE), edge_index=edges, y=labels[i])  # Create a graph for each sample
    graph_list.append(graph)

# DataLoader to handle batches of graphs
loader = DataLoader(graph_list, batch_size=batch_size, shuffle=False)

# Initialize the model
model = GNNModel(num_node_features=1, decision_making_vector=decision_making_vector).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scaler = GradScaler()

# Initialize the loss function
criterion = BCEWithLogitsLoss()

model.train()
for batch_idx, batch in tqdm(enumerate(loader)):
    batch = batch.to(DEVICE)

    optimizer.zero_grad()
    
    with autocast(device_type):
        out = model(batch)
        loss = criterion(out, batch.y.unsqueeze(-1).float())    
    # Backward pass and optimize
    loss.backward()
    optimizer.step()


10it [00:01,  8.03it/s]


# TODO
1. Literature review to identify the "thinking" neurons
2. Identify these neurons in the classification dataframe and create a class_labels tensor
3. Train the model with the class_labels tensor
4. Check model accuracy and weber ratio
5. Try with other model architectures, specially with a one-hot encoding for each neuron type to simulate different neurons

Kenyon Cells (KC): KCab, KCapbp-m, KCapbp-ap1, KCapbp-ap2
T4/T5 Neurons: These are involved in motion detection and possibly could be implicated in processing visual information related to numerosity. The neurons you've listed include T4a, T4b, T4c, T4d, T5a, T5b, T5c, T5d.
Central Complex Neurons: These are involved in a variety of integrative brain functions which could include decision-making processes. The neurons from your list include C2, C3.
torch.unique(batch.y)

In [13]:
len(out)

1

In [15]:
batch.y

tensor([1])