Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solves bss permutation efficiently #318

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions mir_eval/separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import scipy.fftpack
from scipy.linalg import toeplitz
from scipy.signal import fftconvolve
from scipy.optimize import linear_sum_assignment
import collections
import itertools
import warnings
Expand Down Expand Up @@ -214,12 +215,10 @@ def bss_eval_sources(reference_sources, estimated_sources,
_bss_source_crit(s_true, e_spat, e_interf, e_artif)

# select the best ordering
perms = list(itertools.permutations(list(range(nsrc))))
mean_sir = np.empty(len(perms))
dum = np.arange(nsrc)
for (i, perm) in enumerate(perms):
mean_sir[i] = np.mean(sir[perm, dum])
popt = perms[np.argmax(mean_sir)]
if sir.shape[0] == 1:
dum = popt = np.arange(nsrc)
else:
dum, popt = _linear_sum_assignment_with_inf(-sir.T)
idx = (popt, dum)
return (sdr[idx], sir[idx], sar[idx], np.asarray(popt))
else:
Expand Down Expand Up @@ -456,12 +455,11 @@ def bss_eval_images(reference_sources, estimated_sources,
_bss_image_crit(s_true, e_spat, e_interf, e_artif)

# select the best ordering
perms = list(itertools.permutations(list(range(nsrc))))
mean_sir = np.empty(len(perms))
dum = np.arange(nsrc)
for (i, perm) in enumerate(perms):
mean_sir[i] = np.mean(sir[perm, dum])
popt = perms[np.argmax(mean_sir)]
if sir.shape[0] == 1:
dum = popt = np.arange(nsrc)
else:
dum, popt = _linear_sum_assignment_with_inf(-sir.T)

idx = (popt, dum)
return (sdr[idx], isr[idx], sir[idx], sar[idx], np.asarray(popt))
else:
Expand Down Expand Up @@ -919,3 +917,38 @@ def evaluate(reference_sources, estimated_sources, **kwargs):
scores['Sources - Source permutation'] = perm.tolist()

return scores


def _linear_sum_assignment_with_inf(cost_matrix):
'''
Solves the permutation problem efficiently via the linear sum
assignment problem.
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html

This implementation was proposed by @louisabraham in
https://github.com/scipy/scipy/issues/6900
to handle infinite entries in the cost matrix.
'''
cost_matrix = np.asarray(cost_matrix)
min_inf = np.isneginf(cost_matrix).any()
max_inf = np.isposinf(cost_matrix).any()
if min_inf and max_inf:
raise ValueError("matrix contains both inf and -inf")

if min_inf or max_inf:
cost_matrix = cost_matrix.copy()
values = cost_matrix[~np.isinf(cost_matrix)]
m = values.min()
M = values.max()
n = min(cost_matrix.shape)
# strictly positive constant even when added
# to elements of the cost matrix
positive = n * (M - m + np.abs(M) + np.abs(m) + 1)
if max_inf:
place_holder = (M + (n - 1) * (M - m)) + positive
if min_inf:
place_holder = (m + (n - 1) * (m - M)) - positive

cost_matrix[np.isinf(cost_matrix)] = place_holder

return linear_sum_assignment(cost_matrix)
46 changes: 46 additions & 0 deletions tests/test_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,49 @@ def __unit_test_framewise_small_window(metric):
atol=A_TOL)


def __unit_test_linear_sum_assignment():

cost = np.array([[4, 1, 3], [2, 0, 5], [3, 2, 2]], dtype=np.float)
cost_inf = cost

# test a normal case
_, col_ind = mir_eval.separation._linear_sum_assignment_with_inf(
cost
)
assert np.allclose([1, 0, 2], col_ind)

# test a case with one negative infinity
cost[0, 1] = -np.inf
_, col_ind = mir_eval.separation._linear_sum_assignment_with_inf(
cost
)
assert np.allclose([1, 0, 2], col_ind)

# test a case with one positive infinity
cost[0, 1] = 1
cost[0, 0] = np.inf
_, col_ind = mir_eval.separation._linear_sum_assignment_with_inf(
cost
)
assert np.allclose([1, 0, 2], col_ind)

# make sure the exception due to both pos and neg
# infinity is caught
pos_neg_inf = False
try:
# make it fail ...
cost[0, 0] = -np.inf
cost[1, 1] = np.inf
_, col_ind = mir_eval.separation._linear_sum_assignment_with_inf(
cost
)
except ValueError:
# ... and catch the exception
pos_neg_inf = True

assert pos_neg_inf


def test_separation_functions():
# Load in all files in the same order
ref_files = sorted(glob.glob(REF_GLOB))
Expand All @@ -260,6 +303,9 @@ def test_separation_functions():
mir_eval.separation.bss_eval_images_framewise]:
yield (__unit_test_framewise_small_window, metric)
yield (__unit_test_partial_silence, metric)

yield(__unit_test_linear_sum_assignment)

# Regression tests
for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files):
with open(sco_f, 'r') as f:
Expand Down