In [3]:
import torch
from hand_to_neuro_SingleSessionSingleTrialDataset import SingleSessionSingleTrialDataset
import numpy as np
from pynwb import NWBHDF5IO

import os

dataset_path = "000070"
nwb_file_path = os.path.join(
    dataset_path, "sub-Jenkins", "sub-Jenkins_ses-20090916_behavior+ecephys.nwb")
io = NWBHDF5IO(nwb_file_path, 'r')
nwb_file = io.read()
hand_data = nwb_file.processing['behavior'].data_interfaces['Position']['Hand'].data[:]
hand_timestamps = nwb_file.processing['behavior'].data_interfaces['Position']['Hand'].timestamps[:]
trial_data = nwb_file.intervals['trials']

unit_spike_times = [nwb_file.units[unit_id]['spike_times'].iloc[0][:]
                    for unit_id in range(len(nwb_file.units))]
n_neurons = len(unit_spike_times)
n_future_vel_bins = 20

trials_start_from = int(2000 * 0.5)
n_trials = int(2000 * 0.01)
datasets = [SingleSessionSingleTrialDataset(
    trial_data, hand_data, hand_timestamps, unit_spike_times, trial_id, bin_size=0.02, n_future_vel_bins=n_future_vel_bins) for trial_id in range(trials_start_from, trials_start_from + n_trials)]
dataset = torch.utils.data.ConcatDataset(datasets)
print(f"Dataset from {n_trials} trials has {len(dataset)} samples")

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Dataset from 20 trials has 20 samples


In [14]:
from hand_to_neuro_models import TransformerModel
from hand_to_neuro_dataloaders import get_max_trial_length

n_fr_bins = 9
d_model = 512
latent_dim = None
model_type = "transformer"  # transformer, lstm


n_trials = 200
n_epochs = 200
lr = 0.0005
weight_decay = 0.0


prefix = f"{model_type}_dm{d_model}"
if latent_dim is not None:
    prefix += f"_ld{latent_dim}"
prefix += f"_lr{lr}_wd{weight_decay}"
os.makedirs('model_data', exist_ok=True)
n_future_vel_bins = 20
n_fr_bins = 9
bin_size = 0.02


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


max_trial_length = 206#get_max_trial_length(dataset, bin_size, min_max_trial_length_seconds=4)


input_size = (n_neurons) + 2 * n_future_vel_bins
hidden_size = d_model
model = TransformerModel(input_size, hidden_size,
                         n_neurons, n_fr_bins, max_trial_length).to(device)
checkpoint = torch.load(f'{prefix}_epoch{n_epochs}.pt', map_location=device)
model.load_state_dict(checkpoint)



model.eval()


Using device: cuda


TransformerModel(
  (input_projection): Linear(in_features=232, out_features=512, bias=True)
  (pos_encoder): PositionalEncoding()
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
      )
    )
  )
  (output_projection): Linear(in_features=512, out_features=1728, bias=True)
  (unflatten): Unflatten(dim=2, unflattened_size=(192, 9))
)

In [15]:
from torch import nn
# Define forward model (taken from neuro_to_hand_visualize.ipynb)
forward_model_input_size = n_neurons * 50
forward_model = nn.Sequential(
    nn.Linear(forward_model_input_size, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 2)
).to(device)

# Load the trained model
forward_model.load_state_dict(torch.load('neuro_to_hand_mlp.pth'))
forward_model.eval()  # Set the model to evaluation mode

