In [1]:
from espnet2.tasks.asr_TruCLeS import ASRTask
from pathlib import Path
import yaml
import argparse
import torch
import torchaudio
import os
from espnet2.text.build_tokenizer import build_tokenizer
from espnet2.text.token_id_converter import TokenIDConverter

In [2]:
test_dataset = torchaudio.datasets.LIBRISPEECH("/home/pb_deployment/Downloads", url="test-clean", download=True)

# Loading ASR model for inference

In [3]:
model_root_dir = "/home/pb_deployment/espnet/asr_inference/model_files/e_branchformer_librispeech/"
asr_config_path = os.path.join(model_root_dir, "exp/asr_train_asr_e_branchformer_raw_en_bpe5000_sp/config.yaml")
model_path = os.path.join(model_root_dir, "exp/asr_train_asr_e_branchformer_raw_en_bpe5000_sp/valid.acc.ave_10best.pth")
bpe_model_path = os.path.join(model_root_dir,"data/en_token_list/bpe_unigram5000/bpe.model")

In [4]:
def buildAndLoadASRModel(asr_config, asr_model, device="cpu"):
    config_file = Path(asr_config)
    with config_file.open("r", encoding="utf-8") as f:
        args = yaml.safe_load(f)
    
    args = argparse.Namespace(**args)
    model = ASRTask.build_model(args)
    
    state_dict = torch.load(asr_model, map_location=device)
    
    
    model.to(device)
    use_lora = getattr(args, "use_lora", False)
    model.load_state_dict(state_dict, strict=not use_lora)
    if any(["frontend.upstream.model" in k for k in state_dict.keys()]):
        if any(
            [
                "frontend.upstream.upstream.model" in k
                for k in dict(model.named_parameters())
                            ]
        ):
            state_dict = {
                k.replace(
                    "frontend.upstream.model",
                    "frontend.upstream.upstream.model",
                ): v
                for k, v in state_dict.items()
            }
            model.load_state_dict(state_dict, strict=not use_lora)
    return model

In [5]:
class TruCLeS_Tokenizer:
    def __init__(self, asr_config_path, bpe_model_path):
        self.asr_config_path = asr_config_path
        self.bpe_model_path = bpe_model_path
        self.tokenizer = build_tokenizer(
            token_type='bpe',
            bpemodel=bpe_model_path  
        )
        with open(self.asr_config_path, 'r') as file:
            config_data = yaml.safe_load(file)
        self.tokens_list = config_data.get('token_list', [])
        self.tokenIDConvertor = TokenIDConverter(token_list = self.tokens_list)

        self.ids = dict(zip(self.tokens_list, [i for i in range(len(self.tokens_list))]))

    def text2ids(self, text):
        tokenized = self.tokenizer.text2tokens(text)
        ids = self.tokenIDConvertor.tokens2ids(tokenized)
        return ids
    
    def ids2text(self, ids):
        tokenized = self.tokenIDConvertor.ids2tokens(ids)
        text = self.tokenizer.tokens2text(tokenized)
        return text

    def text2tokens(self, text):
        return self.tokenizer.text2tokens(text)
        
    def tokens2text(self, tokens):
        return self.tokenizer.tokens2text(tokens)

    def tokens2ids(self, tokens):
        return self.tokenIDConvertor.tokens2ids(tokens)

    def ids2tokens(self, ids):
        return self.tokenIDConvertor.ids2tokens(ids)

In [6]:
ASRModel = buildAndLoadASRModel(asr_config_path, model_path, device="cuda")
tokenizer = TruCLeS_Tokenizer(asr_config_path, bpe_model_path)

In [7]:
speech = test_dataset[0][0]
speech_length = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))

text = test_dataset[0][2]
print(text)
text = torch.tensor(tokenizer.text2ids(text)).unsqueeze(0)
text_length = torch.tensor([text.shape[-1]])

print(speech.shape, speech_length, text, text_length)

HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE
torch.Size([1, 166960]) tensor([166960]) tensor([[  10, 2668,   53,   60,   26,  996,  242,   20, 1002,  552, 1118,    3,
            4,  409,  251,   22,    3,    4,  434,  389, 2117, 4840,    4, 1839,
          354,  334,  292, 1852,    6,   26, 1218,  341,   65,    8, 1049, 3779,
           13, 2551, 1839,  833,   13, 3753]]) tensor([42])


```


encoder_out, encoder_out_lens = ASRModel.encode(speech: torch.Tensor, speech_lengths: torch.Tensor)

ys_hat, decoder_out, softmax_out = ASRModel.decode_TruCLeS(
                                        encoder_out: torch.Tensor,
                                        encoder_out_lens: torch.Tensor,
                                        text: torch.Tensor,
                                        text_lengths: torch.Tensor,
                                    )
```    

In [8]:
DEVICE = "cuda"
ASRModel.eval()
encoder_out, encoder_out_lens = ASRModel.encode(speech.to(DEVICE), speech_length.to(DEVICE))

In [9]:
hyp, decoder_out= ASRModel.decode_TruCLeS(encoder_out, encoder_out_lens, text.to(DEVICE), text_length.to(DEVICE))

In [12]:
print(hyp)

tensor([[  10, 2668,   53,   60,   26,  996,  242,   20, 1002,  552, 1118,    3,
            4,  409,  251,   22,    3,    4,  434,  389, 2117, 4840,    4, 1839,
          354,  334,  292, 1852,    6,   26, 1218,  341,   65,    8, 1049, 3779,
           13, 2551, 1839,  833,   13, 3753, 4999]], device='cuda:0')


In [13]:
print(tokenizer.ids2text(ys_hat[0]))

HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE<sos/eos>
