# Train Cascade ST with Huggingface Transformers (HF)

This notebook provides the training script used in our research paper.

It utilizes pre-processed data generated by ESPnet to train a speech-to-text model. The data includes audio file paths and corresponding text transcripts. The audio file paths may point to either the original or copied versions of the audio files.

In this notebook, we will use the Huggingface Transformers (HF) adapter to fine-tune the pre-trained OWSM model. Basically everything is the same as the `train.ipynb` except for the custom fine-tuning model.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import librosa
from pathlib import Path
from espnet2.layers.create_adapter_fn import create_lora_adapter
from espnet2.asr.espnet_model import ESPnetASRModel
from espnet2.train.dataset import kaldi_loader
from espnet2.train.abs_espnet_model import AbsESPnetModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import espnetez as ez

Now, let's define the paths to your dumped files and other important training parameters.

In [None]:
BASE_DIR = Path("path/to/egs2/must_c_v2/ez1")

CONFIG = "owsm_finetune_base"
FINETUNE_MODEL = "pyf98/librispeech_100_e_branchformer"
HF_MODEl = "google-t5/t5-base"

DATA_PATH = f"{BASE_DIR}/data"
DUMP_DIR = f"{BASE_DIR}/dump/raw"
STATS_DIR = f"{BASE_DIR}/exp/stats_huggingface_cascade"
EXP_DIR = f"{BASE_DIR}/exp/train_huggingface_cascade"

We'll define a custom dataset class to retrieve data from the dumped files.

In [None]:
# Define a custom dataset class to load data from preprocessed files
class CustomDataset:
    def __init__(self, data_path, is_train=True):
        self.data_path = data_path
        if is_train:
            data_path = f"{data_path}/train.en-de"
        else:
            data_path = f"{data_path}/dev.en-de"
        
        self.data = {}
        with open(f"{data_path}/text.tc.de", "r") as f:
            for line in f.readlines():
                audio_id, translated = line.strip().split(maxsplit=1)
                translated = translated.replace(" &apos;", "'")\
                                       .replace(" &quot;", '"')\
                                       .replace(" &amp;", "&")
                self.data[audio_id] = {
                    'translated': translated
                }
        
        with open(f"{data_path}/text", "r") as f:
            for line in f.readlines():
                audio_id, text = line.strip().split(maxsplit=1)
                text = text.replace(" &apos;", "'")\
                           .replace(" &quot;", '"')\
                           .replace(" &amp;", "&")
                self.data[audio_id]['text'] = text
        
        self.keys = list(self.data.keys())[1:]
        self.loader = kaldi_loader(f"{data_path}/wav.scp")
    
    def __len__(self):
        return len(self.keys)
    
    def __getitem__(self, idx):
        if type(idx) == int:
            idx = int(idx)
            return {
                'speech': self.loader[idx][1].astype(np.float32),
                'text': self.data[self.keys[idx]]['text'],
                'translated': self.data[self.keys[idx]]['translated']
            }
        return {
            'speech': self.loader[idx][1].astype(np.float32),
            'text': self.data[idx]['text'],
            'translated': self.data[idx]['translated']
        }


Then we define a custom class for fine-tuning.

