-
Notifications
You must be signed in to change notification settings - Fork 47
/
mbstoi.py
330 lines (293 loc) · 12.1 KB
/
mbstoi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
"""Modified Binaural Short-Time Objective Intelligibility (MBSTOI) Measure"""
import importlib.resources as pkg_resources
import logging
import math
import numpy as np
import yaml # type: ignore
from numpy import ndarray
from scipy.signal import resample
from clarity.evaluator.mbstoi.mbstoi_utils import (
equalisation_cancellation,
remove_silent_frames,
stft,
thirdoct,
)
# pylint: disable=too-many-locals
# basic stoi parameters from file
params_file = pkg_resources.open_text(__package__, "parameters.yaml")
basic_stoi_parameters = yaml.safe_load(params_file.read())
def mbstoi(
left_ear_clean: ndarray,
right_ear_clean: ndarray,
left_ear_noisy: ndarray,
right_ear_noisy: ndarray,
sr_signal: float,
gridcoarseness: int = 1,
sample_rate: float = 10000.0,
n_frame: int = 256,
fft_size_in_samples: int = 512,
n_third_octave_bands: int = 15,
centre_freq_first_third_octave_hz: int = 150,
n_frames: int = 30,
dyn_range: int = 40,
tau_min: float = -0.001,
tau_max: float = 0.001,
gamma_min: int = -20,
gamma_max: int = 20,
sigma_delta_0: float = 65e-6,
sigma_epsilon_0: float = 1.5,
alpha_0_db: int = 13,
tau_0: float = 1.6e-3,
level_shift_deviation: float = 1.6,
) -> float:
"""The Modified Binaural Short-Time Objective Intelligibility (mbstoi) measure.
Args:
left_ear_clean (ndarray): Clean speech signal from left ear.
right_ear_clean (ndarray): Clean speech signal from right ear.
left_ear_noisy (ndarray) : Noisy/processed speech signal from left ear.
right_ear_noisy (ndarray) : Noisy/processed speech signal from right ear.
fs_signal (int) : Frequency sample rate of signal.
gridcoarseness (int) : Grid coarseness as denominator of ntaus and ngammas.
Defaults to 1.
sample_rate (int) : Sample Rate.
n_frame (int) : Number of Frames.
fft_size_in_samples (int) : ??? size in samples.
n_third_octave_bands (int) : Number of third octave bands.
centre_freq_first_third_octave_hz (int) : 150,
n_frames (int) : Number of Frames.
dyn_range (int) : Dynamic Range.
tau_min (float) : Min Tau the ???
tau_max (float) : Max Tau the ???
gamma_min (int) : Minimum gamma the ???
gamma_max (int) : Maximum gamma the ???
sigma_delta_0 (float) : ???
sigma_epsilon_0 (float) : ???
alpha_0_db (int) : ???
tau_0 (float) : ???
level_shift_deviation (float) : ???
Returns:
float : mbstoi index d.
Notes:
All title, copyrights and pending patents pertaining to mbtsoi[1]_ in and to the
original Matlab software are owned by oticon a/s and/or Aalborg University.
Please see `http://ah-andersen.net/code/<http://ah-andersen.net/code/>`
.. [1] A. H. Andersen, J. M. de Haan, Z.-H. Tan, and J. Jensen (2018) Refinement and
validation of the binaural short time objective intelligibility measure for
spatially diverse conditions. Speech Communication vol. 102, pp. 1-13
doi:10.1016/j.specom.2018.06.001 <https://doi.org/10.1016/j.specom.2018.06.001>
"""
n_taus = math.ceil(100 / gridcoarseness) # number of tau values to try out
n_gammas = math.ceil(40 / gridcoarseness) # number of gamma values to try out
# prepare signals, ensuring that inputs are column vectors
left_ear_clean = left_ear_clean.flatten()
right_ear_clean = right_ear_clean.flatten()
left_ear_noisy = left_ear_noisy.flatten()
right_ear_noisy = right_ear_noisy.flatten()
# Resample signals to 10 kHz
if sr_signal != sample_rate:
logging.debug(
"Resampling signals with sr=%s for MBSTOI calculation.", sample_rate
)
# Assumes fs_signal is 44.1 kHz
length_left_ear_clean = len(left_ear_clean)
left_ear_clean = resample(
left_ear_clean, int(length_left_ear_clean * (sample_rate / sr_signal) + 1)
)
right_ear_clean = resample(
right_ear_clean, int(length_left_ear_clean * (sample_rate / sr_signal) + 1)
)
left_ear_noisy = resample(
left_ear_noisy, int(length_left_ear_clean * (sample_rate / sr_signal) + 1)
)
right_ear_noisy = resample(
right_ear_noisy, int(length_left_ear_clean * (sample_rate / sr_signal) + 1)
)
# Remove silent frames
(
left_ear_clean,
right_ear_clean,
left_ear_noisy,
right_ear_noisy,
) = remove_silent_frames(
left_ear_clean,
right_ear_clean,
left_ear_noisy,
right_ear_noisy,
dyn_range,
n_frame,
n_frame / 2,
)
# Handle case when signals are zeros
if (
abs(np.log10(np.linalg.norm(left_ear_clean) / np.linalg.norm(left_ear_noisy)))
> 5.0
or abs(
np.log10(np.linalg.norm(right_ear_clean) / np.linalg.norm(right_ear_noisy))
)
> 5.0
):
sii = 0
# STDFT and filtering
# Get 1/3 octave band matrix
[
octave_band_matrix,
centre_frequencies,
frequency_band_edges_indices,
_freq_low,
_freq_high,
] = thirdoct(
sample_rate,
fft_size_in_samples,
n_third_octave_bands,
centre_freq_first_third_octave_hz,
)
# This is now the angular frequency in radians per sec
centre_frequencies = 2 * math.pi * centre_frequencies
# Apply short time DFT to signals and transpose
left_ear_clean_hat = stft(left_ear_clean, n_frame, fft_size_in_samples).transpose()
right_ear_clean_hat = stft(
right_ear_clean, n_frame, fft_size_in_samples
).transpose()
left_ear_noisy_hat = stft(left_ear_noisy, n_frame, fft_size_in_samples).transpose()
right_ear_noisy_hat = stft(
right_ear_noisy, n_frame, fft_size_in_samples
).transpose()
# Take single sided spectrum of signals
idx_upper = int(fft_size_in_samples / 2 + 1)
left_ear_clean_hat = left_ear_clean_hat[0:idx_upper, :]
right_ear_clean_hat = right_ear_clean_hat[0:idx_upper, :]
left_ear_noisy_hat = left_ear_noisy_hat[0:idx_upper, :]
right_ear_noisy_hat = right_ear_noisy_hat[0:idx_upper, :]
# Compute intermediate correlation via EC search
logging.info("Starting EC evaluation")
# Here intermediate correlation coefficients are evaluated for a discrete set of
# gamma and tau values (a "grid") and the highest value is chosen.
intermediate_intelligibility_measure_grid = np.zeros(
(n_third_octave_bands, np.shape(left_ear_clean_hat)[1] - n_frames + 1)
)
p_ec_max = np.zeros(
(n_third_octave_bands, np.shape(left_ear_clean_hat)[1] - n_frames + 1)
)
# Interaural compensation time and level values
taus = np.linspace(tau_min, tau_max, n_taus)
gammas = np.linspace(gamma_min, gamma_max, n_gammas)
# Jitter incorporated below - Equations 5 and 6 in Andersen et al. 2018
sigma_epsilon = (
np.sqrt(2)
* sigma_epsilon_0
* (1 + (abs(gammas) / alpha_0_db) ** level_shift_deviation)
/ 20
)
gammas = gammas / 20
sigma_delta = np.sqrt(2) * sigma_delta_0 * (1 + (abs(taus) / tau_0))
logging.info("Processing Equalisation Cancellation stage")
updated_intermediate_intelligibility_measure, p_ec_max = equalisation_cancellation(
left_ear_clean_hat,
right_ear_clean_hat,
left_ear_noisy_hat,
right_ear_noisy_hat,
n_third_octave_bands,
n_frames,
frequency_band_edges_indices,
centre_frequencies.flatten(),
taus,
n_taus,
gammas,
n_gammas,
intermediate_intelligibility_measure_grid,
p_ec_max,
sigma_epsilon,
sigma_delta,
)
# Compute the better ear STOI
logging.info("Computing better ear intermediate correlation coefficients")
# Arrays for the 1/3 octave envelope
left_ear_clean_third_octave_band = np.zeros(
(n_third_octave_bands, np.shape(left_ear_clean_hat)[1])
)
right_ear_clean_third_octave_band = np.zeros(
(n_third_octave_bands, np.shape(left_ear_clean_hat)[1])
)
left_ear_noisy_third_octave_band = np.zeros(
(n_third_octave_bands, np.shape(left_ear_clean_hat)[1])
)
right_ear_noisy_third_octave_band = np.zeros(
(n_third_octave_bands, np.shape(left_ear_clean_hat)[1])
)
# Apply 1/3 octave bands as described in Eq.(1) of the STOI article
for k in range(np.shape(left_ear_clean_hat)[1]):
left_ear_clean_third_octave_band[:, k] = np.dot(
octave_band_matrix, abs(left_ear_clean_hat[:, k]) ** 2
)
right_ear_clean_third_octave_band[:, k] = np.dot(
octave_band_matrix, abs(right_ear_clean_hat[:, k]) ** 2
)
left_ear_noisy_third_octave_band[:, k] = np.dot(
octave_band_matrix, abs(left_ear_noisy_hat[:, k]) ** 2
)
right_ear_noisy_third_octave_band[:, k] = np.dot(
octave_band_matrix, abs(right_ear_noisy_hat[:, k]) ** 2
)
# Arrays for better-ear correlations
dl_interm = np.zeros(
(n_third_octave_bands, len(range(n_frames, len(left_ear_clean_hat[1]) + 1)))
)
dr_interm = np.zeros(
(n_third_octave_bands, len(range(n_frames, len(left_ear_clean_hat[1]) + 1)))
)
left_improved = np.zeros(
(n_third_octave_bands, len(range(n_frames, len(left_ear_clean_hat[1]) + 1)))
)
right_improved = np.zeros(
(n_third_octave_bands, len(range(n_frames, len(left_ear_clean_hat[1]) + 1)))
)
# Compute temporary better-ear correlations
for m in range(
n_frames, np.shape(left_ear_clean_hat)[1]
): # pylint: disable=invalid-name
left_ear_clean_seg = left_ear_clean_third_octave_band[:, (m - n_frames) : m]
right_ear_clean_seg = right_ear_clean_third_octave_band[:, (m - n_frames) : m]
left_ear_noisy_seg = left_ear_noisy_third_octave_band[:, (m - n_frames) : m]
right_ear_noisy_seg = right_ear_noisy_third_octave_band[:, (m - n_frames) : m]
for n in range(n_third_octave_bands): # pylint: disable=invalid-name
left_ear_clean_n = (
left_ear_clean_seg[n, :] - np.sum(left_ear_clean_seg[n, :]) / n_frames
)
right_ear_clean_n = (
right_ear_clean_seg[n, :] - np.sum(right_ear_clean_seg[n, :]) / n_frames
)
left_ear_noisy_n = (
left_ear_noisy_seg[n, :] - np.sum(left_ear_noisy_seg[n, :]) / n_frames
)
right_ear_noisy_n = (
right_ear_noisy_seg[n, :] - np.sum(right_ear_noisy_seg[n, :]) / n_frames
)
np.sum(left_ear_clean_n * left_ear_clean_n)
left_improved[n, m - n_frames] = np.sum(
left_ear_clean_n * left_ear_clean_n
) / np.sum(left_ear_noisy_n * left_ear_noisy_n)
right_improved[n, m - n_frames] = np.sum(
right_ear_clean_n * right_ear_clean_n
) / np.sum(right_ear_noisy_n * right_ear_noisy_n)
dl_interm[n, m - n_frames] = np.sum(left_ear_clean_n * left_ear_noisy_n) / (
np.linalg.norm(left_ear_clean_n) * np.linalg.norm(left_ear_noisy_n)
)
dr_interm[n, m - n_frames] = np.sum(
right_ear_clean_n * right_ear_noisy_n
) / (np.linalg.norm(right_ear_clean_n) * np.linalg.norm(right_ear_noisy_n))
# Get the better ear intermediate coefficients
dl_interm[~np.isfinite(dl_interm)] = 0
dr_interm[~np.isfinite(dr_interm)] = 0
p_be_max = np.maximum(left_improved, right_improved)
dbe_interm = np.zeros(np.shape(dl_interm))
idx_left_better = left_improved > right_improved
dbe_interm[idx_left_better] = dl_interm[idx_left_better]
dbe_interm[~idx_left_better] = dr_interm[~idx_left_better]
# Compute STOI measure
# Whenever a single ear provides a higher correlation than the corresponding EC
# processed alternative,the better-ear correlation is used.
idx_use_be = p_be_max > p_ec_max
updated_intermediate_intelligibility_measure[idx_use_be] = dbe_interm[idx_use_be]
sii = np.mean(updated_intermediate_intelligibility_measure)
logging.info("MBSTOI processing complete")
return sii