In [2]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import confusion_matrix

from tensorflow.keras.utils import to_categorical
import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

from read_emg import *
from build_CNN import *

import seaborn as sns
plt.style.use('default')
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['legend.fontsize'] = 12
plt.rcParams['legend.title_fontsize'] = 16
plt.rcParams['axes.labelsize'] = 16

In [3]:
# %%time
# df = read_file()
# df.to_pickle('DataCollection.pkl')

In [4]:
df = pd.read_pickle('DataCollection.pkl')
df['Trial_num'] = df['Trial_num'].astype(int)
Gesture_list = ['Cylinder Grasp', 'Wrist Extension', 'Fist', 'Finger mass extension', 'Opposition', 'Lateral Pinch']

# Within session classification

In [46]:
%%time
subject_list = df['ID'].unique()
acc_list = []
confusion_matrix_list = []
stats = []
for subject in tqdm(['0001', '1001', '1003', '1004', '1005', '1006', '1007', '1111', '1234', '9999']):
    for session in ['S1', 'S2']:
        print(f"\n\x1b[31m\"Current session: {session}\"\x1b[0m")
        ### Takes only the gestures that we are interested in
        df_subject = df[(df['ID'] == subject) 
                        & (df['session'] == session) 
                        & (df['Gesture'] != '1') 
                        & (df['Gesture'] != '0') 
                        & (df['Gesture'] != '3')
                        & (df['Gesture'] != '8')
                        & (df['Gesture'] != '10')
                        & (df['Gesture'] != '11')]
        
        ### Skip current loop if no matching subject is found
        if len(df_subject) == 0:
            continue
            
        ### Iterate through all trials in sessions    
        for trial_test in [1, 2, 3]:
            df_subject_train = df_subject[(df_subject['Trial_num'] != trial_test) & (df_subject.index > 100)].copy()
            df_subject_test = df_subject[(df_subject['Trial_num'] == trial_test) & (df_subject.index > 100)].copy()
            
            ### Take the last 100 samples (0.5 second as the validation set)
            df_subject_train.reset_index(inplace = True, drop = True)
            df_subject_val = df_subject_train.groupby(['Gesture', 'session', 'Trial_num']).tail(100)

            df_subject_train.loc[df_subject_val.index, :] = np.nan
            df_subject_train.dropna(inplace = True, axis = 0)
            
            ### Preprocess the data into overlapping windows and filter with a high pass filter
            print('\n==================================')
            print('Preprocess and filter the EMG data')
            print('==================================')

            print('Training Data')
            X_train, y_train, _= preprocess(df_subject_train, window_size = 52, nonoverlap_size = 5)
            
            print('Validation Data')
            X_val, y_val, _ = preprocess(df_subject_val, window_size = 52, nonoverlap_size = 5)
            
            print('Test Data')
            X_test, y_test, _ = preprocess(df_subject_test, window_size = 52, nonoverlap_size = 5)
            
            ### Label encoding the gesture
            le = LabelEncoder()
            y_train = le.fit_transform(y_train)
            y_val = le.transform(y_val)
            y_test = le.transform(y_test)
            
            ### One-hot encoding
            y_train = to_categorical(y_train)
            y_val = to_categorical(y_val)
            y_test = to_categorical(y_test)
            
            ### Feature scaling using standardization
            standard_scaler = StandardScaler()
            X_train = standard_scaler.fit_transform(X_train.reshape(-1, 8)).reshape(X_train.shape)
            X_val = standard_scaler.transform(X_val.reshape(-1, 8)).reshape(X_val.shape)
            X_test = standard_scaler.transform(X_test.reshape(-1, 8)).reshape(X_test.shape)
            
            ### Reshape the array into (data length, window_size, EMG_channel_size, channel [Greyscale])
            ### The default shape of the EMG snapshot is (data_length, 52, 8, 1)
            X_train = X_train.reshape((-1, X_train.shape[1], X_train.shape[2], 1))
            X_val = X_val.reshape((-1, X_val.shape[1], X_val.shape[2], 1))
            X_test = X_test.reshape((-1, X_test.shape[1], X_test.shape[2], 1))
                
            ### Create the CNN model with dropout and L2 regularization)
            model = get_CNN_model(X_train.shape, y_train.shape[1], dr = 0.5, wd = 0.01)

            model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001), 
                      loss = tf.keras.losses.CategoricalCrossentropy(),
                      metrics = [tf.keras.metrics.CategoricalAccuracy(name = 'acc')])
            reduce_lr = ReduceLROnPlateau(monitor = 'val_acc', patience = 3, mode = 'max', verbose = 0)
            early_stopping = EarlyStopping(monitor = 'val_acc', patience = 10, mode = 'max', verbose = 0)
            print('\n==================================')
            print('         Start training           ')
            print('==================================')
            print(f'Shape of training data {X_train.shape}')
            print(f'Shape of validation data {X_val.shape}')
            print(f'Shape of test data {X_test.shape}')
            history = model.fit(X_train, y_train, 
                                validation_data = (X_val, y_val), 
                                callbacks = [reduce_lr, early_stopping], 
                                epochs = 30, verbose = 0)
            acc_val = model.evaluate(X_val, y_val, verbose = 0)[1]
            acc = model.evaluate(X_test, y_test, verbose = 0)[1]
            print(f'Accuracy of trial {trial_test}: %.2f' % (acc))
            acc_list.append(acc)
            y_pred = np.argmax(model.predict(X_test), axis = -1)
            y_test = np.argmax(y_test, axis = -1)
            matrix = confusion_matrix(y_test, y_pred)
            cm = pd.DataFrame(matrix, columns = Gesture_list, index = Gesture_list)
            cm['ID'] = subject
            confusion_matrix_list.append(cm)
            stats.append({'ID': subject,
                          'Trial_num': trial_test,
                          'Session': session,
                          'Accuracy': acc,
                          'confusion_matrix': cm, 
                          'history_list': history})    

  0%|                                                                                           | 0/10 [00:00<?, ?it/s]


