In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import glob
import regex
from typing import Dict, List, Tuple, Union

import tqdm.notebook as tqdm

import numpy as np
import math
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split

import librosa
import torchaudio

import matplotlib_inline
import matplotlib.pyplot as plt

from IPython.display import display, Audio, Markdown

%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')

In [3]:
sections = ["syllables", "phonemes_m3", "phonemes_m4", "words"]

# Dataset (del this section)

In [None]:
decoded_path = os.path.join(".", "Data", "Vartanov", "decoded")

In [4]:
def get_dataset(path, sr=1006.24, delta=0.5, delimiter=';'):
    dataset = []
    for filename in os.listdir(path):
        file_path = os.path.join(path, filename)
        contents = pd.read_csv(file_path, delimiter=delimiter).to_numpy()
        labels_indices = np.where(contents[:, 1] != 0)[0]
        num_samples = int(delta * sr)
        samples_indices = [np.arange(i, i + num_samples) for i in labels_indices]
        data = contents[samples_indices]
        dataset.append(data)
    return np.vstack(dataset)[..., 1:]

In [40]:
partition_size = 32

for sec in range(4):
    data = get_dataset(os.path.join(decoded_path, sections[sec]))
    print(f"data.shape at {sec}: {data.shape}")
    n_chunks = math.ceil(len(data) / partition_size)
    for i in tqdm.tqdm(range(n_chunks)):
        path_ = f"./Data/Vartanov/shortened/{sections[sec]}/{i}.csv"
        start = i * partition_size
        end = (i + 1) * partition_size
        if len(data[start:end]) < partition_size:
            # easiest way
            continue
        np.savetxt(path_, data[start:end].reshape(-1, data.shape[-1]), delimiter=",")

data.shape at 0: (4197, 503, 64)


  0%|          | 0/132 [00:00<?, ?it/s]

data.shape at 1: (7201, 503, 64)


  0%|          | 0/226 [00:00<?, ?it/s]

data.shape at 2: (3915, 503, 64)


  0%|          | 0/123 [00:00<?, ?it/s]

data.shape at 3: (9261, 503, 64)


  0%|          | 0/290 [00:00<?, ?it/s]

# Torch Dataset

