In [1]:
from typing import Iterable

import os
import numpy as np
import pandas as pd
import seaborn as sns
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import datetime

import gdown

torch.manual_seed(42)

<torch._C.Generator at 0x7a47ed84aad0>

In [2]:
class ChronoNet(nn.Module):

    def __init__(self, inception_dropout_p=0.0, gru_dropout_p=0.0):
        super(ChronoNet, self).__init__()

        self.idp = inception_dropout_p
        self.gdp = gru_dropout_p

        self.cnn_layers = nn.Sequential(
            self.MultiscaleConv1D(64, 32),
            # nn.ReLU(),
            nn.Dropout(p=self.idp),
            self.MultiscaleConv1D(96, 32),
            # nn.ReLU(),
            nn.Dropout(p=self.idp),
            self.MultiscaleConv1D(96, 32),
            # nn.ReLU(),
        )

        # ModuleList should contain 4 Sequential dropout-then-GRU containers
        self.gru_layers = nn.ModuleList([
            nn.Sequential(
                nn.Dropout(p=self.gdp),
                nn.GRU(96, 32, batch_first=True)),
            nn.Sequential(
                nn.Dropout(p=self.gdp),
                nn.GRU(32, 32, batch_first=True)),
            nn.Sequential(
                nn.Dropout(p=self.gdp),
                nn.GRU(64, 32, batch_first=True)),
            nn.Sequential(
                nn.Dropout(p=self.gdp),
                nn.GRU(96, 32, batch_first=True)),
        ])

        self.linear = nn.Linear(32, 1)

        self.sig = nn.Sigmoid()

    def forward(self, batch):
        # Transpose back and forth because CNN modules expect time at last dimension instead of features.
        batch = torch.transpose(batch, 1, 2)
        cnn_out = self.cnn_layers(batch)
        cnn_out = torch.transpose(cnn_out, 1, 2)

        gru_out_0, _ = self.gru_layers[0](cnn_out)
        gru_out_1, _ = self.gru_layers[1](gru_out_0)
        gru_out_2, _ = self.gru_layers[2](torch.cat((gru_out_0, gru_out_1), dim=2))
        gru_out_3, _ = self.gru_layers[3](torch.cat((gru_out_0, gru_out_1, gru_out_2), dim=2))

        # maybe test concatenating with input
        out = self.linear(gru_out_3)
        score = self.sig(out)

        return score

    class MultiscaleConv1D(nn.Module):
        def __init__(self, in_channels: int, out_channels: int, kernel_sizes: Iterable[int] = (2, 4, 8), stride: int = 1):
            super(ChronoNet.MultiscaleConv1D, self).__init__()
            # iterate the list and create a ModuleList of single Conv1d blocks
            self.kernels = nn.ModuleList()
            for k in kernel_sizes:
                self.kernels.append(nn.Conv1d(in_channels, out_channels, k, stride=stride, padding=k//2 - 1))

        def forward(self, batch):
            # now you can build a single output from the list of convs
            out = [module(batch) for module in self.kernels]
            # concatenate at dim=1 since in convolutions features are at dim=1
            return torch.cat(out, dim=1)

In [3]:
class CustomNPZDataset(Dataset):
    def __init__(self, file_path):
        # Load the .npz file in 'mmap_mode' for memory-efficient access
        self.data = np.load(file_path, mmap_mode='r')

        # Assume the .npz file contains two arrays: 'inputs' and 'labels'
        self.inputs = self.data['epochs']
        self.labels = self.data['labels']

    def __len__(self):
        return self.inputs.shape[0]  # Return the number of samples (rows)

    def __getitem__(self, idx):
        # Load a single input and label
        input_data = self.inputs[idx]
        label = self.labels[idx]

        # Convert to PyTorch tensors and return
        return torch.tensor(input_data, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)

In [4]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
    elif isinstance(m, nn.LSTM):
        for name, param in m.named_parameters():
            if 'weight_ih' in name:
                torch.nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                torch.nn.init.orthogonal_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)

In [5]:
balance = True
scale = True  # if True, uses scaled data

if balance:
  file_id = '1BVrAZ5kg96Zqpwlfaea3WhGMDg8ovBU5'  # file containing balanced (and scaled) data
elif scale:
  file_id = '16CyXKsWCW4zkBM9CiSrleAoUi8gArZQm'  # file containing scaled data
else:
  file_id = '1ckbrLscgUmJHVR_yI4bdoSZyVEpgstD3'  # file containing unscaled data

local_path = 'content/drive/My Drive/Colab Notebooks/'  # Replace with your desired local path

gdown.download(
    f'https://drive.google.com/uc?id={file_id}',
    local_path,
    quiet=True
)

'content/drive/My Drive/Colab Notebooks/training_data_200Hz_scaled_BALANCED.npz'

In [6]:
if balance:
  filename = os.path.join(local_path, 'training_data_200Hz_scaled_BALANCED.npz')
elif scale:
  filename = os.path.join(local_path, 'training_data_200Hz_scaled.npz')
else:
  filename = os.path.join(local_path, 'training_data_200Hz.npz')
dataset = CustomNPZDataset(file_path=filename)

In [163]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# HYPERPARAMETERS
downsample_factor = 5  # when 1 -> memory overload: can I save file in several steps?
washout_factor = 900 / 2250  # 'time in ms you want to washout' / 'EEG window length in ms'
learning_rate = 1e-4
num_epochs = 15

weight_decay = 1e-1
inception_dropout_p = 0.5
gru_dropout_p = 0.5

cuda


In [164]:
# Split lengths (e.g., 80% train, 20% test)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

# Split the dataset
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)  # test num_workers = 1, 2, 4, ...
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Create model
model = ChronoNet(inception_dropout_p=inception_dropout_p, gru_dropout_p=gru_dropout_p)
model = model.to(device)
model.apply(init_weights)

