In [None]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import pandas as pd
import torch
torch.set_num_threads(1)
import torchaudio
import torchaudio.transforms as T
from dataclasses import dataclass
from datasets import load_dataset, Dataset, Audio
from torch.utils.data import DataLoader
from transformers import TrainerCallback, TrainingArguments, Trainer, WhisperProcessor
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn as nn
import torch.nn.functional as F
from utils import determine_functions, detect_encoding, initialise_repository, LossLoggerCallback
from transformers import WhisperForConditionalGeneration
from typing import Any, Dict, List, Union

DEVICE = torch.device("cpu")
SEED = 37
DEBUGGING_HEAD_AMOUNT = 3
MAX_TOKEN_LENGTH = 10

ROOT_PATH = os.getcwd()
MODELS_PATH = os.path.join(ROOT_PATH, "models")
DATASET_PATH = os.path.join(ROOT_PATH, "dataset")
INPUTS_PATH = os.path.join(ROOT_PATH, "inputs")

LOG_PATH = os.path.join(MODELS_PATH, "log")
LEXICON_PATH = os.path.join(DATASET_PATH, "lexicon.txt")
AUDIO_PATH = os.path.join(DATASET_PATH, "Channel0", "audio")
SCRIPTS_PATH = os.path.join(DATASET_PATH, "Channel0", "scripts")
TRANSCRIPTIONS_PATH = os.path.join(DATASET_PATH, "Channel0", "transcriptions.csv")

MODEL_BRAND = "openai/whisper-tiny"
MODEL_PROCESSOR = WhisperProcessor.from_pretrained(MODEL_BRAND)
if MODEL_PROCESSOR.tokenizer.pad_token is None or MODEL_PROCESSOR.tokenizer.pad_token == MODEL_PROCESSOR.tokenizer.eos_token:
		MODEL_PROCESSOR.tokenizer.add_special_tokens({'pad_token': '<pad>'})
		MODEL_PROCESSOR.tokenizer.pad_token_id = MODEL_PROCESSOR.tokenizer.convert_tokens_to_ids('<pad>')
		MODEL_PROCESSOR.save_pretrained(MODELS_PATH)
		MODEL_PROCESSOR = WhisperProcessor.from_pretrained(MODELS_PATH)

class WhisperCTC(nn.Module):
	def __init__(self, base_model_name):
		super().__init__()
		self.model = WhisperForConditionalGeneration.from_pretrained(base_model_name)
		self.model.config.use_cache = False
		self.ctc_head = nn.Linear(self.model.config.d_model, self.model.config.vocab_size)
		self.ctc_loss_fn = nn.CTCLoss(blank=self.model.config.pad_token_id, zero_infinity=True)

	def forward(self, input_features, labels, input_lengths=None, label_lengths=None):
		outputs = self.model(input_features, decoder_input_ids=torch.zeros((input_features.shape[0], 1), dtype=torch.long, device=input_features.device))
		encoder_outputs = outputs.encoder_last_hidden_state
		ctc_logits = self.ctc_head(encoder_outputs).log_softmax(2)

		if input_lengths is None:
			input_lengths = torch.tensor([ctc_logits.shape[1]] * ctc_logits.shape[0], dtype=torch.long, device=input_features.device)
		if label_lengths is None:
			label_lengths = torch.tensor([len(label[label != -100]) for label in labels], dtype=torch.long, device=input_features.device)

		loss = self.ctc_loss_fn(ctc_logits.permute(1, 0, 2), labels, input_lengths, label_lengths)
		return loss, ctc_logits

def add_new_tokens():
	global MODEL_PROCESSOR
	dataset_vocabulary_set = set()
	with open(LEXICON_PATH, "r", encoding="utf-8") as f:
		for lexicon_line in f:
			p = lexicon_line.strip().split("\t")
			if len(p) < 2:
				continue
			lexicon_word = p[0]
			dataset_vocabulary_set.add(lexicon_word)
	new_tokens_list = list(dataset_vocabulary_set - set(MODEL_PROCESSOR.tokenizer.get_vocab().keys()))
	if new_tokens_list:
		MODEL_PROCESSOR.tokenizer.add_tokens(new_tokens_list)
		MODEL_PROCESSOR.save_pretrained(MODELS_PATH)
		MODEL_PROCESSOR = WhisperProcessor.from_pretrained(MODELS_PATH)