[31m"Current session: S1"[0m

Preprocess and filter the EMG data
Training Data
Processing subject: 0001

Validation Data
Processing subject: 0001

Test Data
Processing subject: 0001


         Start training           
Shape of training data (3083, 52, 8, 1)
Shape of validation data (240, 52, 8, 1)
Shape of test data (886, 52, 8, 1)
Accuracy of trial 1: 0.92

Preprocess and filter the EMG data
Training Data
Processing subject: 0001

Validation Data
Processing subject: 0001

Test Data
Processing subject: 0001


         Start training           
Shape of training data (3079, 52, 8, 1)
Shape of validation data (240, 52, 8, 1)
Shape of test data (890, 52, 8, 1)
Accuracy of trial 2: 1.00

Preprocess and filter the EMG data
Training Data
Processing subject: 0001

Validation Data
Processing subject: 0001

Test Data
Processing subject: 0001


         Start training           
Shape of training data (3061, 52, 8, 1)
Shape of validation data (240, 52, 8, 1)
Shape of test data (908, 52, 8, 1

 10%|████████▏                                                                         | 1/10 [01:46<15:56, 106.31s/it]

Accuracy of trial 3: 0.98

[31m"Current session: S1"[0m

Preprocess and filter the EMG data
Training Data
Processing subject: 1001

Validation Data
Processing subject: 1001

Test Data
Processing subject: 1001


         Start training           
Shape of training data (1689, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (962, 52, 8, 1)
Accuracy of trial 1: 0.75

Preprocess and filter the EMG data
Training Data
Processing subject: 1001

Validation Data
Processing subject: 1001

Test Data
Processing subject: 1001


         Start training           
Shape of training data (1689, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (962, 52, 8, 1)
Accuracy of trial 2: 0.96

Preprocess and filter the EMG data
Training Data
Processing subject: 1001

Validation Data
Processing subject: 1001

Test Data
Processing subject: 1001


         Start training           
Shape of training data (1684, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape o

 20%|████████████████▍                                                                 | 2/10 [03:32<14:07, 105.99s/it]


[31m"Current session: S1"[0m

Preprocess and filter the EMG data
Training Data
Processing subject: 1003

Validation Data
Processing subject: 1003

Test Data
Processing subject: 1003


         Start training           
Shape of training data (1677, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (970, 52, 8, 1)
Accuracy of trial 1: 0.98

Preprocess and filter the EMG data
Training Data
Processing subject: 1003

Validation Data
Processing subject: 1003

Test Data
Processing subject: 1003


         Start training           
Shape of training data (1685, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (962, 52, 8, 1)
Accuracy of trial 2: 0.98

Preprocess and filter the EMG data
Training Data
Processing subject: 1003

Validation Data
Processing subject: 1003

Test Data
Processing subject: 1003


         Start training           
Shape of training data (1692, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (955, 52, 8, 1

 30%|████████████████████████▌                                                         | 3/10 [05:37<13:24, 114.95s/it]


[31m"Current session: S1"[0m

Preprocess and filter the EMG data
Training Data
Processing subject: 1004

Validation Data
Processing subject: 1004

Test Data
Processing subject: 1004


         Start training           
Shape of training data (1677, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (951, 52, 8, 1)
Accuracy of trial 1: 0.83

Preprocess and filter the EMG data
Training Data
Processing subject: 1004

Validation Data
Processing subject: 1004

Test Data
Processing subject: 1004


         Start training           
Shape of training data (1668, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (960, 52, 8, 1)
Accuracy of trial 2: 1.00

Preprocess and filter the EMG data
Training Data
Processing subject: 1004

Validation Data
Processing subject: 1004

Test Data
Processing subject: 1004


         Start training           
Shape of training data (1671, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (957, 52, 8, 1

 40%|████████████████████████████████▊                                                 | 4/10 [07:36<11:39, 116.65s/it]


[31m"Current session: S1"[0m

Preprocess and filter the EMG data
Training Data
Processing subject: 1005

Validation Data
Processing subject: 1005

Test Data
Processing subject: 1005


         Start training           
Shape of training data (1681, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (964, 52, 8, 1)
Accuracy of trial 1: 0.99

Preprocess and filter the EMG data
Training Data
Processing subject: 1005

Validation Data
Processing subject: 1005

Test Data
Processing subject: 1005


         Start training           
Shape of training data (1685, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (960, 52, 8, 1)
Accuracy of trial 2: 0.98

Preprocess and filter the EMG data
Training Data
Processing subject: 1005

Validation Data
Processing subject: 1005

Test Data
Processing subject: 1005


         Start training           
Shape of training data (1684, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (961, 52, 8, 1

 50%|█████████████████████████████████████████                                         | 5/10 [09:37<09:50, 118.13s/it]


[31m"Current session: S1"[0m

Preprocess and filter the EMG data
Training Data
Processing subject: 1006

Validation Data
Processing subject: 1006

Test Data
Processing subject: 1006


         Start training           
Shape of training data (1683, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (962, 52, 8, 1)
Accuracy of trial 1: 0.78

Preprocess and filter the EMG data
Training Data
Processing subject: 1006

Validation Data
Processing subject: 1006

Test Data
Processing subject: 1006


         Start training           
Shape of training data (1681, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (964, 52, 8, 1)
Accuracy of trial 2: 0.91

Preprocess and filter the EMG data
Training Data
Processing subject: 1006

Validation Data
Processing subject: 1006

Test Data
Processing subject: 1006


         Start training           
Shape of training data (1686, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (959, 52, 8, 1

 60%|█████████████████████████████████████████████████▏                                | 6/10 [11:44<08:03, 120.99s/it]


[31m"Current session: S1"[0m

Preprocess and filter the EMG data
Training Data
Processing subject: 1007

Validation Data
Processing subject: 1007

Test Data
Processing subject: 1007


         Start training           
Shape of training data (1825, 52, 8, 1)
Shape of validation data (130, 52, 8, 1)
Shape of test data (955, 52, 8, 1)
Accuracy of trial 1: 0.85

Preprocess and filter the EMG data
Training Data
Processing subject: 1007

Validation Data
Processing subject: 1007

Test Data
Processing subject: 1007


         Start training           
Shape of training data (1819, 52, 8, 1)
Shape of validation data (130, 52, 8, 1)
Shape of test data (961, 52, 8, 1)
Accuracy of trial 2: 0.97

Preprocess and filter the EMG data
Training Data
Processing subject: 1007

Validation Data
Processing subject: 1007

Test Data
Processing subject: 1007


         Start training           
Shape of training data (1817, 52, 8, 1)
Shape of validation data (130, 52, 8, 1)
Shape of test data (963, 52, 8, 1

 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [14:01<06:19, 126.33s/it]


[31m"Current session: S1"[0m

Preprocess and filter the EMG data
Training Data
Processing subject: 1111

Validation Data
Processing subject: 1111

Test Data
Processing subject: 1111


         Start training           
Shape of training data (2264, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (1251, 52, 8, 1)
Accuracy of trial 1: 0.64

Preprocess and filter the EMG data
Training Data
Processing subject: 1111

Validation Data
Processing subject: 1111

Test Data
Processing subject: 1111


         Start training           
Shape of training data (2268, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (1247, 52, 8, 1)
Accuracy of trial 2: 0.83

Preprocess and filter the EMG data
Training Data
Processing subject: 1111

Validation Data
Processing subject: 1111

Test Data
Processing subject: 1111


         Start training           
Shape of training data (2258, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (1257, 52, 8

 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [16:21<04:21, 130.70s/it]


[31m"Current session: S1"[0m

Preprocess and filter the EMG data
Training Data
Processing subject: 1234

Validation Data
Processing subject: 1234

Test Data
Processing subject: 1234


         Start training           
Shape of training data (1672, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (964, 52, 8, 1)
Accuracy of trial 1: 0.94

Preprocess and filter the EMG data
Training Data
Processing subject: 1234

Validation Data
Processing subject: 1234

Test Data
Processing subject: 1234


         Start training           
Shape of training data (1682, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (954, 52, 8, 1)
Accuracy of trial 2: 1.00

Preprocess and filter the EMG data
Training Data
Processing subject: 1234

Validation Data
Processing subject: 1234

Test Data
Processing subject: 1234


         Start training           
Shape of training data (1678, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (958, 52, 8, 1

 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [19:15<02:24, 144.33s/it]


[31m"Current session: S1"[0m

Preprocess and filter the EMG data
Training Data
Processing subject: 9999

Validation Data
Processing subject: 9999

Test Data
Processing subject: 9999


         Start training           
Shape of training data (1934, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (1097, 52, 8, 1)
Accuracy of trial 1: 0.99

Preprocess and filter the EMG data
Training Data
Processing subject: 9999

Validation Data
Processing subject: 9999

Test Data
Processing subject: 9999


         Start training           
Shape of training data (1942, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (1089, 52, 8, 1)
Accuracy of trial 2: 0.98

Preprocess and filter the EMG data
Training Data
Processing subject: 9999

Validation Data
Processing subject: 9999

Test Data
Processing subject: 9999


         Start training           
Shape of training data (1946, 52, 8, 1)
Shape of validation data (120, 52, 8, 1)
Shape of test data (1085, 52, 8

100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [23:12<00:00, 139.21s/it]

Wall time: 23min 12s





In [10]:
pd.options.display.float_format = "{:,.3f}".format
# stats_df = pd.DataFrame(stats).set_index('ID')
# stats_df.iloc[:, :-1].to_pickle('Stats_within_session.pkl')
stats_df = pd.read_pickle('Stats_within_session.pkl')
acc_list = stats_df['Accuracy']

# Confusion matrix of all subjects 
(x axis is the ground truth and y axis is the predicted labels)

In [12]:
print('Accuracy %.2f +/- %.2f' %(np.mean(acc_list), np.std(acc_list)))

Accuracy 0.94 +/- 0.07


In [14]:
Matrix_1 = stats_df['confusion_matrix'].sum().iloc[:, :6]
Matrix_1 = Matrix_1 / Matrix_1.astype('float').sum(axis = 1)
Matrix_1 * 100
# fig = plt.figure(figsize = (12, 7))
# sns.heatmap(Matrix_1, cmap = 'Blues', annot=True, fmt=".2%", linewidths=1.0, annot_kws={"fontsize":14})
# plt.title('Confusion matrix', fontsize = 18)
# plt.ylabel('Target', labelpad = 30)
# plt.xlabel('Predicted', labelpad = 30)
# plt.tight_layout()

Unnamed: 0,Cylinder Grasp,Wrist Extension,Fist,Finger mass extension,Opposition,Lateral Pinch
Cylinder Grasp,88.415,0.05,0.34,4.492,4.172,2.494
Wrist Extension,0.02,96.575,0.05,0.35,2.519,0.489
Fist,0.651,0.01,96.156,0.2,0.0,2.973
Finger mass extension,4.906,0.02,0.11,93.277,0.637,1.047
Opposition,1.352,0.757,0.771,0.17,93.289,3.681
Lateral Pinch,0.971,0.617,0.39,0.42,3.156,94.444
