In [None]:
!pip install wandb torchsummaryX mne transformers -q

In [None]:
### Torch
import  torch
import  torch.nn as nn
import  torch.nn.functional as F
from    torch.optim import lr_scheduler
from    torchsummaryX import summary
from    torch.utils.data import Dataset, DataLoader
import  torchaudio
import  torchaudio.transforms as tat

### General
import  random
import  numpy as np
import  pandas as pd
import  scipy
import  gc
from    tqdm.auto import tqdm
import  os
import  datetime
import  time
import  wandb
import  matplotlib.pyplot as plt
import  seaborn as sns

# wav2vec2 and EEG processing
from    transformers import (AutoProcessor, AutoModelForPreTraining,
                             CLIPProcessor, CLIPModel)
import  mne

# data [QA]

## studies/brennan.py

In [1]:
def _read_meta(fname):
    proc = loadmat(
        fname,
        squeeze_me=True,
        chars_as_strings=True,
        struct_as_record=True,
        simplify_cells=True,
    )["proc"]

    # ref = proc["implicitref"]
    # ref_channels = proc["refchannels"]

    # subject_id = proc["subject"]
    meta = proc["trl"]

    # TODO artefacts, ica, rejected components etc
    assert len(meta) == proc["tot_trials"]
    assert proc["tot_chans"] == 61
    bads = list(proc["impedence"]["bads"])
    bads += list(proc["rejections"]["badchans"])

    columns = list(proc["varnames"])
    if len(columns) != meta.shape[1]:
        columns = ["start_sample", "stop_sample", "offset"] + columns
        assert len(columns) == meta.shape[1]
    meta = pd.DataFrame(meta, columns=["_" + i for i in columns])
    assert len(meta) == 2129  # FIXME retrieve subjects who have less trials?

    # Add Brennan's annotations
    paths = get_paths()
    story = pd.read_csv(paths.download / "AliceChapterOne-EEG.csv")
    events = meta.join(story)

    events["kind"] = "word"
    events["condition"] = "sentence"
    events["duration"] = events.offset - events.onset
    columns = dict(Word="word", Position="word_id", Sentence="sequence_id")
    events = events.rename(columns=columns)
    events["start"] = events["_start_sample"] / SFREQ

    # add audio events
    wav_file = (
        paths.download / "audio" / "DownTheRabbitHoleFinal_SoundFile%i.wav"
    )
    sounds = []
    for segment, d in events.groupby("Segment"):
        # Some wav files start BEFORE the onset of eeg recording...
        start = d.iloc[0].start - d.iloc[0].onset
        sound = dict(
            kind="sound", start=start, filepath=str(wav_file) % segment
        )
        sounds.append(sound)
    events = pd.concat([events, pd.DataFrame(sounds)], ignore_index=True)
    events = events.sort_values("start").reset_index()

    # clean up
    keep = [
        "start",
        "duration",
        "kind",
        "word",
        "word_id",
        "sequence_id",
        "condition",
        "filepath",
    ]
    events = events[keep]
    events[['language', 'modality']] = 'english', 'audio'
    events = extract_sequence_info(events)
    events = events.event.create_blocks(groupby='sentence')
    events = events.event.validate()

    return events


