In [3]:
import numpy as np
import librosa
from scipy.io import wavfile
from glob import glob
import re
import warnings
from numpy.random import random, uniform, randint, choice
import os
import shutil
from multiprocessing import Pool

In [4]:
LABELS = [
    'yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go',
    'unknown', 'silence'
]

In [27]:
class DataGenerator(object):
    def __init__(self,
                 input_dir,
                 labels=LABELS,
                 bg_noise_dir=None,
                 silence_dir=None,
                 extra_silence_dir=None):

        self.debug = False
        self._input_files = None
        self._cached_waves = {}
        self._label_onehots = None
        self.val_files = {}

        # msg normalization params
        # inspired by https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py#L529
        self.samplewise_norm = True
        # following two needs to be computed on many msgs
        self.msg_mean = 116.536
        self.msg_std = 21.5913

        self.input_dir = input_dir
        self.sample_rate = 16000
        self.duration = 1.  # sec
        self.n_fft = 512
        self.n_mels = 64
        self.msg_w = 64

        # valid labels
        self.labels = labels

        # dirs for bg noise/silence
        self.bg_noise_dir = self.input_dir + '/_background_noise_' if bg_noise_dir is None else bg_noise_dir
        self.silence_dir = self.input_dir + '/_background_noise_' if silence_dir is None else silence_dir

        # silence files
        self.silence_files = glob(self.silence_dir + '/*.wav')

        # extra silence files
        if extra_silence_dir:
            self.extra_silence_files = glob(extra_silence_dir + '/*.wav')
            self.silence_files += self.extra_silence_files
        else:
            self.extra_silence_files = None

        # random transforms

        self.transforms = {
            'pitch': {
                'probability': 0.5,
                'range': [-10., 10.]
            },
            'speed': {
                'probability': 0.33,
                'range': [0.5, 2.]
            },
            'volume': {
                'probability': 0.25,
                'range': [0.5, 1.25]
            }
        }

        # silence-specific transforms
        self.silence_transforms = {
            'volume': {
                'probability': 1.,
                'range': [0.01, 1.]
            }
        }

        self._init_mixing_params()

    def _init_mixing_params(self):
        # random mixing settings

        self.mix_with = {}

        self.mix_with[self.bg_noise_dir + '/doing_the_dishes.wav'] = {
            'volume': [0.05, 0.75],
            'probability': 0.2
        }

        self.mix_with[self.bg_noise_dir + '/exercise_bike.wav'] = {
            'volume': [0.05, 0.75],
            'probability': 0.2
        }

        self.mix_with[self.bg_noise_dir + '/white_noise.wav'] = {
            'volume': [0.001, 0.06],
            'probability': 0.2
        }

        self.mix_with[self.bg_noise_dir + '/dude_miaowing.wav'] = {
            'volume': [0.05, 0.75],
            'probability': 0.2
        }

        self.mix_with[self.bg_noise_dir + '/pink_noise.wav'] = {
            'volume': [0.001, 0.06],
            'probability': 0.2
        }

        self.mix_with[self.bg_noise_dir + '/running_tap.wav'] = {
            'volume': [0.25, 0.75],
            'probability': 0.2
        }

        # extra silence files
        if self.extra_silence_files:
            for f in self.extra_silence_files:
                self.mix_with[f] = {
                    'volume': [0.25, 0.75],
                    'probability': 1. / len(self.extra_silence_files)
                }

    # compute mel-scaled spectrogram
    def msg(self, wave):

        hop_length = int(1 + self.duration * self.sample_rate //
                         (self.msg_w - 1))
        desired_wave_len = int(hop_length * (self.msg_w - 1))

        # pad wave if neccessary to get the desired msg width
        if desired_wave_len > len(wave):
            wave = np.pad(wave, (0, desired_wave_len - len(wave)), 'median')

        # trim wave if it's too long
        elif len(wave) > desired_wave_len:
            wave = wave[:desired_wave_len]

        msg = librosa.feature.melspectrogram(
            y=wave,
            sr=self.sample_rate,
            hop_length=hop_length,
            n_fft=self.n_fft,
            n_mels=self.n_mels)
        msg = librosa.logamplitude(msg**2, ref_power=1.)
        assert msg.shape[1] == self.msg_w

        msg = msg.astype(np.float32)

        return msg

    @property
    def input_files(self):
        if self._input_files is None:

            _labels_set = set(self.labels)

            def _get_label(path):
                m = re.findall('([^/]+)/', path)
                if not m: return None
                if m[0] in _labels_set:
                    return m[0]
                else:
                    return 'unknown'

            ff = glob(self.input_dir + '/**/*.wav', recursive=True)
            ff = filter(lambda x: '_background_noise_/' not in x, ff)
            ff = [os.path.relpath(x, self.input_dir) for x in ff]

            self._input_files = {}

            for f in ff:
                label = _get_label(f)
                if label not in self._input_files:
                    self._input_files[label] = [f]
                else:
                    self._input_files[label].append(f)

        return self._input_files

    def _normalize_wave_len(self, wave, min_samples=None, max_samples=None):
        if min_samples is not None and len(wave) < min_samples:
            len_to_add = min_samples - len(wave)
            wave = np.pad(wave, (len_to_add + 1) // 2, 'median')[:min_samples]

        if max_samples is not None and len(wave) > max_samples:
            len_to_cut = len(wave) - max_samples
            wave = wave[len_to_cut // 2:max_samples + len_to_cut // 2]

        return wave

    # read file and extract random segment
    def _load_random_segment(self, file):

        if file not in self._cached_waves:
            # supress warnings
            with warnings.catch_warnings():
                warnings.simplefilter('ignore')
                sr, wave = wavfile.read(file)
                wave = wave.astype(np.float32)

            # pad if too short
            wave = self._normalize_wave_len(wave, \
                min_samples=int(self.duration * self.sample_rate))

            # save to cache
            self._cached_waves[file] = (sr, wave)

        sr, wave = self._cached_waves[file]
        assert sr == self.sample_rate
        desired_len = int(sr * self.duration)

        start = randint(0, max(1, len(wave) - desired_len))
        wave = wave[start:start + desired_len]

        return wave

    # gen random audio sample for 'silence' label
    def generate_silence_audio(self):
        silence_file = choice(self.silence_files)
        if self.debug: print('silence file', silence_file)
        wave = self._load_random_segment(silence_file)
        wave = self.apply_transforms(wave, silence_transforms=True)
        return wave

    # gen random audio for non-silence labels
    def generate_audio(self, label=None, file=None, transform=True):
        # pick label
        if label is None: label = choice(list(self.labels))

        if 'silence' == label:
            return self.generate_silence_audio(), label
        else:
            # pick wav file
            if file is None:
                while file is None or file in self.val_files:
                    file = choice(self.input_files[label])

            if self.debug: print('file', file)

            sr, wave = wavfile.read(os.path.join(self.input_dir, file))
            wave = wave.astype(np.float32)
            assert sr == self.sample_rate

            # pad too short ones/trim too long
            desired_len = int(self.duration * self.sample_rate)
            wave = self._normalize_wave_len(
                wave, min_samples=desired_len, max_samples=desired_len)
            assert len(wave) == desired_len

            # transforms
            if transform: wave = self.apply_transforms(wave)

            return wave, label

    # apply randon transforms
    def apply_transforms(self, wave, silence_transforms=False):

        if silence_transforms:
            # silence/volume
            t = self.silence_transforms['volume']
            if random() < t['probability']:
                factor = uniform(*t['range'])
                wave = self.transform_volume(wave, factor)

        # mix
        for file, options in self.mix_with.items():
            if random() < options['probability']:
                wave2 = self._load_random_segment(file)
                volume2 = uniform(*options['volume'])

                if self.debug:
                    print('mixing with %s at %.2f vol' % (file, volume2))

                wave = self.mix(wave, 1., wave2, volume2)

        # pitch
        t = self.transforms['pitch']
        if random() < t['probability']:
            wave = self.transform_pitch(wave, uniform(*t['range']))

        # speed
        t = self.transforms['speed']
        if random() < t['probability']:
            wave = self.transform_speed(wave, uniform(*t['range']))

        return wave

    def transform_pitch(self, wave, factor):
        if self.debug: print('transforming pitch', factor)
        return librosa.effects.pitch_shift(wave, self.sample_rate, factor)

    def transform_volume(self, wave, factor):
        if self.debug: print('transforming volume', factor)
        return np.multiply(wave, factor)

    def transform_speed(self, wave, factor):
        if self.debug: print('transforming speed', factor)

        orig_len = len(wave)
        wave = librosa.effects.time_stretch(wave, factor)

        # pad/trim from the center
        wave = self._normalize_wave_len(
            wave, min_samples=orig_len, max_samples=orig_len)

        assert len(wave) == orig_len
        return wave

    def mix(self, wave1, volume1, wave2, volume2):
        if volume1 + volume2 > 0:
            volume1, volume2 = \
                volume1 / (volume1 + volume2), \
                volume2 / (volume1 + volume2)
            return wave1 * volume1 + wave2 * volume2

    def normalize_msg(self, msg):

        if self.msg_mean is not None:
            msg -= self.msg_mean

        if self.msg_std is not None:
            msg /= self.msg_std + 1e-7

        if self.samplewise_norm:
            msg -= np.mean(msg)
            msg /= np.std(msg) + 1e-7

        return msg

    def compute_msg_norm_params(self, n_steps=100):
        msgs = np.zeros((n_steps, self.n_mels, self.msg_w), dtype=np.float32)

        for i in range(n_steps):
            msgs[i] = self.msg(self.generate_audio()[0])

        self.msg_mean = np.mean(msgs)
        self.msg_std = np.std(msgs)

    def _init_onehots(self):
        def _onehot(l, o):
            z = np.zeros(l, dtype=np.float32)
            z[o] = 1.
            return z

        self._label_onehots = {
            label: _onehot(len(self.labels), i)
            for i, label in enumerate(self.labels)
        }

    def label_to_onehot(self, label):
        if self._label_onehots is None:
            self._init_onehots()
        return self._label_onehots[label]

    def onehot_to_label(self, onehot):
        return self.labels[np.argmax(onehot)]

    def generate_val_set(self, n=1000):

        val_X = np.zeros((n, self.n_mels, self.msg_w), dtype=np.float32)
        val_Y = np.zeros((n, len(self.labels)))

        # pick files for non-silence labels
        if not self.val_files:
            self.val_files = {}

            for i in range(n):
                label = choice(self.labels)

                if 'silence' != label:
                    file = ''
                    while file == '' or file in self.val_files:
                        file = choice(self.input_files[label])
                    self.val_files[file] = label

        # gen non-silence samples
        i = 0
        for file, label in self.val_files.items():
            wave, _ = self.generate_audio(
                label=label, file=file, transform=False)
            msg = self.msg(wave)
            msg = self.normalize_msg(msg)
            val_X[i] = msg
            val_Y[i] = self.label_to_onehot(label)
            i += 1

        # silence samples
        for j in range(i, n):
            wave, label = self.generate_audio('silence', transform=True)
            msg = self.msg(wave)
            msg = self.normalize_msg(msg)
            val_X[j] = msg
            val_Y[j] = self.label_to_onehot(label)

        return np.expand_dims(val_X, 3), val_Y

    def _gen_training_samples(self, n, start_i, tmp_dir):
        X = np.zeros((n, self.n_mels, self.msg_w, 1), dtype=np.float32)
        Y = np.zeros((n, len(self.labels)), dtype=np.float32)
        for i in range(n):
            wave, label = self.generate_audio()
            msg = self.normalize_msg(self.msg(wave))
            msg = np.expand_dims(msg, 2)
            X[i] = msg
            Y[i] = self.label_to_onehot(label)
        np.save('%s/X_%07d-%07d' % (tmp_dir, start_i, n + start_i), X)
        np.save('%s/Y_%07d-%07d' % (tmp_dir, start_i, n + start_i), Y)

    def generate_train_set(self,
                           n_total=100,
                           n_per_job=10,
                           n_pools=16,
                           X_file='out/train_X.mem',
                           Y_file='out/train_Y.mem',
                           tmp_dir='out/train'):

        assert n_total % n_per_job == 0

        # init memory mapped files

        if os.path.isfile(X_file): os.unlink(X_file)
        train_X = np.memmap(
            X_file,
            np.float32,
            'w+',
            shape=(n_total, self.n_mels, self.msg_w, 1))

        if os.path.isfile(Y_file): os.unlink(Y_file)
        train_Y = np.memmap(
            Y_file, np.float32, 'w+', shape=(n_total, len(self.labels)))

        # cleanup tmp dir
        if os.path.isdir(tmp_dir):
            shutil.rmtree(tmp_dir)
        os.makedirs(tmp_dir)

        # launch generation in pool of workers

        n_jobs = n_total // n_per_job
        params = map(lambda x: [n_per_job, x * n_per_job, tmp_dir],
                     range(n_jobs))

        with Pool(n_pools) as p:
            p.starmap(self._gen_training_samples, list(params))

        # glue generated files together

        for i in range(0, n_total, n_per_job):
            X_file = '%s/X_%07d-%07d.npy' % (tmp_dir, i, i + n_per_job)
            Y_file = '%s/Y_%07d-%07d.npy' % (tmp_dir, i, i + n_per_job)
            X = np.load(X_file)
            Y = np.load(Y_file)
            train_X[i:i + n_per_job] = X
            train_Y[i:i + n_per_job] = Y

        train_X.flush()
        train_Y.flush()

        # cleanup
        shutil.rmtree(tmp_dir)

In [16]:
# dg = DataGenerator('/d2/caches/tf-speech/train/audio')

# dg.generate_train_set(
#     n_total=100,
#     n_per_job=10,
#     n_pools=16,
#     X_file='out/__train_X.npy',
#     Y_file='out/__train_Y.npy',
#     tmp_dir='out/__train')

In [None]:
# dg = DataGenerator('/d2/caches/tf-speech/train/audio')
# print(dg._normalize_wave_len([1,2,3,4], min_samples=8))
# print(dg._normalize_wave_len([1,2,3,4], max_samples=8, min_samples=8))