In [22]:
from espnet2.tasks.asr_TruCLeS import ASRTask
from pathlib import Path
import yaml
import argparse
import torch
import torchaudio
import os

import jiwer
from espnet2.text.build_tokenizer import build_tokenizer
from espnet2.text.token_id_converter import TokenIDConverter

In [2]:
test_dataset = torchaudio.datasets.LIBRISPEECH("/home/pb_deployment/Downloads", url="test-clean", download=True)

In [3]:
model_root_dir = "/home/pb_deployment/espnet/asr_inference/model_files/e_branchformer_librispeech/"
asr_config_path = os.path.join(model_root_dir, "exp/asr_train_asr_e_branchformer_raw_en_bpe5000_sp/config.yaml")
model_path = os.path.join(model_root_dir, "exp/asr_train_asr_e_branchformer_raw_en_bpe5000_sp/valid.acc.ave_10best.pth")
bpe_model_path = os.path.join(model_root_dir,"data/en_token_list/bpe_unigram5000/bpe.model")

# Tokenizer definition

In [4]:
class TruCLeS_Tokenizer:
    def __init__(self, asr_config_path, bpe_model_path):
        self.asr_config_path = asr_config_path
        self.bpe_model_path = bpe_model_path
        self.tokenizer = build_tokenizer(
            token_type='bpe',
            bpemodel=bpe_model_path  
        )
        with open(self.asr_config_path, 'r') as file:
            config_data = yaml.safe_load(file)
        self.tokens_list = config_data.get('token_list', [])
        self.tokenIDConvertor = TokenIDConverter(token_list = self.tokens_list)

        self.ids = dict(zip(self.tokens_list, [i for i in range(len(self.tokens_list))]))

    def text2ids(self, text):
        tokenized = self.tokenizer.text2tokens(text)
        ids = self.tokenIDConvertor.tokens2ids(tokenized)
        return ids
    
    def ids2text(self, ids):
        tokenized = self.tokenIDConvertor.ids2tokens(ids)
        text = self.tokenizer.tokens2text(tokenized)
        return text

    def text2tokens(self, text):
        return self.tokenizer.text2tokens(text)
        
    def tokens2text(self, tokens):
        return self.tokenizer.tokens2text(tokens)

    def tokens2ids(self, tokens):
        return self.tokenIDConvertor.tokens2ids(tokens)

    def ids2tokens(self, ids):
        return self.tokenIDConvertor.ids2tokens(ids)

# Loading ASR model for inference

In [5]:
def buildAndLoadASRModel(asr_config, asr_model, device="cpu"):
    config_file = Path(asr_config)
    with config_file.open("r", encoding="utf-8") as f:
        args = yaml.safe_load(f)
    
    args = argparse.Namespace(**args)
    model = ASRTask.build_model(args)
    
    state_dict = torch.load(asr_model, map_location=device)
    
    
    model.to(device)
    use_lora = getattr(args, "use_lora", False)
    model.load_state_dict(state_dict, strict=not use_lora)
    if any(["frontend.upstream.model" in k for k in state_dict.keys()]):
        if any(
            [
                "frontend.upstream.upstream.model" in k
                for k in dict(model.named_parameters())
                            ]
        ):
            state_dict = {
                k.replace(
                    "frontend.upstream.model",
                    "frontend.upstream.upstream.model",
                ): v
                for k, v in state_dict.items()
            }
            model.load_state_dict(state_dict, strict=not use_lora)
    return model

In [6]:
ASRModel = buildAndLoadASRModel(asr_config_path, model_path, device="cuda")
tokenizer = TruCLeS_Tokenizer(asr_config_path, bpe_model_path)

In [15]:
speech = test_dataset[0][0]
speech_length = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))

#text = test_dataset[0][2]

text = "HE HOPED THERE WOULD BE STEW FOR DINNER CARROTS AND BRUISED POTATOES AND FAT BUTTON PIECES TO BE LABLED OUT IN THICK FLOWER FLATTENED SAUCE"
print(text)
text = torch.tensor(tokenizer.text2ids(text)).unsqueeze(0)
text_length = torch.tensor([text.shape[-1]])

print(speech.shape, speech_length, text, text_length)

tensor([[  10, 2668,   53,   60,   26,  996,  242,   20, 1002,  409,  251,   22,
            3,    4,  434,  389, 2117, 4840,    4, 1839,   27,  292, 1852,    6,
           26,  370,  721,   30,   65,    8, 1049, 1849, 2005,  833,   13, 3753]])
torch.Size([1, 166960]) tensor([166960]) tensor([[  10, 2668,   53,   60,   26,  996,  242,   20, 1002,  409,  251,   22,
            3,    4,  434,  389, 2117, 4840,    4, 1839,   27,  292, 1852,    6,
           26,  370,  721,   30,   65,    8, 1049, 1849, 2005,  833,   13, 3753]]) tensor([36])


