In [10]:
import torch
import torch.nn.functional as F
from utils import seed_everything, load_config
from models import EEGEncoder, EEGAE
from dataset import get_eeg_loader
from glob import glob
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [2]:
config = load_config("./configs/EEG_trainer.yaml")
data_config = config["data"]
eeg_config = config["EEG"]
train_config = config["train"]
seed_everything(42) #42

In [4]:
eeg_list = sorted(glob("./data/split_eeg/*/*.csv"))
eeg_loader = get_eeg_loader(eeg_list, 16, False)

In [6]:
eeg_ae = EEGAE(eeg_config)
eeg_ae.load_state_dict(torch.load("./ckpt_temp/EEG_AE/epoch_300.pt"))

<All keys matched successfully>

In [7]:
eeg_ae.eval()

EEGAE(
  (encoder): EEGEncoder(
    (shallow_net): ShallowEncoderNet(
      (temporal_conv): Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1))
      (spatial_conv): Conv2d(40, 40, kernel_size=(60, 1), stride=(1, 1))
      (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (elu): ELU(alpha=1.0)
      (avgpool): AvgPool2d(kernel_size=(1, 25), stride=(1, 5), padding=0)
      (dr): Dropout(p=0.5, inplace=False)
      (enhance_conv): Conv2d(40, 40, kernel_size=(1, 1), stride=(1, 1))
    )
    (trns_enc_blk): TransformerEncoderBlock(
      (multi_head_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=40, out_features=40, bias=True)
      )
      (feed_foward): Sequential(
        (0): Linear(in_features=40, out_features=160, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=160, out_features=40, bias=True)
      )
      (norm): Layer

In [9]:
eeg = next(iter(eeg_loader))

eeg = eeg.float()

In [16]:
recon_eeg = eeg_ae(eeg)
recon_eeg = recon_eeg.detach().numpy()
print(np.mean(np.abs(recon_eeg - np.array(eeg))))

0.4126677
