In [3]:
import os
import pandas as pd
import torch
import torchaudio
import torchaudio.transforms as T
from dataclasses import dataclass
from datasets import load_dataset, Dataset, Audio
from torch.utils.data import DataLoader
from transformers import TrainerCallback, TrainingArguments, Trainer, WhisperForConditionalGeneration, WhisperProcessor
from typing import Any, Dict, List, Union
from utils import determine_functions, detect_encoding, LossLoggerCallback

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 37

MODEL_BRAND = "openai/whisper-tiny"
MODEL_PROCESSOR = WhisperProcessor.from_pretrained(MODEL_BRAND)

ROOT_PATH = os.getcwd()
MODELS_PATH = os.path.join(ROOT_PATH, "models")
DATASET_PATH = os.path.join(ROOT_PATH, "dataset")
INPUTS_PATH = os.path.join(ROOT_PATH, "inputs")

LOG_PATH = os.path.join(MODELS_PATH, "log")
LEXICON_PATH = os.path.join(DATASET_PATH, "lexicon.txt")
AUDIO_PATH = os.path.join(DATASET_PATH, "Channel0", "audio")
SCRIPTS_PATH = os.path.join(DATASET_PATH, "Channel0", "scripts")
TRANSCRIPTIONS_PATH = os.path.join(DATASET_PATH, "Channel0", "transcriptions.csv")

def add_new_tokens():
	global MODEL_PROCESSOR
	dataset_vocabulary_set = set()
	with open(LEXICON_PATH, "r", encoding="utf-8") as f:
		for lexicon_line in f:
			p = lexicon_line.strip().split("\t")
			if len(p) < 2:
				continue
			lexicon_word = p[0]
			dataset_vocabulary_set.add(lexicon_word)
	new_tokens_list = list(dataset_vocabulary_set - set(MODEL_PROCESSOR.tokenizer.get_vocab().keys()))
	if new_tokens_list:
		print(f"Adding {len(new_tokens_list)} new tokens")
		MODEL_PROCESSOR.tokenizer.add_tokens(new_tokens_list)
		MODEL_PROCESSOR.save_pretrained(MODELS_PATH)
		MODEL_PROCESSOR = WhisperProcessor.from_pretrained(MODELS_PATH)
	else:
		print("No new tokens to add")

def write_transcriptions():
	transcriptions_list = []
	for s in os.listdir(SCRIPTS_PATH):
		script_path = os.path.join(SCRIPTS_PATH, s)
		script_encoding = detect_encoding(script_path)
		if not script_encoding: continue
		try:
			with open(script_path, "r", encoding=script_encoding) as f:
				script_lines = f.readlines()
		except Exception as e:
			continue
		for script_line in script_lines:
			x = script_line.strip().split("\t")
			if len(x) != 2: continue
			identifier, transcript = x
			audio_path = os.path.join(AUDIO_PATH, f"{identifier}.WAV")
			if os.path.exists(audio_path):
				transcriptions_list.append({"audio_path": audio_path, "transcript": transcript})
	df = pd.DataFrame(transcriptions_list)
	df.to_csv(TRANSCRIPTIONS_PATH, index=False)

def load_dataset():
	transcriptions_df = pd.read_csv(TRANSCRIPTIONS_PATH)
	transcriptions_df = transcriptions_df.head(120)
	X = Dataset.from_pandas(transcriptions_df)
	X = X.cast_column("audio_path", Audio(sampling_rate=16000))
	time_mask = T.TimeMasking(time_mask_param=80)
	frequency_mask = T.FrequencyMasking(freq_mask_param=30)

	def preprocess_transcriptions_batch(X):
		X["input_features"] = frequency_mask(time_mask(
			MODEL_PROCESSOR(
				X["audio_path"]["array"],
				sampling_rate=16000,
				return_tensors="pt"
			).input_features[0]
		))
		X["labels"] = MODEL_PROCESSOR.tokenizer(
			X["transcript"],
			truncation=True,
			max_length=448
		).input_ids
		return X

	X = X.map(preprocess_transcriptions_batch, remove_columns=["audio_path"])
	X1 = X.train_test_split(test_size=0.2, seed=SEED)
	X2 = X1["train"].train_test_split(test_size=0.125, seed=SEED)
	return X2["train"], X2["test"], X1["test"]

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
	processor: Any
	def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
		B = self.processor.feature_extractor.pad([{"input_features": feature["input_features"]} for feature in features], return_tensors="pt")
		L = self.processor.tokenizer.pad([{"input_ids": feature["labels"]} for feature in features], return_tensors="pt", padding=True)
		B["labels"] = L["input_ids"].masked_fill(L.attention_mask.ne(1), -100)
		return B

