In [22]:
import sys
import os
import torch
from torch.utils.data import DataLoader
from braindecode.datautil import load_concat_dataset

In [50]:
dir_current = os.getcwd()
dir_repo = os.path.dirname(dir_current)
# sys.path.append(dir_current)
sys.path.append(dir_repo)

from models import masking, myTemporal_Imputer, ShallowFBCSPFeatureExtractor
from loss import CrossEntropyLabelSmooth

In [11]:
dir_preprocessed = os.path.join(dir_repo, 'data', 'Schirrmeister2017_preprocessed')

subject_id = 3
subject_ids_lst = [subject_id, ]
# If a preprocessed dataset exists
if os.path.exists(dir_preprocessed) and os.listdir(dir_preprocessed):
    print('Preprocessed dataset exists')
    windows_dataset = load_concat_dataset(
        path = dir_preprocessed,
        preload = True,
        ids_to_load = list(range(2 * (subject_ids_lst[0] - 1), 2 * subject_ids_lst[-1])),
        target_name = None,
    )
    sfreq = windows_dataset.datasets[0].raw.info['sfreq']
    print('Preprocessed dataset loaded')

Preprocessed dataset exists
Reading 0 ... 3347499  =      0.000 ...  6694.998 secs...
Reading 0 ... 609499  =      0.000 ...  1218.998 secs...
Preprocessed dataset loaded


In [94]:
splitted = windows_dataset.split('run')
pretrain_set = splitted['0train']  
valid_set = splitted['1test'] 

batch_size = 32
pretrain_loader = DataLoader(pretrain_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(valid_set, batch_size=batch_size)

data_dimension = 128
num_classes = 4
sample_shape = torch.Size([data_dimension, 2250])
feature_extractor = ShallowFBCSPFeatureExtractor(sample_shape, 'drop', num_classes)
feature_dimension = 40
temporal_verifier = myTemporal_Imputer(feature_dimension, feature_dimension)

# losses
mse_loss = torch.nn.MSELoss()
cross_entropy = CrossEntropyLabelSmooth(4, 'cpu', epsilon=0.1)



In [100]:
windows_dataset.split('run')

{'0train': <braindecode.datasets.base.BaseConcatDataset at 0x1a4a9a9d390>,
 '1test': <braindecode.datasets.base.BaseConcatDataset at 0x1a4a9a9cb90>}

In [107]:
for step, (src_x, src_y, _) in enumerate(pretrain_loader):
    print(step)
    src_features, src_prediction = feature_extractor(src_x)
    src_features = src_features.squeeze(-1)
    # print(src_features.shape)

    masked_x, mask = masking(src_x, num_splits=10, num_masked=2)
    # print(masked_x.shape)
    masked_features, masked_prediction = feature_extractor(masked_x)
    masked_features = masked_features.squeeze(-1)
    # print(masked_features.shape)
    tov_predictions = temporal_verifier(masked_features.detach())
    # print(tov_predictions.shape)
    tov_loss = mse_loss(tov_predictions, src_features)
    print(tov_loss)
    # # # print(src_y)
    break

0
torch.Size([32, 40, 144])
torch.Size([32, 144, 40])
torch.Size([32, 144, 40])
torch.Size([32, 40, 144])
tensor(11.6131, grad_fn=<MseLossBackward0>)


In [106]:
feature_extractor

ShallowFBCSPFeatureExtractor(
  (model): ShallowFBCSPNet(
    (ensuredims): Ensure4d()
    (dimshuffle): Rearrange('batch C T 1 -> batch 1 T C')
    (conv_time_spat): CombinedConv(
      (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
      (conv_spat): Conv2d(40, 40, kernel_size=(1, 128), stride=(1, 1), bias=False)
    )
    (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_nonlin_exp): Expression(expression=square) 
    (pool): AvgPool2d(kernel_size=(75, 1), stride=(15, 1), padding=0)
    (pool_nonlin_exp): Expression(expression=safe_log) 
    (drop): Dropout(p=0.5, inplace=False)
    (final_layer): Sequential(
      (conv_classifier): Conv2d(40, 4, kernel_size=(144, 1), stride=(1, 1))
      (logsoftmax): LogSoftmax(dim=1)
      (squeeze): Expression(expression=squeeze_final_output) 
    )
  )
)