In [9]:
import chardet
import os
import pandas as pd
import tgt
import torch

from torch.utils.data import DataLoader
from dataclasses import dataclass
from typing import Any, Dict, List, Union

from datasets import load_dataset, Dataset, Audio
from transformers import WhisperForConditionalGeneration, WhisperProcessor, TrainingArguments, Trainer

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
ROOT_PATH = os.getcwd()
AUDIO_PATH = os.path.join(ROOT_PATH, "dataset", "CloseMic", "audio")
SCRIPT_PATH = os.path.join(ROOT_PATH, "dataset", "CloseMic", "scripts")
TRANSCRIPTION_DICT_PATH = os.path.join(ROOT_PATH, "dataset", "CloseMic", "transcriptions.csv")
MODEL_BRAND = "openai/whisper-tiny"
MODEL_PATH = os.path.join(ROOT_PATH, "models")

def detect_encoding(file_path):
	with open(file_path, "rb") as f:
		raw_data = f.read(100000)
		return chardet.detect(raw_data)["encoding"]

def extract_textgrid_transcriptions():
	data = []
	for filename in os.listdir(SCRIPT_PATH):
		if not filename.endswith(".TextGrid"): continue

		textgrid_file = os.path.join(SCRIPT_PATH, filename)
		audio_file = filename.replace(".TextGrid", ".wav")
		audio_path = os.path.join(AUDIO_PATH, audio_file)

		detected_encoding = detect_encoding(textgrid_file)
		if not detected_encoding:
			print(f"Warning: Could not detect encoding for {filename}. Skipping.")
			continue

		try:
			with open(textgrid_file, "r", encoding=detected_encoding) as f:
				content = f.read()
		except Exception as e:
			print(f"Error reading {filename} with detected encoding '{detected_encoding}': {e}")
			continue

		utf8_textgrid_file = textgrid_file + ".utf8"
		with open(utf8_textgrid_file, "w", encoding="utf-8") as f:
			f.write(content)

		try:
			tg = tgt.io.read_textgrid(utf8_textgrid_file, encoding="utf-8")
		except Exception as e:
			print(f"Error parsing {filename}: {e}")
			continue

		# Extract the correct tier name (matches filename)
		tier_name = filename.replace(".TextGrid", "")
		if tier_name not in [t.name for t in tg.tiers]:
			print(f"Warning: Tier '{tier_name}' not found in {filename}. Skipping.")
			continue

		tier = tg.get_tier_by_name(tier_name)

		# Extract transcription (ignoring special tokens like <S>, <SIL>, <Z>)
		transcription = " ".join(
			[interval.text for interval in tier.intervals if interval.text.strip() and interval.text not in {"<S>", "<SIL>", "<Z>"}]
		)
		if transcription: data.append({"audio": audio_path, "text": transcription})
		os.remove(utf8_textgrid_file)

	df = pd.DataFrame(data)
	df.to_csv(TRANSCRIPTION_DICT_PATH, index=False)
	print(f"Saved dataset to {TRANSCRIPTION_DICT_PATH}")

def load_and_prepare_dataset():
	df = pd.read_csv(TRANSCRIPTION_DICT_PATH)
	df = df.head(2)  # Select only 2 samples for a quick test
	dataset = Dataset.from_pandas(df)
	dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
	processor = WhisperProcessor.from_pretrained(MODEL_BRAND)

	def preprocess(batch):
		audio = batch["audio"]
		batch["input_features"] = processor(audio["array"], sampling_rate=16000, return_tensors="pt").input_features[0]
		tokenized = processor.tokenizer(batch["text"], truncation=True, max_length=448)
		batch["labels"] = tokenized.input_ids
		return batch

	dataset = dataset.map(preprocess, remove_columns=["audio"])

	# Split dataset
	dataset = dataset.train_test_split(test_size=0.2)  # 80% train, 20% eval
	train_dataset = dataset["train"]
	eval_dataset = dataset["test"]

	return train_dataset, eval_dataset, processor

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
	processor: Any
	def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
		input_features = [{"input_features": feature["input_features"]} for feature in features]
		label_features = [{"input_ids": feature["labels"]} for feature in features]
		batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
		labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt", padding=True)
		labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
		batch["labels"] = labels
		return batch

def fine_tune_whisper(train_dataset, eval_dataset, processor):
	model = WhisperForConditionalGeneration.from_pretrained(MODEL_BRAND).to(DEVICE)
	training_args = TrainingArguments(
		output_dir=MODEL_PATH,
		per_device_train_batch_size=2,
		max_steps=2,
		evaluation_strategy="epoch",
		save_strategy="epoch",
		logging_dir=os.devnull,
		report_to="none",
		logging_strategy="no",
	)
	data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
	trainer = Trainer(
		model=model,
		args=training_args,
		train_dataset=train_dataset,
		eval_dataset=eval_dataset,
		processing_class=processor,
		data_collator=data_collator,
	)
	trainer.train()
	trainer.save_model(MODEL_PATH)

if __name__ == "__main__":
	extract_textgrid_transcriptions()
	train_dataset, eval_dataset, processor = load_and_prepare_dataset()
	fine_tune_whisper(train_dataset, eval_dataset, processor)

Saved dataset to /Users/gregory/Code/TranscribeLeh/dataset/CloseMic/transcriptions.csv


Map:   0%|          | 0/2 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss
1,No log,4.523684
2,No log,4.473498


