In [1]:
from utils import *

# DN3

In [2]:
from abc import ABC
import torch, mne
from torch.utils.data import Dataset as TorchDataset
import numpy as np

from collections import OrderedDict
_EXTRA_CHANNELS = 5

from mne.io.constants import FIFF

In [3]:
_LEFT_NUMBERS = list(reversed(range(1, 9, 2)))
_RIGHT_NUMBERS = list(range(2, 10, 2))


DEEP_1010_CHS_LISTING = [
    # EEG
    "NZ",
    "FP1", "FPZ", "FP2",
    "AF7", "AF3", "AFZ", "AF4", "AF8",
    "F9", *["F{}".format(n) for n in _LEFT_NUMBERS], "FZ", *["F{}".format(n) for n in _RIGHT_NUMBERS], "F10",

    "FT9", "FT7", *["FC{}".format(n) for n in _LEFT_NUMBERS[1:]], "FCZ",
    *["FC{}".format(n) for n in _RIGHT_NUMBERS[:-1]], "FT8", "FT10",
                                                                                                                                  
    "T9", "T7", "T3",  *["C{}".format(n) for n in _LEFT_NUMBERS[1:]], "CZ",
    *["C{}".format(n) for n in _RIGHT_NUMBERS[:-1]], "T4", "T8", "T10",

    "TP9", "TP7", *["CP{}".format(n) for n in _LEFT_NUMBERS[1:]], "CPZ",
    *["CP{}".format(n) for n in _RIGHT_NUMBERS[:-1]], "TP8", "TP10",

    "P9", "P7", "T5",  *["P{}".format(n) for n in _LEFT_NUMBERS[1:]], "PZ",
    *["P{}".format(n) for n in _RIGHT_NUMBERS[:-1]],  "T6", "P8", "P10",

    "PO7", "PO3", "POZ", "PO4", "PO8",
    "O1",  "OZ", "O2",
    "IZ",
    # EOG
    "VEOGL", "VEOGR", "HEOGL", "HEOGR",

    # Ear clip references
    "A1", "A2", "REF",
    # SCALING
    "SCALE",
    # Extra
    *["EX{}".format(n) for n in range(1, _EXTRA_CHANNELS+1)]
]


# Not crazy about this approach..
from mne.utils._bunch import NamedInt
from mne.io.constants import FIFF
# Careful this doesn't overlap with future additions to MNE, might have to coordinate
DEEP_1010_SCALE_CH = NamedInt('DN3_DEEP1010_SCALE_CH', 3000)
DEEP_1010_EXTRA_CH = NamedInt('DN3_DEEP1010_EXTRA_CH', 3001)

EEG_INDS = list(range(0, DEEP_1010_CHS_LISTING.index('VEOGL')))
EOG_INDS = [DEEP_1010_CHS_LISTING.index(ch) for ch in ["VEOGL", "VEOGR", "HEOGL", "HEOGR"]]
REF_INDS = [DEEP_1010_CHS_LISTING.index(ch) for ch in ["A1", "A2", "REF"]]
EXTRA_INDS = list(range(len(DEEP_1010_CHS_LISTING) - _EXTRA_CHANNELS, len(DEEP_1010_CHS_LISTING)))
SCALE_IND = -len(EXTRA_INDS) + len(DEEP_1010_CHS_LISTING)
_NUM_EEG_CHS = len(DEEP_1010_CHS_LISTING) - len(EOG_INDS) - len(REF_INDS) - len(EXTRA_INDS) - 1

DEEP_1010_CH_TYPES = ([FIFF.FIFFV_EEG_CH] * _NUM_EEG_CHS) + ([FIFF.FIFFV_EOG_CH] * len(EOG_INDS)) + \
                     ([FIFF.FIFFV_EEG_CH] * len(REF_INDS)) + [DEEP_1010_SCALE_CH] + \
                     ([DEEP_1010_EXTRA_CH] * _EXTRA_CHANNELS)

## Raw Torch Reading

In [4]:
class Preprocessor:
    """
    Base class for various preprocessing actions. Sub-classes are called with a subclass of `_Recording`
    and operate on these instances in-place.

    Any modifications to data specifically should be implemented through a subclass of :any:`BaseTransform`, and
    returned by the method :meth:`get_transform()`
    """
    def __call__(self, recording, **kwargs):
        """
        Preprocess a particular recording. This is allowed to modify aspects of the recording in-place, but is not
        strictly advised.

        Parameters
        ----------
        recording :
        kwargs : dict
                 New :any:`_Recording` subclasses may need to provide additional arguments. This is here for support of
                 this.
        """
        raise NotImplementedError()

    def get_transform(self):
        """
        Generate and return any transform associated with this preprocessor. Should be used after applying this
        to a dataset, i.e. through :meth:`DN3ataset.preprocess`

        Returns
        -------
        transform : BaseTransform
        """
        raise NotImplementedError()

