In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/home/husein/t5/prepare/mesolitica-tpu.json'

In [2]:
import tensorflow as tf
import random

In [3]:
files = tf.io.gfile.glob('gs://mesolitica-tpu-general/imda/*/*.tfrecord')
len(files)

978

In [4]:
files[0].split('gs://mesolitica-tpu-general/imda/')[1].split('/')[0]

'part1'

In [5]:
from collections import defaultdict

parts = defaultdict(list)
for f in files:
    part = f.split('gs://mesolitica-tpu-general/imda/')[1].split('/')[0]
    parts[part].append(f)

In [6]:
parts.keys()

dict_keys(['part1', 'part2', 'part3', 'part4-diff-room', 'part4-same-room', 'part5-debate', 'part6-call-centre-1', 'part6-call-centre-2'])

In [7]:
test_set = []
train_set = []
for part in parts.keys():
    if len(parts[part]) >= 100:
        choice = random.choice(parts[part])
        train = list(set(parts[part]) - set([choice]))
        train_set.extend(train)
        test_set.append(choice)
    else:
        train_set.extend(parts[part])
len(test_set), len(train_set)

(4, 974)

In [8]:
test_set

['gs://mesolitica-tpu-general/imda/part1/0-29.tfrecord',
 'gs://mesolitica-tpu-general/imda/part2/5-30.tfrecord',
 'gs://mesolitica-tpu-general/imda/part3/4-16.tfrecord',
 'gs://mesolitica-tpu-general/imda/part4-same-room/3-2.tfrecord']

In [9]:
import json

with open('imda-tfrecords.json', 'w') as fopen:
    json.dump({'train': train_set, 'test': test_set}, fopen)

In [10]:
import numpy as np
import malaya_speech.train as train
import malaya_speech.config
import malaya_speech.train.model.transducer as transducer
import malaya_speech.train.model.conformer as conformer
import malaya_speech.augmentation.spectrogram as mask_augmentation
import malaya_speech.augmentation.waveform as augmentation
import malaya_speech
import tensorflow as tf






The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [17]:
featurizer = malaya_speech.tf_featurization.STTFeaturizer(
    normalize_per_feature=True
)
n_mels = featurizer.num_feature_bins

def mel_augmentation(features):

    features = mask_augmentation.warp_time_pil(features)
    features = mask_augmentation.mask_frequency(features, width_freq_mask=12)
    features = mask_augmentation.mask_time(
        features, width_time_mask=int(features.shape[0] * 0.05)
    )
    return features


def preprocess_inputs(example):
    s = featurizer.vectorize(example['waveforms'])
    s = tf.reshape(s, (-1, n_mels))
    s = tf.compat.v1.numpy_function(mel_augmentation, [s], tf.float32)
    mel_fbanks = tf.reshape(s, (-1, n_mels))
    length = tf.cast(tf.shape(mel_fbanks)[0], tf.int32)
    length = tf.expand_dims(length, 0)
    example['inputs'] = mel_fbanks
    example['inputs_length'] = length
    example['targets'] = tf.cast(example['targets'], tf.int32)
    example['targets_length'] = tf.cast(example['targets_length'], tf.int32)
    return example

def parse(serialized_example):

    data_fields = {
        'waveforms': tf.compat.v1.VarLenFeature(tf.float32),
        'targets': tf.compat.v1.VarLenFeature(tf.int64),
        'targets_length': tf.compat.v1.VarLenFeature(tf.int64),
    }
    features = tf.compat.v1.parse_single_example(
        serialized_example, features=data_fields
    )
    for k in features.keys():
        features[k] = features[k].values

    features = preprocess_inputs(features)

    keys = list(features.keys())
    for k in keys:
        if k not in ['inputs', 'inputs_length', 'targets', 'targets_length']:
            features.pop(k, None)

    return features

