<center> 
    <h1> Speech Transformer : A Speech to Text Transformer in TensorFlow 2 </h1>
    <h2> Training and Decoding </h2>
</center>

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import sys
import pickle
from tensorflow_addons.image import sparse_image_warp
import tensorflow_datasets as tfds

In [None]:
class Config:

    def __init__(self):

        #training
        self.TPU = True
        self.TRAIN = True
        self.EPOCHS = 15
        self.DTYPE = tf.bfloat16
        self.BATCH_SIZE = 128
        self.LABEL_SMOOTHING = 0.2
        self.NEIGHBORHOOD_SMOOTHING = False
        self.WARMUP_STEPS = 4000
        self.K = 1.5
        self.BETA1 = 0.9
        self.BETA2 = 0.98
        self.EPS = 1e-9
        self.AUG = True

        #model
        self.N_ENC = 3
        self.N_DEC = 3 
        self.UNITS = 2048 
        self.D_MODEL = 512 
        self.NUM_HEADS = 8 
        self.DROPOUT = 0.1
        self.CNN = True

        #samples
        self.ENCODING = 'subword' # subword / character 
        self.MAX_SAMPLE = float('inf')
        self.MAX_LENGTH = 75 # 75 subword / 250 character
        self.MAX_SPEC_LENGTH = 1600
        self.D_SPEC = 80
        self.VOCAB_SIZE = None
        self.START_TOKEN = None
        self.END_TOKEN = None
        self.MAX_POSITION_ENCODING = None

        #saves
        self.DATASET_PATH = 'path/to/dataset'
        self.CALLBACKS_PATH = 'path/to/callback'
        self.LOAD = True
        self.SAVED_MODEL = 'checkpoints.h5'

        #decoding
        self.BEAM_SIZE = 10
        self.MAX_REP = 2 # 2 subword / 3 character 

C = Config()

In [None]:
if C.TPU:
    # Create distribution strategy
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)

#Prepare Dataset

In [None]:
#load the whole dataset for training
#else only load the 100 hours version

if C.TRAIN:
    with open(os.path.join(C.DATASET_PATH, "pickled_LibriSpeech_clean_10ms.pkl"), "br") as f:
        data = pickle.load(f)
    spec1 = data[0]
    speech1 = data[1]

    with open(os.path.join(C.DATASET_PATH, "pickled_LibriSpeech_other_10ms.pkl"), "br") as f:
        data = pickle.load(f)
    spec2 = data[0]
    speech2 = data[1]

    spec_train = spec1 + spec2
    speech_train = speech1 + speech2
    print("{} training spec".format(len(spec_train)))
    print("{} training speech".format(len(speech_train)))
else:
    with open(os.path.join(C.DATASET_PATH, "pickled_LibriSpeech_clean100_10ms.pkl"), "br") as f:
        data = pickle.load(f)
    spec_train = data[0]
    speech_train = data[1]
    print("{} training spec".format(len(spec_train)))
    print("{} training speech".format(len(speech_train)))

with open(os.path.join(C.DATASET_PATH, "pickled_LibriSpeech_dev_clean_10ms.pkl"), "br") as f:
    data = pickle.load(f)
spec_val = data[0]
speech_val = data[1]
print("{} validation spec".format(len(spec_val)))
print("{} validation speech".format(len(speech_val)))

In [None]:
if not os.path.isfile(os.path.join(C.CALLBACKS_PATH,'tokenizer.pkl')):
    if C.ENCODING == 'subword':
        tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus(speech_train+speech_val, target_vocab_size=2**13)
    else:
        tokenizer = tf.keras.preprocessing.text.Tokenizer(char_level=True)
        tokenizer.fit_on_texts(speech_train)

    if not os.path.isdir(C.CALLBACKS_PATH):
        os.mkdir(C.CALLBACKS_PATH)
    with open(os.path.join(C.CALLBACKS_PATH,'tokenizer.pkl'), 'bw') as f:
        pickle.dump(tokenizer, f, protocol=4)
else:
    with open(os.path.join(C.CALLBACKS_PATH,'tokenizer.pkl'), 'br') as f:
        tokenizer = pickle.load(f)

if C.ENCODING == 'subword':
    C.START_TOKEN, C.END_TOKEN = [tokenizer.vocab_size], [tokenizer.vocab_size + 1]
    C.VOCAB_SIZE = tokenizer.vocab_size + 2
else:
    C.START_TOKEN, C.END_TOKEN = [len(tokenizer.word_counts) + 1], [len(tokenizer.word_counts) + 2]
    C.VOCAB_SIZE = len(tokenizer.word_counts) + 3
print("VOCAB_SIZE = {}".format(C.VOCAB_SIZE))

In [None]:
if C.MAX_SPEC_LENGTH == None:
    max = 0
    for S in spec_train:
        max = tf.maximum(max, S.shape[0])
    C.MAX_SPEC_LENGTH = max
    print('MAX_SPEC_LENGTH :', C.MAX_SPEC_LENGTH.numpy())

if C.MAX_LENGTH == None:
    max = 0
    for s in speech_train:
        if C.ENCODING == 'subword':
            max = tf.maximum(max, len(tokenizer.encode(s))+2)
        else:
            max = tf.maximum(max, len(tokenizer.texts_to_sequences([s])[0])+2)
    C.MAX_LENGTH = max
    print('MAX_LENGTH :', C.MAX_LENGTH.numpy())

if C.MAX_POSITION_ENCODING == None:
    C.MAX_POSITION_ENCODING = tf.maximum(C.MAX_LENGTH, C.MAX_SPEC_LENGTH)
    print('MAX_POSITION_ENCODING :', C.MAX_POSITION_ENCODING.numpy())


if not os.path.isdir(C.CALLBACKS_PATH):
    os.mkdir(C.CALLBACKS_PATH)
with open(os.path.join(C.CALLBACKS_PATH,'config.pkl'), 'bw') as f:
    pickle.dump(C, f, protocol=4)

In [None]:
tokenized_speech_train = []
spec_audio_train = []
s=0
for (S,spe) in zip(spec_train, speech_train):
    if C.ENCODING == 'subword':
        spe = C.START_TOKEN + tokenizer.encode(spe) + C.END_TOKEN
    else:
        spe = C.START_TOKEN + tokenizer.texts_to_sequences([spe])[0] + C.END_TOKEN

    if len(spe) <= C.MAX_LENGTH and len(S) <= C.MAX_SPEC_LENGTH:
        tokenized_speech_train.append(spe)
        S = tf.cast(S/25, C.DTYPE)
        S = tf.concat([S,tf.zeros((C.MAX_SPEC_LENGTH-len(S),C.D_SPEC), dtype=C.DTYPE)], axis=0)

        spec_audio_train.append(S)

        s+=1
        sys.stdout.write("\r{}/{}".format(s,C.MAX_SAMPLE))
        if s==C.MAX_SAMPLE:
            break

tokenized_speech_train = tf.keras.preprocessing.sequence.pad_sequences(tokenized_speech_train, maxlen=C.MAX_LENGTH, padding='post')

print()
print("{} training tokenized speech".format(len(tokenized_speech_train)))
print("{} training spectrogram".format(len(spec_audio_train)))


tokenized_speech_val = []
spec_audio_val = []
s=0
for (S,spe) in zip(spec_val, speech_val):
    if C.ENCODING == 'subword':
        spe = C.START_TOKEN + tokenizer.encode(spe) + C.END_TOKEN
    else:
        spe = C.START_TOKEN + tokenizer.texts_to_sequences([spe])[0] + C.END_TOKEN

    if len(spe) <= C.MAX_LENGTH and len(S) <= C.MAX_SPEC_LENGTH:
        tokenized_speech_val.append(spe)
        S = tf.cast(S/25, C.DTYPE)
        S = tf.concat([S,tf.zeros((C.MAX_SPEC_LENGTH-len(S),C.D_SPEC), dtype=C.DTYPE)], axis=0)

        spec_audio_val.append(S)

        s+=1
        sys.stdout.write("\r{}/{}".format(s,C.MAX_SAMPLE))
        if s==C.MAX_SAMPLE:
            break

tokenized_speech_val = tf.keras.preprocessing.sequence.pad_sequences(tokenized_speech_val, maxlen=C.MAX_LENGTH, padding='post')

print()
print("{} validation tokenized speech".format(len(tokenized_speech_val)))
print("{} validation spectrogram".format(len(spec_audio_val)))

In [None]:
def freq_mask(input, param, name=None):
    freq_max = tf.shape(input)[1]
    f = tf.random.uniform(shape=(), minval=0, maxval=param, dtype=tf.dtypes.int32)
    f0 = tf.random.uniform(
        shape=(), minval=0, maxval=freq_max - f, dtype=tf.dtypes.int32
    )
    indices = tf.reshape(tf.range(freq_max), (1, -1))
    condition = tf.math.logical_and(
        tf.math.greater_equal(indices, f0), tf.math.less(indices, f0 + f)
    )
    return tf.where(condition, tf.cast(0.0, C.DTYPE), input)

def sparse_warp(mel_spectrogram, param):

    mel_spectrogram = tf.expand_dims(tf.expand_dims(mel_spectrogram, axis=0), axis=-1)#(1, T, F, 1)
    mel_spectrogram = tf.cast(mel_spectrogram, tf.float32)

    fbank_size = tf.shape(mel_spectrogram)
    T, F = fbank_size[1], fbank_size[2]

    pt = tf.random.uniform([], param, T-param, tf.int32)
    src_ctr_pt_freq = [F//2]
    src_ctr_pt_time = tf.ones_like(src_ctr_pt_freq) * pt
    src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1)
    src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32)

    w = tf.random.uniform([], -param, param, tf.int32)
    dest_ctr_pt_freq = src_ctr_pt_freq
    dest_ctr_pt_time = src_ctr_pt_time + w
    dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1)
    dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32)

    source_control_point_locations = tf.expand_dims(src_ctr_pts, 0)
    dest_control_point_locations = tf.expand_dims(dest_ctr_pts, 0) 

    warped_image, _ = sparse_image_warp(mel_spectrogram,
                                        source_control_point_locations,
                                        dest_control_point_locations, 
                                        num_boundary_points=1)
    return tf.cast(warped_image, C.DTYPE)[0,:,:,0]

def augment(inputs, outputs):
    """
    (LD) Librispeech Double policy augmentation without time masking
    F = 27
    W = 80
    p = 100%
    """
    inputs['inputs'] = sparse_warp(inputs['inputs'], param=80)
    inputs['inputs'] = freq_mask(inputs['inputs'], param=27)
    inputs['inputs'] = freq_mask(inputs['inputs'], param=27)
    return inputs, outputs

In [None]:
dataset_train = tf.data.Dataset.from_tensor_slices((
    {
        'inputs': spec_audio_train,
        'dec_inputs': tokenized_speech_train[:, :-1]
    },
    {
        'outputs': tokenized_speech_train[:, 1:]
    },
))
dataset_train = dataset_train.cache()
dataset_train = dataset_train.shuffle(buffer_size=len(spec_audio_train))
if C.AUG:
    dataset_train = dataset_train.map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset_train = dataset_train.batch(C.BATCH_SIZE, drop_remainder=True)
dataset_train = dataset_train.prefetch(tf.data.experimental.AUTOTUNE)
print(dataset_train)

dataset_val = tf.data.Dataset.from_tensor_slices((
    {
        'inputs': spec_audio_val,
        'dec_inputs': tokenized_speech_val[:, :-1]
    },
    {
        'outputs': tokenized_speech_val[:, 1:]
    },
))
dataset_val = dataset_val.cache()
dataset_val = dataset_val.shuffle(buffer_size=len(spec_audio_val))
dataset_val = dataset_val.batch(C.BATCH_SIZE, drop_remainder=True)
dataset_val = dataset_val.prefetch(tf.data.experimental.AUTOTUNE)
print(dataset_val)

#Model

In [None]:
def scaled_dot_product_attention(query, key, value, mask):
  """Calculate the attention weights. """
  matmul_qk = tf.matmul(query, key, transpose_b=True)

  # scale matmul_qk
  depth = tf.cast(tf.shape(key)[-1], tf.float32)
  logits = matmul_qk / tf.math.sqrt(depth)

  # add the mask to zero out padding tokens
  if mask is not None:
    logits += (mask * -1e9)

  # softmax is normalized on the last axis (seq_len_k)
  attention_weights = tf.nn.softmax(logits, axis=-1)

  output = tf.matmul(attention_weights, value)

  return output

class MultiHeadAttention(tf.keras.layers.Layer):

  def __init__(self, d_model, num_heads, name="multi_head_attention"):
    super(MultiHeadAttention, self).__init__(name=name)
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.query_dense = tf.keras.layers.Dense(units=d_model)
    self.key_dense = tf.keras.layers.Dense(units=d_model)
    self.value_dense = tf.keras.layers.Dense(units=d_model)

    self.dense = tf.keras.layers.Dense(units=d_model)

  def split_heads(self, inputs, batch_size):
    inputs = tf.reshape(
        inputs, shape=(batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(inputs, perm=[0, 2, 1, 3])

  def call(self, inputs):
    query, key, value, mask = inputs['query'], inputs['key'], inputs[
        'value'], inputs['mask']
    batch_size = tf.shape(query)[0]

    # linear layers
    query = self.query_dense(query)
    key = self.key_dense(key)
    value = self.value_dense(value)

    # split heads
    query = self.split_heads(query, batch_size)
    key = self.split_heads(key, batch_size)
    value = self.split_heads(value, batch_size)

    # scaled dot-product attention
    scaled_attention = scaled_dot_product_attention(query, key, value, mask)

    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])

    # concatenation of heads
    concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))

    # final linear layer
    outputs = self.dense(concat_attention)

    return outputs

In [None]:
def create_padding_mask(x):
  mask = tf.cast(tf.math.equal(x, 0), tf.float32)
  # (batch_size, 1, 1, sequence length)
  return mask[:, tf.newaxis, tf.newaxis, :]

def create_look_ahead_mask(x):
  seq_len = tf.shape(x)[1]
  look_ahead_mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
  padding_mask = create_padding_mask(x)
  return tf.maximum(look_ahead_mask, padding_mask)

In [None]:
class PositionalEncoding(tf.keras.layers.Layer):

  def __init__(self, position, d_model):
    super(PositionalEncoding, self).__init__()
    self.pos_encoding = self.positional_encoding(position, d_model)

  def get_angles(self, position, i, d_model):
    angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
    return position * angles

  def positional_encoding(self, position, d_model):
    angle_rads = (self.get_angles(
        position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
        i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
        d_model=d_model)).numpy()
    
    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = tf.math.sin(angle_rads[:, 0::2])
    
    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = tf.math.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[tf.newaxis, ...]

    return tf.cast(pos_encoding, tf.float32)

  def call(self, inputs):
    return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]

In [None]:
def encoder_layer(units, d_model, num_heads, dropout, name="encoder_layer"):
  inputs = tf.keras.Input(shape=(None, d_model), name="inputs")
  padding_mask = tf.keras.Input(shape=(1, 1, None), name="padding_mask")

  attention = tf.keras.layers.LayerNormalization(epsilon=1e-6)(inputs)
  attention = MultiHeadAttention(
      d_model, num_heads, name="attention")({
          'query': attention,
          'key': attention,
          'value': attention,
          'mask': padding_mask
      })
  attention = tf.keras.layers.Dropout(rate=dropout)(attention)
  attention = attention + inputs

  outputs = tf.keras.layers.LayerNormalization(epsilon=1e-6)(attention)
  outputs = tf.keras.layers.Dense(units=units, activation='relu')(outputs)
  outputs = tf.keras.layers.Dense(units=d_model)(outputs)
  outputs = tf.keras.layers.Dropout(rate=dropout)(outputs)
  outputs = outputs + attention

  return tf.keras.Model(inputs=[inputs, padding_mask],
                        outputs=outputs,
                        name=name)

def encoder(maximum_position_encoding,
            num_layers,
            units,
            d_model,
            num_heads,
            dropout,
            name="encoder"):
  inputs = tf.keras.Input(shape=(None,d_model,), name="inputs")
  padding_mask = tf.keras.Input(shape=(1, 1, None), name="padding_mask")

  embeddings = inputs
      
  embeddings *= tf.math.sqrt(tf.cast(d_model, tf.float32))
  embeddings = PositionalEncoding(maximum_position_encoding, d_model)(embeddings)

  outputs = tf.keras.layers.Dropout(rate=dropout)(embeddings)

  for i in range(num_layers):
    outputs = encoder_layer(
        units=units,
        d_model=d_model,
        num_heads=num_heads,
        dropout=dropout,
        name="encoder_layer_{}".format(i),
    )([outputs, padding_mask])

  outputs = tf.keras.layers.LayerNormalization(epsilon=1e-6)(outputs)

  return tf.keras.Model(inputs=[inputs, padding_mask],
                        outputs=outputs,
                        name=name)

