<a href="https://www.kaggle.com/code/neuralsrg/speech-reconstruction-lstm?scriptVersionId=138770538" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Ideas

### Artefact removal

* https://www.intechopen.com/chapters/74032
* https://medium.com/@nikeshbajaj/artifacts-in-eeg-and-how-to-remove-them-atar-algorithm-ica-fbb91ea8485a

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
! rm -rf /kaggle/working/PyTorchWavelets
! git clone https://github.com/neuralsrg/PyTorchWavelets.git

Cloning into 'PyTorchWavelets'...
remote: Enumerating objects: 123, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (21/21), done.[K
remote: Total 123 (delta 11), reused 2 (delta 0), pack-reused 102[K
Receiving objects: 100% (123/123), 1.08 MiB | 23.00 MiB/s, done.
Resolving deltas: 100% (60/60), done.


In [58]:
import sys
sys.path.append('/kaggle/working/PyTorchWavelets/')

import os
import glob
import regex
import random
import pickle
from typing import Dict, List, Tuple, Union

from tqdm.notebook import tqdm, trange

import numpy as np
import math
import pandas as pd

from sklearn.decomposition import PCA

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split

from wavelets_pytorch.transform import WaveletTransformTorch
from wavelets_pytorch.wavelets import Morlet

import librosa
import torchaudio

import matplotlib_inline
import matplotlib.pyplot as plt; plt.style.use('ggplot')
import seaborn as sns

from IPython.display import display, Audio, Markdown

%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')

# Configuration

In [4]:
base = '/kaggle/input/internal-speech-recognition/Vartanov/audios'

A = "A.wav"
B = "B.wav"
F = "F.wav"
G = "G.wav"
M = "M.wav"
R = "R.wav"
U = "U.wav"

Ba = "Ba.wav"
Bu = "Bu.wav"
Fa = "Fa.wav"
Fu = "Fu.wav"
Ga = "Ga.wav"
Gu = "Gu.wav"
Ma = "Ma.wav"
Mu = "Mu.wav"
Ra = "Ra.wav"
Ru = "Ru.wav"

Biblioteka = "St1.wav"
Raketa = "St2.wav"
Kurier = "St3.wav"
Ograda = "St4.wav"
Haketa = "St5.wav"

phonemes_m3_labels = {
    12: os.path.join(base, "phonemes", A),
    22: os.path.join(base, "phonemes", A),
    13: os.path.join(base, "phonemes", B),
    23: os.path.join(base, "phonemes", B),
    14: os.path.join(base, "phonemes", F),
    24: os.path.join(base, "phonemes", F),
    15: os.path.join(base, "phonemes", G),
    25: os.path.join(base, "phonemes", G),
    16: os.path.join(base, "phonemes", M),
    26: os.path.join(base, "phonemes", M),
    17: os.path.join(base, "phonemes", R),
    27: os.path.join(base, "phonemes", R),
    18: os.path.join(base, "phonemes", U),
    28: os.path.join(base, "phonemes", U)
}

phonemes_m4_labels = {
    1: os.path.join(base, "phonemes", A),
    11: os.path.join(base, "phonemes", A),
    2: os.path.join(base, "phonemes", B),
    12: os.path.join(base, "phonemes", B),
    3: os.path.join(base, "phonemes", F),
    13: os.path.join(base, "phonemes", F),
    4: os.path.join(base, "phonemes", G),
    14: os.path.join(base, "phonemes", G),
    5: os.path.join(base, "phonemes", M),
    15: os.path.join(base, "phonemes", M),
    6: os.path.join(base, "phonemes", R),
    16: os.path.join(base, "phonemes", R),
    7: os.path.join(base, "phonemes", U),
    17: os.path.join(base, "phonemes", U)
}

syllables_labels = {
    1: os.path.join(base, "syllables", Ba),
    11: os.path.join(base, "syllables", Ba),
    2: os.path.join(base, "syllables", Fa),
    12: os.path.join(base, "syllables", Fa),
    3: os.path.join(base, "syllables", Ga),
    13: os.path.join(base, "syllables", Ga),
    4: os.path.join(base, "syllables", Ma),
    14: os.path.join(base, "syllables", Ma),
    5: os.path.join(base, "syllables", Ra),
    15: os.path.join(base, "syllables", Ra),
    6: os.path.join(base, "syllables", Bu),
    16: os.path.join(base, "syllables", Bu),
    7: os.path.join(base, "syllables", Ru),
    17: os.path.join(base, "syllables", Ru),
    8: os.path.join(base, "syllables", Mu),
    18: os.path.join(base, "syllables", Mu),
    9: os.path.join(base, "syllables", Fu),
    19: os.path.join(base, "syllables", Fu),
    10: os.path.join(base, "syllables", Gu),
    20: os.path.join(base, "syllables", Gu)
}

words_labels = {
    11: os.path.join(base, "words", Biblioteka),
    21: os.path.join(base, "words", Biblioteka),
    12: os.path.join(base, "words", Raketa),
    22: os.path.join(base, "words", Raketa),
    13: os.path.join(base, "words", Kurier),
    23: os.path.join(base, "words", Kurier),
    14: os.path.join(base, "words", Ograda),
    24: os.path.join(base, "words", Ograda),
    15: os.path.join(base, "words", Haketa),
    25: os.path.join(base, "words", Haketa)
}

audio_map = {
    "syllables": syllables_labels,
    "phonemes_m3": phonemes_m3_labels,
    "phonemes_m4": phonemes_m4_labels,
    "words": words_labels
}

In [5]:
config = {
    
    'path': '/kaggle/input/internal-speech-recognition/Vartanov/feather',
    
    # EEG
    'eeg_sr': 1006,
    'fragment_length': 2012,
    
    'n_channels': 63,
    'in_seq_len': 1145,
    
    # Wavelet Transform
    'dj': 0.8,
    'wavelet': Morlet(),
    'n_wvt_bins': 12,
    
    # STFT Transform
    'center': False,
    'n_fft': 2048,
    'hop_size': 512,
    'n_audio_frames': [99, 95],
    
    # Audio
    'audio_maps': audio_map,
    'audio_sr': 44100,
    'sound_size': 50176,
    'audio_channel': 1,
    'n_components': 128,
    
    # Audio Augmentation
    'shift_bounds': (-5, 5),
    'sigma': 0.005,
    
    # Model Parameters
    'encoder_hidden_size': 512, 
    'decoder_hidden_size': 512,
    'n_layers': 2,
    'encoder_dropout': .2,
    'decoder_dropout': .2
}
config['n_audio_frames'] = config['n_audio_frames'][0] if config['center'] else config['n_audio_frames'][1]

In [6]:
sections = ["syllables", "phonemes_m3", "phonemes_m4", "words"]

In [None]:
# spectrogram to sound

def restore(D: np.array, frame_size=config['n_fft'], hop_length=config['hop_size'], epochs: int = 2, window: str = 'hann'):

    length = (D.shape[1] + 1) * hop_length  # (D.shape[1] - 1 + 2) * hop_length
    D = np.concatenate((np.zeros((D.shape[0], 1)), D, np.zeros((D.shape[0], 1))), axis=1)
    mag, _ = librosa.magphase(D)
    phase = np.exp(1.j * np.random.uniform(0., 2*np.pi, size=mag.shape))
    x_ = librosa.istft(mag * phase, hop_length=hop_length, center=True, window=window, length=length)

    for i in range(epochs):
        _, phase = librosa.magphase(librosa.stft(x_, n_fft=frame_size, hop_length=hop_length, center=True,
                                                 window=window))
        x_ = librosa.istft(mag * phase, hop_length=hop_length, center=True, window=window, length=length)

    return x_[hop_length:-hop_length]

# Torch Dataset

In [62]:
class EEGDataset(Dataset):
    def __init__(self, path: str, audio_maps: dict, fragment_length: int = 2012, partition_size: int = 32,
                 sample_rate: int = 44100, sound_channel: int = 1, val_ratio: float = 0.15, seed: int = 1337):
        '''
        path: path to sections (folders)
        audio_maps: two-level map: section names -> labels -> audio_paths
        fragment_lengtht: length of fragment after label
        partition_size: number of nonzero labels in each csv file
        '''
        super().__init__()
        rnd = random.Random(seed)
        
        self.sections = os.listdir(path)
        rnd.shuffle(self.sections)
        assert set(self.sections) == set(audio_maps.keys()), "Sections must be the same!"
        self.audio_maps = audio_maps
        
        all_paths = []
        for sec in self.sections:
            l = os.listdir(os.path.join(path, sec))
            rnd.shuffle(l)
            all_paths.append([os.path.join(path, sec, file) for file in l])
                
        # all_paths = [[os.path.join(path, sec, file) for file in sorted(os.listdir(os.path.join(path, sec)))] for sec in self.sections]
        num_all_files = [len(elem) for elem in all_paths]
        splits = [int(elem * val_ratio) for elem in num_all_files]
        
        self.val_paths = [sec_paths[:split] for sec_paths, split in zip(all_paths, splits)]
        self.paths = [sec_paths[split:] for sec_paths, split in zip(all_paths, splits)]
        
        self.sec_num_files = [len(elem) for elem in self.paths]
        self.sec_cumnum = np.cumsum(self.sec_num_files) * partition_size
        self.total_num_files = sum(self.sec_num_files)
        
        self.sec_num_val_files = [len(elem) for elem in self.val_paths]
        self.sec_val_cumnum = np.cumsum(self.sec_num_val_files) * partition_size
        self.total_num_val_files = sum(self.sec_num_val_files)
        
        self.partition_size = partition_size
        self.fragment_length = fragment_length
        self.sr = sample_rate
        self.sound_channel = sound_channel
        self.val_mode = False
        
    def __len__(self) -> int:
        num = self.total_num_val_files if self.val_mode else self.total_num_files
        return num * self.partition_size
    
    def set_val_mode(self, mode: bool):
        '''
        Switch between train/val subsets
        '''
        assert mode in [True, False], "Incorrect mode type!"
        self.val_mode = mode
        return self
    
    def to_section(self, idx: int) -> Tuple[int, int]:
        '''
        Get file section and inner index by its absolute index
        '''
        cumnum = self.sec_val_cumnum if self.val_mode else self.sec_cumnum
        section = np.where(idx < cumnum)[0][0]
        section_idx = idx if (section == 0) else (idx - cumnum[section - 1])
        return section, section_idx
    
    def get_audio(self, section: str, label: int) -> torch.Tensor:
        '''
        Get audio by section and corresponding label
        '''
        section_name = self.sections[section]
        audio, current_sr = torchaudio.load(self.audio_maps[section_name][label])
        audio = torchaudio.functional.resample(audio, orig_freq=current_sr, new_freq=self.sr)
        return audio[self.sound_channel]
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        '''
        int idx: file ID
        return: EEG fragment with its corresponding audio
        '''
        section, section_idx = self.to_section(idx)
        paths_source = self.val_paths if self.val_mode else self.paths
        file_path = paths_source[section][section_idx // self.partition_size]
        
        start = (section_idx % self.partition_size) * self.fragment_length
        end = start + self.fragment_length
        
        data = pd.read_feather(file_path).to_numpy()
        x, label = torch.tensor(data[start:end, 1:]), data[start, 0].astype(int)
        
        audio = self.get_audio(section, label)
        
        # Cut model inputs so that they match desirable sizes
        E, S = config['in_seq_len'], config['sound_size']
        x = x[:E] if x.size(0) >= E else nn.functional.pad(x, (0, E-x.size(0)), value=0)
        audio = audio[:S] if audio.size(0) >= S else nn.functional.pad(audio, (0, S-audio.size(0)), value=0)
        
        return x.t(), audio

In [101]:
dataset = EEGDataset(path=config['path'], audio_maps=config['audio_maps'])
train_ds = EEGDataset(path=config['path'], audio_maps=config['audio_maps'])
val_ds = dataset = EEGDataset(path=config['path'], audio_maps=config['audio_maps']).set_val_mode(True)

In [102]:
len(train_ds), len(val_ds)

(22368, 3872)

In [33]:
dataset.sec_cumnum

array([ 3584,  6912, 13056, 22368])

In [34]:
dataset.set_val_mode(False)
x, audio = dataset[0]
Audio(audio, rate=dataset.sr)

In [36]:
x, audio = dataset[4192]
Audio(audio, rate=dataset.sr)

In [37]:
x, audio = dataset[8096]
Audio(audio, rate=dataset.sr)

In [38]:
x, audio = dataset[15296]
Audio(audio, rate=dataset.sr)

# LSTM Net

In [50]:
class LSTMNet(nn.Module):
    
    def __init__(self, config: dict, audio_paths: List[str]):
        super().__init__()
        
        self.audio_sr = config['audio_sr']
        self.audio_channel = config['audio_channel']
        self.n_fft = config['n_fft']
        self.hop_size = config['hop_size']
        self.n_components = config['n_components']
        self.shift_bounds = config['shift_bounds']
        self.sigma = config['sigma']
        self.register_buffer('window', torch.hann_window(config['n_fft']).double())
        
        # Computing Principal Components
        self.compute_pca_components(audio_paths, n_components=config['n_components'])
        
        # Wavelet Transform
        self.wvt_transformer = WaveletTransformTorch(
            dt=1/config['eeg_sr'],
            dj=config['dj'],
            wavelet=config['wavelet'],
            cuda=torch.cuda.is_available()
        )
            
        self.projector = torch.nn.Sequential(
            nn.Linear(config['n_channels'], 1),
            nn.Flatten(-2, -1)
        )
        self.encoder = nn.LSTM(input_size=config['n_wvt_bins'], hidden_size=config['encoder_hidden_size'], 
                               batch_first=True, dropout=config['encoder_dropout'], num_layers=config['n_layers'])
        self.decoder = nn.LSTM(input_size=config['encoder_hidden_size'], hidden_size=config['decoder_hidden_size'], 
                               batch_first=True, dropout=config['decoder_dropout'], num_layers=config['n_layers'])
        self.n_audio_frames = config['n_audio_frames']
        self.decoder_hidden_size = config['decoder_hidden_size']
        self.fc = torch.nn.Sequential(
            nn.Linear(config['decoder_hidden_size'], config['decoder_hidden_size'] // 2),
            nn.ReLU(),
            nn.Linear(config['decoder_hidden_size'] // 2, config['decoder_hidden_size'] // 4),
            nn.ReLU(),
            nn.Linear(config['decoder_hidden_size'] // 4, config['decoder_hidden_size'] // 8),
            nn.ReLU(),
            nn.Linear(config['decoder_hidden_size'] // 8, config['n_components']),
        )
        
    def compute_pca_components(self, audio_paths: List[str], n_components: int):
        """
        :param List[str] audio_paths: list of audio file paths to fit PCA on
        """
        audios_srs = [torchaudio.load(path) for path in audio_paths]
        all_audios = []
        for audio, sr in audios_srs:
            audio = audio[self.audio_channel]
            if sr != self.audio_sr:
                audio = torchaudio.functional.resample(waveform=audio, orig_freq=sr, new_freq=self.audio_sr)
            all_audios.append(audio)
        
        all_audios = torch.cat(all_audios)
        all_audios = torch.stft(all_audios, n_fft=self.n_fft, hop_length=self.hop_size,
                                return_complex=True, window=self.window)  # (n_freq_bins, n_frames)
        all_audios = torch.abs(all_audios).t().numpy()
        
        pca = PCA(n_components=self.n_components)
        pca.fit(all_audios)
        
        components = torch.tensor(pca.components_)  # (d_model, n_freq_bins)
        mean = torch.tensor(pca.mean_)  # (n_freq_bins)
        
        self.register_buffer('components', components)
        self.register_buffer('mean', mean)
        
    def restore_spec(self, x: torch.tensor):
        """
        :param torch.tensor x: Encoded spectrogram of shape
        (n_audio_frames, n_components) for unbatched data or 
        (batch_size, n_audio_frames, n_components) for batched data
        :return torch.tensor out: Decoded spectrogram of shape ([batch_size], n_freq_bins, n_audio_frames)
        """
        return (x @ self.components + self.mean).transpose(-1, -2)
        
    def forward(self, eeg: torch.tensor, audio: torch.tensor):
        """
        :param torch.tensor eeg: Tensor of shape (batch_size, n_channels, in_seq_len)
        :param torch.tensor audio: Tensor of shape (batch, sound_size)
        :rtype Tuple[torch.tensor, torch.tensor]
        :return Tuple[
            torch.tensor out: model output of shape (batch_size, n_audio_frames, n_components),
            torch.tensor audio: encoded audio time-frequency representation of shape (batch_size, n_audio_frames, n_components)
        ]
        """
        # EEG
        with torch.no_grad():
            
            # Reshaping for wavelet transform which takes as input tensor of shape (N, in_seq_len)
            batch_size, n_channels, in_seq_len = eeg.size()
            out = eeg.view(batch_size * n_channels, in_seq_len)  # (N, in_seq_len)
            
            # Wavelet transform
            out = self.wvt_transformer.cwt(out)  # (N, n_wvt_bins, in_seq_len)
            n_wvt_bins = out.size(1)
            out = out.view(batch_size, n_channels, n_wvt_bins, in_seq_len)  # (batch_size, n_channels, n_wvt_bins, in_seq_len)
        
        # Channel downsampling
        out = out.permute(0, 2, 3, 1)  # (batch_size, n_wvt_bins, in_seq_len, n_channels)
        out = self.projector(out)  # (batch_size, n_wvt_bins, in_seq_len)
        out = out.permute(0, 2, 1)  # (batch_size, in_seq_len, n_wvt_bins)
        
        out, (h, c) = self.encoder(out)   # out: (batch_size, in_seq_len, encoder_hidden_size)
        out = out[:, -1, :].unsqueeze(1)  # (batch_size, 1, encoder_hidden_size)
        
        hs = torch.zeros((out.size(0), self.n_audio_frames, self.decoder_hidden_size), device=eeg.device)
        for i in range(self.n_audio_frames):
            out, (h, c) = self.decoder(out, (h, c)) # out: (batch, 1, decoder_hidden_size)
            hs[:, i, :] = out.squeeze()
            
        out = self.fc(hs)  # (batch_size, n_audio_frames, n_components)
    
        # Audio
        # Noise
        shift = torch.FloatTensor(1).uniform_(self.shift_bounds[0], self.shift_bounds[1]+1).int().item()
        audio = nn.functional.pad(audio, (shift, -shift), value=0)
        # audio += torch.FloatTensor(*audio.size()).normal_(0, self.sigma)
        # STFT
        # n_freq_bins = self.n_fft // 2 + 1
        # n_audio_frames = self.n_fft // self.hop_size + 1
        audio = torch.stft(audio, n_fft=self.n_fft, hop_length=self.hop_size, return_complex=True,
                           window=self.window, center=True)  # (batch_size, n_freq_bins, n_audio_frames)
        audio = torch.abs(audio.permute(0, 2, 1))  # (batch_size, n_audio_frames, n_freq_bins)

        # PCA
        audio = audio - self.mean
        audio = audio @ self.components.t()  # (batch_size, n_audio_frames, n_components)

        return out, audio

In [51]:
def criterion(out, audio):
    """
    Custom loss function from https://www.isca-speech.org/archive/pdfs/interspeech_2021/kim21h_interspeech.pdf
    :param torch.tensor out: predicted spectrogram of shape (batch_size, n_audio_frames, n_components)
    :param torch.tensor audio: real spectrogram of shape (batch_size, n_audio_frames, n_components)
    """
    T = config['sound_size']
    L_sc = torch.norm(out - audio) / torch.norm(audio)
    L_mag = 1/T * torch.norm(torch.log(audio) - torch.log(out), p=1)
    
    return L_sc + L_mag

In [52]:
audio_paths = glob.glob('/kaggle/input/internal-speech-recognition/Vartanov/audios/syllables/*.wav') + \
              glob.glob('/kaggle/input/internal-speech-recognition/Vartanov/audios/words/*.wav')

model = LSTMNet(config, audio_paths)

In [48]:
x, audio = dataset[15296]
print(audio.shape)
Audio(audio, rate=dataset.sr)

torch.Size([50176])


In [53]:
out, spec = model(x.unsqueeze(0), audio.unsqueeze(0))
restored = model.restore_spec(spec).squeeze().detach().numpy()
restored = restore(restored)
Audio(restored, rate=dataset.sr)

In [59]:
%%time

for i in trange(256):
    dataset[i]

  0%|          | 0/256 [00:00<?, ?it/s]

CPU times: user 12.5 s, sys: 4.46 s, total: 17 s
Wall time: 7.51 s


In [103]:
train_dl = DataLoader(train_ds, 128, shuffle=True, num_workers=2)
val_dl = DataLoader(val_ds, 128, shuffle=False, num_workers=2)

In [104]:
len(train_dl), len(val_dl)

(175, 31)

# Train

In [None]:
hist = []

In [None]:
model_checkpoint_path = os.path.join(os.getcwd(), 'model_checkpoints')

if not os.path.exists(model_checkpoint_path):
    os.mkdir(model_checkpoint_path)
    
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = LSTMNet(config, audio_paths).to(device).train()

optimizer = optim.Adam(model.parameters(), lr=0.01)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

current_lr = scheduler.get_last_lr()[-1]
max_norm = 10.0

description = 'LSTM-512'

In [None]:
n_epochs = 2

total_step = 0
train_loss = []
val_loss = 0
grad_norm = []

for epoch in tqdm.tqdm(range(n_epochs)):

    for eeg, audio in (pbar := tqdm.tqdm(train_dl, total=len(train_dl))):

        total_step += 1
        
        # Move tensors to device
        eeg, audio = eeg.to(device), audio.to(device)

        # Clear gradients
        optimizer.zero_grad()
        
        
        x, y = batch[0].to(device), batch[2].to(device)

        # Perform forward pass
        y_pred = model(x)
        loss = criterion(y_pred, y)

        # Perform backward pass
        loss.backward()

        # Averaging over sliding window
        if len(train_loss) < 100:
            train_loss.append(loss.item())
            # Perform Gradient Clipping
            grad_norm.append(torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm).item())
        else:
            train_loss[(total_step - 1) % 100] = loss.item()
            grad_norm[(total_step - 1) % 100] = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm).item()

        # Perform Optimization Step
        optimizer.step()
        
        # Perform Scheduler Step
        scheduler.step()
        current_lr = scheduler.get_last_lr()[-1]
        
        # saving history
        hist.append((total_step, np.mean(train_loss), description, 'train'))
        
        # update tqdm info
        pbar.set_description(
            'Step: {0}|\tLr: {1:.2e}|\tGradNorm: {2:.2e}|\tTrain Loss: {3:.2e}|\tVal Loss: {4:.3f}'.format(
                total_step, current_lr, np.mean(grad_norm), np.mean(train_loss), val_loss
            )
        )

        # validation
        if total_step % 200 == 0:
            
            val_losses = []
            model.eval()

            with torch.no_grad():
                for batch in (val_pbar := tqdm.tqdm(val_dl, total=len(val_dl))):
                    x, y = batch[0].to(device), batch[2].to(device)
                    y_pred = model(x)
                    val_losses.append(criterion(y_pred, y).item())

                    # update tqdm info
                    val_pbar.set_description(
                        'Validation|\tVal Loss: {0:.3f}'.format(np.mean(val_losses))
                    )

                val_loss = np.mean(val_losses)

                # saving val history
                hist.append((total_step, val_loss, description, 'val'))

            model.train()

# saving model
torch.save(model.state_dict(), os.path.join(model_checkpoint_path, f'{description}.pth'))

with open(os.path.join(model_checkpoint_path, 'hist.pickle'), 'wb') as handle:
    pickle.dump(hist, handle)

In [None]:
elem = train_ds[8096]

audio = train_ds.restore_audio(elem[2])
Audio(audio, rate=train_ds.eeg_dataset.sr)

In [None]:
model.eval()
with torch.no_grad():
    y_pred = model(elem[0].unsqueeze(0).to(device)).cpu().squeeze()
model.train()

audio = train_ds.restore_audio(y_pred)
Audio(audio, rate=train_ds.eeg_dataset.sr)

In [None]:
from scipy.io import wavfile
wavfile.write('/kaggle/working/mu.wav', train_ds.eeg_dataset.sr, audio.numpy())

In [None]:
criterion(y_pred, elem[2])

In [None]:
torch.save(model.state_dict(), os.path.join(model_checkpoint_path, f'{description}.pth'))

with open(os.path.join(model_checkpoint_path, 'hist.pickle'), 'wb') as handle:
    pickle.dump(hist, handle)

# Experiments

In [None]:
! rm -rf /kaggle/working/PyTorchWavelets
! git clone https://github.com/neuralsrg/PyTorchWavelets.git
# ! pip install -r /kaggle/working/PyTorchWavelets/requirements.txt
# ! python /kaggle/working/PyTorchWavelets/setup.py install

In [None]:
! ls /kaggle/working/PyTorchWavelets

In [None]:
import sys
sys.path.append('/kaggle/working/PyTorchWavelets/')
from wavelets_pytorch.transform import WaveletTransformTorch

In [None]:
import numpy as np
from wavelets_pytorch.transform import WaveletTransform        # SciPy version
from wavelets_pytorch.transform import WaveletTransformTorch   # PyTorch version

from wavelets_pytorch.wavelets import Morlet, Ricker

"""
Example script to demonstrate the CWT on a batch of random sinusoidal signals. 
We compare both the SciPy implementation and the PyTorch implementation. 
"""

dt = 1 / 10          # sampling frequency
dj = 0.8               # scale distribution parameter
batch_size = 32        # how many signals to process in parallel
cuda = torch.cuda.is_available()            # enable GPU

t = np.linspace(0., 10., int(10./dt))

wavelet = Morlet()

# Sinusoidals with random frequency
frequencies = np.random.uniform(-0.5, 2.0, size=batch_size)
batch = np.asarray([np.sin(2*np.pi*f*t) for f in frequencies])

# Initialize wavelet filter banks (scipy and torch implementation)
wa_scipy = WaveletTransform(dt, dj, wavelet)
wa_torch = WaveletTransformTorch(dt, dj, wavelet, cuda=cuda)

# Performing wavelet transform (and compute scalogram)
print(f'batch shape: {batch.shape}')
# cwt_scipy = wa_scipy.cwt(batch)
cwt_torch = wa_torch.cwt(torch.tensor(batch).float())

print(cwt_scipy.shape)
print(cwt_torch.shape)

In [None]:
wa_scipy.scales

In [None]:
0.19360266 * 2 ** 0.8

In [None]:
wa_scipy._filters[0].shape

In [None]:
for f in wa_scipy._filters:
    plt.plot(f.real)

In [None]:

sampling_period = dt


In [None]:
Ms = [10 * scale / wa_scipy.dt for scale in wa_scipy.scales]
w = 6
[2*s*w*sampling_period / M for s, M in zip(wa_scipy.scales, Ms)]