def finetune_model(X_train, X_val):
	model = WhisperForConditionalGeneration.from_pretrained(MODEL_BRAND).to(DEVICE)
	model.resize_token_embeddings(len(MODEL_PROCESSOR.tokenizer))
	model.config.use_cache = False
	model.config.forced_decoder_ids = MODEL_PROCESSOR.tokenizer.get_decoder_prompt_ids(language="en", task="transcribe", no_timestamps=True)
	trainer = Trainer(
		args=TrainingArguments(
			eval_strategy="epoch",
			logging_dir=MODELS_PATH,
			logging_strategy="epoch",
			max_steps=-1,
			num_train_epochs=1,
			output_dir=MODELS_PATH,
			per_device_eval_batch_size=2,
			per_device_train_batch_size=2,
			report_to="none",
			save_strategy="epoch",
		),
		data_collator=DataCollatorSpeechSeq2SeqWithPadding(processor=MODEL_PROCESSOR),
		eval_dataset=X_val,
		model=model,
		train_dataset=X_train,
	)
	trainer.add_callback(LossLoggerCallback(LOG_PATH))
	trainer.train()
	trainer.save_model(MODELS_PATH)
	MODEL_PROCESSOR.tokenizer.save_pretrained(MODELS_PATH)

def evaluate_model(X_test):
	model = WhisperForConditionalGeneration.from_pretrained(MODELS_PATH).to(DEVICE)
	trainer = Trainer(
		args=TrainingArguments(
			output_dir=MODELS_PATH,
			per_device_eval_batch_size=2,
			report_to="none",
		),
		data_collator=DataCollatorSpeechSeq2SeqWithPadding(processor=MODEL_PROCESSOR),
		eval_dataset=X_test,
		model=model,
	)
	print(trainer.evaluate())

determine_functions(SEED)
add_new_tokens()
write_transcriptions()
X_train, X_val, X_test = load_dataset()
finetune_model(X_train, X_val)
evaluate_model(X_test)

def transcribe_audio():
	global MODEL_PROCESSOR

	model = WhisperForConditionalGeneration.from_pretrained(MODELS_PATH).to(DEVICE)
	MODEL_PROCESSOR.tokenizer.pad_token = MODEL_PROCESSOR.tokenizer.eos_token

	for a in [f for f in os.listdir(INPUTS_PATH) if f.lower().endswith((".wav", ".mp3", ".flac"))]:
		input_audio_path = os.path.join(INPUTS_PATH, a)
		input_waveform, input_sample_rate = torchaudio.load(input_audio_path)
		if input_sample_rate != 16000:
			transform = torchaudio.transforms.Resample(orig_freq=input_sample_rate, new_freq=16000)
			input_waveform = transform(input_waveform)
		input_audio_features = MODEL_PROCESSOR(
			input_waveform.numpy().squeeze(), sampling_rate=16000, return_tensors="pt"
		).input_features.to(DEVICE)
		with torch.no_grad():
			predicted_token_ids_tensor = model.generate(input_audio_features, forced_decoder_ids=MODEL_PROCESSOR.tokenizer.get_decoder_prompt_ids(language="en"))

		print(f"{input_audio_path}: {MODEL_PROCESSOR.tokenizer.batch_decode(predicted_token_ids_tensor, skip_special_tokens=True)[0]}")

transcribe_audio()
print(not(MODEL_PROCESSOR.tokenizer.get_vocab().get("lah") is None))

Adding 54863 new tokens


Map:   0%|          | 0/120 [00:00<?, ? examples/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Epoch,Training Loss,Validation Loss
1,5.3528,4.435731




The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


{'eval_loss': 4.587049961090088, 'eval_model_preparation_time': 0.0008, 'eval_runtime': 2.714, 'eval_samples_per_second': 8.843, 'eval_steps_per_second': 4.422}
/Users/gregory/Code/SingaScribe/inputs/there_were_barrels_of_wine_in_the_huge_cellar.WAV:  There were barrels of white in the  sselow
/Users/gregory/Code/SingaScribe/inputs/i_was_so_tired_from_work_i_could_not_even_bother_to_brush_my_teeth.WAV:  I was so tired from work i could not even bother to brush my teeth
True
