Skip to content

Commit

Permalink
Merge pull request #3243 from idiap/checkpoints
Browse files Browse the repository at this point in the history
Remove duplicate/unused code
  • Loading branch information
erogol committed Nov 22, 2023
2 parents 29dede2 + 0fb0d67 commit b47d9c6
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 241 deletions.
23 changes: 18 additions & 5 deletions TTS/bin/train_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@

import torch
from torch.utils.data import DataLoader
from trainer.io import copy_model_files, save_best_model, save_checkpoint
from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer

from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files
from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update

Expand Down Expand Up @@ -222,7 +222,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,

if global_step % c.save_step == 0:
# save model
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch)
save_checkpoint(
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
)

end_time = time.time()

Expand All @@ -245,7 +247,18 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
flush=True,
)
# save the best checkpoint
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
best_loss = save_best_model(
eval_loss,
best_loss,
c,
model,
optimizer,
None,
global_step,
epoch,
OUT_PATH,
criterion=criterion.state_dict(),
)
model.train()

return best_loss, global_step
Expand Down Expand Up @@ -276,7 +289,7 @@ def main(args): # pylint: disable=redefined-outer-name

if c.loss == "softmaxproto" and c.model != "speaker_encoder":
c.map_classid_to_classname = map_classid_to_classname
copy_model_files(c, OUT_PATH)
copy_model_files(c, OUT_PATH, new_fields={})

if args.restore_path:
criterion, args.restore_step = model.load_checkpoint(
Expand Down
46 changes: 0 additions & 46 deletions TTS/encoder/utils/generic_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import datetime
import glob
import os
import random
import re

import numpy as np
from scipy import signal

from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec


class AugmentWAV(object):
Expand Down Expand Up @@ -118,11 +115,6 @@ def apply_one(self, audio):
return self.additive_noise(noise_type, audio)


def to_camel(text):
text = text.capitalize()
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)


def setup_encoder_model(config: "Coqpit"):
if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
Expand All @@ -142,41 +134,3 @@ def setup_encoder_model(config: "Coqpit"):
audio_config=config.audio,
)
return model


def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
checkpoint_path = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))

new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
save_fsspec(state, checkpoint_path)


def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
save_fsspec(state, bestmodel_path)
return best_loss
38 changes: 0 additions & 38 deletions TTS/encoder/utils/io.py

This file was deleted.

2 changes: 1 addition & 1 deletion TTS/encoder/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint
from trainer.io import copy_model_files
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger

from TTS.config import load_config, register_config
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files


@dataclass
Expand Down
146 changes: 0 additions & 146 deletions TTS/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import datetime
import json
import os
import pickle as pickle_tts
import shutil
from typing import Any, Callable, Dict, Union

import fsspec
import torch
from coqpit import Coqpit

from TTS.utils.generic_utils import get_user_data_dir

Expand All @@ -28,34 +24,6 @@ def __init__(self, *args, **kwargs):
self.__dict__ = self


def copy_model_files(config: Coqpit, out_path, new_fields=None):
"""Copy config.json and other model files to training folder and add
new fields.
Args:
config (Coqpit): Coqpit config defining the training run.
out_path (str): output path to copy the file.
new_fields (dict): new fileds to be added or edited
in the config file.
"""
copy_config_path = os.path.join(out_path, "config.json")
# add extra information fields
if new_fields:
config.update(new_fields, allow_new=True)
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
json.dump(config.to_dict(), f, indent=4)

# copy model stats file if available
if config.audio.stats_path is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
filesystem = fsspec.get_mapper(copy_stats_path).fs
if not filesystem.exists(copy_stats_path):
with fsspec.open(config.audio.stats_path, "rb") as source_file:
with fsspec.open(copy_stats_path, "wb") as target_file:
shutil.copyfileobj(source_file, target_file)


def load_fsspec(
path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
Expand Down Expand Up @@ -100,117 +68,3 @@ def load_checkpoint(
if eval:
model.eval()
return model, state


def save_fsspec(state: Any, path: str, **kwargs):
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
Args:
state: State object to save
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.save.
"""
with fsspec.open(path, "wb") as f:
torch.save(state, f, **kwargs)


def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
if isinstance(optimizer, list):
optimizer_state = [optim.state_dict() for optim in optimizer]
elif optimizer.__class__.__name__ == "CapacitronOptimizer":
optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()]
else:
optimizer_state = optimizer.state_dict() if optimizer is not None else None

if isinstance(scaler, list):
scaler_state = [s.state_dict() for s in scaler]
else:
scaler_state = scaler.state_dict() if scaler is not None else None

if isinstance(config, Coqpit):
config = config.to_dict()

state = {
"config": config,
"model": model_state,
"optimizer": optimizer_state,
"scaler": scaler_state,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
save_fsspec(state, output_path)


def save_checkpoint(
config,
model,
optimizer,
scaler,
current_step,
epoch,
output_folder,
**kwargs,
):
file_name = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(output_folder, file_name)
print("\n > CHECKPOINT : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
**kwargs,
)


def save_best_model(
current_loss,
best_loss,
config,
model,
optimizer,
scaler,
current_step,
epoch,
out_path,
keep_all_best=False,
keep_after=10000,
**kwargs,
):
if current_loss < best_loss:
best_model_name = f"best_model_{current_step}.pth"
checkpoint_path = os.path.join(out_path, best_model_name)
print(" > BEST MODEL : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
model_loss=current_loss,
**kwargs,
)
fs = fsspec.get_mapper(out_path).fs
# only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after):
model_names = fs.glob(os.path.join(out_path, "best_model*.pth"))
for model_name in model_names:
if os.path.basename(model_name) != best_model_name:
fs.rm(model_name)
# create a shortcut which always points to the currently best model
shortcut_name = "best_model.pth"
shortcut_path = os.path.join(out_path, shortcut_name)
fs.copy(checkpoint_path, shortcut_path)
best_loss = current_loss
return best_loss
4 changes: 2 additions & 2 deletions tests/aux_tests/test_embedding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import numpy as np
import torch
from trainer.io import save_checkpoint

from tests import get_tests_input_path
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.managers import EmbeddingManager
from TTS.utils.audio import AudioProcessor

Expand All @@ -31,7 +31,7 @@ def test_speaker_embedding():

# create a dummy speaker encoder
model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0)
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())

# load audio processor and speaker encoder
manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/aux_tests/test_speaker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import numpy as np
import torch
from trainer.io import save_checkpoint

from tests import get_tests_input_path
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor

Expand All @@ -30,7 +30,7 @@ def test_speaker_embedding():

# create a dummy speaker encoder
model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0)
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())

# load audio processor and speaker encoder
ap = AudioProcessor(**config.audio)
Expand Down
Loading

0 comments on commit b47d9c6

Please sign in to comment.