Skip to content

Commit

Permalink
[install] Upgrade dependencies to torch 2.x and lightning 2.x (#682)
Browse files Browse the repository at this point in the history
* Update Lighting and Torch requirements to 2.X

* Remove optimizer_idx from `lr_scheduler_step` in System

* Replace torch.symeig with torch.linalg.eigh in beamforming

* Disable X_UMX model in torch2.x (complex STFT needs to be used)

* Skip XUMX tests until it is fixed.

* Replace on_x_end to on_train_x_end

* Replace torch.testing.assert_allclose with torch.testing.assert_close.

* Black it
  • Loading branch information
mpariente committed Oct 12, 2023
1 parent 4272438 commit fc13514
Show file tree
Hide file tree
Showing 22 changed files with 67 additions and 57 deletions.
4 changes: 2 additions & 2 deletions asteroid/dsp/beamforming.py
Expand Up @@ -94,7 +94,7 @@ def forward(
"""
# TODO: Implement several RTF estimation strategies, and choose one here, or expose all.
# Get relative transfer function (1st PCA of Σss)
e_val, e_vec = torch.symeig(target_scm.permute(0, 3, 1, 2), eigenvectors=True)
e_val, e_vec = torch.linalg.eigh(target_scm.permute(0, 3, 1, 2))
rtf_vect = e_vec[..., -1] # bfm
return self.from_rtf_vect(mix=mix, rtf_vec=rtf_vect.transpose(-1, -2), noise_scm=noise_scm)

Expand Down Expand Up @@ -471,7 +471,7 @@ def _generalized_eigenvalue_decomposition(a, b):
# Compute C matrix L⁻1 A L^-T
cmat = inv_cholesky @ a @ inv_cholesky.conj().transpose(-1, -2)
# Performing the eigenvalue decomposition
e_val, e_vec = torch.symeig(cmat, eigenvectors=True)
e_val, e_vec = torch.linalg.eigh(cmat)
# Collecting the eigenvectors
e_vec = torch.matmul(inv_cholesky.conj().transpose(-1, -2), e_vec)
return e_val, e_vec
Expand Down
2 changes: 1 addition & 1 deletion asteroid/engine/schedulers.py
Expand Up @@ -180,7 +180,7 @@ class SinkPITBetaScheduler(pl.callbacks.Callback):
def __init__(self, cooling_schedule=sinkpit_default_beta_schedule):
self.cooling_schedule = cooling_schedule

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
assert isinstance(pl_module.loss_func, SinkPITLossWrapper)
assert trainer.current_epoch == pl_module.current_epoch # same
epoch = pl_module.current_epoch
Expand Down
2 changes: 1 addition & 1 deletion asteroid/engine/system.py
Expand Up @@ -163,7 +163,7 @@ def configure_optimizers(self):
epoch_schedulers.append(sched)
return [self.optimizer], epoch_schedulers

def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
def lr_scheduler_step(self, scheduler, metric):
if metric is None:
scheduler.step()
else:
Expand Down
7 changes: 7 additions & 0 deletions asteroid/models/x_umx.py
Expand Up @@ -7,6 +7,13 @@


class XUMX(BaseModel):
def __init__(self, *args, **kwargs):
raise RuntimeError(
"XUMX is broken in torch 2.0, use torch<2.0 with asteroid<0.7 to use it until it's fixed."
)


class BrokenXUMX(BaseModel):
r"""CrossNet-Open-Unmix (X-UMX) for Music Source Separation introduced in [1].
There are two notable contributions with no effect on inference:
a) Multi Domain Losses
Expand Down
2 changes: 1 addition & 1 deletion notebooks/03_PITLossWrapper.ipynb
Expand Up @@ -187,7 +187,7 @@
" return pw_loss.mean(dim=mean_over)\n",
"# Compute pairwise losses using broadcasting (+ unit test equality)\n",
"direct_pairwise_losses = pairwise_mse(estimate_sources, sources)\n",
"torch.testing.assert_allclose(pairwise_losses, direct_pairwise_losses)\n",
"torch.testing.assert_close(pairwise_losses, direct_pairwise_losses)\n",
"# Plot the pairwise losses\n",
"ax = plt.imshow(direct_pairwise_losses[0].data.numpy())"
]
Expand Down
2 changes: 1 addition & 1 deletion requirements/install.txt
Expand Up @@ -2,7 +2,7 @@
-r ./torchhub.txt
PyYAML>=5.0
pandas>=0.23.4
pytorch-lightning>=1.5.0,<=1.7.7
pytorch-lightning>=2.0.0
torchmetrics<=0.11.4
torchaudio>=0.8.0
pb_bss_eval>=0.0.2
Expand Down
2 changes: 1 addition & 1 deletion requirements/torchhub.txt
Expand Up @@ -2,7 +2,7 @@
# Note that Asteroid itself is not required to be installed.
numpy>=1.16.4
scipy>=1.10.1
torch>=1.8.0,<2.0.0
torch>=2.0.0
asteroid-filterbanks>=0.4.0
requests
filelock
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -39,14 +39,14 @@ def find_version(*file_paths):
# From requirements/torchhub.txt
"numpy>=1.16.4",
"scipy>=1.10.1",
"torch>=1.8.0,<2.0.0",
"torch>=2.0.0",
"asteroid-filterbanks>=0.4.0",
"SoundFile>=0.10.2",
"huggingface_hub>=0.0.2",
# From requirements/install.txt
"PyYAML>=5.0",
"pandas>=0.23.4",
"pytorch-lightning>=1.5.0,<=1.7.7",
"pytorch-lightning>=2.0.0",
"torchmetrics<=0.11.4",
"torchaudio>=0.5.0",
"pb_bss_eval>=0.0.2",
Expand Down
20 changes: 10 additions & 10 deletions tests/complex_nn_test.py
@@ -1,5 +1,5 @@
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import pytest
import math