def _read_eeg(fname):
    fname = Path(fname)
    assert fname.exists()
    assert str(fname).endswith(".mat")
    mat = loadmat(
        fname,
        squeeze_me=True,
        chars_as_strings=True,
        struct_as_record=True,
        simplify_cells=True,
    )
    mat = mat["raw"]

    # sampling frequency
    sfreq = mat["hdr"]["Fs"]
    assert sfreq == 500.0
    assert mat["fsample"] == sfreq

    # channels
    n_chans = mat["hdr"]["nChans"]
    n_samples = mat["hdr"]["nSamples"]
    ch_names = list(mat["hdr"]["label"])
    assert len(ch_names) == n_chans

    # vertical EOG
    assert ch_names[60] == "VEOG"

    # audio channel
    add_audio_chan = False
    if len(ch_names) == 61:
        ch_names += ["AUD"]
        add_audio_chan = True
    assert ch_names[61] in ("AUD", "Aux5")

    # check name
    for i, ch in enumerate(ch_names[:-2]):
        assert ch == str(i + 1 + (i >= 28))

    # channel type
    assert set(mat["hdr"]["chantype"]) == set(["eeg"])
    ch_types = ["eeg"] * 60 + ["eog", "misc"]
    assert set(mat["hdr"]["chanunit"]) == set(["uV"])

    # create MNE info
    info = mne.create_info(ch_names, sfreq, ch_types, verbose=False)
    subject_id = fname.name.split(".mat")[0]
    info["subject_info"] = dict(his_id=subject_id, id=int(subject_id[1:]))

    # time
    diff = np.diff(mat["time"]) - 1 / sfreq
    tol = 1e-5
    assert np.all(diff < tol)
    assert np.all(diff > -tol)
    start, stop = mat["sampleinfo"]
    assert start == 1
    assert stop == n_samples
    assert mat["hdr"]["nSamplesPre"] == 0
    assert mat["hdr"]["nTrials"] == 1

    # data
    data = mat["trial"]
    assert data.shape[0] == n_chans
    assert data.shape[1] == n_samples
    if add_audio_chan:
        data = np.vstack((data, np.zeros_like(data[0])))

    # create mne objects
    info = mne.create_info(ch_names, sfreq, ch_types, verbose=False)
    raw = mne.io.RawArray(data * 1e-6, info, verbose=False)
    montage = mne.channels.make_standard_montage("easycap-M10")
    raw.set_montage(montage)

    assert raw.info["sfreq"] == SFREQ
    assert len(raw.ch_names) == 62

    return raw


class Brennan2019Recording(api.Recording):

    data_url = "https://deepblue.lib.umich.edu/data/concern/data_sets/"
    data_url += "bg257f92t"
    paper_url = "https://journals.plos.org/plosone/"
    paper_url += "article?id=10.1371/journal.pone.0207741"
    doi = "https://doi.org/10.1371/journal.pone.0207741"
    licence = "CC BY 4.0"
    modality = "audio"
    language = "english"
    device = "eeg"
    description = """EEG of Alice in WonderLand, By Brennan and Hale 2019.
    The eeg data was bandpassed between 0.1 and 200. Hz"""

    @classmethod
    def iter(cls) -> tp.Iterator["Brennan2019Recording"]:  # type: ignore
        """Returns a generator of all recordings"""
        # download, extract, organize
        paths = get_paths()
        _prepare()

        subjects = [
            f.name
            for f in (paths.download / "proc").iterdir()
            if (f.name.startswith("S") and f.name.endswith(".mat"))
        ]
        assert len(subjects) == 42
        # remove bad subject s24 (metadata does not have enough trials)
        # FIXME retrieve these subjects?
        bads = [
            "S24.mat",
            "S26.mat",
            "S27.mat",
            "S30.mat",
            "S32.mat",
            "S34.mat",
            "S35.mat",
            "S36.mat",
        ]
        bads += ["S02.mat"]  # bad proc.trl?
        subjects = [s.split(".")[0] for s in subjects if s not in bads]

        for subject in subjects:
            recording = cls(subject_uid=str(subject))
            yield recording

    def __init__(self, subject_uid: str) -> None:
        super().__init__(subject_uid=subject_uid, recording_uid=subject_uid)

    def _load_raw(self) -> mne.io.RawArray:
        paths = get_paths()
        raw = _read_eeg(paths.download / f"{self.subject_uid}.mat")
        return raw

    def _load_events(self) -> pd.DataFrame:
        file = get_paths().download / "proc" / f"{self.subject_uid}.mat"
        events = _read_meta(file)
        return events

NameError: ignored

## features/audio.py