In [None]:
def decoder_layer(units, d_model, num_heads, dropout, name="decoder_layer"):
  inputs = tf.keras.Input(shape=(None, d_model), name="inputs")
  enc_outputs = tf.keras.Input(shape=(None, d_model), name="encoder_outputs")
  look_ahead_mask = tf.keras.Input(shape=(1, None, None), name="look_ahead_mask")
  padding_mask = tf.keras.Input(shape=(1, 1, None), name='padding_mask')

  attention1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(inputs)
  attention1 = MultiHeadAttention(
      d_model, num_heads, name="attention_1")(inputs={
          'query': attention1,
          'key': attention1,
          'value': attention1,
          'mask': look_ahead_mask
      })
  attention1 = tf.keras.layers.Dropout(rate=dropout)(attention1)
  attention1 = attention1 + inputs

  attention2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(attention1)
  attention2 = MultiHeadAttention(
      d_model, num_heads, name="attention_2")(inputs={
          'query': attention2,
          'key': enc_outputs,
          'value': enc_outputs,
          'mask': padding_mask
      })
  attention2 = tf.keras.layers.Dropout(rate=dropout)(attention2)
  attention2 = attention2 + attention1

  outputs = tf.keras.layers.LayerNormalization(epsilon=1e-6)(attention2)
  outputs = tf.keras.layers.Dense(units=units, activation='relu')(outputs)
  outputs = tf.keras.layers.Dense(units=d_model)(outputs)
  outputs = tf.keras.layers.Dropout(rate=dropout)(outputs)
  outputs = outputs + attention2

  return tf.keras.Model(
      inputs=[inputs, enc_outputs, look_ahead_mask, padding_mask],
      outputs=outputs,
      name=name)
  
def decoder(vocab_size,
            maximum_position_encoding,
            num_layers,
            units,
            d_model,
            num_heads,
            dropout,
            name='decoder'):
  inputs = tf.keras.Input(shape=(None,), name='inputs')
  enc_outputs = tf.keras.Input(shape=(None, d_model), name='encoder_outputs')
  look_ahead_mask = tf.keras.Input(shape=(1, None, None), name='look_ahead_mask')
  padding_mask = tf.keras.Input(shape=(1, 1, None), name='padding_mask')
  
  embeddings = tf.keras.layers.Embedding(vocab_size, d_model)(inputs)
  embeddings *= tf.math.sqrt(tf.cast(d_model, tf.float32))
  embeddings = PositionalEncoding(maximum_position_encoding, d_model)(embeddings)

  outputs = tf.keras.layers.Dropout(rate=dropout)(embeddings)

  for i in range(num_layers):
    outputs = decoder_layer(
        units=units,
        d_model=d_model,
        num_heads=num_heads,
        dropout=dropout,
        name='decoder_layer_{}'.format(i),
    )(inputs=[outputs, enc_outputs, look_ahead_mask, padding_mask])

  outputs = tf.keras.layers.LayerNormalization(epsilon=1e-6)(outputs)

  return tf.keras.Model(inputs=[inputs, enc_outputs, look_ahead_mask, padding_mask],
                        outputs=outputs,
                        name=name)