```


encoder_out, encoder_out_lens = ASRModel.encode(speech: torch.Tensor, speech_lengths: torch.Tensor)

ys_hat, decoder_out, softmax_out = ASRModel.decode_TruCLeS(
                                        encoder_out: torch.Tensor,
                                        encoder_out_lens: torch.Tensor,
                                        text: torch.Tensor,
                                        text_lengths: torch.Tensor,
                                    )
```    

In [16]:
DEVICE = "cuda"
ASRModel.eval()
encoder_out, encoder_out_lens = ASRModel.encode(speech.to(DEVICE), speech_length.to(DEVICE))

In [17]:
hyps, hyp_lens, decoder_outs= ASRModel.decode_TruCLeS(encoder_out, encoder_out_lens, text.to(DEVICE), text_length.to(DEVICE))

In [18]:
print(hyps)
print(hyp_lens)

tensor([[  10, 2668,   53,   60,   26,  996,  242,   20, 1002,  552,  251,   22,
            3,    4,  409,  389, 2117, 4840,    4, 1839,  354,  292, 1852,    6,
           26, 1218,  341,   30,   65,    8, 1049, 3779, 1839,  833,   13, 3753,
         4999]], device='cuda:0')
tensor([36], device='cuda:0')


In [45]:
hyp = tokenizer.ids2text(hyps[0][:hyp_lens[0]])
ref = test_dataset[0][2]
hyp = "HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS POTATOES AND FAT CHICKEN MUTTON PIECES TO BE LADLED OUT IN THICK PAPERED FLOUR FATTENED SAUCE GRAVY"
print(hyp)
print(hyp_lens)


HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS POTATOES AND FAT CHICKEN MUTTON PIECES TO BE LADLED OUT IN THICK PAPERED FLOUR FATTENED SAUCE GRAVY
tensor([36], device='cuda:0')


In [46]:
print(jiwer.visualize_alignment(jiwer.process_words(ref, hyp), show_measures=False, skip_correct=False))
print(tokenizer.text2tokens(ref))
print(tokenizer.text2tokens(hyp))

sentence 1
REF: HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT ******* MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE *****
HYP: HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS *** ******* POTATOES AND FAT CHICKEN MUTTON PIECES TO BE LADLED OUT IN THICK  PAPERED FLOUR FATTENED SAUCE GRAVY
                                                                   D       D                        I                                                S                          I

