In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import glob
import regex
from typing import Dict, List, Tuple, Union

import tqdm.notebook as tqdm

import numpy as np
import math
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split

import librosa
import torchaudio

from scipy.io import wavfile
from sklearn.decomposition import PCA

import matplotlib_inline
import matplotlib.pyplot as plt

from IPython.display import display, Audio, Markdown

%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')

In [3]:
sections = ["syllables", "phonemes_m3", "phonemes_m4", "words"]

# Dataset (del this section)

In [41]:
decoded_path = os.path.join(".", "Data", "Vartanov", "decoded")

In [42]:
def get_dataset(path, sr=1006.24, delta=2.0, delimiter=';'):
    for filename in os.listdir(path):
        file_path = os.path.join(path, filename)
        contents = pd.read_csv(file_path, delimiter=delimiter).to_numpy()
        labels_indices = np.where(contents[:, 1] != 0)[0]
        num_samples = int(delta * sr)
        samples_indices = [np.arange(i, i + num_samples) for i in labels_indices if (i + num_samples < len(contents))]
        data = contents[samples_indices]
        yield filename, data[..., 1:]

In [43]:
partition_size = 32

for sec in range(4):
    tail = np.empty(0)
    path_ = os.path.join(decoded_path, sections[sec])
    for file_num, (filename, data) in enumerate(get_dataset(path_)):
        if tail.size:
            data = np.vstack((tail, data))
        else:
            tail = np.empty(0)
        print(f"data.shape at {sec}: {data.shape} with tail of shape {tail.shape}")
        n_chunks = math.ceil(len(data) / partition_size)
        for i in (pbar := tqdm.tqdm(range(n_chunks))):
            pbar.set_description(filename)
            path_ = f"./Data/Vartanov/shortened/{sections[sec]}/{file_num}_{i}.csv"
            start = i * partition_size
            end = (i + 1) * partition_size
            if len(data[start:end]) < partition_size:
                # easiest way
                tail = data[start:end]
                continue
            np.savetxt(path_, data[start:end].reshape(-1, data.shape[-1]), delimiter=",")

data.shape at 0: (434, 2012, 64) with tail of shape (0,)


  0%|          | 0/14 [00:00<?, ?it/s]

data.shape at 0: (553, 2012, 64) with tail of shape (18, 2012, 64)


  0%|          | 0/18 [00:00<?, ?it/s]

data.shape at 0: (547, 2012, 64) with tail of shape (9, 2012, 64)


  0%|          | 0/18 [00:00<?, ?it/s]

data.shape at 0: (519, 2012, 64) with tail of shape (3, 2012, 64)


  0%|          | 0/17 [00:00<?, ?it/s]

data.shape at 0: (541, 2012, 64) with tail of shape (7, 2012, 64)


  0%|          | 0/17 [00:00<?, ?it/s]

data.shape at 0: (533, 2012, 64) with tail of shape (29, 2012, 64)


  0%|          | 0/17 [00:00<?, ?it/s]

data.shape at 0: (569, 2012, 64) with tail of shape (21, 2012, 64)


  0%|          | 0/18 [00:00<?, ?it/s]

data.shape at 0: (504, 2012, 64) with tail of shape (25, 2012, 64)


  0%|          | 0/16 [00:00<?, ?it/s]

data.shape at 0: (564, 2012, 64) with tail of shape (24, 2012, 64)


  0%|          | 0/18 [00:00<?, ?it/s]

data.shape at 0: (575, 2012, 64) with tail of shape (20, 2012, 64)


  0%|          | 0/18 [00:00<?, ?it/s]

data.shape at 0: (522, 2012, 64) with tail of shape (31, 2012, 64)


  0%|          | 0/17 [00:00<?, ?it/s]

data.shape at 0: (465, 2012, 64) with tail of shape (10, 2012, 64)


  0%|          | 0/15 [00:00<?, ?it/s]

data.shape at 0: (550, 2012, 64) with tail of shape (17, 2012, 64)


  0%|          | 0/18 [00:00<?, ?it/s]

data.shape at 0: (545, 2012, 64) with tail of shape (6, 2012, 64)


  0%|          | 0/18 [00:00<?, ?it/s]

data.shape at 1: (465, 2012, 64) with tail of shape (0,)


  0%|          | 0/15 [00:00<?, ?it/s]

data.shape at 1: (548, 2012, 64) with tail of shape (17, 2012, 64)


  0%|          | 0/18 [00:00<?, ?it/s]

