Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility problem with torch >= 1.8.0 when torch_complex package is not installed #8

Closed
adriengossesonos opened this issue Jun 4, 2022 · 2 comments

Comments

@adriengossesonos
Copy link

adriengossesonos commented Jun 4, 2022

Hello,
I noticed that when trying to use the package (version 0.1.3), I get some compatibility issues when using torch.Tensor inputs for the method bss_eval_sources because I did not have the torch_complex package installed. However, the torch_complex package shouldn't be required in this case since I use torch 1.10.2.

This happens because in the __init__.py file, the variable has_torch is not set to True

try:
    import torch as pt
    has_torch = True

    from . import torch as torch     # --> this line fails
    from .torch import sdr_pit_loss, si_sdr_pit_loss   
except ImportError:
    has_torch = False

    # dummy pytorch module
    class pt:
        class Tensor:
            def __init__(self):
                pass

    # dummy torch submodule
    class torch:
        bss_eval_sources = None
        sdr = None
        sdr_loss = None

from . import numpy as numpy

Apparently this happens because the line that fails tries to import the file torch/compatibility.py :

try:
    from packaging.version import Version
except [ImportError, ModuleNotFoundError]:
    from distutils.version import LooseVersion as Version

from torch_complex import ComplexTensor # --> this line causes the problem when torch_complex is not installed 

import torch

is_torch_1_8_plus = Version(torch.__version__) >= Version("1.8.0")

if not is_torch_1_8_plus:
    try:
        import torch_complex
    except ImportError:
        raise ImportError(
            "When using torch<=1.7, the package torch_complex is required."
            " Install it as `pip install torch_complex`"
        )

If I understand correctly, the fix would simply be to do the following :

try:
    from packaging.version import Version
except [ImportError, ModuleNotFoundError]:
    from distutils.version import LooseVersion as Version

import torch

is_torch_1_8_plus = Version(torch.__version__) >= Version("1.8.0")

if not is_torch_1_8_plus:
    try:
        from torch_complex import ComplexTensor 
    except ImportError:
        raise ImportError(
            "When using torch<=1.7, the package torch_complex is required."
            " Install it as `pip install torch_complex`"
        )
@fakufaku
Copy link
Owner

fakufaku commented Jun 7, 2022

Thank you very much for the detailed bug report! I also had some issues with torch_complex import failing, but had not looked into it! The bugfix is simple enough: I'll try it right now!

@fakufaku fakufaku mentioned this issue Jun 7, 2022
@fakufaku
Copy link
Owner

fakufaku commented Jun 7, 2022

@adriengossesonos , thanks to you I fixed the bug and bumped to 1.4! Please let me know if there are more problems!

@fakufaku fakufaku closed this as completed Jun 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants