In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [9]:
# !wget https://huggingface.co/huseinzol05/language-model-bahasa-manglish-combined/resolve/main/model.klm

In [2]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import transformers
from transformers import (
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    is_apex_available,
    set_seed,
    AutoModelForCTC,
    TFWav2Vec2ForCTC,
    TFWav2Vec2PreTrainedModel,
    Wav2Vec2PreTrainedModel,
)

In [4]:
import string
import json

CTC_VOCAB = [''] + list(string.ascii_lowercase + string.digits) + [' ']

In [5]:
vocab_dict = {v: k for k, v in enumerate(CTC_VOCAB)}
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

with open("ctc-vocab.json", "w") as vocab_file:
    json.dump(vocab_dict, vocab_file)

tokenizer = Wav2Vec2CTCTokenizer(
    "ctc-vocab.json",
    unk_token="[UNK]",
    pad_token="[PAD]",
    word_delimiter_token="|",
)

In [6]:
from glob import glob
malay = sorted(glob('malay-test/*.wav'), key = lambda x: int(x.split('/')[1].replace('.wav', '')))
singlish = sorted(glob('singlish-test/*.wav'), key = lambda x: int(x.split('/')[1].replace('.wav', '')))
mandarin = sorted(glob('mandarin-test/*.wav'), key = lambda x: int(x.split('/')[1].replace('.wav', '')))
len(malay), len(singlish), len(mandarin)

(765, 3579, 614)

In [11]:
with open('malay-test.json') as fopen:
    malay_label = json.load(fopen)
with open('singlish-test.json') as fopen:
    singlish_label = json.load(fopen)
with open('mandarin-test.json') as fopen:
    mandarin_label = json.load(fopen)
    
len(malay_label), len(singlish_label), len(mandarin_label)

(765, 3579, 614)

In [12]:
from sklearn.utils import shuffle

audio = malay + singlish + mandarin
labels = malay_label + singlish_label + mandarin_label
audio, labels = shuffle(audio, labels)
test_set = list(zip(audio, labels))
test_set[:10]

[('singlish-test/3460.wav',
  'the mother tongue language collections in libraries will also be enhanced'),
 ('singlish-test/44.wav',
  'doing an ankle rotation with the tree because he can'),
 ('singlish-test/2807.wav',
  'a massive field of trees designed in the shape of a giant q r code'),
 ('malay-test/247.wav',
  'agar bisa segera keluar dari ruangan maut pak dadi mulai keluar kelas'),
 ('singlish-test/1700.wav', 'it smelt like petroleum and that was disturbing'),
 ('mandarin-test/289.wav', 'nan dao shi wo pu tong hua bu biao zhun'),
 ('mandarin-test/435.wav', 'jiao wo jin tian xia wu san dian qu ji chang'),
 ('singlish-test/857.wav',
  'apart from new ideas singapore also needs a customise model to meets its social needs'),
 ('malay-test/154.wav', 'selepas lebih kurang tiga'),
 ('singlish-test/1370.wav', 'after all google search can only get you so far')]

In [22]:
import soundfile as sf
import numpy as np

def norm_audio(x):
    return (x - x.mean()) / np.sqrt(x.var() + 1e-7)

def sequence_1d(
    seq, maxlen=None, padding: str = 'post', pad_int=0, return_len=False
):
    if padding not in ['post', 'pre']:
        raise ValueError('padding only supported [`post`, `pre`]')

    if not maxlen:
        maxlen = max([len(s) for s in seq])

    padded_seqs, length = [], []
    for s in seq:
        if isinstance(s, np.ndarray):
            s = s.tolist()
        if padding == 'post':
            padded_seqs.append(s + [pad_int] * (maxlen - len(s)))
        if padding == 'pre':
            padded_seqs.append([pad_int] * (maxlen - len(s)) + s)
        length.append(len(s))
    if return_len:
        return np.array(padded_seqs), length
    return np.array(padded_seqs)

def batching(audios):
    audios = [sf.read(a)[0] for a in audios]
    batch, lens = sequence_1d(audios,return_len=True)
    attentions = [[1] * l for l in lens]
    attentions = sequence_1d(attentions)
    normed_input_values = []

    for vector, length in zip(batch, attentions.sum(-1)):
        normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
        if length < normed_slice.shape[0]:
            normed_slice[length:] = 0.0

        normed_input_values.append(normed_slice)

    normed_input_values = np.array(normed_input_values)
    return normed_input_values.astype(np.float32), attentions

In [25]:
model = AutoModelForCTC.from_pretrained(
    './wav2vec2-mixed/checkpoint-60000',
    ctc_loss_reduction="mean",
    pad_token_id=tokenizer.pad_token_id,
    vocab_size=len(tokenizer),
)

In [26]:
_ = model.eval()

In [23]:
batch_size = 4
batch_x = audio[:batch_size]
normed_input_values, attentions = batching(batch_x)

In [70]:
o_pt = model(torch.from_numpy(normed_input_values.astype(np.float32)), 
             attention_mask = torch.from_numpy(attentions))
o_pt = o_pt.logits.detach().numpy()
o_pt = log_softmax(o_pt, axis = -1)

In [71]:
pred_ids = np.argmax(o_pt, axis = -1)
tokenizer.batch_decode(pred_ids)

['te madatown language collections in librarys will also be enhanced',
 'doing an imped anchor rotation with the tree because he can',
 'a massive field of trees designed in the shape of a giant q r cod',
 'agra bisa segera keluar dari keroangan maut pak dadi melai keluar kelas']

In [48]:
unique_vocab = list(vocab_dict.keys())
unique_vocab[-3] = ' ' 
unique_vocab[-2] = '?'
unique_vocab[-1] = '_'

In [69]:
from pyctcdecode import build_ctcdecoder
from scipy.special import log_softmax
import kenlm

kenlm_model = kenlm.Model('model.klm')
decoder = build_ctcdecoder(
    unique_vocab,
    kenlm_model,
    alpha=0.2,
    beta=1.0,
    ctc_token_idx=tokenizer.pad_token_id
)

In [72]:
for k in range(len(o_pt_)):
    out = decoder.decode_beams(o_pt_[k], prune_history=True)
    d_lm2, lm_state, timesteps, logit_score, lm_score = out[0]
    print(k, d_lm2)

0 the madatown language collections in librarys will also be enhanced
1 doing an impede anchor rotation with the tree because he can
2 a massive field of trees designed in the shape of a giant q r cod
3 agra bisa segera keluar dari ruangan maut pak dadi mulai keluar kelas


In [73]:
labels[:batch_size]

['the mother tongue language collections in libraries will also be enhanced',
 'doing an ankle rotation with the tree because he can',
 'a massive field of trees designed in the shape of a giant q r code',
 'agar bisa segera keluar dari ruangan maut pak dadi mulai keluar kelas']

In [74]:
def calculate_cer(actual, hyp):
    """
    Calculate CER using `python-Levenshtein`.
    """
    import Levenshtein as Lev

    actual = actual.replace(' ', '')
    hyp = hyp.replace(' ', '')
    return Lev.distance(actual, hyp) / len(actual)


def calculate_wer(actual, hyp):
    """
    Calculate WER using `python-Levenshtein`.
    """
    import Levenshtein as Lev

    b = set(actual.split() + hyp.split())
    word2char = dict(zip(b, range(len(b))))

    w1 = [chr(word2char[w]) for w in actual.split()]
    w2 = [chr(word2char[w]) for w in hyp.split()]

    return Lev.distance(''.join(w1), ''.join(w2)) / len(actual.split())


In [None]:
from tqdm import tqdm

wer, cer = [], []
wer_lm, cer_lm = [], []

for i in tqdm(range(0, len(audio), batch_size)):
    batch_x = audio[i: i + batch_size]
    batch_y = labels[i: i + batch_size]
    normed_input_values, attentions = batching(batch_x)
    o_pt = model(torch.from_numpy(normed_input_values.astype(np.float32)), 
             attention_mask = torch.from_numpy(attentions))
    o_pt = o_pt.logits.detach().numpy()
    o_pt = log_softmax(o_pt, axis = -1)
    pred_ids = np.argmax(o_pt, axis = -1)
    pred = tokenizer.batch_decode(pred_ids)
    for k in range(len(o_pt)):
        out = decoder.decode_beams(o_pt[k], prune_history=True)
        d_lm2, lm_state, timesteps, logit_score, lm_score = out[0]
        
        wer.append(calculate_wer(batch_y[k], pred[k]))
        cer.append(calculate_cer(batch_y[k], pred[k]))
        
        wer_lm.append(calculate_wer(batch_y[k], d_lm2))
        cer_lm.append(calculate_cer(batch_y[k], d_lm2))

 78%|███████▊  | 962/1240 [6:15:18<2:28:48, 32.12s/it]

In [81]:
np.mean(wer), np.mean(cer), np.mean(wer_lm), np.mean(cer_lm)

(0.14151468058308714,
 0.048555454439612775,
 0.09809135311921899,
 0.03977501945111893)

In [83]:
len(audio), len(wer)

(4958, 4958)

In [89]:
index_malay = [no for no, i in enumerate(audio) if 'malay-test/' in i]
index_singlish = [no for no, i in enumerate(audio) if 'singlish-test/' in i]
index_mandarin = [no for no, i in enumerate(audio) if 'mandarin-test/' in i]

In [90]:
np.mean(np.array(wer)[index_malay]), np.mean(np.array(cer)[index_malay]), np.mean(np.array(wer_lm)[index_malay]), np.mean(np.array(cer_lm)[index_malay])

(0.23714922876687583,
 0.05372605571018908,
 0.1294898148329521,
 0.03508559320616622)

In [91]:
np.mean(np.array(wer)[index_singlish]), np.mean(np.array(cer)[index_singlish]), np.mean(np.array(wer_lm)[index_singlish]), np.mean(np.array(cer_lm)[index_singlish])

(0.12941144843784677,
 0.04883661835898531,
 0.09411106530063956,
 0.04119293317615638)

In [92]:
np.mean(np.array(wer)[index_mandarin]), np.mean(np.array(cer)[index_mandarin]), np.mean(np.array(wer_lm)[index_mandarin]), np.mean(np.array(cer_lm)[index_mandarin])

(0.09291050873816364,
 0.04047435404966954,
 0.08217217867571727,
 0.037352703254831865)