# Train with Lhotse

This notebook provides the training script used in our research paper.

It utilizes Lhotse for data preprocessing and ESPnet2 for training.
This notebook also supports online fine-tuning.

In this notebook, we will use the LoRA adapter to fine-tune the pre-trained OWSM model.
Basically this notebook is almost the same as the `train.ipynb` except for the custom dataset class.

In [None]:
# Import necessary libraries
import torch
import numpy as np
import librosa
from pathlib import Path
from espnet2.layers.create_adapter_fn import create_lora_adapter
import argparse

# import lhotse related modules
from lhotse import CutSet
from lhotse.dataset import AudioSamples
from lhotse.recipes import prepare_librispeech

import espnetez as ez

Now, let's define the paths to your dumped files and other important training parameters.

In [None]:
# Define paths to project directories and pre-trained model
BASE_DIR = Path("path/to/egs2/librispeech_100/ez1")

STATS_DIR = f"{BASE_DIR}/exp/stats_owsm_base_finetune"
EXP_DIR = f"{BASE_DIR}/exp/owsm_base_finetune"

FINETUNE_MODEL = "espnet/owsm_v3.1_ebf_base"

We'll define a custom dataset class to retrieve data from the dumped files.

In [None]:
# Define a custom dataset class to load data from preprocessed files
class CustomDataset:
    def __init__(self, manifest, apply_augmentation=False, p=0.3):
        self.apply_augmentation = apply_augmentation
        self.length = len(manifest['recordings'])
        self.cuts = CutSet.from_manifests(**manifest).trim_to_supervisions()
        self.p = p
        self.audio_samples = AudioSamples()
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        monocut = self.cuts[idx]
        text = monocut.supervisions[0].text
        if self.apply_augmentation:
            if random.random() < self.p:
                monocut = monocut.perturb_speed(
                    factor=random.choice([0.9])
                )
            if random.random() < self.p:
                monocut = monocut.perturb_volume(
                    factor=random.uniform(0.125, 2.0)
                )
            if random.random() < self.p:
                monocut = monocut.perturb_tempo(
                    factor=random.choice([0.9])
                )
        
        return {
            'speech': self.audio_samples([monocut])[0][0],
            'text': text,
        }


Since the OWSM models were trained on lowercase text, we need to convert all text data to lowercase. 

Next, we'll define the training configuration. This involves using the configuration from a pre-trained OWSM model as a base and then adding our own custom configurations.

In [None]:
# Load pre-trained OWSM model configuration for tokenizer, converter, and base training configuration
from espnet2.bin.s2t_inference import Speech2Text

pretrained_model = Speech2Text.from_pretrained(
    FINETUNE_MODEL,
    # category_sym="<en>",  # Comment out if not used
    beam_size=10,
    device="cpu"
)
tokenizer = pretrained_model.tokenizer
converter = pretrained_model.converter
training_config = vars(pretrained_model.s2t_train_args)
del pretrained_model

# Update finetuning configuration from a YAML file (likely user-defined)
finetune_config = ez.config.update_finetune_config(
    "s2t",
    training_config,
    "path/to/your/finetune/config.yaml"
)
finetune_config['multiple_iterator'] = False  # Set training parameters


We'll define `data_info` to connect our custom dataset with the ESPnet dataloader. This helps the dataloader understand how to process the data.

In [None]:
# Define data_info to connect custom dataset with ESPnet dataloader
def tokenize(text):
    return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))


data_info = {
    "speech": lambda d : d['speech'].numpy(), # audio is already loaded in lhotse.
    "text": lambda d : tokenize(f"{d['text'].lower()}"),
    "text_prev": lambda d : tokenize("<na>"),
    "text_ctc": lambda d : tokenize(d['text'].lower()),
} 

We'll define a function to prepare the model for fine-tuning. While you can define a custom model here, this notebook uses the pre-trained OWSM model directly.

In [None]:

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def build_model_fn(args):
    pretrained_model = Speech2Text.from_pretrained(
        FINETUNE_MODEL,
        beam_size=10,
    )
    model = pretrained_model.s2t_model
    model.train()
    print(f'Trainable parameters: {count_parameters(model)}')
    return model


We're almost ready!
Now, let's prepare the datasets and convert them into the format expected by ESPnet's dataloader. The dataloader relies on ESPnetDataset objects to process and feed data during training.

We are using the Lhotse library to prepare the LibriSpeech dataset, but you can use your own manifest file.

In [None]:
librispeech_100_path = "path/to/LibriSpeech"
libri = prepare_librispeech(librispeech_100_path)

train_dataset = CustomDataset(libri["train-clean-100"], apply_augmentation=True)
dev_dataset = CustomDataset(libri["dev-clean"], apply_augmentation=False)

train_dataset = ez.dataset.ESPnetEZDataset(train_dataset, data_info=data_info)
dev_dataset = ez.dataset.ESPnetEZDataset(dev_dataset, data_info=data_info)

Now that everything is set up, let's start the training process!

In [None]:
trainer = ez.Trainer(
    task="s2t",
    train_config=finetune_config,
    train_dataset=train_dataset,
    valid_dataset=dev_dataset,
    data_info=data_info,
    build_model_fn=build_model_fn,
    output_dir=EXP_DIR,
    stats_dir=STATS_DIR,
    ngpu=1,
)
trainer.collect_stats()
trainer.train()