In [None]:
def transformer(vocab_size,
                maximum_position_encoding,
                num_layers_enc,
                num_layers_dec,
                units,
                d_spec,
                d_model,
                num_heads,
                dropout,
                cnn,
                name="transformer"):
  inputs = tf.keras.Input(shape=(None,d_spec), name="inputs")
  dec_inputs = tf.keras.Input(shape=(None,), name="dec_inputs")

  if cnn:
      x = tf.expand_dims(inputs, axis=-1)

      #block 1
      x = tf.keras.layers.Conv2D(32, (3,3), padding='same', activation='relu')(x)
      x = tf.keras.layers.Conv2D(32, (3,3), padding='same', activation='relu')(x)
      x = tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=(2,2))(x)

      #block2
      x = tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu')(x)
      x = tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu')(x)
      x = tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=(2,2))(x)

      x = tf.keras.layers.Reshape((-1,(d_spec//4)*64))(x)
      inputs_strided = tf.keras.layers.MaxPool1D(pool_size=4, strides=4)(inputs)
  else:
      x = inputs
      inputs_strided = inputs
      
  x = tf.keras.layers.Dense(d_model)(x)

  inputs_masks = tf.dtypes.cast(     
      tf.math.reduce_sum(
      inputs_strided,
      axis=2,
      keepdims=False,
  ), tf.int32)

  #creating padding mask
  enc_padding_mask = tf.keras.layers.Lambda(create_padding_mask, output_shape=(1, 1, None),name='enc_padding_mask')(inputs_masks)

  # mask the future tokens for decoder inputs at the 1st attention block
  look_ahead_mask = tf.keras.layers.Lambda(create_look_ahead_mask,output_shape=(1, None, None),name='look_ahead_mask')(dec_inputs)

  # mask the encoder outputs for the 2nd attention block
  dec_padding_mask = tf.keras.layers.Lambda(create_padding_mask, output_shape=(1, 1, None),name='dec_padding_mask')(inputs_masks)

  enc_outputs = encoder(
      maximum_position_encoding=maximum_position_encoding,
      num_layers=num_layers_enc,
      units=units,
      d_model=d_model,
      num_heads=num_heads,
      dropout=dropout,
  )(inputs=[x, enc_padding_mask])

  dec_outputs = decoder(
      vocab_size=vocab_size,
      maximum_position_encoding=maximum_position_encoding,
      num_layers=num_layers_dec,
      units=units,
      d_model=d_model,
      num_heads=num_heads,
      dropout=dropout,
  )(inputs=[dec_inputs, enc_outputs, look_ahead_mask, dec_padding_mask])

  outputs = tf.keras.layers.Dense(units=vocab_size, name="outputs")(dec_outputs)

  return tf.keras.Model(inputs=[inputs, dec_inputs],
                        outputs=outputs,
                        name=name)

In [None]:
def loss_function(y_true, y_pred):

  mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
  N = tf.math.count_nonzero(mask, dtype=tf.float32)

  if C.NEIGHBORHOOD_SMOOTHING:
      y_true_smoothed = neighborhood_smoothing(y_true)
  else:
      y_true_smoothed = tf.one_hot(tf.cast(y_true, tf.int32), C.VOCAB_SIZE, on_value=1-C.LABEL_SMOOTHING, off_value=C.LABEL_SMOOTHING/C.VOCAB_SIZE)

  loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction='none')(y_true_smoothed , y_pred)
  loss = tf.multiply(loss, mask)

  return tf.math.reduce_sum(loss)/N

def neighborhood_smoothing(y_true):
    y_true = tf.cast(y_true, tf.int32)
    y_true_masked = tf.where(y_true!=0, y_true, -1)

    y_true_one_hot = tf.one_hot(y_true_masked, C.VOCAB_SIZE, on_value=1-C.LABEL_SMOOTHING)

    y_true_shift_left_2 = tf.roll(y_true_masked, -2, axis=1)
    y_true_shift_left_2_one_hot = tf.one_hot(y_true_shift_left_2, C.VOCAB_SIZE, on_value=C.LABEL_SMOOTHING/6)

    y_true_shift_left_1 = tf.roll(y_true_masked, -1, axis=1)
    y_true_shift_left_1_one_hot = tf.one_hot(y_true_shift_left_1, C.VOCAB_SIZE, on_value=C.LABEL_SMOOTHING/3)

    y_true_shift_right_1 = tf.roll(y_true_masked, 1, axis=1)
    y_true_shift_right_1_one_hot = tf.one_hot(y_true_shift_right_1, C.VOCAB_SIZE, on_value=C.LABEL_SMOOTHING/3)

    y_true_shift_right_2 = tf.roll(y_true_masked, 2, axis=1)
    y_true_shift_right_2_one_hot = tf.one_hot(y_true_shift_right_2, C.VOCAB_SIZE, on_value=C.LABEL_SMOOTHING/6)

    y_true_smoothed = tf.math.reduce_max([y_true_shift_left_2_one_hot, y_true_shift_left_1_one_hot, y_true_one_hot, y_true_shift_right_1_one_hot, y_true_shift_right_2_one_hot], axis=0)

    return y_true_smoothed

def accuracy(y_true, y_pred):

  y_pred = tf.cast(tf.argmax(y_pred, axis=-1), tf.float32)
  N = tf.math.count_nonzero(y_true, dtype=tf.float32)

  return tf.math.reduce_sum(tf.cast(tf.math.equal(y_true, y_pred), tf.float32))/N

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):

  def __init__(self, d_model, warmup_steps,k):
    super(CustomSchedule, self).__init__()

    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps
    self.k = k

  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps**-1.5)

    return self.k*tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


