In [9]:
from glob import glob
import malaya_speech
import random
import numpy as np
from sklearn.utils import shuffle
from datasets import Audio
import os

In [51]:
import torch
from collections import namedtuple

Batch = namedtuple("Batch", ["features", 'features_length', "targets"])

class Dataset(torch.utils.data.IterableDataset):
    
    sr = 16000
    
    def __init__(self):
        super(Dataset).__init__()
        
        files = random.sample(glob('/home/husein/ssd2/LibriSpeech/*/*/*/*.flac'), 10000)
        edge_tts = random.sample(glob('/home/husein/ssd2/*-tts-wav/*.wav'), 10000)
        wavenet = random.sample(glob('/home/husein/ssd2/ms-MY-Wavenet-*/*.mp3'), 10000)
        musan_speech = glob('/home/husein/ssd2/noise/musan/speech/*/*')
        vctk = random.sample(glob('/home/husein/ssd2/wav48_silence_trimmed/*/*.flac'), 10000)

        speeches = files + edge_tts + wavenet + musan_speech + vctk
        random.shuffle(speeches)
        self.speeches = speeches
        
        mic_noise = glob('/home/husein/ssd2/noise/mic-noise/*')
        non_speech = glob('/home/husein/ssd2/noise/Nonspeech/*')
        musan_noise = glob('/home/husein/ssd2/noise/musan/noise/*/*.wav')
        musan_music = glob('/home/husein/ssd2/noise/musan/music/*/*.wav')
        noises = mic_noise + non_speech + musan_noise + musan_music
        noises = [f for f in noises if os.path.getsize(f) / 1e6 < 10]
        random.shuffle(noises)
        self.noises = noises
        
        ami = glob('/home/husein/speech-bahasa/ami/amicorpus/*/*/*.wav')
        self.ami = {os.path.split(f)[1].replace('.wav', ''): f for f in ami}
        self.annotations = malaya_speech.extra.rttm.load('/home/husein/speech-bahasa/MixHeadset.train.rttm')
        self.annotations_keys = list(self.annotations.keys())
        
        self.audio = Audio(sampling_rate=self.sr)
        
        self.frame_sizes = [30, 50, 63]
    
    def __iter__(self):
        while True:
            for i in range(len(self.speeches)):
                f = self.speeches[i]
                y = self.audio.decode_example(self.audio.encode_example(f))['array']
                if random.random() > 0.6:
                    y = malaya_speech.augmentation.waveform.random_pitch(y)
                
                y_int = malaya_speech.astype.float_to_int(y)
                vad = malaya_speech.vad.webrtc(minimum_amplitude = int(np.quantile(np.abs(y_int), 0.3)))
                frames_int = malaya_speech.generator.frames(y_int, 30, self.sr, False)
                frames = malaya_speech.generator.frames(y, 30, self.sr, False)
                frames = [(frames[no], vad(frame)) for no, frame in enumerate(frames_int)]
                grouped = malaya_speech.group.group_frames(frames)
                
                x, y = [], []
                for g in grouped:
                    if random.random() > 0.8:
                        if g[1]:
                            factor = random.uniform(0.1, 0.4)
                        else:
                            factor = random.uniform(0.4, 0.9)
                        
                        n = self.audio.decode_example(self.audio.encode_example(random.choice(noises)))['array']
                        g[0].array = malaya_speech.augmentation.waveform.add_noise(g[0].array, n, 
                                                                                    factor = factor)
                    
                    frame_size = random.choice(self.frame_sizes)
                    frames = malaya_speech.generator.frames(g[0].array, frame_size, self.sr, False)
                    frames = [f.array for f in frames]
                    x.extend(frames)
                    y.extend([int(g[1])] * len(frames))
                
                x, y = shuffle(x, y)
                for k in range(len(x)):
                    yield torch.tensor(x[k], dtype = torch.float32), y[k]
                
                
                mix = random.choice(self.annotations_keys)
                sample = self.annotations[mix]
                y, _ = malaya_speech.load(self.ami[mix])
                if random.random() > 0.6:
                    y = malaya_speech.augmentation.waveform.random_pitch(y)
                
                frame_size = random.choice(self.frame_sizes)
                frames = malaya_speech.generator.frames(y, frame_size, self.sr, False)
                for k in range(len(frames)):
                    if len(sample.crop(frames[k].timestamp, frames[k].timestamp + frames[k].duration)._labelNeedsUpdate):
                        label = 1
                    else:
                        label = 0
                    
                    yield torch.tensor(frames[k].array, dtype = torch.float32), label
                    

In [52]:
dataset = Dataset()
# dataset = iter(dataset) 

In [53]:
def batch(batches):
    
    features = torch.nn.utils.rnn.pad_sequence([b[0] for b in batches], batch_first=True)
    features_length = torch.tensor([len(b[0]) for b in batches], dtype = torch.int32)
    targets = torch.tensor([b[1] for b in batches], dtype = torch.int64)
    return Batch(features, features_length, targets)

loader = torch.utils.data.DataLoader(dataset, batch_size = 4, collate_fn = batch)
loader = iter(loader)

In [87]:
while True:
    b = next(loader)
    if b.targets.numpy().mean() > 0:
        break