# imports

In [1]:
!pip install wandb torchsummaryX mne transformers -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.3/8.3 MB[0m [31m61.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m20.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m252.8/252.8 kB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [14]:
### Torch
import  torch
import  torch.nn as nn
import  torch.nn.functional as F
from    torch.optim import lr_scheduler
from    torchsummaryX import summary
from    torch.utils.data import Dataset, DataLoader
import  torchaudio
import  torchaudio.transforms as tat
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

### General
import  random
import  numpy as np
import  pandas as pd
import  scipy
import  gc
from    tqdm.auto import tqdm
import  os
import  datetime
import  time
import  wandb
import  matplotlib.pyplot as plt
import  seaborn as sns

### Other

from functools import partial
import logging
import math
import typing as tp

# wav2vec2 and EEG processing
from    transformers import (AutoProcessor, AutoModelForPreTraining,
                             CLIPProcessor, CLIPModel)
import  mne
from transformers import Wav2Vec2Model

# working directory

In [3]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [4]:
DATA_PATH  = '/content/gdrive/MyDrive/11785-IDLf23/Final_project/0_Data/'

BRAIN_PATH = os.path.join(DATA_PATH, 'brennan_and_hale_v2')
AUDIO_PATH = os.path.join(DATA_PATH, 'brennan_and_hale_v2/audio')
PPROC_PATH = os.path.join(DATA_PATH, 'brennan_and_hale_v2/proc/timelock-preprocessing')

EEG_PATH   = os.path.join(DATA_PATH, 'eeg-segments')
EMBED_PATH = os.path.join(DATA_PATH, 'brennan_wav2vec2_embeddings')

# data [QA]

## dataset definition

In [5]:
print(os.listdir(EEG_PATH))

# for fname in os.listdir(EEG_PATH):
#     print(fname[:3])

import re

def extract_info(s):
    # This pattern looks for any text (non-digits) followed by digits, a hyphen, and more digits
    match = re.match(r'([A-Za-z]+[0-9]+)-([0-9]+).pkl', s)
    if match:
        return match.group(1), match.group(2)
    else:
        return None

# for fname in os.listdir(EEG_PATH):
#     print(extract_info(fname))

['S01-1.npy', 'S01-2.npy', 'S01-3.npy', 'S01-4.npy', 'S01-5.npy', 'S01-6.npy', 'S01-7.npy', 'S01-8.npy', 'S01-9.npy', 'S01-10.npy', 'S01-11.npy', 'S01-12.npy', 'S03-1.npy', 'S03-2.npy', 'S03-3.npy', 'S03-4.npy', 'S03-5.npy', 'S03-6.npy', 'S03-7.npy', 'S03-8.npy', 'S03-9.npy', 'S03-10.npy', 'S03-11.npy', 'S03-12.npy', 'S04-1.npy', 'S04-2.npy', 'S04-3.npy', 'S04-4.npy', 'S04-5.npy', 'S04-6.npy', 'S04-7.npy', 'S04-8.npy', 'S04-9.npy', 'S04-10.npy', 'S04-11.npy', 'S04-12.npy', 'S05-1.npy', 'S05-2.npy', 'S05-3.npy', 'S05-4.npy', 'S05-5.npy', 'S05-6.npy', 'S05-7.npy', 'S05-8.npy', 'S05-9.npy', 'S05-10.npy', 'S05-11.npy', 'S05-12.npy', 'S06-1.npy', 'S06-2.npy', 'S06-3.npy', 'S06-4.npy', 'S06-5.npy', 'S06-6.npy', 'S06-7.npy', 'S06-8.npy', 'S06-9.npy', 'S06-10.npy', 'S06-11.npy', 'S06-12.npy', 'S08-1.npy', 'S08-2.npy', 'S08-3.npy', 'S08-4.npy', 'S08-5.npy', 'S08-6.npy', 'S08-7.npy', 'S08-8.npy', 'S08-9.npy', 'S08-10.npy', 'S08-11.npy', 'S08-12.npy', 'S10-1.npy', 'S10-2.npy', 'S10-3.npy', 'S10-4

In [16]:
class BrainAudioDataset(torch.utils.data.Dataset):

    def __init__(self, eeg_path=EEG_PATH, embed_path=EMBED_PATH, transforms=None):
        super().__init__()

        self.eeg_root = eeg_path
        self.embed_root = embed_path
        self.transforms = transforms

        eeg_fnames = sorted(os.listdir(self.eeg_root))
        print(eeg_fnames)

        self.data = {}
        for idx, fname in enumerate(eeg_fnames):
            subject_id, segment = self.extract_info(fname)
            eeg_fpath = os.path.join(self.eeg_root, fname)
            audio_fpath = os.path.join(self.embed_root, f'audio-{segment}.pkl')
            self.data[idx] = (eeg_fpath, audio_fpath)

        self.length = len(self.data)

    def __len__(self):

        return self.length

    def __getitem__(self, idx):
        """
        - eeg               : (1, T_eeg, C_eeg)
        - audio_embedding   : (1, T_audio, 1024)
        - audio_logits      : (1, T_audio, 33)
        """
        eeg_path, audio_path = self.data[idx]

        eeg = np.load(eeg_path)
        eeg = torch.tensor(eeg.transpose())

        audio = torch.load(audio_path)

        audio_embed = audio.hidden_states[-1]

        return eeg, audio_embed.squeeze(0), len(eeg), len(audio_embed)

    def collate_fn(self, batch):

        eeg, audio_embed, len_eeg, len_audio_embed = zip(*batch)

        # pack the mfccs and transcripts using the pad_sequence function from pytorch
        batch_eeg_pad = torch.nn.utils.rnn.pad_sequence(eeg, batch_first=True)
        batch_audio_embed_pad = torch.nn.utils.rnn.pad_sequence(audio_embed, batch_first=True)
        del eeg, audio_embed

        # Apply transformations
        if self.transforms is not None:
            batch_eeg_pad = self.transforms(batch_eeg_pad)

        return (batch_eeg_pad,
                batch_audio_embed_pad,
                torch.tensor(len_eeg, dtype=torch.int64),
                torch.tensor(len_audio_embed, dtype=torch.int64))

    def read_sfp(self, file_path):
        """
        Reads a BESA SFP (Surface Point) file (locations of sensors)
        """

        electrodes = {}
        with open(file_path, 'r') as file:
            for line in file:
                parts = line.strip().split()  # Split by whitespace
                if len(parts) == 4:
                    # Parse the electrode name and coordinates
                    name = parts[0]
                    x, y, z = map(float, parts[1:])  # Convert strings to floats
                    electrodes[name] = (x, y, z)

        return electrodes

    def extract_info(self, fname):
        # This pattern looks for any text (non-digits) followed by digits, a hyphen, and more digits
        match = re.match(r'([A-Za-z]+[0-9]+)-([0-9]+).npy', fname)
        if match:
            return match.group(1), match.group(2)
        else:
            return None

In [17]:
dataset = BrainAudioDataset()

for i in range(3):
    eeg, audio_embedding, l_eeg, l_audio = dataset[i]
    print(eeg.shape, audio_embedding.shape, l_eeg, l_audio)

['S01-1.npy', 'S01-10.npy', 'S01-11.npy', 'S01-12.npy', 'S01-2.npy', 'S01-3.npy', 'S01-4.npy', 'S01-5.npy', 'S01-6.npy', 'S01-7.npy', 'S01-8.npy', 'S01-9.npy', 'S03-1.npy', 'S03-10.npy', 'S03-11.npy', 'S03-12.npy', 'S03-2.npy', 'S03-3.npy', 'S03-4.npy', 'S03-5.npy', 'S03-6.npy', 'S03-7.npy', 'S03-8.npy', 'S03-9.npy', 'S04-1.npy', 'S04-10.npy', 'S04-11.npy', 'S04-12.npy', 'S04-2.npy', 'S04-3.npy', 'S04-4.npy', 'S04-5.npy', 'S04-6.npy', 'S04-7.npy', 'S04-8.npy', 'S04-9.npy', 'S05-1.npy', 'S05-10.npy', 'S05-11.npy', 'S05-12.npy', 'S05-2.npy', 'S05-3.npy', 'S05-4.npy', 'S05-5.npy', 'S05-6.npy', 'S05-7.npy', 'S05-8.npy', 'S05-9.npy', 'S06-1.npy', 'S06-10.npy', 'S06-11.npy', 'S06-12.npy', 'S06-2.npy', 'S06-3.npy', 'S06-4.npy', 'S06-5.npy', 'S06-6.npy', 'S06-7.npy', 'S06-8.npy', 'S06-9.npy', 'S08-1.npy', 'S08-10.npy', 'S08-11.npy', 'S08-12.npy', 'S08-2.npy', 'S08-3.npy', 'S08-4.npy', 'S08-5.npy', 'S08-6.npy', 'S08-7.npy', 'S08-8.npy', 'S08-9.npy', 'S10-1.npy', 'S10-10.npy', 'S10-11.npy', 'S10

## dataloader definition

In [18]:
train_loader = torch.utils.data.DataLoader(
    dataset      = dataset,
    batch_size   = 2,
    shuffle      = True,
    drop_last    = False,
    num_workers  = 8,
    pin_memory   = True,
    collate_fn   = dataset.collate_fn)

print(len(dataset))
print(train_loader.__len__())

348
174


In [19]:
for batch in train_loader:
    x, y, x_len, y_len = batch
    print(x.shape, y.shape, x_len.shape, y_len.shape)
    print(x.dtype, y.dtype, x_len.dtype, y_len.dtype)
    del x, y, x_len, y_len
    break

# x: raw eeg [batch_size, sequence length, number of features]
# y: audio embedding [batch_size, sequence length, number of features]

torch.Size([2, 31914, 62]) torch.Size([2, 3188, 1024]) torch.Size([2]) torch.Size([2])
torch.float64 torch.float32 torch.int64 torch.int64


# models [KS]

Relevant Meta folders: [Features folder](https://github.com/facebookresearch/brainmagick/tree/main/bm/features) and [Model folder](https://github.com/facebookresearch/brainmagick/tree/main/bm/models).


**Shared model components and models:**

1. **[common.py](https://github.com/facebookresearch/brainmagick/blob/main/bm/models/common.py)** - collection of components used by other models

  - ScaledEmbedding

  - SubjectLayers

  - LayerScale

  - ConvSequence

  - DualPathRNN

  - PositionGetter

  - FourierEmb

  - ChannelDropout

  - ChannelMerger

2. **[convrnn.py](https://github.com/facebookresearch/brainmagick/blob/main/bm/models/convrnn.py)** - a model used both as an a encoder and a decoder.

  - LSTM

  - Attention

  - ConvRNN
  
    - SubjectLayers (subject layer)

    - ScaledEmbedding (subject embedding)
    
    - LSTM (bidirectional)

    - Attention (multi-head dot product)

    - ConvSequence (decoder)
  
    - Conv1d or Conv1d + ReLU + Conv1d (final)

3. **[simpleconv.py](https://github.com/facebookresearch/brainmagick/blob/main/bm/models/simpleconv.py)**

  - SimpleConv

    - takes a sample of channels (subsampled_meg_channels)
    
    - ChannelDropout

    - ChannelMerger

    - Conv1d + activations (initial layer)

    - SubjectLayers (subject layer)

    - ta.transforms.Spectrogram (short-time fourier transform)

    - ScaledEmbedding (subject embedding)

    - ConvSequence (encoder)

    - DualPathRNN
  
    - Conv1d or Conv1d + ReLU + Conv1d (final)

4. **[features.py](https://github.com/facebookresearch/brainmagick/blob/main/bm/models/features.py)** - model to extract features

  - DeepMel(ConvSequence)

**Combined encoders:**

1. **[deep_mel.yaml](https://github.com/facebookresearch/brainmagick/blob/main/bm/conf/feature_model/deep_mel.yaml)** - calls **features.py** - feature model

2. **[convrnn.yaml](https://github.com/facebookresearch/brainmagick/blob/main/bm/conf/model/convrnn.yaml)** - calls **convrnn.py**

**Combined decoders:**

1. **[clip_conv.yaml](https://github.com/facebookresearch/brainmagick/blob/main/bm/conf/model/clip_conv.yaml/)** - calls **simpleconv.py** - default model

2. **[decoder_convrnn.yaml](https://github.com/facebookresearch/brainmagick/blob/main/bm/conf/model/decoder_convrnn.yaml)** - calls **convrnn.py**


Other: All the imports https://github.com/facebookresearch/brainmagick/blob/main/requirements.txt

## toy data

In [49]:
batch_size = 2
eeg_num_channels = 4
eeg_time_steps = 5
audio_time_steps = 6

eeg_data = torch.randn(batch_size, eeg_time_steps, eeg_num_channels)
audio_data = torch.randn(batch_size, audio_time_steps)
vocabulary = ["word1", "word2", "word3", "word4", "word5"]

print('eeg_data.shape', eeg_data.shape)
print('eeg_data\n', eeg_data)
print('\naudio_data.shape', audio_data.shape)
print('audio_data\n', audio_data)

eeg_data.shape torch.Size([2, 5, 4])
eeg_data
 tensor([[[-1.1552,  0.5698, -0.0762, -0.2001],
         [-0.7471, -0.7212, -0.3877,  1.8738],
         [-0.6147,  1.2159,  0.3479,  1.5585],
         [-1.1880, -0.4776, -2.1453,  0.4058],
         [-2.0641,  0.4924,  1.5100,  0.7131]],

        [[ 0.2494, -1.3765,  0.1368,  0.4436],
         [ 0.0899,  2.5809,  0.1485, -0.2769],
         [-0.7275,  0.1031, -1.8410,  0.4143],
         [ 0.5914,  0.3040,  0.0953,  1.6982],
         [-0.9059,  0.0356, -0.3355, -1.1054]]])

audio_data.shape torch.Size([2, 6])
audio_data
 tensor([[-0.8494, -0.0652,  0.0875, -0.6615,  1.1431,  1.2393],
        [-1.0426, -1.0902,  0.7864, -0.8260, -0.1956,  1.2193]])


In [50]:
"""
eeg_num_channels = 2
eeg_time_steps = 3
audio_time_steps = 4

eeg_data_subject_0_segment_0 = torch.randn(eeg_time_steps, eeg_num_channels)
eeg_data_subject_0_segment_1 = torch.randn(eeg_time_steps, eeg_num_channels)
audio_data_subject_0_segment_0 = torch.randn(audio_time_steps)
audio_data_subject_0_segment_1 = torch.randn(audio_time_steps)

eeg_data_subject_1_segment_0 = torch.randn(eeg_time_steps, eeg_num_channels)
audio_data_subject_1_segment_0 = torch.randn(audio_time_steps)

data = {
    (0, 0): (eeg_data_subject_0_segment_0, audio_data_subject_0_segment_0),
    (0, 1): (eeg_data_subject_0_segment_1, audio_data_subject_0_segment_0),

    (1, 0): (eeg_data_subject_1_segment_0, audio_data_subject_1_segment_0)}


for key, (eeg, audio) in data.items():
    print(f"Subject ID: {key[0]}, Segment ID: {key[1]}")
    print("EEG Data Shape:", eeg.shape)
    print(eeg)
    print("Audio Data Shape:", audio.shape)
    print(audio)
    print()

batch_size = 2
data = {}
for subject_id in range(batch_size):
    for segment_id in range(batch_size):
        eeg_data = torch.randn(eeg_time_steps, eeg_num_channels)
        audio_data = torch.randn(audio_time_steps)
        data[(subject_id, segment_id)] = (eeg_data, audio_data)
data
"""

'\neeg_num_channels = 2\neeg_time_steps = 3\naudio_time_steps = 4\n\neeg_data_subject_0_segment_0 = torch.randn(eeg_time_steps, eeg_num_channels)\neeg_data_subject_0_segment_1 = torch.randn(eeg_time_steps, eeg_num_channels)\naudio_data_subject_0_segment_0 = torch.randn(audio_time_steps)\naudio_data_subject_0_segment_1 = torch.randn(audio_time_steps)\n\neeg_data_subject_1_segment_0 = torch.randn(eeg_time_steps, eeg_num_channels)\naudio_data_subject_1_segment_0 = torch.randn(audio_time_steps)\n\ndata = {\n    (0, 0): (eeg_data_subject_0_segment_0, audio_data_subject_0_segment_0),\n    (0, 1): (eeg_data_subject_0_segment_1, audio_data_subject_0_segment_0), \n\n    (1, 0): (eeg_data_subject_1_segment_0, audio_data_subject_1_segment_0)}\n\n\nfor key, (eeg, audio) in data.items():\n    print(f"Subject ID: {key[0]}, Segment ID: {key[1]}")\n    print("EEG Data Shape:", eeg.shape)\n    print(eeg)\n    print("Audio Data Shape:", audio.shape)\n    print(audio)\n    print()\n\nbatch_size = 2\ndata

## brain encoder 1

In [73]:
class SubjectLayers(nn.Module): # learn how different are the subjects and based on that normalize the EEG
    def __init__(self, in_channels, out_channels, n_subjects):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels)) # initialize weights for each subject
        self.weights.data *= 1 / in_channels**0.5 # normalize weights

    def forward(self, x, n_subjects):
        _, C, D = self.weights.shape
        subject_weights = self.weights.gather(0, n_subjects.view(-1, 1, 1).expand(-1, C, D)) # select the appropriate weights for each subject in the batch
        transformed_eeg = torch.einsum("bct,bcd->bdt", x, subject_weights) # apply the subject-specific transformations

        return transformed_eeg

class ScaledEmbedding(nn.Module): # assign a unique vector to each subject (similar to positional embedding)

    def __init__(self, n_subjects, embedding_dim, scale):

        super().__init__()
        self.embedding = nn.Embedding(n_subjects, embedding_dim)
        self.embedding.weight.data /= scale
        self.scale = scale

    def forward(self, x):
        scaled_embedding = self.embedding(x) * self.scale
        return scaled_embedding

class LayerScale(nn.Module):
    def __init__(self, channels, init = 0.1, boost = 5.):
        super().__init__()
        self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
        self.scale.data[:] = init / boost
        self.boost = boost

    def forward(self, x):
        return (self.boost * self.scale[:, None]) * x

class ConvSequence(nn.Module):

    def __init__(self,
                 channels = [16, 32, 64, 128],
                 kernel = 4, dilation_growth = 1, dilation_period = None,
                 stride = 2, dropout = 0.0, leakiness = 0.0,
                 groups = 1, decode = False, batch_norm = False,
                 dropout_input = 0, skip = False, scale = None,
                 rewrite = False, activation_on_last = True,
                 post_skip = False, glu = 0, glu_context = 0,
                 glu_glu = True, activation = None):

        super().__init__()
        dilation = 1
        channels = tuple(channels)
        self.skip = skip
        self.sequence = nn.ModuleList()
        self.glus = nn.ModuleList()
        if activation is None:
            activation = partial(nn.LeakyReLU, leakiness)
        Conv = nn.Conv1d if not decode else nn.ConvTranspose1d
        # build layers
        for k, (chin, chout) in enumerate(zip(channels[:-1], channels[1:])):
            layers: tp.List[nn.Module] = []
            is_last = k == len(channels) - 2

            # Set dropout for the input of the conv sequence if defined
            if k == 0 and dropout_input:
                assert 0 < dropout_input < 1
                layers.append(nn.Dropout(dropout_input))

            # conv layer
            if dilation_growth > 1:
                assert kernel % 2 != 0 # supports only odd kernel with dilation
            if dilation_period and (k % dilation_period) == 0:
                dilation = 1
            pad = kernel // 2 * dilation
            layers.append(Conv(chin, chout, kernel, stride, pad,
                               dilation=dilation, groups=groups if k > 0 else 1))
            dilation *= dilation_growth
            # non-linearity
            if activation_on_last or not is_last:
                if batch_norm:
                    layers.append(nn.BatchNorm1d(num_features=chout))
                layers.append(activation())
                if dropout:
                    layers.append(nn.Dropout(dropout))
                if rewrite:
                    layers += [nn.Conv1d(chout, chout, 1), nn.LeakyReLU(leakiness)]
            if chin == chout and skip:
                if scale is not None:
                    layers.append(LayerScale(chout, scale))
                if post_skip:
                    layers.append(Conv(chout, chout, 1, groups=chout, bias=False))

            self.sequence.append(nn.Sequential(*layers))
            if glu and (k + 1) % glu == 0:
                ch = 2 * chout if glu_glu else chout
                act = nn.GLU(dim=1) if glu_glu else activation()
                self.glus.append(
                    nn.Sequential(
                        nn.Conv1d(chout, ch, 1 + 2 * glu_context, padding=glu_context), act))
            else:
                self.glus.append(None)

    def forward(self, x):
        for module_idx, module in enumerate(self.sequence):
            old_x = x
            x = module(x)
            if self.skip and x.shape == old_x.shape:
                x = x + old_x
            glu = self.glus[module_idx]
            if glu is not None:
                x = glu(x)
        return x

class Attention(nn.Module): # scaled dot product with relative position encoding and local attention
    def __init__(self, channels, radius = 50, heads = 4):
        super().__init__()
        assert channels % heads == 0
        self.content = nn.Conv1d(channels, channels, 1)
        self.query = nn.Conv1d(channels, channels, 1)
        self.key = nn.Conv1d(channels, channels, 1)
        self.embedding = nn.Embedding(radius * 2 + 1, channels // heads)
        weight = self.embedding.weight.data
        weight[:] = weight.cumsum(0) / torch.arange(1, len(weight) + 1).float().view(-1, 1).sqrt()
        self.heads = heads
        self.radius = radius
        self.bn = nn.BatchNorm1d(channels)
        self.fc = nn.Conv1d(channels, channels, 1)
        self.scale = nn.Parameter(torch.full([channels], 0.1))

    def forward(self, x):

        def _split(y):
            return y.view(y.shape[0], self.heads, -1, y.shape[2])

        content = _split(self.content(x))
        query = _split(self.query(x))
        key = _split(self.key(x))

        batch_size, _, dim, length = content.shape

        dots = torch.einsum("bhct,bhcs->bhts", query, key) # first index `t` is query, second index `s` is key.

        steps = torch.arange(length, device=x.device)
        relative = (steps[:, None] - steps[None, :])
        embs = self.embedding.weight.gather(0, self.radius + relative.clamp_(-self.radius, self.radius).view(-1, 1).expand(-1, dim))
        embs = embs.view(length, length, -1)
        dots += 0.3 * torch.einsum("bhct,tsc->bhts", query, embs)
        dots = torch.where(
            relative.abs() <= self.radius, dots, torch.tensor(-float('inf')).to(embs))

        weights = torch.softmax(dots, dim=-1)
        out = torch.einsum("bhts,bhcs->bhct", weights, content)
        out += 0.3 * torch.einsum("bhts,tsc->bhct", weights, embs)
        out = out.reshape(batch_size, -1, length)
        out = F.relu(self.bn(self.fc(out))) * self.scale.view(1, -1, 1)
        return out

"""
# TEST 1
def test_subject_layers():
    batch_size = 2
    in_channels = 3
    out_channels = 2
    n_subjects = 1
    time_steps = 2
    eeg_data = torch.randn(batch_size, in_channels, time_steps)
    print(eeg_data)
    subjects = torch.randint(0, n_subjects, (batch_size,))
    print(subjects)
    subject_layers = SubjectLayers(in_channels, out_channels, n_subjects)
    output = subject_layers(eeg_data, subjects)
    expected_shape = (batch_size, out_channels, time_steps)
    assert output.shape == expected_shape, f"Output shape mismatch: expected {expected_shape}, got {output.shape}"
test_subject_layers()

# TEST 2
def scaled_embedding():
  n_subjects = 10
  embedding_dim = 5
  scale = 10
  scaled_embedding = ScaledEmbedding(n_subjects, embedding_dim, scale)
  subject_indices = torch.tensor([1, 2])
  embeddings = scaled_embedding(subject_indices)
  print(embeddings)
scaled_embedding()

# TEST 3
conv_sequence = ConvSequence(channels=[16, 32, 64, 128])
input_tensor = torch.randn(1, 16, 50) # (Batch Size, Channels, Length)
output = conv_sequence(input_tensor)
output.shape
#summary(conv_sequence, input_size=(16, 50))
"""

'\n# TEST 1\ndef test_subject_layers():\n    batch_size = 2\n    in_channels = 3\n    out_channels = 2\n    n_subjects = 1\n    time_steps = 2\n    eeg_data = torch.randn(batch_size, in_channels, time_steps) \n    print(eeg_data)\n    subjects = torch.randint(0, n_subjects, (batch_size,))\n    print(subjects)\n    subject_layers = SubjectLayers(in_channels, out_channels, n_subjects)\n    output = subject_layers(eeg_data, subjects)\n    expected_shape = (batch_size, out_channels, time_steps)\n    assert output.shape == expected_shape, f"Output shape mismatch: expected {expected_shape}, got {output.shape}"\ntest_subject_layers()\n\n# TEST 2\ndef scaled_embedding():\n  n_subjects = 10\n  embedding_dim = 5    \n  scale = 10      \n  scaled_embedding = ScaledEmbedding(n_subjects, embedding_dim, scale)\n  subject_indices = torch.tensor([1, 2]) \n  embeddings = scaled_embedding(subject_indices)\n  print(embeddings)\nscaled_embedding()\n\n# TEST 3\nconv_sequence = ConvSequence(channels=[16, 32

In [74]:
class EEG_Encoder(nn.Module):

    def __init__(self,

                 in_channels = 64, n_subjects = 30,

                 # subject layers
                 out_channels = 128,

                 # conv
                 conv_chanels = [16, 32, 64, 128],
                 kernel = 4, stride = 2, conv_dropout = 0,
                 batch_norm = False, dropout_input = 0,
                 leakiness = 0,

                 # lstm
                 hidden_size = 128, lstm_layers = 4, lstm_dropout = 0.1,

                 # attention
                 attention_heads = 4, subject_dim = 64, embedding_scale = 1.0):

        super().__init__()

        # subject layers (normalization across subjects)
        self.subject_layers = SubjectLayers(in_channels, out_channels, n_subjects)

        # scaled embedding (optional)
        # self.subject_embedding = ScaledEmbedding(n_subjects, subject_dim, embedding_scale)

        # convsequence
        self.convs = ConvSequence(channels=conv_chanels, kernel = kernel_size,
                                  stride = stride, dropout = conv_dropout,
                                  batch_norm = batch_norm, leakiness = leakiness,
                                  dropout_input = dropout_input)

        # lstm
        self.lstm = nn.LSTM(input_size    = out_channels,
                            hidden_size   = hidden_size//2,
                            num_layers    = lstm_layers,
                            dropout       = lstm_dropout,
                            bidirectional = True,
                            batch_first    = True)

        # attention
        self.attention = Attention(hidden_size, heads=attention_heads)

        # final linear layer
        self.finalconv1 = nn.Conv1d(hidden_size, out_channels, kernel_size = 1)

    def forward(self, eeg_inputs): # to pass as an additional paramater n_subjects=10

        # subject layers
        # subject_indices = torch.arange(n_subjects).to(eeg_inputs.device)
        # normalized_eeg_inputs = self.subject_layers(eeg_inputs, subject_indices)

        # scaled embedding // scaled_embedding = self.subject_embedding(subject_indices)

        # convsequence
        print('eeg_inputs before convs', eeg_inputs.shape)
        out = self.convs(eeg_inputs)
        print('eeg_inputs after convs', out.shape)
        print("After convs:", torch.isnan(out).any())

        # lstm
        out = out.permute(2, 0, 1)
        out, _ = self.lstm(out)
        out = out.permute(1, 2, 0)
        print("After LSTM:", torch.isnan(out).any())

        # attention
        out = out + self.attention(out)
        print("After Attention:", torch.isnan(out).any())

        # decoder
        # out = self.decoder(out)

        # final
        out = self.finalconv1(out)
        print("Final Output:", torch.isnan(out).any())

        return out

# TEST
in_channels = 64
n_subjects = 30
out_channels = 128
conv_channels = [64, 32, 64, 128]
kernel_size = 4
stride = 2
conv_dropout = 0
batch_norm = False
dropout_input = 0
leakiness = 0
hidden_size = 128
lstm_layers = 4
lstm_dropout = 0.1
attention_heads = 4
subject_dim = 64
embedding_scale = 1.0

eeg_encoder = EEG_Encoder(in_channels, n_subjects, out_channels, conv_channels, kernel_size, stride,
                          conv_dropout, batch_norm, dropout_input, leakiness, hidden_size, lstm_layers,
                          lstm_dropout, attention_heads, subject_dim, embedding_scale)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
eeg_encoder.to(device)
batch_size = 2
in_channels = 64
sequence_length = 128
x_sample = torch.rand(batch_size, in_channels, sequence_length)
summary(eeg_encoder, x_sample.to(device))

eeg_inputs before convs torch.Size([2, 64, 128])
eeg_inputs after convs torch.Size([2, 128, 17])
After convs: tensor(False)
After LSTM: tensor(False)
After Attention: tensor(False)
Final Output: tensor(False)
                                 Kernel Shape  Output Shape    Params  \
Layer                                                                   
0_convs.sequence.0.Conv1d_0       [64, 32, 4]   [2, 32, 65]    8.224k   
1_convs.sequence.0.LeakyReLU_1              -   [2, 32, 65]         -   
2_convs.sequence.1.Conv1d_0       [32, 64, 4]   [2, 64, 33]    8.256k   
3_convs.sequence.1.LeakyReLU_1              -   [2, 64, 33]         -   
4_convs.sequence.2.Conv1d_0      [64, 128, 4]  [2, 128, 17]   32.896k   
5_convs.sequence.2.LeakyReLU_1              -  [2, 128, 17]         -   
6_lstm                                      -  [17, 2, 128]  397.312k   
7_attention.Conv1d_content      [128, 128, 1]  [2, 128, 17]   16.512k   
8_attention.Conv1d_query        [128, 128, 1]  [2, 128, 17]  

  df_sum = df.sum()


Unnamed: 0_level_0,Kernel Shape,Output Shape,Params,Mult-Adds
Layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0_convs.sequence.0.Conv1d_0,"[64, 32, 4]","[2, 32, 65]",8224.0,532480.0
1_convs.sequence.0.LeakyReLU_1,-,"[2, 32, 65]",,
2_convs.sequence.1.Conv1d_0,"[32, 64, 4]","[2, 64, 33]",8256.0,270336.0
3_convs.sequence.1.LeakyReLU_1,-,"[2, 64, 33]",,
4_convs.sequence.2.Conv1d_0,"[64, 128, 4]","[2, 128, 17]",32896.0,557056.0
5_convs.sequence.2.LeakyReLU_1,-,"[2, 128, 17]",,
6_lstm,-,"[17, 2, 128]",397312.0,393216.0
7_attention.Conv1d_content,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0
8_attention.Conv1d_query,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0
9_attention.Conv1d_key,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0


## brain encoder 2

## audio encoder

In [None]:
# Speech model

class LayerScale(nn.Module):
    def __init__(self, channels: int, init: float = 0.1, boost: float = 5.):
        super().__init__()
        self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
        self.scale.data[:] = init / boost
        self.boost = boost

    def forward(self, x):
        return (self.boost * self.scale[:, None]) * x

class ConvSequence(nn.Module):

    def __init__(self,
                 channels: tp.Sequence[int],
                 kernel: int = 4,
                 dilation_growth: int = 1,
                 dilation_period: tp.Optional[int] = None,
                 stride: int = 2,
                 dropout: float = 0.0,
                 leakiness: float = 0.0,
                 groups: int = 1,
                 decode: bool = False,
                 batch_norm: bool = False,
                 dropout_input: float = 0,
                 skip: bool = False,
                 scale: tp.Optional[float] = None,
                 rewrite: bool = False,
                 activation_on_last: bool = True,
                 post_skip: bool = False,
                 glu: int = 0,
                 glu_context: int = 0,
                 glu_glu: bool = True,
                 activation: tp.Any = None) -> None:
        super().__init__()
        dilation = 1
        channels = tuple(channels)
        self.skip = skip
        self.sequence = nn.ModuleList()
        self.glus = nn.ModuleList()
        if activation is None:
            activation = partial(nn.LeakyReLU, leakiness)
        Conv = nn.Conv1d if not decode else nn.ConvTranspose1d
        # build layers
        for k, (chin, chout) in enumerate(zip(channels[:-1], channels[1:])):
            layers: tp.List[nn.Module] = []
            is_last = k == len(channels) - 2

            # Set dropout for the input of the conv sequence if defined
            if k == 0 and dropout_input:
                assert 0 < dropout_input < 1
                layers.append(nn.Dropout(dropout_input))

            # conv layer
            if dilation_growth > 1:
                assert kernel % 2 != 0, "Supports only odd kernel with dilation for now"
            if dilation_period and (k % dilation_period) == 0:
                dilation = 1
            pad = kernel // 2 * dilation
            layers.append(Conv(chin, chout, kernel, stride, pad,
                               dilation=dilation, groups=groups if k > 0 else 1))
            dilation *= dilation_growth
            # non-linearity
            if activation_on_last or not is_last:
                if batch_norm:
                    layers.append(nn.BatchNorm1d(num_features=chout))
                layers.append(activation())
                if dropout:
                    layers.append(nn.Dropout(dropout))
                if rewrite:
                    layers += [nn.Conv1d(chout, chout, 1), nn.LeakyReLU(leakiness)]
                    # layers += [nn.Conv1d(chout, 2 * chout, 1), nn.GLU(dim=1)]
            if chin == chout and skip:
                if scale is not None:
                    layers.append(LayerScale(chout, scale))
                if post_skip:
                    layers.append(Conv(chout, chout, 1, groups=chout, bias=False))

            self.sequence.append(nn.Sequential(*layers))
            if glu and (k + 1) % glu == 0:
                ch = 2 * chout if glu_glu else chout
                act = nn.GLU(dim=1) if glu_glu else activation()
                self.glus.append(
                    nn.Sequential(
                        nn.Conv1d(chout, ch, 1 + 2 * glu_context, padding=glu_context), act))
            else:
                self.glus.append(None)

    def forward(self, x: tp.Any) -> tp.Any:
        for module_idx, module in enumerate(self.sequence):
            old_x = x
            x = module(x)
            if self.skip and x.shape == old_x.shape:
                x = x + old_x
            glu = self.glus[module_idx]
            if glu is not None:
                x = glu(x)
        return x

class DeepMel(ConvSequence):
    """DeepMel model that extracts features from the Mel spectrogram.

    Parameters
    ----------
    n_in_channels :
        Number of input channels.
    n_hidden_channels :
        Number of channels in hidden layers.
    n_hidden_layers :
        Number of hidden layers.
    n_out_channels :
        Number of output channels.
    kwargs:
        Additional keyword arguments to pass to ConvSequence.
    """
    def __init__(self,
                 n_in_channels: int,
                 n_hidden_channels: int,
                 n_hidden_layers: int,
                 n_out_channels: int, **kwargs):

        channels = \
            [n_in_channels] + [n_hidden_channels] * (n_hidden_layers - 1) + [n_out_channels]

        super().__init__(channels, **kwargs)

# Test the model
n_in_channels = 16
n_hidden_channels = 32
n_hidden_layers = 3
n_out_channels = 64
model = DeepMel(n_in_channels, n_hidden_channels, n_hidden_layers, n_out_channels)
input_size = (n_in_channels, 128)  # Example input size (channels, sequence length)
summary(model, input_size=input_size, device="cpu")

## clip loss

In [75]:
class ClipLoss(torch.nn.Module): # CLIP (See Open AI CLIP) constrastive loss

    def __init__(self, linear=None, twin=True, pool=False, tmin=None, tmax=None,
                 tmin_train=None, tmax_train=None, dset_args=None, center=False):
        super().__init__()
        self.linear = None
        self.pool = pool
        self.center = center
        if linear is not None:
            self.linear_est = torch.nn.LazyLinear(linear)
            if twin:
                self.linear_gt = self.linear_est
            else:
                self.linear_gt = torch.nn.LazyLinear(linear)
        self.tmin = tmin
        self.tmax = tmax
        self.tmin_train = tmin_train
        self.tmax_train = tmax_train
        self.dset_args = dset_args

    def trim_samples(self, estimates, candidates):
        """Given estimates that is [B1, C, T] and candidates
        which is [B2, C, T], return estimates_trim of size [B1, C, T']
        and candidates_trim of size [B2, C, T'], such that T'
        corresponds to the samples between [self.tmin, self.tmax]
        """
        if self.training and (self.tmin_train is not None or self.tmax_train is not None):
            tmin, tmax = self.tmin_train, self.tmax_train
        else:
            tmin, tmax = self.tmin, self.tmax
        if (tmin is not None) or (tmax is not None):
            assert self.dset_args is not None
            assert self.dset_args.tmin is not None
            dset_tmin = self.dset_args.tmin
        if tmin is None:
            trim_min = 0
        else:
            assert tmin >= dset_tmin, 'clip.tmin should be above dset.tmin'
            trim_min = int((-dset_tmin + tmin) * self.dset_args.sample_rate)
        if tmax is None:
            trim_max = estimates.shape[-1]
        else:
            trim_max = int((-dset_tmin + tmax) * self.dset_args.sample_rate)
        estimates_trim = estimates[..., trim_min:trim_max]
        candidates_trim = candidates[..., trim_min:trim_max]
        return estimates_trim, candidates_trim

    def get_scores(self, estimates: torch.Tensor, candidates: torch.Tensor):
        """Given estimates that is [B, C, T] and candidates
        which is [B', C, T], return a [B, B'] matrix of scores of matching.
        """
        estimates, candidates = self.trim_samples(estimates, candidates)
        if self.linear:
            estimates = self.linear_est(estimates)
            candidates = self.linear_gt(candidates)
        if self.pool:
            estimates = estimates.mean(dim=2, keepdim=True)
            candidates = candidates.mean(dim=2, keepdim=True)
        if self.center:
            estimates = estimates - estimates.mean(dim=(1, 2), keepdim=True)
            candidates = candidates - candidates.mean(dim=(1, 2), keepdim=True)
        inv_norms = 1 / (1e-8 + candidates.norm(dim=(1, 2), p=2))
        scores = torch.einsum("bct,oct,o->bo", estimates, candidates, inv_norms)
        return scores

    def get_probabilities(self, estimates, candidates):
        """Given estimates that is [B, C, T] and candidates
        which is [B', C, T], return a [B, B'] matrix of probabilities of matching.
        """
        scores = self.get_scores(estimates, candidates)
        probabilities = F.softmax(scores, dim=1)
        return probabilities

    def forward(self, estimate, candidate, mask=None):
        """Warning: estimate and candidate are not symmetrical.
        If estimate of shape [B, C, T] and candidate of size [B', C, T]
        with B'>=B, the first B samples of candidate are targets, while
        the remaining B'-B samples of candidate are only used as negatives.
        """
        assert estimate.size(0) <= candidate.size(0), "need at least as many targets as estimates"
        scores = self.get_scores(estimate, candidate)
        target = torch.arange(len(scores), device=estimate.device)
        loss = F.cross_entropy(scores, target)
        return loss

# TEST
batch_size = 3
num_channels = 4
time_steps = 5
estimates = torch.randn(batch_size, num_channels, time_steps)
candidates = torch.randn(batch_size, num_channels, time_steps)
clip_loss = ClipLoss()

probabilities = clip_loss.get_probabilities(estimates, candidates) # --> output in the model's forward function
loss = clip_loss(estimates, candidates) # --> use a loss

print(probabilities)
print(loss)

### Each row in the probability matrix corresponds to a set of EEG data,
#   and each column corresponds to a set of audio data.
#   The values in the matrix are probabilities that indicate
#   how likely it is that a given set of EEG data matches a given set of audio data.

tensor([[0.3387, 0.1995, 0.4617],
        [0.7710, 0.1033, 0.1257],
        [0.1687, 0.6365, 0.1948]])
tensor(1.6630)


In [None]:
# estimates = egg[0], eeg[1], eeg[2]
# candidates = audio[0], audio[1], audio[2]

# egg[0] vs audio[0] - positive pair
# egg[0] vs audio[1] - negative pair
# egg[0] vs audio[2] - negative pair

# egg[0] vs audio[1] - positive pair
# egg[0] vs audio[0] - negative pair
# egg[0] vs audio[2] - negative pair

## brain-to-audio model

In [77]:
class BrainAudioModel(torch.nn.Module):

  def __init__(# brain encoding
               self, in_channels, model_chout

               # audio encoding

               # brain to audio probabilities

               ):

    super().__init__()

    # brain encoding
    self.brain_encoding = EEG_Encoder(in_channels=in_channels,
                                      n_subjects=30,
                                      out_channels=model_chout,
                                      conv_chanels=[64, 32, 64, 128],
                                      kernel=4,
                                      stride=2,
                                      conv_dropout=0,
                                      batch_norm=False,
                                      dropout_input=0,
                                      leakiness=0,
                                      hidden_size=128,
                                      lstm_layers=4,
                                      lstm_dropout=0.1,
                                      attention_heads=4,
                                      subject_dim=64,
                                      embedding_scale=1.0)

    # audio encoding
    #self.audio_encoding = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # https://huggingface.co/facebook/wav2vec2-base-960h

    # brain to audio probabilities
    self.clip = ClipLoss(linear=None, twin=True, pool=False, tmin=None, tmax=None,
                         tmin_train=None, tmax_train=None, dset_args=None, center=False)

    # linear layers to match the dimensions
    self.linear1 = nn.Linear(128, model_chout)

  """
  def freeze_audio_embedding(self): # we need to call it after initializing the model model.freeze_audio_embedding()
    for param in self.audio_embedding.parameters():
      param.requires_grad = False
  """

  def forward(self, raw_eeg_inputs, audio_hidden_states):  # raw_audio_inputs

    # brain encoding
    brain_encoder_outputs = self.brain_encoding(raw_eeg_inputs)
    print('brain_encoder_outputs', brain_encoder_outputs.shape)

    print('audio_hidden_states', audio_hidden_states.shape)

    # audio encoding (optional, could be passed as a parameter to forward)
    #audio_encoder_outputs = self.audio_encoding(raw_audio_inputs).last_hidden_state
    #audio_encoder_outputs = self.linear(audio_encoder_outputs)
    audio_hidden_states_processed = self.linear1(audio_hidden_states)
    print('audio_hidden_states_processed', audio_hidden_states_processed.shape)

    # brain to audio probabilities
    brain_encoder_outputs_reshaped = brain_encoder_outputs.permute(0, 2, 1) # reshaping to [B, C, T]
    audio_hidden_states_reshaped = audio_hidden_states_processed.permute(0, 2, 1) # reshaping to [B, C, T]
    probabilities = self.clip.get_probabilities(brain_encoder_outputs_reshaped, audio_hidden_states_reshaped)
    print('probabilities', probabilities)
    return probabilities

# TEST
in_channels = 64  # EEG input channels
model_chout = 17  # output channels
model = BrainAudioModel(in_channels, model_chout)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
eeg_input_size = (2, in_channels, 128)  # Batch size, channels, sequence length
audio_hidden_state_size = (2, model_chout, 128)  # Assuming the same sequence length and batch size
eeg_dummy_input = torch.randn(eeg_input_size).to(device)
audio_hidden_state_dummy = torch.randn(audio_hidden_state_size).to(device)
summary(model, eeg_dummy_input, audio_hidden_state_dummy)

eeg_inputs before convs torch.Size([2, 64, 128])
eeg_inputs after convs torch.Size([2, 128, 17])
After convs: tensor(False)
After LSTM: tensor(True)
After Attention: tensor(True)
Final Output: tensor(True)
brain_encoder_outputs torch.Size([2, 17, 17])
audio_hidden_states torch.Size([2, 17, 128])
audio_hidden_states_processed torch.Size([2, 17, 17])
probabilities tensor([[nan, nan],
        [nan, nan]])
                                                Kernel Shape  Output Shape  \
Layer                                                                        
0_brain_encoding.convs.sequence.0.Conv1d_0       [64, 32, 4]   [2, 32, 65]   
1_brain_encoding.convs.sequence.0.LeakyReLU_1              -   [2, 32, 65]   
2_brain_encoding.convs.sequence.1.Conv1d_0       [32, 64, 4]   [2, 64, 33]   
3_brain_encoding.convs.sequence.1.LeakyReLU_1              -   [2, 64, 33]   
4_brain_encoding.convs.sequence.2.Conv1d_0      [64, 128, 4]  [2, 128, 17]   
5_brain_encoding.convs.sequence.2.LeakyReLU_1   

  df_sum = df.sum()


Unnamed: 0_level_0,Kernel Shape,Output Shape,Params,Mult-Adds
Layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0_brain_encoding.convs.sequence.0.Conv1d_0,"[64, 32, 4]","[2, 32, 65]",8224.0,532480.0
1_brain_encoding.convs.sequence.0.LeakyReLU_1,-,"[2, 32, 65]",,
2_brain_encoding.convs.sequence.1.Conv1d_0,"[32, 64, 4]","[2, 64, 33]",8256.0,270336.0
3_brain_encoding.convs.sequence.1.LeakyReLU_1,-,"[2, 64, 33]",,
4_brain_encoding.convs.sequence.2.Conv1d_0,"[64, 128, 4]","[2, 128, 17]",32896.0,557056.0
5_brain_encoding.convs.sequence.2.LeakyReLU_1,-,"[2, 128, 17]",,
6_brain_encoding.LSTM_lstm,-,"[17, 2, 128]",340480.0,336384.0
7_brain_encoding.attention.Conv1d_content,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0
8_brain_encoding.attention.Conv1d_query,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0
9_brain_encoding.attention.Conv1d_key,"[128, 128, 1]","[2, 128, 17]",16512.0,278528.0


# losses and metrics

# trainer


In [None]:
class Trainer:
    def __init__(self, model, loader, optimizer, criterion, scheduler, max_epochs= 1, run_id= 'exp'):

        self.model      = model
        self.loader     = loader
        self.optimizer  = optimizer
        self.criterion  = criterion
        self.scheduler = scheduler

        self.train_losses           = []
        self.val_losses             = []
        self.prediction_probs       = []
        self.prediction_probs_test  = []
        self.generated_texts_test   = []
        self.epochs                 = 0
        self.max_epochs             = max_epochs
        self.run_id                 = run_id


    def calculate_loss(self, out, target):
        # output: (B, T, Vocab_size) - probability distributions
        # target: (B, T)
        # Read the documentation of CrossEntropyLoss and try to understand how it takes inputs

        # Tip: If your target is of shape (B, T) it means that you have B batches with T words.
        # Tip: What is the total number of words in this batch?
        # Tip: Crossentropy calculates the loss between a label and its probability distribution.

        out     = out.view(-1, out.size(-1)) # TODO
        targets = target.view(-1)  # TODO
        loss    = self.criterion(out, targets)

        return loss


    def train(self):

        self.model.train() # set to training mode
        self.model.to(DEVICE)
        epoch_loss  = 0
        num_batches = 0

        for batch_num, (inputs, targets) in enumerate(tqdm(self.loader)):

            # TODO: Complete the loop. You should be able to complete this without any helper comments after 3 HWs
            # Tip: Mixed precision training
            # For loss calculation, use the calculate_loss function. You need to complete it before using.

            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            self.optimizer.zero_grad()

            output, _ = self.model(inputs)
            loss = self.calculate_loss(output, targets)
            loss.backward()
            self.optimizer.step()

            loss_val = loss.item()
            epoch_loss += loss_val
            num_batches += 1

        #epoch_loss = epoch_loss / (batch_num + 1)
        epoch_loss = epoch_loss / num_batches
        self.epochs += 1
        print('[TRAIN] \tEpoch [%d/%d] \tLoss: %.4f \tLr: %.6f'
                      % (self.epochs, self.max_epochs, epoch_loss, self.optimizer.param_groups[0]['lr']))
        self.train_losses.append(epoch_loss)


    def test(self): # Don't change this function

        self.model.eval() # set to eval mode
        prediction_probs     = self.model.predict(fixtures_pred['inp']).detach().cpu().numpy() # get predictions
        self.prediction_probs.append(prediction_probs)

        generated_indexes_test   = self.model.generate(fixtures_gen_test, 10).detach().cpu().numpy() # generated predictions for 10 words

        nll                   = get_prediction_nll(prediction_probs, fixtures_pred['out'])
        generated_texts_test  = make_generation_text(fixtures_gen_test, generated_indexes_test, VOCAB)
        self.val_losses.append(nll)

        self.generated_texts_test.append(generated_texts_test)

        # generate predictions for test data
        prediction_probs_test = self.model.predict(fixtures_pred_test['inp']).detach().cpu().numpy() # get predictions
        self.prediction_probs_test.append(prediction_probs_test)

        print('[VAL] \tEpoch [%d/%d] \tLoss: %.4f'
                      % (self.epochs, self.max_epochs, nll))
        return nll


    def save(self): # Don't change this function

        model_path = os.path.join('hw4/experiments', self.run_id, 'model-{}.pkl'.format(self.epochs))
        torch.save({'state_dict': self.model.state_dict()}, model_path)
        np.save(os.path.join('hw4/experiments', self.run_id, 'prediction-probs-{}.npy'.format(self.epochs)), self.prediction_probs[-1])
        np.save(os.path.join('hw4/experiments', self.run_id, 'prediction-probs-test-{}.npy'.format(self.epochs)), self.prediction_probs_test[-1])

        with open(os.path.join('hw4/experiments', self.run_id, 'generated-texts-{}-test.txt'.format(self.epochs)), 'w') as fw:
            fw.write(self.generated_texts_test[-1])

# experiments

In [None]:
configs = dict(

    # model
    embedding_dim=600,
    hidden_dim=600,
    embedding_dropout=0.3,
    locked_dropout=0.5,

    # loader
    batch_size = 128,
    sequence_length = 30,

    # optimizer
    init_lr = 0.002,
    weight_decay = 5e-3,

    # scheduler
    factor = 0.9,
    patience = 1,

    # trainer
    num_epochs = 40)

In [None]:
model       = LanguageModel(vocab_size=len(VOCAB), embedding_dim=configs['embedding_dim'], hidden_dim=configs['hidden_dim'],
                            embedding_dropout=configs['embedding_dropout'], locked_dropout=configs['locked_dropout'])

loader      = DataLoaderForLanguageModeling(dataset, batch_size=configs['batch_size'], shuffle=True, drop_last=True, sequence_length=configs['sequence_length'])
inputs, targets = next(iter(loader))

criterion   = torch.nn.CrossEntropyLoss()

optimizer   = AdamW(model.parameters(), lr=configs['init_lr'], betas=(0.9, 0.999), eps=1e-8, weight_decay=configs['weight_decay'], amsgrad=False)

scheduler   = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=configs['factor'], patience=configs['patience'], verbose=True)

torchsummaryX.summary(model, inputs) #print(model)

In [89]:
trainer = Trainer(
    model       = model,
    loader      = loader,

    optimizer   = optimizer,
    criterion   = criterion,
    scheduler   = scheduler,

    max_epochs  = configs['num_epochs'],
    run_id      = run_id)

NameError: ignored

# evaluation

## viz