Skip to content

Commit

Permalink
Merge pull request #19 from eginhard/fix-vits-comments
Browse files Browse the repository at this point in the history
docs(tts.models.vits): clarify use of discriminator/generator
  • Loading branch information
eginhard committed Mar 12, 2024
2 parents 0c6c20f + 89a061f commit eaa7283
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> T
Args:
batch (Dict): Input tensors.
criterion (nn.Module): Loss layer designed for the model.
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
optimizer_idx (int): Index of optimizer to use. 0 for the discriminator and 1 for the generator networks.
Returns:
Tuple[Dict, Dict]: Model ouputs and computed losses.
Expand Down Expand Up @@ -1651,13 +1651,16 @@ def get_data_loader(

def get_optimizer(self) -> List:
"""Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
It returns 2 optimizers in a list. First one is for the discriminator
and the second one is for the generator.
Returns:
List: optimizers.
"""
# select generator parameters
optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)

# select generator parameters
gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
optimizer1 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
Expand Down

0 comments on commit eaa7283

Please sign in to comment.