In [1]:
import warnings
import pandas as pd
import torch
from torch import device, cuda
from torch.cuda.amp import GradScaler
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm
import wandb
from random import sample
from scipy.sparse import load_npz

import flyvision
from flyvision_ans import DECODING_CELLS
from flyvision.utils.activity_utils import LayerActivity
from from_image_to_video import image_paths_to_sequences
from from_retina_to_connectome_funcs import get_cell_type_indices, compute_voronoi_averages, from_retina_to_connectome
from logs_to_wandb import log_images_to_wandb, log_running_stats_to_wandb
from from_video_to_training_batched_funcs import get_files_from_directory, select_random_videos, paths_to_labels
from from_retina_to_connectome_utils import (
    hex_to_square_grid,
    initialize_results_df,
    predictions_and_corrects_from_model_results,
    update_results_df,
    update_running_loss,
    get_decision_making_neurons,
    vector_to_one_hot,
)
from adult_models import FullAdultModel


warnings.filterwarnings(
    'ignore',
    message='invalid value encountered in cast',
    category=RuntimeWarning,
    module='wandb.sdk.data_types.image'
)

torch.manual_seed(1234)
dtype = torch.float32

device_type = "cuda" if cuda.is_available() else "cpu"
# device_type = "cpu"
DEVICE = device(device_type)
sparse_layout = torch.sparse_coo

TRAINING_DATA_DIR = "images/easy_v2"
TESTING_DATA_DIR = "images/easy_images"
VALIDATION_DATA_DIR = "images/easyval_images"

debugging = True
debug_length = 100
validation_length = 50
wandb_ = False
wandb_images_every = 100
small = True
small_length = 1000

num_epochs = 1
batch_size = 1

dropout = .1
max_lr = 0.01
base_lr = 0.0002
weight_decay = 0.0001
NUM_CONNECTOME_PASSES = 10

use_one_cycle_lr = False

model_config = {
    "debugging": debugging,
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "dropout": dropout,
    "base_lr": base_lr,
    "max_lr": max_lr,
    "weight_decay": weight_decay,
    "num_connectome_passes": NUM_CONNECTOME_PASSES,
}

  _C._set_default_tensor_type(t)


In [2]:
# init stuff
extent, kernel_size = 15, 13
decision_making_vector = get_decision_making_neurons(dtype)
receptors = flyvision.rendering.BoxEye(extent=extent, kernel_size=kernel_size)
network_view = flyvision.NetworkView(flyvision.results_dir / "opticflow/000/0000")
network = network_view.init_network(chkpt="best_chkpt")
classification = pd.read_csv("adult_data/classification_clean.csv")
root_id_to_index = pd.read_csv("adult_data/root_id_to_index.csv")
dt = 1 / 100 # some parameter from flyvision
last_good_frame = 2
cell_type_plot = "TmY18"

cell_type_indices = get_cell_type_indices(classification, root_id_to_index, DECODING_CELLS)

training_videos = get_files_from_directory(TRAINING_DATA_DIR)
test_videos = get_files_from_directory(TESTING_DATA_DIR)
validation_videos = get_files_from_directory(TESTING_DATA_DIR)

if small:
    training_videos = sample(training_videos, small_length)
    test_videos = sample(test_videos, small_length)
    validation_videos = sample(validation_videos, int(small_length / 5))

if len(training_videos) == 0:
    print("I can't find any training images or videos!")

In [3]:
synaptic_matrix = load_npz("adult_data/synaptic_matrix_sparse.npz")
one_hot_decision_making = vector_to_one_hot(
    decision_making_vector, dtype, sparse_layout
).to(DEVICE)

model = FullAdultModel(
    synaptic_matrix,
    one_hot_decision_making,
    cell_type_indices,
    NUM_CONNECTOME_PASSES,
    log_transform_weights=True,
    sparse_layout=sparse_layout,
    dtype=dtype,
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=base_lr)
scaler = GradScaler()

# Initialize the loss function
criterion = BCEWithLogitsLoss()

In [5]:
if wandb_:
    wandb.init(project="adult_connectome", config=model_config)

model.train()

results = initialize_results_df()
probabilities = []
accuracies = []
already_selected = []
running_loss = 0.0
total_correct = 0
total = 0

iterations = debug_length if debugging else len(training_videos) // batch_size

