In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
#%cd /content/drive/MyDrive/Georgetown/BrainDecoding

/content/drive/MyDrive/Georgetown/BrainDecoding


In [4]:
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True,bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, output_size)
        self.activation = nn.Sigmoid()
    def forward(self, x):
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(x.device)
        # Initialize cell state with zeros
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))  # shape (batch_size, seq_length, hidden_size)
        out = self.fc(out[:, -1, :])  # shape (batch_size, output_size)
        out = self.activation(out)
        return out


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
input_size = 200  # Dimensionality of each image vector
hidden_size = 128
num_layers = 8
output_size = 1
lr = 0.001
num_epochs = 100

In [7]:
model = LSTMModel(input_size, hidden_size, num_layers, output_size).to(device)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)


3105025


In [8]:
Xtrain = torch.load('preprocessed_data/Xtrain.pt').to(device)
ytrain = torch.load('preprocessed_data/ytrain.pt').unsqueeze(1).to(device)
Xdev = torch.load('preprocessed_data/Xdev.pt').to(device)
ydev= torch.load('preprocessed_data/ydev.pt').unsqueeze(1).to(device)
Xtest = torch.load('preprocessed_data/Xtest.pt').to(device)
ytest= torch.load('preprocessed_data/ytest.pt').unsqueeze(1).to(device)

In [9]:

for epoch in range(num_epochs):
    model.train()
    outputs = model(Xtrain)
    loss = criterion(outputs, ytrain)
    #optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

    with torch.no_grad():
        model.eval()
        op = model(Xdev).squeeze()
        op_cpu = op.cpu()
        ydev_cpu = ydev.cpu()
        preds = [int(o.item()>0.5) for o in op_cpu]
        print(accuracy_score(ydev_cpu,preds))

Epoch [1/100], Loss: 0.6931
0.48605119355766463
Epoch [2/100], Loss: 0.6933
0.48605119355766463
Epoch [3/100], Loss: 0.6931
0.5139488064423353
Epoch [4/100], Loss: 0.6930
0.5139488064423353
Epoch [5/100], Loss: 0.6931
0.5139488064423353
Epoch [6/100], Loss: 0.6933
0.5139488064423353
Epoch [7/100], Loss: 0.6930
0.5139488064423353
Epoch [8/100], Loss: 0.6921
0.5139488064423353
Epoch [9/100], Loss: 0.6902
0.5148116192119644
Epoch [10/100], Loss: 0.6871
0.5257405809605982
Epoch [11/100], Loss: 0.6826
0.5349439171699741
Epoch [12/100], Loss: 0.6785
0.5329306873741732
Epoch [13/100], Loss: 0.6793
0.541558815070463
Epoch [14/100], Loss: 0.6805
0.5524877768190969
Epoch [15/100], Loss: 0.6776
0.5536381938452689
Epoch [16/100], Loss: 0.6811
0.5579522576934138
Epoch [17/100], Loss: 0.6869
0.5616911130284729
Epoch [18/100], Loss: 0.6847
0.5634167385677308
Epoch [19/100], Loss: 0.6788
0.5634167385677308
Epoch [20/100], Loss: 0.6778
0.5631291343111878
Epoch [21/100], Loss: 0.6823
0.48605119355766463

In [12]:
with torch.no_grad():
      model.eval()
      op = model(Xdev).squeeze()
      op_cpu = op.cpu()
      ydev_cpu = ydev.cpu()
      preds = [int(o.item()>0.5) for o in op_cpu]
      print(accuracy_score(ydev_cpu,preds))
      print(classification_report(ydev_cpu,preds))

0.542134023583549
              precision    recall  f1-score   support

         0.0       0.54      0.42      0.47      1690
         1.0       0.55      0.66      0.60      1787

    accuracy                           0.54      3477
   macro avg       0.54      0.54      0.53      3477
weighted avg       0.54      0.54      0.54      3477



In [17]:
with torch.no_grad():
      model.eval()
      op = model(Xtest).squeeze()
      op_cpu = op.cpu()
      ytest_cpu = ytest.cpu()
      preds = [int(o.item()>0.5) for o in op_cpu]
      print(accuracy_score(ytest_cpu,preds))
      print(classification_report(ytest_cpu,preds))

0.5602154674874952
              precision    recall  f1-score   support

         0.0       0.61      0.52      0.56      2816
         1.0       0.52      0.61      0.56      2382

    accuracy                           0.56      5198
   macro avg       0.56      0.56      0.56      5198
weighted avg       0.57      0.56      0.56      5198

