In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
#%cd /content/drive/MyDrive/Georgetown/BrainDecoding

/content/drive/MyDrive/Georgetown/BrainDecoding


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
Xtrain = torch.load('preprocessed_data/Xtrain.pt').to(device)
ytrain = torch.load('preprocessed_data/ytrain.pt').unsqueeze(1).to(device)
Xdev = torch.load('preprocessed_data/Xdev.pt').to(device)
ydev= torch.load('preprocessed_data/ydev.pt').unsqueeze(1).to(device)

In [6]:
# Class for custom RNN model
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.activation = nn.Sigmoid()

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])  # Select the last time step output because this is sequence classification
        out = self.activation(out)
        return out

In [7]:
# Hyperparams
vector_length = 200
input_size = vector_length
hidden_size = 64
output_size = 1
epochs = 1000
lr = 0.001

In [8]:
# Create instance of RNN model
model = RNNModel(input_size, hidden_size, output_size).to(device)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    output = model(Xtrain)
    loss = criterion(output, ytrain)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')

# Evaluate the model on dev after each epoch
    with torch.no_grad():
        model.eval()
        op = model(Xdev)
        op = op.cpu()
        preds = [int(o.item()>0.5) for o in op]
        print(accuracy_score(ydev.cpu(),preds))



17089
Epoch 1/1000, Loss: 0.7417199611663818
0.5047454702329595
Epoch 2/1000, Loss: 0.7369950413703918
0.5079091170549324
Epoch 3/1000, Loss: 0.732903778553009
0.5087719298245614
Epoch 4/1000, Loss: 0.7288743853569031
0.5099223468507333
Epoch 5/1000, Loss: 0.725101888179779
0.5127983894161634
Epoch 6/1000, Loss: 0.7217933535575867
0.5090595340811044
Epoch 7/1000, Loss: 0.7189183831214905
0.5125107851596203
Epoch 8/1000, Loss: 0.7164000272750854
0.5087719298245614
Epoch 9/1000, Loss: 0.7141180634498596
0.5047454702329595
Epoch 10/1000, Loss: 0.7118733525276184
0.5021570319240725
Epoch 11/1000, Loss: 0.7102100849151611
0.5050330744895024
Epoch 12/1000, Loss: 0.7086242437362671
0.5030198446937014
Epoch 13/1000, Loss: 0.7072809934616089
0.5073339085418465
Epoch 14/1000, Loss: 0.7061548233032227
0.5070463042853034
Epoch 15/1000, Loss: 0.7049278020858765
0.5073339085418465
Epoch 16/1000, Loss: 0.7040411233901978
0.5061834915156744
Epoch 17/1000, Loss: 0.7031784653663635
0.5058958872591315
Ep

In [9]:
# Load in test data
Xtest = torch.load('preprocessed_data/Xtest.pt').to(device)
ytest= torch.load('preprocessed_data/ytest.pt').unsqueeze(1).to(device)

In [10]:
# Get f, recall, precision on Dev
with torch.no_grad():
        model.eval()
        op = model(Xdev)
        op = op.cpu()
        preds = [int(o.item()>0.5) for o in op]
        print(accuracy_score(ydev.cpu(),preds))
        print(classification_report(ydev.cpu(),preds))

0.5125107851596203
              precision    recall  f1-score   support

         0.0       0.50      0.50      0.50      1690
         1.0       0.53      0.52      0.52      1787

    accuracy                           0.51      3477
   macro avg       0.51      0.51      0.51      3477
weighted avg       0.51      0.51      0.51      3477



In [11]:
# Get f, recall, precision on Test

with torch.no_grad():
        model.eval()
        op = model(Xtest)
        op = op.cpu()
        preds = [int(o.item()>0.5) for o in op]
        print(accuracy_score(ytest.cpu(),preds))
        print(classification_report(ytest.cpu(),preds))

0.5069257406694883
              precision    recall  f1-score   support

         0.0       0.55      0.49      0.52      2816
         1.0       0.47      0.53      0.50      2382

    accuracy                           0.51      5198
   macro avg       0.51      0.51      0.51      5198
weighted avg       0.51      0.51      0.51      5198

