In [1]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import os

from data.Dataset import MaskedDataset
from data.DataModules import SSLDataModule
from models import SSL_EEG
from modules.decoders import MaskedDecoder
from modules.loss import MaskedMSELoss

In this notebook we want to plot the reconstrution of the pretrained model. 

# Load the model from the checkpoint

In [8]:
checkpoints = os.path.join(os.getcwd(), 'checkpoints')  
prueba = os.path.join(checkpoints, 'SSL-4s-val_loss=0.12-epoch=96.ckpt')
# prueba2 = os.path.join(checkpoints, 'SSL-4s-v7.ckpt')

In [9]:
model = SSL_EEG.load_from_checkpoint(prueba, decoder = MaskedDecoder(d_model=128, feat_dim=8), loss_fn= MaskedMSELoss)
model.eval()
# model2 = SSL_EEG.load_from_checkpoint(prueba2, decoder = MaskedDecoder(d_model=64, feat_dim=8), loss_fn= MaskedMSELoss)
# model2.eval()

SSL_EEG(
  (encoder): TSTransformerEncoder(
    (pos_enc): FixedPositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerBatchNormEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=256, out_features=128, bias=True)
          (norm1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (norm2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (dropout1): Dropout(p=0.1, inplace=False)
  )
 

# Load the data

In [10]:
eeg = os.path.join(os.getcwd(), 'data/LEMON-mask-4s-8channel.h5')
masked_dataset = MaskedDataset(hdf5_file=eeg, normalize= 'normalization')
# Create the data module with batch_size from wandb.config)
datamodule = SSLDataModule(dataset = masked_dataset,
    batch_size=1,
    )
datamodule.prepare_data()
val_dataset = datamodule.val_dataset
train_dataset = datamodule.train_dataset
val_data = datamodule.val_dataloader()

The length of the dataset is: 49182


Load the data from the UVA dataset. 

In [11]:
%matplotlib widget
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider


@interact(i=IntSlider(min=0, max=7, step=1, value=0), j=IntSlider(min=0, max=599, step=1, value=0), continuous_update=False)
def plote(i, j):
    sample_uva = train_dataset[j]
    sample_uva_masked = sample_uva[0].unsqueeze(0)
    sample_uva_raw = sample_uva[1]
    sample_uva_mask = sample_uva[2].unsqueeze(0)
    output_encoder, _= model(sample_uva_masked.to(model.device))
    print(output_encoder.shape)
    #output_encoder2, _= model2(sample_uva_masked.to(model2.device))
    # Define x, assuming 128 data points for each sample
    x = np.linspace(0, 512, 512)
    
    plt.figure(figsize=(10, 6))  # Create a new figure for each plot

    # Plot the raw data
    plt.plot(x, sample_uva_raw[:, i].cpu().detach().numpy(), label='Raw EEG')
    # Plot the mask
    #plt.plot(x, sample_uva_mask[0,:128, i].cpu().detach().numpy(), label='Mask')
    #plt.plot(x, sample_uva_masked[0, :, i].cpu().detach().numpy(), label='Masked')
    # Plot the prediction
    plt.plot(x, output_encoder[0,:, i].cpu().detach().numpy(), label='Predicted')
    
    #plt.plot(x, output_encoder2[0, :128, i].cpu().detach().numpy(), label='Predicted2')
    plt.legend()
    plt.grid(True)
    plt.xlabel('Time Step')
    plt.ylabel('Amplitude')
    plt.title(f'EEG Data Visualization for Channel {i+1} and Sample {j}')
    plt.show()

interactive(children=(IntSlider(value=0, description='i', max=7), IntSlider(value=0, description='j', max=599)…

Prove the model in another dataset different of the use for train. View the generalization of the model. 

In [None]:
deap_data = os.path.join(os.getcwd(), 'preprocess_data/deap_mask.hdf5')

In [None]:
deap_masked_dataset = MaskedDataset(hdf5_file=deap_data, normalize= 'normalization')
# Create the data module with batch_size from wandb.config
datamodule = SSLDataModule(dataset = deap_masked_dataset,
    batch_size=64,
    )

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider


@interact(i=IntSlider(min=0, max=7, step=1, value=0), j=IntSlider(min=0, max=99, step=1, value=0), continuous_update=False)
def plote(i, j):
    sample_uva = deap_masked_dataset[j]
    sample_uva_masked = sample_uva[0].unsqueeze(0)
    sample_uva_raw = sample_uva[1].unsqueeze(0)
    sample_uva_mask = sample_uva[2].unsqueeze(0)
    output_encoder, _= model(sample_uva_masked.to(model.device))

    # Define x, assuming 128 data points for each sample
    x = np.linspace(0, 128, 128)

    plt.figure(figsize=(10, 6))  # Create a new figure for each plot

    # Plot the raw data
    plt.plot(x, sample_uva_raw[:,i].cpu().detach().numpy(), label='Raw EEG')
    # Plot the mask
    plt.plot(x, sample_uva_mask[:,i].cpu().detach().numpy(), label='Mask')
    # Plot the prediction
    plt.plot(x, output_encoder[:,i].cpu().detach().numpy(), label='Predicted')

    plt.legend()
    plt.grid(True)
    plt.xlabel('Time Step')
    plt.ylabel('Amplitude')
    plt.title(f'EEG Data Visualization for Channel {i+1} and Sample {j}')
    plt.show()

interactive(children=(IntSlider(value=0, description='i', max=7), IntSlider(value=0, description='j', max=99),…

In [None]:
uva_data = os.path.join(os.getcwd(), 'data/UVA-mask.h5')

In [None]:
uva_masked = MaskedDataset(hdf5_file=uva_data, normalize= 'normalization')
# Create the data module with batch_size from wandb.config
datamodule = SSLDataModule(dataset = deap_masked_dataset,
    batch_size=64,
    )

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider


@interact(i=IntSlider(min=0, max=7, step=1, value=0), j=IntSlider(min=0, max=99, step=1, value=0), continuous_update=False)
def plote(i, j):
    sample_uva = deap_masked_dataset[j]
    sample_uva_masked = sample_uva[0].unsqueeze(0)
    sample_uva_raw = sample_uva[1].unsqueeze(0)
    sample_uva_mask = sample_uva[2].unsqueeze(0)
    output_encoder, _= model(sample_uva_masked.to(model.device))

    # Define x, assuming 128 data points for each sample
    x = np.linspace(0, 128, 128)

    plt.figure(figsize=(10, 6))  # Create a new figure for each plot

    # Plot the raw data
    plt.plot(x, sample_uva_raw[0,:,i].cpu().detach().numpy(), label='Raw EEG')
    # Plot the mask
    plt.plot(x, sample_uva_mask[0,:,i].cpu().detach().numpy(), label='Mask')
    # Plot the prediction
    plt.plot(x, output_encoder[0,:,i].cpu().detach().numpy(), label='Predicted')

    plt.legend()
    plt.grid(True)
    plt.xlabel('Time Step')
    plt.ylabel('Amplitude')
    plt.title(f'EEG Data Visualization for Channel {i+1} and Sample {j}')
    plt.show()

interactive(children=(IntSlider(value=0, description='i', max=7), IntSlider(value=0, description='j', max=99),…