## Finetune Model with ESPnet-Easy

In this notebook, we will explore the process of finetuning a pretrained model using the LJSpeech dataset. We'll start by downloading a pretrained model from the Hugging Face model hub.

In this notebook, we assume that the dump files have been already created. If you need guidance on creating the dump files, you can refer to the `tacotron2.ipynb` notebook.

As with the `tacotron2.ipynb` notebook, we need to provide a dictionary to specify the file path and type for each data.

In [None]:
DUMP_DIR = "./dump/ljspeech"
data_info = {
    "speech": ["wav.scp", "sound"],
    "text": ["text", "text"],
}

### Load a pretrained model

In ESPnet-Easy, you have the flexibility to define a custom model using the `build_model_fn` method. Additionally, you can load a pretrained model when needed.

In [None]:
from espnet2.bin.tts_inference import Text2Speech

def build_model_fn(args):
    pretrained_model = Text2Speech.from_pretrained('espnet/kan-bayashi_ljspeech_tacotron2')
    model = pretrained_model.model
    model.train()
    return model

When working with a pretrained model, the configuration is inherited from the model by default. This configuration update can be easily achieved using the `update_finetune_config` method.

In [None]:
import espnetez as ez


pretrained_model = Text2Speech.from_pretrained('espnet/kan-bayashi_ljspeech_tacotron2')
pretrain_config = vars(pretrained_model.train_args)
if pretrain_config['pretrain_path'] is not None:
    pretrain_config['pretrain_path'] = None

del pretrained_model

finetune_config = ez.config.update_finetune_config(
	'tts',
	pretrain_config,
	'finetune.yaml'
)

### Training

Finally, let's start training.

In [None]:
EXP_DIR = "exp/finetune"
STATS_DIR = "exp/stats_finetune"

trainer = ez.Trainer(
    task='tts',
    train_config=finetune_config,
    train_dump_dir="dump/ljspeech/train",
    valid_dump_dir="dump/ljspeech/test",
    build_model_fn=build_model_fn, # provide the pre-trained model
    data_info=data_info,
    output_dir=EXP_DIR,
    stats_dir=STATS_DIR,
    ngpu=1
)
trainer.collect_stats()

Finally, we are ready to start finetune!

In [None]:
trainer.train()