In [1]:
import os

# stupid cuda stuff
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:100'

import pandas as pd
import random
import torch
from torch import device, cuda, autocast
from torch.cuda.amp import GradScaler
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm

from flyvision_ans import ResponseProcessor, DECODING_CELLS
from from_retina_to_connectome_funcs import from_retina_to_model
from graph_models import GNNModel

device_type = "cuda" if cuda.is_available() else "cpu"
DEVICE = device(device_type)
torch.manual_seed(42)
batch_size = 10
last_good_frame = 8

  _C._set_default_tensor_type(t)


In [2]:
# get a dataframe indicating which neurons will be used to classify
rational_neurons = pd.read_csv("adult_data/rational_neurons.csv", index_col=0)
decision_making_vector = torch.tensor(rational_neurons.values.squeeze(), dtype=torch.float16, device=DEVICE).detach()

In [3]:
response_processor = ResponseProcessor("very_toy_videos/yellow")
# compute the layer activations
layer_activations_yellow = response_processor.compute_layer_activations()
response_processor = ResponseProcessor("very_toy_videos/blue")
layer_activations_blue = response_processor.compute_layer_activations()
combined_activations = layer_activations_yellow + layer_activations_blue

# Create labels tensor
labels_0 = torch.zeros(len(layer_activations_yellow), dtype=torch.long)
labels_1 = torch.ones(len(layer_activations_blue), dtype=torch.long)
combined_labels = torch.cat((labels_0, labels_1), dim=0)

del layer_activations_yellow, layer_activations_blue
torch.cuda.empty_cache()

100%|██████████| 20/20 [00:01<00:00, 18.14it/s]
100%|██████████| 20/20 [00:01<00:00, 19.92it/s]


In [4]:
# shuffle (since the dataloader shuffle is broken, we have to do it by hand)
indices = list(range(len(combined_activations)))
random.shuffle(indices)

shuffled_list = [combined_activations[i] for i in indices]
shuffled_labels = combined_labels[torch.tensor(indices)]

In [5]:
loader = from_retina_to_model(shuffled_list, shuffled_labels, DECODING_CELLS, last_good_frame, DEVICE, batch_size)
torch.cuda.empty_cache()

100%|██████████| 34/34 [00:04<00:00,  8.49it/s]


In [6]:
# Initialize the model
model = GNNModel(num_node_features=1, decision_making_vector=decision_making_vector).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scaler = GradScaler()

# Initialize the loss function
criterion = BCEWithLogitsLoss()

model.train()
for batch_idx, batch in tqdm(enumerate(loader)):
    batch = batch.to(DEVICE)
    optimizer.zero_grad()
    
    with autocast(device_type):
        out = model(batch)
        loss = criterion(out, batch.y.unsqueeze(-1).float())    
        
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

4it [00:03,  1.23it/s]


# TODO

1. Literature review to identify the "thinking" neurons [x]
2. Identify these neurons in the classification dataframe and create a class_labels tensor [x]
3. Train the model with the class_labels tensor [x]
4. Check model accuracy and weber ratio
5. Try with other model architectures, specially with a one-hot encoding for each neuron type to simulate different neurons