In [39]:
import os, re
from glob import glob
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import layers
import soundfile as sf
import jiwer

os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [62]:
USE_EXTRA_DATA = True
CACHE_DIR = "."
DATA_SUBDIR = "datasets"

BATCH_SIZE = 8
VAL_BATCH  = 8
EPOCHS = 30
MAX_TARGET_LEN = 200
AUDIO_PAD_LEN = 2754            # ~10s
FFT_LENGTH = 256
HOP = 80
WIN = 200
FEAT_DIM = FFT_LENGTH // 2 + 1  # 129 for fft_length=256
START_TOKEN_IDX = 2             # '<'
END_TOKEN_IDX   = 3             # '>'

In [2]:
pip install tensorflow-io


Note: you may need to restart the kernel to use updated packages.


ERROR: Could not find a version that satisfies the requirement tensorflow-io (from versions: none)
ERROR: No matching distribution found for tensorflow-io


In [1]:
pip install -U tensorflow


Collecting tensorflow
  Using cached tensorflow-2.20.0-cp312-cp312-win_amd64.whl.metadata (4.6 kB)
Using cached tensorflow-2.20.0-cp312-cp312-win_amd64.whl (331.9 MB)
Installing collected packages: tensorflow
  Attempting uninstall: tensorflow
    Found existing installation: tensorflow 2.19.0
    Uninstalling tensorflow-2.19.0:
      Successfully uninstalled tensorflow-2.19.0
Successfully installed tensorflow-2.20.0
Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install soundfile


Note: you may need to restart the kernel to use updated packages.


# Download the dataset & extract the dataset

In [None]:
train100_dir = f"./{DATA_SUBDIR}/train-clean-100/LibriSpeech/train-clean-100"
train360_url = "https://www.openslr.org/resources/12/train-clean-360.tar.gz"
train500_url = "https://www.openslr.org/resources/12/train-other-500.tar.gz"


if USE_EXTRA_DATA:
    keras.utils.get_file(fname="train-clean-360.tar.gz", origin=train360_url,
                         extract=True, cache_dir=CACHE_DIR, cache_subdir=DATA_SUBDIR)
    keras.utils.get_file(fname="train-other-500.tar.gz", origin=train500_url,
                         extract=True, cache_dir=CACHE_DIR, cache_subdir=DATA_SUBDIR)



In [42]:
candidates = glob(f"./{DATA_SUBDIR}/**/LibriSpeech/**/train-*", recursive=True)
all_data_dirs = []
if os.path.isdir(train100_dir):
    all_data_dirs.append(train100_dir)
for c in candidates:
    if c not in all_data_dirs:
        all_data_dirs.append(c)
print("Using data folders:", all_data_dirs)

# ---------------- Transcripts -----------
pattern_wav_name = re.compile(r"([^/\\\.]+)")
id_to_text = {}
for folder in all_data_dirs:
    trans_files = glob(f"{folder}/**/*.trans.txt", recursive=True)
    for trans_path in trans_files:
        with open(trans_path, "r", encoding="utf-8") as f:
            for line in f:
                parts = line.strip().split(" ", 1)
                if len(parts) == 2:
                    utt_id, text = parts
                    id_to_text[utt_id] = text.lower()


Using data folders: ['./datasets\\train-clean-100_extracted\\LibriSpeech\\train-clean-100', './datasets\\train-clean-360_extracted\\LibriSpeech\\train-clean-360', './datasets\\train-other-500_extracted\\LibriSpeech\\train-other-500']


# 3. MAP AUDIO FILES TO TRANSCRIPTS

In [43]:
def get_data(audio_files, id_to_text, maxlen=200):
    data = []
    for f in audio_files:
        utt_id = os.path.splitext(os.path.basename(f))[0]
        if utt_id in id_to_text and len(id_to_text[utt_id]) < maxlen:
            data.append({"audio": f, "text": id_to_text[utt_id]})
    return data

