# Imports

In [None]:
import numpy as np
import time
import torch
import os
from typing import List, Dict, Union, Set, Any
from torch import nn
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict, Counter
from pathlib import Path
import pandas as pd
import soundfile as sf
import torchaudio
import warnings

import matplotlib.pyplot as plt
%matplotlib inline
from IPython import display


In [None]:
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install jiwer

Collecting jiwer
  Downloading jiwer-3.0.5-py3-none-any.whl.metadata (2.7 kB)
Collecting rapidfuzz<4,>=3 (from jiwer)
  Downloading rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading jiwer-3.0.5-py3-none-any.whl (21 kB)
Downloading rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-3.0.5 rapidfuzz-3.10.1


# Download data

In [None]:
# !pip install kaggle

In [None]:
# https://github.com/Kaggle/kaggle-api - Docs kaggle
# Simplest way: go to https://www.kaggle.com/settings , "Create new token" and move it into "~/.kaggle"

!kaggle datasets download -d mfekadu/darpa-timit-acousticphonetic-continuous-speech

Dataset URL: https://www.kaggle.com/datasets/mfekadu/darpa-timit-acousticphonetic-continuous-speech
License(s): copyright-authors
Downloading darpa-timit-acousticphonetic-continuous-speech.zip to /content
 99% 819M/829M [00:05<00:00, 118MB/s]
100% 829M/829M [00:05<00:00, 148MB/s]


In [None]:
!unzip -o -q darpa-timit-acousticphonetic-continuous-speech.zip -d timit/

# Dataset

In [None]:
import librosa
class TimitDataset(Dataset):
    """Загрузка TIMIT данных с диска"""
    def __init__(self, data_path):
        self.data_path = data_path
        self.uri2wav = {}
        self.uri2text = {}
        self.uri2word_ali = {}
        self.uri2phone_ali = {}
        for d, _, fs in os.walk(data_path):
            for f in fs:
                full_path = f'{d}/{f}'
                if f.endswith('.WAV'):
                    # skip it. Use .wav instead
                    pass
                elif f.endswith('.wav'):
                    stem = Path(f[:-4]).stem # .WAV.wav
                    self.uri2wav[f'{d}/{stem}'] = full_path
                elif f.endswith('.TXT'):
                    stem = Path(f).stem
                    self.uri2text[f'{d}/{stem}'] = full_path
                elif f.endswith('.WRD'):
                    stem = Path(f).stem
                    self.uri2word_ali[f'{d}/{stem}'] = full_path
                elif f.endswith('.PHN'):
                    stem = Path(f).stem
                    self.uri2phone_ali[f'{d}/{stem}'] = full_path
                else:
                    warnings.warn(f"Unknown file type {full_path} . Skip it.")

        self.uris = list(sorted(set(self.uri2wav.keys()) \
                                & set(self.uri2text.keys()) \
                                & set(self.uri2word_ali.keys()) \
                                &  set(self.uri2phone_ali.keys())
                               ))
        print(f"Found {len(self.uris)} utterances in {self.data_path}. ",
              f"{len(self.uri2wav)} wavs, ",
              f"{len(self.uri2text)} texts, ",
              f"{len(self.uri2word_ali)} word alinments, ",
             f"{len(self.uri2phone_ali)} phone alignments")

    def get_uri(self, index_or_uri: Union[str, int]):
        if isinstance(index_or_uri, str):
            uri = index_or_uri
        else:
            uri = self.uris[index_or_uri]
        return uri


    def get_audio(self, index_or_uri: Union[str, int]):
        uri = self.get_uri(index_or_uri)
        wav_path = self.uri2wav[uri]
        # wav_channels, sr = torchaudio.load(wav_path)
        wav_channels, sr = librosa.load(wav_path, sr = None)
        return wav_channels, sr

    def get_text(self, index_or_uri: Union[str, int]):
        """ Return (start_sample, stop_sample, text)"""
        uri = self.get_uri(index_or_uri)
        txt_path = self.uri2text[uri]
        with open(txt_path) as f:
            start, stop, text = f.read().strip().split(maxsplit=2)
            start, stop = int(start), int(stop)
            assert start == 0, f"{txt_path}"
        return start, stop, text

    def get_word_ali(self, index_or_uri):
        """ Return [(start_sample, stop_sample, word), ...]"""
        uri = self.get_uri(index_or_uri)
        wrd_path = self.uri2word_ali[uri]
        with open(wrd_path) as f:
            words = [(int(start), int(stop), word) for start, stop, word in map(str.split, f.readlines())]
        return words

    def get_phone_ali(self, index_or_uri):
        """ Return [(start_sample, stop_sample, phone), ...]"""
        uri = self.get_uri(index_or_uri)
        ph_path = self.uri2phone_ali[uri]
        with open(ph_path) as f:
            phonemes = [(int(start), int(stop), ph) for start, stop, ph in map(str.split, f.readlines())]
        return phonemes

    def __getitem__(self, index):
        return {"uri": self.get_uri(index),
                "audio": self.get_audio(index),
                "text": self.get_text(index),
                "word_ali": self.get_word_ali(index),
                "phone_ali": self.get_phone_ali(index)}

    def __len__(self):
        return len(self.uris)

    def total_audio_samples(self) -> int:
        audio_len = 0
        for uri in self.uris:
          audio, sr = self.get_audio(uri)
          audio_len+=len(audio)
        return audio_len

    def total_num_words(self) -> int:
        num_words = 0
        for uri in self.uris:
          words = self.get_word_ali(uri)
          num_words+=len(words)
        return num_words

    def total_num_phones(self) -> int:
        num_phones = 0
        for uri in self.uris:
          phones = self.get_phone_ali(uri)
          num_phones+=len(phones)
        return num_phones

    def get_vocab(self) -> Set[str]:
        unique_words = set()
        for uri in self.uris:
          words = self.get_word_ali(uri)
          words = [w[2].lower().strip() for w in words]
          unique_words.update(words)
        return unique_words

    def get_phones(self) -> Set[str]:
        unique_phones = set()
        for uri in self.uris:
          phones = self.get_phone_ali(uri)
          phones = [p[2].lower().strip() for p in phones]
          unique_phones.update(phones)
        return unique_phones

    def phones_prior(self) -> Dict[str, float]:
        all_phones = list()
        for uri in self.uris:
          phones = self.get_phone_ali(uri)
          phones = [p[2].lower().strip() for p in phones]
          all_phones.extend(phones)

        counts = Counter(all_phones)
        total_count = counts.total()
        counts = {k: v/total_count for k,v in dict(counts).items()}
        return counts


