Skip to content

Commit

Permalink
fix: cast to real numbers since pytorch 2.0+ supports complex
Browse files Browse the repository at this point in the history
  • Loading branch information
iyaja committed Feb 1, 2024
1 parent 784f87b commit b506a0f
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tests/core/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@ def _test_audio_grad(attr: str, target=True, kwargs: dict = {}):
# If necessary, propagate spectrogram changes to waveform
if result.stft_data is not None:
result.istft()
result.audio_data.sum().backward()
if result.audio_data.dtype.is_complex:
result.real().sum().backward()
else:
result.audio_data.sum().backward()
else:
result.sum().backward()
if result.dtype.is_complex:
result.real.sum().backward()
else:
result.sum().backward()

assert signal.audio_data.grad is not None or not target
except RuntimeError:
Expand Down

0 comments on commit b506a0f

Please sign in to comment.