def write_transcriptions():
	transcriptions_list = []
	for s in os.listdir(SCRIPTS_PATH):
		script_path = os.path.join(SCRIPTS_PATH, s)
		script_encoding = detect_encoding(script_path)
		if not script_encoding: continue
		try:
			with open(script_path, "r", encoding=script_encoding) as f:
				script_lines = f.readlines()
		except Exception:
			continue
		for script_line in script_lines:
			x = script_line.strip().split("\t")
			if len(x) != 2: continue
			identifier, transcript = x
			audio_path = os.path.join(AUDIO_PATH, f"{identifier}.WAV")
			if os.path.exists(audio_path):
				transcriptions_list.append({"audio_path": audio_path, "transcript": transcript})
	df = pd.DataFrame(transcriptions_list)
	df.to_csv(TRANSCRIPTIONS_PATH, index=False)

def load_dataset():
	transcriptions_df = pd.read_csv(TRANSCRIPTIONS_PATH).head(DEBUGGING_HEAD_AMOUNT)
	X = Dataset.from_pandas(transcriptions_df)
	X = X.cast_column("audio_path", Audio(sampling_rate=16000))

	def preprocess_transcriptions_batch(X):
		X["input_features"] = MODEL_PROCESSOR(
			X["audio_path"]["array"], sampling_rate=16000, return_tensors="pt"
		).input_features[0]

		tokenized = MODEL_PROCESSOR.tokenizer(
			X["transcript"], truncation=True, max_length=MAX_TOKEN_LENGTH, padding=True, return_attention_mask=True
		)

		X["labels"] = tokenized.input_ids
		X["input_lengths"] = X["input_features"].shape[1]
		X["label_lengths"] = len(tokenized.input_ids)
		return X

	X = X.map(preprocess_transcriptions_batch, remove_columns=["audio_path"])
	X1 = X.train_test_split(test_size=0.2, seed=SEED)
	X2 = X1["train"].train_test_split(test_size=0.125, seed=SEED)
	return X2["train"], X2["test"], X1["test"]

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
	processor: Any
	def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
		B = self.processor.feature_extractor.pad(
			[{"input_features": feature["input_features"]} for feature in features], return_tensors="pt"
		)
		L = self.processor.tokenizer.pad(
			[{"input_ids": feature["labels"]} for feature in features], return_tensors="pt", padding=True
		)
		B["labels"] = L["input_ids"].masked_fill(L.get("attention_mask", torch.ones_like(L["input_ids"])).ne(1), -100)
		return B

class WhisperTrainer(Trainer):
	def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
		input_features = inputs["input_features"]
		labels = inputs["labels"]
		input_lengths = torch.tensor([x.shape[0] for x in input_features], dtype=torch.long, device=DEVICE)
		label_lengths = torch.tensor([len(x[x != -100]) for x in labels], dtype=torch.long, device=DEVICE)
		loss, outputs = model(input_features, labels, input_lengths, label_lengths)
		return (loss, outputs) if return_outputs else loss

def finetune_model(X_train, X_val):
	model = WhisperCTC(MODEL_BRAND).to(DEVICE)
	trainer = WhisperTrainer(
		args=TrainingArguments(
			eval_strategy="epoch",
			logging_dir=MODELS_PATH,
			logging_strategy="epoch",
			num_train_epochs=1,
			output_dir=MODELS_PATH,
			per_device_train_batch_size=1,  # Increase slightly
			per_device_eval_batch_size=1,
			gradient_accumulation_steps=8,  # Reduce computation per step
			report_to="none",
			save_safetensors=False,
			save_strategy="epoch",
		),
		data_collator=DataCollatorSpeechSeq2SeqWithPadding(processor=MODEL_PROCESSOR),
		eval_dataset=X_val,
		model=model,
		train_dataset=X_train,
	)
	trainer.add_callback(LossLoggerCallback(LOG_PATH))
	print("after callback")
	trainer.train()
	print("after train()")
	trainer.save_model(MODELS_PATH)
	print("after save_model")
	MODEL_PROCESSOR.tokenizer.save_pretrained(MODELS_PATH)

def evaluate_model(X_test):
	model = WhisperCTC(MODELS_PATH).to(DEVICE)
	trainer = Trainer(
		args=TrainingArguments(
			output_dir=MODELS_PATH,
			per_device_eval_batch_size=1,
			report_to="none",
			save_safetensors=False,
		),
		data_collator=DataCollatorSpeechSeq2SeqWithPadding(processor=MODEL_PROCESSOR),
		eval_dataset=X_test,
		model=model,
	)
	print(trainer.evaluate())

determine_functions(SEED)
initialise_repository(ROOT_PATH)
add_new_tokens()
write_transcriptions()
X_train, X_val, X_test = load_dataset()
finetune_model(X_train, X_val)
# evaluate_model(X_test)

UserWarning: The operator 'aten::_ctc_loss' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:14.)
  return torch.ctc_loss(