['▁HE', '▁HOPED', '▁THERE', '▁WOULD', '▁BE', '▁STE', 'W', '▁FOR', '▁DINNER', '▁TURN', 'IP', 'S', '▁AND', '▁CAR', 'RO', 'T', 'S', '▁AND', '▁B', 'RU', 'ISED', '▁POTATOES', '▁AND', '▁FAT', '▁M', 'UT', 'TON', '▁PIECES', '▁TO', '▁BE', '▁LAD', 'LED', '▁OUT', '▁IN', '▁THICK', '▁PEPPER', 'ED', '▁FLOUR', '▁FAT', 'TEN', 'ED', '▁SAUCE']
['▁HE', '▁HOPED', '▁THERE', '▁WOULD', '▁BE', '▁STE', 'W', '▁FOR', '▁DINNER', '▁TURN', 'IP', 'S', '▁AND', '▁CAR', 'RO', 'T', 'S

In [47]:
def getAlignedHyp(ref, hyp, tokenizer=tokenizer):
    wOut = jiwer.process_words(ref, hyp)
    wRefList = wOut.references[0]
    wHypList = wOut.hypotheses[0]
    wAligns = wOut.alignments[0]
    tokenizedWRefList = tokenizer.text2tokens(wRefList)
    tokenizedWHypList = tokenizer.text2tokens(wHypList)
    
    align_list = []
    for wAlign in wAligns:
        if wAlign.type == "equal":
            align_list = align_list + tokenizedWHypList[wAlign.hyp_start_idx:wAlign.hyp_end_idx]
        elif wAlign.type == "insert":
            align_list.append([])
        elif wAlign.type == "substitute":
            hypWords = tokenizedWHypList[wAlign.hyp_start_idx:wAlign.hyp_end_idx]
            refWords = tokenizedWRefList[wAlign.ref_start_idx:wAlign.ref_end_idx]
            word_list = []
            for hypWord, refWord in zip(hypWords, refWords):
                tOut = jiwer.process_words(" ".join(refWord), " ".join(hypWord))
                tRefList = tOut.references[0]
                tHypList = tOut.hypotheses[0]
                tAligns = tOut.alignments[0]
                for tAlign in tAligns:
                    if tAlign.type == "equal":
                        word_list = word_list + tHypList[tAlign.hyp_start_idx:tAlign.hyp_end_idx]
                    elif tAlign.type == "substitute":
                        word_list = word_list + tRefList[tAlign.ref_start_idx:tAlign.ref_end_idx]
                    elif tAlign.type == "insert":
                        word_list.append("")
            align_list = align_list + [word_list]
    return align_list

In [48]:
align_list = getAlignedHyp(ref, hyp)
print(align_list)

[['▁HE'], ['▁HOPED'], ['▁THERE'], ['▁WOULD'], ['▁BE'], ['▁STE', 'W'], ['▁FOR'], ['▁DINNER'], ['▁TURN', 'IP', 'S'], ['▁AND'], ['▁CAR', 'RO', 'T', 'S'], ['▁POTATOES'], ['▁AND'], ['▁FAT'], [], ['▁M', 'UT', 'TON'], ['▁PIECES'], ['▁TO'], ['▁BE'], ['▁LAD', 'LED'], ['▁OUT'], ['▁IN'], ['▁THICK'], ['▁PEPPER', 'ED'], ['▁FLOUR'], ['▁FAT', 'TEN', 'ED'], ['▁SAUCE'], []]


In [38]:
i = 0
for word in align_list:
    if len(word) != 0:
        for tok in word:
            if tok!="":
                print(i,tok, 1)
            else:
                print(i,"no_tok", 0)
            i += 1
    else:
        print("no_word", 0)

0 ▁HE 1
1 ▁HOPED 1
2 ▁THERE 1
3 ▁WOULD 1
4 ▁BE 1
5 ▁STE 1
6 W 1
7 ▁FOR 1
8 ▁DINNER 1
9 ▁TURN 1
10 no_tok 0
11 IP 1
12 S 1
13 ▁AND 1
14 ▁CAR 1
15 RO 1
16 T 1
17 ▁POTATOES 1
18 ▁AND 1
19 ▁FAT 1
20 ▁M 1
21 TON 1
22 ▁PIECES 1
23 ▁TO 1
24 ▁BE 1
25 ▁LAD 1
26 LED 1
27 no_tok 0
28 ▁OUT 1
29 ▁IN 1
30 ▁THICK 1
31 ▁PEPPER 1
32 ▁FAT 1
33 TEN 1
34 ED 1
35 ▁SAUCE 1


In [39]:
final_toks = []
for word in align_list:
    for tok in word:
        final_toks.append(tok)
print(final_toks)

['▁HE', '▁HOPED', '▁THERE', '▁WOULD', '▁BE', '▁STE', 'W', '▁FOR', '▁DINNER', '▁TURN', '', 'IP', 'S', '▁AND', '▁CAR', 'RO', 'T', '▁POTATOES', '▁AND', '▁FAT', '▁M', 'TON', '▁PIECES', '▁TO', '▁BE', '▁LAD', 'LED', '', '▁OUT', '▁IN', '▁THICK', '▁PEPPER', '▁FAT', 'TEN', 'ED', '▁SAUCE']


In [23]:
def getAlignedScore(ref, hyp, softmax, tokenizer=tokenizer):
    wOut = jiwer.process_words(ref, hyp)
    wRefList = wOut.references[0]
    wHypList = wOut.hypotheses[0]
    wAligns = wOut.alignments[0]
    tokenizedWRefList = tokenizer.text2tokens(wRefList)
    tokenizedWHypList = tokenizer.text2tokens(wHypList)
    
    align_list = []
    for wAlign in wAligns:
        if wAlign.type == "equal":
            align_list = align_list + tokenizedWHypList[wAlign.hyp_start_idx:wAlign.hyp_end_idx]
        elif wAlign.type == "insert":
            align_list.append([])
        elif wAlign.type == "substitute":
            hypWords = tokenizedWHypList[wAlign.hyp_start_idx:wAlign.hyp_end_idx]
            refWords = tokenizedWRefList[wAlign.ref_start_idx:wAlign.ref_end_idx]
            word_list = []
            for hypWord, refWord in zip(hypWords, refWords):
                tOut = jiwer.process_words(" ".join(refWord), " ".join(hypWord))
                tRefList = tOut.references[0]
                tHypList = tOut.hypotheses[0]
                tAligns = tOut.alignments[0]
                for tAlign in tAligns:
                    if tAlign.type == "equal":
                        word_list = word_list + tHypList[tAlign.hyp_start_idx:tAlign.hyp_end_idx]
                    elif tAlign.type == "substitute":
                        word_list = word_list + tRefList[tAlign.ref_start_idx:tAlign.ref_end_idx]
                    elif tAlign.type == "insert":
                        word_list.append("")
            align_list = align_list + [word_list]
    return align_list