In [None]:
def create_model():
    model = transformer(
            vocab_size=C.VOCAB_SIZE,
            maximum_position_encoding=C.MAX_POSITION_ENCODING,
            num_layers_enc=C.N_ENC,
            num_layers_dec=C.N_DEC,   
            units=C.UNITS,
            d_spec=C.D_SPEC,
            d_model=C.D_MODEL,
            num_heads=C.NUM_HEADS,
            dropout=C.DROPOUT,
            cnn=C.CNN)
    
    learning_rate = CustomSchedule(C.D_MODEL, C.WARMUP_STEPS, C.K)
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=C.BETA1, beta_2=C.BETA2, epsilon=C.EPS)
    model.compile(optimizer=optimizer, loss=loss_function, metrics=[accuracy])

    return model


#Training

In [None]:
tf.keras.backend.clear_session()

if C.TPU:
    with strategy.scope():
        model = create_model()
else:
    model = create_model()

model.summary()

In [None]:
path = C.CALLBACKS_PATH

if not os.path.isdir(path):
    os.mkdir(path)

callbacks = [tf.keras.callbacks.CSVLogger(os.path.join(path,'logs.csv'),append=True),
            tf.keras.callbacks.ModelCheckpoint(os.path.join(path,'checkpoints.h5'), save_weights_only=True, save_best_only=True),
            tf.keras.callbacks.ModelCheckpoint(os.path.join(path,'checkpoints_{epoch:04d}.h5'), save_weights_only=True)]

if os.path.isfile(os.path.join(path,C.SAVED_MODEL)) and C.LOAD:
    model.load_weights(os.path.join(path,C.SAVED_MODEL))
    print ('Checkpoint restored')
    if C.TRAIN:
        epoch = int(C.SAVED_MODEL.split('_')[1][:4])
        model.optimizer.iterations.assign(epoch*len(spec_audio_train)//C.BATCH_SIZE)
        print("{} iterations trained".format(model.optimizer.iterations.numpy()))
        print("{} epochs trained".format(epoch))
else:
    print('first training')
    epoch = 0


In [None]:
if C.TRAIN:
    model.fit(dataset_train, epochs=C.EPOCHS, validation_data=dataset_val, callbacks=callbacks, initial_epoch=epoch)

#Evaluation

In [None]:
def gready_search_decoding(input_encoder, model, tokenizer, config, verbose=0):

    input_encoder = tf.expand_dims(input_encoder, axis=0)
    input_decoder = tf.expand_dims(config.START_TOKEN, axis=0)

    for i in range(config.MAX_LENGTH):
        if verbose:
          if config.ENCODING == 'subword':
              sys.stdout.write("\r{}".format( tokenizer.decode([j for j in input_decoder.numpy()[0] if j < config.VOCAB_SIZE-2]) ))
          else:
              sys.stdout.write("\r{}".format( tokenizer.sequences_to_texts([[j for j in input_decoder.numpy()[0] if j < config.VOCAB_SIZE-2]]) ))

        predictions = model(inputs=[input_encoder, input_decoder], training=False)
        predictions = predictions[:, -1:, :]
        predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)

        if tf.equal(predicted_id, config.END_TOKEN):
            break

        input_decoder = tf.concat([input_decoder, predicted_id], axis=-1)

    if verbose:
        print()

    if config.ENCODING == 'subword':
        return tokenizer.decode([i for i in tf.squeeze(input_decoder, axis=0).numpy() if i < config.VOCAB_SIZE-2])
    else:
        return tokenizer.sequences_to_texts([[i for i in tf.squeeze(input_decoder, axis=0).numpy() if i < config.VOCAB_SIZE-2]])[0][::2]


