In [24]:
# This script only saves the spike times for spinnaker implementation

import os
import numpy as np
import matplotlib.pyplot as plt
import scipy as sc
from scipy.signal import butter, lfilter, welch, square  #for signal filtering

In [25]:
# general stuff
fs = 200  # sampling frequency of MYO
VERBOSE = True
data_dir = '/Users/Anonymous/Desktop/EMG/Roshambo/'
classes = ['rock', 'paper', 'scissor']
classes_dict = {'rock': 0, 'paper': 1, 'scissor': 2}
classes_inv = {v: k for k, v in classes_dict.items()}

In [26]:
class Person(object):
    def __init__(self, name, emg, ann, classes= ['rock', 'paper', 'scissor']):
        self.name = name
        self.emg = emg
        self.ann = ann
        self.trials = {c: [] for c in classes}
        self.begs = {c: [] for c in classes}
        self.ends = {c: [] for c in classes}
        self.x = {c: [] for c in classes}
        self.y = {c: [] for c in classes}
        self.ts = {c: [] for c in classes}
        self.pol = {c: [] for c in classes}
        self.emg_spks = {c: [] for c in classes}
        self.spk_trials = {c: [] for c in classes}  

In [27]:
subjects = {}
names = sorted([name for name in os.listdir(data_dir) if "emg" in name])

In [28]:
for name in names:
    _emg = np.load(data_dir + '{}'.format(name)).astype('float32')
    _ann = np.concatenate([np.array(['none']), np.load(data_dir + '{}'.format(name.replace("emg","ann")))[:-1]])
    subjects["_".join(name.split("_")[:2])] = Person(name.split("_")[0], _emg, _ann, classes=classes)
    print("Loaded {}: EMG = [{}] // ANN = [{}]".format("_".join(name.split("_")[:2]), _emg.shape, len(_ann)))
   

Loaded subject01_session01: EMG = [(8962, 8)] // ANN = [8962]
Loaded subject01_session02: EMG = [(8933, 8)] // ANN = [8933]
Loaded subject01_session03: EMG = [(8989, 8)] // ANN = [8989]
Loaded subject02_session01: EMG = [(8990, 8)] // ANN = [8990]
Loaded subject02_session02: EMG = [(8985, 8)] // ANN = [8985]
Loaded subject02_session03: EMG = [(8975, 8)] // ANN = [8975]
Loaded subject03_session01: EMG = [(8976, 8)] // ANN = [8976]
Loaded subject03_session02: EMG = [(8949, 8)] // ANN = [8949]
Loaded subject03_session03: EMG = [(8981, 8)] // ANN = [8981]
Loaded subject04_session01: EMG = [(8953, 8)] // ANN = [8953]
Loaded subject04_session02: EMG = [(8943, 8)] // ANN = [8943]
Loaded subject04_session03: EMG = [(8953, 8)] // ANN = [8953]
Loaded subject05_session01: EMG = [(9185, 8)] // ANN = [9185]
Loaded subject05_session02: EMG = [(9146, 8)] // ANN = [9146]
Loaded subject05_session03: EMG = [(9150, 8)] // ANN = [9150]
Loaded subject06_session01: EMG = [(8984, 8)] // ANN = [8984]
Loaded s

In [29]:
for name, data in subjects.items():
    for _class in classes:
        _annotation = np.float32(data.ann == _class)
        derivative = np.diff(_annotation)/1.0
        begins = np.where(derivative == 1)[0]
        ends = np.where(derivative == -1)[0]
        for b, e in zip(begins, ends):
            _trials = data.emg[b:e]
            data.trials[_class].append(_trials / np.std(_trials))
            data.begs[_class].append(b)
            data.ends[_class].append(e)
print("Done sorting trials!")


Done sorting trials!


In [30]:
# check that we get 5 trials per subject per gesture
for sub_name, sub_data in subjects.items():
    for _class, trials in sub_data.trials.items():
        assert (len(trials) == 5), "Something wrong with the number of trials!"
print("All good!")

All good!


In [31]:
def signal_to_spike_refractory(interpfact, time, amplitude, thr_up, thr_dn, refractory_period):
    actual_dc = 0
    spike_up = []
    spike_up = []
    spike_dn = []
    last_sample = interpfact * refractory_period

    f = sc.interpolate.interp1d(time, amplitude)
    rangeint = np.round((np.max(time) - np.min(time)) * interpfact)
    xnew = np.linspace(np.min(time), np.max(time), num=int(rangeint), endpoint=True)
    data = np.reshape([xnew, f(xnew)], (2, len(xnew))).T

    i = 0
    while i < (len(data) - int(last_sample)):
        if ((actual_dc + thr_up) < data[i, 1]):
            spike_up.append(data[i, 0])  # spike up
            actual_dc = data[i, 1]  # update current dc value
            i += int(refractory_period * interpfact)
        elif ((actual_dc - thr_dn) > data[i, 1]):
            spike_dn.append(data[i, 0])  # spike dn
            actual_dc = data[i, 1]  # update curre
            i += int(refractory_period * interpfact)
        else:
            i += 1

    return spike_up, spike_dn


