-
Notifications
You must be signed in to change notification settings - Fork 363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dynamic batch size for VALLE #19
Conversation
models/tts/valle/valle_dataset.py
Outdated
@@ -82,6 +93,10 @@ def __getitem__(self, index): | |||
|
|||
return single_feature | |||
|
|||
def get_num_frames(self, index): | |||
utt_info = self.metadata[index] | |||
return int(utt_info['Duration'] * 75) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the meaning of 75? Is it a fixed parameter or a variable parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is a fixed parameter. 75 means 1s have 75 tokens for encodec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to move this parameter to the config file because it relies on the codec.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
models/tts/valle/valle_trainer.py
Outdated
if not self.cfg.train.use_dynamic_batchsize: | ||
return super()._build_dataloader() | ||
Dataset, Collator = self._build_dataset() | ||
train_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it only work on dataset[0] instead of all elements of the dataset list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It only works for dynamic batchsize training for VALLE, if not, it will use super()._build_dataloader()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cfg.dataset specifics a dataset list to be processed, but your code seems to only process the first dataset (self.cfg.dataset[0]). How about other datasets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
models/tts/base/tts_trainer.py
Outdated
if (self.model_type == "VALLE") and (not self.cfg.train.use_dynamic_batchsize): | ||
( | ||
self.train_dataloader, | ||
self.valid_dataloader, | ||
) = self.accelerator.prepare( | ||
self.train_dataloader, | ||
self.valid_dataloader, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This judgment statement means that if the model_type is VITS or Fastspeech2, the dataloader will not be prepared. It is wrong!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rewrite accelerator prepare in valle
Add dynamic batch size for VALLE