In [1]:
import torch
from hand_to_neuro_SingleSessionSingleTrialDataset import SingleSessionSingleTrialDataset
import numpy as np
from pynwb import NWBHDF5IO

import os

dataset_path = "000070"
nwb_file_path = os.path.join(
    dataset_path, "sub-Jenkins", "sub-Jenkins_ses-20090916_behavior+ecephys.nwb")
nwb_file_path = "sub-Jenkins_ses-20090923_behavior+ecephys.nwb"
io = NWBHDF5IO(nwb_file_path, 'r')
nwb_file = io.read()
hand_data = nwb_file.processing['behavior'].data_interfaces['Position']['Hand'].data[:]
hand_timestamps = nwb_file.processing['behavior'].data_interfaces['Position']['Hand'].timestamps[:]
trial_data = nwb_file.intervals['trials']

unit_spike_times = [nwb_file.units[unit_id]['spike_times'].iloc[0][:]
                    for unit_id in range(len(nwb_file.units))]
n_neurons = len(unit_spike_times)
n_context_bins = 50

trials_start_from = int(2000 * 0.9)
n_trials = int(2000 * 0.1)
datasets = [SingleSessionSingleTrialDataset(
    trial_data, hand_data, hand_timestamps, unit_spike_times, trial_id, bin_size=0.02, n_context_bins=n_context_bins) for trial_id in range(trials_start_from, trials_start_from + n_trials)]
dataset = torch.utils.data.ConcatDataset(datasets)
print(f"Dataset from {n_trials} trials has {len(dataset)} samples")

FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = 'sub-Jenkins_ses-20090923_behavior+ecephys.nwb', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [4]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Convert dataset to PyTorch tensors and move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X = []
y = []
for i in range(len(dataset)):
    features, labels = dataset[i]
    X.append(features[:].flatten())
    y.append(labels)
X_test = torch.stack(X).to(device)
y_test = torch.stack(y).to(device)

In [5]:
# Define model
input_size = n_neurons * n_context_bins
model = nn.Sequential(
    nn.Linear(input_size, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 2)
).to(device)

# Load the trained model
model.load_state_dict(torch.load('hand_to_neuro_mlp.pth'))
model.eval()  # Set the model to evaluation mode


