### Whisper STT model with fine-tuning

In [3]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
from datasets import load_dataset, DatasetDict, Audio
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import pipeline

import IPython.display as ipd
import numpy as np
import random

import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate

### Load Data

In [1]:
# Load into memory
dataset = DatasetDict()
dataset = load_dataset("data_gtts", data_dir="I:/Repos/STT_FineTune/nats/data_gtts")

print("Train: ", len(dataset['train']))
print("Test: ", len(dataset['test']))

# Resample to 16kHz
dataset = dataset.cast_column('audio', Audio(sampling_rate=16000))
dataset['train'][0]

Using custom data configuration default-9b41ac5bfd8c70c2
Found cached dataset data_gtts (I:/Repos/HFdatasets/data_gtts/default-9b41ac5bfd8c70c2/0.1.0/99611922a2fe30672e990db44b070dc747a16dd2cb691d0d2c33dc670a2e3b68)


  0%|          | 0/2 [00:00<?, ?it/s]

Train:  500
Test:  150


{'audio': {'path': 'data_gtts/train/RYA0QN_Knots_260.mp3',
  'array': array([-3.6807258e-14, -2.4243667e-15,  7.0424528e-14, ...,
         -4.9580830e-08, -1.6381831e-07, -7.4223811e-07], dtype=float32),
  'sampling_rate': 16000},
 'transcription': 'Ryanair Zero Quebec November fly speed Two Six Zero knots'}

In [5]:
# Play a data file
rand_int = random.randint(0, len(dataset["train"]))

print(dataset["train"][rand_int]["transcription"])
ipd.Audio(data=np.asarray(dataset["train"][rand_int]["audio"]["array"]), autoplay=True, rate=16000)

Speedbird Three Zero Three descend Flight Level Four Zero


### Preprocessing Pipeline

In [6]:
# Model path on Huggingface model hub
hf_model = "openai/whisper-tiny.en"

# Load pre-trained feature extractor
feature_extractor = WhisperFeatureExtractor.from_pretrained(hf_model)
# Load pre-trained tokenizer
tokenizer = WhisperTokenizer.from_pretrained(hf_model, language="English", task="transcribe")

# Generate the ASR pipeline
processor = WhisperProcessor.from_pretrained(hf_model, language="English", task="transcribe")

In [7]:
# Check that tokenizer is in the right language and contains processing tokens
input_str = dataset["train"][0]["transcription"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

print(f"Input:                 {input_str}")
print(f"Decoded w/ special:    {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal:             {input_str == decoded_str}")

Input:                 Ryanair Zero Quebec November fly speed Two Six Zero knots
Decoded w/ special:    <|startoftranscript|><|en|><|transcribe|><|notimestamps|>Ryanair Zero Quebec November fly speed Two Six Zero knots<|endoftext|>
Decoded w/out special: Ryanair Zero Quebec November fly speed Two Six Zero knots
Are equal:             True


### Preprocess data

In [8]:
def prepare_dataset(batch):
    # load and resample audio data from 24 to 16kHz
    audio = batch["audio"]
    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch

In [9]:
# Apply preprocessing
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"])
dataset

Loading cached processed dataset at I:/Repos/HFdatasets/data_gtts/default-9b41ac5bfd8c70c2/0.1.0/99611922a2fe30672e990db44b070dc747a16dd2cb691d0d2c33dc670a2e3b68\cache-071ab40b728590cb.arrow
Loading cached processed dataset at I:/Repos/HFdatasets/data_gtts/default-9b41ac5bfd8c70c2/0.1.0/99611922a2fe30672e990db44b070dc747a16dd2cb691d0d2c33dc670a2e3b68\cache-c9de11d73b5af1d6.arrow


DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 500
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 150
    })
})

In [10]:
# Define a data collator, which applies padding and converts the data to PyTorch tensors
class DataCollatorSpeechSeq2SeqWithPadding:
    
    def __init__(self, processor: Any):
        self.processor = processor

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [11]:
# Define the evaluation metric
metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

### Model Fine-tuning

In [12]:
# Load model
model = WhisperForConditionalGeneration.from_pretrained(hf_model, use_cache=False)

# Set up specialised tokens
model.config.forced_decoder_ids = None#[50257, 50258, 50358, 50362]
model.config.suppress_tokens = []

In [13]:
# Training arguments
repo_name = "checkpoints"

training_args = Seq2SeqTrainingArguments(
    output_dir=repo_name,  # change to a repo name of your choice
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=10,
    max_steps=300,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=50,
    save_steps=100,
    eval_steps=100,
    # logging_steps=100,
    # report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    save_total_limit=2,
    push_to_hub=False,
)

In [14]:
# Initialise trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

