In [1]:
import chardet
import os
import pandas as pd
import torch
import torchaudio.transforms as T
from dataclasses import dataclass
from datasets import load_dataset, Dataset, Audio
from torch.utils.data import DataLoader
from transformers import TrainerCallback, TrainingArguments, Trainer, WhisperForConditionalGeneration, WhisperProcessor
from typing import Any, Dict, List, Union

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_BRAND = "openai/whisper-tiny"
MODEL_PROCESSOR = WhisperProcessor.from_pretrained(MODEL_BRAND)

ROOT_PATH = os.getcwd()

MODELS_PATH = os.path.join(ROOT_PATH, "models")
DATASET_PATH = os.path.join(ROOT_PATH, "dataset")
INPUTS_PATH = os.path.join(ROOT_PATH, "inputs")

LOG_PATH = os.path.join(MODELS_PATH, "log")
LEXICON_PATH = os.path.join(DATASET_PATH, "lexicon.txt")
AUDIO_PATH = os.path.join(DATASET_PATH, "Channel0", "audio")
SCRIPTS_PATH = os.path.join(DATASET_PATH, "Channel0", "scripts")
TRANSCRIPTIONS_PATH = os.path.join(DATASET_PATH, "Channel0", "transcriptions.csv")

def add_new_tokens():
	global MODEL_PROCESSOR
	dataset_vocabulary_set = set()
	with open(LEXICON_PATH, "r", encoding="utf-8") as f:
		for lexicon_line in f:
			p = lexicon_line.strip().split("\t")
			if len(p) < 2:
				continue
			lexicon_word = p[0]
			dataset_vocabulary_set.add(lexicon_word)
	new_tokens_list = list(dataset_vocabulary_set - set(MODEL_PROCESSOR.tokenizer.get_vocab().keys()))
	if new_tokens_list:
		print(f"Adding {len(new_tokens_list)} new tokens")
		MODEL_PROCESSOR.tokenizer.add_tokens(new_tokens_list)
		MODEL_PROCESSOR.save_pretrained(MODELS_PATH)
		MODEL_PROCESSOR = WhisperProcessor.from_pretrained(MODELS_PATH)
	else:
		print("No new tokens to add")

add_new_tokens()

def detect_encoding(path):
	with open(path, "rb") as f:
		return chardet.detect(f.read(100000))["encoding"]

def write_transcriptions():
	transcriptions_list = []
	for s in os.listdir(SCRIPTS_PATH):
		script_path = os.path.join(SCRIPTS_PATH, s)
		script_encoding = detect_encoding(script_path)
		if not script_encoding: continue
		try:
			with open(script_path, "r", encoding=script_encoding) as f:
				script_lines = f.readlines()
		except Exception as e:
			continue
		for script_line in script_lines:
			x = script_line.strip().split("\t")
			if len(x) != 2: continue
			identifier, transcript = x
			audio_path = os.path.join(AUDIO_PATH, f"{identifier}.WAV")
			if os.path.exists(audio_path):
				transcriptions_list.append({"audio_path": audio_path, "transcript": transcript})
	df = pd.DataFrame(transcriptions_list)
	df.to_csv(TRANSCRIPTIONS_PATH, index=False)

def load_dataset():
	transcriptions_df = pd.read_csv(TRANSCRIPTIONS_PATH)
	transcriptions_df = transcriptions_df.head(360)
	X = Dataset.from_pandas(transcriptions_df)
	X = X.cast_column("audio_path", Audio(sampling_rate=16000))
	time_mask = T.TimeMasking(time_mask_param=80)
	frequency_mask = T.FrequencyMasking(freq_mask_param=30)

	def preprocess_transcriptions_batch(X):
		X["input_features"] = frequency_mask(time_mask(
			MODEL_PROCESSOR(
				X["audio_path"]["array"],
				sampling_rate=16000,
				return_tensors="pt"
			).input_features[0]
		))
		X["labels"] = MODEL_PROCESSOR.tokenizer(
			X["transcript"],
			truncation=True,
			max_length=448
		).input_ids
		return X

	X = X.map(preprocess_transcriptions_batch, remove_columns=["audio_path"])
	X1 = X.train_test_split(test_size=0.2, seed=37)
	X2 = X1["train"].train_test_split(test_size=0.125, seed=37)
	return X2["train"], X2["test"], X1["test"]

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
	processor: Any
	def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
		B = self.processor.feature_extractor.pad([{"input_features": feature["input_features"]} for feature in features], return_tensors="pt")
		L = self.processor.tokenizer.pad([{"input_ids": feature["labels"]} for feature in features], return_tensors="pt", padding=True)
		B["labels"] = L["input_ids"].masked_fill(L.attention_mask.ne(1), -100)
		return B