flacs = glob(f"{DATA_SUBDIR}/**/*.flac", recursive=True)
print("Total FLAC files:", len(flacs))
data = get_data(flacs, id_to_text, maxlen=MAX_TARGET_LEN)
print("Usable examples:", len(data))


Total FLAC files: 281241
Usable examples: 162743


# 4. VECTORIZER

In [63]:
class VectorizeChar:
    def __init__(self, max_len=MAX_TARGET_LEN):
        self.vocab = (
            ["-", "#", "<", ">"]
            + [chr(i + 96) for i in range(1, 27)]
            + [" ", ".", ",", "?", "'"]
        )
        self.max_len = max_len
        self.char_to_idx = {ch: i for i, ch in enumerate(self.vocab)}

    def __call__(self, text):
        text = text.lower()[: self.max_len - 2]
        text = "<" + text + ">"
        pad_len = self.max_len - len(text)
        return [self.char_to_idx.get(ch, 1) for ch in text] + [0] * pad_len

    def get_vocabulary(self):
        return self.vocab

vectorizer = VectorizeChar(MAX_TARGET_LEN)
idx_to_char = vectorizer.get_vocabulary()
vocab_size = len(idx_to_char)
print("Vocab size:", vocab_size)


Vocab size: 35


# 5. AUDIO PROCESSING


In [64]:
def decode_flac_py(path):
    path = path.numpy().decode("utf-8")
    audio, sr = sf.read(path, dtype="float32")
    if audio.ndim > 1:
        audio = audio.mean(axis=1)
    return audio

@tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.string)])
def path_to_audio_graph(path):
    audio = tf.py_function(func=decode_flac_py, inp=[path], Tout=tf.float32)
    audio.set_shape([None])
    stfts = tf.signal.stft(audio, frame_length=WIN, frame_step=HOP, fft_length=FFT_LENGTH)
    x = tf.math.pow(tf.abs(stfts), 0.5)   # [T, FEAT_DIM]
    x.set_shape([None, FEAT_DIM])
    means = tf.reduce_mean(x, axis=1, keepdims=True)
    stds = tf.math.reduce_std(x, axis=1, keepdims=True)
    x = (x - means) / (stds + 1e-9)
    paddings = tf.constant([[0, AUDIO_PAD_LEN], [0, 0]])
    x = tf.pad(x, paddings)[:AUDIO_PAD_LEN, :]
    x.set_shape([AUDIO_PAD_LEN, FEAT_DIM])
    return x

def create_tf_dataset(samples, bs=BATCH_SIZE, shuffle=True, augment=False):
    flist = [ex["audio"] for ex in samples]
    texts = [vectorizer(ex["text"]) for ex in samples]
    ds_paths = tf.data.Dataset.from_tensor_slices(tf.constant(flist, dtype=tf.string))
    ds_texts = tf.data.Dataset.from_tensor_slices(tf.constant(texts, dtype=tf.int32))
    ds_audio = ds_paths.map(path_to_audio_graph, num_parallel_calls=tf.data.AUTOTUNE)
    ds = tf.data.Dataset.zip((ds_audio, ds_texts))
    ds = ds.map(lambda a, t: {"source": a, "target": t}, num_parallel_calls=tf.data.AUTOTUNE)
    if shuffle:
        ds = ds.shuffle(4096, reshuffle_each_iteration=True)
    ds = ds.batch(bs).prefetch(tf.data.AUTOTUNE)
    return ds



# 6. TRAIN/VAL SPLIT


In [65]:
split = int(len(data) * 0.9)
train_data, test_data = data[:split], data[split:]
print("Train examples:", len(train_data), "Val examples:", len(test_data))

train_ds = create_tf_dataset(train_data, bs=BATCH_SIZE, shuffle=True)
val_ds = create_tf_dataset(test_data, bs=VAL_BATCH, shuffle=False)

# Sanity shapes
for b in train_ds.take(1):
    print("Batch 'source' shape:", b["source"].shape)  # (B, AUDIO_PAD_LEN, FEAT_DIM)
    print("Batch 'target' shape:", b["target"].shape)  # (B, MAX_TARGET_LEN)