In [None]:
class MelSpectrum(base.Feature, CaptureInit):
    """Outputs the sound waves with the features frequency
    """
    event_kind = "sound"

    def __init__(self, sample_rate: Frequency, n_mels=40, n_fft=512, in_sampling=16_000,
                 normalized=True, use_log_scale=True, log_scale_eps=1e-5,
                 norm_audio: bool = True) -> None:
        super().__init__(sample_rate)
        self.dimension = n_mels
        kwargs = self._init_kwargs
        kwargs.pop('sample_rate')
        self.cache = Cache(self.__class__.__name__, kwargs)

        self.in_sampling = in_sampling
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = n_fft // 4
        self.use_log_scale = use_log_scale
        self.log_scale_eps = log_scale_eps
        self.normalized = normalized
        self.norm_audio = norm_audio
        self.trans = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.in_sampling, n_mels=self.n_mels,
            n_fft=n_fft, hop_length=self.hop_length, normalized=normalized
        )

        if use_log_scale:
            self.default_value = math.log10(log_scale_eps)

    def _compute(self, filepath: Path, start: float, stop: float) -> torch.Tensor:
        wav, sr = _extract_wav_part(filepath, start, stop)
        wav = torch.mean(wav, dim=0)  # stereo to mono
        if self.norm_audio:
            wav = (wav - wav.mean()) / (1e-8 + wav.std())
        wav = julius.resample.ResampleFrac(old_sr=int(sr), new_sr=self.in_sampling)(wav)

        # Two UserWarnings thrown internally by torch here: "stft will require the return_complex
        # parameter be explicitly" and "The function torch.rfft is deprecated". Remove this once
        # torch library updates to fix this
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            melspec = self.trans(wav)
        if self.use_log_scale:
            melspec = torch.log10(melspec + self.log_scale_eps)
        return melspec

    def get(self, event: events.Sound) -> torch.Tensor:
        melspec = self.cache.get(
            self._compute, filepath=event.filepath,
            start=event.offset, stop=event.offset + event.duration)
        feature_samples = self.sample_rate.to_ind(event.stop - event.start)
        return F.interpolate(melspec[None], feature_samples)[0]


class Pitch(base.Feature, CaptureInit):
    """Pitch from the waveform.
    """

    event_kind = "sound"

    def __init__(self, sample_rate: Frequency, min_f0=100.0, max_f0=350.0, harmonic_thresh=0.1,
                 frame_length_in_samples=256, frame_space_in_samples=64) -> None:
        super().__init__(sample_rate)
        kwargs = self._init_kwargs
        kwargs.pop('sample_rate')
        self.cache = Cache(self.__class__.__name__, kwargs)

        self.frame_length_in_samples = frame_length_in_samples
        self.frame_space_in_samples = frame_space_in_samples
        self.harmonic_thresh = harmonic_thresh
        self.min_f0 = min_f0
        self.max_f0 = max_f0
        self.in_sampling = 16_000

    @property
    def _cache_params(self):
        return self._init_args_kwargs

    def _compute(self, filepath: Path, start: float, stop: float) -> torch.Tensor:
        wav_stereo, sr = _extract_wav_part(filepath, start, stop)
        wav = torch.mean(wav_stereo, axis=0)  # Stereo to mono
        wav = julius.resample.ResampleFrac(old_sr=int(sr), new_sr=self.in_sampling)(wav)

        pitches, harmonic_rates, argmins, times = compute_yin(
            sig=wav.numpy(),
            sr=self.in_sampling,
            w_len=self.frame_length_in_samples,
            w_step=self.frame_space_in_samples,
            harmo_thresh=self.harmonic_thresh,
            f0_min=self.min_f0,
            f0_max=self.max_f0)
        out = torch.FloatTensor(pitches)
        return out

    def get(self, event: events.Sound) -> torch.Tensor:
        pitches = self.cache.get(
            self._compute, filepath=event.filepath,
            start=event.offset, stop=event.offset + event.duration)
        feature_samples = self.sample_rate.to_ind(event.stop - event.start)
        out = F.interpolate(pitches[None, None], feature_samples)[0, 0]
        return out[None]


