In [1]:
import os
import numpy as np
import torch
import sys


In [2]:
alphabet = ' абвгдеёжзийклмнопрстуфхцчшщъыьэюяАОУЫЭЯЁЮИЕ,.!?;:()'

i2a = {a:i for i,a in enumerate(alphabet)}
print(i2a)

{' ': 0, 'а': 1, 'б': 2, 'в': 3, 'г': 4, 'д': 5, 'е': 6, 'ё': 7, 'ж': 8, 'з': 9, 'и': 10, 'й': 11, 'к': 12, 'л': 13, 'м': 14, 'н': 15, 'о': 16, 'п': 17, 'р': 18, 'с': 19, 'т': 20, 'у': 21, 'ф': 22, 'х': 23, 'ц': 24, 'ч': 25, 'ш': 26, 'щ': 27, 'ъ': 28, 'ы': 29, 'ь': 30, 'э': 31, 'ю': 32, 'я': 33, 'А': 34, 'О': 35, 'У': 36, 'Ы': 37, 'Э': 38, 'Я': 39, 'Ё': 40, 'Ю': 41, 'И': 42, 'Е': 43, ',': 44, '.': 45, '!': 46, '?': 47, ';': 48, ':': 49, '(': 50, ')': 51}


In [4]:
class BatchSampler():
    def __init__(self, data, batch_size):
        self.data = data
        self.batch_size = batch_size
        self.i = 0
        self.n = len(self.data)
        self.max_x = 500
        self.max_y = 2000

    def reset(self):
        self.i = 0
        np.random.shuffle(self.data)

    def pad(self, a, l, value):
        if len(a.shape) == 1:
            z = value + np.zeros(l)
            m = np.zeros(l)
            z[:len(a)] = a
            m[:len(a)] = 1
        elif len(a.shape) == 2:
            z = np.zeros((l, a.shape[1]))
            m = np.zeros(l)
            z[:len(a)] = a
            z[len(a):] = value
            m[:len(a)] = 1
        return z, m
        
    def next_batch(self):
        if self.i + self.batch_size >= self.n:
            self.reset()

        batch_x = []
        batch_mask_x = []
        batch_y = []
        batch_mask_y = []
        max_x = 0
        max_y = 0
        for j in range(self.batch_size):
            indices = np.array([i2a[l] for l in self.data[self.i][0]])
            spec = self.data[self.i][1]

            text_padded, text_mask = self.pad(indices, self.max_x, 0)
            mel_padded, mel_mask = self.pad(spec, self.max_y, 0)

            batch_x.append(text_padded)
            batch_mask_x.append(text_mask)
            batch_y.append(mel_padded)
            batch_mask_y.append(mel_mask)
            if indices.shape[0] > max_x:
                max_x = indices.shape[0]
            if spec.shape[0] > max_y:
                max_y = spec.shape[0]
            self.i += 1

        batch_x = np.array(batch_x)[:, :max_x]
        batch_mask_x = np.array(batch_mask_x)[:, :max_x]
        batch_y = np.array(batch_y)[:, :max_y]
        batch_mask_y = np.array(batch_mask_y)[:, :max_y]
        
        return batch_x, batch_mask_x, batch_y, batch_mask_y


            
            
        

In [5]:
# gen mock data
def sample_random(min_x=1, max_x=300, min_y=10, max_y=1000):
    n = np.random.randint(min_x, max_x)
    k = np.random.randint(min_y, max_y)
    text = ''.join([np.random.choice(list(alphabet)) for i in range(n)])
    spec = np.array([np.random.random(size=80) for i in range(k)])
    return text, spec

n = 1024
batch_size = 32

data = np.array([sample_random() for i in range(n)])
print(len(data))

sampler = BatchSampler(data, batch_size)


1024


  data = np.array([sample_random() for i in range(n)])


In [7]:
batch = sampler.next_batch()
# print(batch[0])
for b in batch:
    print(b.shape)

(32, 292)
(32, 292)
(32, 969, 80)
(32, 969)


In [8]:
zero_count = 0
all_count = 0
for i in range(sampler.n // batch_size):
    batch = sampler.next_batch()
    mask = batch[1]
    for j in range(batch_size):
        zero_count += np.sum(1 - mask[j])
        all_count += mask[j].shape[0]
print(zero_count / all_count)

0.4937190327445067


In [13]:
class SmartSampler(BatchSampler):
    def __init__(self, data, batch_size):
        super().__init__(data, batch_size)

    def reset(self):
        d = np.array(list(map(lambda x: len(x[0]), self.data)))
        idx = np.argsort(d + np.random.randint(-2, 2, size=len(d)))
        self.data = self.data[idx]

        b = np.arange(len(self.data)).reshape(len(self.data) // self.batch_size, self.batch_size)
        np.random.shuffle(b)
        self.data = self.data[b.reshape(-1)]
        self.i = 0

In [14]:
smart_sampler = SmartSampler(data, batch_size)
smart_sampler.reset()



In [15]:
zero_count = 0
all_count = 0
for i in range(smart_sampler.n // batch_size):
    batch = smart_sampler.next_batch()
    mask = batch[1]
    for j in range(batch_size):
        zero_count += np.sum(1 - mask[j])
        all_count += mask[j].shape[0]
print(zero_count / all_count)

0.03260337056633353
