In [None]:
# Copyright 2021 Dialpad, Inc. (Shreekantha Nadig, Riqiang Wang)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

In [1]:
from espnet.asr.pytorch_backend.asr_init import load_trained_model
from espnet.bin.asr_recog import get_parser
from espnet.asr.asr_utils import parse_hypothesis
from espnet.asr.asr_utils import get_model_conf
import espnet.nets.pytorch_backend.lm.default as lm_pytorch
from espnet.asr.asr_utils import torch_load
import espnet.lm.pytorch_backend.extlm as extlm_pytorch
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.transform.cmvn import CMVN
import torchaudio
import torch
import numpy as np
import resampy
import math
import kaldiio
import numpy as np
import scipy.io.wavfile as wav
import wave
import array
import time
import matplotlib.pyplot as plt
from glob import glob

In [2]:
cmvn_stats_file = "mucs_2021_models/b0/is21_challenge/data/task1/train_combined/cmvn_default.ark"
cmvn = CMVN(cmvn_stats_file, norm_vars=True)

# Load model

In [3]:
model_path = "mucs_2021_models/b0/exp/train_enc_dec_multilingual_default_large/results/model.acc.best"
decode_config = "mucs_2021_models/b0/conf/decode.yaml"
model, train_args = load_trained_model(model_path)
model.eval()



E2E(
  (enc): Encoder(
    (enc): ModuleList(
      (0): VGG2L(
        (conv1_1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1): RNNP(
        (birnn0): LSTM(2560, 1024, batch_first=True, bidirectional=True)
        (bt0): Linear(in_features=2048, out_features=1024, bias=True)
        (birnn1): LSTM(1024, 1024, batch_first=True, bidirectional=True)
        (bt1): Linear(in_features=2048, out_features=1024, bias=True)
        (birnn2): LSTM(1024, 1024, batch_first=True, bidirectional=True)
        (bt2): Linear(in_features=2048, out_features=1024, bias=True)
        (birnn3): LSTM(1024, 1024, batch_first=True, bidirectional=True)
        (bt3): Linear(in_features=2048, out_features=1024, bias=Tru

# Load RNNLM

In [4]:
rnnlm_path = "mucs_2021_models/rnnlms/combined/rnnlm/rnnlm.model.best"
rnnlm_opts = f"--rnnlm {rnnlm_path}"
rnnlm_args = get_model_conf(rnnlm_path)
rnnlm = lm_pytorch.ClassifierWithState(
    lm_pytorch.RNNLM(
        len(rnnlm_args.char_list_dict),
        rnnlm_args.layer,
        rnnlm_args.unit,
        getattr(rnnlm_args, "embed_unit", None),  # for backward compatibility
    )
)
torch_load(rnnlm_path, rnnlm)
rnnlm.eval()

ClassifierWithState(
  (lossfun): CrossEntropyLoss()
  (predictor): RNNLM(
    (embed): Embedding(304, 1024)
    (rnn): ModuleList(
      (0): LSTMCell(1024, 1024)
      (1): LSTMCell(1024, 1024)
    )
    (dropout): ModuleList(
      (0): Dropout(p=0.5, inplace=False)
      (1): Dropout(p=0.5, inplace=False)
      (2): Dropout(p=0.5, inplace=False)
    )
    (lo): Linear(in_features=1024, out_features=304, bias=True)
  )
)

In [5]:
parser = get_parser()
rnnlm_opts = ""
rnnlm=rnnlm
args = parser.parse_args(f'--config {decode_config} \
                          --ngpu 1 --backend pytorch --batchsize 1 --result-label results.json \
                          --model {model_path} \
                          {rnnlm_opts} \
                          --api v1')
set_deterministic_pytorch(args)
model.recog_args = args

In [6]:
args.beam_size = 10
args.lm_weight = 0.2
args.ctc_weight = 0.5
args.nbest = 1
args.verbose = 4
args.debugmode = 2

In [7]:
audio_files = glob("downloads/hindi/test/audio/*.wav")
audio_file = audio_files[0]

# Extract features

In [8]:
(signal, rate) = torchaudio.load(audio_file)
current_signal = resampy.resample(signal[0].numpy(), rate, 8000, axis=0)
lmspc = torchaudio.compliance.kaldi.fbank(
            waveform=torch.unsqueeze(torch.tensor(current_signal), axis=0),
            sample_frequency=8000,
            dither=1e-32,
            energy_floor=0,
            num_mel_bins=80,
        )

# Apply CMVN

In [9]:
normed_feats = cmvn(lmspc)
normed_feats = torch.from_numpy(normed_feats.numpy())

# Perform inference

In [10]:
with torch.no_grad():
    feat = (
    [normed_feats]
    )
    nbest_hyps = model.recognize_batch(feat, args, char_list=train_args.char_list, rnnlm=rnnlm)

# Print hypothesis

In [11]:
hypothesis = ''.join([train_args.char_list[ele] for ele in nbest_hyps[0][0]['yseq']])
print(f'HYPO for {audio_file}: ', hypothesis)

HYPO for downloads/hindi/test/audio/4602_088.wav:  <eos>▁सेठ▁जी▁ने▁समझ▁लिया▁कि▁इस▁समय▁समझाने▁बुझाने▁से▁कुछ▁काम▁न▁चलेगा<eos>
