In [1]:
%cd ../

/Users/jingles/github/bcikit


In [2]:
import numpy as np

In [3]:
config = {
  "seed": 12,
  "segment_config": {
    "window_len": 1,
    "shift_len": 250,
    "sample_rate": 250,
  },
  "bandpass_config": {
      "sample_rate": 250,
      "lowcut": 7,
      "highcut": 70,
      "order": 6
  },
  "subject_ids": {
    "low": 1,
    "high": 3
  },
  "root": "_data/hsssvep",
  "selected_channels": ['PZ', 'PO5', 'PO3', 'POz', 'PO4', 'PO6', 'O1', 'Oz', 'O2', 'PO7', 'PO8'],
  "num_classes": 40,
  "batchsize": 64,
}

## Create preprocessing function

In [4]:
from bcikit.transforms.channels import pick_channels
from bcikit.transforms.segment_time import segment_data_time_domain
from bcikit.transforms.bandpass import butter_bandpass_filter


def preprocessing(data, targets, channel_names, sample_rate, selected_channels, segment_config, bandpass_config, verbose, **kwargs):
    print()
    print("preprocessing data shape", data.shape) # (subject,session,trial,channel,time)

    # filter channels
    data = pick_channels(
        data=data, 
        channel_names=channel_names, 
        selected_channels=selected_channels,
        verbose=False
    )
    print("after pick_channels", data.shape)

    # segment signal and select the first segment
    data = segment_data_time_domain(
        data=data,
        window_len=segment_config['window_len'],
        shift_len=segment_config['shift_len'],
        sample_rate=segment_config['sample_rate'],
        add_segment_axis=True,
    )
    data = data[:, :, :, :, 0, :] # select the first segment, and remove the rest
    print("after segment_data_time_domain", data.shape)

    # bandpass filter
    data = butter_bandpass_filter(
        signal=data, 
        lowcut=bandpass_config["lowcut"], 
        highcut=bandpass_config["highcut"], 
        sample_rate=bandpass_config["sample_rate"], 
        order=bandpass_config["order"]
    )
    print("after butter_bandpass_filter", data.shape)

    # since we are doing leave one subject out, we don't care about `session`, we only want data in this format (subject, trial, channel, time).
    data = data.reshape((data.shape[0], data.shape[1]*data.shape[2], data.shape[3], data.shape[4]))
    targets = targets.reshape((targets.shape[0], targets.shape[1]*targets.shape[2]))

    return data, targets

## Load data

In [5]:
from bcikit.datasets.ssvep import HSSSVEP
from bcikit.datasets import EEGDataloader
from bcikit.datasets.data_selection_methods import leave_one_subject_out


subject_ids = list(np.arange(config['subject_ids']['low'], config['subject_ids']['high']+1, dtype=int))
print("Load subject IDs", subject_ids)

data = EEGDataloader(
    dataset=HSSSVEP, 
    root=config["root"], 
    subject_ids=subject_ids,
    preprocessing_fn=preprocessing, # customize your own preprocessing
    data_selection_fn=leave_one_subject_out, # customize your data selection function or use common ones from `data_selection_methods`
    verbose=True,
    selected_channels=config["selected_channels"],
    segment_config=config["segment_config"],
    bandpass_config=config["bandpass_config"],
)

print('Final data shape', data.data.shape)
print()

train_loader, val_loader, test_loader = data.get_dataloaders(test_subject_id=1, batchsize=64)

print("train_loader:")
print(train_loader.dataset.data.shape)
print(train_loader.dataset.targets.shape)
print("val_loader:")
print(val_loader.dataset.data.shape)
print(val_loader.dataset.targets.shape)
print("test_loader:")
print(test_loader.dataset.data.shape)
print(test_loader.dataset.targets.shape)

Load subject IDs [1, 2, 3]
Load subject: 1
Load subject: 2
Load subject: 3

preprocessing data shape (3, 1, 240, 64, 1000)
after pick_channels (3, 1, 240, 11, 1000)
after segment_data_time_domain (3, 1, 240, 11, 250)
after butter_bandpass_filter (3, 1, 240, 11, 250)
Final data shape (3, 240, 11, 250)

train_loader:
(320, 11, 250)
(320,)
val_loader:
(160, 11, 250)
(160,)
test_loader:
(240, 11, 250)
(240,)