data.shape at 1: (348, 2012, 64) with tail of shape (4, 2012, 64)


  0%|          | 0/11 [00:00<?, ?it/s]

data.shape at 1: (564, 2012, 64) with tail of shape (28, 2012, 64)


  0%|          | 0/18 [00:00<?, ?it/s]

data.shape at 1: (273, 2012, 64) with tail of shape (20, 2012, 64)


  0%|          | 0/9 [00:00<?, ?it/s]

data.shape at 1: (542, 2012, 64) with tail of shape (17, 2012, 64)


  0%|          | 0/17 [00:00<?, ?it/s]

data.shape at 1: (362, 2012, 64) with tail of shape (30, 2012, 64)


  0%|          | 0/12 [00:00<?, ?it/s]

data.shape at 1: (409, 2012, 64) with tail of shape (10, 2012, 64)


  0%|          | 0/13 [00:00<?, ?it/s]

data.shape at 1: (408, 2012, 64) with tail of shape (25, 2012, 64)


  0%|          | 0/13 [00:00<?, ?it/s]

data.shape at 1: (452, 2012, 64) with tail of shape (24, 2012, 64)


  0%|          | 0/15 [00:00<?, ?it/s]

data.shape at 2: (299, 2012, 64) with tail of shape (0,)


  0%|          | 0/10 [00:00<?, ?it/s]

data.shape at 2: (317, 2012, 64) with tail of shape (11, 2012, 64)


  0%|          | 0/10 [00:00<?, ?it/s]

data.shape at 2: (374, 2012, 64) with tail of shape (29, 2012, 64)


  0%|          | 0/12 [00:00<?, ?it/s]

data.shape at 2: (393, 2012, 64) with tail of shape (22, 2012, 64)


  0%|          | 0/13 [00:00<?, ?it/s]

data.shape at 2: (299, 2012, 64) with tail of shape (9, 2012, 64)


  0%|          | 0/10 [00:00<?, ?it/s]

data.shape at 2: (380, 2012, 64) with tail of shape (11, 2012, 64)


  0%|          | 0/12 [00:00<?, ?it/s]

data.shape at 2: (372, 2012, 64) with tail of shape (28, 2012, 64)


  0%|          | 0/12 [00:00<?, ?it/s]

data.shape at 2: (352, 2012, 64) with tail of shape (20, 2012, 64)


  0%|          | 0/11 [00:00<?, ?it/s]

data.shape at 2: (356, 2012, 64) with tail of shape (20, 2012, 64)


  0%|          | 0/12 [00:00<?, ?it/s]

data.shape at 2: (317, 2012, 64) with tail of shape (4, 2012, 64)


  0%|          | 0/10 [00:00<?, ?it/s]

data.shape at 2: (295, 2012, 64) with tail of shape (29, 2012, 64)


  0%|          | 0/10 [00:00<?, ?it/s]

data.shape at 2: (351, 2012, 64) with tail of shape (7, 2012, 64)


  0%|          | 0/11 [00:00<?, ?it/s]

data.shape at 3: (279, 2012, 64) with tail of shape (0,)


  0%|          | 0/9 [00:00<?, ?it/s]

data.shape at 3: (123, 2012, 64) with tail of shape (23, 2012, 64)


  0%|          | 0/4 [00:00<?, ?it/s]

data.shape at 3: (495, 2012, 64) with tail of shape (27, 2012, 64)


  0%|          | 0/16 [00:00<?, ?it/s]

data.shape at 3: (309, 2012, 64) with tail of shape (15, 2012, 64)


  0%|          | 0/10 [00:00<?, ?it/s]

data.shape at 3: (215, 2012, 64) with tail of shape (21, 2012, 64)


  0%|          | 0/7 [00:00<?, ?it/s]

data.shape at 3: (297, 2012, 64) with tail of shape (23, 2012, 64)


  0%|          | 0/10 [00:00<?, ?it/s]

data.shape at 3: (188, 2012, 64) with tail of shape (9, 2012, 64)


  0%|          | 0/6 [00:00<?, ?it/s]

data.shape at 3: (393, 2012, 64) with tail of shape (28, 2012, 64)


  0%|          | 0/13 [00:00<?, ?it/s]

data.shape at 3: (450, 2012, 64) with tail of shape (9, 2012, 64)


  0%|          | 0/15 [00:00<?, ?it/s]

