In [1]:
import numpy as np
import matplotlib.pyplot as plt

import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import functional as SF
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import torch
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm 
import psweep as ps # <--

import warnings # highly illegal move to make pandas compliant
warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:
import sys
sys.path.append('../code/utils')  # Add the utils directory to the Python path

import utils_data, utils_spikes, utils_events, utils_tensor

# And now that we have a dataloader and a working network, we can sweep for a few parameters
### Meta parameters would be lr, beta, threshold
### Network parameters in a simple FF MLP would be depth, width

In [3]:
# Experiment specific parameters 
chip_id = 9501 # experiment ID
chip_session = 0 # 2 for post-training, 0 for pre-training

# Stable parameters
data_path = '../data/cortical_labs_data/' # path to data
fs = 20000 # sampling frequency
binsize = 10 # ms, bin size for spike counts
array_size = 1024 # number of electrode in the array

# Torch parameters 
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
len_trial = 100 # how long in bins is a trial (so in ms it's len_trial*binsize)

# Using our brand new dataloader

In [4]:
data_subset, events = utils_data.load_file(chip_id, chip_session, data_path)
spiketimes = utils_data.get_spiketimes(data_subset, array_size,fs)
sensory_spikes, up1_spikes, up2_spikes, down1_spikes, down2_spikes = utils_data.get_electrode_regions(data_subset, spiketimes, do_plot = False)

all_spikes = [sensory_spikes, up1_spikes, up2_spikes, down1_spikes, down2_spikes]
# Find maximum time across all spike lists
max_time_ms = max(max(max(spikes) for spikes in spike_list)*1000 for spike_list in all_spikes)

# Create binned spikes tensor for each region
sensory_spikes_binned = utils_tensor.spike_times_to_bins(sensory_spikes, binsize, max_time_ms, spike_tag = 'sensory')
up1_spikes_binned = utils_tensor.spike_times_to_bins(up1_spikes, binsize, max_time_ms, spike_tag = 'up1')
down1_spikes_binned = utils_tensor.spike_times_to_bins(down1_spikes, binsize, max_time_ms, spike_tag='down1')
up2_spikes_binned = utils_tensor.spike_times_to_bins(up2_spikes, binsize, max_time_ms, spike_tag = 'up2')
down2_spikes_binned = utils_tensor.spike_times_to_bins(down2_spikes, binsize, max_time_ms, spike_tag = 'down2')

# Verifying that the tensor are binary files
utils_tensor.check_binary(sensory_spikes_binned, "sensory_spikes_binned")
utils_tensor.check_binary(up1_spikes_binned, "up1_spikes_binned")
utils_tensor.check_binary(down1_spikes_binned, "down1_spikes_binned")
utils_tensor.check_binary(up2_spikes_binned, "up2_spikes_binned")
utils_tensor.check_binary(down2_spikes_binned, "down2_spikes_binned")

Loading data...: 100%|██████████| 29/29 [00:01<00:00, 28.15it/s]


Stimulation mode: full game


Binning sensory channels: 100%|██████████| 500/500 [00:00<00:00, 523.05it/s]
  return torch.tensor(binned_spikes)
Binning up1 channels: 100%|██████████| 100/100 [00:00<00:00, 734.27it/s]
Binning down1 channels: 100%|██████████| 100/100 [00:00<00:00, 759.29it/s]
Binning up2 channels: 100%|██████████| 100/100 [00:00<00:00, 689.70it/s]
Binning down2 channels: 100%|██████████| 100/100 [00:00<00:00, 710.56it/s]


True

# A bit of preprocessing to get a nice PyTorch friendly format 

In [5]:
# Processing events
# Process everything so its nice and milliseconds
events[0]['event'] = 'motor layout: 0' # change the being game to motor layout for convenience 
for event in events:
    event['norm_timestamp'] /= fs  # fs to seconds 
    event['norm_timestamp'] *= 1000  # seconds to ms
    
