## Load Data

In [2]:
from datasets import load_from_disk

dataset = load_from_disk("../Data/kham_asr_dataset")['test']

dataset

Loading dataset from disk:   0%|          | 0/34 [00:00<?, ?it/s]

Dataset({
    features: ['file_name', 'uni', 'wylie', 'url', 'dept', 'grade', 'char_len', 'audio_len', '__index_level_0__', 'audio', 'transcript'],
    num_rows: 4000
})

## Load Model

In [3]:
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2Config

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
tokenizer = Wav2Vec2CTCTokenizer("../Models/vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="་")
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

model = Wav2Vec2ForCTC.from_pretrained("../Models/baseline-fine-tuned")

model.to('cuda:0')

Wav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder)

## Run Inference on Test Data

In [4]:
import torch

def generate_predictions(batch):
    # Load and resample the audio
    audio = batch["audio"]
    inputs = processor(
        audio["array"], 
        sampling_rate=audio["sampling_rate"], 
        return_tensors="pt",
        padding=True,
    ).input_values.to("cuda")

    # Generate logits and get argmax predictions
    with torch.no_grad():
        logits = model(inputs).logits
        predicted_ids = torch.argmax(logits, dim=-1)

    # Decode predictions to text
    batch["prediction"] = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    return batch

# Apply the function to the test dataset
processed_test_dataset = dataset.map(generate_predictions)



Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

## Compute Metrics

In [6]:
import jiwer
from tibetan_wer.metrics import wer, ser

# Extract predictions and references
predictions = [elt[0].replace(' ', '་') for elt in processed_test_dataset["prediction"]]
references = [elt for elt in processed_test_dataset["uni"]]

# Compute metrics
cer = jiwer.cer(predictions, references)
ser = ser(predictions, references)['micro_ser']
wer = wer(predictions, references)['micro_wer']

print(f"Character Error Rate (CER): {cer}")
print(f"Syllable Error Rate (SER): {ser}")
print(f"Word Error Rate (WER): {wer}")

  warn(


Character Error Rate (CER): 0.24396886693369646
Syllable Error Rate (SER): 0.5381160584158711
Word Error Rate (WER): 0.5547484822202949


In [None]:
import pickle

with open('baseline_preds.pickle', 'wb') as f:
    pickle.dump(f, predictions)