data.shape at 3: (441, 2012, 64) with tail of shape (2, 2012, 64)


  0%|          | 0/14 [00:00<?, ?it/s]

data.shape at 3: (481, 2012, 64) with tail of shape (25, 2012, 64)


  0%|          | 0/16 [00:00<?, ?it/s]

data.shape at 3: (321, 2012, 64) with tail of shape (1, 2012, 64)


  0%|          | 0/11 [00:00<?, ?it/s]

data.shape at 3: (223, 2012, 64) with tail of shape (1, 2012, 64)


  0%|          | 0/7 [00:00<?, ?it/s]

data.shape at 3: (238, 2012, 64) with tail of shape (31, 2012, 64)


  0%|          | 0/8 [00:00<?, ?it/s]

data.shape at 3: (431, 2012, 64) with tail of shape (14, 2012, 64)


  0%|          | 0/14 [00:00<?, ?it/s]

data.shape at 3: (432, 2012, 64) with tail of shape (15, 2012, 64)


  0%|          | 0/14 [00:00<?, ?it/s]

data.shape at 3: (248, 2012, 64) with tail of shape (16, 2012, 64)


  0%|          | 0/8 [00:00<?, ?it/s]

data.shape at 3: (215, 2012, 64) with tail of shape (24, 2012, 64)


  0%|          | 0/7 [00:00<?, ?it/s]

data.shape at 3: (476, 2012, 64) with tail of shape (23, 2012, 64)


  0%|          | 0/15 [00:00<?, ?it/s]

data.shape at 3: (263, 2012, 64) with tail of shape (28, 2012, 64)


  0%|          | 0/9 [00:00<?, ?it/s]

data.shape at 3: (459, 2012, 64) with tail of shape (7, 2012, 64)


  0%|          | 0/15 [00:00<?, ?it/s]

data.shape at 3: (246, 2012, 64) with tail of shape (11, 2012, 64)


  0%|          | 0/8 [00:00<?, ?it/s]

data.shape at 3: (374, 2012, 64) with tail of shape (22, 2012, 64)


  0%|          | 0/12 [00:00<?, ?it/s]

data.shape at 3: (345, 2012, 64) with tail of shape (22, 2012, 64)


  0%|          | 0/11 [00:00<?, ?it/s]

data.shape at 3: (273, 2012, 64) with tail of shape (25, 2012, 64)


  0%|          | 0/9 [00:00<?, ?it/s]

data.shape at 3: (353, 2012, 64) with tail of shape (17, 2012, 64)


  0%|          | 0/12 [00:00<?, ?it/s]

data.shape at 3: (249, 2012, 64) with tail of shape (1, 2012, 64)


  0%|          | 0/8 [00:00<?, ?it/s]

data.shape at 3: (115, 2012, 64) with tail of shape (25, 2012, 64)


  0%|          | 0/4 [00:00<?, ?it/s]

data.shape at 3: (426, 2012, 64) with tail of shape (19, 2012, 64)


  0%|          | 0/14 [00:00<?, ?it/s]

data.shape at 3: (391, 2012, 64) with tail of shape (10, 2012, 64)


  0%|          | 0/13 [00:00<?, ?it/s]

data.shape at 3: (409, 2012, 64) with tail of shape (7, 2012, 64)


  0%|          | 0/13 [00:00<?, ?it/s]

data.shape at 3: (198, 2012, 64) with tail of shape (25, 2012, 64)


  0%|          | 0/7 [00:00<?, ?it/s]

data.shape at 3: (390, 2012, 64) with tail of shape (6, 2012, 64)


  0%|          | 0/13 [00:00<?, ?it/s]

data.shape at 3: (406, 2012, 64) with tail of shape (6, 2012, 64)


  0%|          | 0/13 [00:00<?, ?it/s]

data.shape at 3: (364, 2012, 64) with tail of shape (22, 2012, 64)


  0%|          | 0/12 [00:00<?, ?it/s]

NOTE: csv's were converted to .feather format later

# Torch Dataset