max_steps is given, it will override any value given in num_train_epochs
Using cuda_amp half precision backend


In [15]:
trainer.train()

***** Running training *****
  Num examples = 500
  Num Epochs = 10
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 2
  Total optimization steps = 300
  Number of trainable parameters = 37760256


Step,Training Loss,Validation Loss,Wer
100,No log,0.613631,0.15083
200,No log,0.17133,0.075415
300,No log,0.091583,0.075415


***** Running Evaluation *****
  Num examples = 150
  Batch size = 8
Saving model checkpoint to whisper-nats-syn\checkpoint-100
Configuration saved in whisper-nats-syn\checkpoint-100\config.json
Model weights saved in whisper-nats-syn\checkpoint-100\pytorch_model.bin
Feature extractor saved in whisper-nats-syn\checkpoint-100\preprocessor_config.json
***** Running Evaluation *****
  Num examples = 150
  Batch size = 8
Saving model checkpoint to whisper-nats-syn\checkpoint-200
Configuration saved in whisper-nats-syn\checkpoint-200\config.json
Model weights saved in whisper-nats-syn\checkpoint-200\pytorch_model.bin
Feature extractor saved in whisper-nats-syn\checkpoint-200\preprocessor_config.json
***** Running Evaluation *****
  Num examples = 150
  Batch size = 8
Saving model checkpoint to whisper-nats-syn\checkpoint-300
Configuration saved in whisper-nats-syn\checkpoint-300\config.json
Model weights saved in whisper-nats-syn\checkpoint-300\pytorch_model.bin
Feature extractor saved in w

TrainOutput(global_step=300, training_loss=0.6890928649902344, metrics={'train_runtime': 839.7534, 'train_samples_per_second': 5.716, 'train_steps_per_second': 0.357, 'total_flos': 1.1905692844032e+17, 'train_loss': 0.6890928649902344, 'epoch': 9.67})

In [17]:
# save model
trainer.save_model("./model_whisper/")

Saving model checkpoint to ./model_local/
Configuration saved in ./model_local/config.json
Model weights saved in ./model_local/pytorch_model.bin
Feature extractor saved in ./model_local/preprocessor_config.json


### Inference

In [22]:
# Load a checkpoint for inference
inference_interface = pipeline(model="I:/Repos/STT_FineTune/nats/model_whisper/", task='automatic-speech-recognition', device=0)

In [5]:
# Load into memory
dataset_tmp = DatasetDict()
dataset_tmp = load_dataset("data_gtts", data_dir="I:/Repos/STT_FineTune/nats/data_gtts")
# Resample to 16kHz
dataset_tmp = dataset_tmp.cast_column('audio', Audio(sampling_rate=16000))

Using custom data configuration default-9b41ac5bfd8c70c2
Found cached dataset data_gtts (I:/Repos/HFdatasets/data_gtts/default-9b41ac5bfd8c70c2/0.1.0/99611922a2fe30672e990db44b070dc747a16dd2cb691d0d2c33dc670a2e3b68)


  0%|          | 0/2 [00:00<?, ?it/s]

In [47]:
metric = evaluate.load("wer")

idx = np.random.randint(0, len(dataset_tmp['test']))
pred = inference_interface(dataset_tmp["test"][idx]['audio']['array'])['text']
label = dataset_tmp["test"][idx]['transcription']

print(pred)
print(label)
print(pred == label)
print(f"WER: {metric.compute(predictions=[pred], references=[label])}")

Astraeus Fife Six Two fly heading One Seven Fife degrees
Astraeus Fife Six Two fly heading One Seven Fife degrees
True
WER: 0.0


In [48]:
from tqdm.auto import tqdm

preds = []
labels = []

for i in tqdm(range(len(dataset_tmp['test']))):
    preds.append(inference_interface(dataset_tmp["test"][i]['audio']['array'])['text'])
    labels.append(dataset_tmp["test"][i]['transcription'])
    if preds[-1] != labels[-1]:
        print(preds[-1])
        print(labels[-1])

print(f"WER: {metric.compute(predictions=preds, references=labels)}")

  0%|          | 0/150 [00:00<?, ?it/s]



Ryanair Two Xray Three resume own navigation Two Oscar Tango Sierra India Delta
Ryanair Two Xray Three resume own navigation to Oscar Tango Sierra India Delta
WER: 0.0007541478129713424


In [49]:
import gradio as gr

def transcribe(audio):
    text = inference_interface(audio)["text"]
    return text

iface = gr.Interface(
    fn=transcribe, 
    inputs=gr.Audio(source="microphone", type="filepath"), 
    outputs="text",
    title="Whisper tiny",
    description="Realtime demo for speech recognition using a fine-tuned Whisper small model.",
)

# iface.launch()