# EEG to Speech Transformer with Conformer as encoder

In [2]:
import os
import sys
import random
import warnings
import subprocess

import math
import pickle
from glob import glob
from functools import partial
from tqdm.notebook import tqdm, trange
from typing import List, Tuple, Optional
from IPython.display import Audio, FileLink, display, clear_output

import librosa
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torchaudio
import torch.nn as nn
torch.manual_seed(1337)
from torch.utils.data import Dataset, DataLoader

from wavelets_pytorch.transform import WaveletTransformTorch
from wavelets_pytorch.wavelets import Morlet

from sklearn.decomposition import PCA
# clear_output()

# Configuration

In [None]:
base = '/kaggle/input/internal-speech-recognition/Vartanov/audios'

A = "A.wav"
B = "B.wav"
F = "F.wav"
G = "G.wav"
M = "M.wav"
R = "R.wav"
U = "U.wav"

Ba = "Ba.wav"
Bu = "Bu.wav"
Fa = "Fa.wav"
Fu = "Fu.wav"
Ga = "Ga.wav"
Gu = "Gu.wav"
Ma = "Ma.wav"
Mu = "Mu.wav"
Ra = "Ra.wav"
Ru = "Ru.wav"

Biblioteka = "St1.wav"
Raketa = "St2.wav"
Kurier = "St3.wav"
Ograda = "St4.wav"
Haketa = "St5.wav"

phonemes_m3_labels = {
    12: os.path.join(base, "phonemes", A),
    22: os.path.join(base, "phonemes", A),
    13: os.path.join(base, "phonemes", B),
    23: os.path.join(base, "phonemes", B),
    14: os.path.join(base, "phonemes", F),
    24: os.path.join(base, "phonemes", F),
    15: os.path.join(base, "phonemes", G),
    25: os.path.join(base, "phonemes", G),
    16: os.path.join(base, "phonemes", M),
    26: os.path.join(base, "phonemes", M),
    17: os.path.join(base, "phonemes", R),
    27: os.path.join(base, "phonemes", R),
    18: os.path.join(base, "phonemes", U),
    28: os.path.join(base, "phonemes", U)
}

phonemes_m4_labels = {
    1: os.path.join(base, "phonemes", A),
    11: os.path.join(base, "phonemes", A),
    2: os.path.join(base, "phonemes", B),
    12: os.path.join(base, "phonemes", B),
    3: os.path.join(base, "phonemes", F),
    13: os.path.join(base, "phonemes", F),
    4: os.path.join(base, "phonemes", G),
    14: os.path.join(base, "phonemes", G),
    5: os.path.join(base, "phonemes", M),
    15: os.path.join(base, "phonemes", M),
    6: os.path.join(base, "phonemes", R),
    16: os.path.join(base, "phonemes", R),
    7: os.path.join(base, "phonemes", U),
    17: os.path.join(base, "phonemes", U)
}

syllables_labels = {
    1: os.path.join(base, "syllables", Ba),
    11: os.path.join(base, "syllables", Ba),
    2: os.path.join(base, "syllables", Fa),
    12: os.path.join(base, "syllables", Fa),
    3: os.path.join(base, "syllables", Ga),
    13: os.path.join(base, "syllables", Ga),
    4: os.path.join(base, "syllables", Ma),
    14: os.path.join(base, "syllables", Ma),
    5: os.path.join(base, "syllables", Ra),
    15: os.path.join(base, "syllables", Ra),
    6: os.path.join(base, "syllables", Bu),
    16: os.path.join(base, "syllables", Bu),
    7: os.path.join(base, "syllables", Ru),
    17: os.path.join(base, "syllables", Ru),
    8: os.path.join(base, "syllables", Mu),
    18: os.path.join(base, "syllables", Mu),
    9: os.path.join(base, "syllables", Fu),
    19: os.path.join(base, "syllables", Fu),
    10: os.path.join(base, "syllables", Gu),
    20: os.path.join(base, "syllables", Gu)
}

words_labels = {
    11: os.path.join(base, "words", Biblioteka),
    21: os.path.join(base, "words", Biblioteka),
    12: os.path.join(base, "words", Raketa),
    22: os.path.join(base, "words", Raketa),
    13: os.path.join(base, "words", Kurier),
    23: os.path.join(base, "words", Kurier),
    14: os.path.join(base, "words", Ograda),
    24: os.path.join(base, "words", Ograda),
    15: os.path.join(base, "words", Haketa),
    25: os.path.join(base, "words", Haketa)
}

sections = ["syllables", "phonemes_m3", "phonemes_m4", "words"]

audio_map = {
    "syllables": syllables_labels,
    "phonemes_m3": phonemes_m3_labels,
    "phonemes_m4": phonemes_m4_labels,
    "words": words_labels
}

config = {
    
    'path': '/kaggle/input/internal-speech-recognition/Vartanov/feather',
    'audio_maps': audio_map,
    
    # Dataset
    'fragment_length': 2012,
    'partition_size': 32,
    'val_ratio': 0.15,
    'seed': 1337,
    'sound_channel': 1,
    
    # EEG
    'eeg_sr': 1006,
    'n_channels': 63,
    'in_seq_len': 1145,
    
    # Audio
    'audio_sr': 44100,
    'sound_size': 50176,
    
    # STFT Patameters
    'n_fft': 2048,
    'hop_size': 512,
    
    # Model
    
    # Convolution Module
    'kernel_size': 31,
    'conv_module_dropout': .1,
    
    # Positional Encoding
    'emb_dropout': .1,
    
    # Transformer
    'd_model': 512,
    'nhead': 8,
    'num_encoder_layers': 6,
    'num_decoder_layers': 6,
    'dim_feedforward': 2048,
    'dropout': 0.1,
    'activation': 'relu',
    
    # Scheduler
    'base_lr': 0.2,
    'min_lr': 1e-5,
}


In [None]:
# spectrogram to sound

