In [1]:
from datasets import load_from_disk

ds = load_from_disk("../Data/garchen_dataset")['validation']
ds

Dataset({
    features: ['file_name', 'uni', 'wylie', 'url', 'dept', 'grade', 'char_len', 'audio_len', '__index_level_0__', 'audio', 'transcript'],
    num_rows: 3131
})

In [3]:
from transformers import Wav2Vec2FeatureExtractor, AutoTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2Config

feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("openpecha/Garchen_Rinpoche_stt")
tokenizer = AutoTokenizer.from_pretrained("openpecha/Garchen_Rinpoche_stt")
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

config = Wav2Vec2Config.from_pretrained("openpecha/Garchen_Rinpoche_stt")

model = Wav2Vec2ForCTC.from_pretrained(
    "../Checkpoints/nicttib1-pre-ft/model", 
    config=config
)

model.to('cuda:0')

Wav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (1-4): 4 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projec

In [4]:
import torch
import numpy as np

def generate_predictions(batch):
    # Load and resample the audio
    audio = batch["audio"]
    inputs = processor(
        audio["array"], 
        sampling_rate=audio["sampling_rate"], 
        return_tensors="pt",
        padding=True,
    ).input_values.to("cuda")

    # Generate logits and get argmax predictions
    with torch.no_grad():
        logits = model(inputs).logits
        predicted_ids = torch.argmax(logits, dim=-1)

    # Decode predictions to text
    batch["prediction"] = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    return batch

# Apply the function to the test dataset
processed_test_dataset = ds.map(generate_predictions)



Map:   0%|          | 0/3131 [00:00<?, ? examples/s]

In [5]:
import jiwer
from tibetan_wer.metrics import wer, ser

# Extract predictions and references
predictions = [elt[0].replace(' ', '་') for elt in processed_test_dataset["prediction"]]
references = [elt for elt in processed_test_dataset["uni"]]

# Compute metrics
cer = jiwer.cer(predictions, references)
ser = ser(predictions, references)['micro_ser']
wer = wer(predictions, references)['micro_wer']

print(f"Character Error Rate (CER): {cer}")
print(f"Syllable Error Rate (SER): {ser}")
print(f"Word Error Rate (WER): {wer}")

Character Error Rate (CER): 0.29316734677785006
Syllable Error Rate (SER): 0.5647187026707546
Word Error Rate (WER): 0.5978872139342342


In [6]:
import pickle

with open('1_ep_nicttib1_ft_val_preds.pickle', 'wb') as f:
    pickle.dump(predictions, f)