# installs and imports

In [1]:
!pip install wandb torchsummaryX mne transformers Levenshtein adamp jiwer -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.3/8.3 MB[0m [31m56.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m169.4/169.4 kB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m252.8/252.8 kB[0m [31m35.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m99.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for adamp (setup.py) ... [?25l[?25hdone


In [None]:
### Torch
import  torch
import  torch.nn as nn
import  torch.nn.functional as F
from    torch.optim import lr_scheduler
from    torchsummaryX import summary
from    torch.utils.data import Dataset, DataLoader, random_split
import  torchaudio
import  torchaudio.transforms as tat
from    torch.nn.utils.rnn import (pad_sequence, pack_padded_sequence,
                                   pad_packed_sequence)
from torch.optim import AdamW
from adamp import AdamP

### General
import  random
import  numpy as np
import  pandas as pd
import  scipy
import  gc
from    tqdm.auto import tqdm
import  os
import  datetime
import  time
import  wandb
import  matplotlib.pyplot as plt
import  seaborn as sns
import re
import seaborn
import Levenshtein as lev

### Other

from functools import partial
import logging
import math
import typing as tp
import warnings
warnings.filterwarnings('ignore')

import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)

# wav2vec2 and EEG processing
from    transformers import (AutoProcessor, AutoModelForPreTraining,
                             CLIPProcessor, CLIPModel, Wav2Vec2ForCTC, Wav2Vec2Processor)
import  mne
from transformers import Wav2Vec2Model

# working directory

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
DATA_PATH  = '/content/gdrive/MyDrive/11785-IDLf23/Final_project/0_Data/'

BRAIN_PATH = os.path.join(DATA_PATH, 'brennan_and_hale_v2')
AUDIO_PATH = os.path.join(DATA_PATH, 'brennan_and_hale_v2/audio')
PPROC_PATH = os.path.join(DATA_PATH, 'brennan_and_hale_v2/proc/timelock-preprocessing')

EEG_PATH   = os.path.join(DATA_PATH, 'eeg-segments')
EMBED_PATH = os.path.join(DATA_PATH, 'brennan_wav2vec2_embeddings')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# configs

In [None]:
config = dict(

    # dataloaders ----------------------------------------------------------------
    batch_size = 4,
    transforms = None,
    reduce_data_ratio = 1, # 1 if we want to use the entire dataset

    # model ----------------------------------------------------------------

    ###

    ###

    ###

    # optimizer ----------------------------------------------------------------
    learning_rate = 0.0001,
    weight_decay = 5e-3,

    # scheduler ----------------------------------------------------------------
    factor = 0.9,
    patience = 3,

    # trainer ----------------------------------------------------------------
    epochs = 30)

# dataset and dataloaders [QA]

## dataset definition

In [6]:
class BrainAudioDataset(torch.utils.data.Dataset):
    """
    Hybrid Memory Efficient Dataset.
    Loads all audio embeddings in __init__(). ~1-2min.
    Loads eeg in __getitem__().
    Loading everything in __init__() crashes.
    Loading everything in __getitem__() is super slow.

    Load eeg
    Pad (add AUD of zeros) to be 62 chan
    Mask the bad chans
    """

    def __init__(self, eeg_path=EEG_PATH, embed_path=EMBED_PATH, transforms=None):
        super().__init__()

        self.eeg_root = eeg_path
        self.embed_root = embed_path
        self.transforms = transforms

        eeg_fnames = sorted(os.listdir(self.eeg_root))
        embed_fnames = sorted(os.listdir(self.embed_root))

        gc.collect()

        self.audio_embeddings = {}
        for segment_idx, fname in enumerate(tqdm(embed_fnames), start=1):
            audio_fpath = os.path.join(self.embed_root, fname)
            audio = torch.load(audio_fpath)
            audio_embed = audio.hidden_states[-1]
            audio_embed = audio_embed.squeeze(0)
            audio_embed = audio_embed.to(DEVICE)
            self.audio_embeddings[segment_idx] = audio_embed
            del audio, audio_embed
            gc.collect()

        self.eegs = {}
        for idx, fname in enumerate(tqdm(eeg_fnames)):
            # Extract subject id and segment (to be returned in __getitem__())
            subject_idx, segment_idx = self.extract_info(fname)

            # Load eeg filepath
            eeg_fpath = os.path.join(self.eeg_root, fname)
            self.eegs[idx] = (subject_idx, int(segment_idx), eeg_fpath)

        self.length = len(self.eegs)

        # montage = mne.channels.make_standard_montage("easycap-M10")

    def __len__(self):

        return self.length

    def __getitem__(self, idx):
        """
        - eeg           : (1, T_eeg, C_eeg)
        - audio_embed   : (1, T_audio, 1024)
        """
        # Get eeg
        _, segment_idx, eeg_fpath = self.eegs[idx]
        eeg = np.load(eeg_fpath)
        eeg = torch.tensor(eeg.transpose(), device=DEVICE)

        # If nchan < 62, zero pad for the audio channel (should be done in preproc)

        # Remove bad chan (should be done in prepoc)

        # Retrieve pre-loaded audio embedding
        audio_embed = self.audio_embeddings[segment_idx]
        return eeg, audio_embed, len(eeg), len(audio_embed)

    def collate_fn(self, batch):

        eeg, audio_embed, len_eeg, len_audio_embed = zip(*batch)

        # pack the mfccs and transcripts using the pad_sequence function from pytorch
        batch_eeg_pad = pad_sequence(eeg, batch_first=True)
        batch_audio_embed_pad = pad_sequence(audio_embed, batch_first=True)
        del eeg, audio_embed

        # Apply transformations
        if self.transforms is not None:
            batch_eeg_pad = self.transforms(batch_eeg_pad)

        return (batch_eeg_pad,
                batch_audio_embed_pad,
                torch.tensor(len_eeg, dtype=torch.int64),
                torch.tensor(len_audio_embed, dtype=torch.int64))

    def read_sfp(self, file_path):
        """
        Reads a BESA SFP (Surface Point) file (locations of sensors)
        """

        electrodes = {}
        with open(file_path, 'r') as file:
            for line in file:
                parts = line.strip().split()  # Split by whitespace
                if len(parts) == 4:
                    # Parse the electrode name and coordinates
                    name = parts[0]
                    x, y, z = map(float, parts[1:])  # Convert strings to floats
                    electrodes[name] = (x, y, z)

        return electrodes

    def extract_info(self, fname):
        # This pattern looks for any text (non-digits) followed by digits, a hyphen, and more digits
        match = re.match(r'([A-Za-z]+[0-9]+)-([0-9]+).npy', fname)
        if match:
            return match.group(1), match.group(2)
        else:
            return None


## dataloader definition

In [7]:
full_dataset = BrainAudioDataset()

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/348 [00:00<?, ?it/s]

In [14]:
TRAIN_PORTION = 0.8
VAL_PORTION   = 0.1
TEST_PORTION  = 0.1

# Splitting the dataset /// train_len = int(len(full_dataset) * TRAIN_PORTION)
torch.manual_seed(1)
train_data, val_data, test_data = random_split(full_dataset, [TRAIN_PORTION, VAL_PORTION, TEST_PORTION])

train_loader = torch.utils.data.DataLoader(
    dataset      = train_data,
    batch_size   = config['batch_size'],
    shuffle      = True,
    drop_last    = False,
    num_workers  = 4,
    pin_memory   = True,
    collate_fn   = full_dataset.collate_fn
)

val_loader = torch.utils.data.DataLoader(
    dataset      = val_data,
    batch_size   = config['batch_size'],
    shuffle      = False,
    drop_last    = False,
    num_workers  = 4,
    pin_memory   = True,
    collate_fn   = full_dataset.collate_fn
)

test_loader = torch.utils.data.DataLoader(
    dataset      = test_data,
    batch_size   = config['batch_size'],
    shuffle      = False,
    drop_last    = False,
    num_workers  = 4,
    pin_memory   = True,
    collate_fn   = full_dataset.collate_fn
)

In [15]:
print('-'*80)
print(f'Len full data:      {len(full_dataset)}')
print(f'Len train data:     {len(train_data)}')
print(f'Len val data:       {len(val_data)}')
print(f'Len test data:      {len(test_data)}')
print('-'*80)
print(f'Len Train Loader:   {train_loader.__len__()}')
print(f'Len Val Loader:     {val_loader.__len__()}')
print(f'Len Test Loader:    {test_loader.__len__()}')

gc.collect()

for batch in train_loader:
    eeg, audio_embedding, l_eeg, l_audio = batch
    print(eeg.shape, audio_embedding.shape, l_eeg.shape, l_audio.shape)
    print(eeg.dtype, audio_embedding.dtype, l_eeg.dtype, l_audio.dtype)
    del eeg, audio_embedding, l_eeg, l_audio
    gc.collect()
    break

for batch in val_loader:
    eeg, audio_embedding, l_eeg, l_audio = batch
    print(eeg.shape, audio_embedding.shape, l_eeg.shape, l_audio.shape)
    print(eeg.dtype, audio_embedding.dtype, l_eeg.dtype, l_audio.dtype)
    del eeg, audio_embedding, l_eeg, l_audio
    gc.collect()
    break

for batch in test_loader:
    eeg, audio_embedding, l_eeg, l_audio = batch
    print(eeg.shape, audio_embedding.shape, l_eeg.shape, l_audio.shape)
    print(eeg.dtype, audio_embedding.dtype, l_eeg.dtype, l_audio.dtype)
    del eeg, audio_embedding, l_eeg, l_audio
    gc.collect()
    break

--------------------------------------------------------------------------------
Len full data:      348
Len train data:     279
Len val data:       35
Len test data:      34
--------------------------------------------------------------------------------
Len Train Loader:   70
Len Val Loader:     9
Len Test Loader:    9


RuntimeError: ignored

# meta models desc

Relevant Meta folders: [Features folder](https://github.com/facebookresearch/brainmagick/tree/main/bm/features) and [Model folder](https://github.com/facebookresearch/brainmagick/tree/main/bm/models).


**Shared model components and models:**

1. **[common.py](https://github.com/facebookresearch/brainmagick/blob/main/bm/models/common.py)** - collection of components used by other models

  - ScaledEmbedding

  - SubjectLayers

  - LayerScale

  - ConvSequence

  - DualPathRNN

  - PositionGetter

  - FourierEmb

  - ChannelDropout

  - ChannelMerger

2. **[convrnn.py](https://github.com/facebookresearch/brainmagick/blob/main/bm/models/convrnn.py)** - a model used both as an a encoder and a decoder.

  - LSTM

  - Attention

  - ConvRNN
  
    - SubjectLayers (subject layer)

    - ScaledEmbedding (subject embedding)
    
    - LSTM (bidirectional)

    - Attention (multi-head dot product)

    - ConvSequence (decoder)
  
    - Conv1d or Conv1d + ReLU + Conv1d (final)

3. **[simpleconv.py](https://github.com/facebookresearch/brainmagick/blob/main/bm/models/simpleconv.py)**

  - SimpleConv

    - takes a sample of channels (subsampled_meg_channels)
    
    - ChannelDropout

    - ChannelMerger

    - Conv1d + activations (initial layer)

    - SubjectLayers (subject layer)

    - ta.transforms.Spectrogram (short-time fourier transform)

    - ScaledEmbedding (subject embedding)

    - ConvSequence (encoder)

    - DualPathRNN
  
    - Conv1d or Conv1d + ReLU + Conv1d (final)

4. **[features.py](https://github.com/facebookresearch/brainmagick/blob/main/bm/models/features.py)** - model to extract features

  - DeepMel(ConvSequence)

**Combined encoders:**

1. **[deep_mel.yaml](https://github.com/facebookresearch/brainmagick/blob/main/bm/conf/feature_model/deep_mel.yaml)** - calls **features.py** - feature model

2. **[convrnn.yaml](https://github.com/facebookresearch/brainmagick/blob/main/bm/conf/model/convrnn.yaml)** - calls **convrnn.py**

**Combined decoders:**

1. **[clip_conv.yaml](https://github.com/facebookresearch/brainmagick/blob/main/bm/conf/model/clip_conv.yaml/)** - calls **simpleconv.py** - default model

2. **[decoder_convrnn.yaml](https://github.com/facebookresearch/brainmagick/blob/main/bm/conf/model/decoder_convrnn.yaml)** - calls **convrnn.py**


Other: All the imports https://github.com/facebookresearch/brainmagick/blob/main/requirements.txt

# models [KS]

## replicated meta models

### brain encoder (replicated meta's convrnn)

In [None]:
class SubjectLayers(nn.Module): # learn how different are the subjects and based on that normalize the EEG
    def __init__(self, in_channels, out_channels, n_subjects):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels)) # initialize weights for each subject
        self.weights.data *= 1 / in_channels**0.5 # normalize weights

    def forward(self, x, n_subjects):
        _, C, D = self.weights.shape
        subject_weights = self.weights.gather(0, n_subjects.view(-1, 1, 1).expand(-1, C, D)) # select the appropriate weights for each subject in the batch
        transformed_eeg = torch.einsum("bct,bcd->bdt", x, subject_weights) # apply the subject-specific transformations

        return transformed_eeg

class ScaledEmbedding(nn.Module): # assign a unique vector to each subject (similar to positional embedding)

    def __init__(self, n_subjects, embedding_dim, scale):

        super().__init__()
        self.embedding = nn.Embedding(n_subjects, embedding_dim)
        self.embedding.weight.data /= scale
        self.scale = scale

    def forward(self, x):
        scaled_embedding = self.embedding(x) * self.scale
        return scaled_embedding

class LayerScale(nn.Module):
    def __init__(self, channels, init = 0.1, boost = 5.):
        super().__init__()
        self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
        self.scale.data[:] = init / boost
        self.boost = boost

    def forward(self, x):
        return (self.boost * self.scale[:, None]) * x

class ConvSequence(nn.Module):

    def __init__(self,
                 channels = [16, 32, 64, 128],
                 kernel = 4, dilation_growth = 1, dilation_period = None,
                 stride = 2, dropout = 0.0, leakiness = 0.0,
                 groups = 1, decode = False, batch_norm = False,
                 dropout_input = 0, skip = False, scale = None,
                 rewrite = False, activation_on_last = True,
                 post_skip = False, glu = 0, glu_context = 0,
                 glu_glu = True, activation = None):

        super().__init__()
        dilation = 1
        channels = tuple(channels)
        self.skip = skip
        self.sequence = nn.ModuleList()
        self.glus = nn.ModuleList()
        if activation is None:
            activation = partial(nn.LeakyReLU, leakiness)
        Conv = nn.Conv1d if not decode else nn.ConvTranspose1d
        # build layers
        for k, (chin, chout) in enumerate(zip(channels[:-1], channels[1:])):
            layers: tp.List[nn.Module] = []
            is_last = k == len(channels) - 2

            # Set dropout for the input of the conv sequence if defined
            if k == 0 and dropout_input:
                assert 0 < dropout_input < 1
                layers.append(nn.Dropout(dropout_input))

            # conv layer
            if dilation_growth > 1:
                assert kernel % 2 != 0 # supports only odd kernel with dilation
            if dilation_period and (k % dilation_period) == 0:
                dilation = 1
            pad = kernel // 2 * dilation
            layers.append(Conv(chin, chout, kernel, stride, pad,
                               dilation=dilation, groups=groups if k > 0 else 1))
            dilation *= dilation_growth
            # non-linearity
            if activation_on_last or not is_last:
                if batch_norm:
                    layers.append(nn.BatchNorm1d(num_features=chout))
                layers.append(activation())
                if dropout:
                    layers.append(nn.Dropout(dropout))
                if rewrite:
                    layers += [nn.Conv1d(chout, chout, 1), nn.LeakyReLU(leakiness)]
            if chin == chout and skip:
                if scale is not None:
                    layers.append(LayerScale(chout, scale))
                if post_skip:
                    layers.append(Conv(chout, chout, 1, groups=chout, bias=False))

            self.sequence.append(nn.Sequential(*layers))
            if glu and (k + 1) % glu == 0:
                ch = 2 * chout if glu_glu else chout
                act = nn.GLU(dim=1) if glu_glu else activation()
                self.glus.append(
                    nn.Sequential(
                        nn.Conv1d(chout, ch, 1 + 2 * glu_context, padding=glu_context), act))
            else:
                self.glus.append(None)

    def forward(self, x):
        for module_idx, module in enumerate(self.sequence):
            old_x = x
            x = module(x)
            if self.skip and x.shape == old_x.shape:
                x = x + old_x
            glu = self.glus[module_idx]
            if glu is not None:
                x = glu(x)
        return x

class Attention(nn.Module): # scaled dot product with relative position encoding and local attention
    def __init__(self, channels, radius = 50, heads = 4):
        super().__init__()
        assert channels % heads == 0
        self.content = nn.Conv1d(channels, channels, 1)
        self.query = nn.Conv1d(channels, channels, 1)
        self.key = nn.Conv1d(channels, channels, 1)
        self.embedding = nn.Embedding(radius * 2 + 1, channels // heads)
        weight = self.embedding.weight.data
        weight[:] = weight.cumsum(0) / torch.arange(1, len(weight) + 1).float().view(-1, 1).sqrt()
        self.heads = heads
        self.radius = radius
        self.bn = nn.BatchNorm1d(channels)
        self.fc = nn.Conv1d(channels, channels, 1)
        self.scale = nn.Parameter(torch.full([channels], 0.1))

    def forward(self, x):

        def _split(y):
            return y.view(y.shape[0], self.heads, -1, y.shape[2])

        content = _split(self.content(x))
        query = _split(self.query(x))
        key = _split(self.key(x))

        batch_size, _, dim, length = content.shape

        dots = torch.einsum("bhct,bhcs->bhts", query, key) # first index `t` is query, second index `s` is key.

        steps = torch.arange(length, device=x.device)
        relative = (steps[:, None] - steps[None, :])
        embs = self.embedding.weight.gather(0, self.radius + relative.clamp_(-self.radius, self.radius).view(-1, 1).expand(-1, dim))
        embs = embs.view(length, length, -1)
        dots += 0.3 * torch.einsum("bhct,tsc->bhts", query, embs)
        dots = torch.where(
            relative.abs() <= self.radius, dots, torch.tensor(-float('inf')).to(embs))

        weights = torch.softmax(dots, dim=-1)
        out = torch.einsum("bhts,bhcs->bhct", weights, content)
        out += 0.3 * torch.einsum("bhts,tsc->bhct", weights, embs)
        out = out.reshape(batch_size, -1, length)
        out = F.relu(self.bn(self.fc(out))) * self.scale.view(1, -1, 1)
        return out

"""
# TEST 1
def test_subject_layers():
    batch_size = 2
    in_channels = 3
    out_channels = 2
    n_subjects = 1
    time_steps = 2
    eeg_data = torch.randn(batch_size, in_channels, time_steps)
    print(eeg_data)
    subjects = torch.randint(0, n_subjects, (batch_size,))
    print(subjects)
    subject_layers = SubjectLayers(in_channels, out_channels, n_subjects)
    output = subject_layers(eeg_data, subjects)
    expected_shape = (batch_size, out_channels, time_steps)
    assert output.shape == expected_shape, f"Output shape mismatch: expected {expected_shape}, got {output.shape}"
test_subject_layers()

# TEST 2
def scaled_embedding():
  n_subjects = 10
  embedding_dim = 5
  scale = 10
  scaled_embedding = ScaledEmbedding(n_subjects, embedding_dim, scale)
  subject_indices = torch.tensor([1, 2])
  embeddings = scaled_embedding(subject_indices)
  print(embeddings)
scaled_embedding()

# TEST 3
conv_sequence = ConvSequence(channels=[16, 32, 64, 128])
input_tensor = torch.randn(1, 16, 50) # (Batch Size, Channels, Length)
output = conv_sequence(input_tensor)
output.shape
#summary(conv_sequence, input_size=(16, 50))
"""

'\n# TEST 1\ndef test_subject_layers():\n    batch_size = 2\n    in_channels = 3\n    out_channels = 2\n    n_subjects = 1\n    time_steps = 2\n    eeg_data = torch.randn(batch_size, in_channels, time_steps) \n    print(eeg_data)\n    subjects = torch.randint(0, n_subjects, (batch_size,))\n    print(subjects)\n    subject_layers = SubjectLayers(in_channels, out_channels, n_subjects)\n    output = subject_layers(eeg_data, subjects)\n    expected_shape = (batch_size, out_channels, time_steps)\n    assert output.shape == expected_shape, f"Output shape mismatch: expected {expected_shape}, got {output.shape}"\ntest_subject_layers()\n\n# TEST 2\ndef scaled_embedding():\n  n_subjects = 10\n  embedding_dim = 5    \n  scale = 10      \n  scaled_embedding = ScaledEmbedding(n_subjects, embedding_dim, scale)\n  subject_indices = torch.tensor([1, 2]) \n  embeddings = scaled_embedding(subject_indices)\n  print(embeddings)\nscaled_embedding()\n\n# TEST 3\nconv_sequence = ConvSequence(channels=[16, 32

In [None]:
class EEG_Encoder(nn.Module):

    def __init__(self,

                 in_channels = 64, n_subjects = 30,

                 # subject layers
                 out_channels = 128,

                 # conv
                 conv_chanels = [16, 32, 64, 128],
                 kernel = 4, stride = 2, conv_dropout = 0,
                 batch_norm = False, dropout_input = 0,
                 leakiness = 0,

                 # lstm
                 hidden_size = 128, lstm_layers = 4, lstm_dropout = 0.1,

                 # attention
                 attention_heads = 4, subject_dim = 64, embedding_scale = 1.0):

        super().__init__()

        # subject layers (normalization across subjects)
        self.subject_layers = SubjectLayers(in_channels, out_channels, n_subjects)

        # scaled embedding (optional)
        # self.subject_embedding = ScaledEmbedding(n_subjects, subject_dim, embedding_scale)

        # convsequence
        self.convs = ConvSequence(channels=conv_chanels, kernel = kernel_size,
                                  stride = stride, dropout = conv_dropout,
                                  batch_norm = batch_norm, leakiness = leakiness,
                                  dropout_input = dropout_input)

        # lstm
        self.lstm = nn.LSTM(input_size    = out_channels,
                            hidden_size   = hidden_size//2,
                            num_layers    = lstm_layers,
                            dropout       = lstm_dropout,
                            bidirectional = True,
                            batch_first    = True)

        # attention
        self.attention = Attention(hidden_size, heads=attention_heads)

        # final linear layer
        self.finalconv1 = nn.Conv1d(hidden_size, out_channels, kernel_size = 1)

    def forward(self, eeg_inputs): # to pass as an additional paramater n_subjects=10

        # subject layers
        # subject_indices = torch.arange(n_subjects).to(eeg_inputs.device)
        # normalized_eeg_inputs = self.subject_layers(eeg_inputs, subject_indices)

        # scaled embedding // scaled_embedding = self.subject_embedding(subject_indices)

        # convsequence
        print('eeg_inputs before convs', eeg_inputs.shape)
        out = self.convs(eeg_inputs)
        print('eeg_inputs after convs', out.shape)
        print("After convs:", torch.isnan(out).any())

        # lstm
        out = out.permute(2, 0, 1)
        out, _ = self.lstm(out)
        out = out.permute(1, 2, 0)
        print("After LSTM:", torch.isnan(out).any())

        # attention
        out = out + self.attention(out)
        print("After Attention:", torch.isnan(out).any())

        # decoder
        # out = self.decoder(out)

        # final
        out = self.finalconv1(out)
        print("Final Output:", torch.isnan(out).any())

        return out

# TEST
in_channels = 64
n_subjects = 30
out_channels = 128
conv_channels = [64, 32, 64, 128]
kernel_size = 4
stride = 2
conv_dropout = 0
batch_norm = False
dropout_input = 0
leakiness = 0
hidden_size = 128
lstm_layers = 4
lstm_dropout = 0.1
attention_heads = 4
subject_dim = 64
embedding_scale = 1.0

eeg_encoder = EEG_Encoder(in_channels, n_subjects, out_channels, conv_channels, kernel_size, stride,
                          conv_dropout, batch_norm, dropout_input, leakiness, hidden_size, lstm_layers,
                          lstm_dropout, attention_heads, subject_dim, embedding_scale)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
eeg_encoder.to(device)
batch_size = 2
in_channels = 64
sequence_length = 128
x_sample = torch.rand(batch_size, in_channels, sequence_length)
summary(eeg_encoder, x_sample.to(device))

eeg_inputs before convs torch.Size([2, 64, 128])
eeg_inputs after convs torch.Size([2, 128, 17])
After convs: tensor(False)
After LSTM: tensor(False)
After Attention: tensor(False)
Final Output: tensor(False)
                                 Kernel Shape  Output Shape    Params  \
Layer                                                                   
0_convs.sequence.0.Conv1d_0       [64, 32, 4]   [2, 32, 65]    8.224k   
1_convs.sequence.0.LeakyReLU_1              -   [2, 32, 65]         -   
2_convs.sequence.1.Conv1d_0       [32, 64, 4]   [2, 64, 33]    8.256k   
3_convs.sequence.1.LeakyReLU_1              -   [2, 64, 33]         -   
4_convs.sequence.2.Conv1d_0      [64, 128, 4]  [2, 128, 17]   32.896k   
5_convs.sequence.2.LeakyReLU_1              -  [2, 128, 17]         -   
6_lstm                                      -  [17, 2, 128]  397.312k   
7_attention.Conv1d_content      [128, 128, 1]  [2, 128, 17]   16.512k   
8_attention.Conv1d_query        [128, 128, 1]  [2, 128, 17]  

  df_sum = df.sum()


Unnamed: 0_level_0,Kernel Shape,Output Shape,Params,Mult-Adds
Layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0_convs.sequence.0.Conv1d_0,"[64, 32, 4]","[2, 32, 65]",8224.0,532480.0
1_convs.sequence.0.LeakyReLU_1,-,"[2, 32, 65]",,
2_convs.sequence.1.Conv1d_0,"[32, 64, 4]","[2, 64, 33]",8256.0,270336.0
3_convs.sequence.1.LeakyReLU_1,-,"[2, 64, 33]",,
4_convs.sequence.2.Conv1d_0,"[64, 128, 4]","[2, 128, 17]",32896.0,557056.0
5_convs.sequence.2.LeakyReLU_1,-,"[2, 128, 17]",,
6_lstm,-,"[17, 2, 128]",397312.0,393216.0
7_attention.Conv1d_content,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0
8_attention.Conv1d_query,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0
9_attention.Conv1d_key,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0


### audio encoder (replicated meta's deepmel)

In [None]:
# Speech model

class LayerScale(nn.Module):
    def __init__(self, channels: int, init: float = 0.1, boost: float = 5.):
        super().__init__()
        self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
        self.scale.data[:] = init / boost
        self.boost = boost

    def forward(self, x):
        return (self.boost * self.scale[:, None]) * x

class ConvSequence(nn.Module):

    def __init__(self,
                 channels: tp.Sequence[int],
                 kernel: int = 4,
                 dilation_growth: int = 1,
                 dilation_period: tp.Optional[int] = None,
                 stride: int = 2,
                 dropout: float = 0.0,
                 leakiness: float = 0.0,
                 groups: int = 1,
                 decode: bool = False,
                 batch_norm: bool = False,
                 dropout_input: float = 0,
                 skip: bool = False,
                 scale: tp.Optional[float] = None,
                 rewrite: bool = False,
                 activation_on_last: bool = True,
                 post_skip: bool = False,
                 glu: int = 0,
                 glu_context: int = 0,
                 glu_glu: bool = True,
                 activation: tp.Any = None) -> None:
        super().__init__()
        dilation = 1
        channels = tuple(channels)
        self.skip = skip
        self.sequence = nn.ModuleList()
        self.glus = nn.ModuleList()
        if activation is None:
            activation = partial(nn.LeakyReLU, leakiness)
        Conv = nn.Conv1d if not decode else nn.ConvTranspose1d
        # build layers
        for k, (chin, chout) in enumerate(zip(channels[:-1], channels[1:])):
            layers: tp.List[nn.Module] = []
            is_last = k == len(channels) - 2

            # Set dropout for the input of the conv sequence if defined
            if k == 0 and dropout_input:
                assert 0 < dropout_input < 1
                layers.append(nn.Dropout(dropout_input))

            # conv layer
            if dilation_growth > 1:
                assert kernel % 2 != 0, "Supports only odd kernel with dilation for now"
            if dilation_period and (k % dilation_period) == 0:
                dilation = 1
            pad = kernel // 2 * dilation
            layers.append(Conv(chin, chout, kernel, stride, pad,
                               dilation=dilation, groups=groups if k > 0 else 1))
            dilation *= dilation_growth
            # non-linearity
            if activation_on_last or not is_last:
                if batch_norm:
                    layers.append(nn.BatchNorm1d(num_features=chout))
                layers.append(activation())
                if dropout:
                    layers.append(nn.Dropout(dropout))
                if rewrite:
                    layers += [nn.Conv1d(chout, chout, 1), nn.LeakyReLU(leakiness)]
                    # layers += [nn.Conv1d(chout, 2 * chout, 1), nn.GLU(dim=1)]
            if chin == chout and skip:
                if scale is not None:
                    layers.append(LayerScale(chout, scale))
                if post_skip:
                    layers.append(Conv(chout, chout, 1, groups=chout, bias=False))

            self.sequence.append(nn.Sequential(*layers))
            if glu and (k + 1) % glu == 0:
                ch = 2 * chout if glu_glu else chout
                act = nn.GLU(dim=1) if glu_glu else activation()
                self.glus.append(
                    nn.Sequential(
                        nn.Conv1d(chout, ch, 1 + 2 * glu_context, padding=glu_context), act))
            else:
                self.glus.append(None)

    def forward(self, x: tp.Any) -> tp.Any:
        for module_idx, module in enumerate(self.sequence):
            old_x = x
            x = module(x)
            if self.skip and x.shape == old_x.shape:
                x = x + old_x
            glu = self.glus[module_idx]
            if glu is not None:
                x = glu(x)
        return x

class DeepMel(ConvSequence):
    """DeepMel model that extracts features from the Mel spectrogram.

    Parameters
    ----------
    n_in_channels :
        Number of input channels.
    n_hidden_channels :
        Number of channels in hidden layers.
    n_hidden_layers :
        Number of hidden layers.
    n_out_channels :
        Number of output channels.
    kwargs:
        Additional keyword arguments to pass to ConvSequence.
    """
    def __init__(self,
                 n_in_channels: int,
                 n_hidden_channels: int,
                 n_hidden_layers: int,
                 n_out_channels: int, **kwargs):

        channels = \
            [n_in_channels] + [n_hidden_channels] * (n_hidden_layers - 1) + [n_out_channels]

        super().__init__(channels, **kwargs)

# Test the model
n_in_channels = 16
n_hidden_channels = 32
n_hidden_layers = 3
n_out_channels = 64
model = DeepMel(n_in_channels, n_hidden_channels, n_hidden_layers, n_out_channels)
input_size = (n_in_channels, 128)  # Example input size (channels, sequence length)
summary(model, input_size=input_size, device="cpu")

### brain-audio model (replicated meta)

In [None]:
class BrainAudioModel(torch.nn.Module):

  def __init__(# brain encoding
               self, in_channels, model_chout

               # audio encoding

               # brain to audio probabilities

               ):

    super().__init__()

    # brain encoding
    self.brain_encoding = EEG_Encoder(in_channels=in_channels,
                                      n_subjects=30,
                                      out_channels=model_chout,
                                      conv_chanels=[64, 32, 64, 128],
                                      kernel=4,
                                      stride=2,
                                      conv_dropout=0,
                                      batch_norm=False,
                                      dropout_input=0,
                                      leakiness=0,
                                      hidden_size=128,
                                      lstm_layers=4,
                                      lstm_dropout=0.1,
                                      attention_heads=4,
                                      subject_dim=64,
                                      embedding_scale=1.0)

    # audio encoding
    #self.audio_encoding = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # https://huggingface.co/facebook/wav2vec2-base-960h

    # brain to audio probabilities
    self.clip = ClipLoss(linear=None, twin=True, pool=False, tmin=None, tmax=None,
                         tmin_train=None, tmax_train=None, dset_args=None, center=False)

    # linear layers to match the dimensions
    self.linear1 = nn.Linear(128, model_chout)

  """
  def freeze_audio_embedding(self): # we need to call it after initializing the model model.freeze_audio_embedding()
    for param in self.audio_embedding.parameters():
      param.requires_grad = False
  """

  def forward(self, raw_eeg_inputs, audio_hidden_states):  # raw_audio_inputs

    # brain encoding
    brain_encoder_outputs = self.brain_encoding(raw_eeg_inputs)
    print('brain_encoder_outputs', brain_encoder_outputs.shape)

    print('audio_hidden_states', audio_hidden_states.shape)

    # audio encoding (optional, could be passed as a parameter to forward)
    #audio_encoder_outputs = self.audio_encoding(raw_audio_inputs).last_hidden_state
    #audio_encoder_outputs = self.linear(audio_encoder_outputs)
    audio_hidden_states_processed = self.linear1(audio_hidden_states)
    print('audio_hidden_states_processed', audio_hidden_states_processed.shape)

    # brain to audio probabilities
    brain_encoder_outputs_reshaped = brain_encoder_outputs.permute(0, 2, 1) # reshaping to [B, C, T]
    audio_hidden_states_reshaped = audio_hidden_states_processed.permute(0, 2, 1) # reshaping to [B, C, T]
    probabilities = self.clip.get_probabilities(brain_encoder_outputs_reshaped, audio_hidden_states_reshaped)
    print('probabilities', probabilities)
    return probabilities

# TEST
in_channels = 64  # EEG input channels
model_chout = 17  # output channels
model = BrainAudioModel(in_channels, model_chout)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
eeg_input_size = (2, in_channels, 128)  # Batch size, channels, sequence length
audio_hidden_state_size = (2, model_chout, 128)  # Assuming the same sequence length and batch size
eeg_dummy_input = torch.randn(eeg_input_size).to(device)
audio_hidden_state_dummy = torch.randn(audio_hidden_state_size).to(device)
summary(model, eeg_dummy_input, audio_hidden_state_dummy)

eeg_inputs before convs torch.Size([2, 64, 128])
eeg_inputs after convs torch.Size([2, 128, 17])
After convs: tensor(False)
After LSTM: tensor(True)
After Attention: tensor(True)
Final Output: tensor(True)
brain_encoder_outputs torch.Size([2, 17, 17])
audio_hidden_states torch.Size([2, 17, 128])
audio_hidden_states_processed torch.Size([2, 17, 17])
probabilities tensor([[nan, nan],
        [nan, nan]])
                                                Kernel Shape  Output Shape  \
Layer                                                                        
0_brain_encoding.convs.sequence.0.Conv1d_0       [64, 32, 4]   [2, 32, 65]   
1_brain_encoding.convs.sequence.0.LeakyReLU_1              -   [2, 32, 65]   
2_brain_encoding.convs.sequence.1.Conv1d_0       [32, 64, 4]   [2, 64, 33]   
3_brain_encoding.convs.sequence.1.LeakyReLU_1              -   [2, 64, 33]   
4_brain_encoding.convs.sequence.2.Conv1d_0      [64, 128, 4]  [2, 128, 17]   
5_brain_encoding.convs.sequence.2.LeakyReLU_1   

  df_sum = df.sum()


Unnamed: 0_level_0,Kernel Shape,Output Shape,Params,Mult-Adds
Layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0_brain_encoding.convs.sequence.0.Conv1d_0,"[64, 32, 4]","[2, 32, 65]",8224.0,532480.0
1_brain_encoding.convs.sequence.0.LeakyReLU_1,-,"[2, 32, 65]",,
2_brain_encoding.convs.sequence.1.Conv1d_0,"[32, 64, 4]","[2, 64, 33]",8256.0,270336.0
3_brain_encoding.convs.sequence.1.LeakyReLU_1,-,"[2, 64, 33]",,
4_brain_encoding.convs.sequence.2.Conv1d_0,"[64, 128, 4]","[2, 128, 17]",32896.0,557056.0
5_brain_encoding.convs.sequence.2.LeakyReLU_1,-,"[2, 128, 17]",,
6_brain_encoding.LSTM_lstm,-,"[17, 2, 128]",340480.0,336384.0
7_brain_encoding.attention.Conv1d_content,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0
8_brain_encoding.attention.Conv1d_query,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0
9_brain_encoding.attention.Conv1d_key,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0


## brain encoder (our: cnn + lstm + transformer)

In [243]:
class PermuteBlock(torch.nn.Module):

    def forward(self, x):
        return x.transpose(1, 2)

class PositionalEncoding(nn.Module):

    def __init__(self, input_size, max_input_seq_len, dropout):

        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_input_seq_len, input_size)
        position = torch.arange(0, max_input_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, input_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / input_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):

        x = x + self.pe[:, :x.size(1), :]
        x = self.dropout(x)

        return x

# Transformer  --------------------------------------------------------------------

class TransformerEncoder(torch.nn.Module):

    def __init__(self, tf_input_size, num_heads, dropout):

        super().__init__()

        self.self_attn    = torch.nn.MultiheadAttention(tf_input_size, num_heads)

        self.ln           = nn.LayerNorm(tf_input_size)

        self.feed_forward = nn.Sequential(nn.Linear(tf_input_size, tf_input_size * 4),
                                          nn.ReLU(),
                                          nn.Dropout(dropout),
                                          nn.Linear(tf_input_size * 4, tf_input_size))

        self.feed_forward_simple = nn.Linear(tf_input_size, tf_input_size)

    def forward(self, x, mask=None):

        # attention
        attention_context, attention_weights = self.self_attn(x, x, x, mask)

        # residual connection 1
        res1 = x + attention_context

        # layer norm 1
        ln1 = self.ln(res1)

        # feed forward
        ff = self.feed_forward(ln1)

        # residual connection 2
        res2 = ln1 + ff

        # layer norm 2
        ln2 = self.ln(res2)

        return ln2

# Listener: CNN -> BLSTM -> positional -> transformer -------------------------------------------

class EEG_Transformer_Encoder(torch.nn.Module):

    def __init__(self, input_size, lstm_layers, listener_hidden_size,
                 num_heads_listener, tf_blocks_listener, dropout):

        super().__init__()

        # embedding 1: cnn
        self.cnn = torch.nn.Sequential(

            PermuteBlock(),

            # conv 1
            nn.Conv1d(input_size, listener_hidden_size // 4, kernel_size=2, stride=2),
            nn.BatchNorm1d(listener_hidden_size // 4),
            nn.GELU(),

            # conv 2
            nn.Conv1d(listener_hidden_size // 4, listener_hidden_size // 2, kernel_size=2, stride=2),
            nn.BatchNorm1d(listener_hidden_size // 2),
            nn.GELU(),

            # conv 3
            nn.Conv1d(listener_hidden_size // 2, listener_hidden_size, kernel_size=2, stride=2),
            nn.BatchNorm1d(listener_hidden_size),
            nn.GELU(),

            PermuteBlock())

        # embedding 2: blstm
        self.blstm          = nn.LSTM(input_size    = listener_hidden_size,
                                      hidden_size   = listener_hidden_size // 2,
                                      num_layers    = lstm_layers,
                                      batch_first   = True,
                                      bidirectional = True,
                                      dropout=0)

        # embedding 3: positional
        self.max_input_seq_len   = 31914 // num_heads_listener
        self.positional_encoding = PositionalEncoding(input_size=listener_hidden_size,
                                                      max_input_seq_len=self.max_input_seq_len,
                                                      dropout=dropout)

        # transformer (same input-output shape: tf_input_size --> tf_input_size)
        transformer_blocks = []
        for i in range(tf_blocks_listener):
            transformer_blocks.append(
                TransformerEncoder(
                    tf_input_size=listener_hidden_size,
                    num_heads=num_heads_listener,
                    dropout=dropout))
        self.transformer_encoder = nn.Sequential(*transformer_blocks)

    def forward(self, x, x_len): # x = raw_eeg, x_len = raw_eeg_len

        # embedding 1: cnn
        x = self.cnn(x)
        x_len = torch.clamp(x_len, max=x.shape[1])

        # embedding 2: blstm
        x_packed                = pack_padded_sequence(x, x_len.cpu(), batch_first=True, enforce_sorted=False)
        lstm_out, _             = self.blstm(x_packed)
        output, output_lengths  = pad_packed_sequence(lstm_out, batch_first=True)

        # embedding 3: positional
        output  = self.positional_encoding(output)

        # transformer
        output  = self.transformer_encoder(output)

        return output, output_lengths


# TEST
def test_transformer_listener():
    model = EEG_Transformer_Encoder(
        input_size=62,
        lstm_layers=1,
        listener_hidden_size=512,
        num_heads_listener=2,
        tf_blocks_listener=2,
        dropout=0.0).to(DEVICE)

    # Sample input
    batch_size = 2
    seq_len = 1000
    input_dim = 62
    x_sample = torch.rand(batch_size, seq_len, input_dim).to(DEVICE)
    x_lengths = torch.full((batch_size,), seq_len, dtype=torch.int64).to(DEVICE)

    # Forward pass
    output, output_lengths = model(x_sample, x_lengths)

    # Check output dimensions
    assert output.size(0) == batch_size
    assert output.size(2) == 512
    assert all(length <= seq_len for length in output_lengths)

    summary(model, x_sample, x_lengths)
    print("All tests passed.")

test_transformer_listener()

                                                  Kernel Shape  \
Layer                                                            
0_cnn.PermuteBlock_0                                         -   
1_cnn.Conv1d_1                                    [62, 128, 2]   
2_cnn.BatchNorm1d_2                                      [128]   
3_cnn.GELU_3                                                 -   
4_cnn.Conv1d_4                                   [128, 256, 2]   
5_cnn.BatchNorm1d_5                                      [256]   
6_cnn.GELU_6                                                 -   
7_cnn.Conv1d_7                                   [256, 512, 2]   
8_cnn.BatchNorm1d_8                                      [512]   
9_cnn.GELU_9                                                 -   
10_cnn.PermuteBlock_10                                       -   
11_blstm                                                     -   
12_positional_encoding.Dropout_dropout                       -   
13_transfo

## clip loss (copied from meta)

In [389]:
class ClipLoss(torch.nn.Module):

    def __init__(self, linear=None, twin=True, pool=False, tmin=None, tmax=None,
                 tmin_train=None, tmax_train=None, dset_args=None, center=False):
        super().__init__()
        self.linear = None
        self.pool = pool
        self.center = center
        if linear is not None:
            self.linear_est = torch.nn.LazyLinear(linear)
            if twin:
                self.linear_gt = self.linear_est
            else:
                self.linear_gt = torch.nn.LazyLinear(linear)
        self.tmin = tmin
        self.tmax = tmax
        self.tmin_train = tmin_train
        self.tmax_train = tmax_train
        self.dset_args = dset_args

    def trim_samples(self, estimates, candidates):
        """Given estimates that is [B1, C, T] and candidates
        which is [B2, C, T], return estimates_trim of size [B1, C, T']
        and candidates_trim of size [B2, C, T'], such that T'
        corresponds to the samples between [self.tmin, self.tmax]
        """
        if self.training and (self.tmin_train is not None or self.tmax_train is not None):
            tmin, tmax = self.tmin_train, self.tmax_train
        else:
            tmin, tmax = self.tmin, self.tmax
        if (tmin is not None) or (tmax is not None):
            assert self.dset_args is not None
            assert self.dset_args.tmin is not None
            dset_tmin = self.dset_args.tmin
        if tmin is None:
            trim_min = 0
        else:
            assert tmin >= dset_tmin, 'clip.tmin should be above dset.tmin'
            trim_min = int((-dset_tmin + tmin) * self.dset_args.sample_rate)
        if tmax is None:
            trim_max = estimates.shape[-1]
        else:
            trim_max = int((-dset_tmin + tmax) * self.dset_args.sample_rate)
        estimates_trim = estimates[..., trim_min:trim_max]
        candidates_trim = candidates[..., trim_min:trim_max]
        return estimates_trim, candidates_trim

    def get_scores(self, estimates: torch.Tensor, candidates: torch.Tensor):
        """Given estimates that is [B, C, T] and candidates
        which is [B', C, T], return a [B, B'] matrix of scores of matching.
        """
        estimates, candidates = self.trim_samples(estimates, candidates)
        if self.linear:
            estimates = self.linear_est(estimates)
            candidates = self.linear_gt(candidates)
        if self.pool:
            estimates = estimates.mean(dim=2, keepdim=True)
            candidates = candidates.mean(dim=2, keepdim=True)
        if self.center:
            estimates = estimates - estimates.mean(dim=(1, 2), keepdim=True)
            candidates = candidates - candidates.mean(dim=(1, 2), keepdim=True)
        inv_norms = 1 / (1e-8 + candidates.norm(dim=(1, 2), p=2))
        scores = torch.einsum("bct,oct,o->bo", estimates, candidates, inv_norms)
        return scores

    def get_probabilities(self, estimates, candidates):
        """Given estimates that is [B, C, T] and candidates
        which is [B', C, T], return a [B, B'] matrix of probabilities of matching.
        """
        scores = self.get_scores(estimates, candidates)
        probabilities = F.softmax(scores, dim=1)
        return probabilities

    def forward(self, estimate, candidate, mask=None):
        """Warning: estimate and candidate are not symmetrical.
        If estimate of shape [B, C, T] and candidate of size [B', C, T]
        with B'>=B, the first B samples of candidate are targets, while
        the remaining B'-B samples of candidate are only used as negatives.
        """
        assert estimate.size(0) <= candidate.size(0), "need at least as many targets as estimates"
        scores = self.get_scores(estimate, candidate)
        target = torch.arange(len(scores), device=estimate.device)
        loss = F.cross_entropy(scores, target)
        return loss

# TEST
batch_size = 3
num_channels = 4
time_steps = 5
estimates = torch.randn(batch_size, num_channels, time_steps)
candidates = torch.randn(batch_size, num_channels, time_steps)
clip_loss = ClipLoss()

probabilities = clip_loss.get_probabilities(estimates, candidates) # --> output in the model's forward function
loss = clip_loss(estimates, candidates) # --> use a loss

print(probabilities)
print(loss)

### Each row in the probability matrix corresponds to a set of EEG data,
#   and each column corresponds to a set of audio data.
#   The values in the matrix are probabilities that indicate
#   how likely it is that a given set of EEG data matches a given set of audio data.

tensor([[0.8012, 0.1432, 0.0556],
        [0.3515, 0.3947, 0.2538],
        [0.2379, 0.5154, 0.2468]])
tensor(0.8502)


## brain-audio model (our)

In [390]:
class BrainAudioModel(torch.nn.Module):

  def __init__(self, input_size, lstm_layers, listener_hidden_size,
               num_heads_listener, tf_blocks_listener, dropout,
               linear_output_size):

    super().__init__()

    # brain encoding
    self.brain_encoding = EEG_Transformer_Encoder(input_size, lstm_layers,
                                                  listener_hidden_size,
                                                  num_heads_listener, tf_blocks_listener,
                                                  dropout)

    # brain to audio probabilities
    self.clip = ClipLoss(linear=linear_output_size)

    # linear layer to match the dimensions if needed
    self.linear1 = nn.Linear(listener_hidden_size, linear_output_size)

  def forward(self, raw_eeg_inputs, raw_eeg_inputs_len, embedded_audio_inputs = None):

    print('\nraw_eeg_inputs.shape', raw_eeg_inputs.shape) # [B, T, C]
    print('embedded_audio_inputs.shape', embedded_audio_inputs.shape) # [B, T, C]

    # brain encoding
    brain_encoder_outputs, _ = self.brain_encoding(raw_eeg_inputs, raw_eeg_inputs_len)
    print('\nbrain_encoder_outputs.shape', brain_encoder_outputs.shape)

    brain_encoder_outputs = self.linear1(brain_encoder_outputs)

    # brain to audio probabilities
    brain_encoder_outputs_reshaped = brain_encoder_outputs.permute(0, 2, 1)  # reshaping to [B, C, T]
    audio_hidden_states_reshaped = embedded_audio_inputs.permute(0, 2, 1)  # reshaping to [B, C, T]

    print('\nbrain_encoder_outputs_reshaped.shape', brain_encoder_outputs_reshaped.shape)
    print('audio_hidden_states_reshaped.shape', audio_hidden_states_reshaped.shape)

    probabilities = self.clip.get_probabilities(brain_encoder_outputs_reshaped, audio_hidden_states_reshaped)
    print('\nprobabilities', probabilities)
    print('\nprobabilities.shape', probabilities.shape)

    return probabilities, brain_encoder_outputs_reshaped, audio_hidden_states_reshaped

# TEST
def test_combined():
    model = BrainAudioModel(input_size=62, lstm_layers=2,
                            listener_hidden_size=256,
                            num_heads_listener=8,
                            tf_blocks_listener=4, dropout=0.1,
                            linear_output_size=1024).to(DEVICE)
    batch_size = 2
    seq_len_eeg = 1000
    input_dim_eeg = 62

    seq_len_audio = 2000
    input_dim_audio = 1024

    x_sample = torch.rand(batch_size, seq_len_eeg, input_dim_eeg).to(DEVICE)
    x_lengths = torch.full((batch_size,), seq_len_eeg, dtype=torch.int64).to(DEVICE)
    y_sample = torch.rand(batch_size, seq_len_audio, input_dim_audio).to(DEVICE)

    # Forward pass
    probabilities = model(x_sample, x_lengths, y_sample)

    # Check output dimensions

    summary(model, x_sample, x_lengths, y_sample)

test_combined()


raw_eeg_inputs.shape torch.Size([2, 1000, 62])
embedded_audio_inputs.shape torch.Size([2, 2000, 1024])

brain_encoder_outputs.shape torch.Size([2, 125, 256])

brain_encoder_outputs_reshaped.shape torch.Size([2, 1024, 125])
audio_hidden_states_reshaped.shape torch.Size([2, 1024, 2000])

probabilities tensor([[0.6911, 0.3089],
        [0.6605, 0.3395]], grad_fn=<SoftmaxBackward0>)

probabilities.shape torch.Size([2, 2])

raw_eeg_inputs.shape torch.Size([2, 1000, 62])
embedded_audio_inputs.shape torch.Size([2, 2000, 1024])

brain_encoder_outputs.shape torch.Size([2, 125, 256])

brain_encoder_outputs_reshaped.shape torch.Size([2, 1024, 125])
audio_hidden_states_reshaped.shape torch.Size([2, 1024, 2000])

probabilities tensor([[0.6723, 0.3277],
        [0.7041, 0.2959]])

probabilities.shape torch.Size([2, 2])
                                                     Kernel Shape  \
Layer                                                               
0_brain_encoding.cnn.PermuteBlock_0         

## audio-to-word decoder & wer metrics

In [391]:
# Load the pre-trained model
wav2vec_path = '/content/gdrive/MyDrive/11785-IDLf23/Final_project/pretrained-models/'
wav2vec_final_layer = torch.load(wav2vec_path + 'wav2vec2-final-layer.pkl').to(DEVICE)
wav2vec_processor = torch.load(wav2vec_path + 'wav2vec2-processor.pkl')

# Decode predictions
def decode_predictions(predictions):

    logits = wav2vec_final_layer(predictions)
    predicted_ids = logits.argmax(dim=-1)
    transcriptions = wav2vec_processor.batch_decode(predicted_ids)

    return transcriptions

# Calculate the evaluation metric
def decode_predictions_and_evaluate(predictions, targets):

    batch_size = predictions.size(0)
    total_wer_lev = 0.0
    total_correct_general = 0
    total_words_general = 0
    total_correct_vocab = 0
    total_words_vocab = 0

    for i in range(batch_size):

        # Decode logits of predictions and targets into textual transcriptions
        pred_words = decode_predictions(predictions[i].unsqueeze(0))
        target_words = decode_predictions(targets[i].unsqueeze(0))

        # Ensure pred_words and target_words are lists of strings
        if not isinstance(pred_words, list):
            pred_words = [pred_words]
        if not isinstance(target_words, list):
            target_words = [target_words]

        # METRIC 1: WER Levenshtein -------------------------------------------------------------------------------
        """
        Average Levenshtein distance-based WER across the batch.
        It indicates the average # of single-character edits (insertions, deletions, substitutions)
        required to change the predicted sentences into the target sentences,
        normalized by the number of words in the target sentences.
        A value of 2.25 suggests that, on average, about 2 to 3 edits are needed per word to correct the predictions.
        """
        wer_lev = lev.distance(' '.join(pred_words), ' '.join(target_words)) / max(len(target_words), 1)
        total_wer_lev += wer_lev

        # METRIC 2: WER General -------------------------------------------------------------------------------
        """
        WER General is calculated as the simple proportion of correctly identified words,
        where both position and order are crucial.
        WER General of 75% indicates that 75% of the words in the predictions were incorrect
        when compared to the target words, taking into account the exact sequence in which they appear.
        """
        correct_general = sum(pw == tw for pw, tw in zip(pred_words, target_words))
        total_correct_general += correct_general
        total_words_general += len(target_words)

        # METRIC 3: WER Vocab -------------------------------------------------------------------------------
        """
        This metric calculates the proportion of words in the predictions that are present in the target vocabulary,
        regardless of their position or order.
        A WER of 75% here indicates that 75% of the words in the predictions were not found in the target vocabulary.
        This metric is more lenient, focusing on the presence of predicted words within the overall pool of words
        used in the targets, without considering their specific sequence or placement.
        """
        vocabulary = set(target_words)
        correct_vocab = sum(pw in vocabulary for pw in pred_words)
        total_correct_vocab += correct_vocab
        total_words_vocab += len(pred_words)

    avg_wer_lev = total_wer_lev / batch_size
    wer_general = (1 - total_correct_general / total_words_general) * 100
    accuracy_general = 100 - wer_general
    wer_vocab = (1 - total_correct_vocab / total_words_vocab) * 100
    accuracy_vocab = 100 - wer_vocab

    return avg_wer_lev, wer_general, accuracy_general, wer_vocab, accuracy_vocab

# TEST -------------------------------------------------------------------------------
dummy_targets = torch.randn(10, 1024)
dummy_predictions = torch.randn(10, 1024)
decoded_targets = decode_predictions(dummy_targets)
decoded_predictions = decode_predictions(dummy_predictions)
avg_wer_lev, wer_general, accuracy_general, wer_vocab, accuracy_vocab = decode_predictions_and_evaluate(dummy_predictions, dummy_targets)
#target_words = ["hello", "world", "test", "laptop"]
#pred_words = ["hello", "test", "deep learning", "11765"]
#print("dummy_targets\n", dummy_targets)
#print("dummy_predictions\n", dummy_predictions)
#print("decoded_targets", decoded_targets)
#print("decoded_predictions", decoded_predictions)
#print("avg_wer_lev", avg_wer_lev)
#print("wer_general", wer_general)
#print("accuracy_general", accuracy_general)
#print("wer_vocab", wer_vocab)
#print("accuracy_vocab", accuracy_vocab)

# training and validation

In [392]:
def plot_attention(attention):
    plt.clf()
    seaborn.heatmap(attention, cmap='GnBu')
    plt.show()

def save_model(model, optimizer, scheduler, metric, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         'scheduler_state_dict'     : scheduler.state_dict(),
         metric[0]                  : metric[1],
         'epoch'                    : epoch},
         path)

def load_model(best_path, epoch_path, model, mode= 'best', metric= 'valid_acc', optimizer= None, scheduler= None):

    if mode == 'best':
        checkpoint  = torch.load(best_path)
        print("Loading best checkpoint: ", checkpoint[metric])
    else:
        checkpoint  = torch.load(epoch_path)
        print("Loading epoch checkpoint: ", checkpoint[metric])

    model.load_state_dict(checkpoint['model_state_dict'], strict= False)

    if optimizer != None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        #optimizer.param_groups[0]['lr'] = 1.5e-3
        optimizer.param_groups[0]['weight_decay'] = 1e-5
    if scheduler != None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    epoch   = checkpoint['epoch']
    metric  = torch.load(best_path)[metric]

    return [model, optimizer, scheduler, epoch, metric]

class TimeElapsed():
    def __init__(self):
        self.start  = -1

    def time_elapsed(self):
        if self.start == -1:
            self.start = time.time()
        else:
            end = time.time() - self.start
            hrs, rem    = divmod(end, 3600)
            min, sec    = divmod(rem, 60)
            min         = min + 60*hrs
            print("Time Elapsed: {:0>2}:{:02}".format(int(min),int(sec)))
            self.start  = -1

def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
        if m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [393]:
def train(model, dataloader, criterion, optimizer):

    model.train()

    total_loss = 0.0

    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train')

    for i, (eeg, audio, eeg_lengths, audio_lengths) in enumerate(dataloader): # should start from 0

        eeg, audio = eeg.to(DEVICE), audio.to(DEVICE)

        optimizer.zero_grad()

        probabilities, brain_encoder_outputs_reshaped, audio_hidden_states_reshaped = model(eeg, eeg_lengths, audio)
        loss = criterion(brain_encoder_outputs_reshaped, audio_hidden_states_reshaped)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3) # optional

        batch_bar.set_postfix(loss="{:.04f}".format(total_loss/(i+1)),
                              lr="{:.04f}".format(float(optimizer.param_groups[0]['lr'])))
        batch_bar.update()

        del eeg, audio, eeg_lengths, audio_lengths
        torch.cuda.empty_cache()

    loss = total_loss / len(dataloader)
    batch_bar.close()

    return loss

def validate(model, dataloader):

    model.eval()

    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc="Val")

    total_loss = 0.0
    total_avg_wer_lev = 0.0
    total_wer_general = 0.0
    total_accuracy_general = 0.0
    total_wer_vocab = 0.0
    total_accuracy_vocab = 0.0

    with torch.no_grad():

        for i, (eeg, audio, eeg_lengths, audio_lengths) in enumerate(dataloader):

            eeg = eeg.to(DEVICE)

            # Get model predictions
            probabilities, brain_encoder_outputs_reshaped, audio_hidden_states_reshaped = model(eeg, eeg_lengths, audio)

            # Calculate loss
            loss = criterion(brain_encoder_outputs_reshaped, audio_hidden_states_reshaped)
            total_loss += loss.item()

            # Decode predictions and targets, then calculate WER
            #decoded_predictions = decode_predictions(brain_encoder_outputs_reshaped)
            #decoded_targets = decode_predictions(audio_hidden_states_reshaped)
            avg_wer_lev, wer_general, accuracy_general, wer_vocab, accuracy_vocab = decode_predictions_and_evaluate(brain_encoder_outputs_reshaped,
                                                                                                         audio_hidden_states_reshaped)
            total_avg_wer_lev += avg_wer_lev
            total_wer_general += wer_general
            total_accuracy_general += accuracy_general
            total_wer_vocab += wer_vocab
            total_accuracy_vocab += accuracy_vocab

            batch_bar.set_postfix(loss="{:.04f}".format(total_loss / (i + 1)),
                                  wer_lev="{:.02f}".format(total_avg_wer_lev / (i + 1)),
                                  wer_general="{:.02f}".format(total_wer_general / (i + 1)),
                                  accuracy_general="{:.02f}".format(total_accuracy_general / (i + 1)),
                                  wer_vocab="{:.02f}".format(total_wer_vocab / (i + 1)),
                                  accuracy_vocab="{:.02f}".format(total_accuracy_vocab / (i + 1)))
            batch_bar.update()

            del eeg, audio, eeg_lengths, audio_lengths
            torch.cuda.empty_cache()

    batch_bar.close()
    loss = total_loss / len(dataloader)
    wer_lev = total_avg_wer_lev / len(dataloader)
    wer_general = total_wer_general / len(dataloader)
    accuracy_general = total_accuracy_general / len(dataloader)
    wer_vocab = total_wer_vocab / len(dataloader)
    accuracy_vocab = total_accuracy_vocab / len(dataloader)

    return loss, wer_lev, wer_general, accuracy_general, wer_vocab, accuracy_vocab

# wandb


In [399]:
import wandb
wandb.login(key="3e9cb37b9f485bc61ca01d7f5ac0130ff698fb6a")

run = wandb.init(
    name = "test", # wandb creates random run names if you skip this field
    reinit = True, # allows reinitalizing runs when you re-run this cell
    #id = "niwgbub6", # insert specific run id here if you want to resume a previous run
    #resume = "allow", # you need this to resume previous runs, but comment out reinit = True when using this
    project = "final_project", # project should be created in your wandb account
    config = config) # wandb config for your run

model_arch_file = "model_architecture.txt" # save your model architecture in a txt file, and save the file to Wandb

with open(model_arch_file, "w") as f:
    f.write(str(model))
wandb.save(model_arch_file)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mkshapova[0m ([33midl2023[0m). Use [1m`wandb login --relogin`[0m to force relogin


['/content/wandb/run-20231210_155407-rynti3eh/files/model_architecture.txt']

# experiments

In [None]:
torch.cuda.empty_cache()
gc.collect()

model = BrainAudioModel(input_size           = 62,
                        lstm_layers          = 2,
                        listener_hidden_size = 256,
                        num_heads_listener   = 8,
                        tf_blocks_listener   = 4,
                        dropout              = 0.1,
                        linear_output_size   = 1024)

model.apply(weights_init_kaiming)

model = model.to(DEVICE)

optimizer = AdamW(model.parameters(), lr=config['learning_rate'], betas=(0.9, 0.999),
                  eps=1e-8, weight_decay=config['weight_decay'], amsgrad=False) #optimizer = AdamP(model.parameters(), lr=config['learning_rate'], betas=(0.9, 0.999), weight_decay=config['weight_decay'])

criterion = ClipLoss(linear=None, twin=True, pool=False, tmin=None, tmax=None,
                     tmin_train=None, tmax_train=None, dset_args=None, center=False)

#scaler = torch.cuda.amp.GradScaler() # optional

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=config['factor'],
                                                       patience=config['patience'], verbose=True)

# TEST ------------------------------------------------------------------------------------------------------------
def test_combined():
    seq_len_eeg = 1000
    input_dim_eeg = 62

    seq_len_audio = 2000
    input_dim_audio = 1024

    eeg_sample = torch.rand(config['batch_size'], seq_len_eeg, input_dim_eeg).to(DEVICE)
    eeg_lengths = torch.full((config['batch_size'],), seq_len_eeg, dtype=torch.int64).to(DEVICE)
    audio_sample = torch.rand(config['batch_size'], seq_len_audio, input_dim_audio).to(DEVICE)

    # Forward pass
    probabilities = model(eeg_sample, eeg_lengths, audio_sample)

    # Check output dimensions
    summary(model, eeg_sample, eeg_lengths, audio_sample)

test_combined()

In [None]:
# PLEASE RUN THIS CELL ONLY IF YOU NEED TO RESUME TRAINING FROM A CERTAIN CHECKPOINT
"""
# Load the most recent checkpoint if needed
checkpoint_path = "/content/drive/MyDrive/HW4_P2_Checkpoints/checkpoint_epoch11_valid_dist8.8932.pth"
model, optimizer, scheduler, last_epoch_completed, best_wer = load_model(checkpoint_path, checkpoint_path, model, metric='valid_dist', optimizer=optimizer, scheduler=scheduler)

# Introduce changes if needeed
new_lr = 0.00005
for param_group in optimizer.param_groups:
    param_group['lr'] = new_lr
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True)

# Print diagnostic information if needed
print(f"Model: {model}")
print(f"Optimizer: {optimizer}")
print(f"Current learning rate from optimizer: {optimizer.param_groups[0]['lr']}")
print(f"Scheduler: {scheduler}")
print(f"Last completed epoch from checkpoint: {last_epoch_completed}")
print(f"Best Levenshtein distance from checkpoint: {best_wer}")
"""

In [396]:
start = 0
end = config["epochs"]
best_wer = float("inf")

checkpoint_dir = '/content/gdrive/MyDrive/11785-IDLf23/Final_project/6_Checkpoints/'
epoch_model_path = os.path.join(checkpoint_dir, 'checkpoint_epoch{epoch}_valid_dist{valid_dist:.4f}.pth')
best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')

In [None]:
torch.cuda.empty_cache()
gc.collect()

for epoch in range(start, config['epochs']):

    print("\nEpoch: {}/{}".format(epoch+1, config['epochs']))

    curr_lr = optimizer.param_groups[0]['lr']

    # Call train and validate, get attention weights from training
    train_loss = train(model, train_loader, criterion, optimizer)
    valid_loss, wer_lev, wer_general, accuracy_general, wer_vocab, accuracy_vocab = validate(model, val_loader)

    # Print your metrics
    print("\tTrain Loss {:.04f}\t Learning Rate {:.07f}".format(train_loss, curr_lr))
    print("\tVal Loss {:.04f}%".format(valid_loss))

    # Plot attention for a single item in the batch
    #plot_attention(attention_plot[0].cpu().detach().numpy())

    # Log metrics to Wandb
    wandb.log({"train_loss": train_loss,
               "valid_loss": valid_loss,
               "wer_lev": wer_lev,
               "wer_general": wer_general,
               "accuracy_general": accuracy_general,
               "wer_vocab": wer_vocab,
               "accuracy_vocab": accuracy_vocab,
               "epoch": epoch+1,
               "lr": curr_lr})
    save_model(model, optimizer, scheduler, ['wer_vocab', wer_vocab], epoch, epoch_model_path.format(epoch=epoch,
                                                                                                     wer_vocab=wer_vocab))

    wandb.save(epoch_model_path)
    print("Saved epoch model")

    # Scheduler
    scheduler.step(wer_vocab)

    if wer_vocab <= best_wer:
      best_lev_dist = wer_vocab
      save_model(model, optimizer, scheduler, None, ['wer_vocab', wer_vocab], epoch, best_model_path)
      wandb.save(best_model_path)
      print("Saved best model")

run.finish()

# viz