class LossLoggerCallback(TrainerCallback):
	def __init__(self, log_path):
		self.log_path = log_path
	def on_log(self, args, state, control, logs=None, **kwargs):
		if not logs: return
		loss_info = f"Step {state.global_step}: "
		if "loss" in logs:
			loss_info += f"Training Loss = {logs['loss']} "
		if "eval_loss" in logs:
			loss_info += f"Validation Loss = {logs['eval_loss']} "
		if "loss" in logs or "eval_loss" in logs:
			with open(self.log_path, "a") as f:
				f.write(loss_info.strip() + "\n")

def finetune_model(X_train, X_val):
	model = WhisperForConditionalGeneration.from_pretrained(MODEL_BRAND).to(DEVICE)
	model.resize_token_embeddings(len(MODEL_PROCESSOR.tokenizer))
	model.config.use_cache = False
	data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=MODEL_PROCESSOR)
	trainer = Trainer(
		model=model,
		args=TrainingArguments(
			output_dir=MODELS_PATH,
			per_device_train_batch_size=2,
			per_device_eval_batch_size=2,
			num_train_epochs=2,
			max_steps=-1,
			eval_strategy="epoch",
			save_strategy="epoch",
			logging_dir=MODELS_PATH,
			report_to="none",
			logging_strategy="epoch",
		),
		train_dataset=X_train,
		eval_dataset=X_val,
		data_collator=data_collator,
	)
	trainer.add_callback(LossLoggerCallback(LOG_PATH))
	trainer.train()
	trainer.save_model(MODELS_PATH)

def evaluate_model(X_test):
	model = WhisperForConditionalGeneration.from_pretrained(MODELS_PATH).to(DEVICE)
	data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=MODEL_PROCESSOR)
	trainer = Trainer(
		model=model,
		args=TrainingArguments(
			output_dir=MODELS_PATH,
			per_device_eval_batch_size=2,
			report_to="none",
		),
		eval_dataset=X_test,
		data_collator=data_collator,
	)
	metrics = trainer.evaluate()
	print(metrics)

write_transcriptions()
X_train, X_valid, X_test = load_dataset()
finetune_model(X_train, X_valid)
evaluate_model(X_test)

def transcribe_audio():
	model = WhisperForConditionalGeneration.from_pretrained(os.path.join(MODELS_PATH, "checkpoint-2")).to(DEVICE)
	for a in [f for f in os.listdir(INPUTS_PATH) if f.lower().endswith((".wav", ".mp3", ".flac"))]:
		input_audio_path = os.path.join(INPUTS_PATH, a)
		input_audio = Audio(sampling_rate=16000).decode_example({"path": input_audio_path})
		input_audio_features = MODEL_PROCESSOR(audio["array"], sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE)
		with torch.no_grad():
			predicted_token_ids_tensor = model.generate(input_audio_features)
		transcription = MODEL_PROCESSOR.tokenizer.batch_decode(predicted_token_ids_tensor, skip_special_tokens=True)[0]
		print(f"{input_audio_path}: {transcription}")

Adding 54863 new tokens


Map:   0%|          | 0/360 [00:00<?, ? examples/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Epoch,Training Loss,Validation Loss
0,9.979,9.490836




{'eval_loss': 9.634716033935547, 'eval_model_preparation_time': 0.0012, 'eval_runtime': 15.9193, 'eval_samples_per_second': 4.523, 'eval_steps_per_second': 2.261}


In [3]:
def in_vocabulary(word):
	t = MODEL_PROCESSOR.tokenizer(word, add_special_tokens=False).input_ids
	return len(t) == 1 and t[0] in MODEL_PROCESSOR.tokenizer.get_vocab().values()

print(in_vocabulary("lah"))

True