Train examples: 146468 Val examples: 16275
Batch 'source' shape: (8, 2754, 129)
Batch 'target' shape: (8, 200)


# Define the Transformer input layer

In [66]:
class TokenEmbedding(layers.Layer):
    def __init__(self, num_vocab, maxlen, num_hid):
        super().__init__()
        self.emb = layers.Embedding(num_vocab, num_hid)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=num_hid)

    def call(self, x):
        L = tf.shape(x)[-1]
        x = self.emb(x)
        pos = self.pos_emb(tf.range(L))
        return x + pos

In [67]:
class SpeechFeatureEmbedding(layers.Layer):
    def __init__(self, num_hid=128):
        super().__init__()
        self.conv1 = layers.Conv1D(num_hid, 3, strides=2, padding='same', activation='relu')
        self.conv2 = layers.Conv1D(num_hid, 3, strides=2, padding='same', activation='relu')
        self.conv3 = layers.Conv1D(num_hid, 3, strides=2, padding='same', activation='relu')

    def call(self, x):
        # x: [B, T, F]
        return self.conv3(self.conv2(self.conv1(x)))


# Transformer Encoder Layer

In [68]:
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential([layers.Dense(feed_forward_dim, activation='relu'), layers.Dense(embed_dim)])
        self.ln1 = layers.LayerNormalization(epsilon=1e-6)
        self.ln2 = layers.LayerNormalization(epsilon=1e-6)
        self.do1 = layers.Dropout(rate)
        self.do2 = layers.Dropout(rate)

    def call(self, x, training=False):
        attn = self.att(x, x)
        attn = self.do1(attn, training=training)
        out1 = self.ln1(x + attn)
        ffn = self.ffn(out1)
        ffn = self.do2(ffn, training=training)
        return self.ln2(out1 + ffn)


# Transformer Decoder Layer

In [69]:
class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1):
        super().__init__()
        self.ln1 = layers.LayerNormalization(epsilon=1e-6)
        self.ln2 = layers.LayerNormalization(epsilon=1e-6)
        self.ln3 = layers.LayerNormalization(epsilon=1e-6)
        self.self_att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.enc_att  = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.do1 = layers.Dropout(rate)
        self.do2 = layers.Dropout(rate)
        self.do3 = layers.Dropout(rate)
        self.ffn = keras.Sequential([layers.Dense(feed_forward_dim, activation='relu'), layers.Dense(embed_dim)])

    def _causal_mask(self, bs, L, dtype):
        i = tf.range(L)[:, None]
        j = tf.range(L)[None, :]
        mask = tf.cast(i >= j, dtype)
        mask = tf.reshape(mask, [1, L, L])
        return tf.tile(mask, [bs, 1, 1])

    def call(self, enc_out, target, training=False):
        bs = tf.shape(target)[0]
        L = tf.shape(target)[1]
        causal = self._causal_mask(bs, L, tf.bool)
        tgt_att = self.self_att(target, target, attention_mask=causal)
        tgt_att = self.do1(tgt_att, training=training)
        y = self.ln1(target + tgt_att)
        enc_att = self.enc_att(y, enc_out)
        enc_att = self.do2(enc_att, training=training)
        y2 = self.ln2(y + enc_att)
        ffn = self.ffn(y2)
        ffn = self.do3(ffn, training=training)
        return self.ln3(y2 + ffn)


# Complete the Transformer model