class _BaseWav2Vec(base.Feature, CaptureInit):
    """
    Parent class for Wav2VecTr and Wav2VecConv
    """

    event_kind = "sound"
    model_name = "facebook/wav2vec2-large-xlsr-53"

    def __init__(self, sample_rate: Frequency,
                 normalized: bool = True, random: bool = False,
                 device: str = "cpu") -> None:
        super().__init__(sample_rate)
        args: tp.Any = self.model_name
        if random:
            args = (self.model_name, random)
        self.cache = Cache("Wav2VecEmbedding", args, mode="memmap")
        self.normalized = normalized
        self.device = device
        self.random = random
        # Huggingface logging
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        os.environ["TRANSFORMERS_VERBOSITY"] = "critical"
        self._model_cache = MemoryCache("Wav2VecEmbedding", "model")
        self._extractor_cache = MemoryCache("Wav2VecEmbedding", "extractor")

    @property
    def model(self) -> tp.Any:
        from transformers import Wav2Vec2Model
        if self.random:
            return self._model_cache.get(self._get_random_model)
        else:
            return self._model_cache.get(Wav2Vec2Model.from_pretrained, self.model_name)

    def _get_random_model(self):
        from transformers import Wav2Vec2Model, Wav2Vec2Config
        config = Wav2Vec2Config.from_pretrained(self.model_name)
        return Wav2Vec2Model(config)

    @property
    def feature_extractor(self) -> tp.Any:
        from transformers import Wav2Vec2FeatureExtractor
        return self._extractor_cache.get(Wav2Vec2FeatureExtractor.from_pretrained, self.model_name)

    def _preprocess_wav(self, filepath: Union[Path, str],
                        start: float, stop: float) -> torch.Tensor:
        wav, sr = _extract_wav_part(filepath, start, stop)
        logger.debug(
            "Preprocessing Wav on %s, start %.1f, stop %.1f, duration %.1f",
            filepath, start, stop, stop - start)
        wav = torch.mean(wav, dim=0)  # stereo to mono
        model_sr = self.feature_extractor.sampling_rate
        wav = julius.resample.ResampleFrac(old_sr=int(sr), new_sr=model_sr)(wav)

        # [1, T]
        out = self.feature_extractor(wav,
                                     return_tensors="pt",
                                     sampling_rate=model_sr,
                                     do_normalize=self.normalized).input_values
        return out

    def _compute_hidden_states(
            self, name: str, filepath: Path, start: float, stop: float,
            layers: tp.Optional[tp.List[int]] = None) -> torch.Tensor:
        input_values = self._preprocess_wav(filepath=filepath, start=start, stop=stop)

        self.model.to(self.device)
        self.model.eval()  # needs to be in eval mode
        with torch.no_grad():
            outputs = self.model(input_values.to(self.device), output_hidden_states=True)
        out: tp.Any = outputs.get(name)
        if isinstance(out, tuple):
            out = torch.stack(out)
        if layers is not None:
            out = out[layers].mean(0)
        return out.detach().cpu().clone().numpy()

    def _get_cached_tensor(
            self, event: events.Sound, overlap: events.DataSlice, name: str,
            layers: tp.Optional[tp.List[int]] = None,
    ) -> torch.Tensor:
        outputs = self.cache.get(
            self._compute_hidden_states, start=event.offset, stop=event.offset + event.duration,
            filepath=event.filepath, name=name, layers=layers)
        embd_sr = outputs.shape[-2] / event.duration
        # safety, to make sure we extract the right dim... but maybe slow
        if event.duration >= 0.5:
            assert 42 < embd_sr < 52, (f"Unexpected sampling rate for embedding {embd_sr}",
                                       event.duration, outputs.shape[-2])
        # if the above assert fails, event duration may be inconsistent with actual wav duration
        # or the wav2vec output sampling rate has changed.
        # we'd need to either find a way to get the embedding sampling rate independently, or
        # figure out the duration in another way
        sr = Frequency(embd_sr)
        start, stop = [sr.to_ind(x - event.start) for x in (overlap.start, overlap.stop)]
        start = min(start, outputs.shape[-2] - 1)
        stop = max(start + 1, stop)
        chunk = outputs[..., start: stop, :]
        # load into memory (probably unnecessary, but lets avoid weird issues)
        chunk = np.array(chunk, copy=True)
        return torch.from_numpy(chunk)

    def get(self, event: events.Sound) -> torch.Tensor:
        raise RuntimeError(f"Only get_on_overlap is available for {self.__class__.__name__}")


