In [1]:
import warnings
import pandas as pd
import torch
from torch import device, cuda, autocast
from torch.cuda.amp import GradScaler
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm
import wandb
from random import sample
from scipy.sparse import load_npz

import flyvision
from flyvision_ans import DECODING_CELLS
from flyvision.utils.activity_utils import LayerActivity
from from_image_to_video import image_paths_to_sequences
from from_retina_to_connectome_funcs import from_retina_to_model, get_cell_type_indices, compute_voronoi_averages, from_retina_to_connectome
from logs_to_wandb import log_images_to_wandb, log_running_stats_to_wandb
from from_video_to_training_batched_funcs import get_files_from_directory, select_random_videos, paths_to_labels, \
    load_custom_sequences
from from_retina_to_connectome_utils import (
    hex_to_square_grid,
    initialize_results_df,
    predictions_and_corrects_from_model_results,
    update_results_df,
    update_running_loss,
    get_decision_making_neurons,
    vector_to_one_hot,
)
from adult_models import FullAdultModel


warnings.filterwarnings(
    'ignore',
    message='invalid value encountered in cast',
    category=RuntimeWarning,
    module='wandb.sdk.data_types.image'
)

torch.manual_seed(1234)
dtype = torch.float32

device_type = "cuda" if cuda.is_available() else "cpu"
# device_type = "cpu"
DEVICE = device(device_type)
sparse_layout = torch.sparse_coo

TRAINING_DATA_DIR = "images/easy_v2"
TESTING_DATA_DIR = "images/easy_images"
VALIDATION_DATA_DIR = "images/easyval_images"

debugging = False
debug_length = 1000
validation_length = 50
wandb_ = True
wandb_images_every = 100
small = True
small_length = 2000

num_epochs = 1
batch_size = 1

dropout = .1
max_lr = 0.01
base_lr = 0.0001
weight_decay = 0.0001
NUM_CONNECTOME_PASSES = 8

use_one_cycle_lr = False

model_config = {
    "debugging": debugging,
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "dropout": dropout,
    "base_lr": base_lr,
    "max_lr": max_lr,
    "weight_decay": weight_decay,
    "num_connectome_passes": NUM_CONNECTOME_PASSES,
}

  _C._set_default_tensor_type(t)


In [2]:
# init stuff
extent, kernel_size = 15, 13
decision_making_vector = get_decision_making_neurons(dtype)
receptors = flyvision.rendering.BoxEye(extent=extent, kernel_size=kernel_size)
network_view = flyvision.NetworkView(flyvision.results_dir / "opticflow/000/0000")
network = network_view.init_network(chkpt="best_chkpt")
classification = pd.read_csv("adult_data/classification_clean.csv")
root_id_to_index = pd.read_csv("adult_data/root_id_to_index.csv")
dt = 1 / 100 # some parameter from flyvision
last_good_frame = 2
cell_type_plot = "TmY18"

cell_type_indices = get_cell_type_indices(classification, root_id_to_index, DECODING_CELLS)

training_videos = get_files_from_directory(TRAINING_DATA_DIR)
test_videos = get_files_from_directory(TESTING_DATA_DIR)
validation_videos = get_files_from_directory(TESTING_DATA_DIR)

if small:
    training_videos = sample(training_videos, small_length)
    test_videos = sample(test_videos, small_length)
    validation_videos = sample(validation_videos, int(small_length / 5))

if len(training_videos) == 0:
    print("I can't find any training images or videos!")

In [3]:
synaptic_matrix = load_npz("adult_data/synaptic_matrix_sparse.npz")
one_hot_decision_making = vector_to_one_hot(
    decision_making_vector, dtype, sparse_layout
).to(DEVICE)

model = FullAdultModel(
    synaptic_matrix,
    one_hot_decision_making,
    NUM_CONNECTOME_PASSES,
    dtype=dtype,
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=base_lr)
scaler = GradScaler()

# Initialize the loss function
criterion = BCEWithLogitsLoss()

In [4]:
if wandb_:
    wandb.init(project="adult_connectome", config=model_config)

model.train()

results = initialize_results_df()
probabilities = []
accuracies = []
already_selected = []
running_loss = 0.0
total_correct = 0
total = 0

iterations = debug_length if debugging else len(training_videos) // batch_size