def restore(D: np.array, frame_size=config['n_fft'], hop_length=config['hop_size'], epochs: int = 2, window: str = 'hann'):

    length = (D.shape[1] + 1) * hop_length  # (D.shape[1] - 1 + 2) * hop_length
    D = np.concatenate((np.zeros((D.shape[0], 1)), D, np.zeros((D.shape[0], 1))), axis=1)
    mag, _ = librosa.magphase(D)
    phase = np.exp(1.j * np.random.uniform(0., 2*np.pi, size=mag.shape))
    x_ = librosa.istft(mag * phase, hop_length=hop_length, center=True, window=window, length=length)

    for i in range(epochs):
        _, phase = librosa.magphase(librosa.stft(x_, n_fft=frame_size, hop_length=hop_length, center=True,
                                                 window=window))
        x_ = librosa.istft(mag * phase, hop_length=hop_length, center=True, window=window, length=length)

    return x_[hop_length:-hop_length]

# Dataset

In [None]:
class EEGDataset(Dataset):
    def __init__(self, path: str, audio_maps: dict, fragment_length: int = 2012, partition_size: int = 32,
                 sample_rate: int = 44100, sound_channel: int = 1, val_ratio: float = 0.15, seed: int = 1337):
        '''
        path: path to sections (folders)
        audio_maps: two-level map: section names -> labels -> audio_paths
        fragment_lengtht: length of fragment after label
        partition_size: number of nonzero labels in each csv file
        '''
        super().__init__()
        rnd = random.Random(seed)
        
        self.sections = os.listdir(path)
        rnd.shuffle(self.sections)
        assert set(self.sections) == set(audio_maps.keys()), "Sections must be the same!"
        self.audio_maps = audio_maps
        
        all_paths = []
        for sec in self.sections:
            l = os.listdir(os.path.join(path, sec))
            rnd.shuffle(l)
            all_paths.append([os.path.join(path, sec, file) for file in l])
                
        # all_paths = [[os.path.join(path, sec, file) for file in sorted(os.listdir(os.path.join(path, sec)))] for sec in self.sections]
        num_all_files = [len(elem) for elem in all_paths]
        splits = [int(elem * val_ratio) for elem in num_all_files]
        
        self.val_paths = [sec_paths[:split] for sec_paths, split in zip(all_paths, splits)]
        self.paths = [sec_paths[split:] for sec_paths, split in zip(all_paths, splits)]
        
        self.sec_num_files = [len(elem) for elem in self.paths]
        self.sec_cumnum = np.cumsum(self.sec_num_files) * partition_size
        self.total_num_files = sum(self.sec_num_files)
        
        self.sec_num_val_files = [len(elem) for elem in self.val_paths]
        self.sec_val_cumnum = np.cumsum(self.sec_num_val_files) * partition_size
        self.total_num_val_files = sum(self.sec_num_val_files)
        
        self.partition_size = partition_size
        self.fragment_length = fragment_length
        self.sr = sample_rate
        self.sound_channel = sound_channel
        self.val_mode = False
        
    def __len__(self) -> int:
        num = self.total_num_val_files if self.val_mode else self.total_num_files
        return num * self.partition_size
    
    def set_val_mode(self, mode: bool):
        '''
        Switch between train/val subsets
        '''
        assert mode in [True, False], "Incorrect mode type!"
        self.val_mode = mode
        return self
    
    def to_section(self, idx: int) -> Tuple[int, int]:
        '''
        Get file section and inner index by its absolute index
        '''
        cumnum = self.sec_val_cumnum if self.val_mode else self.sec_cumnum
        section = np.where(idx < cumnum)[0][0]
        section_idx = idx if (section == 0) else (idx - cumnum[section - 1])
        return section, section_idx
    
    def get_audio(self, section: str, label: int) -> torch.Tensor:
        '''
        Get audio by section and corresponding label
        '''
        section_name = self.sections[section]
        audio, current_sr = torchaudio.load(self.audio_maps[section_name][label])
        audio = torchaudio.functional.resample(audio, orig_freq=current_sr, new_freq=self.sr)
        return audio[self.sound_channel]
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        '''
        int idx: file ID
        return: EEG fragment with its corresponding audio
        '''
        section, section_idx = self.to_section(idx)
        paths_source = self.val_paths if self.val_mode else self.paths
        file_path = paths_source[section][section_idx // self.partition_size]
        
        start = (section_idx % self.partition_size) * self.fragment_length
        end = start + self.fragment_length
        
        data = pd.read_feather(file_path).to_numpy()
        x, label = torch.tensor(data[start:end, 1:]), data[start, 0].astype(int)
        
        audio = self.get_audio(section, label)
        
        # Cut model inputs so that they match desirable sizes
        E, S = config['in_seq_len'], config['sound_size']
        x = x[:E] if x.size(0) >= E else nn.functional.pad(x, (0, E-x.size(0)), value=0)
        audio = audio[:S] if audio.size(0) >= S else nn.functional.pad(audio, (0, S-audio.size(0)), value=0)
        
        x = x.t()  # (n_channels, in_seq_len)
        
        return x.float(), audio.float()

# Dataloaders

In [None]:
train_ds = EEGDataset(path=config['path'], audio_maps=config['audio_maps'])
val_ds = EEGDataset(path=config['path'], audio_maps=config['audio_maps']).set_val_mode(True)

# train_dl = DataLoader(train_ds, 8, shuffle=True, num_workers=2)
# val_dl = DataLoader(val_ds, 8, shuffle=False, num_workers=2)
train_dl = partial(DataLoader, dataset=train_ds, shuffle=True, num_workers=2)
val_dl = partial(DataLoader, dataset=val_ds, num_workers=2)

batch_size = 8
print(f'Batch size: {batch_size}')
print(f'{"Train dataset len:": <20} {len(train_ds)};\t{"Validation datset len:": <25} {len(val_ds)};')
print(f'{"Num train batches:": <20} {len(train_dl(batch_size=batch_size))};\t{"Num validation batches:": <25} {len(val_dl(batch_size=batch_size))};')

# Model

In [None]:
class ConformerFeedForward(torch.nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        '''
        :param int d_model: Input dimension
        :param int d_ff: Hidden dimension
        :param float dropout: Dropout probability for linear layers
        '''
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff
        
        self.layer_norm = nn.LayerNorm(self.d_model)
        self.linear_1 = nn.Linear(self.d_model, self.d_ff)
        self.activation = nn.SiLU()
        self.dropout_1 = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(self.d_ff, self.d_model)
        self.dropout_2 = nn.Dropout(dropout)
        
        self.reset_parameters()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        :param torch.Tensor x: (batch, time, d_model)
        :return: (batch, time, d_model)
        :rtype: torch.Tensor
        '''
        x = self.layer_norm(x)
        x = self.linear_1(x)
        x = self.activation(x)
        x = self.dropout_1(x)
        x = self.linear_2(x)
        x = self.dropout_2(x)
        
        return x
    
    def reset_parameters(self):
        ff_1_max = self.d_model ** -0.5
        ff_2_max = self.d_ff ** -0.5
        with torch.no_grad():
            torch.nn.init.uniform_(self.linear_1.weight, -ff_1_max, ff_1_max)
            torch.nn.init.uniform_(self.linear_2.weight, -ff_2_max, ff_2_max)
            
            torch.nn.init.uniform_(self.linear_1.bias, -ff_1_max, ff_1_max)
            torch.nn.init.uniform_(self.linear_2.bias, -ff_2_max, ff_2_max)

In [None]:
class ConformerConvolution(torch.nn.Module):
    def __init__(self, d_model: int, kernel_size: int, dropout: float):
        '''
        :param int d_model: Input dimension
        :param int kernel_size: Kernel size of Depthwise Convolution
        :param float dropout: Dropout probability 
        '''
        super().__init__()
        
        self.d_model = d_model
        self.kernel_size = kernel_size

        self.layer_norm = nn.LayerNorm(self.d_model)
        self.pointwise_conv_1 = nn.Conv1d(self.d_model, 2 * self.d_model, kernel_size=1)
        self.activation_1 = nn.GLU()
        self.depthwise_conv = nn.Conv1d(self.d_model, self.d_model, kernel_size=self.kernel_size, groups=self.d_model, padding='same')
        self.batch_norm = nn.BatchNorm1d(self.d_model)
        self.activation_2 = nn.SiLU()
        self.pointwise_conv_2 = nn.Conv1d(self.d_model, self.d_model, kernel_size=1)
        self.dropout = nn.Dropout(dropout)
        
        self.reset_parameters()

    def forward(self, x: torch.Tensor, pad_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        '''
        :param torch.Tensor x: (batch, time, d_model)
        :param torch.Tensor pad_mask: (batch, time) takes True value for the positions corresponding to the padding
        :return: (batch, time, d_model)
        :rtype: torch.Tensor
        '''
        
        x = self.layer_norm(x)
        x = self.pointwise_conv_1(x.permute(0, 2, 1)).permute(0, 2, 1)
        x = self.activation_1(x)

        if pad_mask is not None:
            x = x.masked_fill(pad_mask[..., None], 0.0)

        x = self.depthwise_conv(x.permute(0, 2, 1))
        x = self.batch_norm(x)
        x = self.activation_2(x)
        x = self.pointwise_conv_2(x).permute(0, 2, 1)
        x = self.dropout(x)

        return x
    
    def reset_parameters(self):
        pw_max = self.d_model ** -0.5
        dw_max = self.kernel_size ** -0.5
        with torch.no_grad():
            torch.nn.init.uniform_(self.pointwise_conv_1.weight, -pw_max, pw_max)
            torch.nn.init.uniform_(self.pointwise_conv_2.weight, -pw_max, pw_max)
            torch.nn.init.uniform_(self.depthwise_conv.weight, -dw_max, dw_max)
            
            torch.nn.init.uniform_(self.pointwise_conv_1.bias, -pw_max, pw_max)
            torch.nn.init.uniform_(self.pointwise_conv_2.bias, -pw_max, pw_max)
            torch.nn.init.uniform_(self.depthwise_conv.bias, -dw_max, dw_max)

In [None]:
class RelPositionMultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model: int, n_head: int, dropout: float):
        '''
        x:param int d_model: Input dimension
        x:param int kernel_size: Number of MHSA heads
        x:param float dropout: Dropout probability for attention probabilities
        '''
        super().__init__()

        assert d_model % n_head == 0
        
        self.n_head = n_head
        self.d_model = d_model
        self.d_k = d_model // n_head

        # Linear transformations for queries, keys and values (W_{q}, W_{k}, W_{v})
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model) 
        
        # Linear transformation for positional encoding (W_{k,R})
        self.linear_pos = nn.Linear(d_model, d_model, bias=False)

        # These two learnable biases that are used in matrix c and matrix d
        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
        self.pos_bias_u = nn.Parameter(torch.randn(d_model))
        self.pos_bias_v = nn.Parameter(torch.randn(d_model))
        
        # Dropout layer for attention probabilities
        self.dropout = nn.Dropout(dropout)

        # Linear transformation for model output
        self.linear_out = nn.Linear(d_model, d_model)
        
        self.reset_parameters()
        
    @staticmethod
    def rel_shift(x: torch.Tensor) -> torch.Tensor:
        '''Compute relative positional encoding.
        :param torch.Tensor x: (batch, head, time_x, time_y)
        :return: (batch, head, time_x, time_y)
        :rtype: torch.Tensor
        '''
        batch, head, time_x, time_y = x.shape
        
        # Add a column of zeros on the left side of last dimension to perform the relative shifting
        x = torch.cat((x.new_zeros(batch, head, time_x, 1), x), dim=-1)
        # x = torch.nn.functional.pad(x, pad=(1, 0))

        # Reshape matrix
        # x = x.view(batch, head, -1, x.shape[-1] // 2)
        x = x.view(batch, head, -1, time_x)

        # Drop the first row and reshape matrix back
        x = x[:, :, 1:, :].reshape(batch, head, time_x, -1)

        return x

    def forward_qkv(
        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Transform query, key and value.
        :param torch.Tensor query: (batch, time_1, d_model)
        :param torch.Tensor key:   (batch, time_2, d_model)
        :param torch.Tensor value: (batch, time_2, d_model)

        :return: (q, k, v):
            torch.Tensor q: (batch, head, time_1, d_k)
            torch.Tensor k: (batch, head, time_2, d_k)
            torch.Tensor v: (batch, head, time_2, d_k)
        :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        """
        q = self.linear_q(query).view(query.shape[0], query.shape[1], self.n_head, self.d_k).permute(0, 2, 1, 3)
        k = self.linear_k(key).view(key.shape[0], key.shape[1], self.n_head, self.d_k).permute(0, 2, 1, 3)
        v = self.linear_v(value).view(value.shape[0], value.shape[1], self.n_head, self.d_k).permute(0, 2, 1, 3)

        return q, k, v

    def forward_attention(self, value: torch.Tensor, scores: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
        """Compute attention context vector.
        :param torch.Tensor value:    (batch, head, time_2, d_k)
        :param torch.Tensor scores:   (batch, head, time_1, time_2)
        :param Optional[torch.Tensor] mask: (batch, time_1, time_2) attention mask between queries and keys
            Takes True value for the positions corresponding to which tokens should NOT be attended to

        :return: Transformed `value` of shape (batch, time_1, d_model) weighted by the attention scores
        :rtype: torch.Tensor
        """
        if mask is not None:
            if mask.ndim == 2:
                mask = mask.unsqueeze(0)
            # Mask scores so that the won't be used in attention probabilities
            scores = scores.masked_fill(mask.unsqueeze(1), -1e+30 if scores.dtype == torch.float32 else -1e+4)
            
            # Calculate attention probabilities
            # Do not forget to mask probabilities
            attn = torch.softmax(scores, dim=-1).masked_fill(mask.unsqueeze(1), 0)  # 1e-9 might cause issues when dealing with mixed precision
        else:
            # Calculate attention probabilities
            attn = torch.softmax(scores, dim=-1)

        # Apply attention dropout
        attn = self.dropout(attn)

        # Reweigh value w.r.t. attention probabilities
        out = attn @ value
        
        # Apply output linear transformation
        return self.linear_out(out.permute(0, 2, 1, 3).reshape(scores.shape[0], scores.shape[2], -1))

    def forward(
        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, 
        mask: Optional[torch.Tensor], pos_emb: torch.Tensor
    ) -> torch.Tensor:
        '''Compute 'Scaled Dot Product Attention' with rel. positional encoding.
        :param torch.Tensor query:          (batch, time_1, d_model)
        :param torch.Tensor key:            (batch, time_2, d_model)
        :param torch.Tensor value:          (batch, time_2, d_model)
        :param Optional[torch.Tensor] mask: (batch, time_1, time_2) attention mask between queries and keys
            Takes True value for the positions corresponding to which tokens should NOT be attended to
        :param torch.Tensor pos_emb:        (batch, 2*time_2-1, d_model) relative positional embeddings 
            for all possible values of i - j

        :return: Transformed `value` of shape (batch, time_1, d_model) weighted by the query-key attention
        :rtype: torch.Tensor
        '''
        # Apply linear transformation for positional embeddings
        pos_emb = self.linear_pos(pos_emb)

        # Apply linear transformation for queries, keys and values
        q, k, v = self.forward_qkv(query, key, value)

        # Sum q with biases
        # I.e (W_{q}E_{x_{i}} + u) and (W_{q}E_{x_{i}} + v)
        a = q + self.pos_bias_u.view(1, -1, 1, self.d_k)
        b = q + self.pos_bias_v.view(1, -1, 1, self.d_k)

        # Compute attention scores
        # First compute matrix a + matrix c
        #   as described in https://arxiv.org/abs/1901.02860 Section 3.3
        # (batch, head, time1, time2)
        matrix_ac = a @ k.permute(0, 1, 3, 2)
        
        # Compute matrix b + matrix d
        # (batch, head, time1, 2*time_2 - 1)
        matrix_bd = b @ pos_emb.view(pos_emb.shape[0], pos_emb.shape[1], -1, self.d_k).permute(0, 2, 3, 1)

        # Apply relative shift to b + d matrix
        matrix_bd = self.rel_shift(matrix_bd)

        # Drops extra elements in the matrix_bd to match the matrix_ac's size
        matrix_bd = matrix_bd[:, :, :, :matrix_ac.shape[3]]

        scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)  # (batch, head, time_1, time_2)
        
        # Compute reweighed values using scores and mask
        out = self.forward_attention(v, scores, mask)

        return out
    
    def reset_parameters(self):
        f_max = self.d_model ** -0.5
        with torch.no_grad():
            torch.nn.init.uniform_(self.linear_q.weight, -f_max, f_max)
            torch.nn.init.uniform_(self.linear_k.weight, -f_max, f_max)
            torch.nn.init.uniform_(self.linear_v.weight, -f_max, f_max)
            torch.nn.init.uniform_(self.linear_out.weight, -f_max, f_max)
            torch.nn.init.uniform_(self.linear_pos.weight, -f_max, f_max)

            torch.nn.init.uniform_(self.linear_q.bias, -f_max, f_max)
            torch.nn.init.uniform_(self.linear_k.bias, -f_max, f_max)
            torch.nn.init.uniform_(self.linear_v.bias, -f_max, f_max)
            torch.nn.init.uniform_(self.linear_out.bias, -f_max, f_max)

In [None]:
class RelPositionalEncoding(torch.nn.Module):
    '''Relative positional encoding for TransformerXL's layers
    See : Appendix B in https://arxiv.org/abs/1901.02860
    '''

    def __init__(self, d_model, dropout, max_len=5000, xscale=False, dropout_emb=0.0):
        '''Construct an RelPositionalEncoding object.
        :param int d_model: Embedding dim
        :param float dropout: Dropout probability for input embeddings
        :param int max_len: Maximum input length
        :param bool xscale: Whether to scale the input by sqrt(d_model)
        :param float dropout_emb: Dropout probability for positional embeddings
        '''
        super().__init__()

        self.d_model = d_model
        self.xscale = xscale

        # Create Dropout layer for input embeddings
        self.dropout = nn.Dropout(dropout)
        
        # Create Dropout layer for positional embeddings
        self.dropout_emb = nn.Dropout(dropout_emb)
        
        # Positions must be from positive numbers to negative
        # Positive positions will be used for left positions and negative for right positions
        positions = torch.arange(max_len-1, -max_len, -1)
        self.create_pe(positions)

    def create_pe(self, positions: torch.Tensor):
        '''Compute positional encoding for given indices
        :attr torch.Tensor pe: (1, pos_length, d_model)
        :param torch.Tensor positions: (pos_length)
        '''
        pos_length = positions.size(0)

        # Compute positional encoding
        # as described in https://arxiv.org/abs/1706.03762 Section 3.5
        pe = torch.zeros(pos_length, self.d_model, requires_grad=False)
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * -(math.log(10000.0) / self.d_model))
        pe[:, 0::2] = torch.sin(positions.unsqueeze(1) * div_term)
        pe[:, 1::2] = torch.cos(positions.unsqueeze(1) * div_term)
        pe = pe.unsqueeze(0)

        # Save precomputed positional embeddings
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''Compute positional encoding.
        :param torch.Tensor x: Input of size(batch, time, feature_size)

        :return Tuple[torch.Tensor, torch.Tensor]: (x, pos_emb):
            torch.Tensor x: (batch, time, feature_size)
            torch.Tensor pos_emb: (1, 2*time-1, feature_size)
        '''

        # Rescale input
        if self.xscale:
            x = x * (self.d_model ** 0.5)
            
        # Apply embeddings dropout
        x = self.dropout(x)

        # Center_pos would be the index of position 0
        # Negative positions would be used for right and positive for left tokens
        # for input of length L, 2*L-1 positions are needed, positions from (L-1) to -(L-1)
        time = x.size(1)
        center_pos = self.pe.size(1) // 2
        start_pos = center_pos - (time-1)

        pos_emb = self.pe[:, start_pos:start_pos+2*time-1, :]

        # Apply positional embeddings dropout
        pos_emb = self.dropout_emb(pos_emb.to(x.device))

        return x, pos_emb

In [None]:
class ConformerEncoderBlock(torch.nn.Module):
    def __init__(self, d_model: int, d_ff: int, n_heads: int, kernel_size: int, dropout: float, dropout_att: float):
        """
        :param int d_model: Input dimension
        :param int d_ff: Hidden dimension for Feed Forward Module
        :param int n_heads: Number of MHSA heads
        :param int kernel_size: Kernel size of Depthwise Convolution
        :param float dropout: Dropout probability for Feed Forward and Convolution Modules
        :param float dropout_att: Dropout probability for attention probabilities
        """
        super().__init__()

        self.fc_factor = 0.5

        self.feed_forward_1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)

        self.layer_norm_attn = nn.LayerNorm(d_model)
        self.self_attn = RelPositionMultiHeadAttention(d_model=d_model, n_head=n_heads, dropout=dropout_att)
        self.dropout_attn = nn.Dropout(dropout_att)

        self.conv = ConformerConvolution(d_model=d_model, kernel_size=kernel_size, dropout=dropout)

        self.feed_forward_2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)

        self.layer_norm_out = nn.LayerNorm(d_model)

    def forward(
        self, x: torch.Tensor, pos_emb: torch.Tensor, 
        att_mask: Optional[torch.Tensor] = None, pad_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        :param torch.Tensor x: (batch, time, d_model) input features 
        :param torch.Tensor pos_emb: (batch, 2*time-1, d_model) relative positional embeddings 
            for all possible values of i - j
        :param Optional[torch.Tensor] att_mask: (batch, time, time) attention mask between queries and keys
            Takes True value for the positions corresponding to which tokens should NOT be attended to
        :param Optional[torch.Tensor] pad_mask: (batch, time) padding mask
            Takes True value for the positions corresponding to the padding
        :return: (batch, time, d_model)
        :rtype: torch.Tensor
        """
        
        # Apply first Feed Forward Block with residual connection
        x = x + self.feed_forward_1(x) * self.fc_factor
        
        # Apply MHSA Block with residual connection
        mhsa = self.layer_norm_attn(x)
        mhsa = self.self_attn(mhsa, mhsa, mhsa, att_mask, pos_emb)
        mhsa = self.dropout_attn(mhsa)
        x = x + mhsa

        # Apply Convolutional Block with residual connection
        x = x + self.conv(x, pad_mask)
        
        # Apply second Feed Forward Block with residual connection
        x = x + self.feed_forward_2(x) * self.fc_factor
        x = self.layer_norm_out(x)

        return x

In [None]:
class ConformerEncoder(torch.nn.Module):
    def __init__(
        self, n_layers: int, d_model: int, d_ff: int, n_heads: int,  kernel_size: int, 
        max_len: int, xscale: bool, dropout_emb: float, dropout: float, dropout_att: float
    ):
        '''
        :param int n_layers: Number of Conformer Blocks
        :param int d_model: Input dimension
        :param int d_ff: Hidden dimension for Feed Forward Module
        :param int n_heads: Number of MHSA heads
        :param int kernel_size: Kernel size of Depthwise Convolution
        :param int max_len: Maximum input length
        :param bool xscale: Whether to scale the input by sqrt(d_model)
        :param float dropout_emb: Dropout probability for positional embeddings
        :param float dropout: Dropout probability for Feed Forward and Convolution Modules
        :param float dropout_att: Dropout probability for attention probabilities
        '''
        super().__init__()
        
        self.encoding = RelPositionalEncoding(d_model=d_model, dropout=dropout, max_len=max_len, xscale=xscale, dropout_emb=dropout_emb)
        self.layers = nn.ModuleList([ConformerEncoderBlock(d_model=d_model, d_ff=d_ff, n_heads=n_heads, kernel_size=kernel_size,
                                                           dropout=dropout, dropout_att=dropout_att) for _ in range(n_layers)])
        
    @staticmethod
    def _create_masks(max_length, length: torch.Tensor) -> Tuple[torch.BoolTensor, torch.BoolTensor]:
        
        '''
        :param int max_length: Maximum size of time dimension in the batch
        :param torch.Tensor length: (batch) length of sequences in batch
        :return: (pad_mask, att_mask):
            torch.BoolTensor pad_mask: (batch, max_length)
                Takes True value for the positions corresponding to the padding
            torch.BoolTensor att_mask: (batch, max_length, max_length)
                Takes True value for the positions corresponding to which tokens should NOT be attended to
            Where max_length is a size of time dimension of the batch
        :rtype: Tuple[torch.BoolTensor, torch.BoolTensor]
        '''
        # pad_mask is the masking to be used to ignore paddings
        pad_mask = torch.zeros(length.size(0), max_length, dtype=torch.bool, device=length.device)
        for i, l in enumerate(length):
            pad_mask[i, l:] = True

        # att_mask is the masking to be used in self-attention to ignore paddings
        att_mask = torch.ones(length.size(0), max_length, max_length, dtype=torch.bool, device=length.device)
        for i, l in enumerate(length):
            att_mask[i, :l, :l] = False

        return pad_mask, att_mask
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        :param torch.Tensor x: (batch, time, d_model) input features
        :param torch.Tensor length: (batch) length of sequences in batch
        :return:
        :rtype: torch.Tensor
        '''
        # We do not use masks since all inputs are of the same size
        # pad_mask, att_mask = self._create_masks(x.size(1), length)
        
        # Encode input features
        x, enc = self.encoding(x)
        
        # Apply Conformer Blocks
        for conf_block in self.layers:
            x = conf_block(x, enc)

        return x

In [None]:
class ConformerDecoderBlock(torch.nn.Module):
    def __init__(self, d_model: int, d_ff: int, n_heads: int, dropout: float, dropout_att: float):
        """
        :param int d_model: Input dimension
        :param int d_ff: Hidden dimension for Feed Forward Module
        :param int n_heads: Number of MHSA heads
        :param float dropout: Dropout probability for Feed Forward and Convolution Modules
        :param float dropout_att: Dropout probability for attention probabilities
        """
        super().__init__()

        self.fc_factor = 0.5

        self.feed_forward_1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)

        self.layer_norm_attn1 = nn.LayerNorm(d_model)
        self.self_attn1 = RelPositionMultiHeadAttention(d_model=d_model, n_head=n_heads, dropout=dropout_att)
        self.dropout_attn1 = nn.Dropout(dropout_att)

        self.layer_norm_attn2 = nn.LayerNorm(d_model)
        self.self_attn2 = RelPositionMultiHeadAttention(d_model=d_model, n_head=n_heads, dropout=dropout_att)
        self.dropout_attn2 = nn.Dropout(dropout_att)

        self.feed_forward_2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)

        self.layer_norm_out = nn.LayerNorm(d_model)

    def forward(
        self, memory: torch.Tensor, target_input: torch.Tensor, pos_emb: torch.Tensor
    ) -> torch.Tensor:
        """
        :param torch.Tensor memory: (batch, time_x, d_model) encoder output
        :param torch.Tensor target_input: (batch, time_y, d_model) input features 
        :param torch.Tensor pos_emb: (batch, 2*time-1, d_model) relative positional embeddings 
            for all possible values of i - j
        :param Optional[torch.Tensor] lookahead_mask: (batch, time, time) attention mask between queries and keys
            Takes True value for the positions corresponding to which tokens should NOT be attended to
        :return: (batch, time_y, d_model)
        :rtype: torch.Tensor
        """
        
        # Apply first Feed Forward Block with residual connection
        x = target_input + self.feed_forward_1(target_input) * self.fc_factor
        
        # Apply MHSA Block with residual connection
        lookahead_mask = nn.Transformer.generate_square_subsequent_mask(sz=target_input.size(1),
                                                                        device=target_input.device)
        mhsa = self.layer_norm_attn1(x)
        mhsa = self.self_attn1(mhsa, mhsa, mhsa, mask=lookahead_mask.type(torch.bool), pos_emb=pos_emb)
        mhsa = self.dropout_attn1(mhsa)
        x = x + mhsa

        # Apply second MHSA Block with residual connection
        mhsa = self.layer_norm_attn2(x)
        mhsa = self.self_attn2(mhsa, memory, memory, mask=None, pos_emb=pos_emb)
        mhsa = self.dropout_attn2(mhsa)
        x = x + mhsa

        # Apply second Feed Forward Block with residual connection
        x = x + self.feed_forward_2(x) * self.fc_factor
        x = self.layer_norm_out(x)

        return x


In [None]:
class ConformerDecoder(torch.nn.Module):
    def __init__(
        self, n_layers: int, d_model: int, d_ff: int, n_heads: int, max_len: int,
        xscale: bool, dropout_emb: float, dropout: float, dropout_att: float
    ):
        '''
        :param int n_layers: Number of Conformer Blocks
        :param int d_model: Input dimension
        :param int d_ff: Hidden dimension for Feed Forward Module
        :param int n_heads: Number of MHSA heads
        :param int max_len: Maximum input length
        :param bool xscale: Whether to scale the input by sqrt(d_model)
        :param float dropout_emb: Dropout probability for positional embeddings
        :param float dropout: Dropout probability for Feed Forward and Convolution Modules
        :param float dropout_att: Dropout probability for attention probabilities
        '''
        super().__init__()
        
        self.encoding = RelPositionalEncoding(d_model=d_model, dropout=dropout, max_len=max_len, xscale=xscale, dropout_emb=dropout_emb)
        self.layers = nn.ModuleList([ConformerDecoderBlock(d_model=d_model, d_ff=d_ff, n_heads=n_heads,
                                                           dropout=dropout, dropout_att=dropout_att) for _ in range(n_layers)])
    
    def forward(self, memory: torch.Tensor, target_input: torch.Tensor) -> torch.Tensor:
        '''
        :param torch.Tensor memory: (batch, time_x, d_model) encoder output
        :param torch.Tensor target_input: (batch, time_y, d_model) input features
        :return:
        :rtype: torch.Tensor
        '''
        # Encode input features
        x, enc = self.encoding(target_input)
        
        # Apply Conformer Blocks
        for conf_block in self.layers:
            x = conf_block(memory, x, enc)

        return x

In [None]:
class E2SConformer(nn.Module):
    
    def __init__(self,
                 n_fft: int, hop_size: int, d_model: int,
                 eeg_sr: int, audio_sr: int, n_channels: int,
                 
                 # Conformer Encoder
                 num_encoder_layers: int, d_ff: int,
                 n_heads: int, kernel_size: int, 
                 in_seq_len: int, xscale: bool,
                 dropout_emb: float, dropout: float, dropout_att: float,

                 # Conformer Decoder
                 num_decoder_layers: int, out_seq_len: int,

                 audio_paths: List[str]):
        """
        :param dict config: dictionart with all model parameters
        :param List[str] audio_paths: list of audio file paths to fit PCA on
        :param torch.tensor example_input: input to compute wavelet filters on. Should have shape (n_channels, in_seq_len)
        """
        super().__init__()

        self.n_fft = n_fft
        self.hop_size = hop_size
        self.d_model = d_model
        self.eeg_sr = eeg_sr
        self.audio_sr = audio_sr
        
        # self.conv_downsampling = torch.nn.Conv1d(n_channels, 1, kernel_size=1) # (N, c_in, L) -> (N, 1, L)
        # self.ln = nn.LayerNorm(n_wvt_bins)
        # self.ffn = nn.Linear(n_wvt_bins, d_model)
        self.ln = nn.LayerNorm(n_channels)
        self.pointwise = nn.Conv1d(n_channels, d_model, kernel_size=1)
        self.encoder = ConformerEncoder(
            n_layers=num_encoder_layers,
            d_model=d_model,
            d_ff=d_ff,
            n_heads=n_heads,
            kernel_size=kernel_size,
            max_len=in_seq_len+2,
            xscale=xscale,
            dropout_emb=dropout_emb,
            dropout=dropout,
            dropout_att=dropout_att
        )
        self.decoder = ConformerDecoder(
            n_layers=num_decoder_layers,
            d_model=d_model,
            d_ff=d_ff,
            n_heads=n_heads,
            max_len=out_seq_len+2,
            xscale=xscale,
            dropout_emb=dropout_emb,
            dropout=dropout,
            dropout_att=dropout_att
        )

        self.compute_pca_components(audio_paths)
        
        # Specials
        self.src_sos = nn.Parameter(torch.Tensor(1, 1, self.d_model))
        self.src_eos = nn.Parameter(torch.Tensor(1, 1, self.d_model))
        self.tgt_sos = nn.Parameter(torch.Tensor(1, 1, self.d_model))
        self.tgt_eos = nn.Parameter(torch.Tensor(1, 1, self.d_model))
        
        self.reset_parameters()
        
    def reset_parameters(self):
        pw_max = self.d_model ** -0.5
        with torch.no_grad():
            torch.nn.init.uniform_(self.src_sos, -pw_max, pw_max)
            torch.nn.init.uniform_(self.src_eos, -pw_max, pw_max)
            torch.nn.init.uniform_(self.tgt_sos, -pw_max, pw_max)
            torch.nn.init.uniform_(self.tgt_eos, -pw_max, pw_max)
        
    def compute_pca_components(self, audio_paths):
        """
        :param List[str] audio_paths: list of audio file paths to fit PCA on
        """
        audios_srs = [torchaudio.load(path) for path in audio_paths]
        all_audios = []
        for audio, sr in audios_srs:
            if sr != self.audio_sr:
                audio = torchaudio.functional.resample(waveform=audio, orig_freq=sr, new_freq=self.audio_sr)[0]
            all_audios.append(audio)
        
        all_audios = torch.cat(all_audios)
        all_audios = torch.stft(all_audios, n_fft=self.n_fft, hop_length=self.hop_size, return_complex=True)  # (n_freq_bins, n_frames)
        all_audios = torch.abs(all_audios).t().numpy()
        
        pca = PCA(n_components=self.d_model)
        pca.fit(all_audios)
        
        components = torch.tensor(pca.components_)  # (d_model, n_freq_bins)
        mean = torch.tensor(pca.mean_)  # (n_freq_bins)
        
        self.register_buffer('components', components)
        self.register_buffer('mean', mean)
        
    def prepare_src(self, x):
        """
        :param torch.tensor x: input of shape (batch_size, n_channels, in_seq_len)
        :rtype torch.tensor
        :return out of shape (batch_size, in_seq_len, d_model)
        """
        
        # LayerNorm & Feed Forward
        out = x.permute(0, 2, 1)  # (batch_size, in_seq_len, n_channels)
        out = self.ln(out)  # (batch_size, in_seq_len, n_channels)
        out = out.permute(0, 2, 1)  # (batch_size, n_channels, in_seq_len)
        out = self.pointwise(out)  # (batch_size, d_model, in_seq_len)
        
        return out.permute(0, 2, 1)
    
    def prepare_tgt(self, x):  # Add some audio normalization???
        """
        :param torch.tensor x: input of shape (batch_size, audio_len)
        :rtype torch.tensor
        :return out of shape (batch_size, out_seq_len, d_model)
        """
        # n_freq_bins = self.n_fft // 2 + 1
        # out_seq_len = self.n_fft // self.hop_size + 1
        
        # STFT
        out = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_size, return_complex=True)  # (batch_size, n_freq_bins, out_seq_len)
        out = torch.abs(out.permute(0, 2, 1))  # (batch_size, out_seq_len, n_freq_bins)
        
        # PCA
        out = out - self.mean
        out = out @ self.components.t()  # (batch_size, out_seq_len, d_model)
        return out
        
    def forward(self, eeg, audio):
        """
        :param torch.tensor eeg: input of shape (batch_size, n_channels, in_seq_len)
        :rtype torch.tensor
        :return out of shape (batch_size, out_seq_len, n_freq_bins)
        """
        batch_size = eeg.size(0)
        src = self.prepare_src(eeg)  # (batch_size, in_seq_len, d_model)
        tgt = self.prepare_tgt(audio)  # (batch_size, out_seq_len, d_model)

        # Add <sos> and <eos>
        src = torch.cat((self.src_sos.repeat(batch_size, 1, 1), src, self.src_eos.repeat(batch_size, 1, 1)),
                        dim=1)  # (batch_size, 1 + in_seq_len + 1, d_model)
        tgt = torch.cat((self.tgt_sos.repeat(batch_size, 1, 1), tgt, self.tgt_eos.repeat(batch_size, 1, 1)),
                        dim=1)  # (batch_size, 1 + out_seq_len + 1, d_model)
        
        # tgt_input <sos>, token_1, token_2, ..., token_n
        tgt_input = tgt[:, :-1, :]  # (batch_size, 1 + out_seq_len, d_model)

        # tgt_output token_1, token_2, ..., token_n, <eos>
        tgt_output = tgt[:, 1:, :]  # (batch_size, out_seq_len + 1, d_model)
        
        memory = self.encoder(src)
        out = self.decoder(memory, tgt_input)

        return out, tgt_output

    def predict(self, eeg, out_seq_len):
        """
        :param torch.tensor eeg: input of shape ([batch_size], n_channels, in_seq_len)
        :param int out_seq_len: output sequence length
        :rtype torch.tensor
        :return predicted_encoding of shape (batch_size, out_seq_len, d_model)
        """
        device = eeg.device
        if eeg.ndim == 2:
            eeg.unsqueeze_(0)

        self.eval().to(device)
        with torch.no_grad():
            src = self.prepare_src(eeg)  # (batch_size, in_seq_len, d_model)
            src = torch.cat((self.src_sos.repeat(src.size(0), 1, 1), src, self.src_eos.repeat(src.size(0), 1, 1)),
                            dim=1)  # (batch_size, 1 + in_seq_len + 1, d_model)
            memory = self.encoder(src)  # (batch_size, in_seq_len, d_model)

            pred = self.tgt_sos.repeat(eeg.size(0), 1, 1).to(device)  # (batch_size, 1, d_model)
            for _ in range(out_seq_len):
                # (batch_size, d_model)
                new_window = self.decoder(memory, pred)[:, -1, :]
                pred = torch.cat((pred, new_window.unsqueeze(1)), dim=1)

        self.train()
        return pred[:, 1:, :]

# Noam Annealing

In [None]:

class NoamAnnealing(torch.optim.lr_scheduler._LRScheduler):
    def __init__(
        self, optimizer: torch.optim.Optimizer, *, 
        d_model: int, warmup_steps: int, min_lr: float = 0.0, last_epoch: int = -1
    ):
        """
        :param torch.optim.Optimizer optimizer:
        :param int d_model: Model input dimension
        :param int warmup_steps:
        :param float min_lr: Lower bound for learning rate after warmup
        :param int last_epoch:
        """
        assert warmup_steps
        
        # It is necessary to assign all attributes *before* __init__,
        # as class is wrapped by an inner class.
        self.min_lr = min_lr
        self.warmup_steps = warmup_steps
        self.normalization = d_model ** (-0.5)

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
            )

        step = max(1, self.last_epoch)
        new_lrs = [
            self._noam_annealing(initial_lr=initial_lr, step=step) 
            for initial_lr in self.base_lrs
        ]
        return new_lrs

    def _noam_annealing(self, initial_lr: float, step: int) -> float:
        """Compute noam annealing learning rate 
            as described in https://arxiv.org/abs/1706.03762 Section 5.3.
            After warmup_steps learning rate should be always greater than min_lr

        :param float initial_lr: Additional multiplicative factor for learning rate
        :param int step: Current optimization step
        :return: Learning rate at given step
        :rtype: float
        """
        lrate = self.normalization * min(step ** (-0.5), step * self.warmup_steps ** (-1.5)) * initial_lr
        if step > self.warmup_steps:
            lrate = max(self.min_lr, lrate)
        
        return lrate

## Sanity checks

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = E2SConformer(
    n_fft=2048,
    hop_size=512,
    d_model=512,
    eeg_sr=1006,
    audio_sr=44100,
    n_channels=63,
    num_encoder_layers=6,
    num_decoder_layers=6,
    d_ff=1024,
    n_heads=4,
    kernel_size=31,
    in_seq_len=1145,
    xscale=True,
    dropout_emb=0.1,
    dropout=0.1,
    dropout_att=0.1,
    out_seq_len=99,
    audio_paths=[
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/syllables/Bu.wav',
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/syllables/Fa.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/syllables/Mu.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/syllables/Ga.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/syllables/Ba.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/syllables/Ra.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/syllables/Ma.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/syllables/Ru.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/syllables/Gu.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/syllables/Fu.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/words/St2.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/words/St5.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/words/St4.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/words/St3.wav', 
        '/kaggle/input/internal-speech-recognition/Vartanov/audios/words/St1.wav'
    ]
).to(device)

eeg, audio = train_ds[15296]
eeg, audio = eeg.unsqueeze(0).to(device), audio.unsqueeze(0).to(device)

out, tgt = model(eeg, audio)

with torch.no_grad():
    restored = tgt[:, :-1, :].squeeze()  # (out_seq_len, d_model)
    restored = restored @ model.components  # (out_seq_len, n_freq_bins)
    restored = (restored + model.mean).t().cpu().numpy()  # (n_freq_bins, out_seq_len)

restored = restore(restored)
Audio(restored, rate=train_ds.sr)