In [1]:
import os
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'mesolitica-tpu.json'

In [5]:
from glob import glob
import tensorflow as tf
from tqdm import tqdm
import malaya_speech
from malaya_speech.utils import subword
import numpy as np
import mp
from google.cloud import storage
from unidecode import unidecode

In [3]:
subwords = subword.load('transducer-singlish.subword')

In [6]:
wave_texts = glob('WAVE-text/*.TXT')

singlishs = []
for f in tqdm(wave_texts):
    speaker = f.split('/')[1].replace('.TXT', '')
    channel = speaker[-1]
    speaker = speaker[1:-1]
    
    with open(f) as fopen:
        texts = list(filter(None, fopen.read().split('\n')))
        texts = [texts[i: i + 2] for i in range(0, len(texts), 2)]
    
    for text in texts:
        splitted = text[0].split('\t')
        wav = unidecode(splitted[0])
        t = text[1].split('\t')[1]
        path = f'WAVE/SPEAKER{speaker}/SESSION{channel}/{wav}.WAV'
        
        if os.path.exists(path) and len(t):
            singlishs.append((path, t))
        else:
            print(splitted, path)

100%|██████████| 2034/2034 [00:47<00:00, 43.10it/s]


In [7]:
len(singlishs)

756342

In [8]:
import unicodedata
import re
import itertools

vocabs = [" ", "a", "e", "n", "i", "t", "o", "u", "s", "k", "r", "l", "h", "d", "m", "g", "y", "b", "p", "w", "c", "f", "j", "v", "z", "0", "1", "x", "2", "q", "5", "3", "4", "6", "9", "8", "7"]

def preprocessing_text(string):
    
    string = unicodedata.normalize('NFC', string.lower())
    string = string.replace('\'', '')
    string = ''.join([c if c in vocabs else ' ' for c in string])
    string = re.sub(r'[ ]+', ' ', string).strip()
    string = (
        ''.join(''.join(s)[:2] for _, s in itertools.groupby(string))
    )
    return string

In [9]:
def get_after_mandarin(word):
    if '<mandarin>' in word:
        w = word.split('>')[1].split(':')[1]
        return w.split('</')[0]
    else:
        return word
    
def get_before_mandarin(word):
    if '</mandarin>' in word:
        return word.split('</')[0]
    else:
        return word

def replace_paralinguistic(string, replaces = ['(ppb)', '(ppc)', '(ppl)', '(ppo)', '<UNK>', '<MANDARIN>']):
    for r in replaces:
        string = string.replace(r, ' ')
    string = string.split()
    string = [get_after_mandarin(w) for w in string]
    string = [get_before_mandarin(w) for w in string]
    string = [w for w in string if w[0] not in '<[(' and w[-1] not in '>])']
    return ' '.join(string)

In [10]:
singlishs[0]

('WAVE/SPEAKER0882/SESSION1/008821401.WAV',
 'a smile can often lift up a weary spirit')

In [11]:
def loop(files):
    files, index = files
    results = []
    for i in tqdm(files):
        try:
            text = i[1]
            if len(text) < 2:
                continue
            if text[0] == '<' and text[-1] == '>':
                continue
            text = replace_paralinguistic(text)
            text = preprocessing_text(text)
            if len(text):
                results.append((i[0], text))
        except Exception as e:
            pass
    return results

In [12]:
loop((singlishs[:10], 0))

100%|██████████| 10/10 [00:00<00:00, 4674.88it/s]