def beam_search_decoding(input_encoder, model, tokenizer, config, verbose=0):

    input_encoder = tf.expand_dims(input_encoder, axis=0)
    input_decoder = tf.expand_dims(config.START_TOKEN, axis=0)

    k_scores = [0.0]

    for i in range(config.MAX_LENGTH):
        if verbose:
            print('\nStep', i)
            if config.ENCODING == 'subword':
                for k in range(input_decoder.shape[0]):
                    print( tokenizer.decode([j for j in input_decoder.numpy()[k] if j < config.VOCAB_SIZE-2]) )
            else:
                for k in range(input_decoder.shape[0]):
                    print( tokenizer.sequences_to_texts([[j for j in input_decoder.numpy()[k] if j < config.VOCAB_SIZE-2]]) )
        predictions = model(inputs=[input_encoder, input_decoder], training=False)
        predictions = predictions[:, -1:, :]
        values, indices = tf.math.top_k(tf.math.log(tf.nn.softmax(predictions)), config.BEAM_SIZE)

        sequences = []
        scores = []

        for k in range(input_decoder.shape[0]):
            for b in range(config.BEAM_SIZE):
                sequences.append(tf.concat([input_decoder[k], [indices[k,0,b]]], axis=0))
                if i>=config.MAX_REP and len(tf.unique(sequences[-1][-config.MAX_REP:])[0])==1:
                    scores.append(k_scores[k] - float('inf'))
                else:
                    scores.append(k_scores[k] + values[k,0,b])

        values, indices = tf.math.top_k(scores, config.BEAM_SIZE)
          
        k_scores = []
        input_decoder = []
        for k in range(config.BEAM_SIZE):
            k_scores.append(values[k])
            input_decoder.append(sequences[indices[k]])
        input_decoder = tf.stack(input_decoder)

        if input_encoder.shape[0] == 1:
            input_encoder = tf.repeat(input_encoder, config.BEAM_SIZE, axis=0)

        if tf.equal(input_decoder[0,-1], config.END_TOKEN):
            break

    if verbose:
        print()

    if config.ENCODING == 'subword':
        return tokenizer.decode([i for i in input_decoder[0].numpy() if i < config.VOCAB_SIZE-2])
    else:
        return tokenizer.sequences_to_texts([[i for i in input_decoder[0].numpy() if i < config.VOCAB_SIZE-2]])[0][::2]




In [None]:
s = tf.random.uniform([], 0, len(spec_audio_train), dtype=tf.int32)

print('Input Spec : {}'.format(spec_audio_train[s].shape))
plt.matshow(tf.transpose(tf.cast(spec_audio_train[s], tf.float32)), origin='lower')
plt.show()

pred = beam_search_decoding(spec_audio_train[s], model, tokenizer, C, verbose=1)
if C.ENCODING == 'subword':
    print('Speech : \n{}'.format( tokenizer.decode([i for i in tokenized_speech_train[s] if i < C.VOCAB_SIZE-2]) ))
else:
    print('Speech : \n{}'.format( tokenizer.sequences_to_texts([[i for i in tokenized_speech_train[s] if i < C.VOCAB_SIZE-2]])[0][::2] ))
print('Prediction : \n{}'.format( pred ))

In [None]:
s = tf.random.uniform([], 0, len(spec_audio_val), dtype=tf.int32)

print('Input Spec : {}'.format(spec_audio_val[s].shape))
plt.matshow(tf.transpose(tf.cast(spec_audio_val[s], tf.float32)), origin='lower')
plt.show()

pred = beam_search_decoding(spec_audio_val[s], model, tokenizer, C, verbose=0)
if C.ENCODING == 'subword':
    print('Speech : \n{}'.format( tokenizer.decode([i for i in tokenized_speech_val[s] if i < C.VOCAB_SIZE-2]) ))
else:
    print('Speech : \n{}'.format( tokenizer.sequences_to_texts([[i for i in tokenized_speech_val[s] if i < C.VOCAB_SIZE-2]])[0][::2] ))
print('Prediction : \n{}'.format( pred ))