class Wav2VecTransformer(_BaseWav2Vec):
    """Outputs the Wav2Vec transformer layers
    """
    event_kind = "sound"
    dimension = 1024

    def __init__(self, sample_rate: Frequency,
                 normalized: bool = True,
                 layers: tp.Tuple[int, ...] = (14, 15, 16, 17, 18),
                 random: bool = False,
                 device: str = "cpu") -> None:
        super().__init__(sample_rate=sample_rate, normalized=normalized,
                         device=device, random=random)
        self.layers = layers

    def get_on_overlap(self, event: events.Sound, overlap: events.DataSlice) -> torch.Tensor:
        outputs = self._get_cached_tensor(
            event, overlap=overlap,
            name="hidden_states", layers=list(self.layers))
        outputs = outputs[0].transpose(0, 1)  # [1, T, D] -> [T, D] -> [D, T]
        return F.interpolate(outputs[None], overlap.duration_ind)[0]


class Wav2VecConvolution(_BaseWav2Vec):
    """Outputs the Wav2Vec convolutional layers
    """
    event_kind = "sound"
    dimension = 512

    def get_on_overlap(self, event: events.Sound, overlap: events.DataSlice) -> torch.Tensor:
        outputs = self._get_cached_tensor(event, overlap=overlap, name="extract_features")
        # [1, T, D] -> [T, D] -> [D, T]
        outputs = outputs[0].transpose(0, 1)  # [1, T, D] -> [T, D] -> [D, T]
        out = F.interpolate(outputs[None], overlap.duration_ind)[0]
        return out


class Wav2VecChunk(_BaseWav2Vec):
    """Outputs a chunk of the waveform compatible to be an input of Wav2Vec Model"""

    event_kind = "sound"
    dimension = 1
    model_name = "facebook/wav2vec2-large-xlsr-53"
    normalizable = False

    def __init__(self, sample_rate: Frequency,
                 normalized: bool = True,
                 random: bool = False,
                 device: str = "cpu") -> None:
        # Forcing the SR to 16k for this feature (base::FeaturesBuilder()
        # doesn't handle multiple SRs)
        super().__init__(sample_rate=Frequency(16000), normalized=normalized,
                         device=device, random=random)

    @property
    def feature_extractor(self) -> tp.Any:
        from transformers import Wav2Vec2FeatureExtractor

        return self._extractor_cache.get(
            Wav2Vec2FeatureExtractor.from_pretrained, self.model_name
        )

    def get(self, event: events.Sound) -> torch.Tensor:
        # Possible improv.: add cache here to read full .wav once (small time reduction expected)
        wav = self._preprocess_wav(
            filepath=event.filepath,
            start=event.offset,
            stop=event.offset + event.duration,
        )
        return wav


def _extract_wav_part(
    filepath: Union[Path, str], onset: float, offset: float
) -> tp.Tuple[torch.Tensor, Frequency]:
    """Extract a chunk of a wave file based on onset and offset in seconds
    """
    info = torchaudio.info(str(filepath))
    sr = Frequency(info.sample_rate)
    wav = torchaudio.load(
        filepath, frame_offset=sr.to_ind(onset), num_frames=sr.to_ind(offset - onset))[0]
    delta = abs(wav.shape[-1] / sr - offset + onset)
    assert delta <= 0.1, (delta, filepath, onset, offset, onset - offset)
    return wav, sr

## scaling

## train / test split

## dataset definition

In [None]:
class BrainAudioDataset(torch.utils.data.Dataset):

    def __init__():
        super().__init__()
        pass

    def __len__():
        pass

    def __getitem__():
        """
        inputs:
        - idx:

        outputs:
        - eeg:   (1, C, T)
        - mfcc:  (1, C, T)
        - words: ??
        """
        pass

    def collate_fn(self, batch):
        pass


## dataloader definition

In [None]:
#TODO

# models [KS]

## common.py

In [None]:
class ScaledEmbedding(nn.Module):
    """Scale up learning rate for the embedding, otherwise, it can move too slowly.
    """
    def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data /= scale
        self.scale = scale

    @property
    def weight(self):
        return self.embedding.weight * self.scale

    def forward(self, x):
        return self.embedding(x) * self.scale


