In [None]:
import tensorflow as tf
import numpy as np
from python_speech_features import mfcc, fbank, delta
from sklearn.preprocessing import StandardScaler
import scipy.io.wavfile as wav
import subprocess
import os, time

In [None]:
## original phonemes
phn_61 = ['aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', 'b', 'bcl', 'ch', 'd', 'dcl', 'dh', 'dx', 'eh', 'el', 'em', 'en', 'eng', 'epi', 'er', 'ey', 'f', 'g', 'gcl', 'h#', 'hh', 'hv', 'ih', 'ix', 'iy', 'jh', 'k', 'kcl', 'l', 'm', 'n', 'ng', 'nx', 'ow', 'oy', 'p', 'pau', 'pcl', 'q', 'r', 's', 'sh', 't', 'tcl', 'th', 'uh', 'uw', 'ux', 'v', 'w', 'y', 'z', 'zh']

mapping = {'ah': 'ax', 'ax-h': 'ax', 'ux': 'uw', 'aa': 'ao', 'ih': 'ix', \
               'axr': 'er', 'el': 'l', 'em': 'm', 'en': 'n', 'nx': 'n',\
               'eng': 'ng', 'sh': 'zh', 'hv': 'hh', 'bcl': 'h#', 'pcl': 'h#',\
               'dcl': 'h#', 'tcl': 'h#', 'gcl': 'h#', 'kcl': 'h#',\
               'q': 'h#', 'epi': 'h#', 'pau': 'h#'}

phn_39 = ['ae', 'ao', 'aw', 'ax', 'ay', 'b', 'ch', 'd', 'dh', 'dx', 'eh', \
             'er', 'ey', 'f', 'g', 'h#', 'hh', 'ix', 'iy', 'jh', 'k', 'l', \
             'm', 'n', 'ng', 'ow', 'oy', 'p', 'r', 's', 't', 'th', 'uh', 'uw',\
             'v', 'w', 'y', 'z', 'zh']

TIMIT_DIR = './timit'
DATA_DIR = './data'

In [None]:
def generate_tfrecords_from_timit(feat_type='mfcc'):
        
    if not os.path.isdir(DATA_DIR):
        os.makedirs(DATA_DIR)
    
    for data_type in ['train', 'test']:
        timit_dir = os.path.join(TIMIT_DIR, data_type)
        writer = tf.python_io.TFRecordWriter(os.path.join(DATA_DIR, (data_type + '.tfrecords')))
        feats_list = []
        phoneme_list = []
        start = time.time()
        cnt = 0
        for path, dirs, files in os.walk(timit_dir):
            for file in files:
                if file.startswith('sa'): # exclude all 'SA' files according to 'https://github.com/zzw922cn/Automatic_Speech_Recognition'
                    continue
                if file.endswith('wav'):
                    # .wav
                    fullFileName = os.path.join(path, file)
                    fnameNoSuffix = os.path.splitext(fullFileName)[0]
                    fNameTmp = fnameNoSuffix + '_tmp.wav'
                    subprocess.call(['sox', fullFileName, fNameTmp], shell=True)
                    rate, sig = wav.read(fNameTmp)
                    os.remove(fNameTmp)

                    if feat_type == 'mfcc':
                        mfcc_feat = mfcc(sig, rate)
                        mfcc_feat_delta = delta(mfcc_feat, 2)
                        mfcc_feat_delta_delta = delta(mfcc_feat_delta, 2)
                        feats = np.concatenate((mfcc_feat, mfcc_feat_delta, mfcc_feat_delta_delta), axis=-1)
                    else: # log Mel-filterbank energy + total energy
                        filters, energy = fbank(sig, rate, nfilt=40)
                        log_filters = np.log(filters)
                        logfbank_feat = np.concatenate((log_filters, energy.reshape(-1,1)), axis=-1)
                        logfbank_feat_delta = delta(logfbank_feat, 2)
                        logfbank_feat_delta_delta = delta(logfbank_feat_delta, 2)
                        feats = np.concatenate((logfbank_feat, logfbank_feat_delta, logfbank_feat_delta_delta), axis=-1)
                    feats_list.append(feats)

                    # .phn
                    phoneme = []
                    with open(fnameNoSuffix + '.phn', 'r') as f:
                        for line in f.read().splitlines():
                            phn = line.split(' ')[2]
                            p_index = phn_61.index(phn)
                            phoneme.append(p_index)
                    phoneme_list.append(phoneme)

                    cnt += 1
        if data_type == 'train':
            scaler = StandardScaler()
            scaler.fit(np.concatenate(feats_list, axis=0))
        for feats, phoneme in zip(feats_list, phoneme_list):
            seq_exam = tf.train.SequenceExample()
            seq_exam.context.feature['feats_dim'].int64_list.value.append(feats.shape[1])
            seq_exam.context.feature['feats_seq_len'].int64_list.value.append(feats.shape[0])
            seq_exam.context.feature['labels_seq_len'].int64_list.value.append(len(phoneme))

            scaler.transform(feats)
            for feat in feats:
                seq_exam.feature_lists.feature_list['features'].feature.add().float_list.value[:] = feat
            for p in phoneme:
                seq_exam.feature_lists.feature_list['labels'].feature.add().int64_list.value.append(p)
            writer.write(seq_exam.SerializeToString())
                    
        writer.close()
        print('{}-{}: {} utterances - {:.0f}s'.format(data_type, feat_type, cnt, (time.time()-start)))

