Skip to content

Commit

Permalink
add test for torch.complex in test_mixit_solver.py
Browse files Browse the repository at this point in the history
  • Loading branch information
simpleoier committed Sep 12, 2022
1 parent f705a58 commit d7f8047
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions test/espnet2/enh/loss/wrappers/test_mixit_solver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import pytest
import torch
import torch.nn.functional as F
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor

from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainL1
from espnet2.enh.loss.criterions.time_domain import TimeDomainL1
from espnet2.enh.loss.wrappers.mixit_solver import MixITSolver

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


@pytest.mark.parametrize("inf_num, time_domain", [(4, True), (4, False)])
def test_MixITSolver_forward(inf_num, time_domain):
Expand Down Expand Up @@ -53,27 +56,44 @@ def test_MixITSolver_forward(inf_num, time_domain):
assert perm[1].equal(torch.tensor(correct_perm2))


@pytest.mark.parametrize("inf_num", [4])
def test_MixITSolver_complex_tensor_forward(inf_num):
@pytest.mark.parametrize(
"inf_num, torch_complex",
[(4, True), (4, False)],
)
def test_MixITSolver_complex_forward(inf_num, torch_complex):

batch = 2
solver = MixITSolver(FrequencyDomainL1())

inf = [
ComplexTensor(
torch.rand(batch, 100, 10, 10),
torch.rand(batch, 100, 10, 10),
)
for _ in range(inf_num)
]
# 2 speaker's reference
ref = [
ComplexTensor(
torch.zeros(batch, 100, 10, 10),
torch.zeros(batch, 100, 10, 10),
)
for _ in range(inf_num // 2)
]
if torch_complex:
if is_torch_1_9_plus:
inf = [
torch.rand(batch, 100, 10, 10, dtype=torch.cfloat)
for _ in range(inf_num)
]
# 2 speaker's reference
ref = [
torch.zeros(batch, 100, 10, 10, dtype=torch.cfloat),
torch.zeros(batch, 100, 10, 10, dtype=torch.cfloat),
]
else:
return
else:
inf = [
ComplexTensor(
torch.rand(batch, 100, 10, 10),
torch.rand(batch, 100, 10, 10),
)
for _ in range(inf_num)
]
# 2 speaker's reference
ref = [
ComplexTensor(
torch.zeros(batch, 100, 10, 10),
torch.zeros(batch, 100, 10, 10),
)
for _ in range(inf_num // 2)
]

ref[0][0] = inf[2][0] + inf[3][0] # sample1, speaker 1
ref[1][0] = inf[0][0] + inf[1][0] # sample1, speaker 2
Expand Down

0 comments on commit d7f8047

Please sign in to comment.