In [1]:
import librosa
import torch
import os
import argparse

from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
from tqdm import tqdm
from pyctcdecode import build_ctcdecoder

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class Inferencer:
    def __init__(self, device, huggingface_folder, w2v_model_path, kenlm_model_path, alpha = 1.0):
        self.device = device
        self.processor = Wav2Vec2Processor.from_pretrained(huggingface_folder)
        vocab_dict = self.processor.tokenizer.get_vocab()
        sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1]) if k not in ["<s>", "</s>"]}
        self.decoder = build_ctcdecoder(
            labels=list(sorted_vocab_dict.keys()),
            kenlm_model_path=kenlm_model_path,
            alpha=alpha
        )
        # self.processor_with_lm = Wav2Vec2ProcessorWithLM(
        #     feature_extractor=self.processor.feature_extractor,
        #     tokenizer=self.processor.tokenizer,
        #     decoder=self.decoder
        # )
        self.model = Wav2Vec2ForCTC.from_pretrained(huggingface_folder).to(self.device)
        if w2v_model_path is not None:
            self.preload_model(w2v_model_path)


    def preload_model(self, model_path) -> None:
        """
        Preload model parameters (in "*.tar" format) at the start of experiment.
        Args:
            model_path: The file path of the *.tar file
        """
        assert os.path.exists(model_path), f"The file {model_path} is not exist. please check path."
        checkpoint = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint, strict = True)
        print(f"Model preloaded successfully from {model_path}.")


    def transcribe(self, wav) -> str:
        input_values = self.processor(wav, sampling_rate=16000, return_tensors="pt").input_values
        logits = self.model(input_values.to(self.device)).logits
        # Sử dụng KenLM với beam search decoding
        pred_transcript = self.decoder.decode(logits.cpu().detach().numpy()[0])  # Chuyển logits sang numpy và giải mã
        return pred_transcript

    def run(self, test_filepath):
        filename = test_filepath.split('/')[-1].split('.')[0]
        filetype = test_filepath.split('.')[1]
        if filetype == 'txt':
            f = open(test_filepath, 'r')
            lines = f.read().splitlines()
            f.close()

            f = open(test_filepath.replace(filename, 'transcript_'+filename), 'w+')
            for line in tqdm(lines):
                wav, _ = librosa.load(line, sr = 16000)
                transcript = self.transcribe(wav)
                f.write(line + ' ' + transcript + '\n')
            f.close()

        else:
            wav, _ = librosa.load(test_filepath, sr = 16000)
            print(f"transcript: {self.transcribe(wav)}")

In [None]:
device_id = 0
device = f"cuda:{device_id}" if torch.cuda.is_available() else "cpu"

In [None]:
inferencer = Inferencer(
    device = device, 
    huggingface_folder = "custom_model", 
    w2v_model_path = "custom_model/pytorch_model.bin",
    kenlm_model_path = "5gram_correct.arpa")