for i in tqdm(range(iterations)):
    batch_files, already_selected = select_random_videos(
        training_videos, batch_size, already_selected
    )
    labels = paths_to_labels(batch_files)
    batch_sequences = image_paths_to_sequences(batch_files)
    rendered_sequences = receptors(batch_sequences)

    layer_activations = []
    for rendered_sequence in rendered_sequences:
        # rendered sequences are in RGB; move it to 0-1 for better training
        rendered_sequence = torch.div(rendered_sequence, 255)
        simulation = network.simulate(rendered_sequence[None], dt)
        layer_activations.append(
            LayerActivity(simulation, network.connectome, keepref=True)
        )

    if wandb_ and i % wandb_images_every == 0:
        la_0 = hex_to_square_grid(
            layer_activations[0][cell_type_plot].squeeze()[-last_good_frame].cpu().numpy()
            ),
        log_images_to_wandb(batch_sequences[0], rendered_sequences[0], la_0, batch_files[0], frame=last_good_frame, cell_type=cell_type_plot)

    voronoi_averages_df = compute_voronoi_averages(
        layer_activations, classification, DECODING_CELLS, last_good_frame=last_good_frame
    )
    # normalize column wise (except last column)
    values_cols = voronoi_averages_df.columns != "index_name"
    voronoi_averages_df.loc[:, values_cols] = voronoi_averages_df.loc[:, values_cols].apply(
        lambda x: (x - x.min()) / (x.max() - x.min()), axis=0
        )

    activation_df = from_retina_to_connectome(
        voronoi_averages_df, classification, root_id_to_index
    )
    del layer_activations, rendered_sequences, rendered_sequence, simulation
    torch.cuda.empty_cache()

    optimizer.zero_grad()

    inputs = torch.tensor(activation_df.values, dtype=dtype, device=DEVICE).to_sparse(
        layout=sparse_layout
    )
    labels = torch.tensor(labels, dtype=dtype, device=DEVICE)

    out = model(inputs)
    loss = criterion(out, labels)
    # scaler.scale(loss).backward()
    # scaler.step(optimizer)
    # scaler.update()
    loss.backward()
    optimizer.step()

    # Calculate run parameters
    predictions, labels_cpu, correct = predictions_and_corrects_from_model_results(out, labels)
    results = update_results_df(results, batch_files, predictions, labels_cpu, correct)
    running_loss += update_running_loss(loss, inputs)
    total += labels.shape[0]
    total_correct += correct.sum()

    if wandb_:
        log_running_stats_to_wandb(0, i, running_loss, total_correct, total, results)

