In [1]:
from glob import glob
import random

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [3]:
import librosa
import numpy as np

def load_wav(vid_path, sr = 16000, mode='eval'):
    wav, sr_ret = librosa.load(vid_path, sr=sr)
    assert sr_ret == sr
    if mode == 'train':
        extended_wav = np.append(wav, wav)
        if np.random.random() < 0.3:
            extended_wav = extended_wav[::-1]
        return extended_wav
    else:
        extended_wav = np.append(wav, wav[::-1])
        return extended_wav


def lin_spectogram_from_wav(wav, hop_length, win_length, n_fft=1024):
    linear = librosa.stft(wav, n_fft=n_fft, win_length=win_length, hop_length=hop_length)
    return linear.T


def load_data(wav, win_length=400, sr=16000, hop_length=160, n_fft=512, spec_len=120, mode='train'):
    # wav = load_wav(path, sr=sr, mode=mode)
    linear_spect = lin_spectogram_from_wav(wav, hop_length, win_length, n_fft)
    mag, _ = librosa.magphase(linear_spect)  # magnitude
    mag_T = mag.T
    freq, time = mag_T.shape
    if mode == 'train':
        if time < spec_len:
            spec_mag = np.pad(mag_T, ((0, 0), (0, spec_len - time)), 'constant')
        else:
            spec_mag = mag_T
    else:
        spec_mag = mag_T
    mu = np.mean(spec_mag, 0, keepdims=True)
    std = np.std(spec_mag, 0, keepdims=True)
    return (spec_mag - mu) / (std + 1e-5)

def padding_sequence_nd(
    seq, maxlen = None, padding: str = 'post', pad_val = 0.0, dim: int = 1
):
    if padding not in ['post', 'pre']:
        raise ValueError('padding only supported [`post`, `pre`]')

    if not maxlen:
        maxlen = max([np.shape(s)[dim] for s in seq])

    padded_seqs = []
    for s in seq:
        npad = [[0, 0] for _ in range(len(s.shape))]
        if padding == 'pre':
            padding = 0
        if padding == 'post':
            padding = 1
        npad[dim][padding] = maxlen - s.shape[dim]
        padded_seqs.append(
            np.pad(
                s,
                pad_width = npad,
                mode = 'constant',
                constant_values = pad_val,
            )
        )
    return np.array(padded_seqs)

def add_noise(samples, noise, random_sample = True, factor = 0.1):
    y_noise = samples.copy()
    if len(y_noise) > len(noise):
        noise = np.tile(noise, int(np.ceil(len(y_noise) / len(noise))))
    else:
        if random_sample:
            noise = noise[np.random.randint(0, len(noise) - len(y_noise) + 1) :]
    return y_noise + noise[: len(y_noise)] * factor

def frames(
    audio,
    frame_duration_ms: int = 30,
    sample_rate: int = 16000,
):

    n = int(sample_rate * (frame_duration_ms / 1000.0))
    offset = 0
    timestamp = 0.0
    duration = float(n) / sample_rate
    results = []
    while offset + n < len(audio):
        results.append(audio[offset : offset + n])
        timestamp += duration
        offset += n
    return results

In [4]:
import json

with open('indices.json') as fopen:
    data = json.load(fopen)

files = data['files']
speakers = data['speakers']

In [5]:
def get_id(file):
    return file.split('/')[-1].split('-')[1]

In [6]:
unique_speakers = sorted(list(speakers.keys()))
unique_speakers.index(get_id(files[1]))

5368

In [7]:
import pickle

with open('../noise/noise.pkl', 'rb') as fopen:
    noises = pickle.load(fopen)

In [20]:
from sklearn.utils import shuffle
import itertools

cycle_files = itertools.cycle(files)
batch_size = 32

def generate(partition = 100, batch_size = batch_size, sample_rate = 16000, max_length = 5):
    while True:
        batch_files = [next(cycle_files) for _ in range(partition)]
        X, Y = [], []
        for file in batch_files:
            y = unique_speakers.index(get_id(file))
            w = load_wav(file)
            if len(w) / sample_rate > max_length:
                X.append(w[:sample_rate * max_length])
                Y.append(y)
            for _ in range(random.randint(1, 3)):
                f = frames(w, random.randint(500, max_length * 1000))
                X.extend(f)
                Y.extend([y] * len(f))
        
        for k in range(len(X)):
            if random.randint(0, 1):
                for _ in range(random.randint(1, 5)):
                    x = add_noise(X[k], random.choice(noises), random.uniform(0.1, 0.6))
                    X.append(x)
                    Y.append(Y[k])
        

        actual_X, actual_Y = [], []

        for k in range(len(X)):
            try:
                actual_X.append(load_data(X[k]))
                actual_Y.append(Y[k])
            except:
                pass

        X, Y = shuffle(actual_X, actual_Y)
        
        for k in range(len(X)):
            yield {'inputs': np.expand_dims(X[k], -1), 'targets': [Y[k]]}

#         for k in range(0, (len(X) // batch_size) * batch_size, batch_size):
#             batch_x = X[k: k + batch_size]
#             batch_y = Y[k: k + batch_size]
            
#             yield {'inputs': padding_sequence_nd(batch_x), 'targets': batch_y}

In [21]:
import tensorflow as tf

In [22]:
def reshape(example):
    print(example)
    return example

dataset = tf.data.Dataset.from_generator(generate, {'inputs': tf.float32, 'targets': tf.int32},
                                        output_shapes={'inputs': tf.TensorShape([257, None, 1]), 
                                                       'targets': tf.TensorShape([1])})
dataset = dataset.padded_batch(
    batch_size,
    padded_shapes = {
        'inputs': tf.TensorShape([257, None, 1]),
        'targets': tf.TensorShape([None]),
    },
    padding_values = {
        'inputs': tf.constant(0, dtype = tf.float32),
        'targets': tf.constant(0, dtype = tf.int32),
    },
)
iterator = dataset.make_one_shot_iterator().get_next()

In [23]:
iterator

{'inputs': <tf.Tensor 'IteratorGetNext_2:0' shape=(?, 257, ?, 1) dtype=float32>,
 'targets': <tf.Tensor 'IteratorGetNext_2:1' shape=(?, ?) dtype=int32>}

In [12]:
sess = tf.InteractiveSession()

In [33]:
r = sess.run(iterator)
r['inputs'].shape, r['targets'].shape

((32, 257, 501, 1), (32, 1))

In [34]:
r['targets']

array([[4622],
       [ 178],
       [2491],
       [  71],
       [5550],
       [ 178],
       [2302],
       [3193],
       [3193],
       [4712],
       [5507],
       [5386],
       [2491],
       [5487],
       [5606],
       [5419],
       [4661],
       [4114],
       [1728],
       [5536],
       [5507],
       [4565],
       [2639],
       [4622],
       [3750],
       [5982],
       [5789],
       [5386],
       [4545],
       [ 178],
       [3754],
       [5507]], dtype=int32)