<a href="https://www.kaggle.com/code/neuralsrg/speech-reconstruction?scriptVersionId=135626281" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import glob
import regex
from typing import Dict, List, Tuple, Union

import tqdm.notebook as tqdm

import numpy as np
import math
import pandas as pd

from sklearn.decomposition import PCA

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split

import librosa
import torchaudio

import matplotlib_inline
import matplotlib.pyplot as plt

from IPython.display import display, Audio, Markdown

%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


In [3]:
sections = ["syllables", "phonemes_m3", "phonemes_m4", "words"]

# Torch Dataset

In [4]:
class EEGDataset(Dataset):
    def __init__(self, path: str, audio_maps: dict, fragment_length: int = 503, partition_size: int = 32,
                 sample_rate: int = 44100, sound_channel: int = 1):
        '''
        path: path to sections (folders)
        audio_maps: two-level map: section names -> labels -> audio_paths
        fragment_lengtht: length of fragment after label
        partition_size: number of nonzero labels in each csv file
        '''
        super().__init__()
        self.sections = os.listdir(path)
        assert set(self.sections) == set(audio_maps.keys()), "Sections must be the same!"
        self.audio_maps = audio_maps 
        self.paths = [[os.path.join(path, sec, file) for file in os.listdir(os.path.join(path, sec))] for sec in self.sections]
        self.sec_num_files = [len(elem) for elem in self.paths]
        self.sec_cumnum = np.cumsum(self.sec_num_files) * partition_size
        self.total_num_files = sum(self.sec_num_files)
        self.partition_size = partition_size
        self.fragment_length = fragment_length
        self.sr = sample_rate
        self.sound_channel = sound_channel
        
    def __len__(self) -> int:
        return self.total_num_files * self.partition_size
    
    def to_section(self, idx):
        section = np.where(idx < self.sec_cumnum)[0][0]
        section_idx = idx if (section == 0) else (idx - self.sec_cumnum[section - 1])
        return section, section_idx
    
    def get_audio(self, section, label):
        section_name = self.sections[section]
        audio, current_sr = torchaudio.load(self.audio_maps[section_name][label])
        audio = torchaudio.functional.resample(audio, orig_freq=current_sr, new_freq=self.sr)
        return audio[self.sound_channel]
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        '''
        int idx: file ID
        return: EEG fragment with its corresponding audio
        '''
        section, section_idx = self.to_section(idx)
        file_path = self.paths[section][section_idx // self.partition_size]
        data = pd.read_csv(file_path, header=None).to_numpy()
        start = (section_idx % self.partition_size) * self.fragment_length
        end = start + self.fragment_length
        x, label = torch.tensor(data[start:end, 1:]), data[start, 0].astype(int)
        audio = self.get_audio(section, label)
        return x, audio

In [5]:
base = '/kaggle/input/eeg-mixed-shortened/Vartanov/audios'

A = "A.wav"
B = "B.wav"
F = "F.wav"
G = "G.wav"
M = "M.wav"
R = "R.wav"
U = "U.wav"

Ba = "Ba.wav"
Bu = "Bu.wav"
Fa = "Fa.wav"
Fu = "Fu.wav"
Ga = "Ga.wav"
Gu = "Gu.wav"
Ma = "Ma.wav"
Mu = "Mu.wav"
Ra = "Ra.wav"
Ru = "Ru.wav"

Biblioteka = "St1.wav"
Raketa = "St2.wav"
Kurier = "St3.wav"
Ograda = "St4.wav"
Haketa = "St5.wav"

phonemes_m3_labels = {
    12: os.path.join(base, "phonemes", A),
    22: os.path.join(base, "phonemes", A),
    13: os.path.join(base, "phonemes", B),
    23: os.path.join(base, "phonemes", B),
    14: os.path.join(base, "phonemes", F),
    24: os.path.join(base, "phonemes", F),
    15: os.path.join(base, "phonemes", G),
    25: os.path.join(base, "phonemes", G),
    16: os.path.join(base, "phonemes", M),
    26: os.path.join(base, "phonemes", M),
    17: os.path.join(base, "phonemes", R),
    27: os.path.join(base, "phonemes", R),
    18: os.path.join(base, "phonemes", U),
    28: os.path.join(base, "phonemes", U)
}

phonemes_m4_labels = {
    1: os.path.join(base, "phonemes", A),
    11: os.path.join(base, "phonemes", A),
    2: os.path.join(base, "phonemes", B),
    12: os.path.join(base, "phonemes", B),
    3: os.path.join(base, "phonemes", F),
    13: os.path.join(base, "phonemes", F),
    4: os.path.join(base, "phonemes", G),
    14: os.path.join(base, "phonemes", G),
    5: os.path.join(base, "phonemes", M),
    15: os.path.join(base, "phonemes", M),
    6: os.path.join(base, "phonemes", R),
    16: os.path.join(base, "phonemes", R),
    7: os.path.join(base, "phonemes", U),
    17: os.path.join(base, "phonemes", U)
}

syllables_labels = {
    1: os.path.join(base, "syllables", Ba),
    11: os.path.join(base, "syllables", Ba),
    2: os.path.join(base, "syllables", Fa),
    12: os.path.join(base, "syllables", Fa),
    3: os.path.join(base, "syllables", Ga),
    13: os.path.join(base, "syllables", Ga),
    4: os.path.join(base, "syllables", Ma),
    14: os.path.join(base, "syllables", Ma),
    5: os.path.join(base, "syllables", Ra),
    15: os.path.join(base, "syllables", Ra),
    6: os.path.join(base, "syllables", Bu),
    16: os.path.join(base, "syllables", Bu),
    7: os.path.join(base, "syllables", Ru),
    17: os.path.join(base, "syllables", Ru),
    8: os.path.join(base, "syllables", Mu),
    18: os.path.join(base, "syllables", Mu),
    9: os.path.join(base, "syllables", Fu),
    19: os.path.join(base, "syllables", Fu),
    10: os.path.join(base, "syllables", Gu),
    20: os.path.join(base, "syllables", Gu)
}

words_labels = {
    11: os.path.join(base, "words", Biblioteka),
    21: os.path.join(base, "words", Biblioteka),
    12: os.path.join(base, "words", Raketa),
    22: os.path.join(base, "words", Raketa),
    13: os.path.join(base, "words", Kurier),
    23: os.path.join(base, "words", Kurier),
    14: os.path.join(base, "words", Ograda),
    24: os.path.join(base, "words", Ograda),
    15: os.path.join(base, "words", Haketa),
    25: os.path.join(base, "words", Haketa)
}

audio_map = {
    "syllables": syllables_labels,
    "phonemes_m3": phonemes_m3_labels,
    "phonemes_m4": phonemes_m4_labels,
    "words": words_labels
}

In [6]:
dataset = EEGDataset('/kaggle/input/eeg-mixed-shortened/Vartanov/shortened', audio_map)

In [7]:
dataset.sec_cumnum

array([ 3904,  8096, 17344, 24544])

In [8]:
x, audio = dataset[0]
Audio(audio, rate=dataset.sr)

In [9]:
x, audio = dataset[4192]
Audio(audio, rate=dataset.sr)

In [10]:
x, audio = dataset[8096]
Audio(audio, rate=dataset.sr)

In [11]:
x, audio = dataset[15296]
Audio(audio, rate=dataset.sr)

In [12]:
class STFTTransformer:
    
    def __init__(self, sound_frame_size: int, sound_hop: int):
        """
        :param int sound_frame_size: frame_size parameter for sound STFT
        :param int sound_hop: hop_length parameter for sound STFT
        """
        self.sound_frame_size = sound_frame_size
        self.sound_hop = sound_hop
    
    def transform(self, audio: torch.tensor) -> torch.tensor:
        """
        Computes STFT of a given audio
        
        :param torch.tensor audio: audio to compute STFT of
        :return: STFT of a given audio
        :rtype: torch.tensor
        """
        
        spec = torch.stft(audio, n_fft=self.sound_frame_size, hop_length=self.sound_hop,
                          return_complex=True, center=False)
        return torch.abs(spec)
    
    @staticmethod
    def restore(D: np.array, frame_size: int, hop_length: int, epochs: int = 10, window: str = 'hann'):
        
        D = np.concatenate((np.zeros((D.shape[0], 1)), D, np.zeros((D.shape[0], 1))), axis=1)
        mag, _ = librosa.magphase(D)
        phase = np.exp(1.j * np.random.uniform(0., 2*np.pi, size=mag.shape))
        x_ = librosa.istft(mag * phase, hop_length=hop_length, center=False, window=window)

        for i in range(epochs):
            _, phase = librosa.magphase(librosa.stft(x_, n_fft=frame_size, hop_length=hop_length, center=False,
                                                     window=window))
            x_ = librosa.istft(mag * phase, hop_length=hop_length, center=False, window=window)

        return x_[hop_length:-hop_length]
    
    def inverse_transform(self, spec: torch.tensor):
        """
        Transforms spectrum bach to sound
        :param torch.tensor spec: spectrum to transform
        :return: Sound
        :rtype: torch.tensor
        """
        return torch.tensor(STFTTransformer.restore(spec.detach().cpu().numpy(), 
                                                    frame_size=self.sound_frame_size,
                                                    hop_length=self.sound_hop))
        

class EncodedDataset(Dataset):
    
    def __init__(self, sound_size: int, sound_frame_size: int, sound_hop: int,
                 shift_bounds: Tuple[int, int], sigma: float, 
                 n_components: int, **kwargs):
        """
        Sound STFT:
        :param int sound_size: Size of the entire sound signal that is mapped to EEG
        :param int sound_frame_size: frame_size parameter for sound STFT
        :param int sound_hop: hop_length parameter for sound STFT
        
        Adding noise to the sound signal:
        :param Tuple[int, int] shift_bounds: Bounds for the uniform distribution of a random shift
        :param float sigma: std of a random noise
        
        EEGDataset parameters:
        :param dict kwargs: EEGDataset parameters
        
        PCA parameters:
        :param int n_components: Number of components used in PCA
        """
        
        # main dataset
        self.eeg_dataset = EEGDataset(**kwargs)
        
        # STFT params
        self.sound_size = sound_size
        self.stft_transformer = STFTTransformer(sound_frame_size, sound_hop)
        
        # Audio params
        self.sound_channel = kwargs.get('sound_channel', 1)
        self.sr = kwargs.get('sample_rate', 44100)
        
        # Noise params
        self.shift_bounds = shift_bounds
        self.sigma = sigma
        
        # PCA
        self.pca_transformer = PCA(n_components)
        self.fit_pca()
        
    def fit_pca(self):
        """
        Fits PCA transformer
        """
        
        # all sound files
        files = []
        for _, outer_dict in audio_map.items():
            for _, filename in outer_dict.items():
                files.append(filename)
        files = np.unique(files)
        
        # read sound files
        sounds = []
        for filename in files:   
            audio, current_sr = torchaudio.load(filename)
            audio = torchaudio.functional.resample(audio, orig_freq=current_sr, new_freq=self.sr)
            sounds.append(audio[self.sound_channel])
        sounds = torch.cat(sounds, dim=0)
        
        # STFT
        spectrum = self.stft_transformer.transform(sounds)
        
        # fit PCA
        self.pca_transformer.fit(sounds.t())
        
    def __len__(self):
        return len(self.eeg_dataset)
    
    def __getitem__(self, idx: int) -> Dict:
        """
        :param int idx: Index
        :return: Dictionary:
            'eeg': torch.tensor
            'sound': torch.tensor
            'spectrum': torch.tensor
        :rtype: Dict
        """
        
        signal, audio = self.eeg_dataset.__getitem__(idx)
        spectrum = self.stft_transformer.transform(audio)
        projected_spectrum = self.pca_transformer.transform(spectrum.t())
        
        # how should I transform eeg signal???