In [52]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from model import neural_model as nm
import pandas as pd

In [53]:
df = pd.read_csv('final_extracted_a0f66459.csv')
electrode_col_names = [col for col in df.columns if 'GRID' in col ]
X = df[electrode_col_names]
Y = df['mvmt']
Y = (Y=='r_arm_1').astype(int)
num_samples = len(Y)

train_X = X[:int(num_samples*.8)]
test_X = X[int(num_samples*.8):int(num_samples*.9)]
val_X = X[int(num_samples*.9):]
train_Y = Y[:int(num_samples*.8)]
test_Y = Y[int(num_samples*.8):int(num_samples*.9)]
val_Y = Y[int(num_samples*.9):]

In [62]:
class ecog_dataset(Dataset):
    def __init__(self, x, y):
        self.data = x
        self.keys = y
        
    def __len__(self):
        return len(self.keys)
    
    def __getitem__(self, index):
        return torch.tensor(self.data.iloc[index]), torch.tensor([self.keys.iloc[index]])

In [63]:
trainset = ecog_dataset(train_X, train_Y)
testset = ecog_dataset(test_X, test_Y)
valset = ecog_dataset(val_X, val_Y)

trainloader = DataLoader(trainset, batch_size=1, shuffle=False) #dataset to train on (80% data)
testloader = DataLoader(testset, batch_size=1, shuffle=False) #dataset to test each epoch on (10% data)
validloader = DataLoader(valset, batch_size=1, shuffle=False) #dataset to validate perf on (10% data)

In [94]:
#define network and params
net = nm(64)

loss_func = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=.1)

In [95]:
#train
for epoch in range(10):
    running_loss = 0.0
    train_count = 0
    for i, data in enumerate(trainloader):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs.float())
        loss = loss_func(outputs.float(), labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        train_count += 1
    #do epoch test on separate test set
    test_loss = 0.0
    test_count = 0
    for i, data in enumerate(testloader, 0):
        inputs, labels = data
        outputs = net(inputs.float())
        loss = loss_func(outputs, labels.float())
        test_loss += loss.item()
        test_count += 1
    print("Epoch: ", epoch, 'train loss: ', running_loss/train_count, 
          'test loss: ', test_loss/test_count)

Epoch:  0 train loss:  61.028315946348734 test loss:  61.30952380952381
Epoch:  1 train loss:  61.028315946348734 test loss:  61.30952380952381
Epoch:  2 train loss:  61.028315946348734 test loss:  61.30952380952381
Epoch:  3 train loss:  61.028315946348734 test loss:  61.30952380952381
Epoch:  4 train loss:  61.028315946348734 test loss:  61.30952380952381
Epoch:  5 train loss:  61.028315946348734 test loss:  61.30952380952381
Epoch:  6 train loss:  61.028315946348734 test loss:  61.30952380952381
Epoch:  7 train loss:  61.028315946348734 test loss:  61.30952380952381
Epoch:  8 train loss:  61.028315946348734 test loss:  61.30952380952381
Epoch:  9 train loss:  61.028315946348734 test loss:  61.30952380952381


In [98]:
#validate
validate_loss = 0.0
validate_count = 0
y_pred = []
y_true = []
for i, data in enumerate(validloader, 0):
    inputs, labels = data
    outputs = net(inputs.float())
    loss = loss_func(outputs, labels.float())
    y_pred.append(outputs.item())
    y_true.append(labels.item())
    validate_loss += loss.item()
    validate_count += 1
print(y_true)
print(y_pred)
print('validation loss: ', validate_loss / validate_count)

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,

In [99]:
from sklearn.metrics import balanced_accuracy_score
balanced_accuracy_score(y_true, y_pred)

0.5