In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from model import neural_model as nm
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
df = pd.read_csv('final_extracted_a0f66459.csv')
electrode_col_names = [col for col in df.columns if 'GRID' in col ]
X = df[electrode_col_names]
Y = df['mvmt']
Y = (Y=='r_arm_1').astype(int)
num_samples = len(Y)

train_X, test_X, train_Y, test_Y = train_test_split(X, Y,stratify=Y, test_size=0.20)
test_X, val_X, test_Y, val_Y = train_test_split(test_X, test_Y, stratify=test_Y, test_size=0.50)

#train_X = X[:int(num_samples*.8)]
#test_X = X[int(num_samples*.8):int(num_samples*.9)]
#val_X = X[int(num_samples*.9):]
#train_Y = Y[:int(num_samples*.8)]
#test_Y = Y[int(num_samples*.8):int(num_samples*.9)]
#val_Y = Y[int(num_samples*.9):]

In [3]:
class ecog_dataset(Dataset):
    def __init__(self, x, y):
        self.data = x
        self.keys = y
        
    def __len__(self):
        return len(self.keys)
    
    def __getitem__(self, index):
        return torch.tensor(self.data.iloc[index]), torch.tensor([self.keys.iloc[index]])

In [4]:
trainset = ecog_dataset(train_X, train_Y)
testset = ecog_dataset(test_X, test_Y)
valset = ecog_dataset(val_X, val_Y)

trainloader = DataLoader(trainset, batch_size=100, shuffle=True) #dataset to train on (80% data)
testloader = DataLoader(testset, batch_size=100, shuffle=True) #dataset to test each epoch on (10% data)
validloader = DataLoader(valset, batch_size=100, shuffle=True) #dataset to validate perf on (10% data)

In [5]:
#define network and params
net = nm(64)

loss_func = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=.001)
best_loss = 99999
best_net = net

In [6]:
import copy
#train
for epoch in range(60):
    running_loss = 0.0
    train_count = 0
    #net = copy.deepcopy(best_net)
    for i, data in enumerate(trainloader, 0):
        net.train()
        inputs, labels = data
        #if(labels.item() == 1):
        optimizer.zero_grad()
        outputs = net(inputs.float())
        #print(outputs)
        loss = loss_func(outputs.float(), labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        train_count += 1
    #do epoch test on separate test set
    test_loss = 0.0
    test_count = 0
    for i, data in enumerate(testloader, 0):
        inputs, labels = data
        outputs = net(inputs.float())
        loss = loss_func(torch.round(outputs), labels.float())
        test_loss += loss.item()
        test_count += 1
    if test_loss/test_count < best_loss:
        best_loss = test_loss/test_count
        best_net = copy.deepcopy(net)
        print('new best model!')
    print("Epoch: ", epoch, 'train loss: ', running_loss/train_count, 
          'test loss: ', test_loss/test_count)

new best model!
Epoch:  0 train loss:  0.6531955940382821 test loss:  0.6931471526622772
new best model!
Epoch:  1 train loss:  0.655912492956434 test loss:  0.672085165977478
Epoch:  2 train loss:  0.6521646082401276 test loss:  0.6814969480037689
Epoch:  3 train loss:  0.6533336213656834 test loss:  0.6767911016941071
new best model!
Epoch:  4 train loss:  0.6518688542502267 test loss:  0.660320520401001
Epoch:  5 train loss:  0.6521834305354527 test loss:  0.6697321832180023
Epoch:  6 train loss:  0.6534948987620217 test loss:  0.6720851957798004
Epoch:  7 train loss:  0.6507259437016079 test loss:  0.6673793196678162
Epoch:  8 train loss:  0.6508706339768001 test loss:  0.6697322130203247
Epoch:  9 train loss:  0.6485874312264579 test loss:  0.6720851957798004
Epoch:  10 train loss:  0.6477182635239193 test loss:  0.6626733541488647
Epoch:  11 train loss:  0.6459807881287166 test loss:  0.6767910420894623
Epoch:  12 train loss:  0.6437421824250903 test loss:  0.681632936000824
new 

In [7]:
#validate
validate_loss = 0.0
validate_count = 0
y_pred = []
y_true = []
for i, data in enumerate(validloader, 0):
    best_net.eval()
    inputs, labels = data
    outputs = best_net(inputs.float())
    loss = loss_func(outputs, labels.float())
    for item in outputs:
        y_pred.append(item.item())
    for item in labels:
        y_true.append(item.item())
    validate_loss += loss.item()
    validate_count += 1
print(y_true)
print(y_pred)
print('validation loss: ', validate_loss / validate_count)

[1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1]
[0.25887417793273926, 0.2576909363269806, 0.2576003670692444, 0.25725695490837097, 1.0, 0.5215379595756531, 0.25718677043914795, 1.0, 0.4778161942958832, 0.25729987025260925, 1.0, 1.0, 0.600465714931488, 0.27381980419158936, 0.2725762724876404, 0.2577058672904968, 1.0, 0.25732579827308655, 0.25777342915534973, 0.2562524378299713, 1.0, 0.3278510868549347, 0.25738486647605896, 0.25768572092056274, 0.44001269340515137, 0.44625625014305115, 0.25783470273017883, 0.38733914494514465, 0.7376481890

In [8]:
from sklearn.metrics import balanced_accuracy_score
y_pred_round = []
for pred in y_pred: #keep the network outputs within acceptable bounds
    y_pred_round.append(min(max(round(pred), 0),1))
print(y_pred_round)
print(y_true)
balanced_accuracy_score(y_true, y_pred_round)

[0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1]
[1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1,

0.6833154588631837

In [9]:
torch.save(best_net.state_dict(), 'curr_best.pyt')