In [1]:
import re
import json
import os
import shutil

from datetime import datetime
import IPython.display as ipd

import numpy as np
import pandas as pd
pd.options.display.max_columns = 100
pd.options.display.max_colwidth = 200

import datasets as hfd

import transformers
print(f'transformers version: "{transformers.__version__}"')

from transformers import (
    Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, 
    Wav2Vec2Processor, Wav2Vec2ProcessorWithLM,
    Wav2Vec2ForCTC,
    TrainingArguments, Trainer
)

import pyctcdecode

import torch
print(torch.cuda.is_available())

transformers version: "4.17.0"
True


In [2]:
from src.data_collator import DataCollatorCTCWithPadding

In [3]:
import logging
logging_format_str = '%(asctime)s:%(name)s:%(levelname)s:%(message)s'
logging.basicConfig(format=logging_format_str, level=logging.WARNING)
logger = logging.getLogger('STT')
logger.setLevel(level=logging.DEBUG)

In [4]:
# DATA_ROOT_DP = os.environ['DATA_HOME']
SSD_DATA_ROOT_DP = os.environ['SSD_DATA_HOME']

In [5]:
CV_PROCESSED_2_DP = f'{SSD_DATA_ROOT_DP}/datasets/cv-corpus-8.0-2022-01-19__be__processed__2'

In [6]:
ds = hfd.load_from_disk(CV_PROCESSED_2_DP)
ds

DatasetDict({
    train: Dataset({
        features: ['input_values', 'labels'],
        num_rows: 314305
    })
    dev: Dataset({
        features: ['input_values', 'labels'],
        num_rows: 15803
    })
    test: Dataset({
        features: ['input_values', 'labels'],
        num_rows: 15801
    })
})

## read model

In [7]:
wer_metric = hfd.load_metric("wer")

In [8]:
processor = Wav2Vec2Processor.from_pretrained('artifacts/processor')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [10]:
pretrained_checkpoint_dp = 'train/run_1/checkpoint-6480'
print(os.listdir(pretrained_checkpoint_dp))

['scheduler.pt', 'preprocessor_config.json', 'training_args.bin', 'pytorch_model.bin', 'scaler.pt', 'config.json', 'rng_state.pth', 'optimizer.pt', 'trainer_state.json']


In [11]:
# continue from checkpoint
model = Wav2Vec2ForCTC.from_pretrained(pretrained_checkpoint_dp)

In [12]:
model.to('cuda');
model.device

device(type='cuda', index=0)

## prepare processor

In [13]:
vocab = processor.tokenizer.get_vocab()
pd.Series(vocab).sort_values().to_frame().T

Unnamed: 0,|,',i,а,б,в,г,д,е,ж,з,й,к,л,м,н,о,п,р,с,т,у,ф,х,ц,ч,ш,ы,ь,э,ю,я,ё,і,ў,[UNK],[PAD],<s>,</s>
0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38


In [14]:
sorted_vocab_dict = [(char, ix) for char, ix in sorted(vocab.items(), key=lambda item: item[1])]
print(sorted_vocab_dict)

[('|', 0), ("'", 1), ('i', 2), ('а', 3), ('б', 4), ('в', 5), ('г', 6), ('д', 7), ('е', 8), ('ж', 9), ('з', 10), ('й', 11), ('к', 12), ('л', 13), ('м', 14), ('н', 15), ('о', 16), ('п', 17), ('р', 18), ('с', 19), ('т', 20), ('у', 21), ('ф', 22), ('х', 23), ('ц', 24), ('ч', 25), ('ш', 26), ('ы', 27), ('ь', 28), ('э', 29), ('ю', 30), ('я', 31), ('ё', 32), ('і', 33), ('ў', 34), ('[UNK]', 35), ('[PAD]', 36), ('<s>', 37), ('</s>', 38)]


In [15]:
lm_fp = 'artifacts/lm/cv8be_5gram.bin'

In [16]:
from pyctcdecode import build_ctcdecoder

decoder = build_ctcdecoder(
    labels=[x[0] for x in sorted_vocab_dict],
    kenlm_model_path=lm_fp,
)



In [17]:
processor_with_lm = Wav2Vec2ProcessorWithLM(
    feature_extractor=processor.feature_extractor,
    tokenizer=processor.tokenizer,
    decoder=decoder
)

In [18]:
print(len(processor.tokenizer.get_vocab()))
print(sorted(processor.tokenizer.get_vocab()) == sorted(processor_with_lm.tokenizer.get_vocab()))
print(" ".join(sorted(processor_with_lm.tokenizer.get_vocab())))

39
True
' </s> <s> [PAD] [UNK] i | а б в г д е ж з й к л м н о п р с т у ф х ц ч ш ы ь э ю я ё і ў


In [19]:
print(len(processor_with_lm.decoder._alphabet.labels))
print(processor_with_lm.decoder._alphabet.labels)

39
[' ', "'", 'i', 'а', 'б', 'в', 'г', 'д', 'е', 'ж', 'з', 'й', 'к', 'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у', 'ф', 'х', 'ц', 'ч', 'ш', 'ы', 'ь', 'э', 'ю', 'я', 'ё', 'і', 'ў', '⁇', '', '<s>', '</s>']


In [20]:
print(len(processor_with_lm.decoder._idx2vocab))
print(processor_with_lm.decoder._idx2vocab)

39
{0: ' ', 1: "'", 2: 'i', 3: 'а', 4: 'б', 5: 'в', 6: 'г', 7: 'д', 8: 'е', 9: 'ж', 10: 'з', 11: 'й', 12: 'к', 13: 'л', 14: 'м', 15: 'н', 16: 'о', 17: 'п', 18: 'р', 19: 'с', 20: 'т', 21: 'у', 22: 'ф', 23: 'х', 24: 'ц', 25: 'ч', 26: 'ш', 27: 'ы', 28: 'ь', 29: 'э', 30: 'ю', 31: 'я', 32: 'ё', 33: 'і', 34: 'ў', 35: '⁇', 36: '', 37: '<s>', 38: '</s>'}


## evaluate model

### one example at a time

In [21]:
def map_to_result_with_lm(example, device, processor_with_lm):
    with torch.no_grad():
        t = torch.tensor(example['input_values'], device=device).unsqueeze(0)
        logits = model(t)['logits']

    example['pred_text'] = processor_with_lm.batch_decode(logits.cpu().numpy())['text'][0]
    
    example['text'] = processor_with_lm.tokenizer.decode(example['labels'], group_tokens=False)
    
    return example

In [22]:
results = ds['test'].select([12, 9532, 9533, 12305]).map(
    map_to_result_with_lm, fn_kwargs=dict(device='cuda', processor_with_lm=processor_with_lm),
    remove_columns=ds['test'].column_names
)

0ex [00:00, ?ex/s]

In [23]:
df_results = results.to_pandas()[['text', 'pred_text']]
df_results['success'] = (df_results['text'] == df_results['pred_text']).astype('int')
df_results['wer'] = df_results.apply(
    lambda row: 
    wer_metric.compute(
        predictions=[row['pred_text']],
        references=[row['text']]
    ),
    axis=1
)

In [24]:
df_results

Unnamed: 0,text,pred_text,success,wer
0,толькі ў раёне таксама нашы людзі вясковыя,толькі ў раёне таксама наша людзі вясковыя,0,0.142857
1,у яго нават ёсць сольныя выступы,у яго нават есцэльныя выступы,0,0.333333
2,на ўзвышшы што ў цэнтры пляцоўкі стаяла абарончая вежа,на ўзвышшы што ў цэнтры пляцоўкі стаяла абарончае вежы,0,0.222222
3,шлюбам папярэднічалі перагаворы паміж сем'ямі,шлюбам папярэднічалі перагаворы паміж сем'ямі,1,0.0


In [25]:
wer = wer_metric.compute(predictions=df_results['pred_text'], references=df_results['text']) 
wer

0.18518518518518517

### batched

In [26]:
from transformers.trainer_utils import PredictionOutput

In [27]:
def parse_predictions_with_lm(pred: PredictionOutput):
    """
    Parse output of trainer.predict, i.e. predictions for whole dataset.
    """
    
    pred_text = processor_with_lm.batch_decode(pred.predictions)['text']
    
    # convert loss-ignore-token to pad token
    pred.label_ids[pred.label_ids == -100] = processor_with_lm.tokenizer.pad_token_id
    # do not group tokens for ground truth texts
    text = processor_with_lm.tokenizer.batch_decode(pred.label_ids, group_tokens=False)
    
    return {'pred_text': pred_text, 'text': text}

In [28]:
def compute_metrics_with_lm(pred: PredictionOutput):
    parsed_preds = parse_predictions_with_lm(pred)
    wer = wer_metric.compute(predictions=parsed_preds['pred_text'], references=parsed_preds['text'])
    return {"wer": wer}

In [29]:
training_args = TrainingArguments(
    output_dir='eval_output',
    per_device_eval_batch_size=32,
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics_with_lm,
    tokenizer=processor.feature_extractor,
)

In [None]:
pred_out = trainer.predict(ds['test'])

In [31]:
pred_out.metrics

{'test_loss': 0.20659717917442322,
 'test_wer': 0.1868199340918983,
 'test_runtime': 578.4481,
 'test_samples_per_second': 27.316,
 'test_steps_per_second': 0.285}

In [32]:
print(f"Test WER using LM: {pred_out.metrics['test_wer'] :.3f}")

Test WER using LM: 0.187


## compare predictions vs target

In [33]:
parsed_preds = parse_predictions_with_lm(pred_out)

In [34]:
df_results = pd.DataFrame(parsed_preds)[['text', 'pred_text']]
df_results['success'] = (df_results['text'] == df_results['pred_text']).astype('int')
df_results.head()

Unnamed: 0,text,pred_text,success
0,і на гэтым кірунку нас чакаюць вялікія складанасці,і на гэтым кірунку нас чакаюць вялікія складанасці,1
1,быў разбураны пры вызваленні горада ад нямецкафашысцкіх захопнікаў,быў разбураны пры вызваленні горада ад нямецкай фашыцкіх захопнікаў,0
2,далейшае супрацоўніцтва паміж нямецкім і савецкім прыняло форму абмену польскіх ваеннапалонных,далейшы супразоніцтва мічнямецкім савецкім пыняфорамо меннуўпольскіх ваеннаогам,0
3,цвіценне працягваецца з мая да позняй восені,цвіценне працягваецца з мая да позняй восені,1
4,таксама пройдзе канферэнцыя фестываль і мастацкая выстава,таксама пройдзе канферэнцыя фестываль і мастацкая выстава,1


In [35]:
df_results['success'].value_counts(dropna=False)

0    9674
1    6127
Name: success, dtype: int64

In [37]:
stats = df_results['success'].value_counts(dropna=False, normalize=True)
stats

0    0.61224
1    0.38776
Name: success, dtype: float64

In [41]:
print(f'Rate of fully recognized clips from Test set: {stats[1] :.3f}')

Rate of fully recognized clips from Test set: 0.388
