In [1]:
%cd StarGANv2-EmotionalVC

/notebooks/StarGANv2-EmotionalVC


In [2]:
%pip install SoundFile munch parallel_wavegan pydub pyyaml click librosa

[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
import os
import os.path as osp
import re
import sys
import yaml
import shutil
import numpy as np
import torch
import click
import warnings
warnings.simplefilter('ignore')

from functools import reduce
from munch import Munch

from meldataset import build_dataloader
from optimizers import build_optimizer
from models import build_model
from trainer import Trainer
from torch.utils.tensorboard import SummaryWriter

from Utils.ASR.models import ASRCNN
from Utils.JDC.model import JDCNet

import logging
from logging import StreamHandler

: 

In [7]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = StreamHandler()
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)

torch.backends.cudnn.benchmark = True #

def main(config_path):
    config = yaml.safe_load(open(config_path))
    print(config)
    log_dir = config['log_dir']
    if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
    shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
    writer = SummaryWriter(log_dir + "/tensorboard")

    # write logs
    file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
    logger.addHandler(file_handler)
    
    ### Get configuration
    batch_size = config.get('batch_size', 2)
    device = config.get('device', 'cpu')
    epochs = config.get('epochs', 1000)
    save_freq = config.get('save_freq', 20)
    dataset_configuration = config.get('dataset_configuration', None)
    stage = config.get('stage', 'star')
    fp16_run = config.get('fp16_run', False)
    ###
    
    train_set_path = "./Data/training_list.txt"
    validation_set_path = "./Data/validation_list.txt"
    # load dataloader 
    train_dataloader = build_dataloader(train_set_path,dataset_configuration,
                                        batch_size=batch_size,
                                        num_workers=2,
                                        device=device)
    
    val_dataloader = build_dataloader(validation_set_path,dataset_configuration,
                                        batch_size=batch_size,
                                        num_workers=2,
                                        device=device)

    # load pretrained ASR model, FROZEN
    ASR_config = config.get('ASR_config', False)
    ASR_path = config.get('ASR_path', False)
    with open(ASR_config) as f:
            ASR_config = yaml.safe_load(f)
    ASR_model_config = ASR_config['model_params']
    ASR_model = ASRCNN(**ASR_model_config)
    params = torch.load(ASR_path, map_location='cpu')['model']
    ASR_model.load_state_dict(params)
    _ = ASR_model.eval()    
    
    # load pretrained F0 model
    F0_path = config.get('F0_path', False)
    F0_model = JDCNet(num_class=1, seq_len=192)
    params = torch.load(F0_path, map_location='cpu')['net']
    F0_model.load_state_dict(params)
    
    # build model
    model, model_ema = build_model(Munch(config['model_params']), F0_model, ASR_model)

    scheduler_params = {
        "max_lr": float(config['optimizer_params'].get('lr', 2e-4)),
        "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)),
        "epochs": epochs,
        "steps_per_epoch": len(train_dataloader),
    }
    
    _ = [model[key].to(device) for key in model]
    _ = [model_ema[key].to(device) for key in model_ema]
    scheduler_params_dict = {key: scheduler_params.copy() for key in model}
    optimizer = build_optimizer({key: model[key].parameters() for key in model},
                                      scheduler_params_dict=scheduler_params_dict)

    trainer = Trainer(args=Munch(config['loss_params']), model=model,
                            model_ema=model_ema,
                            optimizer=optimizer,
                            device=device,
                            train_dataloader=train_dataloader,
                            val_dataloader=val_dataloader,
                            logger=logger,
                            fp16_run=fp16_run)

    if config.get('pretrained_model', '') != '':
        trainer.load_checkpoint(config['pretrained_model'],
                                load_only_params=config.get('load_only_params', True))

    for _ in range(1, epochs+1):
        epoch = trainer.epochs
        train_results = trainer._train_epoch()
        eval_results = trainer._eval_epoch()
        results = train_results.copy()
        results.update(eval_results)
        logger.info('--- epoch %d ---' % epoch)
        for key, value in results.items():
            if isinstance(value, float):
                logger.info('%-15s: %.4f' % (key, value))
                writer.add_scalar(key, value, epoch)
            else:
                for v in value:
                    writer.add_figure('eval_spec', v, epoch)
        if (epoch % save_freq) == 0:
            trainer.save_checkpoint(osp.join(log_dir, 'epoch_%05d.pth' % epoch))
    return 0

main("Configs/config.yml")

{'log_dir': 'Models/Experiment-3', 'save_freq': 10, 'device': 'cuda', 'epochs': 150, 'batch_size': 5, 'pretrained_model': '', 'load_only_params': False, 'fp16_run': True, 'dataset_configuration': {'data_separetor': '|', 'data_header': ['actor_id', 'statement_id', 'source_path', 'source_emotion', 'reference_path', 'reference_emotion']}, 'F0_path': 'Utils/JDC/bst.t7', 'ASR_config': 'Utils/ASR/config.yml', 'ASR_path': 'Utils/ASR/epoch_00100.pth', 'preprocess_params': {'sr': 24000, 'spect_params': {'n_fft': 2048, 'win_length': 1200, 'hop_length': 300}}, 'model_params': {'dim_in': 64, 'style_dim': 64, 'latent_dim': 16, 'num_domains': 4, 'max_conv_dim': 512, 'n_repeat': 4, 'w_hpf': 0, 'F0_channel': 256}, 'loss_params': {'g_loss': {'lambda_sty': 1.0, 'lambda_cyc': 5.0, 'lambda_ds': 1.0, 'lambda_norm': 1.0, 'lambda_asr': 10.0, 'lambda_f0': 5.0, 'lambda_f0_sty': 0.1, 'lambda_adv': 2.0, 'lambda_adv_cls': 0.5, 'norm_bias': 0.5}, 'd_loss': {'lambda_reg': 1.0, 'lambda_adv_cls': 0.1, 'lambda_con_reg

[train]:   0%|          | 0/3632 [00:16<?, ?it/s]


KeyboardInterrupt: 