In [None]:
import os
import numpy as np
from copy import deepcopy
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
from torchvision import models
from torch import nn
from sklearn.decomposition import PCA

from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

from model import BendrEncoder
from model.model import Flatten

from src.data.conf.eeg_annotations import braincapture_annotations
from net1d import Net1D

from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, balanced_accuracy_score

# set seed
import random
torch.manual_seed(1)
random.seed(1)
np.random.seed(1)

In [None]:
import logging

# Suppress logger messages from MNE-Python
mne_logger = logging.getLogger('mne')
mne_logger.setLevel(logging.ERROR)

In [None]:
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model
encoder = BendrEncoder()

# Load the pretrained model
encoder.load_state_dict(deepcopy(torch.load("encoder.pt", map_location=device)))
encoder = encoder.to(device)

In [None]:
def evaluate_model(model, test_loader, device, optimizer):
    with torch.no_grad():
        model.eval()
        total = correct = 0
        pbar = tqdm(total=len(test_loader), desc=f"Testing...")
        for batch in test_loader:
            if len(batch[0]) < 2: continue            
            
            optimizer.zero_grad()
                    
            X, y = batch
            X, y = X.to(device), y.to(device)

            logits = model(X)
            _, predicted = torch.max(logits.data, 1)

            total += y.size(0)
            correct += (predicted == y).sum().item()
            pbar.update(1)
            
        return 100 * correct / total

In [None]:
# Create binary dataset
X_true = torch.load("X_augment_data.pt").to(device)
X_true = X_true[:,:,:-1] # the feature vector len is 1537 for the false it is 1536 - easy fix / hack
# y_true = torch.load("y_augment_data.pt").to(device)
y_true = torch.ones(X_true.shape[0])
print(X_true.shape, y_true.shape)

X_false = torch.load("Z_data.pt").to(device)
X_false = X_false[torch.randperm(len(X_false))] # shuffle the tensor before discarding half for "test data"
X_false = X_false[len(X_false)//2:] # discard half to help class imbalance
y_false = torch.zeros(X_false.shape[0])
print(X_false.shape, y_false.shape)

X = torch.cat((X_true, X_false), dim=0)
y = torch.cat((y_true, y_false), dim=0)
print(X.shape, y.shape)

In [None]:
batch_size = 16

X_val = torch.load("X_data.pt").to(device)
y_val = torch.load("y_data.pt").to(device)

train_dataset = TensorDataset(X, y)
val_dataset = TensorDataset(X_val, y_val)

train_dataset, test_dataset = train_test_split(train_dataset, test_size=0.2)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

train_loader = DataLoader(train_loader, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


In [None]:
learning_rate = 0.0001
weight_decay=0.01
n_epochs = 5
n_splits = 5

model = create_model()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
criterion = nn.BCEWithLogitsLoss(weight=torch.tensor([0.1, 0.9]))

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=learning_rate, epochs=n_epochs, steps_per_epoch=len(train_loader), pct_start=0.1, last_epoch=-1
)

In [None]:
kf = KFold(n_splits=n_splits)
accuracies = []
all_preds = []
all_labels = []

for fold, (train_index, val_index) in enumerate(kf.split(train_dataset)):
    print(f'Fold {fold}')
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_index)
    val_subsampler = torch.utils.data.SubsetRandomSampler(val_index)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, sampler=train_subsampler)
    val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, sampler=val_subsampler)

    out_features = 2 # binary
    model = nn.Sequential(
        encoder,
        Flatten(),
        nn.Linear(in_features = 3 * 512 * 4, out_features = 512 * 4, bias=True),
        nn.Dropout(p=0.5, inplace=False),
        nn.ReLU(),
        nn.BatchNorm1d(512 * 4),
        nn.Linear(512 * 4, out_features, bias=True) 
    ).to(device)
    model = torch.load("models/binary_model_fold1_epoch5.pt")
    model = model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.BCEWithLogitsLoss(weight=torch.tensor([0.1, 0.9]))
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=0.0001, epochs=n_epochs, steps_per_epoch=len(train_loader), pct_start=0.1, last_epoch=-1
    )
    for epoch in range(1, n_epochs + 1):
        total = correct = 0
        pbar = tqdm(total = len(train_loader), desc = f"Epoch {epoch}, train")

        for vector, label in train_loader:
            #print(vector.shape)
            #if len(vector) < 2: continue

            optimizer.zero_grad()

            vector = vector.to(device)
            label = label.to(device)

            onehot_label = torch.torch.nn.functional.one_hot(label.to(torch.int64), 2).float()

            logits = model(vector)
            _, predicted = torch.max(logits.data, 1)

            total += label.size(0)
            correct += (predicted == label).sum().item()
            
            loss = criterion(logits, onehot_label)
            loss.backward()

            optimizer.step()
            scheduler.step()

            pbar.update(1)

        train_accuracy = 100 * correct / total
        test_accuracy = evaluate_model(model, test_loader, device, optimizer)
        pbar.set_description(f"Epoch {epoch}, train: {train_accuracy:.2f}%, test: {test_accuracy:.2f}%")
        torch.save(model.state_dict(), f"models/binary_model_fold{fold}_epoch{epoch+1}.pt")