event_types = ['ball missed', 'ball bounce', 'ball return', 'motor layout: 0'] # these are all the labels
labels = torch.tensor(utils_tensor.events_to_bins(events, event_types, 10, max_time_ms))
assert labels.shape[-1] == sensory_spikes_binned.shape[-1] # make sure the labels and the data are the same length

transformed_data, transformed_labels = utils_tensor.transform_data(labels, sensory_spikes_binned, len_trial) # change dataformat
assert transformed_data.shape[1] == transformed_labels.shape[0] # make sure the labels and the data have the same trials

# Create Dataset
dataset = utils_tensor.CustomDataset(transformed_data, transformed_labels)

# Now we define the network

In [6]:
# Define Network
class Net(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs, beta=0.95):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):
        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(x.size(1)):
            cur1 = self.fc1(x[:,step])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [7]:
def train_dishnet(params) :
    # Unpack the parameters from the dict 
    num_inputs = params['num_inputs']
    num_hidden = params['num_hidden']
    num_outputs = params['num_outputs']
    beta = params['beta']
    num_epochs = params['num_epochs']
    lr = params['lr']
    batch_size = params['batch_size']
    
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    dishnet = Net(num_inputs = num_inputs, num_hidden = num_hidden, num_outputs=num_outputs, beta = beta).to(device)

    loss_fn = SF.ce_count_loss()
    optimizer = torch.optim.Adam(dishnet.parameters(), lr=lr, betas=(0.9, 0.999)) # future Hugo : these are optimizer's beta, not the SNN's, dont be stupid

    loss_hist = []
    test_loss_hist = []
    counter = 0

    # Outer training loop
    for epoch in tqdm(range(num_epochs)):
        train_batch = iter(data_loader)

        # Minibatch training loop
        for data, targets in train_batch:
            data = data.to(device)
            targets = targets.to(device)

            # forward pass
            dishnet.train()
            spk_rec, _ = dishnet(data)

            # initialize the loss & sum over time
            loss_val = loss_fn(spk_rec, targets)

            # Gradient calculation + weight update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            # Store loss history for future plotting
            loss_hist.append(loss_val.item())

            counter += 1

            if counter == 100:
                break
            
    print('\n')
    return loss_hist[-1]

In [8]:
# Define the lists of values for each parameter you want to sweep over
num_inputs = ps.plist("num_inputs", [transformed_data.shape[-1]])
num_hidden = ps.plist("num_hidden", [50, 100, 150])
num_outputs = ps.plist("num_outputs", [len(event_types)])
beta = ps.plist("beta", [0.9])
num_epochs = ps.plist("num_epochs", [50])
lr = ps.plist("lr", [1e-3])
batch_size = ps.plist("batch_size", [32])

# Create the parameter grid
param_grid = ps.pgrid((num_inputs, num_hidden, num_outputs, beta, num_epochs, lr, batch_size))

# Define a function to run one instance of the experiment
def run_experiment(params):
    return {'loss': train_dishnet(params)}

# Run the parameter sweep
results = ps.run_local(run_experiment, param_grid, verbose = True)


                               batch_size  beta     lr  num_epochs  num_hidden  num_inputs  num_outputs
2023-07-06 20:30:01.849459410          32   0.9  0.001          50          50         500            4


100%|██████████| 50/50 [00:50<00:00,  1.02s/it]




                               batch_size  beta     lr  num_epochs  num_hidden  num_inputs  num_outputs
2023-07-06 20:30:52.633007050          32   0.9  0.001          50         100         500            4


100%|██████████| 50/50 [00:52<00:00,  1.05s/it]




                               batch_size  beta     lr  num_epochs  num_hidden  num_inputs  num_outputs
2023-07-06 20:31:45.305738211          32   0.9  0.001          50         150         500            4


100%|██████████| 50/50 [00:57<00:00,  1.15s/it]