In [None]:
class FeatsDataset(TimitDataset):
    def __init__(self, data_path, feature_extractor, tokenizer, sr = 16000):
        super().__init__(data_path)
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.sr = sr

    def __getitem__(self, index):
        orig_item = super().__getitem__(index)
        wav, sr = orig_item['audio']
        text = orig_item['text'][2]

        return {
                "wav": wav,
                "sr": sr,
                "text": text
                }

    def collate_pad(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        batch_texts = [d['text'] for d in batch]
        batch_wavs = [d['wav'] for d in batch]
        sr = self.sr

        labels = self.tokenizer(batch_texts, return_tensors="pt", padding=True).input_ids
        feats = self.feature_extractor(batch_wavs, sampling_rate=sr, return_tensors="pt", padding='longest').input_values

        return {'input_values': feats,
               'labels': labels,
               }


# Training

In [None]:
from transformers import AutoTokenizer, AutoFeatureExtractor, SpeechEncoderDecoderModel
from datasets import load_dataset

encoder_id = "facebook/wav2vec2-base-960h"  # acoustic model encoder
decoder_id = "google-bert/bert-base-uncased"  # text decoder

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
tokenizer = AutoTokenizer.from_pretrained(decoder_id)

In [None]:
test_ds = FeatsDataset('timit/data/TEST/', feature_extractor=feature_extractor, tokenizer=tokenizer)
train_ds = FeatsDataset('timit/data/TRAIN/', feature_extractor=feature_extractor, tokenizer=tokenizer)

Found 1680 utterances in timit/data/TEST/.  1680 wavs,  1680 texts,  1680 word alinments,  1680 phone alignments
Found 4620 utterances in timit/data/TRAIN/.  4620 wavs,  4620 texts,  4620 word alinments,  4620 phone alignments


In [None]:
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id)

model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertLMHeadModel were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encod

In [None]:
#set random seed for reproducibility
import random
import torch
from transformers.file_utils import is_torch_available

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    if is_torch_available():
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

In [None]:
from evaluate import load

wer_metric = load("wer")

class WER():
  def __init__(self):
    self.wer_batched = []

  def __call__(self, pred, compute_result):
    if compute_result:
      result = {'wer':np.mean(self.wer_batched)}
      self.wer_batched = []
    else:
      wer = self.compute_metrics(pred)
      self.wer_batched.append(wer)
      result = {'wer':wer}
    return result

  def compute_metrics(self, pred):
      pred_ids = torch.argmax(pred.predictions[0],dim=2)
      label_ids = pred.label_ids

      pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
      label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

      wer = wer_metric.compute(predictions=pred_str, references=label_str)
      return wer


In [None]:
from transformers import EarlyStoppingCallback, TrainingArguments, Trainer

set_seed(42)

training_args = TrainingArguments(
    output_dir="AED_from_pretrained",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    weight_decay=0.01,
    warmup_steps=100,
    lr_scheduler_type="cosine",
    eval_strategy="epoch",
    # eval_strategy="steps",
    # eval_steps= 10,
    save_strategy="epoch",
    logging_steps = 50,
    save_total_limit = 1,
    metric_for_best_model='wer',
    greater_is_better=False,
    dataloader_drop_last=True,
    load_best_model_at_end=True,
    remove_unused_columns = False,
    batch_eval_metrics = True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=train_ds.collate_pad,
    compute_metrics=WER(),
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)

trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msaltandmatches[0m. Use [1m`wandb login --relogin`[0m to force relogin


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Wer
1,2.3258,2.836602,0.681021
2,1.2708,2.878866,0.65716
3,1.0131,2.849786,0.641864
4,0.8572,2.846958,0.637151


There were missing keys in the checkpoint model loaded: ['decoder.cls.predictions.decoder.weight', 'decoder.cls.predictions.decoder.bias'].


TrainOutput(global_step=2308, training_loss=1.8007373619740716, metrics={'train_runtime': 1854.2275, 'train_samples_per_second': 24.916, 'train_steps_per_second': 3.112, 'total_flos': 0.0, 'train_loss': 1.8007373619740716, 'epoch': 4.0})