print(f"Finished training with loss {running_loss / total} and accuracy {total_correct / total}")
torch.cuda.empty_cache()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33meudald[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 2000/2000 [3:01:46<00:00,  5.45s/it]  

Finished training with loss 5.30792955799e+28 and accuracy 0.475





In [10]:
already_selected_validation = []
validation_results = initialize_results_df()
validation_iterations = validation_length if validation_length is not None else len(validation_videos) // batch_size
for _ in tqdm(range(validation_iterations)):
    batch_files, already_selected_validation = select_random_videos(validation_videos, batch_size, already_selected_validation)

    labels = paths_to_labels(batch_files)
    batch_sequences = image_paths_to_sequences(batch_files)
    rendered_sequences = receptors(batch_sequences)
    layer_activations = []
    for rendered_sequence in rendered_sequences:
        simulation = network.simulate(rendered_sequence[None], dt)
        layer_activations.append(LayerActivity(simulation, network.connectome, keepref=True))

    voronoi_averages_df = compute_voronoi_averages(
        layer_activations, classification, DECODING_CELLS, last_good_frame=last_good_frame
    )
    # normalize column wise (except last column)
    values_cols = voronoi_averages_df.columns != "index_name"
    voronoi_averages_df.loc[:, values_cols] = voronoi_averages_df.loc[:, values_cols].apply(
        lambda x: (x - x.min()) / (x.max() - x.min()), axis=0
        )

    activation_df = from_retina_to_connectome(
        voronoi_averages_df, classification, root_id_to_index
    )
    del layer_activations, rendered_sequences, rendered_sequence, simulation
    torch.cuda.empty_cache()

    model.eval() 

    total_correct = 0
    total = 0
    running_loss = 0.0

    with torch.no_grad():
        inputs = torch.tensor(activation_df.values, dtype=dtype, device=DEVICE).to_sparse(
            layout=sparse_layout
        )
        labels = torch.tensor(labels, dtype=dtype, device=DEVICE)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        predictions, batch_labels_cpu, correct = predictions_and_corrects_from_model_results(outputs, labels)
        validation_results = update_results_df(validation_results, batch_files, predictions, batch_labels_cpu, correct)
        running_loss += update_running_loss(loss, inputs)
        total += batch_labels_cpu.shape[0]
        total_correct += correct.sum().item()

print(
    f"Validation Loss: {running_loss / total}, "
    f"Validation Accuracy: {total_correct / total}"
)

  2%|▏         | 1/50 [00:03<02:55,  3.58s/it]

Validation Loss: 1.4639311281903427e+29, Validation Accuracy: 0.0


  4%|▍         | 2/50 [00:07<02:51,  3.57s/it]

Validation Loss: 1.444757998579572e+29, Validation Accuracy: 0.0


  6%|▌         | 3/50 [00:10<02:47,  3.56s/it]

Validation Loss: 1.3978831607726631e+29, Validation Accuracy: 0.0


  8%|▊         | 4/50 [00:14<02:43,  3.56s/it]

Validation Loss: 1.3787176700515676e+29, Validation Accuracy: 0.0


 10%|█         | 5/50 [00:17<02:39,  3.54s/it]

Validation Loss: 1.4185685971495152e+29, Validation Accuracy: 0.0


 12%|█▏        | 6/50 [00:21<02:36,  3.56s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 14%|█▍        | 7/50 [00:24<02:33,  3.56s/it]

Validation Loss: 1.4385053256433898e+29, Validation Accuracy: 0.0


 16%|█▌        | 8/50 [00:28<02:29,  3.57s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 18%|█▊        | 9/50 [00:32<02:26,  3.58s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 20%|██        | 10/50 [00:35<02:22,  3.57s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 22%|██▏       | 11/50 [00:39<02:19,  3.58s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 24%|██▍       | 12/50 [00:42<02:15,  3.56s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 26%|██▌       | 13/50 [00:46<02:10,  3.54s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 28%|██▊       | 14/50 [00:49<02:08,  3.56s/it]

Validation Loss: 1.4814516468715036e+29, Validation Accuracy: 0.0


 30%|███       | 15/50 [00:53<02:04,  3.56s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 32%|███▏      | 16/50 [00:57<02:01,  3.56s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 34%|███▍      | 17/50 [01:00<01:57,  3.55s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 36%|███▌      | 18/50 [01:04<01:52,  3.53s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 38%|███▊      | 19/50 [01:07<01:50,  3.56s/it]

Validation Loss: 1.4263331900725547e+29, Validation Accuracy: 0.0


 40%|████      | 20/50 [01:11<01:46,  3.56s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 42%|████▏     | 21/50 [01:14<01:43,  3.56s/it]

Validation Loss: 1.4435393539393594e+29, Validation Accuracy: 0.0


 44%|████▍     | 22/50 [01:18<01:40,  3.58s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 46%|████▌     | 23/50 [01:21<01:36,  3.56s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 48%|████▊     | 24/50 [01:25<01:32,  3.57s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 50%|█████     | 25/50 [01:29<01:29,  3.57s/it]

Validation Loss: 1.4123818700710356e+29, Validation Accuracy: 0.0


 52%|█████▏    | 26/50 [01:32<01:25,  3.56s/it]

Validation Loss: 1.4473418771859164e+29, Validation Accuracy: 0.0


 54%|█████▍    | 27/50 [01:36<01:21,  3.55s/it]

Validation Loss: 1.442428330618242e+29, Validation Accuracy: 0.0


 56%|█████▌    | 28/50 [01:39<01:18,  3.55s/it]

Validation Loss: 1.4154992139221672e+29, Validation Accuracy: 0.0


 58%|█████▊    | 29/50 [01:43<01:14,  3.55s/it]

Validation Loss: 1.4577770839562963e+29, Validation Accuracy: 0.0


 60%|██████    | 30/50 [01:46<01:10,  3.54s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 62%|██████▏   | 31/50 [01:50<01:07,  3.54s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 64%|██████▍   | 32/50 [01:53<01:04,  3.57s/it]

Validation Loss: 1.448585179001663e+29, Validation Accuracy: 0.0


 66%|██████▌   | 33/50 [01:57<01:00,  3.57s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 68%|██████▊   | 34/50 [02:01<00:56,  3.55s/it]

Validation Loss: 1.3709921418301583e+29, Validation Accuracy: 0.0


 70%|███████   | 35/50 [02:04<00:53,  3.55s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 72%|███████▏  | 36/50 [02:08<00:49,  3.55s/it]

Validation Loss: 1.4142123027489078e+29, Validation Accuracy: 0.0


 74%|███████▍  | 37/50 [02:11<00:46,  3.56s/it]

Validation Loss: 1.4200233705060174e+29, Validation Accuracy: 0.0


 76%|███████▌  | 38/50 [02:15<00:42,  3.56s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 78%|███████▊  | 39/50 [02:18<00:39,  3.56s/it]

Validation Loss: 1.4220221488411625e+29, Validation Accuracy: 0.0


 80%|████████  | 40/50 [02:22<00:35,  3.57s/it]

Validation Loss: 1.4622584047410816e+29, Validation Accuracy: 0.0


 82%|████████▏ | 41/50 [02:25<00:32,  3.57s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 84%|████████▍ | 42/50 [02:29<00:28,  3.55s/it]

Validation Loss: 1.442885600355889e+29, Validation Accuracy: 0.0


 86%|████████▌ | 43/50 [02:33<00:24,  3.56s/it]

Validation Loss: 1.4251050693413527e+29, Validation Accuracy: 0.0


 88%|████████▊ | 44/50 [02:36<00:21,  3.54s/it]

Validation Loss: 1.354377266702138e+29, Validation Accuracy: 0.0


 90%|█████████ | 45/50 [02:40<00:17,  3.55s/it]

Validation Loss: 1.4410740231651899e+29, Validation Accuracy: 0.0


 92%|█████████▏| 46/50 [02:43<00:14,  3.55s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 94%|█████████▍| 47/50 [02:47<00:10,  3.55s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0


 96%|█████████▌| 48/50 [02:50<00:07,  3.55s/it]

Validation Loss: 1.4535989980831753e+29, Validation Accuracy: 0.0


 98%|█████████▊| 49/50 [02:54<00:03,  3.56s/it]

Validation Loss: 1.435180281348937e+29, Validation Accuracy: 0.0


100%|██████████| 50/50 [02:57<00:00,  3.56s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0





In [11]:
validation_results

Unnamed: 0,Image,Prediction,True label,Is correct
0,images/easy_images/yellow/img2_15_10_26.png,1.0,1.0,1


When cuda breaks:


In [4]:
!sudo rmmod nvidia_uvm
!sudo modprobe nvidia_uvm

[sudo] password for eudald: 
[sudo] password for eudald: 


In [4]:
from pytorch_model_summary import summary

print(summary(model, torch.zeros(inputs.shape), show_input=True))

NameError: name 'inputs' is not defined

In [6]:
# Correctly print number of non-zero entries
print("Number of non-zero entries (trainable parameters):", model.connectome.shared_weights.numel())
print("Shape of the sparse tensor:", model.connectome.shared_weights.shape)
print("Number of parameters (non-zero entries):", model.connectome.shared_weights._nnz())


Number of non-zero entries (trainable parameters): 3871467
Shape of the sparse tensor: torch.Size([3871467])


NotImplementedError: Could not run 'aten::_nnz' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_nnz' is only available for these backends: [Meta, SparseCPU, SparseCUDA, SparseMeta, SparseCsrCPU, SparseCsrCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Meta: registered at ../aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
SparseCPU: registered at aten/src/ATen/RegisterSparseCPU.cpp:1387 [kernel]
SparseCUDA: registered at aten/src/ATen/RegisterSparseCUDA.cpp:1573 [kernel]
SparseMeta: registered at aten/src/ATen/RegisterSparseMeta.cpp:249 [kernel]
SparseCsrCPU: registered at aten/src/ATen/RegisterSparseCsrCPU.cpp:1135 [kernel]
SparseCsrCUDA: registered at aten/src/ATen/RegisterSparseCsrCUDA.cpp:1276 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:154 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:324 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradHIP: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradMPS: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradIPU: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradXPU: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradVE: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradLazy: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradMTIA: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradMeta: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:17434 [autograd kernel]
Tracer: registered at ../torch/csrc/autograd/generated/TraceType_4.cpp:13162 [kernel]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:378 [backend fallback]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:244 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:720 [backend fallback]
BatchedNestedTensor: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:746 [backend fallback]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:162 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:166 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:158 [backend fallback]


In [13]:
inputs.shape

torch.Size([134191, 2])

In [None]:
model.adul