Expand All @@ -17,23 +17,23 @@ def test_torch_complex_from_magphase():
mag = torch.randn(shape).abs()
phase = torch.remainder(torch.randn(shape), math.pi)
out = cnn.torch_complex_from_magphase(mag, phase)
assert_allclose(torch.abs(out), mag)
assert_allclose(out.angle(), phase)
assert_close(torch.abs(out), mag)
assert_close(out.angle(), phase)


def test_torch_complex_from_reim():
comp = torch.randn(10, 12, dtype=torch.complex64)
assert_allclose(cnn.torch_complex_from_reim(comp.real, comp.imag), comp)
assert_close(cnn.torch_complex_from_reim(comp.real, comp.imag), comp)


def test_onreim():
inp = torch.randn(10, 10, dtype=torch.complex64)
# Identity
fn = cnn.on_reim(lambda x: x)
assert_allclose(fn(inp), inp)
assert_close(fn(inp), inp)
# Top right quadrant
fn = cnn.on_reim(lambda x: x.abs())
assert_allclose(fn(inp), cnn.torch_complex_from_reim(inp.real.abs(), inp.imag.abs()))
assert_close(fn(inp), cnn.torch_complex_from_reim(inp.real.abs(), inp.imag.abs()))


def test_on_reim_class():
Expand All @@ -48,16 +48,16 @@ def forward(self, x):
return x + self.a

fn = cnn.OnReIm(Identity, 0)
assert_allclose(fn(inp), inp)
assert_close(fn(inp), inp)
fn = cnn.OnReIm(Identity, 1)
assert_allclose(fn(inp), cnn.torch_complex_from_reim(inp.real + 1, inp.imag + 1))
assert_close(fn(inp), cnn.torch_complex_from_reim(inp.real + 1, inp.imag + 1))


def test_complex_mul_wrapper():
a = torch.randn(10, 10, dtype=torch.complex64)