class SubjectLayers(nn.Module):
    """Per subject linear layer."""
    def __init__(self, in_channels: int, out_channels: int, n_subjects: int, init_id: bool = False):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels))
        if init_id:
            assert in_channels == out_channels
            self.weights.data[:] = torch.eye(in_channels)[None]
        self.weights.data *= 1 / in_channels**0.5

    def forward(self, x, subjects):
        _, C, D = self.weights.shape
        weights = self.weights.gather(0, subjects.view(-1, 1, 1).expand(-1, C, D))
        return torch.einsum("bct,bcd->bdt", x, weights)

    def __repr__(self):
        S, C, D = self.weights.shape
        return f"SubjectLayers({C}, {D}, {S})"


class LayerScale(nn.Module):
    """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
    This rescales diagonaly residual outputs close to 0 initially, then learnt.
    """
    def __init__(self, channels: int, init: float = 0.1, boost: float = 5.):
        super().__init__()
        self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
        self.scale.data[:] = init / boost
        self.boost = boost

    def forward(self, x):
        return (self.boost * self.scale[:, None]) * x


class ConvSequence(nn.Module):

    def __init__(self, channels: tp.Sequence[int], kernel: int = 4, dilation_growth: int = 1,
                 dilation_period: tp.Optional[int] = None, stride: int = 2,
                 dropout: float = 0.0, leakiness: float = 0.0, groups: int = 1,
                 decode: bool = False, batch_norm: bool = False, dropout_input: float = 0,
                 skip: bool = False, scale: tp.Optional[float] = None, rewrite: bool = False,
                 activation_on_last: bool = True, post_skip: bool = False, glu: int = 0,
                 glu_context: int = 0, glu_glu: bool = True, activation: tp.Any = None) -> None:
        super().__init__()
        dilation = 1
        channels = tuple(channels)
        self.skip = skip
        self.sequence = nn.ModuleList()
        self.glus = nn.ModuleList()
        if activation is None:
            activation = partial(nn.LeakyReLU, leakiness)
        Conv = nn.Conv1d if not decode else nn.ConvTranspose1d
        # build layers
        for k, (chin, chout) in enumerate(zip(channels[:-1], channels[1:])):
            layers: tp.List[nn.Module] = []
            is_last = k == len(channels) - 2

            # Set dropout for the input of the conv sequence if defined
            if k == 0 and dropout_input:
                assert 0 < dropout_input < 1
                layers.append(nn.Dropout(dropout_input))

            # conv layer
            if dilation_growth > 1:
                assert kernel % 2 != 0, "Supports only odd kernel with dilation for now"
            if dilation_period and (k % dilation_period) == 0:
                dilation = 1
            pad = kernel // 2 * dilation
            layers.append(Conv(chin, chout, kernel, stride, pad,
                               dilation=dilation, groups=groups if k > 0 else 1))
            dilation *= dilation_growth
            # non-linearity
            if activation_on_last or not is_last:
                if batch_norm:
                    layers.append(nn.BatchNorm1d(num_features=chout))
                layers.append(activation())
                if dropout:
                    layers.append(nn.Dropout(dropout))
                if rewrite:
                    layers += [nn.Conv1d(chout, chout, 1), nn.LeakyReLU(leakiness)]
                    # layers += [nn.Conv1d(chout, 2 * chout, 1), nn.GLU(dim=1)]
            if chin == chout and skip:
                if scale is not None:
                    layers.append(LayerScale(chout, scale))
                if post_skip:
                    layers.append(Conv(chout, chout, 1, groups=chout, bias=False))

            self.sequence.append(nn.Sequential(*layers))
            if glu and (k + 1) % glu == 0:
                ch = 2 * chout if glu_glu else chout
                act = nn.GLU(dim=1) if glu_glu else activation()
                self.glus.append(
                    nn.Sequential(
                        nn.Conv1d(chout, ch, 1 + 2 * glu_context, padding=glu_context), act))
            else:
                self.glus.append(None)

    def forward(self, x: tp.Any) -> tp.Any:
        for module_idx, module in enumerate(self.sequence):
            old_x = x
            x = module(x)
            if self.skip and x.shape == old_x.shape:
                x = x + old_x
            glu = self.glus[module_idx]
            if glu is not None:
                x = glu(x)
        return x


