## OWSM finetuning with custom dataset
This Jupyter notebook provides a step-by-step guide on using the ESPnetEasy module to finetune owsm model. In this demonstration, we will leverage the custom dataset to finetune an OWSM model for ASR task.

### Data Preparation

For this tutorial, we assume that we have the custom dataset with 654 audio with the following directory structure:
```
audio
├── 001 [420 entries exceeds filelimit, not opening dir]
└── 002 [234 entries exceeds filelimit, not opening dir]
transcription
└── owsm_v3.1
      ├── 001.csv
      └── 002.csv
```
The csv files contains the audio path, text, and text_ctc data in Japanese. For example, the csv constains the following data:
```
audio/001/00014.wav,しゃべるたびに追いかけてくるんですけど,なんかしゃべるたびにおいかけてくるんですけど
audio/001/00015.wav,え、どうしよう,えどうしよう
audio/001/00017.wav,え、何どうしたらなおるの、これ,えなな何どうしたらなおるのこれ
```

In [None]:
import os
from glob import glob
import torch
import numpy as np
import librosa

from espnet2.bin.s2t_inference import Speech2Text
from espnet2.layers.create_lora_adapter import create_lora_adapter
import espnetez as ez

# Define hyper parameters
DUMP_DIR = f"./dump"
CSV_DIR = f"./transcription"
EXP_DIR = f"./exp/finetune"
STATS_DIR = f"./exp/stats_finetune"

FINETUNE_MODEL = "espnet/owsm_v3.1_ebf"
LORA_TARGET = [
    "w_1", "w_2", "merge_proj"
]

Then let's define the custom dataset class. The owsm finetuning requires `audio`, `text`, `text_prev` and `text_ctc` data.

In [None]:
# dataset class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_list):
        # data_list is a list of tuples (audio_path, text, text_ctc)
        self.data = data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self._parse_single_data(self.data[idx])

    def _parse_single_data(self, d):
        text = f"<jpn><asr><notimestamps> {d['transcript']}"
        return {
            "audio_path": d["audio_path"],
            "text": text,
            "text_prev": "<na>",
            "text_ctc": d['text_ctc'],
        }


data_list = []
for csv_file in sorted(glob(os.path.join(CSV_DIR, "*.csv"))):
    with open(csv_file, "r", encoding="utf-8") as f:
        data_list += f.readlines()[1:] # skip header

validation_examples = 20
train_dataset = CustomDataset(data_list[:-validation_examples])
valid_dataset = CustomDataset(data_list[-validation_examples:])


### Setup training configs and model

Since we are going to finetune an OWSM model for ASR task, we will use the tokenizer and TokenIDConverter of the OWSM model. We will also use the training config as the default parameter sets, and update them with the finetuning configuration.

In [None]:
pretrained_model = Speech2Text.from_pretrained(
    FINETUNE_MODEL,
    category_sym="<jpn>",
    beam_size=10,
) # Load model to extract configs.
pretrain_config = vars(pretrained_model.s2t_train_args)
tokenizer = pretrained_model.tokenizer
converter = pretrained_model.converter
del pretrained_model

finetune_config = ez.config.update_finetune_config(
	's2t',
	pretrain_config,
	f"./config/finetune_with_lora.yaml"
)

# define model loading function
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def build_model_fn(args):
    pretrained_model = Speech2Text.from_pretrained(
        FINETUNE_MODEL,
        category_sym="<jpn>",
        beam_size=10,
    )
    model = pretrained_model.s2t_model
    model.train()
    print(f'Trainable parameters: {count_parameters(model)}')
    # apply lora
    create_lora_adapter(model, target_modules=LORA_TARGET)
    print(f'Trainable parameters after LORA: {count_parameters(model)}')
    return model

### Wrap with ESPnetEasyDataset

Before initiating the training process, it is crucial to adapt the dataset to the ESPnet format. The dataset class should output tokenized text and audio files in `np.array` format.

The `torchaudio.datasets` module offers datasets with the following format: `(audio, sample_rate, transcription, speaker_id, chapter_id, utterance_id)`.
To align with the ESPnet format, we must undertake the following preprocessing steps:

- Convert the audio to a `np.ndarray` instance in a single channel.
- Tokenize the transcription and convert it to a `np.ndarray` instance.

We define a `data_info` argument below to specify these preprocessing steps, which is then provided to the `ESPnetEasyDataset` constructor.

In [None]:
def tokenize(text):
    return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))

# The output of CustomDatasetInstance[idx] will converted to np.array
# with the functions defined in the data_info dictionary.
data_info = {
    "speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
    "text": lambda d: tokenize(d["text"]),
    "text_prev": lambda d: tokenize(d["text_prev"]),
    "text_ctc": lambda d: tokenize(d["text_ctc"]),
}

# Convert into ESPnet-Easy dataset format
train_dataset = ez.dataset.ESPnetEasyDataset(train_dataset, data_info=data_info)
valid_dataset = ez.dataset.ESPnetEasyDataset(valid_dataset, data_info=data_info)

### Training

While the configuration remains consistent with other notebooks, the instantiation arguments for the Trainer class differ in this case. As we have not generated dump files, we can disregard arguments related to dump files and directly provide the train/valid dataset classes.

```
trainer = Trainer(
    ...
    train_dataset=your_train_dataset_instance,
    train_dataset=your_valid_dataset_instance,
    ...
)
```

In [None]:
trainer = ez.Trainer(
    task='s2t',
    train_config=finetune_config,
    train_dataset=train_dataset,
    valid_dataset=valid_dataset,
    build_model_fn=build_model_fn, # provide the pre-trained model
    data_info=data_info,
    output_dir=EXP_DIR,
    stats_dir=STATS_DIR,
    ngpu=1
)
trainer.collect_stats()
trainer.train()

### Inference
When training is done, we can use the inference API to generate the transcription, but don't forget to apply lora before loading the model!

In [None]:
DEVICE = "cuda"

model = Speech2Text.from_pretrained(
    "espnet/owsm_v3.1_ebf",
    category_sym="<jpn>",
    beam_size=10,
    device=DEVICE
)
create_lora_adapter(model.s2t_model, target_modules=LORA_TARGET)
model.s2t_model.eval()
d = torch.load("./exp/finetune/1epoch.pth")
model.s2t_model.load_state_dict(d)

### Results
As a result, the finetuned owsm-v3.1 could successfully transcribe the audio files.

**Example**
- before finetune: 出してこの時間二のどりを。  
- after finetune: ダンスでこの世界に彩りを。