[('WAVE/SPEAKER0882/SESSION1/008821401.WAV',
  'a smile can often lift up a weary spirit'),
 ('WAVE/SPEAKER0882/SESSION1/008821402.WAV',
  'i was so tired from work i could not even bother to brush my teeth'),
 ('WAVE/SPEAKER0882/SESSION1/008821403.WAV',
  'a comma can change the meaning of a sentence entirely'),
 ('WAVE/SPEAKER0882/SESSION1/008821404.WAV',
  'before the internet we wrote letters to our pen pals and read magazines'),
 ('WAVE/SPEAKER0882/SESSION1/008821405.WAV',
  'it is easy to book flights and hotels on the computer'),
 ('WAVE/SPEAKER0882/SESSION1/008821406.WAV',
  'heavy rains caused a flood in the village'),
 ('WAVE/SPEAKER0882/SESSION1/008821407.WAV',
  'i get free snacks whenever i go to the supermarket'),
 ('WAVE/SPEAKER0882/SESSION1/008821408.WAV',
  'it is not safe to freeze something again after it has thawed'),
 ('WAVE/SPEAKER0882/SESSION1/008821409.WAV', 'we visited the persian gulf'),
 ('WAVE/SPEAKER0882/SESSION1/008821410.WAV',
  'the crowd guffawed at the

In [13]:
singlishs = mp.multiprocessing(singlishs, loop, cores = 12)

100%|██████████| 63028/63028 [01:02<00:00, 1004.43it/s]
100%|██████████| 6/6 [00:00<00:00, 39.48it/s]1.78it/s] 
100%|██████████| 63028/63028 [01:04<00:00, 978.65it/s]]
100%|██████████| 63028/63028 [01:04<00:00, 972.88it/s] 
100%|██████████| 63028/63028 [01:06<00:00, 944.33it/s]]
100%|██████████| 63028/63028 [01:04<00:00, 982.84it/s] 
100%|██████████| 63028/63028 [01:05<00:00, 967.33it/s] 
100%|██████████| 63028/63028 [01:06<00:00, 942.44it/s] 
100%|██████████| 63028/63028 [01:08<00:00, 913.49it/s] 
100%|██████████| 63028/63028 [01:08<00:00, 918.91it/s] 
100%|██████████| 63028/63028 [01:10<00:00, 895.03it/s] 
100%|██████████| 63028/63028 [01:08<00:00, 920.04it/s] 
100%|██████████| 63028/63028 [01:07<00:00, 931.54it/s] 


In [14]:
len(singlishs)

755913

In [15]:
import six

def to_example(dictionary):
    """Helper: build tf.Example from (string -> int/float/str list) dictionary."""
    features = {}
    for (k, v) in six.iteritems(dictionary):
        if not v:
            raise ValueError('Empty generated field: %s' % str((k, v)))
        # Subtly in PY2 vs PY3, map is not scriptable in py3. As a result,
        # map objects will fail with TypeError, unless converted to a list.
        if six.PY3 and isinstance(v, map):
            v = list(v)
        if isinstance(v[0], six.integer_types) or np.issubdtype(
            type(v[0]), np.integer
        ):
            features[k] = tf.train.Feature(
                int64_list=tf.train.Int64List(value=v)
            )
        elif isinstance(v[0], float):
            features[k] = tf.train.Feature(
                float_list=tf.train.FloatList(value=v)
            )
        elif isinstance(v[0], six.string_types):
            if not six.PY2:  # Convert in python 3.
                v = [bytes(x, 'utf-8') for x in v]
            features[k] = tf.train.Feature(
                bytes_list=tf.train.BytesList(value=v)
            )
        elif isinstance(v[0], bytes):
            features[k] = tf.train.Feature(
                bytes_list=tf.train.BytesList(value=v)
            )
        else:
            raise ValueError(
                'Value for %s is not a recognized type; v: %s type: %s'
                % (k, str(v[0]), str(type(v[0])))
            )
    return tf.train.Example(features=tf.train.Features(feature=features))

In [16]:
sr = 16000
maxlen = 18
minlen_text = 1
global_count = 0

In [17]:
def loop(files):
    client = storage.Client()
    bucket = client.bucket('mesolitica-tpu-general')
    files, index = files
    output_file = f'{index}-{global_count}.tfrecord'
    writer = tf.io.TFRecordWriter(output_file)
    for s in tqdm(files):
        try:
            if len(s[1]) < minlen_text:
                continue
            y, _ = malaya_speech.load(s[0])
            if (len(y) / sr) > maxlen:
                continue
            t = subword.encode(subwords, s[1], add_blank=False)
            example = to_example({'waveforms': y.tolist(), 
                                  'targets': t, 
                                  'targets_length': [len(t)]})
            writer.write(example.SerializeToString())
        except Exception as e:
            print(e)
            pass
    writer.close()
    blob = bucket.blob(f'imda/part1/{output_file}')
    blob.upload_from_filename(output_file)
    os.system(f'rm {output_file}')

In [18]:
loop((singlishs[:10], 0))

100%|██████████| 10/10 [00:00<00:00, 22.34it/s]


In [19]:
batch_size = 25000
for i in range(0, len(singlishs), batch_size):
    batch = singlishs[i: i + batch_size]
    mp.multiprocessing(batch, loop, cores = 6, returned = False)
    global_count += 1