class DualPathRNN(nn.Module):
    def __init__(self, channels: int, depth: int, inner_length: int = 10):
        super().__init__()
        self.lstms = nn.ModuleList([nn.LSTM(channels, channels, 1) for _ in range(depth * 4)])
        self.inner_length = inner_length

    def forward(self, x: torch.Tensor):
        B, C, L = x.shape
        IL = self.inner_length
        x = pad_multiple(x, self.inner_length)
        x = x.permute(2, 0, 1).contiguous()
        for idx, lstm in enumerate(self.lstms):
            y = x.reshape(-1, IL, B, C)
            if idx % 2 == 0:
                y = y.transpose(0, 1).reshape(IL, -1, C)
            else:
                y = y.reshape(-1, IL * B, C)
            y, _ = lstm(x)
            if idx % 2 == 0:
                y = y.reshape(IL, -1, B, C).transpose(0, 1).reshape(-1, B, C)
            else:
                y = y.reshape(-1, B, C)
            x = x + y

            if idx % 2 == 1:
                x = x.flip(dims=(0,))
        return x[:L].permute(1, 2, 0).contiguous()


class PositionGetter:
    INVALID = -0.1

    def __init__(self) -> None:
        self._cache: tp.Dict[int, torch.Tensor] = {}
        self._invalid_names: tp.Set[str] = set()

    def get_recording_layout(self, recording: Recording) -> torch.Tensor:
        index = recording.recording_index
        if index in self._cache:
            return self._cache[index]
        else:
            info = recording.mne_info
            layout = mne.find_layout(info)
            indexes: tp.List[int] = []
            valid_indexes: tp.List[int] = []
            for meg_index, name in enumerate(info.ch_names):
                name = name.rsplit("-", 1)[0]
                try:
                    indexes.append(layout.names.index(name))
                except ValueError:
                    if name not in self._invalid_names:
                        logger.warning(
                            "Channels %s not in layout for recording %s of %s.",
                            name,
                            recording.study_name(),
                            recording.recording_uid)
                        self._invalid_names.add(name)
                else:
                    valid_indexes.append(meg_index)

            positions = torch.full((len(info.ch_names), 2), self.INVALID)
            x, y = layout.pos[indexes, :2].T
            x = (x - x.min()) / (x.max() - x.min())
            y = (y - y.min()) / (y.max() - y.min())
            x = torch.from_numpy(x).float()
            y = torch.from_numpy(y).float()
            positions[valid_indexes, 0] = x
            positions[valid_indexes, 1] = y
            self._cache[index] = positions
            return positions

    def get_positions(self, batch):
        meg = batch.meg
        B, C, T = meg.shape
        positions = torch.full((B, C, 2), self.INVALID, device=meg.device)
        for idx in range(len(batch)):
            recording = batch._recordings[idx]
            rec_pos = self.get_recording_layout(recording)
            positions[idx, :len(rec_pos)] = rec_pos.to(meg.device)
        return positions

    def is_invalid(self, positions):
        return (positions == self.INVALID).all(dim=-1)


