In [None]:
import torch
import numpy as np
from sklearn.preprocessing import MinMaxScaler

def generate_spike_train(data, max_rate, time_window):
    spike_trains = []
    for value in data:
        spike_rate = value * max_rate
        spikes = torch.zeros(time_window)
        num_spikes = int(spike_rate)
        spike_times = torch.randperm(time_window)[:num_spikes]
        spikes[spike_times] = 1
        spike_trains.append(spikes)
    return torch.cat(spike_trains)

def minmax_scale_across_timepoints(data, scaler = MinMaxScaler()):
    samples, timepoints, features = data.shape
    data_reshaped = data.reshape(samples * timepoints, features)
    data_scaled_reshaped = scaler.fit_transform(data_reshaped)    
    return data_scaled_reshaped.reshape(samples, timepoints, features)

from tqdm import tqdm
def rate_encoder(dataset, max_rate = 10, time_window = 20):
  samples, timepoints, features = dataset.shape
  encoded_dataset = torch.zeros(samples, timepoints*time_window, features)
  for i in range(samples):
    for j in range(features):
      encoded_dataset[i][:, j] = generate_spike_train(dataset[i][:, j], max_rate, time_window) 
  return encoded_dataset