Sequential(
  (0): Linear(in_features=9600, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ReLU()
  (4): Linear(in_features=128, out_features=2, bias=True)
)

In [29]:
import pygame
import numpy as np
import time

n_context_bins = 50 # how many bins from some random trial to give the model before the model prediction starts (to give some context for the model)
current_bin = n_context_bins
max_n_bins = 200

# # Get data from first trial
# test_dataset = dataset
# future_velocities, spikes, _ = test_dataset[1]
# spikes = spikes.to(device)  # spikes.shape = (n_timesteps, n_neurons)
# future_velocities = future_velocities.to(device)  # velocities.shape = (n_timesteps, 2, n_future_vel_bins)
# spikes[n_context_bins:] = 0 # remove all spikes from the future that the model will predict
# max_n_bins = len(spikes) # for testing, set it to the length of the trial (because for now we are not using the robot velocities, just using dataset velocities)

# Get data from first 20 trials and concatenate
test_dataset = dataset
all_spikes = []
all_velocities = []
for i in range(3, 20):
    future_velocities, spikes, _ = test_dataset[i]
    all_spikes.append(spikes)
    all_velocities.append(future_velocities)

# Concatenate along time dimension
spikes = torch.cat(all_spikes, dim=0).to(device)  # spikes.shape = (n_timesteps * 20, n_neurons) 
future_velocities = torch.cat(all_velocities, dim=0).to(device)  # velocities.shape = (n_timesteps * 20, 2, n_future_vel_bins)
pred_velocities = np.zeros((len(spikes), 2))

spikes[n_context_bins:] = 0  # remove all spikes from the future that the model will predict
max_n_bins = len(spikes)  # for testing, set it to the length of all concatenated trials

# Initialize Pygame
pygame.init()

# Set up display
fullscreen = True
if fullscreen:
    screen = pygame.display.set_mode((0, 0), pygame.FULLSCREEN)
    WIDTH, HEIGHT = screen.get_size()
else:
    # Set up display
    WIDTH = 1200
    HEIGHT = 800  # Increased height to accommodate velocity plots
    screen = pygame.display.set_mode((WIDTH, HEIGHT))

pygame.display.set_caption("Neural Spike Train and Velocity Visualization")

# Colors
BLACK = (0, 0, 0)
GRAY = (140, 140, 140)  # For grid lines
WHITE = (255, 255, 255)
DARK_GRAY = (40, 40, 40)  # Darker gray for velocity lines

# X offset for plots
X_OFFSET = 45


window_size = 500
bin_step = 1  # Number of bins to advance each frame (1 bin = 20ms)

# Calculate scaling factors
spike_plot_height = HEIGHT // 5 * 4
neuron_height = spike_plot_height // n_neurons
time_bin_width = WIDTH // window_size
plot_height = HEIGHT // 10  # Reduced height for each velocity plot

def normalize_for_plot(value, height):
    # Normalize values to fit in plot height
    return height // 2 + (value * height // 28)


# Create font for labels
font = pygame.font.SysFont('arial', 24)

running = True
clock = pygame.time.Clock()

# Pre-create surface for spike data
spike_surface = pygame.Surface((WIDTH, spike_plot_height))

while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
        elif event.type == pygame.KEYDOWN:
            if event.key == pygame.K_ESCAPE:  # Add escape key to exit
                running = False
            
    # Clear screen
    screen.fill(BLACK)
    spike_surface.fill(BLACK)

    # Get model predictions for next timestep
    with torch.no_grad():
        give_context_from = max(current_bin-max_trial_length, 0)
        give_context_to = current_bin
        outputs = model(spikes[give_context_from:give_context_to].unsqueeze(0), future_velocities[give_context_from:give_context_to].unsqueeze(0)) # shape: [1, n_timesteps, n_neurons, n_fr_bins]
        pred_probs = torch.softmax(outputs, dim=3) # shape: [1, n_timesteps, n_neurons, n_fr_bins]
        pred_sample = torch.multinomial(pred_probs.reshape(-1, n_fr_bins), 1) 
        pred_sample = pred_sample.reshape(outputs.shape[0], outputs.shape[1], outputs.shape[2]) # shape: [1, n_timesteps, n_neurons]
        last_sample = pred_sample[:, -1, :].squeeze(0) # shape: [n_neurons, ] -- this is the prediction for the spikes in the future timestep

        spikes[current_bin] = last_sample #spikes[give_context_from:give_context_to]zed
    
    show_spikes_from = max(current_bin+n_future_vel_bins-window_size, n_context_bins)
    show_spikes_to = min(current_bin+n_future_vel_bins, len(spikes))
    # Draw spike trains first using numpy operations
    # Get the current window of spike data
    spikes_numpy = spikes.cpu().numpy().T # shape: [n_neurons, n_timesteps]
    spike_data_normalized = spikes_numpy / 8
    window_data = spike_data_normalized[:, show_spikes_from:show_spikes_to]
    
    # Convert to pixel values (0-255)
    pixel_values = np.minimum(window_data * 255 * 1.5, 255).astype(np.uint8)
    
    # Create a surface from the numpy array
    for neuron in range(n_neurons):
        row_data = pixel_values[neuron]
        for t, intensity in enumerate(row_data):
            if intensity > 0:  # Only draw if there's activity
                pygame.draw.rect(spike_surface, (intensity, intensity, intensity),
                                (X_OFFSET + t * time_bin_width, neuron * neuron_height,
                                time_bin_width, neuron_height))
    
    # Draw the spike surface to the screen
    screen.blit(spike_surface, (0, 0))
    
    # Draw grid lines and channel numbers on top
    for i in range(0, spike_plot_height, neuron_height * n_neurons):  # Draw every 200 channels
        pygame.draw.line(screen, GRAY, (X_OFFSET, i), (WIDTH, i), 1)
        # Draw channel number
        label = font.render(str(i // neuron_height), True, WHITE)
        # Rotate the label surface
        rotated_label = pygame.transform.rotate(label, 90)
        screen.blit(rotated_label, (10, i))
    
    # Draw velocity plots
    true_velocities_numpy = future_velocities.cpu().numpy()[:, :, 0] * 200
    y_pred = forward_model(spikes[current_bin-50:current_bin].T.reshape(-1).unsqueeze(0)).detach().cpu().numpy()
    pred_velocities[current_bin] = y_pred.flatten() * 200

    y_offset = spike_plot_height - 10  # Start below spike plot
    
    # Draw X velocity plot
    pygame.draw.line(screen, DARK_GRAY, (X_OFFSET, y_offset + plot_height//2), (WIDTH, y_offset + plot_height//2), 1)

    show_vel_from = max(current_bin+n_future_vel_bins-window_size, n_context_bins)
    show_vel_to = min(current_bin+n_future_vel_bins+1, len(spikes))
    # Pre-calculate positions for velocity plots
    t_range = np.arange(show_vel_to-show_vel_from-1)
    max_pred_show_t = max(len(t_range) - n_future_vel_bins, 0)
    x_coords = X_OFFSET + t_range * time_bin_width
    x_coords_next = X_OFFSET + (t_range + 1) * time_bin_width
    
    # X velocity
    true_vel_x = y_offset + normalize_for_plot(true_velocities_numpy[show_vel_from:show_vel_to-1, 0], plot_height) # get the first X velocity from the future (aka current/next velocity)
    true_vel_x_next = y_offset + normalize_for_plot(true_velocities_numpy[show_vel_from+1:show_vel_to, 0], plot_height)
    pred_vel_x = y_offset + normalize_for_plot(pred_velocities[show_vel_from:show_vel_to-1, 0], plot_height)
    pred_vel_x_next = y_offset + normalize_for_plot(pred_velocities[show_vel_from+1:show_vel_to, 0], plot_height)
    
    # Draw lines in batches
    for i in range(len(x_coords)):
        pygame.draw.line(screen, GRAY, 
                        (int(x_coords[i]), int(true_vel_x[i])), 
                        (int(x_coords_next[i]), int(true_vel_x_next[i])), 2)
        if i >= max_pred_show_t: continue
        pygame.draw.line(screen, WHITE,
                        (int(x_coords[i]), int(pred_vel_x[i])),
                        (int(x_coords_next[i]), int(pred_vel_x_next[i])), 2)
    
    # Draw Y velocity plot
    y_offset += plot_height + 10
    pygame.draw.line(screen, DARK_GRAY, (X_OFFSET, y_offset + plot_height//2), (WIDTH, y_offset + plot_height//2), 1)
    
    # Y velocity
    true_vel_y = y_offset + normalize_for_plot(true_velocities_numpy[show_vel_from:show_vel_to-1, 1], plot_height)
    true_vel_y_next = y_offset + normalize_for_plot(true_velocities_numpy[show_vel_from+1:show_vel_to, 1], plot_height)
    pred_vel_y = y_offset + normalize_for_plot(pred_velocities[show_vel_from:show_vel_to-1, 1], plot_height)
    pred_vel_y_next = y_offset + normalize_for_plot(pred_velocities[show_vel_from+1:show_vel_to, 1], plot_height)
    
    # Draw lines in batches
    for i in range(len(x_coords)):
        pygame.draw.line(screen, GRAY,
                        (int(x_coords[i]), int(true_vel_y[i])),
                        (int(x_coords_next[i]), int(true_vel_y_next[i])), 2)
        if i >= max_pred_show_t: continue
        pygame.draw.line(screen, WHITE,
                        (int(x_coords[i]), int(pred_vel_y[i])),
                        (int(x_coords_next[i]), int(pred_vel_y_next[i])), 2)
    
    # Draw axis labels
    time_label = font.render("Time", True, WHITE)
    channels_label = font.render("Channels", True, WHITE)
    x_vel_label = font.render("velocity X (prediction: WHITE)", True, WHITE)
    y_vel_label = font.render("velocity Y (prediction: WHITE)", True, WHITE)
    
    screen.blit(time_label, (WIDTH // 2 - 30, HEIGHT - 30))
    screen.blit(x_vel_label, (WIDTH // 2 - 150, spike_plot_height - 10))
    screen.blit(y_vel_label, (WIDTH // 2 - 150, spike_plot_height + plot_height - 15))
    
    # Rotate and draw y-axis label
    channels_surface = pygame.Surface((200, 30))
    channels_surface.fill(BLACK)
    channels_surface.blit(channels_label, (50, 0))
    channels_surface = pygame.transform.rotate(channels_surface, 90)
    screen.blit(channels_surface, (10, spike_plot_height // 2 - 100))
    
    # Update display
    pygame.display.flip()
    
    # Move window by one bin (20ms) each frame
    current_bin += bin_step
    if current_bin == len(spikes):
        break

    # Control frame rate to 50 FPS (20ms per frame)
    clock.tick(50)

pygame.quit()


# Old Code


In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Convert dataset to PyTorch tensors and move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
future_vels = []
spikes = []
spikes_future = []
for i in range(len(dataset)):
    future_vel, spike, spike_future = dataset[i]
    future_vels.append(future_vel)
    spikes.append(spike) 
    spikes_future.append(spike_future)
future_vels = torch.stack(future_vels).to(device)
spikes = torch.stack(spikes).to(device)
spikes_future = torch.stack(spikes_future).to(device)
print("future_vels.shape", future_vels.shape, "spikes.shape", spikes.shape, "spikes_future.shape", spikes_future.shape)


In [5]:
# This code resaves the checkpoint to take less space (for old train code, not used anymore)
model = TransformerModel(input_size, hidden_size,
                         n_neurons, n_fr_bins, max_trial_length).to(device)
checkpoint = torch.load(f'{prefix}_epoch{n_epochs}.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
torch.save(model.state_dict(), f'{prefix}_epoch{n_epochs}.pt')