## Load Data

In [None]:
from datasets import load_from_disk

dataset = load_from_disk("Data/kham_asr_dataset")['test']

dataset

## Load Model

In [None]:
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2Config

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="་")
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

model = Wav2Vec2ForCTC.from_pretrained("Models/baseline-fine-tuned")

model.to('cuda:0')

## Run Inference on Test Data

In [None]:
import torch

def generate_predictions(batch):
    # Load and resample the audio
    audio = batch["audio"]
    inputs = processor(
        audio["array"], 
        sampling_rate=audio["sampling_rate"], 
        return_tensors="pt",
        padding=True,
    ).input_values.to("cuda")

    # Generate logits and get argmax predictions
    with torch.no_grad():
        logits = model(inputs).logits
        predicted_ids = torch.argmax(logits, dim=-1)

    # Decode predictions to text
    batch["prediction"] = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    return batch

# Apply the function to the test dataset
processed_test_dataset = dataset.map(generate_predictions)

## Compute Metrics

In [None]:
import jiwer
from tibetan_wer.metrics import wer, ser

# Extract predictions and references
predictions = [elt[0].replace(' ', '་') for elt in processed_test_dataset["prediction"]]
references = [elt for elt in processed_test_dataset["transcription"]]

# Compute metrics
cer = jiwer.cer(predictions, references)
ser = ser(predictions, references)['micro_ser']
wer = wer(predictions, references)['micro_wer']

print(f"Character Error Rate (CER): {cer}")
print(f"Syllable Error Rate (SER): {ser}")
print(f"Word Error Rate (WER): {wer}")

In [None]:
import pickle

with open('baseline_preds.pickle', 'wb') as f:
    pickle.dump(f, predictions)