Skip to content

Commit

Permalink
handle distributed model as saving
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Oct 29, 2020
1 parent 9d0ae2b commit e723b99
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion TTS/bin/train_glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def train(model, criterion, optimizer, scheduler,
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths)

# backward pass
# backward pass - DISTRIBUTED
if amp is not None:
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
scaled_loss.backward()
Expand Down
8 changes: 6 additions & 2 deletions TTS/tts/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from TTS.utils.io import RenamingUnpickler



def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
try:
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
Expand All @@ -25,9 +26,12 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):


def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
new_state_dict = model.state_dict()
if hasattr(model, 'module'):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
state = {
'model': new_state_dict,
'model': model_state,
'optimizer': optimizer.state_dict() if optimizer is not None else None,
'step': current_step,
'epoch': epoch,
Expand Down
File renamed without changes.
5 changes: 4 additions & 1 deletion TTS/vocoder/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False):

def save_model(model, optimizer, scheduler, model_disc, optimizer_disc,
scheduler_disc, current_step, epoch, output_path, **kwargs):
model_state = model.state_dict()
if hasattr(model, 'module'):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
model_disc_state = model_disc.state_dict()\
if model_disc is not None else None
optimizer_state = optimizer.state_dict()\
Expand Down

0 comments on commit e723b99

Please sign in to comment.