In [194]:
class EEGDataset(Dataset):
    def __init__(self, path: str, audio_maps: dict, fragment_length: int = 2012,
                 partition_size: int = 32, sample_rate: int = 44100, sound_channel: int = 1,
                 preloaded_audios: bool = False, val_ratio: float = 0.15):
        '''
        path: path to sections (folders)
        audio_maps: two-level map: section names -> labels -> audio_paths
        fragment_lengtht: length of fragment after label
        partition_size: number of nonzero labels in each csv file
        sample_rate: audios' SR
        sound_channel: mono audio channel
        preloaded_audios: if audio_maps lead to paths (False) or directly to data like raw signal or FT (True)
        val_ratio: float in range [0, 1], N_val / N
        '''
        super().__init__()
        self.sections = sorted(os.listdir(path))
        assert set(self.sections) == set(audio_maps.keys()), "Sections must be the same!"
        self.audio_maps = audio_maps 
        self.preloaded_audios = preloaded_audios
        
        all_paths = [[os.path.join(path, sec, file) for file in sorted(os.listdir(os.path.join(path, sec)))] for sec in self.sections]
        num_all_files = [len(elem) for elem in all_paths]
        splits = [int(elem * val_ratio) for elem in num_all_files]
        
        self.val_paths = [sec_paths[:split] for sec_paths, split in zip(all_paths, splits)]
        self.paths = [sec_paths[split:] for sec_paths, split in zip(all_paths, splits)]
        
        self.sec_num_files = [len(elem) for elem in self.paths]
        self.sec_cumnum = np.cumsum(self.sec_num_files) * partition_size
        self.total_num_files = sum(self.sec_num_files)
        
        self.sec_num_val_files = [len(elem) for elem in self.val_paths]
        self.sec_val_cumnum = np.cumsum(self.sec_num_val_files) * partition_size
        self.total_num_val_files = sum(self.sec_num_val_files)
        
        self.partition_size = partition_size
        self.fragment_length = fragment_length
        self.sr = sample_rate
        self.sound_channel = sound_channel
        self.val_mode = False
        
    def __len__(self) -> int:
        num = self.total_num_val_files if self.val_mode else self.total_num_files
        return num * self.partition_size
    
    def set_val_mode(self, mode: bool):
        '''
        Switch between train/val subsets
        mode: 0 -> train, 1 -> val
        '''
        assert mode in [True, False], "Incorrect mode type!"
        self.val_mode = mode
        return self
    
    def to_section(self, idx: int) -> Tuple[int, int]:
        '''
        Get file section and inner index by its absolute index
        idx: absolute index
        return inner index, section number (idx)
        '''
        cumnum = self.sec_val_cumnum if self.val_mode else self.sec_cumnum
        section = np.where(idx < cumnum)[0][0]
        section_idx = idx if (section == 0) else (idx - cumnum[section - 1])
        return section, section_idx
    
    def get_audio(self, section: str, label: int) -> torch.Tensor:
        '''
        Get audio by section and corresponding label
        section: number of section
        label: one of label values in given section
        return: the audio
        '''
        section_name = self.sections[section]
        
        if self.preloaded_audios:
            return self.audio_maps[section_name][label]
        
        audio, current_sr = torchaudio.load(self.audio_maps[section_name][label])
        audio = torchaudio.functional.resample(audio, orig_freq=current_sr, new_freq=self.sr)
        return audio[self.sound_channel]
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        '''
        int idx: file ID
        return: EEG fragment with its corresponding audio
        '''
        section, section_idx = self.to_section(idx)
        paths_source = self.val_paths if self.val_mode else self.paths
        file_path = paths_source[section][section_idx // self.partition_size]
        
        start = (section_idx % self.partition_size) * self.fragment_length
        end = start + self.fragment_length
        
        data = pd.read_feather(file_path).to_numpy()
        x, label = torch.tensor(data[start:end, 1:]), data[start, 0].astype(int)
        
        audio = self.get_audio(section, label)
        
        return x, audio

In [170]:
base = os.path.join("Data", "Vartanov", "audios")

A = "A.wav"
B = "B.wav"
F = "F.wav"
G = "G.wav"
M = "M.wav"
R = "R.wav"
U = "U.wav"

Ba = "Ba.wav"
Bu = "Bu.wav"
Fa = "Fa.wav"
Fu = "Fu.wav"
Ga = "Ga.wav"
Gu = "Gu.wav"
Ma = "Ma.wav"
Mu = "Mu.wav"
Ra = "Ra.wav"
Ru = "Ru.wav"

Biblioteka = "St1.wav"
Raketa = "St2.wav"
Kurier = "St3.wav"
Ograda = "St4.wav"
Haketa = "St5.wav"

phonemes_m3_labels = {
    12: os.path.join(base, "phonemes", A),
    22: os.path.join(base, "phonemes", A),
    13: os.path.join(base, "phonemes", B),
    23: os.path.join(base, "phonemes", B),
    14: os.path.join(base, "phonemes", F),
    24: os.path.join(base, "phonemes", F),
    15: os.path.join(base, "phonemes", G),
    25: os.path.join(base, "phonemes", G),
    16: os.path.join(base, "phonemes", M),
    26: os.path.join(base, "phonemes", M),
    17: os.path.join(base, "phonemes", R),
    27: os.path.join(base, "phonemes", R),
    18: os.path.join(base, "phonemes", U),
    28: os.path.join(base, "phonemes", U)
}

phonemes_m4_labels = {
    1: os.path.join(base, "phonemes", A),
    11: os.path.join(base, "phonemes", A),
    2: os.path.join(base, "phonemes", B),
    12: os.path.join(base, "phonemes", B),
    3: os.path.join(base, "phonemes", F),
    13: os.path.join(base, "phonemes", F),
    4: os.path.join(base, "phonemes", G),
    14: os.path.join(base, "phonemes", G),
    5: os.path.join(base, "phonemes", M),
    15: os.path.join(base, "phonemes", M),
    6: os.path.join(base, "phonemes", R),
    16: os.path.join(base, "phonemes", R),
    7: os.path.join(base, "phonemes", U),
    17: os.path.join(base, "phonemes", U)
}

syllables_labels = {
    1: os.path.join(base, "syllables", Ba),
    11: os.path.join(base, "syllables", Ba),
    2: os.path.join(base, "syllables", Fa),
    12: os.path.join(base, "syllables", Fa),
    3: os.path.join(base, "syllables", Ga),
    13: os.path.join(base, "syllables", Ga),
    4: os.path.join(base, "syllables", Ma),
    14: os.path.join(base, "syllables", Ma),
    5: os.path.join(base, "syllables", Ra),
    15: os.path.join(base, "syllables", Ra),
    6: os.path.join(base, "syllables", Bu),
    16: os.path.join(base, "syllables", Bu),
    7: os.path.join(base, "syllables", Ru),
    17: os.path.join(base, "syllables", Ru),
    8: os.path.join(base, "syllables", Mu),
    18: os.path.join(base, "syllables", Mu),
    9: os.path.join(base, "syllables", Fu),
    19: os.path.join(base, "syllables", Fu),
    10: os.path.join(base, "syllables", Gu),
    20: os.path.join(base, "syllables", Gu)
}

words_labels = {
    11: os.path.join(base, "words", Biblioteka),
    21: os.path.join(base, "words", Biblioteka),
    12: os.path.join(base, "words", Raketa),
    22: os.path.join(base, "words", Raketa),
    13: os.path.join(base, "words", Kurier),
    23: os.path.join(base, "words", Kurier),
    14: os.path.join(base, "words", Ograda),
    24: os.path.join(base, "words", Ograda),
    15: os.path.join(base, "words", Haketa),
    25: os.path.join(base, "words", Haketa)
}

audio_map = {
    "syllables": syllables_labels,
    "phonemes_m3": phonemes_m3_labels,
    "phonemes_m4": phonemes_m4_labels,
    "words": words_labels
}

In [189]:
dataset = EEGDataset(os.path.join(".", "Data", "Vartanov", "feather"), audio_map)

Train mode:

In [192]:
dataset.val_mode = False
dataset.sec_cumnum

array([ 3584,  6912, 13056, 22368])

In [193]:
x, audio = dataset[0]
Audio(audio, rate=dataset.sr)

In [173]:
x, audio = dataset[3584]
Audio(audio, rate=dataset.sr)

In [174]:
x, audio = dataset[6912]
Audio(audio, rate=dataset.sr)

In [177]:
x, audio = dataset[13056]
Audio(audio, rate=dataset.sr)

Val mode:

In [191]:
dataset.val_mode = True
dataset.sec_val_cumnum

array([ 608, 1184, 2240, 3872])

In [179]:
x, audio = dataset[0]
Audio(audio, rate=dataset.sr)

In [180]:
x, audio = dataset[608]
Audio(audio, rate=dataset.sr)

In [181]:
x, audio = dataset[1184]
Audio(audio, rate=dataset.sr)

In [184]:
x, audio = dataset[2240]
Audio(audio, rate=dataset.sr)

In [185]:
dataset.val_mode = True
a = len(dataset)
dataset.val_mode = False
b = len(dataset)
a / (a + b)

0.1475609756097561