In [5]:
class DN3ataset(TorchDataset):

    def __init__(self):
        """
        Base class for that specifies the interface for DN3 datasets.
        """
        self._transforms = list()
        self._safe_mode = False
        self._mutli_proc_start = None
        self._mutli_proc_end = None

    def __getitem__(self, item):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    @property
    def sfreq(self):
        """
        Returns
        -------
        sampling_frequency: float, list
                            The sampling frequencies employed by the dataset.
        """
        raise NotImplementedError

    @property
    def channels(self):
        """
        Returns
        -------
        channels: list
                  The channel sets used by the dataset.
                """
        raise NotImplementedError

    @property
    def sequence_length(self):
        """
        Returns
        -------
        sequence_length: int, list
                         The length of each instance in number of samples
            """
        raise NotImplementedError

    def clone(self):
        """
        A copy of this object to allow the repetition of recordings, thinkers, etc. that load data from
        the same memory/files but have their own tracking of ids.

        Returns
        -------
        cloned : DN3ataset
                 New copy of this object.
        """
        return copy.deepcopy(self)

    def add_transform(self, transform):
        """
        Add a transformation that is applied to every fetched item in the dataset

        Parameters
        ----------
        transform : BaseTransform
                    For each item retrieved by __getitem__, transform is called to modify that item.
        """
        if isinstance(transform, InstanceTransform):
            self._transforms.append(transform)

    def _execute_transforms(self, *x):
        for transform in self._transforms:
            assert isinstance(transform, InstanceTransform)
            if transform.only_trial_data:
                new_x = transform(x[0])
                if isinstance(new_x, (list, tuple)):
                    x = (*new_x, *x[1:])
                else:
                    x = (new_x, *x[1:])
            else:
                x = transform(*x)

            if self._safe_mode:
                for i in range(len(x)):
                    if torch.any(torch.isnan(x[i])):
                        raise DN3atasetNanFound("NaN generated by transform {} for {}'th tensor".format(
                            self, i))
        return x

    def clear_transforms(self):
        """
        Remove all added transforms from dataset.
        """
        self._transforms = list()

    def preprocess(self, preprocessor: Preprocessor, apply_transform=True):
        """
        Applies a preprocessor to the dataset

        Parameters
        ----------
        preprocessor : Preprocessor
                       A preprocessor to be applied
        apply_transform : bool
                          Whether to apply the transform to this dataset (and all members e.g thinkers or sessions)
                          after preprocessing them. Alternatively, the preprocessor is returned for manual application
                          of its transform through :meth:`Preprocessor.get_transform()`

        Returns
        ---------
        processed_data : ndarry
                         Data that has been modified by the preprocessor, should be in the shape of [*, C, T], with C
                         and T according with the `channels` and `sequence_length` properties respectively.
        """
        raise NotImplementedError

    def to_numpy(self, batch_size=64, batch_transforms: list = None, num_workers=4, **dataloader_kwargs):
        """
        Commits the dataset to numpy-formatted arrays. Useful for saving dataset to disk, or preparing for tools that
        expect numpy-formatted data rather than iteratable.

        Notes
        -----
        A pytorch :any:`DataLoader` is used to fetch the data to conveniently leverage multiprocessing, and naturally

        Parameters
        ----------
        batch_size: int
                   The number of items to fetch per worker. This probably doesn't need much tuning.
        num_workers: int
                     The number of spawned processes to fetch and transform data.
        batch_transforms: list
                         These are potential batch-level transforms that
        dataloader_kwargs: dict
                          Keyword arguments for the pytorch :any:`DataLoader` that underpins the fetched data

        Returns
        -------
        data: list
              A list of numpy arrays.
        """
        dataloader_kwargs.setdefault('batch_size', batch_size)
        dataloader_kwargs.setdefault('num_workers', num_workers)
        dataloader_kwargs.setdefault('shuffle', False)
        dataloader_kwargs.setdefault('drop_last', False)

        batch_transforms = list() if batch_transforms is None else batch_transforms

        loaded = None
        loader = DataLoader(self, **dataloader_kwargs)
        for batch in tqdm.tqdm(loader, desc="Loading Batches"):
            for xform in batch_transforms:
                assert callable(xform)
                batch = xform(batch)
            # cpu just to be certain, shouldn't affect things otherwise
            batch = [b.cpu().numpy() for b in batch]
            if loaded is None:
                loaded = batch
            else:
                loaded = [np.concatenate([loaded[i], batch[i]], axis=0) for i in range(len(batch))]

        return loaded
    
