In [7]:
import torch
from hand_to_neuro_SingleSessionSingleTrialDataset import SingleSessionSingleTrialDataset
import numpy as np
from pynwb import NWBHDF5IO

import os

dataset_path = "000070"
nwb_file_path = os.path.join(
    dataset_path, "sub-Jenkins", "sub-Jenkins_ses-20090916_behavior+ecephys.nwb")
io = NWBHDF5IO(nwb_file_path, 'r')
nwb_file = io.read()
hand_data = nwb_file.processing['behavior'].data_interfaces['Position']['Hand'].data[:]
hand_timestamps = nwb_file.processing['behavior'].data_interfaces['Position']['Hand'].timestamps[:]
trial_data = nwb_file.intervals['trials']

unit_spike_times = [nwb_file.units[unit_id]['spike_times'].iloc[0][:]
                    for unit_id in range(len(nwb_file.units))]
n_neurons = len(unit_spike_times)
n_future_vel_bins = 20

trials_start_from = int(2000 * 0.9)
n_trials = int(2000 * 0.01)
datasets = [SingleSessionSingleTrialDataset(
    trial_data, hand_data, hand_timestamps, unit_spike_times, trial_id, bin_size=0.02, n_future_vel_bins=n_future_vel_bins) for trial_id in range(trials_start_from, trials_start_from + n_trials)]
dataset = torch.utils.data.ConcatDataset(datasets)
print(f"Dataset from {n_trials} trials has {len(dataset)} samples")

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Dataset from 20 trials has 20 samples


In [8]:
from hand_to_neuro_models import TransformerModel
from hand_to_neuro_dataloaders import get_max_trial_length

n_fr_bins = 9
d_model = 512
latent_dim = None
model_type = "transformer"  # transformer, lstm


n_trials = 200
n_epochs = 200
lr = 0.0005
weight_decay = 0.0


prefix = f"{model_type}_dm{d_model}"
if latent_dim is not None:
    prefix += f"_ld{latent_dim}"
prefix += f"_lr{lr}_wd{weight_decay}"
os.makedirs('model_data', exist_ok=True)
n_future_vel_bins = 20
n_fr_bins = 9
bin_size = 0.02


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


max_trial_length = 206#get_max_trial_length(dataset, bin_size, min_max_trial_length_seconds=4)


input_size = (n_neurons) + 2 * n_future_vel_bins
hidden_size = d_model
model = TransformerModel(input_size, hidden_size,
                         n_neurons, n_fr_bins, max_trial_length).to(device)
checkpoint = torch.load(f'{prefix}_epoch{n_epochs}.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()


Using device: cpu


TransformerModel(
  (input_projection): Linear(in_features=232, out_features=512, bias=True)
  (pos_encoder): PositionalEncoding()
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
      )
    )
  )
  (output_projection): Linear(in_features=512, out_features=1728, bias=True)
  (unflatten): Unflatten(dim=2, unflattened_size=(192, 9))
)

In [6]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Convert dataset to PyTorch tensors and move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
future_vels = []
spikes = []
spikes_future = []
for i in range(len(dataset)):
    future_vel, spike, spike_future = dataset[i]
    future_vels.append(future_vel)
    spikes.append(spike) 
    spikes_future.append(spike_future)
future_vels = torch.stack(future_vels).to(device)
spikes = torch.stack(spikes).to(device)
spikes_future = torch.stack(spikes_future).to(device)
print("future_vels.shape", future_vels.shape, "spikes.shape", spikes.shape, "spikes_future.shape", spikes_future.shape)


ValueError: too many values to unpack (expected 2)

In [11]:
# Print size of each key in the checkpoint and analyze what's taking up space
import sys
checkpoint = torch.load(f'{prefix}_epoch{n_epochs}.pt', map_location=device)

total_size = 0
sizes = {}

# Calculate size of each component
for key, value in checkpoint.items():
    if isinstance(value, torch.Tensor):
        size_bytes = value.element_size() * value.nelement()
        sizes[key] = size_bytes
    elif isinstance(value, dict):
        total_bytes = 0
        for k, v in value.items():
            if isinstance(v, torch.Tensor):
                total_bytes += v.element_size() * v.nelement()
        sizes[key] = total_bytes
    elif isinstance(value, list):
        size_bytes = sum(sys.getsizeof(item) for item in value)
        sizes[key] = size_bytes
    else:
        sizes[key] = sys.getsizeof(value)
    total_size += sizes[key]

# Print sizes sorted by largest first
print(f"Total checkpoint size: {total_size / (1024 * 1024):.2f} MB\n")
print("Breakdown by component:")
for key, size in sorted(sizes.items(), key=lambda x: x[1], reverse=True):
    pct = (size / total_size) * 100
    print(f"{key}: {size / (1024 * 1024):.2f} MB ({pct:.1f}%)")


Total checkpoint size: 52.52 MB

Breakdown by component:
model_state_dict: 52.50 MB (100.0%)
train_losses: 0.00 MB (0.0%)
val_losses: 0.00 MB (0.0%)
test_accs: 0.00 MB (0.0%)
epoch: 0.00 MB (0.0%)
train_loss: 0.00 MB (0.0%)
val_loss: 0.00 MB (0.0%)
test_acc: 0.00 MB (0.0%)
optimizer_state_dict: 0.00 MB (0.0%)


In [13]:
# Or if you only need the model weights
torch.save(model.state_dict(), f'{prefix}_epoch{n_epochs}.pt')