fn = cnn.ComplexMultiplicationWrapper(torch.nn.ReLU)
assert_allclose(
assert_close(
fn(a),
cnn.torch_complex_from_reim(
torch.relu(a.real) - torch.relu(a.imag), torch.relu(a.real) + torch.relu(a.imag)
Expand Down Expand Up @@ -86,4 +86,4 @@ def test_complexsinglernn(n_layers):
reim = layer.re_module(inp.imag)
imre = layer.im_module(inp.real)
inp = cnn.torch_complex_from_reim(rere - imim, reim + imre)
assert_allclose(out, inp)
assert_close(out, inp)
6 changes: 3 additions & 3 deletions tests/dsp/consistency_test.py
@@ -1,5 +1,5 @@
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import pytest

from asteroid.dsp.consistency import mixture_consistency
Expand All @@ -13,7 +13,7 @@ def test_consistency_noweight(mix_shape, dim, n_src):
est_shape = mix_shape[:dim] + [n_src] + mix_shape[dim:]
est_sources = torch.randn(est_shape)
consistent_est_sources = mixture_consistency(mix, est_sources, dim=dim)
assert_allclose(mix, consistent_est_sources.sum(dim))
assert_close(mix, consistent_est_sources.sum(dim))


@pytest.mark.parametrize("mix_shape", [[2, 1600], [2, 130, 10]])
Expand All @@ -30,7 +30,7 @@ def test_consistency_withweight(mix_shape, dim, n_src):
src_weights = torch.softmax(torch.randn(src_weights_shape), dim=dim)
# Apply mixture consitency
consistent_est_sources = mixture_consistency(mix, est_sources, src_weights=src_weights, dim=dim)
assert_allclose(mix, consistent_est_sources.sum(dim))
assert_close(mix, consistent_est_sources.sum(dim))


def test_consistency_raise():
Expand Down
4 changes: 2 additions & 2 deletions tests/dsp/overlap_add_test.py
@@ -1,5 +1,5 @@
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import pytest

from asteroid.dsp.overlap_add import LambdaOverlapAdd
Expand All @@ -16,4 +16,4 @@ def test_overlap_add(length, batch_size, n_src, window, window_size, hop_size):
nnet = lambda x: x.unsqueeze(1).repeat(1, n_src, 1)
oladd = LambdaOverlapAdd(nnet, n_src, window_size, hop_size, window)
oladded = oladd(mix)
assert_allclose(mix.repeat(1, n_src, 1), oladded)
assert_close(mix.repeat(1, n_src, 1), oladded)
4 changes: 2 additions & 2 deletions tests/jit/jit_filterbanks_test.py
@@ -1,6 +1,6 @@
import torch
import pytest
from torch.testing import assert_allclose
from torch.testing import assert_close
from asteroid_filterbanks import make_enc_dec
from asteroid.models.base_models import BaseEncoderMaskerDecoder

Expand Down Expand Up @@ -30,7 +30,7 @@ def test_jit_filterbanks(filter_bank_name, inference_data):
with torch.no_grad():
res = model(inference_data)
out = traced(inference_data)
assert_allclose(res, out)
assert_close(res, out)


class DummyModel(BaseEncoderMaskerDecoder):
Expand Down
8 changes: 4 additions & 4 deletions tests/jit/jit_masknn_test.py
@@ -1,6 +1,6 @@
import pytest
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
from asteroid.masknn import norms


Expand All @@ -13,10 +13,10 @@ def test_lns(cls):
traced = torch.jit.trace(model, x)

y = torch.randn(3, chan_size, 18, 12)
assert_allclose(traced(y), model(y))
assert_close(traced(y), model(y))

y = torch.randn(2, chan_size, 10, 5, 4)
assert_allclose(traced(y), model(y))
assert_close(traced(y), model(y))


def test_cumln():
Expand All @@ -27,4 +27,4 @@ def test_cumln():
traced = torch.jit.trace(model, x)

y = torch.randn(3, chan_size, 100)
assert_allclose(traced(y), model(y))
assert_close(traced(y), model(y))
4 changes: 2 additions & 2 deletions tests/jit/jit_models_test.py
Expand Up @@ -2,7 +2,7 @@

import torch
import pytest
from torch.testing import assert_allclose
from torch.testing import assert_close
from asteroid.models import (
DCCRNet,
DCUNet,
Expand All @@ -20,7 +20,7 @@
def assert_consistency(model, traced, tensor):
ref = model(tensor)
out = traced(tensor)
assert_allclose(ref, out)
assert_close(ref, out)


@pytest.fixture(scope="module")
Expand Down
12 changes: 6 additions & 6 deletions tests/losses/loss_functions_test.py
@@ -1,6 +1,6 @@
import pytest
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import warnings

from asteroid_filterbanks import STFTFB, Encoder, transforms
Expand Down Expand Up @@ -68,15 +68,15 @@ def test_sisdr_and_mse(n_src, loss):
w_src_wrapper = PITLossWrapper(multisrc, pit_from="perm_avg")

# Circular tests on value
assert_allclose(pw_wrapper(est_targets, targets), wo_src_wrapper(est_targets, targets))
assert_allclose(wo_src_wrapper(est_targets, targets), w_src_wrapper(est_targets, targets))
assert_close(pw_wrapper(est_targets, targets), wo_src_wrapper(est_targets, targets))
assert_close(wo_src_wrapper(est_targets, targets), w_src_wrapper(est_targets, targets))

# Circular tests on returned estimates
assert_allclose(
assert_close(
pw_wrapper(est_targets, targets, return_est=True)[1],
wo_src_wrapper(est_targets, targets, return_est=True)[1],
)
assert_allclose(
assert_close(
wo_src_wrapper(est_targets, targets, return_est=True)[1],
w_src_wrapper(est_targets, targets, return_est=True)[1],
)
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_pmsqe(sample_rate):
assert loss_value.shape[0] == ref.shape[0]
# Assert support for transposed inputs.
tr_loss_value = loss_func(est_spec.transpose(1, 2), ref_spec.transpose(1, 2))
assert_allclose(loss_value, tr_loss_value)
assert_close(loss_value, tr_loss_value)


@pytest.mark.parametrize("n_src", [2, 3])
Expand Down
12 changes: 6 additions & 6 deletions tests/losses/pit_wrapper_test.py
@@ -1,7 +1,7 @@
import pytest
import itertools
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close

from asteroid.losses import PITLossWrapper, pairwise_mse

Expand Down Expand Up @@ -71,7 +71,7 @@ def test_permutation(perm):
loss_value, reordered = loss_func(est_sources, sources, return_est=True)

assert loss_value.item() == 0
assert_allclose(sources, reordered)
assert_close(sources, reordered)


def test_permreduce():
Expand All @@ -95,8 +95,8 @@ def test_permreduce():
w_mean = w_mean_reduce(est_sources, sources)
w_sum = w_sum_reduce(est_sources, sources)

assert_allclose(wo, w_mean)
assert_allclose(wo, w_sum / n_src)
assert_close(wo, w_mean)
assert_close(wo, w_sum / n_src)


def test_permreduce_args():
Expand All @@ -123,8 +123,8 @@ def test_best_perm_match(n_src):
min_loss, min_idx = PITLossWrapper.find_best_perm_factorial(pwl)
min_loss_hun, min_idx_hun = PITLossWrapper.find_best_perm_hungarian(pwl)

assert_allclose(min_loss, min_loss_hun)
assert_allclose(min_idx, min_idx_hun)
assert_close(min_loss, min_loss_hun)
assert_close(min_idx, min_idx_hun)


def test_raises_wrong_pit_from():
Expand Down
6 changes: 3 additions & 3 deletions tests/losses/sinkpit_wrapper_test.py
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import nn, optim
from torch.utils import data
from torch.testing import assert_allclose
from torch.testing import assert_close

import pytorch_lightning as pl
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_proximity_sinkhorn_hungarian(batch_size, n_src, beta, n_iter, function_
mean_loss_hungarian = loss_hungarian(est_targets, targets, return_est=False)

# compare
assert_allclose(mean_loss_sinkhorn, mean_loss_hungarian)
assert_close(mean_loss_sinkhorn, mean_loss_hungarian)


class _TestCallback(pl.callbacks.Callback):
Expand All @@ -99,7 +99,7 @@ def __init__(self, function, total, batch_size):
self.epoch = 0
self.n_batch = total // batch_size

def on_batch_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, *args, **kwargs):
step = trainer.global_step
assert self.epoch * self.n_batch <= step
assert step <= (self.epoch + 1) * self.n_batch
Expand Down
6 changes: 3 additions & 3 deletions tests/masknn/activations_test.py
Expand Up @@ -2,7 +2,7 @@
# list of strings / function pair in parametrize?
import pytest
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close

from asteroid.masknn import activations
from torch import nn
Expand All @@ -27,14 +27,14 @@ def test_activations(activation_tuple):
asteroid_act = activations.get(asteroid_act)()

inp = torch.randn(10, 11, 12)
assert_allclose(torch_act(inp), asteroid_act(inp))
assert_close(torch_act(inp), asteroid_act(inp))


def test_softmax():
torch_softmax = nn.Softmax(dim=-1)
asteroid_softmax = activations.get("softmax")(dim=-1)
inp = torch.randn(10, 11, 12)
assert_allclose(torch_softmax(inp), asteroid_softmax(inp))
assert_close(torch_softmax(inp), asteroid_softmax(inp))
assert torch_softmax == activations.get(torch_softmax)


Expand Down

0 comments on commit fc13514

Please sign in to comment.