for i in tqdm(range(iterations)):
    batch_files, already_selected = select_random_videos(
        training_videos, batch_size, already_selected
    )
    labels = paths_to_labels(batch_files)
    batch_sequences = image_paths_to_sequences(batch_files)
    rendered_sequences = receptors(batch_sequences)

    layer_activations = []
    for rendered_sequence in rendered_sequences:
        # rendered sequences are in RGB; move it to 0-1 for better training
        rendered_sequence = torch.div(rendered_sequence, 255)
        simulation = network.simulate(rendered_sequence[None], dt)
        layer_activations.append(
            LayerActivity(simulation, network.connectome, keepref=True)
        )

    if wandb_ and i % wandb_images_every == 0:
        la_0 = hex_to_square_grid(
            layer_activations[0][cell_type_plot].squeeze()[-last_good_frame].cpu().numpy()
            ),
        log_images_to_wandb(batch_sequences[0], rendered_sequences[0], la_0, batch_files[0], frame=last_good_frame, cell_type=cell_type_plot)

    voronoi_averages_df = compute_voronoi_averages(
        layer_activations, classification, DECODING_CELLS, last_good_frame=last_good_frame
    )
    # normalize column wise (except last column)
    values_cols = voronoi_averages_df.columns != "index_name"
    voronoi_averages_df.loc[:, values_cols] = voronoi_averages_df.loc[:, values_cols].apply(
        lambda x: (x - x.min()) / (x.max() - x.min()), axis=0
        )

    activation_df = from_retina_to_connectome(
        voronoi_averages_df, classification, root_id_to_index
    )
    del layer_activations, rendered_sequences, rendered_sequence, simulation
    torch.cuda.empty_cache()

    optimizer.zero_grad()

    inputs = torch.tensor(activation_df.values, dtype=dtype, device=DEVICE)
    labels = torch.tensor(labels, dtype=dtype, device=DEVICE)

    out = model(inputs)
    loss = criterion(out, labels)
    print(out, labels, loss)
    # scaler.scale(loss).backward()
    # scaler.step(optimizer)
    # scaler.update()
    loss.backward()
    optimizer.step()

    # Calculate run parameters
    predictions, labels_cpu, correct = predictions_and_corrects_from_model_results(out, labels)
    results = update_results_df(results, batch_files, predictions, labels_cpu, correct)
    running_loss += update_running_loss(loss, inputs)
    total += batch_size
    total_correct += correct.sum()

    if wandb_:
        log_running_stats_to_wandb(0, i, running_loss, total_correct, total, results)

print(f"Finished training with loss {running_loss / total} and accuracy {total_correct / total}")
torch.cuda.empty_cache()

  0%|          | 0/100 [00:00<?, ?it/s]

