# Evaluating the VLAAI network proposed in [Accurate decoding of the speech envelope using the VLAAI deep neural network](./#) on the DTU dataset.

In this example, a pre-trained VLAAI network will be evaluated on the publicly available dataset from [Fuglsang et al.](https://zenodo.org/record/1199011). This dataset contains 18 subjects who listened to one of two competing speech audio streams of 50 seconds with different levels of reverberation. For this example, we will only take the single-speaker trials (approximately 10 per subject, 500 seconds in total) into account.

The preprocessing used in this notebook is the same as proposed in the [paper](./#):
* For __EEG__: 
  1. High-pass filtering using a 1st order Butterworth filter with a cutoff frequency of 0.5Hz
  2. Downsampling to 1024 Hz
  3. Eyeblink artefact removal using a Multichannel Wiener filter
  4. Common average re-referencing
  5. Downsampling to 64Hz

* For __Speech__:
  1. Envelope extraction using a gamma-tone filterbank
  2. Downsampling to 1024 Hz
  3. Downsampling to 64 Hz

Preprocessed versions of the data are included in the [github repository](./#), code to the run the preprocessing manually is coming soon.

# Getting started

Installing the requirements

In [None]:
# Clone the repository
!git clone https://github.com/berndie/vlaai
%cd vlaai

In [None]:
# Install the requirements
!pip3 install -r requirements.txt

# Evaluating the pre-trained VLAAI network on the (already preprocessed) DTU dataset

In [None]:
# General imports
import glob
import os.path
import numpy as np

from model import vlaai
from examples.utils import window_data
from scipy.stats import pearsonr

In [None]:
# Load the model
vlaai_model = vlaai()
vlaai_model.load_weights("pretrained_models/vlaai.h5")
vlaai_model.summary()

In [None]:
# Load the dataset
paths = glob.glob("evaluation_datasets/DTU/*.npz")
print("Found {} paths for evaluation".format(len(paths)))
subjects = set(["_".join(os.path.basename(x).split("_")[:2]) for x in paths])
print("Found {} subjects for evaluation".format(len(subjects)))

In [2]:
# Set the number of trials that should be evaluated on for each subject
# If None, it will evaluate on all trials
# You can set this to a lower number to speed up the next code cell
nb_evaluation_trials = None

In [None]:
## Run the model evaluation
subject_scores = {}
boxplot_data = []

# Iterate over the subjects in the DTU dataset
for subject in subjects:
    print("Evaluating subject {}".format(subject))
    for index, p in enumerate(
        glob.glob("evaluation_datasets/DTU/{}_*.npz".format(subject))
    ):
        print("Gathering scores for {}...".format(p))
        # Load the data
        # Data is stored in .npz format with two keys: 'eeg' and 'envelope'
        # containing preprocessed EEG and corresponding speech stimulus
        # envelope.
        data = np.load(p)
        eeg = data["eeg"]
        envelope = data["envelope"]

        # Standardize EEG and envelope
        eeg = (eeg - eeg.mean(axis=0, keepdims=True)) / eeg.std(
            axis=0, keepdims=True
        )
        envelope = (
            envelope - envelope.mean(axis=0, keepdims=True)
        ) / envelope.std(axis=0, keepdims=True)

        # Window the data in windows of 5 seconds with 80% overlap
        windowed_eeg = window_data(eeg, 320, 64)
        windowed_envelope = window_data(envelope, 320, 64)

        # Evaluate the model on the overlapping windows
        if subject not in subject_scores:
            subject_scores[subject] = []
        predictions = vlaai_model.predict(windowed_eeg)
        for pred, true in zip(predictions, windowed_envelope):
            r = pearsonr(pred.reshape(-1), true.reshape(-1))
            subject_scores[subject] += [r[0]]
        if (
            nb_evaluation_trials is not None
            and index == nb_evaluation_trials - 1
        ):
            # Stop at this trial for the current subject
            break
    # Report the mean score for each subject
    mean_scores = np.mean(subject_scores[subject])
    boxplot_data += [mean_scores]
    print("Subject {}: {}".format(subject, mean_scores))

In [None]:
# Plot the results
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

plt.figure(figsize=(16, 9))
df = pd.DataFrame.from_dict({"VLAAI network": boxplot_data})
sns.violinplot(data=df, orient="v")
plt.ylabel("Reconstruction score (Pearson correlation)")
plt.xlabel("Models")
plt.title("Evaluation of the pre-trained VLAAI model on the DTU dataset")
plt.grid(True)
plt.show()
print("Median score = {:.2f}".format(np.median(boxplot_data)))

# COMING SOON: Code to do the preprocessing from scratch