In [71]:
from transformer import Transformer
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
import pandas as pd

In [72]:
import sys
sys.path.append("../process_data/")
from load_datasets import load_datasets

In [73]:
load_dir = "../../data/datasets/discrete/"
train_ds, val_ds, test_ds = load_datasets(load_dir)

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False)

In [74]:
model = Transformer(
    d_model=256,
    d_hidden=512,
    d_feature=4,
    d_timestep=501,
    q=8,
    v=8,
    h=8,
    N=8,
    class_num=1
).float()

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f"Number of params: {int(params//1e6)}M")

Number of params: 5M


In [75]:
def evaluate(model, val_dl, criterion, device):
    model.eval()

    epoch_val_loss = 0
    epoch_val_acc = 0
    epoch_val_f1 = 0
    tot_batches = 0

    for batch in val_dl: 
        eeg = batch[0].permute(0, 2, 1).float().to(device)
        labels = batch[2].float().to(device)

        # todo: need to remove "train"
        preds = model(x=eeg, stage="train").squeeze(1)
        preds_probs = torch.sigmoid(preds).to(device)

        batch_loss = criterion(preds_probs, labels)
        epoch_val_loss += batch_loss.item()

        predicted_labels = (preds_probs >= 0.5).float()
        batch_correct = (predicted_labels == labels).sum().item()

        batch_accuracy = batch_correct / predicted_labels.shape[0]
        epoch_val_acc += batch_accuracy

        batch_f1 = f1_score(labels.cpu().numpy(), predicted_labels.cpu().numpy(), average='macro')
        epoch_val_f1 += batch_f1

        tot_batches += 1

    epoch_val_loss /= tot_batches
    epoch_val_acc /= tot_batches
    epoch_val_f1 /= tot_batches

    return epoch_val_loss, epoch_val_acc, epoch_val_f1

    


In [70]:
device = "mps:0"
model = model.to(device)

num_epochs = 15
lr = 0.00005
wd = 1e-4

model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
steps = 0 

criterion = torch.nn.BCELoss()

training_losses = list()
training_accuracy = list()
training_f1 = list()

val_losses = list()
val_accuracy = list()
val_f1 = list()

for epoch in range(num_epochs): 
    epoch_training_loss = 0
    epoch_training_acc = 0
    epoch_training_f1 = 0
    tot_batches = 0

    pbar = tqdm(train_dl, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch in pbar: 
        model.train()
        optimizer.zero_grad()

        eeg = batch[0].permute(0, 2, 1).float().to(device)
        groups = np.array(batch[1])
        labels = batch[2].float().to(device)

        preds = model(x=eeg, stage="train").squeeze(1)
        preds_probs = torch.sigmoid(preds).to(device)

        batch_loss = criterion(preds_probs, labels)
        batch_loss.backward()
        optimizer.step()
        epoch_training_loss += batch_loss.item()

        predicted_labels = (preds_probs >= 0.5).float()
        batch_correct = (predicted_labels == labels).sum().item()

        batch_accuracy = batch_correct / predicted_labels.shape[0]
        epoch_training_acc += batch_accuracy

        batch_f1 = f1_score(labels.cpu().numpy(), predicted_labels.cpu().numpy(), average='macro')
        epoch_training_f1 += batch_f1

        pbar.set_description("Epoch {}/{} [t_loss={:.3f}] [t_acc={:.3f}] [t_f1={:.3f}]".format(epoch+1, num_epochs, batch_loss.item(), batch_accuracy, batch_f1))
        tot_batches += 1
    

    epoch_training_loss /= tot_batches
    epoch_training_acc /= tot_batches
    epoch_training_f1 /= tot_batches

    training_losses.append(epoch_training_loss)
    training_accuracy.append(epoch_training_acc)
    training_f1.append(epoch_training_f1)

    # pbar.set_description("Epoch {}/{} [validating]".format(epoch+1, num_epochs))
    epoch_val_loss, epoch_val_acc, epoch_val_f1 = evaluate(model, val_dl, criterion, device)

    val_losses.append(epoch_val_loss)
    val_accuracy.append(epoch_val_acc)
    val_f1.append(epoch_val_f1 )

    # pbar.set_description("Epoch {}/{} [v_loss={:.3f}] [v_acc={:.3f}] [v_f1={:.3f}]".format(epoch+1, num_epochs, val_losses[-1], val_accuracy[-1], val_f1[-1]))
    torch.save(model.state_dict(), f"../../checkpoints/gtn2/discrete/model_epoch{epoch+1}.pt")

    progress_df = pd.DataFrame({
        "epoch": np.arange(1, epoch+2),
        "t_loss": training_losses,
        "t_acc": training_accuracy,
        "t_f1": training_f1,
        "v_loss": val_losses,
        "v_acc": val_accuracy,
        "v_f1": val_f1,
    })
    progress_df.to_csv("../../checkpoints/gtn2/discrete/progress.csv")



Epoch 1/15 [t_loss=0.651] [t_acc=0.750] [t_f1=0.733]: 100%|██████████| 1410/1410 [08:32<00:00,  2.75it/s]
Epoch 2/15 [t_loss=0.691] [t_acc=0.500] [t_f1=0.333]: 100%|██████████| 1410/1410 [08:33<00:00,  2.74it/s]
Epoch 3/15 [t_loss=0.815] [t_acc=0.000] [t_f1=0.000]: 100%|██████████| 1410/1410 [08:37<00:00,  2.73it/s]
Epoch 4/15 [t_loss=0.891] [t_acc=0.000] [t_f1=0.000]: 100%|██████████| 1410/1410 [08:28<00:00,  2.77it/s]
Epoch 5/15 [t_loss=0.441] [t_acc=1.000] [t_f1=1.000]: 100%|██████████| 1410/1410 [08:25<00:00,  2.79it/s]
Epoch 6/15 [t_loss=0.553] [t_acc=0.500] [t_f1=0.333]: 100%|██████████| 1410/1410 [08:25<00:00,  2.79it/s]
Epoch 7/15 [t_loss=0.541] [t_acc=0.750] [t_f1=0.733]: 100%|██████████| 1410/1410 [08:23<00:00,  2.80it/s]
Epoch 8/15 [t_loss=0.978] [t_acc=0.500] [t_f1=0.500]: 100%|██████████| 1410/1410 [08:24<00:00,  2.79it/s]
Epoch 9/15 [t_loss=1.188] [t_acc=0.500] [t_f1=0.333]: 100%|██████████| 1410/1410 [08:25<00:00,  2.79it/s]
Epoch 10/15 [t_loss=1.295] [t_acc=0.500] [t_f1