Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduction=none broken #69

Open
turian opened this issue Dec 20, 2023 · 7 comments
Open

reduction=none broken #69

turian opened this issue Dec 20, 2023 · 7 comments

Comments

@turian
Copy link
Contributor

turian commented Dec 20, 2023

import torch
import auraloss

mrstft = auraloss.freq.MultiResolutionSTFTLoss(reduction="none")

input = torch.rand(8,1,44100)
target = torch.rand(8,1,44100)

loss = mrstft(input, target)
print(loss)
print(loss.shape)

gives

Traceback (most recent call last):
  File "/private/tmp/testaura.py", line 9, in <module>
    loss = mrstft(input, target)
  File "/opt/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/miniconda3/lib/python3.9/site-packages/auraloss/freq.py", line 410, in forward
    mrstft_loss += f(x, y)
RuntimeError: The size of tensor a (368) must match the size of tensor b (184) at non-singleton dimension 2

What I'm really looking for is a per-instance reduction, which I can compute from reduction=none.
Anyway, reduction=none is not working out of the box, which is unfortunately a showstopper for me :(

@mcherep
Copy link

mcherep commented Feb 19, 2024

I have the same issue

@csteinmetz1
Copy link
Owner

The reason this currently doesn't work is that the desired behavior for reduction="none" when using the multi-resolution loss is a bit ambiguous. Since each STFTLoss will produce a different shape both in the frequency axis and the time axis it is not possible to combine them into a single loss tensor, as is the normal behavior.

We could consider returning an list of tensors that correspond to each STFTLoss output, but this would provide a different return type than the normal behavior which is just a single tensor. Would this address your applications @turian, @mcherep?

@mcherep
Copy link

mcherep commented Feb 19, 2024

I think this would work for me, what I need is to have one aggregated loss per batch instead of aggregating over all dimensions. Let me know if I'm missing something!

@csteinmetz1
Copy link
Owner

Perhaps the best way to achieve the desired behavior right now is to manage a set of STFTLoss instances yourself. That way you can define the behavior of how they are aggregated, in this case without any. I am hesitant to add more complexity to the possible return types of MultiResolutionSTFTLoss class if we can avoid it. This class is really just a wrapper around STFTLoss for convenience. Here is a potential example.

import torch
import auraloss

fft_sizes = [512, 1024, 2048]
win_lengths = [512, 1024, 2048]
hop_sizes = [256, 512, 1024]
reduction = "none"

loss_fns = torch.nn.ModuleList()

for fft_size, win_length, hop_size in zip(fft_sizes, win_lengths, hop_sizes):
  loss_fns.append(auraloss.freq.STFTLoss(fft_size, hop_size, win_length, reduction=reduction))

bs = 4
chs = 1
seq_len = 131072

x = torch.randn(bs, chs, seq_len)
y = torch.randn(bs, chs, seq_len)

for loss_fn in loss_fns:
  loss = loss_fn(x, y)
  print(loss.shape)

outputs

torch.Size([4, 257, 513])
torch.Size([4, 513, 257])
torch.Size([4, 1025, 129])

@mcherep
Copy link

mcherep commented Feb 21, 2024

Great, I will do this instead. Thanks so much!

@egaznep
Copy link

egaznep commented Apr 17, 2024

I came up with another hacky way to do this while leveraging all the convenience of the MultiResolutionSTFT class is to monkey-patch the auraloss.freq.apply_reduction function in the following way:

auraloss.freq.apply_reduction = lambda losses, reduction: losses.mean(dim=(-1,-2)) (last two dims are STFT bins and frames, reduction averages over them, keeping every other dimension)

@egaznep
Copy link

egaznep commented Apr 23, 2024

@csteinmetz1 Also, I noticed that for spectral convergence loss, despite the option `reduction=None, we nevertheless get a reduced result, because it has been implemented like this:

        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")

instead of

        return torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2), keepdim=True) / torch.norm(y_mag, p="fro", dim=(-1, -2), keepdim=True)

and noticing this also made me wonder if the denominator of the current implementation makes sense. I know that few other works like ParallelWaveGAN implemented this loss in the same way but still I find it rather counterintuitive that it's globally computed instead of "per instance".

renared pushed a commit to renared/auraloss that referenced this issue May 22, 2024
the denominator was averaged over all dimensions including the batch dimension, see comment by @egaznep in csteinmetz1#69
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants