In [1]:
!pip install transformers datasets SoundFile
!pip install https://github.com/kpu/kenlm/archive/master.zip
!pip install pyctcdecode
!pip install jiwer

Collecting transformers
  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 8.2 MB/s 
[?25hCollecting datasets
  Downloading datasets-1.18.3-py3-none-any.whl (311 kB)
[K     |████████████████████████████████| 311 kB 50.0 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 62.0 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 7.4 MB/s 
[?25hCollecting tokenizers!=0.11.3,>=0.10.1
  Downloading tokenizers-0.11.4-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.8 MB)
[K     |████████████████████████████████| 6.8 MB 71.0 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 8

In [2]:
import kenlm
import argparse
import pandas as pd
import numpy as np
import random
import torch
import os
import datasets

from pathlib import Path
from sklearn.model_selection import ParameterGrid

from datasets import load_metric, load_dataset
from datasets import Dataset, DatasetDict, Metric, IterableDatasetDict, IterableDataset
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, PreTrainedTokenizer

from pyctcdecode import build_ctcdecoder, BeamSearchDecoderCTC
from pyctcdecode.language_model import load_unigram_set_from_arpa, LanguageModel, AbstractLanguageModel
from pyctcdecode.alphabet import Alphabet, verify_alphabet_coverage

from pathlib import Path
from torch.utils.data import DataLoader
from collections import defaultdict
from typing import Union, Dict, List, Tuple, Optional, Collection
from functools import partial
# from src.decoding.decode import build_decoder, grid_search_decoder

In [3]:
def get_kenlm_model_unigrams(kenlm_model_path: str,
                             return_alphabet: bool=False,
                             labels=None) -> Tuple[kenlm.Model, Collection[str]]:
  kenlm_model = kenlm.Model(kenlm_model_path)
  if kenlm_model_path.endswith(".arpa"):
    unigrams = load_unigram_set_from_arpa(kenlm_model_path)
  else:
    print(
        "Unigrams not provided and cannot be automatically determined from LM file (only "
        "arpa format). Decoding accuracy might be reduced."
    )
    unigrams = None
  if not return_alphabet:
    return kenlm_model, unigrams
  else:
    return kenlm_model, unigrams, Alphabet.build_alphabet(labels)

def tokenize(sample, tokenizer, feature_extractor):
  if isinstance(sample, datasets.arrow_dataset.Batch):
    sentence_inputs = [s.lower() for s in sample['sentence']]
    audio_inputs = [s['array'] for s in sample['audio']]
    sampling_rate = sample['audio'][0]['sampling_rate']
  else:
    sentence_inputs = sample['sentence'].lower()
    audio_inputs = sample['audio']['array']
    sampling_rate = sample['audio']['sampling_rate']

  
  if tokenizer:
    token = tokenizer(sentence_inputs,
                      padding='longest')
    token['sentence_attention_mask'] = token.pop('attention_mask')
  else:
    token = {}

  audio = feature_extractor(audio_inputs,
                            sampling_rate=sampling_rate,
                            padding='longest')
  audio['audio_attention_mask'] = audio.pop('attention_mask')
  
  return dict(**audio, **token)

def my_build_ctc_decoder(
    labels: List[str],
    kenlm_model: kenlm.Model,
    unigrams: Collection[str],
    alpha: float = 0.5,
    beta: float = 1.5,
    unk_score_offset: float = -10.0,
    lm_score_boundary: bool = True,
    alphabet: Optional[Alphabet]=None) -> BeamSearchDecoderCTC:
  if alphabet is None:
    alphabet = Alphabet.build_alphabet(labels)
  if unigrams is not None:
      verify_alphabet_coverage(alphabet, unigrams)
  if kenlm_model is not None:
      language_model: Optional[AbstractLanguageModel] = LanguageModel(
          kenlm_model,
          unigrams,
          alpha=alpha,
          beta=beta,
          unk_score_offset=unk_score_offset,
          score_boundary=lm_score_boundary,
      )
  else:
      language_model = None
  return BeamSearchDecoderCTC(alphabet, language_model)

def build_decoder(asr_processor,
                  kenlm_path: Union[str, Path, kenlm.Model],
                  alpha: float,
                  beta: float,
                  return_decoder: bool=True,
                  unigrams: Optional[Collection[str]]=None,
                  alphabet=None,
                  **kwargs) -> Union[BeamSearchDecoderCTC, Wav2Vec2ProcessorWithLM]:
    """ Build the decoder and return either the decoder itself or the processor with LM.
    
    Parameters
    ----------
    asr_processor: Wav2Vec2Processor
        Wav2Vec2Processor instance
    kenlm_path: str or Path
        Path to trained KenLM
    alpha: float
        Alpha parameter for Decoder
    beta: float
        Beta parameter for Decoder
    return_decoder: bool
        If True, returns the decoder obtained from build_ctcdecoder method, otherwise returns Wav2Vec2ProcessorWithLM instance
        
    Returns
    ----------
        decoder or Wav2Vec2ProcessorWithLM
    """
    labels = generate_labels(asr_processor.tokenizer.get_vocab(), sort=True)

    if not unigrams is None and isinstance(kenlm_path, kenlm.Model):
      decoder = my_build_ctc_decoder(
          labels,
          kenlm_path,
          unigrams,
          alpha,
          beta,
          alphabet=alphabet,
          **kwargs
      )
    else:
      assert isinstance(kenlm_path, (str, Path))
      decoder = build_ctcdecoder(
          labels,
          kenlm_path,
          alpha=alpha,
          beta=beta,
          **kwargs
      )
    
    if return_decoder:
        return decoder
    
    processor_with_lm = Wav2Vec2ProcessorWithLM(
        feature_extractor=asr_processor.feature_extractor,
        tokenizer=asr_processor.tokenizer,
        decoder=decoder
    )
    
    return processor_with_lm

def generate_labels(vocab_dict, sort=True):
  if sort:
    sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
  else:
    sorted_vocab_dict = vocab_dict

  return list(sorted_vocab_dict.keys())

def compute_total_len(dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset, DataLoader],
                      gt_text_name: str='sentence',
                      tokenizer: Optional[PreTrainedTokenizer]=None) -> int:
  res = 0
  if not isinstance(dataset, DataLoader):
    for audio_sample in dataset:
      res += len(audio_sample[gt_text_name])
  else:
    for b in dataset:
      decoded = tokenizer.batch_decode(b['input_ids'], skip_special_tokens=True)
      for phrase in decoded:
        res += len(phrase)

  return res

def _generate_key(args: dict) -> Tuple:
  k = tuple(v for _, v in sorted(args.items(), key=lambda it: it[0]))
  return k

def grid_search_decoder(asr_processor: Wav2Vec2Processor,
                        asr_model: Wav2Vec2ForCTC,
                        kenlm_path: Union[str, Path],
                        decoder_param_generator,
                        loader: DataLoader,
                        metric: Optional[Union[Metric, Dict]]=None,
                        gt_text_name: str='sentence',
                        device=None,
                        store_all_results: bool=False) -> Dict[Tuple[float, float], Dict[str, np.ndarray]]:
  if metric is None:
    metric = {
        'wer': load_metric('wer'),
        'cer': load_metric('cer')
    }
  elif isinstance(metric, datasets.Metric):
    metric = {
        metric.name: metric
    }
  print(f'Metric dict: {metric}')
  if device is None:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
  asr_model = asr_model.to(device)
  result_dict = defaultdict(lambda: defaultdict(list if store_all_results else float))
  if not store_all_results:
    # pre-compute total length
    total_len = compute_total_len(loader, tokenizer=asr_processor.tokenizer)
    print(f'Total length: {total_len}')

  kenlm_model, unigrams, alphabet = get_kenlm_model_unigrams(kenlm_path,
                                                             return_alphabet=True,
                                                             labels=generate_labels(asr_processor.tokenizer.get_vocab()))

  best_results = defaultdict(lambda: -1)

  decoders = {
        _generate_key(decoder_params): build_decoder(asr_processor,
                                                     kenlm_model,
                                                     return_decoder=False,
                                                     unigrams=unigrams,
                                                     **decoder_params)
        for decoder_params in decoder_param_generator
    }
  for idy, audio_batch in enumerate(loader):
    if idy % 500 == 0:
        print(f'Evaluating sample {idy + 1}/{len(loader)}')

    true_text = [t.lower() for t in asr_processor.batch_decode(audio_batch['input_ids'],
                                                             skip_special_tokens=True)]
    true_len = [len(t) for t in true_text]
    sum_true_len = 0
    for tl in true_len:
      sum_true_len += tl

    with torch.no_grad():
      audio_dev = audio_batch['input_values'].to(device)
      audio_att_mask = audio_batch['audio_attention_mask'].to(device)
      logits = asr_model(input_values=audio_dev,
                         attention_mask=audio_att_mask).logits.to(device)
    for idx, decoder_params in enumerate(decoder_param_generator):
      if idx % 10 == 0:
        print(f'Evaluating config {idx + 1}/{len(decoder_param_generator)}')
      k = _generate_key(decoder_params)
      # decoder = build_decoder(asr_processor,
      #                         kenlm_model,
      #                         return_decoder=False,
      #                         unigrams=unigrams,
      #                         **decoder_params)
      decoder = decoders[k]
      transcription = decoder.batch_decode(logits.detach().cpu().numpy()).text
      
      for m in metric:
        metric_score = metric[m].compute(predictions=transcription, references=true_text)
        if store_all_results:
          result_dict[k][m].append(metric_score)
          result_dict[k]['weight'].append(sum_true_len)
        else:
          result_dict[k][m] += metric_score * sum_true_len / total_len

  for k in result_dict:
    result_dict[k] = dict(result_dict[k])
    if store_all_results:
      for m in result_dict[k]:
        result_dict[k][m] = np.ndarray(result_dict[k][m])
  return dict(result_dict)

def compute_best_config(result_dict: Dict[Tuple, Dict[str, np.ndarray]],
                        metric: str='wer',
                        weighted: bool=True) -> Tuple:
  best_result = -1
  best_k = None
  for k in result_dict:
    store_all_results = 'weight' in result_dict[k]
    if store_all_results:
      weights = result_dict[k]['weight']
      scaled_weights = weights / weights.sum()
      mean_metric = (result_dict[k][metric] * scaled_weights).sum()
    else:
      mean_metric = result_dict[k][metric]
    if mean_metric < best_result or best_result == -1:
      best_result = mean_metric
      best_k = k
  return best_k, mean_metric

In [4]:
!huggingface-cli login


        _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
        _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
        _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
        _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
        _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

        To login, `huggingface_hub` now requires a token generated from https://huggingface.co/settings/token.
        (Deprecated, will be removed in v0.3.0) To login with username and password instead, interrupt with Ctrl+C.
        
Token: 
Login successful
Your token has been saved to /root/.huggingface/token
[1m[31mAuthenticated through git-credential store but this isn't the helper defined on you

In [5]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [6]:
   
KENLM_MODEL_LOC = '/content/gdrive/MyDrive/Colab Notebooks/HF Community week - ASR/5gram-it-multi-ds-eos.arpa'

# SPGI_VAL_DIR = args.datadir
# SPGI_VAL_CSV = args.datalist
MODEL_NAME = 'dbdmg/wav2vec2-xls-r-300m-italian-augmented'
DATASET_NAME = 'mozilla-foundation/common_voice_7_0'
DATASET_CONFIG_NAME = 'it'
TRAIN_SPLIT_NAME = 'test[50%:51%]'
USE_AUTH_TOKEN = True

# val_df = pd.read_csv(SPGI_VAL_CSV, sep='|')

# > val_df.dtypes
# wav_filename    object
# wav_filesize     int64
# transcript      object
# dtype: object

device = 'cuda' if torch.cuda.is_available() else 'cpu'

asr_processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
asr_model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME).to(device)
print("Vocab: ", asr_processor.tokenizer.get_vocab())
print(f'Vocab shape: {asr_processor.tokenizer.get_vocab()}')
print(f'Loading dataset: {DATASET_NAME} - config: {DATASET_CONFIG_NAME}')
print(f'Split: {TRAIN_SPLIT_NAME}')
print(f'Use auth token: {USE_AUTH_TOKEN}')

raw_dataset = load_dataset(
    DATASET_NAME,
    DATASET_CONFIG_NAME,
    split=TRAIN_SPLIT_NAME,
    use_auth_token=USE_AUTH_TOKEN
)

print(raw_dataset)
kenlm_model, unigrams = get_kenlm_model_unigrams(KENLM_MODEL_LOC)

wer_metric = load_metric('wer')
cer_metric = load_metric('cer')

processor_with_lm = build_decoder(asr_processor,
                                  kenlm_model,
                                  unigrams=unigrams,
                                  alpha=.6,
                                  beta=2.0,
                                  return_decoder=False)


Downloading:   0%|          | 0.00/212 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/297 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.80k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.00k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

Vocab:  {'<': 1, '=': 2, '>': 3, '[': 4, ']': 5, '_': 6, '`': 7, 'a': 8, 'b': 9, 'c': 10, 'd': 11, 'e': 12, 'f': 13, 'g': 14, 'h': 15, 'i': 16, 'j': 17, 'k': 18, 'l': 19, 'm': 20, 'n': 21, 'o': 22, 'p': 23, 'q': 24, 'r': 25, 's': 26, 't': 27, 'u': 28, 'v': 29, 'w': 30, 'x': 31, 'y': 32, 'z': 33, '{': 34, '}': 35, '~': 36, '¡': 37, '«': 38, '°': 39, '´': 40, 'µ': 41, 'º': 42, '»': 43, 'ß': 44, 'à': 45, 'á': 46, 'ã': 47, 'ä': 48, 'å': 49, 'æ': 50, 'è': 51, 'é': 52, 'ê': 53, 'ë': 54, 'ì': 55, 'í': 56, 'î': 57, 'ï': 58, 'ð': 59, 'ñ': 60, 'ò': 61, 'ó': 62, 'ô': 63, 'ö': 64, 'ø': 65, 'ù': 66, 'ú': 67, 'û': 68, 'ü': 69, 'þ': 70, 'ÿ': 71, 'ā': 72, 'ą': 73, 'ć': 74, 'č': 75, 'đ': 76, 'ė': 77, 'ę': 78, 'ě': 79, 'ğ': 80, 'ħ': 81, 'ī': 82, 'ı': 83, 'ľ': 84, 'ł': 85, 'ń': 86, 'ň': 87, 'ō': 88, 'ő': 89, 'œ': 90, 'ř': 91, 'ś': 92, 'ş': 93, 'š': 94, 'ū': 95, 'ŭ': 96, 'ź': 97, 'ż': 98, 'ž': 99, 'ș': 100, 'ț': 101, 'ə': 102, 'ʹ': 103, 'ʻ': 104, 'ʼ': 105, 'ʾ': 106, 'ʿ': 107, 'ː': 108, '̇': 109, '̨': 110,

Downloading:   0%|          | 0.00/9.88k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.98k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/49.7k [00:00<?, ?B/s]

Downloading and preparing dataset common_voice/it to /root/.cache/huggingface/datasets/mozilla-foundation___common_voice/it/7.0.0/fe20cac47c166e25b1f096ab661832e3da7cf298ed4a91dcaa1343ad972d175b...


Downloading:   0%|          | 0.00/8.05G [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset common_voice downloaded and prepared to /root/.cache/huggingface/datasets/mozilla-foundation___common_voice/it/7.0.0/fe20cac47c166e25b1f096ab661832e3da7cf298ed4a91dcaa1343ad972d175b. Subsequent calls will reuse this data.
Dataset({
    features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
    num_rows: 148
})


Downloading:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?
Unigrams and labels don't seem to agree.


In [7]:
feature_extractor = asr_processor.feature_extractor
dataset_sampling_rate = raw_dataset[0]['audio']['sampling_rate']
if dataset_sampling_rate != feature_extractor.sampling_rate:
    raw_dataset = raw_dataset.cast_column(
        'audio', datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
    )

In [8]:
vocab_dict = asr_processor.tokenizer.get_vocab().copy()
sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
# sorted_vocab_dict['_'] = sorted_vocab_dict.pop('[pad]')
# sorted_vocab_dict[' '] = sorted_vocab_dict.pop('|')
# sorted_vocab_dict['^'] = sorted_vocab_dict.pop('<s>')
# sorted_vocab_dict['$'] = sorted_vocab_dict.pop('</s>')

In [9]:
vocab_dict

{'<': 1,
 '</s>': 178,
 '<s>': 177,
 '=': 2,
 '>': 3,
 '[': 4,
 '[PAD]': 176,
 '[UNK]': 175,
 ']': 5,
 '_': 6,
 '`': 7,
 'a': 8,
 'b': 9,
 'c': 10,
 'd': 11,
 'e': 12,
 'f': 13,
 'g': 14,
 'h': 15,
 'i': 16,
 'j': 17,
 'k': 18,
 'l': 19,
 'm': 20,
 'n': 21,
 'o': 22,
 'p': 23,
 'q': 24,
 'r': 25,
 's': 26,
 't': 27,
 'u': 28,
 'v': 29,
 'w': 30,
 'x': 31,
 'y': 32,
 'z': 33,
 '{': 34,
 '|': 0,
 '}': 35,
 '~': 36,
 '¡': 37,
 '«': 38,
 '°': 39,
 '´': 40,
 'µ': 41,
 'º': 42,
 '»': 43,
 'ß': 44,
 'à': 45,
 'á': 46,
 'ã': 47,
 'ä': 48,
 'å': 49,
 'æ': 50,
 'è': 51,
 'é': 52,
 'ê': 53,
 'ë': 54,
 'ì': 55,
 'í': 56,
 'î': 57,
 'ï': 58,
 'ð': 59,
 'ñ': 60,
 'ò': 61,
 'ó': 62,
 'ô': 63,
 'ö': 64,
 'ø': 65,
 'ù': 66,
 'ú': 67,
 'û': 68,
 'ü': 69,
 'þ': 70,
 'ÿ': 71,
 'ā': 72,
 'ą': 73,
 'ć': 74,
 'č': 75,
 'đ': 76,
 'ė': 77,
 'ę': 78,
 'ě': 79,
 'ğ': 80,
 'ħ': 81,
 'ī': 82,
 'ı': 83,
 'ľ': 84,
 'ł': 85,
 'ń': 86,
 'ň': 87,
 'ō': 88,
 'ő': 89,
 'œ': 90,
 'ř': 91,
 'ś': 92,
 'ş': 93,
 'š': 94,
 'ū

In [10]:
raw_dataset

Dataset({
    features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
    num_rows: 148
})

In [11]:
def greedy_decode(logits, labels, ignore_set=None):
    """Decode argmax of logits and squash in CTC fashion."""
    label_dict = {n: c for n, c in enumerate(labels)}
    prev_c = None
    out = []
    for n in logits.argmax(axis=1):
        c = label_dict.get(n, "")  # if not in labels, then assume it's ctc blank char
        if not ignore_set is None and c in ignore_set:
          continue
        if c != prev_c:
            out.append(c)
        prev_c = c
    return "".join(out)

In [12]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# processor_with_lm = processor_with_lm.to(device)
print(f'Vocab shape: {len(asr_processor.tokenizer.get_vocab())}')


Vocab shape: 179


In [13]:
asr_processor.feature_extractor(
    raw_dataset[0]['audio']['array'],
    **{
        'return_tensors': 'pt',
        'sampling_rate': raw_dataset[0]['audio']['sampling_rate']
    },
    padding='longest'
  )

{'input_values': tensor([[ 4.3900e-05,  4.3900e-05,  4.3900e-05,  ..., -2.2098e-03,
          5.7065e-03,  2.7297e-03]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], dtype=torch.int32)}

In [14]:
BATCH_SIZE = 32
processed_dataset = raw_dataset.map(partial(tokenize,
                              tokenizer=asr_processor.tokenizer,
                              feature_extractor=asr_processor.feature_extractor),
                      remove_columns=['accent', 'age', 'path', 'client_id', 'down_votes', 'up_votes', 'gender', 'locale', 'segment', 'audio', 'sentence'],
                      batched=True,
                      batch_size=BATCH_SIZE)
processed_dataset.set_format(type='torch', columns=['audio_attention_mask', 'input_values', 'sentence_attention_mask', 'input_ids'])


loader = DataLoader(processed_dataset, batch_size=BATCH_SIZE)

  0%|          | 0/5 [00:00<?, ?ba/s]

In [29]:
# Single execution without DataLoader
import time

for idx in range(5):
  # select random sample
  # sample_number = random.randint(0, len(val_df))
  # sample_name = val_df.loc[sample_number, "wav_filename"]
  # true_text = val_df.loc[sample_number, 'transcript']
  # sample_loc = SPGI_VAL_DIR + sample_name

  arr = torch.tensor(raw_dataset[idx]['audio']['array']).to(device)
  true_text = raw_dataset[idx]['sentence'].lower()

  inputs = {
    'return_tensors': "pt",
    'sampling_rate': raw_dataset[idx]['audio']['sampling_rate']
  }

  with torch.no_grad():
    s = time.time()
    inputs = processor_with_lm(arr, **inputs).to(device)
    logits = asr_model(**inputs).logits.to(device)
    e = time.time()
    # logits = asr_model(**asr_processor(arr, **inputs)).logits.to(device)
  print(logits.shape)

  transcription_no_lm = greedy_decode(logits[0].cpu().numpy(), sorted_vocab_dict, ignore_set={'_', '[pad]', '<s>', '</s>'})
  transcription_no_lm = ("".join(c for c in transcription_no_lm if c not in ["_", '^', '$'])).replace('|', ' ')

  print('_' * 60)
  # transcription_lm = processor_with_lm.batch_decode(logits.cpu().numpy()).text
  sd = time.time()
  transcription_lm = processor_with_lm.decoder.decode(logits.cpu().numpy()[0])
  ed = time.time()

  print(f'model time: {e - s}')
  print(f'decode time: {ed - sd}')

  print(f'Transcription LM: {transcription_lm}')
  print(f'Transcription NO-LM: {transcription_no_lm}')
  print(f'True text: {true_text}')

  wer_lm = wer_metric.compute(predictions=[transcription_lm], references=[true_text])
  wer_no_lm = wer_metric.compute(predictions=[transcription_no_lm], references=[true_text])
  print(f'LM WER: {wer_lm}')
  print(f'NO-LM WER: {wer_no_lm}')

  cer_lm = cer_metric.compute(predictions=[transcription_lm], references=[true_text])
  cer_no_lm = cer_metric.compute(predictions=[transcription_no_lm], references=[true_text])
  print(f'LM CER: {cer_lm}')
  print(f'NO-LM CER: {cer_no_lm}')

torch.Size([1, 194, 179])
____________________________________________________________
model time: 0.06877756118774414
decode time: 0.05829787254333496
Transcription LM: sua madre has le ha donato il rene
Transcription NO-LM: sua madre hasel le ha donato in reme 
True text: sua madre, hazel, le ha donato il rene.
LM WER: 0.375
NO-LM WER: 0.5
LM CER: 0.15384615384615385
NO-LM CER: 0.15384615384615385
torch.Size([1, 296, 179])
____________________________________________________________
model time: 0.0827324390411377
decode time: 0.17313456535339355
Transcription LM: di denti litare questopera venne però inizialmente adele resto post
Transcription NO-LM: di dede utiliztare questopera vene verò inizialmente adelore stopos 
True text: "l'idea di utilizzare quest'opera venne però inizialmente da taylor e da stokowski."
LM WER: 0.75
NO-LM WER: 0.9166666666666666
LM CER: 0.36904761904761907
NO-LM CER: 0.30952380952380953
torch.Size([1, 291, 179])
______________________________________________

In [15]:
alpha_beta_gen = ParameterGrid({
    'alpha': [0.5, 0.6, 0.7, 0.8],
    'beta': [1.0, 2.0, 3.0, 4.0]
})

result_dict = grid_search_decoder(asr_processor,
                                  asr_model,
                                  KENLM_MODEL_LOC,
                                  alpha_beta_gen,
                                  loader,
                                  store_all_results=False)

Metric dict: {'wer': Metric(name: "wer", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Compute WER score of transcribed segments against references.

Args:
    references: List of references for each speech input.
    predictions: List of transcriptions to score.
    concatenate_texts (bool, default=False): Whether to concatenate all input texts or compute WER iteratively.

Returns:
    (float): the word error rate

Examples:

    >>> predictions = ["this is the prediction", "there is an other sample"]
    >>> references = ["this is the reference", "there is another one"]
    >>> wer = datasets.load_metric("wer")
    >>> wer_score = wer.compute(predictions=predictions, references=references)
    >>> print(wer_score)
    0.5
""", stored examples: 0), 'cer': Metric(name: "cer", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Com

Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?
Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?
Unigrams and labels don't seem to agree.
Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?
Unigrams and labels don't seem to agree.
Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?
Unigrams and labels don't seem to agree.
Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?
Unigrams and labels don't seem to agree.
Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as

Evaluating sample 1/5
Evaluating config 1/16
Evaluating config 11/16
Evaluating config 1/16
Evaluating config 11/16
Evaluating config 1/16
Evaluating config 11/16
Evaluating config 1/16
Evaluating config 11/16
Evaluating config 1/16
Evaluating config 11/16


In [31]:
!nvidia-smi

Fri Feb  4 11:59:47 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   73C    P0    74W / 149W |   3247MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [17]:
import pickle as pkl

with open('/content/gdrive/MyDrive/Colab Notebooks/HF Community week - ASR/result_dict.pkl', 'wb') as fp:
  pkl.dump(result_dict, fp)

In [5]:
result_dict

{(0.5, 1.0): {'cer': 0.10878332959892448},
 (0.5, 2.0): {'cer': 0.10755097468070804},
 (0.5, 3.0): {'cer': 0.10799910374187766},
 (0.5, 4.0): {'cer': 0.10855926506833967},
 (0.6, 1.0): {'cer': 0.11079991037418777},
 (0.6, 2.0): {'cer': 0.10799910374187767},
 (0.6, 3.0): {'cer': 0.1066547165583688},
 (0.6, 4.0): {'cer': 0.1090073941295093},
 (0.7, 1.0): {'cer': 0.11259242661886623},
 (0.7, 2.0): {'cer': 0.10923145866009411},
 (0.7, 3.0): {'cer': 0.1091194263948017},
 (0.7, 4.0): {'cer': 0.10777503921129285},
 (0.8, 1.0): {'cer': 0.11416087833295989},
 (0.8, 2.0): {'cer': 0.11315258794532827},
 (0.8, 3.0): {'cer': 0.11102397490477256},
 (0.8, 4.0): {'cer': 0.10990365225184852}}