In [1]:
from collections import namedtuple
from datasets import Audio
from glob import glob
import numpy as np
import torch
import malaya_speech
import malaya_speech.augmentation.waveform as augmentation
import random
import os

`pyaudio` is not available, `malaya_speech.streaming.pyaudio` is not able to use.


In [2]:
Batch = namedtuple("Batch", ["features", 'features_length', "targets"])

labels = [
    '0 speaker',
    '1 speaker',
    '2 speakers',
    '3 speakers',
    '4 speakers',
    '5 speakers',
    'more than 5 speakers',
]

In [15]:
class Dataset(torch.utils.data.IterableDataset):

    sr = 16000

    def __init__(self):
        super(Dataset).__init__()

        files = random.sample(glob('/home/husein/ssd2/LibriSpeech/*/*/*/*.flac'), 10000)
        edge_tts = random.sample(glob('/home/husein/ssd2/*-tts-wav/*.wav'), 10000)
        wavenet = random.sample(glob('/home/husein/ssd2/ms-MY-Wavenet-*/*.mp3'), 10000)
        musan_speech = glob('/home/husein/ssd2/noise/musan/speech/*/*')
        vctk = random.sample(glob('/home/husein/ssd2/wav48_silence_trimmed/*/*.flac'), 10000)
        mandarin = random.sample(glob('/home/husein/ssd3/ST-CMDS-20170001_1-OS/*.wav'), 10000)

        speeches = files + edge_tts + wavenet + musan_speech + vctk + mandarin
        random.shuffle(speeches)
        self.speeches = speeches

        mic_noise = glob('/home/husein/ssd2/noise/mic-noise/*')
        non_speech = glob('/home/husein/ssd2/noise/Nonspeech/*')
        musan_noise = glob('/home/husein/ssd2/noise/musan/noise/*/*.wav')
        musan_music = glob('/home/husein/ssd2/noise/musan/music/*/*.wav')
        noises = mic_noise + non_speech + musan_noise + musan_music
        noises = [f for f in noises if os.path.getsize(f) / 1e6 < 10]
        random.shuffle(noises)
        self.noises = noises

        self.audio = Audio(sampling_rate=self.sr)

        self.frame_size = 300
        self.repeat = 2
        
    def random_sampling(self, s, length):
        return augmentation.random_sampling(s, sr=self.sr, length=length)
    
    def read_positive(self, f):
        y = self.audio.decode_example(self.audio.encode_example(f))['array']
        y = malaya_speech.augmentation.waveform.random_pitch(y)
        y_int = malaya_speech.astype.float_to_int(y)
        vad = malaya_speech.vad.webrtc(
                minimum_amplitude=int(
                    np.quantile(
                        np.abs(y_int), 0.3)))
        frames_int = malaya_speech.generator.frames(y_int, 30, self.sr, False)
        frames = malaya_speech.generator.frames(y, 30, self.sr, False)
        frames = [(frames[no], vad(frame)) for no, frame in enumerate(frames_int)]
        grouped = malaya_speech.group.group_frames(frames)
        grouped = [g[0].array for g in grouped if g[1]]
        return np.concatenate(grouped)
    
    def combine(self, w_samples):
        min_len = min([len(s) for s in w_samples])
        min_len = int((min_len / 16000) * 1000)
        left = np.sum([self.random_sampling(s, min_len) for s in w_samples], axis = 0)
        left = left / np.max(np.abs(left))
        return left
    
    def __iter__(self):
        while True:
            queue = []
            while len(queue) < 200:
                count = random.randint(0, 6)
                if count == 0:
                    combined = random.sample(self.noises, random.randint(1, 5))
                    ys = [self.audio.decode_example(self.audio.encode_example(f))['array'] for f in combined]
                else:
                    if count == 6:
                        count = random.randint(6, 10)
                    combined = random.sample(self.speeches, count)
                    ys = [self.read_positive(f) for f in combined]

                if count > 5:
                    label = 'more than 5 speakers'
                elif count > 1:
                    label = f'{count} speakers'
                else:
                    label = f'{count} speaker'

                n = len(combined)
                w_samples = [
                    self.random_sampling(y, length=random.randint(500, max(10000 // n, 5000)))
                    for y in ys
                ]

                X = self.combine(w_samples)
                fs = malaya_speech.generator.frames(X, self.frame_size, self.sr, append_ending_trail = False)
                for fs_ in fs:
                    queue.append((fs_.array, labels.index(label)))
            
            for _ in range(self.repeat):
                random.shuffle(queue)
                for r in queue:
                    yield torch.tensor(r[0], dtype=torch.float32), r[1]

In [16]:
dataset = Dataset()

In [None]:
i = iter(dataset)

In [None]:
next(i)

In [17]:
def batch(batches):

    features = torch.nn.utils.rnn.pad_sequence([b[0] for b in batches], batch_first=True)
    features_length = torch.tensor([len(b[0]) for b in batches], dtype=torch.int32)
    targets = torch.tensor([b[1] for b in batches], dtype=torch.int64)
    return Batch(features, features_length, targets)

In [18]:
batch_size = 10
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=batch)

In [19]:
ii = iter(loader)

In [21]:
o = next(ii)

In [24]:
o.features.shape

torch.Size([10, 4000])

In [34]:
import IPython.display as ipd
ipd.Audio(o.features[5].numpy(), rate = 16000)

In [30]:
o

Batch(features=tensor([[-0.2235, -0.1727, -0.1131,  ...,  0.1423,  0.1230,  0.1278],
        [-0.0396, -0.0998, -0.1763,  ...,  0.0115,  0.0500,  0.0866],
        [-0.0460, -0.0502, -0.0257,  ...,  0.3644,  0.3021,  0.2202],
        ...,
        [-0.3947, -0.4092, -0.4155,  ...,  0.1173,  0.1215,  0.1398],
        [-0.2395, -0.2020, -0.1823,  ...,  0.2857,  0.1889,  0.2729],
        [ 0.2923,  0.3157,  0.3311,  ..., -0.0214, -0.0319, -0.0342]]), features_length=tensor([4000, 4000, 4000, 4000, 4000, 4000, 4000, 4000, 4000, 4000],
       dtype=torch.int32), targets=tensor([0, 6, 1, 1, 1, 4, 1, 3, 4, 1]))