In [16]:
import chardet
import os
import pandas as pd
import tgt
import torch

from datasets import load_dataset, Dataset, Audio
from torch.utils.data import DataLoader
from transformers import EncoderDecoderCache, TrainingArguments, Trainer, WhisperForConditionalGeneration, WhisperProcessor
from u import DataCollatorSpeechSeq2SeqWithPadding, detect_encoding

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
MODEL_BRAND = "openai/whisper-tiny"
PROCESSOR = WhisperProcessor.from_pretrained(MODEL_BRAND)

ROOT_PATH = os.getcwd()
AUDIO_PATH = os.path.join(ROOT_PATH, "dataset", "CloseMic", "audio")
SCRIPTS_PATH = os.path.join(ROOT_PATH, "dataset", "CloseMic", "scripts")
TRANSCRIPTIONS_PATH = os.path.join(ROOT_PATH, "dataset", "CloseMic", "transcriptions.csv")
MODELS_PATH = os.path.join(ROOT_PATH, "models")
INPUTS_PATH = os.path.join(ROOT_PATH, "inputs")

def add_slang_tokens():
	special_tokens = ["[lah]"]
	PROCESSOR.tokenizer.add_tokens(special_tokens)

add_slang_tokens()

def write_transcriptions():
	transcriptions_list = []

	for s in os.listdir(SCRIPT_PATH):
		script_path = os.path.join(SCRIPT_PATH, s)
		audio_path = os.path.join(AUDIO_PATH, s.replace(".TextGrid", ".wav"))

		script_encoding = detect_encoding(script_path)
		if not script_encoding:
			continue

		try:
			with open(script_path, "r", encoding=script_encoding) as f:
				script_content = f.read()
		except Exception as e:
			continue

		utf8_script_path = script_path + ".utf8"
		with open(utf8_script_path, "w", encoding="utf-8") as f:
			f.write(script_content)

		try:
			script_textgrid_content = tgt.io.read_textgrid(utf8_script_path, encoding="utf-8")
		except Exception as e:
			continue

		tier_name = s.replace(".TextGrid", "")
		if tier_name not in [t.name for t in script_textgrid_content.tiers]:
			continue
		tier = script_textgrid_content.get_tier_by_name(tier_name)

		transcript = " ".join([
			interval.text for interval in tier.intervals if interval.text.strip() and interval.text not in {"<S>", "<SIL>", "<Z>"}
		])
		if not transcript:
			os.remove(utf8_script_path)
			continue
		transcriptions_list.append({"audio_path": audio_path, "transcript": transcript})
		os.remove(utf8_script_path)

	df = pd.DataFrame(transcriptions_list)
	df.to_csv(TRANSCRIPTIONS_PATH, index=False)

def load_dataset():
	df = pd.read_csv(TRANSCRIPTIONS_PATH)
	df = df.head(2)  # Remove only after team debugging and testing
	X = Dataset.from_pandas(df)
	X = X.cast_column("audio_path", Audio(sampling_rate=16000))

	def preprocess_batch(batch):
		audio = batch["audio_path"]
		batch["input_features"] = PROCESSOR(audio["array"], sampling_rate=16000, return_tensors="pt").input_features[0]
		tokenized = PROCESSOR.tokenizer(batch["transcript"], truncation=True, max_length=448)
		batch["labels"] = tokenized.input_ids
		return batch

	X = X.map(preprocess_batch, remove_columns=["audio_path"])
	X = X.train_test_split(test_size=0.2)
	return X["train"], X["test"], PROCESSOR

def finetune_model(X_train, X_test):
	model = WhisperForConditionalGeneration.from_pretrained(MODEL_BRAND).to(DEVICE)
	model.resize_token_embeddings(len(PROCESSOR.tokenizer))
	model.config.use_cache = False
	training_args = TrainingArguments(
		output_dir=MODELS_PATH,
		per_device_train_batch_size=2,
		max_steps=2,  # Replace as needed
		eval_strategy="epoch",
		save_strategy="epoch",
		logging_dir=os.devnull,
		report_to="none",
		logging_strategy="no",
	)
	data_collator = DataCollatorSpeechSeq2SeqWithPadding(PROCESSOR=PROCESSOR)
	trainer = Trainer(
		model=model,
		args=training_args,
		train_dataset=X_train,
		eval_dataset=X_test,
		processing_class=PROCESSOR,
		data_collator=data_collator,
	)
	trainer.train()
	trainer.save_model(MODEL_PATH)

# write_transcriptions()
# X_train, X_test = load_dataset()
# finetune_model(X_train, X_test)

def transcribe_audio(path):
	model = WhisperForConditionalGeneration.from_pretrained("models/checkpoint-2").to(DEVICE)  # Replace as needed
	audio = Audio(sampling_rate=16000).decode_example({"path": path})
	input_features = PROCESSOR(
		audio["array"],
		sampling_rate=16000,
		return_tensors="pt"
	).input_features.to(DEVICE)

	with torch.no_grad():
		predicted_token_ids_tensor = model.generate(input_features)
	return PROCESSOR.tokenizer.batch_decode(predicted_token_ids_tensor, skip_special_tokens=True)[0]

def in_vocabulary(word):
	tokenized = PROCESSOR.tokenizer(word, add_special_tokens=False).input_ids
	return len(tokenized) == 1 and tokenized[0] in PROCESSOR.tokenizer.get_vocab().values()

print(in_vocabulary("[lah]"))

True