tensor([-1.9091e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


  1%|          | 1/100 [01:40<2:45:32, 100.33s/it]

tensor([-1.9359e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


  3%|▎         | 3/100 [03:30<1:33:56, 58.11s/it] 

tensor([-2.2980e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.2980e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


  4%|▍         | 4/100 [03:33<57:40, 36.04s/it]  

tensor([-2.0525e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0525e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor([-2.0400e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


  5%|▌         | 5/100 [03:44<42:53, 27.09s/it]

tensor([-2.0352e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0352e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


  6%|▌         | 6/100 [03:52<32:26, 20.70s/it]

tensor([-2.0057e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0057e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


  7%|▋         | 7/100 [04:01<26:00, 16.77s/it]

tensor([-2.0504e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0504e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


  8%|▊         | 8/100 [04:10<21:46, 14.20s/it]

tensor([-2.1570e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


  9%|▉         | 9/100 [04:58<37:52, 24.97s/it]

tensor([-2.1573e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1573e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 10%|█         | 10/100 [05:18<34:57, 23.31s/it]

tensor([-2.0067e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 11%|█         | 11/100 [05:26<27:47, 18.74s/it]

tensor([-2.1916e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 12%|█▏        | 12/100 [05:34<22:50, 15.57s/it]

tensor([-2.1321e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 13%|█▎        | 13/100 [05:43<19:22, 13.37s/it]

tensor([-1.9914e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 14%|█▍        | 14/100 [05:51<16:59, 11.86s/it]

tensor([-2.1555e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1555e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 15%|█▌        | 15/100 [05:59<15:16, 10.78s/it]

tensor([-1.9592e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 16%|█▌        | 16/100 [06:08<14:02, 10.03s/it]

tensor([-2.0204e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0204e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 17%|█▋        | 17/100 [06:16<13:13,  9.56s/it]

tensor([-2.0986e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0986e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 18%|█▊        | 18/100 [07:03<28:22, 20.76s/it]

tensor([-2.1704e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1704e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 19%|█▉        | 19/100 [07:27<29:23, 21.77s/it]

tensor([-2.0164e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0164e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 20%|██        | 20/100 [07:36<23:39, 17.75s/it]

tensor([-2.1087e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1087e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 21%|██        | 21/100 [07:44<19:34, 14.87s/it]

tensor([-2.0943e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0943e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 22%|██▏       | 22/100 [07:52<16:40, 12.83s/it]

tensor([-1.9500e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 23%|██▎       | 23/100 [07:59<14:22, 11.20s/it]

tensor([-2.0378e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 24%|██▍       | 24/100 [08:06<12:40, 10.01s/it]

tensor([-2.0669e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 25%|██▌       | 25/100 [08:14<11:28,  9.18s/it]

tensor([-2.1046e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1046e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 26%|██▌       | 26/100 [08:21<10:43,  8.70s/it]

tensor([-2.0129e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 27%|██▋       | 27/100 [08:30<10:27,  8.60s/it]

tensor([-2.1324e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1324e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 28%|██▊       | 28/100 [08:38<10:08,  8.45s/it]

tensor([-2.1571e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 29%|██▉       | 29/100 [08:45<09:44,  8.23s/it]

tensor([-2.1259e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 30%|███       | 30/100 [08:53<09:20,  8.01s/it]

tensor([-2.1365e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 31%|███       | 31/100 [09:00<09:01,  7.84s/it]

tensor([-1.9932e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 32%|███▏      | 32/100 [09:08<08:43,  7.70s/it]

tensor([-2.2942e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.2942e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 33%|███▎      | 33/100 [09:15<08:28,  7.60s/it]

tensor([-2.1829e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 34%|███▍      | 34/100 [09:22<08:16,  7.52s/it]

tensor([-2.2325e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 35%|███▌      | 35/100 [09:30<08:06,  7.48s/it]

tensor([-2.0886e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0886e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 36%|███▌      | 36/100 [09:37<07:56,  7.44s/it]

tensor([-2.1121e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1121e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 37%|███▋      | 37/100 [09:45<07:49,  7.46s/it]

tensor([-2.0575e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0575e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 38%|███▊      | 38/100 [09:53<07:53,  7.64s/it]

tensor([-2.0466e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 39%|███▉      | 39/100 [10:01<07:56,  7.81s/it]

tensor([-2.1311e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1311e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 40%|████      | 40/100 [10:09<07:54,  7.90s/it]

tensor([-1.9891e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(1.9891e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 41%|████      | 41/100 [10:17<07:53,  8.03s/it]

tensor([-1.9931e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(1.9931e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 42%|████▏     | 42/100 [10:26<07:51,  8.12s/it]

tensor([-2.1171e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1171e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 43%|████▎     | 43/100 [10:34<07:48,  8.23s/it]

tensor([-2.3469e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 44%|████▍     | 44/100 [10:43<07:45,  8.31s/it]

tensor([-2.0255e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0255e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 45%|████▌     | 45/100 [10:52<07:47,  8.50s/it]

tensor([-2.0181e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0181e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 46%|████▌     | 46/100 [11:49<20:53, 23.21s/it]

tensor([-2.1735e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1735e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 47%|████▋     | 47/100 [12:01<17:24, 19.70s/it]

tensor([-2.0290e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 48%|████▊     | 48/100 [12:09<14:07, 16.30s/it]

tensor([-2.1051e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1051e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 49%|████▉     | 49/100 [12:17<11:48, 13.90s/it]

tensor([-2.1905e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1905e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 50%|█████     | 50/100 [12:26<10:10, 12.20s/it]

tensor([-1.9989e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 51%|█████     | 51/100 [12:34<09:02, 11.08s/it]

tensor([-2.1822e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1822e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 52%|█████▏    | 52/100 [12:42<08:11, 10.25s/it]

tensor([-2.0750e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0750e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 53%|█████▎    | 53/100 [12:53<08:09, 10.41s/it]

tensor([-1.9932e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(1.9932e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 54%|█████▍    | 54/100 [13:02<07:42, 10.05s/it]

tensor([-2.1033e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 55%|█████▌    | 55/100 [13:11<07:08,  9.51s/it]

tensor([-2.2112e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.2112e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 56%|█████▌    | 56/100 [13:53<14:09, 19.32s/it]

tensor([-2.0846e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 57%|█████▋    | 57/100 [14:01<11:27, 16.00s/it]

tensor([-2.1376e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1376e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 58%|█████▊    | 58/100 [14:12<10:14, 14.64s/it]

tensor([-1.9849e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(1.9849e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 59%|█████▉    | 59/100 [14:40<12:37, 18.46s/it]

tensor([-2.1967e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1967e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 60%|██████    | 60/100 [14:48<10:16, 15.42s/it]

tensor([-2.1796e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 61%|██████    | 61/100 [14:57<08:38, 13.30s/it]

tensor([-1.9422e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(1.9422e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 62%|██████▏   | 62/100 [15:05<07:28, 11.79s/it]

tensor([-1.9596e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(1.9596e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 63%|██████▎   | 63/100 [15:15<06:58, 11.32s/it]

tensor([-2.1844e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 64%|██████▍   | 64/100 [16:02<13:07, 21.86s/it]

tensor([-2.1300e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1300e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 65%|██████▌   | 65/100 [16:19<12:03, 20.67s/it]

tensor([-2.0880e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 66%|██████▌   | 66/100 [16:28<09:36, 16.97s/it]

tensor([-2.2207e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.2207e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 67%|██████▋   | 67/100 [16:36<07:51, 14.28s/it]

tensor([-2.0572e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 68%|██████▊   | 68/100 [16:43<06:32, 12.27s/it]

tensor([-1.9704e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 69%|██████▉   | 69/100 [16:51<05:39, 10.96s/it]

tensor([-2.0599e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0599e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 70%|███████   | 70/100 [16:59<05:02, 10.08s/it]

tensor([-2.1889e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1889e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 71%|███████   | 71/100 [17:07<04:32,  9.41s/it]

tensor([-2.0915e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 72%|███████▏  | 72/100 [17:15<04:11,  8.99s/it]

tensor([-2.1738e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 73%|███████▎  | 73/100 [17:23<03:53,  8.65s/it]

tensor([-2.0553e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 74%|███████▍  | 74/100 [17:31<03:40,  8.49s/it]

tensor([-2.1114e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 75%|███████▌  | 75/100 [17:39<03:28,  8.33s/it]

tensor([-2.1258e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 76%|███████▌  | 76/100 [17:47<03:15,  8.16s/it]

tensor([-1.9611e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(1.9611e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 77%|███████▋  | 77/100 [17:55<03:06,  8.13s/it]

tensor([-2.0739e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0739e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 78%|███████▊  | 78/100 [18:03<02:56,  8.00s/it]

tensor([-2.0621e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 79%|███████▉  | 79/100 [18:11<02:49,  8.08s/it]

tensor([-2.0348e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0348e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 80%|████████  | 80/100 [18:19<02:41,  8.06s/it]

tensor([-2.0493e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0493e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 81%|████████  | 81/100 [18:27<02:32,  8.04s/it]

tensor([-2.0749e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 82%|████████▏ | 82/100 [18:35<02:23,  8.00s/it]

tensor([-1.9978e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 83%|████████▎ | 83/100 [18:43<02:16,  8.01s/it]

tensor([-2.0008e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 84%|████████▍ | 84/100 [18:50<02:06,  7.91s/it]

tensor([-2.0593e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 85%|████████▌ | 85/100 [18:58<01:58,  7.88s/it]

tensor([-2.0741e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 86%|████████▌ | 86/100 [19:06<01:50,  7.87s/it]

tensor([-2.0477e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0477e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 87%|████████▋ | 87/100 [19:14<01:42,  7.85s/it]

tensor([-2.0812e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 88%|████████▊ | 88/100 [19:22<01:33,  7.78s/it]

tensor([-2.0396e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0396e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 89%|████████▉ | 89/100 [19:30<01:26,  7.87s/it]

tensor([-2.1575e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 90%|█████████ | 90/100 [19:37<01:18,  7.85s/it]

tensor([-2.1355e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1355e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 91%|█████████ | 91/100 [19:45<01:10,  7.86s/it]

tensor([-2.0679e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 92%|█████████▏| 92/100 [19:53<01:02,  7.76s/it]

tensor([-2.2145e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 93%|█████████▎| 93/100 [20:01<00:54,  7.77s/it]

tensor([-2.0994e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0994e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 94%|█████████▍| 94/100 [20:08<00:46,  7.80s/it]

tensor([-1.9272e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 95%|█████████▌| 95/100 [20:16<00:39,  7.82s/it]

tensor([-2.0553e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0553e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 96%|█████████▌| 96/100 [20:24<00:31,  7.83s/it]

tensor([-2.1868e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1868e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 97%|█████████▋| 97/100 [20:32<00:23,  7.91s/it]

tensor([-2.0207e+28], grad_fn=<ViewBackward0>) tensor([0.]) tensor(0., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 98%|█████████▊| 98/100 [20:40<00:15,  7.88s/it]

tensor([-2.1138e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.1138e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 99%|█████████▉| 99/100 [20:48<00:07,  8.00s/it]

tensor([-2.0372e+28], grad_fn=<ViewBackward0>) tensor([1.]) tensor(2.0372e+28, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


100%|██████████| 100/100 [20:56<00:00, 12.57s/it]

Finished training with loss 1.4883389570724512e+33 and accuracy 0.47





In [6]:
synaptic_matrix

<134191x134191 sparse matrix of type '<class 'numpy.int64'>'
	with 3871467 stored elements in COOrdinate format>

In [6]:
already_selected_validation = []
total_correct = 0
total = 0
running_loss = 0.0
validation_results = initialize_results_df()

validation_iterations = validation_length if validation_length is not None else len(validation_videos) // batch_size
for _ in tqdm(range(validation_iterations)):
    batch_files, already_selected_validation = select_random_videos(validation_videos, batch_size, already_selected_validation)

    labels = paths_to_labels(batch_files)
    batch_sequences = image_paths_to_sequences(batch_files)
    rendered_sequences = receptors(batch_sequences)
    layer_activations = []
    for rendered_sequence in rendered_sequences:
        simulation = network.simulate(rendered_sequence[None], dt)
        layer_activations.append(LayerActivity(simulation, network.connectome, keepref=True))

    voronoi_averages_df = compute_voronoi_averages(
        layer_activations, classification, DECODING_CELLS, last_good_frame=last_good_frame
    )
    # normalize column wise (except last column)
    values_cols = voronoi_averages_df.columns != "index_name"
    voronoi_averages_df.loc[:, values_cols] = voronoi_averages_df.loc[:, values_cols].apply(
        lambda x: (x - x.min()) / (x.max() - x.min()), axis=0
        )

    activation_df = from_retina_to_connectome(
        voronoi_averages_df, classification, root_id_to_index
    )
    del layer_activations, rendered_sequences, rendered_sequence, simulation
    torch.cuda.empty_cache()

    model.eval()
    with torch.no_grad():
        inputs = torch.tensor(activation_df.values, dtype=dtype, device=DEVICE)
        labels = torch.tensor(labels, dtype=dtype, device=DEVICE)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        predictions, batch_labels_cpu, correct = predictions_and_corrects_from_model_results(outputs, labels)
        validation_results = update_results_df(validation_results, batch_files, predictions, batch_labels_cpu, correct)
        running_loss += update_running_loss(loss, inputs)
        total += batch_labels_cpu.shape[0]
        total_correct += correct.sum().item()

if wandb_:
    log_running_stats_to_wandb(0, 0, running_loss, total_correct, total, validation_results)

print(
    f"Validation Loss: {running_loss / total}, "
    f"Validation Accuracy: {total_correct / total}"
)

100%|██████████| 50/50 [03:48<00:00,  4.56s/it]

Validation Loss: 0.0, Validation Accuracy: 1.0





In [12]:
total

1

In [8]:
validation_results

Unnamed: 0,Image,Prediction,True label,Is correct
0,images/easy_images/yellow/img2_15_10_51_equali...,0.0,1.0,0
0,images/easy_images/blue/img2_5_10_61.png,0.0,0.0,1
0,images/easy_images/blue/img2_8_12_83.png,0.0,0.0,1
0,images/easy_images/blue/img_12_16_116.png,0.0,0.0,1
0,images/easy_images/yellow/img_16_8_142.png,0.0,1.0,0
0,images/easy_images/yellow/img_6_4_100_equalize...,0.0,1.0,0
0,images/easy_images/blue/img_6_12_120.png,0.0,0.0,1
0,images/easy_images/yellow/img_12_9_64.png,0.0,1.0,0
0,images/easy_images/yellow/img_9_6_150.png,0.0,1.0,0
0,images/easy_images/blue/img_8_12_30.png,0.0,0.0,1


In [7]:
from pytorch_model_summary import summary
input_shape = torch.Size([134191, 1])
print(summary(model, torch.zeros(input_shape), show_input=True))

-------------------------------------------------------------------------------
              Layer (type)         Input Shape         Param #     Tr. Param #
   RetinaConnectionLayer-1         [134191, 1]      48,060,108      48,060,108
         AdultConnectome-2         [134191, 1]       3,871,467       3,871,467
                  Linear-3              [5467]           5,468           5,468
Total params: 51,937,043
Trainable params: 51,937,043
Non-trainable params: 0
-------------------------------------------------------------------------------