100%|██████████| 4166/4166 [06:06<00:00, 11.36it/s]
100%|██████████| 4166/4166 [06:16<00:00, 11.07it/s]
100%|██████████| 4166/4166 [06:17<00:00, 11.04it/s]
100%|██████████| 4166/4166 [06:17<00:00, 11.03it/s]
100%|██████████| 4166/4166 [06:20<00:00, 10.94it/s]
100%|██████████| 4166/4166 [06:22<00:00, 10.90it/s]
100%|██████████| 4/4 [00:00<00:00, 62.77it/s]
100%|██████████| 4166/4166 [04:37<00:00, 15.02it/s]
100%|██████████| 4166/4166 [04:43<00:00, 14.67it/s]
100%|██████████| 4166/4166 [04:44<00:00, 14.63it/s]
100%|██████████| 4166/4166 [04:44<00:00, 14.63it/s]
100%|██████████| 4166/4166 [04:49<00:00, 14.38it/s]
100%|██████████| 4166/4166 [04:50<00:00, 14.35it/s]
100%|██████████| 4/4 [00:00<00:00, 43.29it/s]
100%|██████████| 4166/4166 [05:35<00:00, 12.42it/s]
100%|██████████| 4166/4166 [05:38<00:00, 12.30it/s]
100%|██████████| 4166/4166 [05:40<00:00, 12.22it/s]
100%|██████████| 4166/4166 [05:41<00:00, 12.21it/s]
100%|██████████| 4166/4166 [05:47<00:00, 11.99it/s]
100%|██████████| 4166/41

In [20]:
from malaya_speech.utils import tf_featurization

config = malaya_speech.config.transducer_featurizer_config
featurizer = tf_featurization.STTFeaturizer(**config)

In [21]:
n_mels = 80

def preprocess_inputs(example):
    s = featurizer.vectorize(example['waveforms'])
    mel_fbanks = tf.reshape(s, (-1, n_mels))
    example['inputs'] = mel_fbanks
    return example

def parse(serialized_example):

    data_fields = {
        'waveforms': tf.compat.v1.VarLenFeature(tf.float32),
        'targets': tf.compat.v1.VarLenFeature(tf.int64),
        'targets_length': tf.compat.v1.VarLenFeature(tf.int64),
    }
    features = tf.compat.v1.parse_single_example(
        serialized_example, features = data_fields
    )
    for k in features.keys():
        features[k] = features[k].values
        
    features = preprocess_inputs(features)

    keys = list(features.keys())
    for k in keys:
        if k not in ['waveforms', 'inputs', 'targets', 'targets_length']:
            features.pop(k, None)

    return features

def get_dataset(files, batch_size = 2, shuffle_size = 32, thread_count = 24):
    def get():
        dataset = tf.data.TFRecordDataset(files)
        dataset = dataset.shuffle(shuffle_size)
        dataset = dataset.map(parse, num_parallel_calls = thread_count)
        dataset = dataset.repeat()
        return dataset

    return get

In [22]:
files = tf.io.gfile.glob('gs://mesolitica-tpu-general/imda/part1/*.tfrecord')
d = get_dataset(files)()
d = d.as_numpy_iterator()

In [23]:
next(d)

{'targets': array([  7, 568, 151,  13, 105, 299, 795,  20,   5, 168, 444,  19,   4,
        236,   2,   7, 403, 133, 278,  56, 177, 389, 884]),
 'targets_length': array([23]),
 'waveforms': array([-0.00017538, -0.00026307, -0.00026307, ..., -0.00017538,
        -0.00017538, -0.00026307], dtype=float32),
 'inputs': array([[-2.2310197, -2.2621276, -2.3545487, ..., -1.1028278, -1.209693 ,
         -1.302464 ],
        [-2.756736 , -2.1352987, -1.8970875, ..., -1.2289646, -1.2374092,
         -1.4299338],
        [-2.2006068, -2.2834134, -2.8744638, ..., -1.0061882, -1.1931208,
         -1.4570584],
        ...,
        [-1.6727058, -1.6292751, -1.5672526, ..., -1.251422 , -1.2502371,
         -1.2261595],
        [-2.3043866, -2.2051554, -2.0876248, ..., -1.1450766, -1.1473254,
         -1.1928668],
        [-2.6993673, -2.6790707, -2.650311 , ..., -1.1693419, -1.1078019,
         -1.2259248]], dtype=float32)}