class _Recording(DN3ataset, ABC):
    """
    Abstract base class for any supported recording
    """
    def __init__(self, info, session_id, person_id, tlen, ch_ind_picks=None):
        super().__init__()
        self.info = info
        self.picks = ch_ind_picks if ch_ind_picks is not None else list(range(len(info['chs'])))
        self._recording_channels = [(ch['ch_name'], int(ch['kind'])) for idx, ch in enumerate(info['chs'])
                                    if idx in self.picks]
        self._recording_sfreq = info['sfreq']
        self._recording_len = int(self._recording_sfreq * tlen)
        assert self._recording_sfreq is not None
        self.session_id = session_id
        self.person_id = person_id

    def get_all(self):
        all_recordings = [self[i] for i in range(len(self))]
        return [torch.stack(t) for t in zip(*all_recordings)]

    @property
    def sfreq(self):
        sfreq = self._recording_sfreq
        for xform in self._transforms:
            sfreq = xform.new_sfreq(sfreq)
        return sfreq

    @property
    def channels(self):
        channels = np.array(self._recording_channels)
        for xform in self._transforms:
            channels = xform.new_channels(channels)
        return channels

    @property
    def sequence_length(self):
        sequence_length = self._recording_len
        for xform in self._transforms:
            sequence_length = xform.new_sequence_length(sequence_length)
        return sequence_length