loss_function = nn.BCELoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [165]:
epochs_train_loss = np.zeros(num_epochs)
epochs_validation_loss = np.zeros(num_epochs)

for epoch in range(num_epochs):
    start = time.time()

    # training loop
    for i, (inputs, labels) in enumerate(train_loader):
        # inputs have shape (batch_size, sequence_length, num_features)
        model.zero_grad()

        inputs, labels = inputs.to(device), labels.to(device)
        if i == 0:
            outputs = model(inputs)
            washout = int(outputs.shape[1] * washout_factor)

        outputs = model(inputs)[:, washout:, :]

        # reshape labels to match output
        labels = labels.unsqueeze(-1).unsqueeze(-1).expand(-1, outputs.shape[1], -1)
        loss = loss_function(outputs, labels)
        epochs_train_loss[epoch] += loss.item()

        loss.backward()
        optimizer.step()

    epochs_train_loss[epoch] /= i+1

    # validation loop
    with torch.no_grad():
        model.eval()

        for i, (inputs, labels) in enumerate(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)[:, washout:, :]

            # reshape labels to match output
            labels = labels.unsqueeze(-1).unsqueeze(-1).expand(-1, outputs.shape[1], -1)
            loss = loss_function(outputs, labels)
            epochs_validation_loss[epoch] += loss.item()

        epochs_validation_loss[epoch] /= i + 1
        model.train()

    end = time.time()
    print('Epoch [{}/{}]:\n{:>17}: {:8.7f}\n{:>17}: {:8.7f}\n{:>17}: {:8.2f}\n'.format(epoch + 1,
                                                                                        num_epochs,
                                                                                        'Train Loss',
                                                                                        epochs_train_loss[epoch],
                                                                                        'Validation Loss',
                                                                                        epochs_validation_loss[
                                                                                            epoch],
                                                                                        'Elapsed Time',
                                                                                        end - start))


Epoch [1/15]:
       Train Loss: 0.7018759
  Validation Loss: 0.6957938
     Elapsed Time:    17.63

Epoch [2/15]:
       Train Loss: 0.6942053
  Validation Loss: 0.6949447
     Elapsed Time:    17.92

Epoch [3/15]:
       Train Loss: 0.6911980
  Validation Loss: 0.6972646
     Elapsed Time:    17.69

Epoch [4/15]:
       Train Loss: 0.6892131
  Validation Loss: 0.6964092
     Elapsed Time:    17.92

Epoch [5/15]:
       Train Loss: 0.6871048
  Validation Loss: 0.6955361
     Elapsed Time:    17.88

Epoch [6/15]:
       Train Loss: 0.6850672
  Validation Loss: 0.6960019
     Elapsed Time:    18.19



KeyboardInterrupt: 

In [None]:
sns.set_context("paper", font_scale=1.25)
plt.figure(figsize=(10, 6))

sns.set_palette(sns.color_palette("deep")[4::2])

sns.lineplot(y=epochs_train_loss, x=range(1, num_epochs + 1), label='Training Loss')
sns.lineplot(y=epochs_validation_loss, x=range(1, num_epochs + 1), label='Validation Loss')

sns.despine()
plt.xlabel("Epoch")
plt.ylabel("Loss")
# plt.title("Loss over epochs")

# TODO: save the data (both loss and accuracies) to be able to change the plot.
np.save('chrononet_train_loss_{}ep_{}wd_{}idp_{}gdp.npy'.format(num_epochs, weight_decay, inception_dropout_p, gru_dropout_p), epochs_train_loss)
np.save('chrononet_val_loss_{}ep_{}wd_{}idp_{}gdp.npy'.format(num_epochs, weight_decay, inception_dropout_p, gru_dropout_p), epochs_validation_loss)

