## Finetune model with ESPnet-Easy!
In this notebook, we will finetune a pretrained model with Librispeech-100 dataset.

We will download the pretrained model from huggingface model hub and apply lora to reduce the number of training parameters!

In this notebook, we assume that we have already created the dump files. 
You can prepare the dump files by following the `libri100.ipynb` notebook.

In [None]:
!pip install ../../ -U
!pip install loralib

In [1]:
# Then create the dump files
DUMP_DIR = "./dump/libri100"

data_info = {
    "speech": ["wav.scp", "sound"],
    "text": ["text", "text"],
}


### Load a pretrained model

Next, we will load a pretrained model with `espnet_model_zoo` library.
We can simply download and initialize the model with `from_pretrained` method.

In [2]:
from espnet2.bin.asr_inference import Speech2Text

pretrained_model = Speech2Text.from_pretrained('pyf98/librispeech_conformer_hop_length160')

training_config = pretrained_model.asr_train_args
model = pretrained_model.asr_model

  from .autonotebook import tqdm as notebook_tqdm
Fetching 32 files: 100%|██████████| 32/32 [00:00<00:00, 237133.80it/s]


And then apply LoRA to the model to reduce the training parameters.

In [3]:
from espnet2.layers.create_lora_adapter import create_lora_adapter

create_lora_adapter(model, target_modules=['linear_q'])

Finally, let's define the training config and start training.
Note that the configuration for the model definition will be ignord when you provide the pretrained model.

In [4]:
import espnetez as ez

EXP_DIR = "exp/finetune"
STATS_DIR = "exp/stats_all"
finetune_config = ez.utils.load_yaml(
    "finetune_with_lora.yaml",
)
for key, value in finetune_config.items():
    setattr(training_config, key, value)

trainer = ez.Trainer(
    task='asr',
    train_config=training_config,
    train_dump_dir="dump/libri100/train",
    valid_dump_dir="dump/libri100/dev",
    model=model, # provide the pre-trained model
    data_info=data_info,
    output_dir=EXP_DIR,
    stats_dir=STATS_DIR,
)



In [None]:
trainer.collect_stats()

In [None]:

trainer.train()