Skip to content

Commit

Permalink
Complex tensor support for MixIT solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
simpleoier committed Sep 7, 2022
1 parent 63c6fc5 commit 621a0e8
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 14 deletions.
45 changes: 37 additions & 8 deletions espnet2/enh/loss/wrappers/mixit_solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import itertools
from typing import Dict, List, Union

import torch
from torch_complex.tensor import ComplexTensor

from espnet2.enh.layers.complex_utils import einsum as complex_einsum
from espnet2.enh.layers.complex_utils import stack as complex_stack
from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper

Expand All @@ -27,7 +31,18 @@ def __init__(
def name(self):
return "mixit"

def forward(self, ref, inf, others={}):
def _complex_einsum(self, equation, *operands):
for op in operands:
if not isinstance(op, ComplexTensor):
op = ComplexTensor(op, torch.zeros_like(op))
return complex_einsum(equation, *operands)

def forward(
self,
ref: Union[List[torch.Tensor], List[ComplexTensor]],
inf: Union[List[torch.Tensor], List[ComplexTensor]],
others: Dict = {},
):
"""MixIT solver.
Args:
Expand All @@ -42,8 +57,19 @@ def forward(self, ref, inf, others={}):
num_ref = num_inf // 2
device = ref[0].device

ref_tensor = torch.stack(ref[:num_ref], dim=1) # (batch, num_ref, ...)
inf_tensor = torch.stack(inf, dim=1) # (batch, num_inf, ...)
is_complex = isinstance(ref[0], ComplexTensor)
assert is_complex == isinstance(inf[0], ComplexTensor)

if not is_complex:
ref_tensor = torch.stack(ref[:num_ref], dim=1) # (batch, num_ref, ...)
inf_tensor = torch.stack(inf, dim=1) # (batch, num_inf, ...)

einsum_fn = torch.einsum
else:
ref_tensor = complex_stack(ref[:num_ref], dim=1) # (batch, num_ref, ...)
inf_tensor = complex_stack(inf, dim=1) # (batch, num_inf, ...)

einsum_fn = self._complex_einsum

# all permutation assignments:
# [(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 1, 0), ..., (1, 1, 1, 1)]
Expand All @@ -57,15 +83,15 @@ def forward(self, ref, inf, others={}):
for asm in all_assignments
],
dim=0,
).float() # (num_ref ^ num_inf, num_ref, num_inf)
).to(
inf_tensor.dtype
) # (num_ref ^ num_inf, num_ref, num_inf)

# (num_ref ^ num_inf, batch, num_ref, seq_len, ...)
if inf_tensor.dim() == 3:
est_sum_mixture = torch.einsum(
"ari,bil->abrl", all_mixture_matrix, inf_tensor
)
est_sum_mixture = einsum_fn("ari,bil->abrl", all_mixture_matrix, inf_tensor)
elif inf_tensor.dim() > 3:
est_sum_mixture = torch.einsum(
est_sum_mixture = einsum_fn(
"ari,bil...->abrl...", all_mixture_matrix, inf_tensor
)

Expand All @@ -86,6 +112,9 @@ def forward(self, ref, inf, others={}):
loss = loss.mean()
perm = torch.index_select(all_mixture_matrix, 0, perm)

if perm.is_complex():
perm = perm.real

stats = dict()
stats[f"{self.criterion.name}_{self.name}"] = loss.detach()

Expand Down
76 changes: 70 additions & 6 deletions test/espnet2/enh/loss/wrappers/test_mixit_solver.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import pytest
import torch
import torch.nn.functional as F
from torch_complex.tensor import ComplexTensor

from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainL1
from espnet2.enh.loss.criterions.time_domain import TimeDomainL1
from espnet2.enh.loss.wrappers.mixit_solver import MixITSolver


@pytest.mark.parametrize("num_spk, time_domain", [(4, True), (4, False)])
def test_MixITSolver_forward(num_spk, time_domain):
@pytest.mark.parametrize("inf_num, time_domain", [(4, True), (4, False)])
def test_MixITSolver_forward(inf_num, time_domain):

batch = 2
if time_domain:
solver = MixITSolver(TimeDomainL1())

inf = [torch.rand(batch, 100) for _ in range(num_spk)]
inf = [torch.rand(batch, 100) for _ in range(inf_num)]
# 2 speaker's reference
ref = [torch.zeros(batch, 100), torch.zeros(batch, 100)]
else:
solver = MixITSolver(FrequencyDomainL1())

inf = [torch.rand(batch, 100, 10, 10) for _ in range(num_spk)]
inf = [torch.rand(batch, 100, 10, 10) for _ in range(inf_num)]
# 2 speaker's reference
ref = [torch.zeros(batch, 100, 10, 10), torch.zeros(batch, 100, 10, 10)]

Expand All @@ -34,7 +35,7 @@ def test_MixITSolver_forward(num_spk, time_domain):
correct_perm1 = (
F.one_hot(
torch.tensor([1, 1, 0, 0], dtype=torch.int64),
num_classes=num_spk // 2,
num_classes=inf_num // 2,
)
.transpose(1, 0)
.float()
Expand All @@ -44,7 +45,70 @@ def test_MixITSolver_forward(num_spk, time_domain):
correct_perm2 = (
F.one_hot(
torch.tensor([0, 1, 1, 0], dtype=torch.int64),
num_classes=num_spk // 2,
num_classes=inf_num // 2,
)
.transpose(1, 0)
.float()
)
assert perm[1].equal(torch.tensor(correct_perm2))


@pytest.mark.parametrize(
"inf_num, torch_complex",
[(4, True), (4, False)],
)
def test_MixITSolver_complex_tensor_forward(inf_num, torch_complex):

batch = 2
solver = MixITSolver(FrequencyDomainL1())

if torch_complex:
inf = [
torch.rand(batch, 100, 10, 10, dtype=torch.cfloat) for _ in range(inf_num)
]
# 2 speaker's reference
ref = [
torch.zeros(batch, 100, 10, 10, dtype=torch.cfloat),
torch.zeros(batch, 100, 10, 10, dtype=torch.cfloat),
]
else:
inf = [
ComplexTensor(
torch.rand(batch, 100, 10, 10),
torch.rand(batch, 100, 10, 10),
)
for _ in range(inf_num)
]
# 2 speaker's reference
ref = [
ComplexTensor(
torch.zeros(batch, 100, 10, 10),
torch.zeros(batch, 100, 10, 10),
)
for _ in range(inf_num // 2)
]

ref[0][0] = inf[2][0] + inf[3][0] # sample1, speaker 1
ref[1][0] = inf[0][0] + inf[1][0] # sample1, speaker 2
ref[0][1] = inf[0][1] + inf[3][1] # sample2, speaker 1
ref[1][1] = inf[1][1] + inf[2][1] # sample2, speaker 2

loss, stats, others = solver(ref, inf)
perm = others["perm"]
correct_perm1 = (
F.one_hot(
torch.tensor([1, 1, 0, 0], dtype=torch.int64),
num_classes=inf_num // 2,
)
.transpose(1, 0)
.float()
)
assert perm[0].equal(torch.tensor(correct_perm1))

correct_perm2 = (
F.one_hot(
torch.tensor([0, 1, 1, 0], dtype=torch.int64),
num_classes=inf_num // 2,
)
.transpose(1, 0)
.float()
Expand Down

0 comments on commit 621a0e8

Please sign in to comment.