In [70]:
@tf.keras.utils.register_keras_serializable(package="custom")
class TransformerASR(tf.keras.Model):
    def __init__(self,
                 num_hid=128, num_heads=4, num_feed_forward=512,
                 source_maxlen=AUDIO_PAD_LEN, target_maxlen=MAX_TARGET_LEN,
                 num_layers_enc=4, num_layers_dec=1, num_classes=vocab_size, dropout_rate=0.1):
        super().__init__()
        self.loss_metric = keras.metrics.Mean(name="loss")
        self.num_layers_dec = num_layers_dec
        self.target_maxlen = target_maxlen
        self.num_classes = num_classes

        self.enc_input = SpeechFeatureEmbedding(num_hid=num_hid)
        self.dec_input = TokenEmbedding(num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid)

        enc_blocks = [TransformerEncoder(num_hid, num_heads, num_feed_forward, dropout_rate) for _ in range(num_layers_enc)]
        self.encoder = keras.Sequential([self.enc_input] + enc_blocks)

        for i in range(num_layers_dec):
            setattr(self, f"dec_layer_{i}", TransformerDecoder(num_hid, num_heads, num_feed_forward, dropout_rate))

        # classifier: if mixed_precision, use float32 for final logits via dtype
        self.classifier = layers.Dense(num_classes, dtype='float32')

    def decode(self, enc_out, target, training=False):
        y = self.dec_input(target)
        for i in range(self.num_layers_dec):
            y = getattr(self, f"dec_layer_{i}")(enc_out, y, training=training)
        return y

    def call(self, inputs, training=False):
        source, target = inputs
        x = self.encoder(source, training=training)
        y = self.decode(x, target, training=training)
        return self.classifier(y)

    def compute_loss(self, y, y_pred, sample_weight=None):
        loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
        return loss_fn(y, y_pred, sample_weight=sample_weight)

    def train_step(self, batch):
        source = batch["source"]
        target = batch["target"]
        dec_in = target[:, :-1]
        dec_tgt = target[:, 1:]
        with tf.GradientTape() as tape:
            logits = self([source, dec_in], training=True)
            one_hot = tf.one_hot(dec_tgt, depth=self.num_classes)
            mask = tf.cast(tf.not_equal(dec_tgt, 0), tf.float32)
            loss = self.compute_loss(one_hot, logits, sample_weight=mask)
        grads = tape.gradient(loss, self.trainable_variables)
        # gradient clipping already in optimizer via clipnorm
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        self.loss_metric.update_state(loss)
        return {"loss": self.loss_metric.result()}

    def test_step(self, batch):
        source = batch["source"]
        target = batch["target"]
        dec_in = target[:, :-1]
        dec_tgt = target[:, 1:]
        logits = self([source, dec_in], training=False)
        one_hot = tf.one_hot(dec_tgt, depth=self.num_classes)
        mask = tf.cast(tf.not_equal(dec_tgt, 0), tf.float32)
        loss = self.compute_loss(one_hot, logits, sample_weight=mask)
        self.loss_metric.update_state(loss)
        return {"loss": self.loss_metric.result()}

    def generate_greedy(self, source, start_idx):
        bs = tf.shape(source)[0]
        enc = self.encoder(source, training=False)
        dec = tf.ones((bs, 1), dtype=tf.int32) * start_idx
        for _ in range(self.target_maxlen - 1):
            y = self.decode(enc, dec, training=False)
            logits = self.classifier(y)
            next_tok = tf.argmax(logits[:, -1, :], axis=-1, output_type=tf.int32)[:, None]
            dec = tf.concat([dec, next_tok], axis=1)
        return dec


# Callbacks to display predictions


In [71]:
import jiwer
import tensorflow as tf
from keras.callbacks import Callback
class DisplayOutputs(keras.callbacks.Callback):
    def __init__(self, sample_batch, idx_to_char, start_token_idx=START_TOKEN_IDX, end_token_idx=END_TOKEN_IDX):
        super().__init__()
        self.batch = sample_batch
        self.idx_to_char = idx_to_char
        self.start = start_token_idx
        self.end = end_token_idx

    def ids_to_text(self, ids):
        s = []
        for i in ids:
            i = int(i)
            if i == self.end:
                break
            ch = self.idx_to_char[i]
            if ch in ["<", ">", "-", "#"]:
                continue
            s.append(ch)
        return "".join(s)

    def on_epoch_end(self, epoch, logs=None):
        # print only every 5 epochs
        if epoch % 5 != 0:
            return
        batch = self.batch
        src = batch["source"]
        tgt = batch["target"].numpy()
        preds_ids = self.model.generate_greedy(src, self.start).numpy()
        refs_txt, preds_txt = [], []
        for i in range(preds_ids.shape[0]):
            rt = self.ids_to_text(tgt[i])
            pt = self.ids_to_text(preds_ids[i])
            refs_txt.append(rt); preds_txt.append(pt)
            print("REF :", rt)
            print("PRED:", pt)
        cer = jiwer.cer(refs_txt, preds_txt)
        wer = jiwer.wer(refs_txt, preds_txt)
        print(f"Epoch {epoch} - CER: {cer:.4f}, WER: {wer:.4f}")