Sequential(
  (0): Linear(in_features=9600, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ReLU()
  (4): Linear(in_features=128, out_features=2, bias=True)
)

In [2]:
import pygame
import numpy as np
import time

# Initialize Pygame
pygame.init()

# # Set up display
# screen = pygame.display.set_mode((0, 0), pygame.FULLSCREEN)
# WIDTH, HEIGHT = screen.get_size()
# Set up display
WIDTH = 1200
HEIGHT = 800  # Increased height to accommodate velocity plots
screen = pygame.display.set_mode((WIDTH, HEIGHT))

pygame.display.set_caption("Neural Spike Train and Velocity Visualization")

# Colors
BLACK = (0, 0, 0)
GRAY = (140, 140, 140)  # For grid lines
WHITE = (255, 255, 255)
DARK_GRAY = (40, 40, 40)  # Darker gray for velocity lines

# X offset for plots
X_OFFSET = 45

# Get spike data
spike_data = X_test.reshape(-1, n_neurons, n_context_bins)[:, :, -1].T
total_bins = spike_data.shape[1]
window_size = 500
bin_step = 1  # Number of bins to advance each frame (1 bin = 20ms)

# Get model predictions
y_pred = model(torch.tensor(X_test, dtype=torch.float32)).detach().numpy()
y_true = y_test.detach().numpy()

# Normalize predictions and true values to have mean 1
y_pred = y_pred / np.mean(np.abs(y_pred))
y_true = y_true / np.mean(np.abs(y_true))

# Calculate scaling factors
spike_plot_height = HEIGHT // 4 * 3
neuron_height = spike_plot_height // n_neurons
time_bin_width = WIDTH // window_size
plot_height = HEIGHT // 10  # Reduced height for each velocity plot

# Normalize spike data for color intensity
spike_data_normalized = (spike_data - spike_data.min()) / (spike_data.max() - spike_data.min())
spike_data_normalized = spike_data_normalized.detach().numpy()

# Create font for labels
font = pygame.font.SysFont('arial', 24)

def normalize_for_plot(value, height):
    # Normalize values to fit in plot height
    return height // 2 + (value * height // 28)

running = True
current_bin = 0
clock = pygame.time.Clock()

# Pre-create surface for spike data
spike_surface = pygame.Surface((WIDTH, spike_plot_height))

while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
        elif event.type == pygame.KEYDOWN:
            if event.key == pygame.K_ESCAPE:  # Add escape key to exit
                running = False
            
    # Clear screen
    screen.fill(BLACK)
    spike_surface.fill(BLACK)
    
    # Draw spike trains first using numpy operations
    if current_bin + window_size <= total_bins:
        # Get the current window of spike data
        window_data = spike_data_normalized[:, current_bin:current_bin+window_size]
        
        # Convert to pixel values (0-255)
        pixel_values = np.minimum(window_data * 255 * 1.5, 255).astype(np.uint8)
        
        # Create a surface from the numpy array
        for neuron in range(n_neurons):
            row_data = pixel_values[neuron]
            for t, intensity in enumerate(row_data):
                if intensity > 0:  # Only draw if there's activity
                    pygame.draw.rect(spike_surface, (intensity, intensity, intensity),
                                   (X_OFFSET + t * time_bin_width, neuron * neuron_height,
                                    time_bin_width, neuron_height))
    
    # Draw the spike surface to the screen
    screen.blit(spike_surface, (0, 0))
    
    # Draw grid lines and channel numbers on top
    for i in range(0, spike_plot_height, neuron_height * n_neurons):  # Draw every 200 channels
        pygame.draw.line(screen, GRAY, (X_OFFSET, i), (WIDTH, i), 1)
        # Draw channel number
        label = font.render(str(i // neuron_height), True, WHITE)
        # Rotate the label surface
        rotated_label = pygame.transform.rotate(label, 90)
        screen.blit(rotated_label, (10, i))
    
    # Draw velocity plots
    y_offset = spike_plot_height - 80  # Start below spike plot
    
    # Draw X velocity plot
    pygame.draw.line(screen, DARK_GRAY, (X_OFFSET, y_offset + plot_height//2), (WIDTH, y_offset + plot_height//2), 1)
    
    # Pre-calculate positions for velocity plots
    if current_bin + window_size <= len(y_true):
        t_range = np.arange(window_size-1)
        x_coords = X_OFFSET + t_range * time_bin_width
        x_coords_next = X_OFFSET + (t_range + 1) * time_bin_width
        
        # X velocity
        true_y_x = y_offset + normalize_for_plot(y_true[current_bin:current_bin+window_size-1, 0], plot_height)
        true_y_x_next = y_offset + normalize_for_plot(y_true[current_bin+1:current_bin+window_size, 0], plot_height)
        pred_y_x = y_offset + normalize_for_plot(y_pred[current_bin:current_bin+window_size-1, 0], plot_height)
        pred_y_x_next = y_offset + normalize_for_plot(y_pred[current_bin+1:current_bin+window_size, 0], plot_height)
        
        # Draw lines in batches
        for i in range(len(x_coords)):
            pygame.draw.line(screen, GRAY, 
                           (int(x_coords[i]), int(true_y_x[i])), 
                           (int(x_coords_next[i]), int(true_y_x_next[i])), 2)
            pygame.draw.line(screen, WHITE,
                           (int(x_coords[i]), int(pred_y_x[i])),
                           (int(x_coords_next[i]), int(pred_y_x_next[i])), 2)
    
    # Draw Y velocity plot
    y_offset += plot_height + 110
    pygame.draw.line(screen, DARK_GRAY, (X_OFFSET, y_offset + plot_height//2), (WIDTH, y_offset + plot_height//2), 1)
    
    if current_bin + window_size <= len(y_true):
        # Y velocity
        true_y_y = y_offset + normalize_for_plot(y_true[current_bin:current_bin+window_size-1, 1], plot_height)
        true_y_y_next = y_offset + normalize_for_plot(y_true[current_bin+1:current_bin+window_size, 1], plot_height)
        pred_y_y = y_offset + normalize_for_plot(y_pred[current_bin:current_bin+window_size-1, 1], plot_height)
        pred_y_y_next = y_offset + normalize_for_plot(y_pred[current_bin+1:current_bin+window_size, 1], plot_height)
        
        # Draw lines in batches
        for i in range(len(x_coords)):
            pygame.draw.line(screen, GRAY,
                           (int(x_coords[i]), int(true_y_y[i])),
                           (int(x_coords_next[i]), int(true_y_y_next[i])), 2)
            pygame.draw.line(screen, WHITE,
                           (int(x_coords[i]), int(pred_y_y[i])),
                           (int(x_coords_next[i]), int(pred_y_y_next[i])), 2)
    
    # Draw axis labels
    time_label = font.render("Time", True, WHITE)
    channels_label = font.render("Channels", True, WHITE)
    x_vel_label = font.render("velocity X (prediction: WHITE)", True, WHITE)
    y_vel_label = font.render("velocity Y (prediction: WHITE)", True, WHITE)
    
    screen.blit(time_label, (WIDTH // 2 - 30, HEIGHT - 30))
    screen.blit(x_vel_label, (WIDTH // 2 - 150, spike_plot_height - 90))
    screen.blit(y_vel_label, (WIDTH // 2 - 150, spike_plot_height + plot_height - 15))
    
    # Rotate and draw y-axis label
    channels_surface = pygame.Surface((200, 30))
    channels_surface.fill(BLACK)
    channels_surface.blit(channels_label, (50, 0))
    channels_surface = pygame.transform.rotate(channels_surface, 90)
    screen.blit(channels_surface, (10, spike_plot_height // 2 - 100))
    
    # Update display
    pygame.display.flip()
    
    # Move window by one bin (20ms) each frame
    current_bin += bin_step
    if current_bin + window_size > total_bins:
        current_bin = 0

    lambda_decay = 0.95  # Exponential decay factor
    integrated_v = np.zeros(2)
    for i in range(current_bin + window_size):
        integrated_v = lambda_decay * integrated_v + y_pred[i]
    integrated_v = integrated_v / 70
    integrated_v = integrated_v.clip(-1, 1)
    integrated_v = integrated_v * 10
    # now integrated v is always in the window -10cm to 10cm

    # Control frame rate to 50 FPS (20ms per frame)
    clock.tick(50)

pygame.quit()


pygame 2.6.1 (SDL 2.28.4, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


NameError: name 'X_test' is not defined

# OLD CODE

In [76]:
import pygame
import numpy as np
import time

# Initialize Pygame
pygame.init()

# Set up display
WIDTH = 1200
HEIGHT = 800
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Neural Spike Train Visualization")

# Colors
BLACK = (0, 0, 0)
GRAY = (40, 40, 40)  # For grid lines
WHITE = (255, 255, 255)

# Get spike data
spike_data = X_test.reshape(-1, n_neurons, n_context_bins)[:, :, -1].T
total_bins = spike_data.shape[1]
window_size = 100

# Calculate scaling factors
neuron_height = HEIGHT // n_neurons
time_bin_width = WIDTH // window_size

# Normalize spike data for color intensity
spike_data_normalized = (spike_data - spike_data.min()) / (spike_data.max() - spike_data.min())

# Create font for labels
font = pygame.font.Font(None, 36)

running = True
current_bin = 0
clock = pygame.time.Clock()

while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
            
    # Clear screen
    screen.fill(BLACK)
    
    # Draw grid lines
    for i in range(0, HEIGHT, neuron_height * 200):  # Draw every 200 channels
        pygame.draw.line(screen, GRAY, (0, i), (WIDTH, i), 1)
        # Draw channel number
        label = font.render(str(i // neuron_height), True, WHITE)
        screen.blit(label, (10, i))
    
    # Draw spike trains
    if current_bin + window_size <= total_bins:
        for neuron in range(n_neurons):
            for t in range(window_size):
                intensity = int(spike_data_normalized[neuron, current_bin + t] * 255)
                color = (intensity, intensity, intensity)  # White/gray scale
                if intensity > 0:  # Only draw if there's activity
                    pygame.draw.rect(screen, color,
                                   (t * time_bin_width, neuron * neuron_height,
                                    time_bin_width, neuron_height))
    
    # Draw axis labels
    time_label = font.render("Time", True, WHITE)
    channels_label = font.render("Channels", True, WHITE)
    screen.blit(time_label, (WIDTH // 2 - 30, HEIGHT - 30))
    # Rotate and draw y-axis label
    channels_surface = pygame.Surface((200, 30))
    channels_surface.fill(BLACK)
    channels_surface.blit(channels_label, (0, 0))
    channels_surface = pygame.transform.rotate(channels_surface, 90)
    screen.blit(channels_surface, (10, HEIGHT // 2 - 100))
    
    # Update display
    pygame.display.flip()
    
    # Move window
    current_bin += 1
    if current_bin + window_size > total_bins:
        current_bin = 0
        
    # Control frame rate
    clock.tick(20)  # 30 FPS

pygame.quit()


In [77]:
import pygame
import numpy as np
import time

# Initialize Pygame
pygame.init()

# Set up display
WIDTH = 1200
HEIGHT = 900  # Increased height to accommodate velocity plots
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Neural Spike Train and Velocity Visualization")

# Colors
BLACK = (0, 0, 0)
GRAY = (140, 140, 140)  # For grid lines
WHITE = (255, 255, 255)
DARK_GRAY = (40, 40, 40)  # Darker gray for velocity lines

# Get spike data
spike_data = X_test.reshape(-1, n_neurons, n_context_bins)[:, :, -1].T
total_bins = spike_data.shape[1]
window_size = 100
bin_step = 1  # Number of bins to advance each frame (1 bin = 20ms)

# Get model predictions
y_pred = model(torch.tensor(X_test, dtype=torch.float32)).detach().numpy()
y_true = y_test.detach().numpy()

# Normalize predictions and true values to have mean 1
y_pred = y_pred / np.mean(np.abs(y_pred))
y_true = y_true / np.mean(np.abs(y_true))

# Calculate scaling factors
spike_plot_height = HEIGHT // 4 * 3
neuron_height = spike_plot_height // n_neurons
time_bin_width = WIDTH // window_size
plot_height = HEIGHT // 10  # Reduced height for each velocity plot

# Normalize spike data for color intensity
spike_data_normalized = (spike_data - spike_data.min()) / (spike_data.max() - spike_data.min())
# Create font for labels
font = pygame.font.SysFont('arial', 24)

def normalize_for_plot(value, height):
    # Normalize values to fit in plot height
    return height // 2 + (value * height // 20)

running = True
current_bin = 0
clock = pygame.time.Clock()

while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
            
    # Clear screen
    screen.fill(BLACK)
    
    # Draw spike trains first
    if current_bin + window_size <= total_bins:
        for neuron in range(n_neurons):
            for t in range(window_size):
                intensity = int(spike_data_normalized[neuron, current_bin + t] * 255) * 1.5
                if intensity > 255:
                    intensity = 255
                color = (intensity, intensity, intensity)  # White/gray scale
                if intensity > 0:  # Only draw if there's activity
                    pygame.draw.rect(screen, color,
                                   (t * time_bin_width, neuron * neuron_height,
                                    time_bin_width, neuron_height))
    
    # Draw grid lines and channel numbers on top
    for i in range(0, spike_plot_height, neuron_height * n_neurons):  # Draw every 200 channels
        pygame.draw.line(screen, GRAY, (0, i), (WIDTH, i), 1)
        # Draw channel number
        label = font.render(str(i // neuron_height), True, WHITE)
        # Rotate the label surface
        rotated_label = pygame.transform.rotate(label, 90)
        screen.blit(rotated_label, (10, i))
    
    # Draw velocity plots
    y_offset = spike_plot_height - 40  # Start below spike plot
    
    # Draw X velocity plot
    pygame.draw.line(screen, DARK_GRAY, (0, y_offset + plot_height//2), (WIDTH, y_offset + plot_height//2), 1)
    for t in range(window_size-1):
        if current_bin + t + 1 < len(y_true):
            # True X velocity
            start_y = y_offset + normalize_for_plot(y_true[current_bin + t, 0], plot_height)
            end_y = y_offset + normalize_for_plot(y_true[current_bin + t + 1, 0], plot_height)
            start_pos = (int(t * time_bin_width), int(start_y))
            end_pos = (int((t + 1) * time_bin_width), int(end_y))
            pygame.draw.line(screen, GRAY, start_pos, end_pos, 2)
            
            # Predicted X velocity
            start_y = y_offset + normalize_for_plot(y_pred[current_bin + t, 0], plot_height)
            end_y = y_offset + normalize_for_plot(y_pred[current_bin + t + 1, 0], plot_height)
            start_pos = (int(t * time_bin_width), int(start_y))
            end_pos = (int((t + 1) * time_bin_width), int(end_y))
            pygame.draw.line(screen, WHITE, start_pos, end_pos, 2)
    
    # Draw Y velocity plot
    y_offset += plot_height + 60
    pygame.draw.line(screen, DARK_GRAY, (0, y_offset + plot_height//2), (WIDTH, y_offset + plot_height//2), 1)
    for t in range(window_size-1):
        if current_bin + t + 1 < len(y_true):
            # True Y velocity
            start_y = y_offset + normalize_for_plot(y_true[current_bin + t, 1], plot_height)
            end_y = y_offset + normalize_for_plot(y_true[current_bin + t + 1, 1], plot_height)
            start_pos = (int(t * time_bin_width), int(start_y))
            end_pos = (int((t + 1) * time_bin_width), int(end_y))
            pygame.draw.line(screen, GRAY, start_pos, end_pos, 2)
            
            # Predicted Y velocity
            start_y = y_offset + normalize_for_plot(y_pred[current_bin + t, 1], plot_height)
            end_y = y_offset + normalize_for_plot(y_pred[current_bin + t + 1, 1], plot_height)
            start_pos = (int(t * time_bin_width), int(start_y))
            end_pos = (int((t + 1) * time_bin_width), int(end_y))
            pygame.draw.line(screen, WHITE, start_pos, end_pos, 2)
    
    # Draw axis labels
    time_label = font.render("Time", True, WHITE)
    channels_label = font.render("Channels", True, WHITE)
    x_vel_label = font.render("velocity X (prediction: WHITE)", True, WHITE)
    y_vel_label = font.render("velocity Y (prediction: WHITE)", True, WHITE)
    
    screen.blit(time_label, (WIDTH // 2 - 30, HEIGHT - 30))
    screen.blit(x_vel_label, (WIDTH // 2 - 150, spike_plot_height - 80))
    screen.blit(y_vel_label, (WIDTH // 2 - 150, spike_plot_height + plot_height - 25))
    
    # Rotate and draw y-axis label
    channels_surface = pygame.Surface((200, 30))
    channels_surface.fill(BLACK)
    channels_surface.blit(channels_label, (50, 0))
    channels_surface = pygame.transform.rotate(channels_surface, 90)
    screen.blit(channels_surface, (10, spike_plot_height // 2 - 100))
    
    # Update display
    pygame.display.flip()
    
    # Move window by one bin (20ms) each frame
    current_bin += bin_step
    if current_bin + window_size > total_bins:
        current_bin = 0
        
    # Control frame rate to 50 FPS (20ms per frame)
    clock.tick(50)

pygame.quit()


  y_pred = model(torch.tensor(X_test, dtype=torch.float32)).detach().numpy()


KeyboardInterrupt: 

In [73]:
import pygame
import numpy as np
import time

# Initialize Pygame
pygame.init()

# Set up display
WIDTH = 1200
HEIGHT = 900  # Increased height to accommodate velocity plots
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Neural Spike Train and Velocity Visualization")

# Colors
BLACK = (0, 0, 0)
GRAY = (140, 140, 140)  # For grid lines
WHITE = (255, 255, 255)
DARK_GRAY = (40, 40, 40)  # Darker gray for velocity lines

# Get spike data
spike_data = X_test.reshape(-1, n_neurons, n_context_bins)[:, :, -1].T
total_bins = spike_data.shape[1]
window_size = 100
bin_step = 1  # Number of bins to advance each frame (1 bin = 20ms)

# Get model predictions
y_pred = model(torch.tensor(X_test, dtype=torch.float32)).detach().numpy()
y_true = y_test.detach().numpy()

# Normalize predictions and true values to have mean 1
y_pred = y_pred / np.mean(np.abs(y_pred))
y_true = y_true / np.mean(np.abs(y_true))

# Calculate scaling factors
spike_plot_height = HEIGHT // 4 * 3
neuron_height = spike_plot_height // n_neurons
time_bin_width = WIDTH // window_size
plot_height = HEIGHT // 10  # Reduced height for each velocity plot

# Normalize spike data for color intensity
spike_data_normalized = (spike_data - spike_data.min()) / (spike_data.max() - spike_data.min())
spike_data_normalized = spike_data_normalized.detach().numpy()

# Create font for labels
font = pygame.font.SysFont('arial', 24)

def normalize_for_plot(value, height):
    # Normalize values to fit in plot height
    return height // 2 + (value * height // 20)

running = True
current_bin = 0
clock = pygame.time.Clock()

# Pre-create surface for spike data
spike_surface = pygame.Surface((WIDTH, spike_plot_height))

while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
            
    # Clear screen
    screen.fill(BLACK)
    spike_surface.fill(BLACK)
    
    # Draw spike trains first using numpy operations
    if current_bin + window_size <= total_bins:
        # Get the current window of spike data
        window_data = spike_data_normalized[:, current_bin:current_bin+window_size]
        
        # Convert to pixel values (0-255)
        pixel_values = np.minimum(window_data * 255 * 1.5, 255).astype(np.uint8)
        
        # Create a surface from the numpy array
        for neuron in range(n_neurons):
            row_data = pixel_values[neuron]
            for t, intensity in enumerate(row_data):
                if intensity > 0:  # Only draw if there's activity
                    pygame.draw.rect(spike_surface, (intensity, intensity, intensity),
                                   (t * time_bin_width, neuron * neuron_height,
                                    time_bin_width, neuron_height))
    
    # Draw the spike surface to the screen
    screen.blit(spike_surface, (0, 0))
    
    # Draw grid lines and channel numbers on top
    for i in range(0, spike_plot_height, neuron_height * n_neurons):  # Draw every 200 channels
        pygame.draw.line(screen, GRAY, (0, i), (WIDTH, i), 1)
        # Draw channel number
        label = font.render(str(i // neuron_height), True, WHITE)
        # Rotate the label surface
        rotated_label = pygame.transform.rotate(label, 90)
        screen.blit(rotated_label, (10, i))
    
    # Draw velocity plots
    y_offset = spike_plot_height - 40  # Start below spike plot
    
    # Draw X velocity plot
    pygame.draw.line(screen, DARK_GRAY, (0, y_offset + plot_height//2), (WIDTH, y_offset + plot_height//2), 1)
    
    # Pre-calculate positions for velocity plots
    if current_bin + window_size <= len(y_true):
        t_range = np.arange(window_size-1)
        x_coords = t_range * time_bin_width
        x_coords_next = (t_range + 1) * time_bin_width
        
        # X velocity
        true_y_x = y_offset + normalize_for_plot(y_true[current_bin:current_bin+window_size-1, 0], plot_height)
        true_y_x_next = y_offset + normalize_for_plot(y_true[current_bin+1:current_bin+window_size, 0], plot_height)
        pred_y_x = y_offset + normalize_for_plot(y_pred[current_bin:current_bin+window_size-1, 0], plot_height)
        pred_y_x_next = y_offset + normalize_for_plot(y_pred[current_bin+1:current_bin+window_size, 0], plot_height)
        
        # Draw lines in batches
        for i in range(len(x_coords)):
            pygame.draw.line(screen, GRAY, 
                           (int(x_coords[i]), int(true_y_x[i])), 
                           (int(x_coords_next[i]), int(true_y_x_next[i])), 2)
            pygame.draw.line(screen, WHITE,
                           (int(x_coords[i]), int(pred_y_x[i])),
                           (int(x_coords_next[i]), int(pred_y_x_next[i])), 2)
    
    # Draw Y velocity plot
    y_offset += plot_height + 60
    pygame.draw.line(screen, DARK_GRAY, (0, y_offset + plot_height//2), (WIDTH, y_offset + plot_height//2), 1)
    
    if current_bin + window_size <= len(y_true):
        # Y velocity
        true_y_y = y_offset + normalize_for_plot(y_true[current_bin:current_bin+window_size-1, 1], plot_height)
        true_y_y_next = y_offset + normalize_for_plot(y_true[current_bin+1:current_bin+window_size, 1], plot_height)
        pred_y_y = y_offset + normalize_for_plot(y_pred[current_bin:current_bin+window_size-1, 1], plot_height)
        pred_y_y_next = y_offset + normalize_for_plot(y_pred[current_bin+1:current_bin+window_size, 1], plot_height)
        
        # Draw lines in batches
        for i in range(len(x_coords)):
            pygame.draw.line(screen, GRAY,
                           (int(x_coords[i]), int(true_y_y[i])),
                           (int(x_coords_next[i]), int(true_y_y_next[i])), 2)
            pygame.draw.line(screen, WHITE,
                           (int(x_coords[i]), int(pred_y_y[i])),
                           (int(x_coords_next[i]), int(pred_y_y_next[i])), 2)
    
    # Draw axis labels
    time_label = font.render("Time", True, WHITE)
    channels_label = font.render("Channels", True, WHITE)
    x_vel_label = font.render("velocity X (prediction: WHITE)", True, WHITE)
    y_vel_label = font.render("velocity Y (prediction: WHITE)", True, WHITE)
    
    screen.blit(time_label, (WIDTH // 2 - 30, HEIGHT - 30))
    screen.blit(x_vel_label, (WIDTH // 2 - 150, spike_plot_height - 80))
    screen.blit(y_vel_label, (WIDTH // 2 - 150, spike_plot_height + plot_height - 25))
    
    # Rotate and draw y-axis label
    channels_surface = pygame.Surface((200, 30))
    channels_surface.fill(BLACK)
    channels_surface.blit(channels_label, (50, 0))
    channels_surface = pygame.transform.rotate(channels_surface, 90)
    screen.blit(channels_surface, (10, spike_plot_height // 2 - 100))
    
    # Update display
    pygame.display.flip()
    
    # Move window by one bin (20ms) each frame
    current_bin += bin_step
    if current_bin + window_size > total_bins:
        current_bin = 0
        
    # Control frame rate to 50 FPS (20ms per frame)
    clock.tick(25)

pygame.quit()


  y_pred = model(torch.tensor(X_test, dtype=torch.float32)).detach().numpy()
