-
Notifications
You must be signed in to change notification settings - Fork 152
/
feature.py
364 lines (288 loc) · 10.9 KB
/
feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import os
import librosa
import numpy as np
import torch
import torch.nn as nn
def stft(y, n_fft, hop_length, win_length):
"""Wrapper of the official torch.stft for single-channel and multi-channel.
Args:
y: single- or multi-channel speech with shape of [B, C, T] or [B, T]
n_fft: number of FFT
hop_length: hop length
win_length: hanning window size
Shapes:
mag: [B, F, T] if dims of input is [B, T], whereas [B, C, F, T] if dims of input is [B, C, T]
Returns:
mag, phase, real and imag with the same shape of [B, F, T] (**complex-valued** STFT coefficients)
"""
num_dims = y.dim()
assert num_dims == 2 or num_dims == 3, "Only support 2D or 3D Input"
batch_size = y.shape[0]
num_samples = y.shape[-1]
if num_dims == 3:
y = y.reshape(-1, num_samples) # [B * C ,T]
complex_stft = torch.stft(
y,
n_fft,
hop_length,
win_length,
window=torch.hann_window(n_fft, device=y.device),
return_complex=True,
)
_, num_freqs, num_frames = complex_stft.shape
if num_dims == 3:
complex_stft = complex_stft.reshape(batch_size, -1, num_freqs, num_frames)
mag = torch.abs(complex_stft)
phase = torch.angle(complex_stft)
real = complex_stft.real
imag = complex_stft.imag
return mag, phase, real, imag
def istft(features, n_fft, hop_length, win_length, length=None, input_type="complex"):
"""Wrapper of the official torch.istft.
Args:
features: [B, F, T] (complex) or ([B, F, T], [B, F, T]) (mag and phase)
n_fft: num of FFT
hop_length: hop length
win_length: hanning window size
length: expected length of istft
use_mag_phase: use mag and phase as the input ("features")
Returns:
single-channel speech of shape [B, T]
"""
if input_type == "real_imag":
# the feature is (real, imag) or [real, imag]
assert isinstance(features, tuple) or isinstance(features, list)
real, imag = features
features = torch.complex(real, imag)
elif input_type == "complex":
assert torch.is_complex(features), "The input feature is not complex."
elif input_type == "mag_phase":
# the feature is (mag, phase) or [mag, phase]
assert isinstance(features, tuple) or isinstance(features, list)
mag, phase = features
features = torch.complex(mag * torch.cos(phase), mag * torch.sin(phase))
else:
raise NotImplementedError(
"Only 'real_imag', 'complex', and 'mag_phase' are supported."
)
return torch.istft(
features,
n_fft,
hop_length,
win_length,
window=torch.hann_window(n_fft, device=features.device),
length=length,
)
def mag_phase(complex_tensor):
mag, phase = torch.abs(complex_tensor), torch.angle(complex_tensor)
return mag, phase
def norm_amplitude(y, scalar=None, eps=1e-6):
if not scalar:
scalar = np.max(np.abs(y)) + eps
return y / scalar, scalar
def tailor_dB_FS(y, target_dB_FS=-25, eps=1e-6):
rms = np.sqrt(np.mean(y**2))
scalar = 10 ** (target_dB_FS / 20) / (rms + eps)
y *= scalar
return y, rms, scalar
def is_clipped(y, clipping_threshold=0.999):
return any(np.abs(y) > clipping_threshold)
def load_wav(file, sr=16000):
if len(file) == 2:
return file[-1]
else:
return librosa.load(os.path.abspath(os.path.expanduser(file)), mono=False, sr=sr)[0]
def aligned_subsample(data_a, data_b, sub_sample_length):
"""
Start from a random position and take a fixed-length segment from two speech samples
Notes
Only support one-dimensional speech signal (T,) and two-dimensional spectrogram signal (F, T)
Only support subsample in the last axis.
"""
assert data_a.shape[-1] == data_b.shape[-1], "Inconsistent dataset size."
if data_a.shape[-1] > sub_sample_length:
length = data_a.shape[-1]
start = np.random.randint(length - sub_sample_length + 1)
end = start + sub_sample_length
# data_a = data_a[..., start: end]
return data_a[..., start:end], data_b[..., start:end]
elif data_a.shape[-1] < sub_sample_length:
length = data_a.shape[-1]
pad_size = sub_sample_length - length
pad_width = [(0, 0)] * (data_a.ndim - 1) + [(0, pad_size)]
data_a = np.pad(data_a, pad_width=pad_width, mode="constant", constant_values=0)
data_b = np.pad(data_b, pad_width=pad_width, mode="constant", constant_values=0)
return data_a, data_b
else:
return data_a, data_b
def subsample(
data, sub_sample_length, start_position: int = -1, return_start_position=False
):
"""Randomly select fixed-length data from.
Args:
data: **one-dimensional data**
sub_sample_length: how long
start_position: If start index smaller than 0, randomly generate one index
"""
assert np.ndim(data) == 1, f"Only support 1D data. The dim is {np.ndim(data)}"
length = len(data)
if length > sub_sample_length:
if start_position < 0:
start_position = np.random.randint(length - sub_sample_length)
end = start_position + sub_sample_length
data = data[start_position:end]
elif length < sub_sample_length:
data = np.append(data, np.zeros(sub_sample_length - length, dtype=np.float32))
else:
pass
assert len(data) == sub_sample_length
if return_start_position:
return data, start_position
else:
return data
def overlap_cat(chunk_list, dim=-1):
"""Overlap concatenate (50%) a list of tensors.
Args:
dim: which dimension to concatenate
chunk_list(list): [[B, T], [B, T], ...]
Returns:
concatenated tensor of shape [B, T]
"""
overlap_output = []
for i, chunk in enumerate(chunk_list):
first_half, last_half = torch.split(chunk, chunk.size(-1) // 2, dim=dim)
if i == 0:
overlap_output += [first_half, last_half]
else:
overlap_output[-1] = (overlap_output[-1] + first_half) / 2
overlap_output.append(last_half)
overlap_output = torch.cat(overlap_output, dim=dim)
return overlap_output
def activity_detector(audio, fs=16000, activity_threshold=0.13, target_level=-25, eps=1e-6):
"""Return the percentage of the time the audio signal is above an energy threshold
Args:
audio:
fs:
activity_threshold:
target_level:
eps:
Returns:
"""
audio, _, _ = tailor_dB_FS(audio, target_level)
window_size = 50 # ms
window_samples = int(fs * window_size / 1000)
sample_start = 0
cnt = 0
prev_energy_prob = 0
active_frames = 0
a = -1
b = 0.2
alpha_rel = 0.05
alpha_att = 0.8
while sample_start < len(audio):
sample_end = min(sample_start + window_samples, len(audio))
audio_win = audio[sample_start:sample_end]
frame_rms = 20 * np.log10(sum(audio_win**2) + eps)
frame_energy_prob = 1.0 / (1 + np.exp(-(a + b * frame_rms)))
if frame_energy_prob > prev_energy_prob:
smoothed_energy_prob = frame_energy_prob * alpha_att + prev_energy_prob * (
1 - alpha_att
)
else:
smoothed_energy_prob = frame_energy_prob * alpha_rel + prev_energy_prob * (
1 - alpha_rel
)
if smoothed_energy_prob > activity_threshold:
active_frames += 1
prev_energy_prob = frame_energy_prob
sample_start += window_samples
cnt += 1
perc_active = active_frames / cnt
return perc_active
def batch_shuffle_frequency(tensor, indices=None):
"""Randomly shuffle frequency of a spectrogram and return shuffle indices.
Args:
tensor: input tensor with batch dim
indices:
Examples:
input =
tensor([[[[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.]]],
[[[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.]]]])
output =
tensor([[[[3., 3., 3.],
[4., 4., 4.],
[2., 2., 2.],
[1., 1., 1.]]],
[[[3., 3., 3.],
[2., 2., 2.],
[1., 1., 1.],
[4., 4., 4.]]]])
Shapes:
tensor: [B, C, F, T]
out: [B, C, F T]
indices: [B, C, F, T]
Returns:
out: after frequency shuffle
indices: shuffle matrix
"""
assert tensor.ndim == 4
batch_size, num_channels, num_freqs, num_frames = tensor.shape
if not torch.is_tensor(indices):
indices = torch.stack(
[torch.randperm(num_freqs, device=tensor.device) for _ in range(batch_size)],
dim=0,
)
indices = indices[:, None, :, None].repeat(1, num_channels, 1, num_frames)
out = torch.gather(tensor, dim=2, index=indices)
return out, indices
def drop_band(input, num_groups=2):
"""Reduce computational complexity of the sub-band part in the FullSubNet model.
Shapes:
input: [B, C, F, T]
return: [B, C, F // num_groups, T]
"""
batch_size, _, num_freqs, _ = input.shape
assert (
batch_size > num_groups
), f"Batch size = {batch_size}, num_groups = {num_groups}. The batch size should larger than the num_groups."
if num_groups <= 1:
# No demand for grouping
return input
# Each sample must has the same number of the frequencies for parallel training.
# Therefore, we need to drop those remaining frequencies in the high frequency part.
if num_freqs % num_groups != 0:
input = input[..., : (num_freqs - (num_freqs % num_groups)), :]
num_freqs = input.shape[2]
output = []
for group_idx in range(num_groups):
samples_indices = torch.arange(
group_idx, batch_size, num_groups, device=input.device
)
freqs_indices = torch.arange(group_idx, num_freqs, num_groups, device=input.device)
selected_samples = torch.index_select(input, dim=0, index=samples_indices)
selected = torch.index_select(
selected_samples, dim=2, index=freqs_indices
) # [B, C, F // num_groups, T]
output.append(selected)
return torch.cat(output, dim=0)
class ChannelWiseLayerNorm(nn.LayerNorm):
"""Channel wise layer normalization"""
def __init__(self, *args, **kwargs):
super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs)
def forward(self, x):
"""
x: BS x N x K
"""
if x.dim() != 3:
raise RuntimeError("{} accept 3D tensor as input".format(self.__name__))
# BS x N x K => BS x K x N
x = torch.transpose(x, 1, 2)
x = super(ChannelWiseLayerNorm, self).forward(x)
x = torch.transpose(x, 1, 2)
return x