# Learning rate schedule


In [72]:
class CustomSchedule(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, init_lr=1e-5, lr_after_warmup=5e-4, final_lr=1e-5, warmup_epochs=10, decay_epochs=3, steps_per_epoch=27):
        super().__init__()
        self.init_lr = float(init_lr)
        self.lr_after_warmup = float(lr_after_warmup)
        self.final_lr = float(final_lr)
        self.warmup_epochs = int(warmup_epochs)
        self.decay_epochs = int(decay_epochs)
        self.steps_per_epoch = int(steps_per_epoch)

    def calculate_lr(self, epoch):
        warmup_lr = self.init_lr + ((self.lr_after_warmup - self.init_lr) / max(1, (self.warmup_epochs - 1))) * epoch
        decay_lr = tf.math.maximum(self.final_lr,
                                  self.lr_after_warmup - (epoch - self.warmup_epochs) * (self.lr_after_warmup - self.final_lr) / self.decay_epochs)
        return tf.math.minimum(warmup_lr, decay_lr)

    def __call__(self, step):
        epoch = step // self.steps_per_epoch
        epoch = tf.cast(epoch, tf.float32)
        return self.calculate_lr(epoch)

    def get_config(self):
        return {
            "init_lr": self.init_lr,
            "lr_after_warmup": self.lr_after_warmup,
            "final_lr": self.final_lr,
            "warmup_epochs": self.warmup_epochs,
            "decay_epochs": self.decay_epochs,
            "steps_per_epoch": self.steps_per_epoch,
        }


In [73]:
steps = int(tf.data.experimental.cardinality(train_ds).numpy())
lr_schedule = CustomSchedule(steps_per_epoch=max(1, steps))

# try AdamW; fallback to Adam if not available
try:
    optimizer = tf.keras.optimizers.AdamW(learning_rate=lr_schedule, weight_decay=1e-4, clipnorm=1.0)
except Exception:
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, clipnorm=1.0)

loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1)

model = TransformerASR(
    num_hid=128, num_heads=4, num_feed_forward=512,
    num_layers_enc=4, num_layers_dec=1, num_classes=vocab_size, dropout_rate=0.1
)

model.compile(optimizer=optimizer, loss=loss_fn)

# sample batch for callback
sample_batch = next(iter(val_ds))
display_cb = DisplayOutputs(sample_batch, idx_to_char, start_token_idx=START_TOKEN_IDX, end_token_idx=END_TOKEN_IDX)

early_stop = keras.callbacks.EarlyStopping(monitor='loss', patience=8, restore_best_weights=True)
ckpt = keras.callbacks.ModelCheckpoint("best_asr_model.keras", monitor='loss', save_best_only=True)

history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, callbacks=[display_cb, early_stop, ckpt])

Epoch 1/30
[1m18309/18309[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 295ms/step - loss: 1.7692REF : so you want me to ride with you i replied yes
PRED: and the the the the the the the an the an the the an an the the the an the the the the the the anoure the an an an
REF : i said nothing however and after a time jane spoke the dance was one thing and riding with you is another i did not wish to dance with you but i do wish to ride with you
PRED: and the the the the the the the an an an the the the an the the the an an the the the an the an the the the an the an an anoure the the the the an the an an the an the the the the the anore the
REF : it meant that she cared for me and would some day be mine
PRED: and the the the the the the the an an an the the the an the the the an the the the the the an the
REF : this was comforting if not satisfying and loosened my tongue jane you know my heart is full of love for you
PRED: and the the the the the the the an the the the an an an t