In [None]:
class CustomFinetuneModel(AbsESPnetModel):
    def __init__(self, nbest=5, beam_size=10, log_every=500):
        super().__init__()
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.log_every = log_every
        self.asr_model = Speech2Text.from_pretrained(
            FINETUNE_MODEL,
            nbest=nbest,
            beam_size=beam_size,
            device=device
        )
        self.lm = AutoModelForSeq2SeqLM.from_pretrained(
            HF_MODEl,
            device_map = device
        )
        self.lm_tokenizer = AutoTokenizer.from_pretrained(HF_MODEl)
        self.log_stats = {
            'loss': 0
        }
        self.iter_count = 0

    def collect_feats(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        *args,
        **kwargs,
    ):
        return {"feats": speech, "feats_lengths": speech_lengths}
    
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ):
        # 1. ASR
        asr_texts = []
        for i in range(len(speech)):
            asr_texts = self.asr_model(speech[i][:speech_lengths[i]])[0][0]
            asr_texts.append("translate English to German: " + asr_texts.capitalize())

        # compute hf loss
        target_tokens = self.lm_tokenizer(
            asr_texts, return_tensors="pt").input_ids.to(speech.device)
        lm_output = self.lm(input_ids=target_tokens, labels=text)

        # Add lm loss to ASR loss
        loss = lm_output.loss
        self.log_stats['loss'] += loss.item()
        stats = {
            'loss': loss.detach()
        }

        self.iter_count += 1
        if self.iter_count % self.log_every == 0:
            _loss = self.log_stats['loss'] / self.log_every
            print(f"[{self.iter_count}] - loss: {_loss:.3f}")
            self.log_stats['loss'] = 0.0

        return loss, stats, None

Next, we'll define the training configuration. This involves using the configuration from a pre-trained OWSM model as a base and then adding our own custom configurations.

In [None]:
# Load pre-trained OWSM model configuration for tokenizer, converter, and base training configuration
from espnet2.bin.s2t_inference import Speech2Text

pretrained_model = Speech2Text.from_pretrained(
    FINETUNE_MODEL,
)
tokenizer = pretrained_model.tokenizer
converter = pretrained_model.converter
training_config = vars(pretrained_model.s2t_train_args)
del pretrained_model

# Update finetuning configuration from a YAML file (likely user-defined)
finetune_config = ez.config.update_finetune_config(
    "s2t",
    training_config,
    "path/to/your/finetune/config.yaml"
)
finetune_config['multiple_iterator'] = False  # Set training parameters


We'll define `data_info` to connect our custom dataset with the ESPnet dataloader. This helps the dataloader understand how to process the data.

In [None]:
# Define data_info to connect custom dataset with ESPnet dataloader
lm_tokenizer = AutoTokenizer.from_pretrained(HF_MODEl)
data_info = {
    "speech": lambda d : d['speech'],
    "text": lambda d : lm_tokenizer(d['translated'].upper(), return_tensors="np").input_ids,
}

We'll define a function to prepare the model for fine-tuning. While you can define a custom model here, this notebook uses the pre-trained OWSM model directly.

Note that if we don't need the preprocessing steps during `collect_stats` process, we can set `build_preprocess_fn` to return `None`.

In [None]:
def build_model_fn(args):
    model = CustomFinetuneModel(log_every=20)
    return model

def build_preprocess_fn(*args, **kwargs):
    return None

We're almost ready!
Now, let's prepare the datasets and convert them into the format expected by ESPnet's dataloader. The dataloader relies on ESPnetDataset objects to process and feed data during training.

In [None]:
train_dataset = CustomDataset(data_path="./dump/raw", is_train=True)
dev_dataset = CustomDataset(data_path="./dump/raw", is_train=False)

train_dataset = ez.dataset.ESPnetEZDataset(train_dataset, data_info=data_info)
dev_dataset = ez.dataset.ESPnetEZDataset(dev_dataset, data_info=data_info)

We change the training configuration a little bit to avoid ESPnet to search tokenizer-related files.

In [None]:
finetune_config['token_list'] = []
finetune_config['token_type'] = 'char'

Now that everything is set up, let's start the training process!

We set the task as `asr` since the model takes the same input as the ASR, and use the same loss function and training process as the ASR model.

In [None]:
trainer = ez.Trainer(
    task="asr",
    train_config=finetune_config,
    train_dataset=train_dataset,
    valid_dataset=dev_dataset,
    data_info=data_info,
    build_model_fn=build_model_fn,
    build_preprocess_fn=build_preprocess_fn,
    output_dir=EXP_DIR,
    stats_dir=STATS_DIR,
    ngpu=1,
)
trainer.collect_stats()
trainer.train()