Skip to content

Commit

Permalink
fix amp
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 4, 2023
1 parent deb97db commit bac049a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
6 changes: 2 additions & 4 deletions naturalspeech2_pytorch/naturalspeech2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,7 @@ def __init__(
num_samples = 1,
results_folder = './results',
amp = False,
fp16 = False,
mixed_precision_type = 'fp16',
use_ema = True,
split_batches = True,
dataloader = None,
Expand All @@ -1663,11 +1663,9 @@ def __init__(

self.accelerator = Accelerator(
split_batches = split_batches,
mixed_precision = 'fp16' if fp16 else 'no'
mixed_precision = mixed_precision_type if amp else 'no'
)

self.accelerator.native_amp = amp

# model

self.model = diffusion_model
Expand Down
2 changes: 1 addition & 1 deletion naturalspeech2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.4'
__version__ = '0.1.5'

0 comments on commit bac049a

Please sign in to comment.