class RawTorchRecording(_Recording):
    """
    Interface for bridging mne Raw instances as PyTorch compatible "Dataset".

    Parameters
    ----------
    raw : mne.io.Raw
          Raw data, data does not need to be preloaded.
    tlen : float
          Length of recording specified in seconds.
    session_id : (int, str, optional)
          A unique (with respect to a thinker within an eventual dataset) identifier for the current recording
          session. If not specified, defaults to '0'.
    person_id : (int, str, optional)
          A unique (with respect to an eventual dataset) identifier for the particular person being recorded.
    stride : int
          The number of samples to skip between each starting offset of loaded samples.
    """

    def __init__(self, raw: mne.io.Raw, tlen, session_id=0, person_id=0, stride=1, ch_ind_picks=None, decimate=1,
                 bad_spans=None, **kwargs):

        """
        Interface for bridging mne Raw instances as PyTorch compatible "Dataset".

        Parameters
        ----------
        raw : mne.io.Raw
              Raw data, data does not need to be preloaded.
        tlen : float
              Length of each retrieved portion of the recording.
        session_id : (int, str, optional)
              A unique (with respect to a thinker within an eventual dataset) identifier for the current recording
              session. If not specified, defaults to '0'.
        person_id : (int, str, optional)
              A unique (with respect to an eventual dataset) identifier for the particular person being recorded.
        stride : int
              The number of samples to skip between each starting offset of loaded samples.
        ch_ind_picks : list[int]
                       A list of channel indices that have been selected for.
        decimate : int
                   The number of samples to move before taking the next sample, in other words take every decimate'th
                   sample.
        bad_spans: List[tuple], None
                   These are tuples of (start_seconds, end_seconds) of times that should be avoided. Any sequences that
                   would overlap with these sections will be excluded.
        """
        super().__init__(raw.info, session_id, person_id, tlen, ch_ind_picks)
        self.filename = raw.filenames[0]
        self.decimate = int(decimate)
        self._recording_sfreq /= self.decimate
        self._recording_len = int(tlen * self._recording_sfreq)
        self.stride = stride
        # Implement my own (rather than mne's) in-memory buffer when there are savings
        self._stride_load = self.decimate > 1 and raw.preload
        self.max = kwargs.get('max', None)
        self.min = kwargs.get('min', 0)
        bad_spans = list() if bad_spans is None else bad_spans
        self.__dict__.update(kwargs)

        self._decimated_sequence_starts = list(
            range(0, raw.n_times // self.decimate - self._recording_len, self.stride)
        )
        # TODO come back to this inefficient BS
        for start, stop in bad_spans:
            start = int(self._recording_sfreq * start)
            stop = int(stop * self._recording_sfreq)
            drop = list()
            for i, span_start in enumerate(self._decimated_sequence_starts):
                if start <= span_start < stop or start <= span_start + self._recording_len <= stop:
                    drop.append(span_start)
            for span_start in drop:
                self._decimated_sequence_starts.remove(span_start)

        # When the stride is greater than the sequence length, preload savings can be found by chopping the
        # sequence into subsequences of length: sequence length. Also, if decimating, can significantly reduce memory
        # requirements not otherwise addressed with the Raw object.
        if self._stride_load and len(self._decimated_sequence_starts) > 0:
            x = raw.get_data(self.picks)
            # pre-decimate this data for more preload savings (and for the stride factors to be valid)
            x = x[:, ::decimate]
            self._x = np.empty([x.shape[0], self._recording_len, len(self._decimated_sequence_starts)], dtype=x.dtype)
            for i, start in enumerate(self._decimated_sequence_starts):
                self._x[..., i] = x[:, start:start + self._recording_len]
        else:
            self._raw_workaround(raw)

    def _raw_workaround(self, raw):
        self.raw = raw

    def __getitem__(self, index):
        if index < 0:
            index += len(self)

        if self._stride_load:
            x = self._x[:, :, index]
        else:
            start = self._decimated_sequence_starts[index]
            x = self.raw.get_data(self.picks, start=start, stop=start + self._recording_len * self.decimate)
            if self.decimate > 1:
                x = x[:, ::self.decimate]

        scale = 1 if self.max is None else (x.max() - x.min()) / (self.max - self.min)
        if scale > 1 or np.isnan(scale):
            print('Warning: scale exeeding 1')

        x = torch.from_numpy(x).float()

        if torch.any(torch.isnan(x)):
            print("Nan found: raw {}, index {}".format(self.filename, index))
            print("Replacing with random values with same shape for now...")
            x = torch.rand_like(x)

        return self._execute_transforms(x)

    def __len__(self):
        return len(self._decimated_sequence_starts)

    def preprocess(self, preprocessor: Preprocessor, apply_transform=True):
        self.raw = preprocessor(recording=self)
        if apply_transform:
            self.add_transform(preprocessor.get_transform())

## _DumbNumbSpace

In [6]:
class _DumbNamespace:
    def __init__(self, d: dict):
        self._d = d.copy()
        for k in d:
            if isinstance(d[k], dict):
                d[k] = _DumbNamespace(d[k])
            if isinstance(d[k], list):
                d[k] = [_DumbNamespace(d[k][i]) if isinstance(d[k][i], dict) else d[k][i] for i in range(len(d[k]))]
        self.__dict__.update(d)

    def keys(self):
        return list(self.__dict__.keys())

    def __getitem__(self, item):
        return self.__dict__[item]

    def as_dict(self):
        return self._d


def _adopt_auxiliaries(obj, remaining):
    def namespaceify(v):
        if isinstance(v, dict):
            return _DumbNamespace(v)
        elif isinstance(v, list):
            return [namespaceify(v[i]) for i in range(len(v))]
        else:
            return v

    obj.__dict__.update({k: namespaceify(v) for k, v in remaining.items()})

## Mapping

In [7]:
def map_named_channels_deep_1010(channel_names: list, EOG=None, ear_ref=None, extra_channels=None):
    """
    Maps channel names to the Deep1010 format, will automatically map EOG and extra channels if they have been
    named according to standard convention. Otherwise provide as keyword arguments.

    Parameters
    ----------
    channel_names : list
                   List of channel names from dataset
    EOG : list, str
         Must be a single channel name, or left and right EOG channels, optionally vertical L/R then horizontal
         L/R for four channels.
    ear_ref : Optional, str, list
               One or two channels to be used as references. If two, should be left and right in that order.
    extra_channels : list, None
                     Up to 6 extra channels to include. Currently not standardized, but could include ECG, respiration,
                     EMG, etc.

    Returns
    -------
    mapping : torch.Tensor
              Mapping matrix from previous channel sequence to Deep1010.
    """
    map = np.zeros((len(channel_names), len(DEEP_1010_CHS_LISTING)))

    if isinstance(EOG, str):
        EOG = [EOG] * 4
    elif len(EOG) == 1:
        EOG = EOG * 4
    elif EOG is None or len(EOG) == 0:
        EOG = []
    elif len(EOG) == 2:
        EOG = EOG * 2
    else:
        assert len(EOG) == 4
    for eog_map, eog_std in zip(EOG, EOG_INDS):
        try:
            map[channel_names.index(eog_map), eog_std] = 1.0
        except ValueError:
            raise ValueError("EOG channel {} not found in provided channels.".format(eog_map))

    if isinstance(ear_ref, str):
        ear_ref = [ear_ref] * 2
    elif ear_ref is None:
        ear_ref = []
    else:
        assert len(ear_ref) <= len(REF_INDS)
    for ref_map, ref_std in zip(ear_ref, REF_INDS):
        try:
            map[channel_names.index(ref_map), ref_std] = 1.0
        except ValueError:
            raise ValueError("Reference channel {} not found in provided channels.".format(ref_map))

    if isinstance(extra_channels, str):
        extra_channels = [extra_channels]
    elif extra_channels is None:
        extra_channels = []
    assert len(extra_channels) <= _EXTRA_CHANNELS
    for ch, place in zip(extra_channels, EXTRA_INDS):
        if ch is not None:
            map[channel_names.index(ch), place] = 1.0

    return _deep_1010(map, channel_names, EOG, ear_ref, extra_channels)


In [8]:
def _likely_eeg_channel(name):
    if name is not None:
        for ch in DEEP_1010_CHS_LISTING[:_NUM_EEG_CHS]:
            if ch in name.upper():
                return True
    return False


def _deep_1010(map, names, eog, ear_ref, extra):

    for i, ch in enumerate(names):
        if ch not in eog and ch not in ear_ref and ch not in extra:
            try:
                map[i, DEEP_1010_CHS_LISTING.index(str(ch).upper())] = 1.0
            except ValueError:
                print("Warning: channel {} not found in standard layout. Skipping...".format(ch))
                continue

    # Normalize for when multiple values are mapped to single location
    summed = map.sum(axis=0)[np.newaxis, :]
    mapping = torch.from_numpy(np.divide(map, summed, out=np.zeros_like(map), where=summed != 0)).float()
    mapping.requires_grad_(False)
    return mapping


def _valid_character_heuristics(name, informative_characters):
    possible = ''.join(c for c in name.upper() if c in informative_characters).replace(' ', '')
    if possible == "":
        print("Could not use channel {}. Could not resolve its true label, rename first.".format(name))
        return None
    return possible


def _check_num_and_get_types(type_dict: OrderedDict):
    type_lists = list()
    for ch_type, max_num in zip(('eog', 'ref'), (len(EOG_INDS), len(REF_INDS))):
        channels = [ch_name for ch_name, _type in type_dict.items() if _type == ch_type]

        for name in channels[max_num:]:
            print("Losing assumed {} channel {} because there are too many.".format(ch_type, name))
            type_dict[name] = None
        type_lists.append(channels[:max_num])
    return type_lists[0], type_lists[1]

def _heuristic_eog_resolution(eog_channel_name):
    return _valid_character_heuristics(eog_channel_name, "VHEOGLR")


def _heuristic_ref_resolution(ref_channel_name: str):
    ref_channel_name = ref_channel_name.replace('EAR', '')
    ref_channel_name = ref_channel_name.replace('REF', '')
    if ref_channel_name.find('A1') != -1:
        return 'A1'
    elif ref_channel_name.find('A2') != -1:
        return 'A2'

    if ref_channel_name.find('L') != -1:
        return 'A1'
    elif ref_channel_name.find('R') != -1:
        return 'A2'
    return "REF"


def _heuristic_eeg_resolution(eeg_ch_name: str):
    eeg_ch_name = eeg_ch_name.upper()
    # remove some common garbage
    eeg_ch_name = eeg_ch_name.replace('EEG', '')
    eeg_ch_name = eeg_ch_name.replace('REF', '')
    informative_characters = set([c for name in DEEP_1010_CHS_LISTING[:_NUM_EEG_CHS] for c in name])
    return _valid_character_heuristics(eeg_ch_name, informative_characters)


def _heuristic_resolution(old_type_dict: OrderedDict):
    resolver = {'eeg': _heuristic_eeg_resolution, 'eog': _heuristic_eog_resolution, 'ref': _heuristic_ref_resolution,
                'extra': lambda x: x, None: lambda x: x}

    new_type_dict = OrderedDict()

    for old_name, ch_type in old_type_dict.items():
        if ch_type is None:
            new_type_dict[old_name] = None
            continue

        new_name = resolver[ch_type](old_name)
        if new_name is None:
            new_type_dict[old_name] = None
        else:
            while new_name in new_type_dict.keys():
                print('Deep1010 Heuristics resulted in duplicate entries for {}, incrementing name, but will be lost '
                      'in mapping'.format(new_name))
                new_name = new_name + '-COPY'
            new_type_dict[new_name] = old_type_dict[old_name]

    assert len(new_type_dict) == len(old_type_dict)
    return new_type_dict


In [9]:
def map_dataset_channels_deep_1010(channels: np.ndarray, exclude_stim=True):
    """
    Maps channels as stored by a :any:`DN3ataset` to the Deep1010 format, will automatically map EOG and extra channels
    by type.

    Parameters
    ----------
    channels : np.ndarray
               Channels that remain a 1D sequence (they should not have been projected into 2 or 3D grids) of name and
               type. This means the array has 2 dimensions:
               ..math:: N_{channels} \by 2
               With the latter dimension containing name and type respectively, as is constructed by default in most
               cases.
    exclude_stim : bool
                   This option allows the stim channel to be added as an *extra* channel. The default (True) will not do
                   this, and it is very rare if ever where this would be needed.

    Warnings
    --------
    If for some reason the stim channel is labelled with a label from the `DEEP_1010_CHS_LISTING` it will be included
    in that location and result in labels bleeding into the observed data.

    Returns
    -------
    mapping : torch.Tensor
              Mapping matrix from previous channel sequence to Deep1010.
    """
    if len(channels.shape) != 2 or channels.shape[1] != 2:
        raise ValueError("Deep1010 Mapping: channels must be a 2 dimensional array with dim0 = num_channels, dim1 = 2."
                         " Got {}".format(channels.shape))
    channel_types = OrderedDict()

    # Use this for some semblance of order in the "extras"
    extra = [None for _ in range(_EXTRA_CHANNELS)]
    extra_idx = 0

    for name, ch_type in channels:
        # Annoyingly numpy converts them to strings...
        ch_type = int(ch_type)
        if ch_type == FIFF.FIFFV_EEG_CH and _likely_eeg_channel(name):
            channel_types[name] = 'eeg'
        elif ch_type == FIFF.FIFFV_EOG_CH or name in [DEEP_1010_CHS_LISTING[idx] for idx in EOG_INDS]:
            channel_types[name] = 'eog'
        elif ch_type == FIFF.FIFFV_STIM_CH:
            if exclude_stim:
                channel_types[name] = None
                continue
            # if stim, always set as last extra
            channel_types[name] = 'extra'
            extra[-1] = name
        elif 'REF' in name.upper() or 'A1' in name.upper() or 'A2' in name.upper() or 'EAR' in name.upper():
            channel_types[name] = 'ref'
        else:
            if extra_idx == _EXTRA_CHANNELS - 1 and not exclude_stim:
                print("Stim channel overwritten by {} in Deep1010 mapping.".format(name))
            elif extra_idx == _EXTRA_CHANNELS:
                print("No more room in extra channels for {}".format(name))
                continue
            channel_types[name] = 'extra'
            extra[extra_idx] = name
            extra_idx += 1

    revised_channel_types = _heuristic_resolution(channel_types)
    eog, ref = _check_num_and_get_types(revised_channel_types)

    return map_named_channels_deep_1010(list(revised_channel_types.keys()), eog, ref, extra)

In [10]:
class InstanceTransform(object):

    def __init__(self, only_trial_data=True):
        """
        Trial transforms are, for the most part, simply operations that are performed on the loaded tensors when they are
        fetched via the :meth:`__call__` method. Ideally this is implemented with pytorch operations for ease of execution
        graph integration.
        """
        self.only_trial_data = only_trial_data

    def __str__(self):
        return self.__class__.__name__

    def __call__(self, *x):
        """
        Modifies a batch of tensors.
        Parameters
        ----------
        x : torch.Tensor, tuple
            The trial tensor, not including a batch-dimension. If initialized with `only_trial_data=False`, then this
            is a tuple of all ids, labels, etc. being propagated.
        Returns
        -------
        x : torch.Tensor, tuple
            The modified trial tensor, or tensors if not `only_trial_data`
        """
        raise NotImplementedError()

    def new_channels(self, old_channels):
        """
        This is an optional method that indicates the transformation modifies the representation and/or presence of
        channels.

        Parameters
        ----------
        old_channels : ndarray
                       An array whose last two dimensions are channel names and channel types.

        Returns
        -------
        new_channels : ndarray
                      An array with the channel names and types after this transformation. Supports the addition of
                      dimensions e.g. a list of channels into a rectangular grid, but the *final two dimensions* must
                      remain the channel names, and types respectively.
        """
        return old_channels

    def new_sfreq(self, old_sfreq):
        """
        This is an optional method that indicates the transformation modifies the sampling frequency of the underlying
        time-series.

        Parameters
        ----------
        old_sfreq : float

        Returns
        -------
        new_sfreq : float
        """
        return old_sfreq

    def new_sequence_length(self, old_sequence_length):
        """
        This is an optional method that indicates the transformation modifies the length of the acquired extracts,
        specified in number of samples.

        Parameters
        ----------
        old_sequence_length : int

        Returns
        -------
        new_sequence_length : int
        """
        return old_sequence_length
    

def min_max_normalize(x: torch.Tensor, low=-1, high=1):
    if len(x.shape) == 2:
        xmin = x.min()
        xmax = x.max()
        if xmax - xmin == 0:
            x = 0
            return x
    elif len(x.shape) == 3:
        xmin = torch.min(torch.min(x, keepdim=True, dim=1)[0], keepdim=True, dim=-1)[0]
        xmax = torch.max(torch.max(x, keepdim=True, dim=1)[0], keepdim=True, dim=-1)[0]
        constant_trials = (xmax - xmin) == 0
        if torch.any(constant_trials):
            # If normalizing multiple trials, stabilize the normalization
            xmax[constant_trials] = xmax[constant_trials] + 1e-6

    x = (x - xmin) / (xmax - xmin)

    # Now all scaled 0 -> 1, remove 0.5 bias
    x -= 0.5
    # Adjust for low/high bias and scale up
    x += (high + low) / 2
    return (high - low) * x

    

class MappingDeep1010(InstanceTransform):
    """
    Maps various channel sets into the Deep10-10 scheme, and normalizes data between [-1, 1] with an additional scaling
    parameter to describe the relative scale of a trial with respect to the entire dataset.

    See https://doi.org/10.1101/2020.12.17.423197  for description.
    """
    def __init__(self, dataset, max_scale=None, return_mask=False):
        """
        Creates a Deep10-10 mapping for the provided dataset.

        Parameters
        ----------
        dataset : Dataset

        max_scale : float
                    If specified, the scale ind is filled with the relative scale of the trial with respect
                    to this, otherwise uses dataset.info.data_max - dataset.info.data_min.
        return_mask : bool
                      If `True` (`False` by default), an additional tensor is returned after this transform that
                      says which channels of the mapping are in fact in use.
        """
        super().__init__()
        self.mapping = map_dataset_channels_deep_1010(dataset.channels)
        if max_scale is not None:
            self.max_scale = max_scale
        elif dataset.info is None or dataset.info.data_max is None or dataset.info.data_min is None:
            print(f"Warning: Did not find data scale information for {dataset}")
            self.max_scale = None
            pass
        else:
            self.max_scale = dataset.info.data_max - dataset.info.data_min
        self.return_mask = return_mask

    @staticmethod
    def channel_listing():
        return DEEP_1010_CHS_LISTING

    def __call__(self, x):
        if self.max_scale is not None:
            scale = 2 * (torch.clamp_max((x.max() - x.min()) / self.max_scale, 1.0) - 0.5)
        else:
            scale = 0

        x = (x.transpose(1, 0) @ self.mapping).transpose(1, 0)

        for ch_type_inds in (EEG_INDS, EOG_INDS, REF_INDS, EXTRA_INDS):
            x[ch_type_inds, :] = min_max_normalize(x[ch_type_inds, :])

        used_channel_mask = self.mapping.sum(dim=0).bool()
        x[~used_channel_mask, :] = 0

        x[SCALE_IND, :] = scale

        if self.return_mask:
            return (x, used_channel_mask)
        else:
            return x

    def new_channels(self, old_channels: np.ndarray):
        channels = list()
        for row in range(self.mapping.shape[1]):
            active = self.mapping[:, row].nonzero().numpy()
            if len(active) > 0:
                channels.append("-".join([old_channels[i.item(), 0] for i in active]))
            else:
                channels.append(None)
        return np.array(list(zip(channels, DEEP_1010_CH_TYPES)))


## 1020

In [11]:
class To1020(InstanceTransform):

    EEG_20_div = [
               'FP1', 'FP2',
        'F7', 'F3', 'FZ', 'F4', 'F8',
        'T7', 'C3', 'CZ', 'C4', 'T8',
        'T5', 'P3', 'PZ', 'P4', 'T6',
                'O1', 'O2'
    ]

    def __init__(self, only_trial_data=True, include_scale_ch=True, include_ref_chs=False):
        """
        Transforms incoming Deep1010 data into exclusively the more limited 1020 channel set.
        """
        super(To1020, self).__init__(only_trial_data=only_trial_data)
        self._inds_20_div = [DEEP_1010_CHS_LISTING.index(ch) for ch in self.EEG_20_div]
        if include_ref_chs:
            self._inds_20_div.append([DEEP_1010_CHS_LISTING.index(ch) for ch in ['A1', 'A2']])
        if include_scale_ch:
            self._inds_20_div.append(SCALE_IND)

    def new_channels(self, old_channels):
        return old_channels[self._inds_20_div]

    def __call__(self, *x):
        x = list(x)
        for i in range(len(x)):
            # Assume every tensor that has deep1010 length should be modified
            if len(x[i].shape) > 0 and x[i].shape[0] == len(DEEP_1010_CHS_LISTING):
                x[i] = x[i][self._inds_20_div, ...]
        return x
    

# Test

In [12]:
def min_max_normalize(x: torch.Tensor, low=-1, high=1):
    if len(x.shape) == 2:
        xmin = x.min()
        xmax = x.max()
        if xmax - xmin == 0:
            x = 0
            return x
    elif len(x.shape) == 3:
        xmin = torch.min(torch.min(x, keepdim=True, dim=1)[0], keepdim=True, dim=-1)[0]
        xmax = torch.max(torch.max(x, keepdim=True, dim=1)[0], keepdim=True, dim=-1)[0]
        constant_trials = (xmax - xmin) == 0
        if torch.any(constant_trials):
            # If normalizing multiple trials, stabilize the normalization
            xmax[constant_trials] = xmax[constant_trials] + 1e-6

    x = (x - xmin) / (xmax - xmin)

    # Now all scaled 0 -> 1, remove 0.5 bias
    x -= 0.5
    # Adjust for low/high bias and scale up
    x += (high + low) / 2
    return (high - low) * x

In [13]:
# load raw eeg
DATA_PATH_RAW = '../../data/eegmmidb (raw)/files/'

patient = 'S001'
run = 'R03'
FILE = DATA_PATH_RAW+f'{patient}/{patient}{run}.edf'

raw_full = get_raw(FILE)

annotations = get_annotations(FILE)

annotation_dict = get_window_dict(raw_full, annotations)

In [14]:
raw = annotation_dict['T0'][0]
raw = pick_and_rename_channels(raw)

In [15]:
recording = RawTorchRecording(raw, 4 , stride=1, decimate=1, ch_ind_picks=None, bad_spans=None)

In [16]:
# set param 
sfreq = 250
new_sfreq = 256
data_max = 3276.7
data_min = -1583.9258304722666

_dum = _DumbNamespace(dict(channels=recording.channels, info=dict(data_max=data_max, data_min=data_min)))

In [17]:
xform = MappingDeep1010(_dum, return_mask = True)
recording.add_transform(xform)

recording.add_transform(To1020())

In [18]:
output1020 = recording.__getitem__(0)[0]
final_example = output1020[None, :]

In [19]:
final_example.shape

torch.Size([1, 20, 1024])

In [20]:
final_example

tensor([[[-0.5255, -0.4970, -0.4779,  ..., -0.6889, -0.7120, -0.7183],
         [-0.5255, -0.5352, -0.5274,  ..., -0.7079, -0.7190, -0.7219],
         [-0.5255, -0.5220, -0.5161,  ..., -0.5466, -0.5705, -0.5705],
         ...,
         [-0.5255, -0.5404, -0.5507,  ..., -0.5279, -0.5082, -0.4976],
         [-0.5255, -0.5751, -0.6067,  ..., -0.5267, -0.5119, -0.4942],
         [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000]]])

In [21]:
NUMBER_CHANNELS = 20
NUMBER_SAMPLES = 1024

x = torch.zeros((1, NUMBER_CHANNELS, NUMBER_SAMPLES))

x[:, :19, :] = torch.from_numpy(raw.copy().get_data()[:, :NUMBER_SAMPLES].reshape(1, NUMBER_CHANNELS - 1, NUMBER_SAMPLES))
x = min_max_normalize(x)

x[:,19,:] = torch.ones((1, NUMBER_SAMPLES)) * -1  

x

tensor([[[-0.5255, -0.4970, -0.4779,  ..., -0.6889, -0.7120, -0.7183],
         [-0.5255, -0.5352, -0.5274,  ..., -0.7079, -0.7190, -0.7219],
         [-0.5255, -0.5220, -0.5161,  ..., -0.5466, -0.5705, -0.5705],
         ...,
         [-0.5255, -0.5404, -0.5507,  ..., -0.5279, -0.5082, -0.4976],
         [-0.5255, -0.5751, -0.6067,  ..., -0.5267, -0.5119, -0.4942],
         [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000]]])

In [33]:
torch.all(final_example[:,:-1] == x[:,:-1])

tensor(True)