In [1]:
import warnings
import pandas as pd
import torch
from torch import device, cuda, autocast
from torch.cuda.amp import GradScaler
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm
import wandb

import flyvision
from flyvision_ans import DECODING_CELLS
from flyvision.utils.activity_utils import LayerActivity
from from_retina_to_connectome_funcs import from_retina_to_model, get_decision_making_neurons, get_cell_type_indices, compute_accuracy, get_tensor_items
from logs_to_wandb import log_images_to_wandb
from graph_models import GNNModel
from from_video_to_training_batched_funcs import get_files_from_directory, select_random_videos, paths_to_labels, \
    load_custom_sequences

warnings.filterwarnings(
    'ignore',
    message='invalid value encountered in cast',
    category=RuntimeWarning,
    module='wandb.sdk.data_types.image'
)

device_type = "cuda" if cuda.is_available() else "cpu"
DEVICE = device(device_type)
torch.manual_seed(42)
batch_size = 2
last_good_frame = 2

TRAINING_DATA_DIR = "videos/easy_videos"
TESTING_DATA_DIR = "videos/easyval_videos"

debugging = True
debug_length = 19
wandb_ = True
wandb_images_every = 5
cell_type_plot = "TmY18"

NUM_CONNECTOME_PASSES=16

  _C._set_default_tensor_type(t)


In [2]:
# init stuff
extent, kernel_size = 15, 13
decision_making_vector = get_decision_making_neurons()
receptors = flyvision.rendering.BoxEye(extent=extent, kernel_size=kernel_size)
network_view = flyvision.NetworkView(flyvision.results_dir / "opticflow/000/0000")
network = network_view.init_network(chkpt="best_chkpt")
classification = pd.read_csv("adult_data/classification_clean.csv")
root_id_to_index = pd.read_csv("adult_data/root_id_to_index.csv")

all_videos = get_files_from_directory(TRAINING_DATA_DIR)
all_validation_videos = get_files_from_directory(TESTING_DATA_DIR)
dt = 1 / 100 # some random parameter from flyvision

cell_type_indices = get_cell_type_indices(classification, root_id_to_index, DECODING_CELLS)

In [3]:
model = GNNModel(
    num_node_features=1, 
    decision_making_vector=decision_making_vector,
    num_passes=NUM_CONNECTOME_PASSES,
    cell_type_indices=cell_type_indices,
    batch_size=batch_size,
    visual_input_persistence_rate=.8
).to(DEVICE)
lr = .01
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scaler = GradScaler()

# Initialize the loss function
criterion = BCEWithLogitsLoss()

In [5]:
if wandb_:
    wandb.init(
        project="adult_connectome", 
        config={"learning_rate": lr, "batch_size": batch_size}
    )
    data_table = wandb.Table(columns=["Predictions", "True Labels"])

probabilities = []
accuracies = []
already_selected = []
iterations = debug_length if debugging else len(all_videos) // batch_size

for i in tqdm(range(iterations)):
    batch_files, already_selected = select_random_videos(
        all_videos, batch_size, already_selected
    )
    labels = paths_to_labels(batch_files)
    batch_sequences = load_custom_sequences(batch_files)
    rendered_sequences = receptors(batch_sequences)
    
    layer_activations = []
    for rendered_sequence in rendered_sequences:
        # rendered sequences are in RGB; move it to 0-1 for better training
        rendered_sequence = torch.div(rendered_sequence, 255)
        simulation = network.simulate(rendered_sequence[None], dt)
        layer_activations.append(
            LayerActivity(simulation, network.connectome, keepref=True)
        )
        
    if wandb_ and i % wandb_images_every == 0:
        log_images_to_wandb(batch_sequences[0], rendered_sequences[0], layer_activations[0], batch_files[0], frame=last_good_frame, cell_type=cell_type_plot)
    
    del rendered_sequences, rendered_sequence, simulation
    torch.cuda.empty_cache()
    
    inputs, labels = from_retina_to_model(
        layer_activations, labels, DECODING_CELLS, last_good_frame, classification, root_id_to_index
    )
    torch.cuda.empty_cache()
    
    model.train()
    if wandb_:
        wandb.watch(model, criterion, log="all", log_freq=10)
    inputs = inputs.to(DEVICE)
    labels = labels.to(DEVICE).unsqueeze(-1).float() 
    optimizer.zero_grad()
    
    with autocast(device_type):
        out = model(inputs)
        loss = criterion(out, labels)
        # Convert logits to probabilities
        prob = torch.sigmoid(out)
        probabilities.append(prob)
        accuracies.append(compute_accuracy(prob, labels))
    
    if wandb_:
        wandb.log({
            "loss": loss.item(), 
            "acc": sum(accuracies) / len(accuracies)}
        )
        
        predictions = get_tensor_items(out)
        true_labels = get_tensor_items(labels)
        for pred, label in zip(predictions, true_labels):
            data_table.add_data(pred, label)
        
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

if wandb_:
    wandb.log({"predictions_vs_labels": data_table})

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112039977777183, max=1.0…

100%|██████████| 19/19 [02:47<00:00,  8.84s/it]


In [58]:
already_selected_validation = []
# Assuming batch_size is defined
for _ in tqdm(range(len(all_validation_videos) // batch_size)):
    batch_files, already_selected_validation = select_random_videos(all_validation_videos, batch_size, already_selected_validation)

    labels = paths_to_labels(batch_files)  # Convert paths to labels
    batch_sequences = load_custom_sequences(batch_files)  # Load and preprocess the video sequences
    
    # Assuming receptors is a function that processes your sequences
    rendered_sequences = receptors(batch_sequences)
    
    layer_activations = []
    for rendered_sequence in rendered_sequences:
        simulation = network.simulate(rendered_sequence[None], dt)
        layer_activations.append(LayerActivity(simulation, network.connectome, keepref=True))
        
    del rendered_sequences, simulation
    torch.cuda.empty_cache()
    
    # Preparing the data for the model, similar to training
    inputs, labels = from_retina_to_model(layer_activations, labels, DECODING_CELLS, last_good_frame, classification, root_id_to_index)
    torch.cuda.empty_cache()

    model.eval()  # Set the model to evaluation mode
    val_loss = []
    with torch.no_grad():  # Disable gradient computation
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE).unsqueeze(-1).float()
        
        with autocast(device_type):
            predictions = model(inputs)
            # Assuming your criterion and evaluation metrics are defined similarly to training
            loss = criterion(predictions, labels)
            val_loss.append(loss.item())
            # Calculate other metrics if necessary, e.g., accuracy
            
            # Log validation metrics to WandB
            wandb.log({"validation_loss": loss.item()})
            # Log other metrics similarly

  0%|          | 0/1960 [00:00<?, ?it/s]


RuntimeError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


When cuda breaks:


In [2]:
!sudo rmmod nvidia_uvm
!sudo modprobe nvidia_uvm

[sudo] password for eudald: 
[sudo] password for eudald: 