In [32]:
def gen_spike_time(time_series_data):
    spike_time_array_up = []
    spike_time_array_dn = []
    for channel_number in range((time_series_data.shape)[1]):
        
      raw_channel = time_series_data[:,channel_number]
      _t = 1000*np.arange(0, raw_channel.shape[0] / fs, 1. / fs) 
        
      spk_up, spk_dn = signal_to_spike_refractory(interpolation, _t, raw_channel, th_up, th_dn, refractory)
      spike_time_array_up.append(spk_up)
      spike_time_array_dn.append(spk_dn)
      
    return spike_time_array_up, spike_time_array_dn

In [33]:
X_EMG = []
Y_EMG = []
SUB_EMG = []
SES_EMG = []
TRI_EMG = []

for name, data in subjects.items():
    for gesture in classes:
        for trial in range(5):
            X_EMG.append(data.trials[gesture][trial])
            Y_EMG.append(classes_dict[gesture])
            SUB_EMG.append(int(name[7:9]))
            SES_EMG.append(int(name[17:19]))
            TRI_EMG.append(trial)

X_EMG = np.array(X_EMG)
# X_EMG_SPIKE = np.array(X_EMG_SPIKE)
Y_EMG = np.array(Y_EMG)
SUB_EMG = np.array(SUB_EMG)
SES_EMG = np.array(SES_EMG)
TRI_EMG = np.array(TRI_EMG)
X_EMG_uniform = np.ones((450,400,8))
for i in range(len(X_EMG)) :
  trial_length = X_EMG[i].shape[0]
  if (trial_length > 400): 
    X_EMG_uniform[i] = X_EMG[i][0:400]
  elif (trial_length < 400) : 
    short = 400 -trial_length
    pad = np.zeros((short,8))
    X_EMG_uniform[i] = np.append(X_EMG[i],pad, axis=0)
  else : 
    X_EMG_uniform[i] = X_EMG[i]
print(len(X_EMG))
print(len(X_EMG_uniform))
print(len(Y_EMG))
print(list(set(Y_EMG)))
print(list(set(SUB_EMG)))
print(list(set(SES_EMG)))
print(list(set(TRI_EMG)))



450
450
450
[0, 1, 2]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[1, 2, 3]
[0, 1, 2, 3, 4]


  X_EMG = np.array(X_EMG)


In [34]:
X_EMG_Train = []
Y_EMG_Train = []
X_EMG_Test = []
Y_EMG_Test = []
for i in range(450) : 
    if (SES_EMG[i]) == 3 :
        X_EMG_Test.append(X_EMG_uniform[i])
        Y_EMG_Test.append(Y_EMG[i])
    else : 
        X_EMG_Train.append(X_EMG_uniform[i])
        Y_EMG_Train.append(Y_EMG[i])
        

In [35]:
interpolation = 5
refractory = 0.00
th_up = th_dn = 0.4
n_ch = 8
fs = 200

In [36]:
spike_times_train_up = []
spike_times_train_dn = []
for i in range(len(X_EMG_Train)):
    spk_up, spk_dn = gen_spike_time(X_EMG_Train[i])
    spike_times_train_up.append(spk_up)
    spike_times_train_dn.append(spk_dn)
  # X_EMG_SPIKE = gen_X_spike(X_EMG[i])

In [37]:
spike_times_test_up = []
spike_times_test_dn = []
for i in range(len(X_EMG_Test)):
    spk_up, spk_dn = gen_spike_time(X_EMG_Test[i])
    spike_times_test_up.append(spk_up)
    spike_times_test_dn.append(spk_dn)
  # X_EMG_SPIKE = gen_X_spike(X_EMG[i])

In [38]:
len(spike_times_test_up[1][1])

233

In [39]:
def gen_spike_rate(spike_time_array):
    
    nb_trials = len(spike_time_array)
    nb_electrodes = len(spike_time_array[1])
    spike_sum = 0
    for trial_number in range(nb_trials):
        for channel_number in range(nb_electrodes):
            spike_sum = spike_sum + len(spike_time_array[trial_number][channel_number])
            time_max = 2
    av_spike_sum = spike_sum/ (nb_trials*nb_electrodes)
    spike_rate = av_spike_sum/time_max
    return spike_rate

In [40]:
rate_up_test = gen_spike_rate(spike_times_test_up)
rate_dn_test = gen_spike_rate(spike_times_test_up)
rate_up_train = gen_spike_rate(spike_times_train_up)
rate_dn_train = gen_spike_rate(spike_times_train_up)
print(rate_up_test)
print(rate_dn_test)
print(rate_up_train)
print(rate_dn_train)

print((rate_up_test+rate_dn_test+rate_up_train+rate_dn_train) / 4)

174.57916666666668
174.57916666666668
168.71416666666667
168.71416666666667
171.64666666666668


In [41]:
len(spike_times_train_dn[1][6])

880

In [43]:

spike_times_test_up = np.array(spike_times_test_up)
spike_times_test_up = np.array(spike_times_test_up)
spike_times_train_up = np.array(spike_times_train_up)
spike_times_train_up = np.array(spike_times_train_up)

In [44]:
np.savez('/Users/Anonymous/Desktop/EMG/EMG_dataset_with_spike_time_python3.npz', Y_EMG_Train=Y_EMG_Train,Y_EMG_Test=Y_EMG_Test,spike_times_train_up = spike_times_train_up ,spike_times_train_dn = spike_times_train_dn,spike_times_test_up = spike_times_test_up ,spike_times_test_dn = spike_times_test_dn)