In [2]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import os

from data.Dataset import MaskedDataset, MaskedDataset1, MaskedDataset2
from data.DataModules import SSLDataModule
from models import SSL_EEG
from modules.decoders import MaskedDecoder
from modules.loss import MaskedMSELoss

In this notebook we want to plot the reconstrution of the pretrained model. 

# Load the model from the checkpoint

In [3]:
checkpoints = os.path.join(os.getcwd(), 'checkpoints')  
prueba = os.path.join(checkpoints, 'SSL-1s-v5.ckpt')

In [4]:
model = SSL_EEG.load_from_checkpoint(prueba, decoder = MaskedDecoder, loss_fn= MaskedMSELoss)
model.eval()

SSL_EEG(
  (encoder): TSTransformerEncoder(
    (project_inp): Linear(in_features=8, out_features=64, bias=True)
    (pos_enc): FixedPositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
          )
          (linear1): Linear(in_features=64, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=256, out_features=64, bias=True)
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (output_layer): Linear(in_features=64, out_features=8

# Load the data

In [6]:
eeg = os.path.join(os.getcwd(), 'preprocess_data/tueh_mask_4s.h5')
masked_dataset = MaskedDataset2(hdf5_file=eeg, normalize= 'normalization')
# Create the data module with batch_size from wandb.config
datamodule = SSLDataModule(dataset = masked_dataset,
    batch_size=64,
    )

Load the data from the UVA dataset. 

In [13]:
%matplotlib widget
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider


@interact(i=IntSlider(min=0, max=7, step=1, value=0), j=IntSlider(min=0, max=99, step=1, value=0), continuous_update=False)
def plote(i, j):
    sample_uva = masked_dataset[j]
    sample_uva_masked = sample_uva[0].unsqueeze(0)
    sample_uva_raw = sample_uva[1].unsqueeze(0)
    sample_uva_mask = sample_uva[2].unsqueeze(0)
    output_encoder, _= model(sample_uva_masked.to(model.device))

    # Define x, assuming 128 data points for each sample
    x = np.linspace(0, 511, 512)
    
    plt.figure(figsize=(10, 6))  # Create a new figure for each plot

    # Plot the raw data
    plt.plot(x, sample_uva_raw[0, :, i].cpu().detach().numpy(), label='Raw EEG')
    # Plot the mask
    plt.plot(x, sample_uva_mask[0, :, i].cpu().detach().numpy(), label='Mask')
    # Plot the prediction
    plt.plot(x, output_encoder[0, :, i].cpu().detach().numpy(), label='Predicted')

    plt.legend()
    plt.grid(True)
    plt.xlabel('Time Step')
    plt.ylabel('Amplitude')
    plt.title(f'EEG Data Visualization for Channel {i+1} and Sample {j}')
    plt.show()

interactive(children=(IntSlider(value=0, description='i', max=7), IntSlider(value=0, description='j', max=99),…

Prove the model in another dataset different of the use for train. View the generalization of the model. 

In [None]:
deap_data = os.path.join(os.getcwd(), 'preprocess_data/deap_mask.hdf5')

In [None]:
deap_masked_dataset = MaskedDataset1(hdf5_file=deap_data, normalize= 'normalization')
# Create the data module with batch_size from wandb.config
datamodule = SSLDataModule(dataset = deap_masked_dataset,
    batch_size=64,
    )

(80640, 8, 128)


  apply_signal[:, :,j] = (apply_signal[:, :, j] - means) / (stds)


In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider


@interact(i=IntSlider(min=0, max=7, step=1, value=0), j=IntSlider(min=0, max=99, step=1, value=0), continuous_update=False)
def plote(i, j):
    sample_uva = deap_masked_dataset[j]
    sample_uva_masked = sample_uva[0].unsqueeze(0)
    sample_uva_raw = sample_uva[1].unsqueeze(0)
    sample_uva_mask = sample_uva[2].unsqueeze(0)
    output_encoder, _= model(sample_uva_masked.to(model.device))

    # Define x, assuming 128 data points for each sample
    x = np.linspace(0, 127, 128)

    plt.figure(figsize=(10, 6))  # Create a new figure for each plot

    # Plot the raw data
    plt.plot(x, sample_uva_raw[0,:,i].cpu().detach().numpy(), label='Raw EEG')
    # Plot the mask
    #plt.plot(x, sample_uva_mask[0,:,i].cpu().detach().numpy(), label='Mask')
    # Plot the prediction
    plt.plot(x, output_encoder[0,:,i].cpu().detach().numpy(), label='Predicted')

    plt.legend()
    plt.grid(True)
    plt.xlabel('Time Step')
    plt.ylabel('Amplitude')
    plt.title(f'EEG Data Visualization for Channel {i+1} and Sample {j}')
    plt.show()

interactive(children=(IntSlider(value=0, description='i', max=7), IntSlider(value=0, description='j', max=99),…