In [18]:
def get_dataset(files, batch_size=20, shuffle_size=32, num_cpu_threads=4,
                thread_count=24, is_training=True):
    def get():
        if is_training:
            d = tf.data.Dataset.from_tensor_slices(tf.constant(files))
            d = d.repeat()
            d = d.shuffle(buffer_size=len(files))
            cycle_length = min(num_cpu_threads, len(files))
            d = d.apply(
                tf.contrib.data.parallel_interleave(
                    tf.data.TFRecordDataset,
                    sloppy=is_training,
                    cycle_length=cycle_length))
            d = d.shuffle(buffer_size=100)
        else:
            d = tf.data.TFRecordDataset(files)
            d = d.repeat()
        d = d.map(parse, num_parallel_calls=thread_count)
        d = d.padded_batch(
            batch_size,
            padded_shapes={
                'inputs': tf.TensorShape([None, n_mels]),
                'inputs_length': tf.TensorShape([None]),
                'targets': tf.TensorShape([None]),
                'targets_length': tf.TensorShape([None]),
            },
            padding_values={
                'inputs': tf.constant(0, dtype=tf.float32),
                'inputs_length': tf.constant(0, dtype=tf.int32),
                'targets': tf.constant(0, dtype=tf.int32),
                'targets_length': tf.constant(0, dtype=tf.int32),
            },
        )
        return d

    return get

In [25]:
train_dataset = get_dataset(train_set, batch_size = 2)().make_one_shot_iterator().get_next()
train_dataset

{'targets': <tf.Tensor 'IteratorGetNext_1:2' shape=(?, ?) dtype=int32>,
 'targets_length': <tf.Tensor 'IteratorGetNext_1:3' shape=(?, ?) dtype=int32>,
 'inputs': <tf.Tensor 'IteratorGetNext_1:0' shape=(?, ?, 80) dtype=float32>,
 'inputs_length': <tf.Tensor 'IteratorGetNext_1:1' shape=(?, ?) dtype=int32>}

In [24]:
sess = tf.Session()

In [26]:
sess.run(train_dataset)

{'targets': array([[ 66,  14,   3,  20,   5,  20,   5,   7, 272,  11, 378, 147, 148,
           0,   0,   0,   0,   0,   0],
        [103,  12, 795,  20, 114,  14,   3,  67, 136,  71, 795,  52,  21,
          34,  16,  87,  62, 382, 875]], dtype=int32),
 'targets_length': array([[13],
        [19]], dtype=int32),
 'inputs': array([[[-0.38590848, -0.85252655, -0.6197265 , ..., -2.2818308 ,
          -2.0294595 , -2.3721614 ],
         [-0.59135187, -1.0706898 , -0.7697544 , ..., -2.206452  ,
          -2.0517805 , -1.6242082 ],
         [-0.46245736, -0.9674944 , -0.719194  , ..., -1.5186559 ,
          -1.6715912 , -2.067461  ],
         ...,
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]],
 
        [[ 0.        ,  0.        

In [27]:
test_dataset = get_dataset(test_set, batch_size = 2, is_training=False)().make_one_shot_iterator().get_next()
test_dataset

{'targets': <tf.Tensor 'IteratorGetNext_2:2' shape=(?, ?) dtype=int32>,
 'targets_length': <tf.Tensor 'IteratorGetNext_2:3' shape=(?, ?) dtype=int32>,
 'inputs': <tf.Tensor 'IteratorGetNext_2:0' shape=(?, ?, 80) dtype=float32>,
 'inputs_length': <tf.Tensor 'IteratorGetNext_2:1' shape=(?, ?) dtype=int32>}

In [28]:
sess.run(test_dataset)

{'targets': array([[111, 161,   2, 112,   1, 795, 592,   3, 102, 235, 795, 544, 663,
           3, 103,   7, 272,  11, 271, 533,  48,  12, 795,  58, 795,  12,
          22, 516,  17,  54, 300, 795,  34,  31, 226,   0,   0],
        [ 59, 225, 135,  41, 795, 397, 213, 365,   2,   7,  92, 352, 218,
         594,  25,  26, 795,  76,   1,  22,  75,  29, 795,   1, 795, 551,
         164, 795, 237,   9, 200,  15, 534,   1,  22, 330,  81]],
       dtype=int32), 'targets_length': array([[35],
        [37]], dtype=int32), 'inputs': array([[[ 6.23412669e-01,  2.11047888e-01,  1.18641414e-01, ...,
           0.00000000e+00,  0.00000000e+00, -1.73207760e+00],
         [ 7.89191663e-01,  1.65041089e-01, -4.14103091e-01, ...,
           0.00000000e+00,  0.00000000e+00, -1.46829021e+00],
         [ 6.07244968e-01, -2.01284041e-04, -3.59744161e-01, ...,
           0.00000000e+00,  0.00000000e+00, -1.21086931e+00],
         ...,
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
        