class FourierEmb(nn.Module):
    """
    Fourier positional embedding.
    Unlike trad. embedding this is not using exponential periods
    for cosines and sinuses, but typical `2 pi k` which can represent
    any function over [0, 1]. As this function would be necessarily periodic,
    we take a bit of margin and do over [-0.2, 1.2].
    """
    def __init__(self, dimension: int = 256, margin: float = 0.2):
        super().__init__()
        n_freqs = (dimension // 2)**0.5
        assert int(n_freqs ** 2 * 2) == dimension
        self.dimension = dimension
        self.margin = margin

    def forward(self, positions):
        *O, D = positions.shape
        assert D == 2
        *O, D = positions.shape
        n_freqs = (self.dimension // 2)**0.5
        freqs_y = torch.arange(n_freqs).to(positions)
        freqs_x = freqs_y[:, None]
        width = 1 + 2 * self.margin
        positions = positions + self.margin
        p_x = 2 * math.pi * freqs_x / width
        p_y = 2 * math.pi * freqs_y / width
        positions = positions[..., None, None, :]
        loc = (positions[..., 0] * p_x + positions[..., 1] * p_y).view(*O, -1)
        emb = torch.cat([
            torch.cos(loc),
            torch.sin(loc),
        ], dim=-1)
        return emb


class ChannelDropout(nn.Module):
    def __init__(self, dropout: float = 0.1, rescale: bool = True):
        """
        Args:
            dropout: dropout radius in normalized [0, 1] coordinates.
            rescale: at valid, rescale all channels.
        """
        super().__init__()
        self.dropout = dropout
        self.rescale = rescale
        self.position_getter = PositionGetter()

    def forward(self, meg, batch):
        if not self.dropout:
            return meg

        B, C, T = meg.shape
        meg = meg.clone()
        positions = self.position_getter.get_positions(batch)
        valid = (~self.position_getter.is_invalid(positions)).float()
        meg = meg * valid[:, :, None]

        if self.training:
            center_to_ban = torch.rand(2, device=meg.device)
            kept = (positions - center_to_ban).norm(dim=-1) > self.dropout
            meg = meg * kept.float()[:, :, None]
            if self.rescale:
                proba_kept = torch.zeros(B, C, device=meg.device)
                n_tests = 100
                for _ in range(n_tests):
                    center_to_ban = torch.rand(2, device=meg.device)
                    kept = (positions - center_to_ban).norm(dim=-1) > self.dropout
                    proba_kept += kept.float() / n_tests
                meg = meg / (1e-8 + proba_kept[:, :, None])

        return meg


class ChannelMerger(nn.Module):
    def __init__(self, chout: int, pos_dim: int = 256,
                 dropout: float = 0, usage_penalty: float = 0.,
                 n_subjects: int = 200, per_subject: bool = False):
        super().__init__()
        assert pos_dim % 4 == 0
        self.position_getter = PositionGetter()
        self.per_subject = per_subject
        if self.per_subject:
            self.heads = nn.Parameter(torch.randn(n_subjects, chout, pos_dim, requires_grad=True))
        else:
            self.heads = nn.Parameter(torch.randn(chout, pos_dim, requires_grad=True))
        self.heads.data /= pos_dim ** 0.5
        self.dropout = dropout
        self.embedding = FourierEmb(pos_dim)
        self.usage_penalty = usage_penalty
        self._penalty = torch.tensor(0.)

    @property
    def training_penalty(self):
        return self._penalty.to(next(self.parameters()).device)

    def forward(self, meg, batch):
        B, C, T = meg.shape
        meg = meg.clone()
        positions = self.position_getter.get_positions(batch)
        embedding = self.embedding(positions)
        score_offset = torch.zeros(B, C, device=meg.device)
        score_offset[self.position_getter.is_invalid(positions)] = float('-inf')

        if self.training and self.dropout:
            center_to_ban = torch.rand(2, device=meg.device)
            radius_to_ban = self.dropout
            banned = (positions - center_to_ban).norm(dim=-1) <= radius_to_ban
            score_offset[banned] = float('-inf')

        if self.per_subject:
            _, cout, pos_dim = self.heads.shape
            subject = batch.subject_index
            heads = self.heads.gather(0, subject.view(-1, 1, 1).expand(-1, cout, pos_dim))
        else:
            heads = self.heads[None].expand(B, -1, -1)

        scores = torch.einsum("bcd,bod->boc", embedding, heads)
        scores += score_offset[:, None]
        weights = torch.softmax(scores, dim=2)
        out = torch.einsum("bct,boc->bot", meg, weights)
        if self.training and self.usage_penalty > 0.:
            usage = weights.mean(dim=(0, 1)).sum()
            self._penalty = self.usage_penalty * usage
        return out

## convrnn.py

## features.py

## simpleconv.py

## new model [CM]

## decoder

audio -> wav2vec2 -> embedding space <- our model <- eeg

once in common embedding space, use CLIP

From CLIP --> decode to words

# losses and metrics

## clip [DS]

## wer [CM]

### general

### vocab-specific

# trainer


# experiments

# evaluation

## viz