In [None]:
generate_tfrecords_from_timit()

## example 1

In [None]:
filename_queue = tf.train.string_input_producer([os.path.join(DATA_DIR, 'test.tfrecords')], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
context_features = {'feats_dim': tf.FixedLenFeature([], dtype=tf.int64), 
                    'feats_seq_len': tf.FixedLenFeature([], dtype=tf.int64),
                    'labels_seq_len': tf.FixedLenFeature([], dtype=tf.int64)}
sequence_features = {'features': tf.FixedLenSequenceFeature([39], dtype=tf.float32),
                     'labels': tf.FixedLenSequenceFeature([], dtype=tf.int64)}
context_parsed, sequence_parsed = tf.parse_single_sequence_example(serialized_example,
                                                                  context_features=context_features,
                                                                  sequence_features=sequence_features)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    try:
        step = 0
        while not coord.should_stop():
            print('step: {}'.format(step))
            feats_dim, feats_seq_len, labels_seq_len, features, labels = sess.run([context_parsed['feats_dim'], 
                                                                context_parsed['feats_seq_len'], context_parsed['labels_seq_len'],
                                                                sequence_parsed['features'], sequence_parsed['labels']])
            print(feats_seq_len)
            print(labels_seq_len)
            print()
            step += 1
            
            if step >= 1:
                break
    except tf.errors.OutOfRangeError:
        print('Done')
    finally:
        coord.request_stop()
        
    coord.join(threads)

In [None]:
feats_dim, features.shape, labels.shape

## example 2

In [None]:
test_dataset = tf.contrib.data.TFRecordDataset(os.path.join(DATA_DIR, 'test.tfrecords'))
context_features = {'feats_dim': tf.FixedLenFeature([], dtype=tf.int64),
                    'feats_seq_len': tf.FixedLenFeature([], dtype=tf.int64),
                    'labels_seq_len': tf.FixedLenFeature([], dtype=tf.int64)}
sequence_features = {'features': tf.FixedLenSequenceFeature([39], dtype=tf.float32),
                     'labels': tf.FixedLenSequenceFeature([], dtype=tf.int64)}
test_dataset = test_dataset.map(lambda serialized_example : tf.parse_single_sequence_example(serialized_example, 
                                                                  context_features=context_features,
                                                                  sequence_features=sequence_features))
test_dataset = test_dataset.map(lambda context, sequence: (context['feats_dim'], context['feats_seq_len'],
                                                        context['labels_seq_len'], sequence['features'], sequence['labels']))
test_iterator = test_dataset.make_initializable_iterator()
feats_dim, feats_seq_len, labels_seq_len, features, labels = test_iterator.get_next()

with tf.Session() as sess:
    sess.run(test_iterator.initializer)
    
    while True:
        step = 0
        try:
            r = sess.run([feats_dim, feats_seq_len, labels_seq_len, features, labels])
            step += 1
            
            if step >= 1:
                break
        except tf.errors.OutOfRangeError:
                print('finish epoch')

In [None]:
r[0], r[1], r[2], r[3].shape, r[4].shape