In [145]:
class EEGDataset(Dataset):
    def __init__(self, path: str, audio_maps: dict, fragment_length: int = 503, partition_size: int = 32):
        '''
        path: path to sections (folders)
        audios: two-level map: section names -> labels -> audio_paths
        fragment_lengtht: length of fragment after label
        partition_size: number of nonzero labels in each csv file
        '''
        super().__init__()
        self.sections = os.listdir(path)
        assert set(sections) == set(audio_maps.keys()), "Sections must be the same!"
        self.audio_maps = audio_maps 
        self.paths = [[os.path.join(path, sec, file) for file in os.listdir(os.path.join(path, sec))] for sec in self.sections]
        self.sec_num_files = [len(elem) for elem in self.paths]
        self.sec_cumnum = np.cumsum(self.sec_num_files) * partition_size
        self.total_num_files = sum(self.sec_num_files)
        self.partition_size = partition_size
        self.fragment_length = fragment_length
        
    def __len__(self) -> int:
        return self.total_num_files * self.partition_size
    
    def to_section(self, idx):
        section = np.where(idx <  self.sec_cumnum)[0][0]
        section_idx = idx if (section == 0) else (idx - self.sec_cumnum[section - 1])
        return section, section_idx
    
    def get_audio(self, section, label):
        section_name = self.sections[section]
        return torchaudio.load(self.audio_maps[section_name][label])
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        '''
        int idx: file ID
        return: EEG fragment with its corresponding audio
        '''
        section, section_idx = self.to_section(idx)
        file_path = self.paths[section][section_idx // self.partition_size]
        data = pd.read_csv(file_path, header=None).to_numpy()
        start = (section_idx % self.partition_size) * self.fragment_length
        end = start + self.fragment_length
        x, label = torch.tensor(data[start:end, 1:]), data[start, 0].astype(int)
        audio = self.get_audio(section, label)
        return x, audio

In [146]:
base = os.path.join("Data", "audios")

A = "A.wav"
B = "B.wav"
F = "F.wav"
G = "G.wav"
M = "M.wav"
R = "R.wav"
U = "U.wav"

Ba = "Ba.wav"
Bu = "Bu.wav"
Fa = "Fa.wav"
Fu = "Fu.wav"
Ga = "Ga.wav"
Gu = "Gu.wav"
Ma = "Ma.wav"
Mu = "Mu.wav"
Ra = "Ra.wav"
Ru = "Ru.wav"

Biblioteka = "St1.wav"
Raketa = "St2.wav"
Kurier = "St3.wav"
Ograda = "St4.wav"
Haketa = "St5.wav"

phonemes_m3_labels = {
    12: os.path.join(base, "phonemes", A),
    22: os.path.join(base, "phonemes", A),
    13: os.path.join(base, "phonemes", B),
    23: os.path.join(base, "phonemes", B),
    14: os.path.join(base, "phonemes", F),
    24: os.path.join(base, "phonemes", F),
    15: os.path.join(base, "phonemes", G),
    25: os.path.join(base, "phonemes", G),
    16: os.path.join(base, "phonemes", M),
    26: os.path.join(base, "phonemes", M),
    17: os.path.join(base, "phonemes", R),
    27: os.path.join(base, "phonemes", R),
    18: os.path.join(base, "phonemes", U),
    28: os.path.join(base, "phonemes", U)
}

phonemes_m4_labels = {
    1: os.path.join(base, "phonemes", A),
    11: os.path.join(base, "phonemes", A),
    2: os.path.join(base, "phonemes", B),
    12: os.path.join(base, "phonemes", B),
    3: os.path.join(base, "phonemes", F),
    13: os.path.join(base, "phonemes", F),
    4: os.path.join(base, "phonemes", G),
    14: os.path.join(base, "phonemes", G),
    5: os.path.join(base, "phonemes", M),
    15: os.path.join(base, "phonemes", M),
    6: os.path.join(base, "phonemes", R),
    16: os.path.join(base, "phonemes", R),
    7: os.path.join(base, "phonemes", U),
    17: os.path.join(base, "phonemes", U)
}

syllables_labels = {
    1: os.path.join(base, "syllables", Ba),
    11: os.path.join(base, "syllables", Ba),
    2: os.path.join(base, "syllables", Fa),
    12: os.path.join(base, "syllables", Fa),
    3: os.path.join(base, "syllables", Ga),
    13: os.path.join(base, "syllables", Ga),
    4: os.path.join(base, "syllables", Ma),
    14: os.path.join(base, "syllables", Ma),
    5: os.path.join(base, "syllables", Ra),
    15: os.path.join(base, "syllables", Ra),
    6: os.path.join(base, "syllables", Bu),
    16: os.path.join(base, "syllables", Bu),
    7: os.path.join(base, "syllables", Ru),
    17: os.path.join(base, "syllables", Ru),
    8: os.path.join(base, "syllables", Mu),
    18: os.path.join(base, "syllables", Mu),
    9: os.path.join(base, "syllables", Fu),
    19: os.path.join(base, "syllables", Fu),
    10: os.path.join(base, "syllables", Gu),
    20: os.path.join(base, "syllables", Gu)
}

words_labels = {
    11: os.path.join(base, "words", Biblioteka),
    21: os.path.join(base, "words", Biblioteka),
    12: os.path.join(base, "words", Raketa),
    22: os.path.join(base, "words", Raketa),
    13: os.path.join(base, "words", Kurier),
    23: os.path.join(base, "words", Kurier),
    14: os.path.join(base, "words", Ograda),
    24: os.path.join(base, "words", Ograda),
    15: os.path.join(base, "words", Haketa),
    25: os.path.join(base, "words", Haketa)
}

audio_map = {
    "syllables": syllables_labels,
    "phonemes_m3": phonemes_m3_labels,
    "phonemes_m4": phonemes_m4_labels,
    "words": words_labels
}

In [147]:
dataset = EEGDataset(os.path.join(".", "Data", "Vartanov", "shortened"), audio_map)

In [148]:
dataset.sec_cumnum

array([ 4192,  8096, 15296, 24544])

In [149]:
x, audio = dataset[0]
Audio(audio[0], rate=audio[1])

In [154]:
x, audio = dataset[4192]
Audio(audio[0], rate=audio[1])

In [157]:
x, audio = dataset[8096]
Audio(audio[0], rate=audio[1])

In [158]:
x, audio = dataset[15296]
Audio(audio[0], rate=audio[1])