In [None]:
def plot_accuracies(data: np.ndarray = None, title: str = "",
                    savefile: str = None, washout: int = None) -> None:
    """
    Plots the mean accuracy over time with confidence band over subjects.
    :param data: 2D numpy array, where each row is the decoding accuracy for one subject over all timesteps.
    :param title: title of the plot.
    :param savefile: file name to save the plot under. If None, no plot is saved.
    :param washout:
    :return: None
    """

    df = pd.DataFrame(data=data.T)
    df = df.reset_index().rename(columns={'index': 'Time'})
    df = df.melt(id_vars=['Time'], value_name='Mean_Accuracy', var_name='Subject')

    sns.set_context("paper", font_scale=1.25)

    # Create a seaborn lineplot, passing the matrix directly to seaborn
    plt.figure(figsize=(10, 6))  # Optional: Set the figure size

    # Create the lineplot, seaborn will automatically calculate confidence intervals
    sns.lineplot(data=df, x=(df['Time'] + washout) * 5 - 1000, y='Mean_Accuracy',
                 errorbar='ci', label='Accuracy')  # BUT confidence band gets much larger with 'sd'
    # Also, it is important to note that MVPA computes CIs over subjects, while the
    # neural nets compute CIs over trials.Higher n makes for narrower CIs, i.e. neural
    # nets will have much narrower CIs without this implying higher certainty.
    sns.despine()

    plt.axhline(y=0.5, color='orange', linestyle='dashdot', linewidth=1, label='Random Chance')
    plt.axvline(x=0, ymin=0, ymax=0.05, color='black', linewidth=1, label='Stimulus Onset')

    # Set plot labels and title
    plt.xlabel('Time (ms)')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title(title)

    if savefile is not None:
        plt.savefig('results/{}.png'.format(savefile))

    # Show the plot
    plt.show()

In [None]:
with torch.no_grad():
    model.eval()

    trainset_accuracies = torch.Tensor(0).to(device)
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)[:, washout:, :]

        outputs[labels == 0] = 1 - outputs[labels == 0]  # invert scores if label is 0 (to represent accuracy)

        trainset_accuracies = torch.cat((trainset_accuracies, outputs), dim=0)

    testset_accuracies = torch.Tensor(0).to(device)
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)[:, washout:, :]

        outputs[labels == 0] = 1 - outputs[labels == 0]  # invert scores if label is 0 (to represent accuracy)

        testset_accuracies = torch.cat((testset_accuracies, outputs), dim=0)

In [None]:
    with torch.no_grad():
        model.eval()

        trainset_scores = torch.empty(0).to(device)
        trainset_accuracies = torch.empty(0).to(device)
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)[:, washout:, :]

            predictions = outputs >= 0.5
            accuracy = predictions == labels.unsqueeze(-1).unsqueeze(-1).expand(-1, outputs.shape[1], -1)

            outputs[labels == 0] = 1 - outputs[labels == 0]  # invert scores if label is 0 (to represent accuracy)

            trainset_accuracies = torch.cat((trainset_accuracies, accuracy), dim=0)
            trainset_scores = torch.cat((trainset_scores, outputs), dim=0)

        testset_scores = torch.empty(0).to(device)
        testset_accuracies = torch.empty(0).to(device)
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)[:, washout:, :]

            predictions = outputs >= 0.5
            accuracy = predictions == labels.unsqueeze(-1).unsqueeze(-1).expand(-1, outputs.shape[1], -1)

            outputs[labels == 0] = 1 - outputs[labels == 0]  # invert scores if label is 0 (to represent accuracy)

            testset_accuracies = torch.cat((testset_accuracies, accuracy), dim=0)
            testset_scores = torch.cat((testset_scores, outputs), dim=0)

In [None]:
# plot_accuracies(data=trainset_scores.squeeze().cpu().numpy(), title='Training Scores', savefile=None, washout=washout)

# plot_accuracies(data=testset_scores.squeeze().cpu().numpy(), title='Validation Scores', savefile=None, washout=washout)

np.save('chrononet_train_sco_{}ep_{}wd_{}idp_{}gdp.npy'.format(num_epochs, weight_decay, inception_dropout_p, gru_dropout_p),
        trainset_scores.squeeze().cpu().numpy())
np.save('chrononet_val_sco_{}ep_{}wd_{}idp_{}gdp.npy'.format(num_epochs, weight_decay, inception_dropout_p, gru_dropout_p),
        testset_scores.squeeze().cpu().numpy())

In [None]:
# TODO: adapt axes of plot (with downsample factor / washout or in method directly)
# plot_accuracies(data=trainset_accuracies.squeeze().cpu().numpy(), title='Training Accuracy', savefile=None, washout=washout)

# plot_accuracies(data=testset_accuracies.squeeze().cpu().numpy(), title='Validation Accuracy', savefile=None, washout=washout)

# TODO: save accuracies as data
np.save('chrononet_train_acc_{}ep_{}wd_{}idp_{}gdp.npy'.format(num_epochs, weight_decay, inception_dropout_p, gru_dropout_p),
        trainset_accuracies.squeeze().cpu().numpy())
np.save('chrononet_val_acc_{}ep_{}wd_{}idp_{}gdp.npy'.format(num_epochs, weight_decay, inception_dropout_p, gru_dropout_p),
        testset_accuracies.squeeze().cpu().numpy())