# Train with Slurp

This notebook provides the training script used in our research paper.

It utilizes pre-processed data generated by ESPnet to train a speech-to-text model. The data includes audio file paths and corresponding text transcripts. The audio file paths may point to either the original or copied versions of the audio files.

In [None]:
# Import necessary libraries
import torch
import numpy as np
import librosa
from pathlib import Path
from espnet2.layers.create_adapter_fn import create_lora_adapter
import argparse

import espnetez as ez

Now, let's define the paths to your dumped files and other important training parameters.

In [None]:
# Define paths to project directories and pre-trained model
BASE_DIR = Path("path/to/egs2/slurp/ez1")

DATA_PATH = f"{BASE_DIR}/data"
DUMP_DIR = f"{BASE_DIR}/dump/raw"
STATS_DIR = f"{BASE_DIR}/exp/stats_owsm_base_finetune"
EXP_DIR = f"{BASE_DIR}/exp/owsm_base_finetune"

FINETUNE_MODEL = "espnet/owsm_v3.1_ebf_base"

We'll define a custom dataset class to retrieve data from the dumped files.

In [None]:
# Define a custom dataset class to load data from preprocessed files
class CustomDataset:
    def __init__(self, data_path, is_train=True):
        self.data_path = data_path
        if is_train:
            data_path = f"{data_path}/train"
        else:
            data_path = f"{data_path}/devel"

        self.data = {}
        with open(f"{data_path}/wav.scp", "r") as f:
            for line in f.readlines():
                audio_id, audio_path = line.strip().split(maxsplit=1)
                self.data[audio_id] = {
                    'audio_path': audio_path
                }

        with open(f"{data_path}/transcript", "r") as f:
            for line in f.readlines():
                audio_id, translated = line.strip().split(maxsplit=1)
                self.data[audio_id]['transcript'] = translated

        with open(f"{data_path}/text", "r") as f:
            for line in f.readlines():
                audio_id, intent, _ = line.strip().split(maxsplit=2)
                self.data[audio_id]['intent'] = intent

        self.keys = list(self.data.keys())

    def __len__(self):
        return len(self.keys)
        
    def __getitem__(self, idx):
        if type(idx) == int:
            idx = int(idx)
            return {
                'audio_path': self.data[self.keys[idx]]['audio_path'],
                'intent': self.data[self.keys[idx]]['intent'],
                'transcript': self.data[self.keys[idx]]['transcript']
            }

        return {
            'audio_path': self.data[idx]['audio_path'],
            'intent': self.data[idx]['intent'],
            'transcript': self.data[idx]['transcript']
        }

Since the OWSM models were trained on lowercase text, we need to convert all text data to lowercase. 

Next, we'll define the training configuration. This involves using the configuration from a pre-trained OWSM model as a base and then adding our own custom configurations.

In [None]:
ADDITIONAL_SPECIAL_TOKENS = [
    "<intent>"
]

# Load pre-trained OWSM model configuration for tokenizer, converter, and base training configuration
from espnet2.bin.s2t_inference import Speech2Text
pretrained_model = Speech2Text.from_pretrained(
    FINETUNE_MODEL,
)
tokenizer = pretrained_model.tokenizer
converter = pretrained_model.converter

# Add new <intent_cls> token after ST-related tokens
tokenizer, converter, _ = ez.preprocess.add_special_tokens(
    tokenizer, converter, pretrained_model.s2t_model.decoder.embed[0],
    ADDITIONAL_SPECIAL_TOKENS, insert_after="<st_zho>"
)

training_config = vars(pretrained_model.s2t_train_args)
del pretrained_model

# Update finetuning configuration from a YAML file (likely user-defined)
finetune_config = ez.config.update_finetune_config(
    "s2t",
    training_config,
    "path/to/your/finetune/config.yaml"
)
finetune_config['multiple_iterator'] = False  # Set training parameters
finetune_config['resume'] = False


We'll define `data_info` to connect our custom dataset with the ESPnet dataloader. This helps the dataloader understand how to process the data.

In [None]:
# Define data_info to connect custom dataset with ESPnet dataloader
def tokenize(text):
    return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))


data_info = {
    "speech": lambda d : librosa.load(d['audio_path'], sr=16000)[0],
    "text": lambda d : tokenize(f"<eng><intent><notimestamps>{d['intent']}"),
    "text_prev": lambda d : tokenize("<na>"),
    "text_ctc": lambda d : tokenize(d['transcript'].lower()),
}

We'll define a function to prepare the model for fine-tuning. While you can define a custom model here, this notebook uses the pre-trained OWSM model directly.

In [None]:

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def build_model_fn(args):
    pretrained_model = Speech2Text.from_pretrained(
        FINETUNE_MODEL,
        beam_size=10,
    )
    model = pretrained_model.s2t_model
    model.train()
    print(f'Trainable parameters: {count_parameters(model)}')
    # Add new <intent> token
    _, _, new_embedding = ez.preprocess.add_special_tokens(
        pretrained_model.tokenizer,
        pretrained_model.converter,
        model.decoder.embed[0],
        ADDITIONAL_SPECIAL_TOKENS,
        insert_after="<st_zho>"
    )
    new_embedding.weight.requires_grad = True
    model.decoder.embed[0] = new_embedding
    return model

We're almost ready!
Now, let's prepare the datasets and convert them into the format expected by ESPnet's dataloader. The dataloader relies on ESPnetDataset objects to process and feed data during training.

In [None]:
train_dataset = CustomDataset(data_path="./dump/raw", is_train=True)
dev_dataset = CustomDataset(data_path="./dump/raw", is_train=False)

train_dataset = ez.dataset.ESPnetEZDataset(train_dataset, data_info=data_info)
dev_dataset = ez.dataset.ESPnetEZDataset(dev_dataset, data_info=data_info)

Now that everything is set up, let's start the training process!

In [None]:
trainer = ez.Trainer(
    task="s2t",
    train_config=finetune_config,
    train_dataset=train_dataset,
    valid_dataset=dev_dataset,
    data_info=data_info,
    build_model_fn=build_model_fn,
    output_dir=EXP_DIR,
    stats_dir=STATS_DIR,
    ngpu=1,
)
trainer.collect_stats()
trainer.train()