Skip to content

Commit

Permalink
Dtype inference and control in CQT (#1171)
Browse files Browse the repository at this point in the history
* added dtype inference helpers

* added dtype control to CQT module

* rewrote trim_stack

* fixed a bug in trim stack

* fixed a bug in trim stack

* hacking on hybrid_cqt

* fixed bug in trim_stack

* more bugfixing trim_stack

* cut down test fixture durations for cqt

* added test case for cqt dtype
  • Loading branch information
bmcfee committed Jun 22, 2020
1 parent 3cab0b6 commit 86cd2c1
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 53 deletions.
109 changes: 75 additions & 34 deletions librosa/core/constantq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
bins_per_octave=12, tuning=0.0, filter_scale=1,
norm=1, sparsity=0.01, window='hann',
scale=True, pad_mode='reflect', res_type=None):
scale=True, pad_mode='reflect', res_type=None, dtype=None):
'''Compute the constant-Q transform of an audio signal.
This implementation is based on the recursive sub-sampling method
Expand Down Expand Up @@ -106,9 +106,13 @@ def cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
but potentially slow FFT-based down-sampling, while `res_type='polyphase'` will
use a fast, but potentially inaccurate down-sampling.
dtype : np.dtype
The (complex) data type of the output array. By default, this is inferred to match
the numerical precision of the input signal.
Returns
-------
CQT : np.ndarray [shape=(n_bins, t), dtype=np.complex or np.float]
CQT : np.ndarray [shape=(n_bins, t)]
Constant-Q value each frequency at each time.
Raises
Expand Down Expand Up @@ -173,14 +177,14 @@ def cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
bins_per_octave=bins_per_octave,
tuning=tuning, filter_scale=filter_scale,
norm=norm, sparsity=sparsity, window=window, scale=scale,
pad_mode=pad_mode, res_type=res_type)
pad_mode=pad_mode, res_type=res_type, dtype=dtype)


@cache(level=20)
def hybrid_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
bins_per_octave=12, tuning=0.0, filter_scale=1,
norm=1, sparsity=0.01, window='hann', scale=True,
pad_mode='reflect', res_type=None):
pad_mode='reflect', res_type=None, dtype=None):
'''Compute the hybrid constant-Q transform of an audio signal.
Here, the hybrid CQT uses the pseudo CQT for higher frequencies where
Expand Down Expand Up @@ -236,6 +240,11 @@ def hybrid_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
res_type : string
Resampling mode. See `librosa.core.cqt` for details.
dtype : np.dtype, optional
The complex dtype to use for computing the CQT.
By default, this is inferred to match the precision of
the input signal.
Returns
-------
CQT : np.ndarray [shape=(n_bins, t), dtype=np.float]
Expand Down Expand Up @@ -303,7 +312,8 @@ def hybrid_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
sparsity=sparsity,
window=window,
scale=scale,
pad_mode=pad_mode))
pad_mode=pad_mode,
dtype=dtype))

if n_bins_full > 0:
cqt_resp.append(np.abs(cqt(y, sr,
Expand All @@ -317,16 +327,18 @@ def hybrid_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
window=window,
scale=scale,
pad_mode=pad_mode,
res_type=res_type)))
res_type=res_type,
dtype=dtype)))

return __trim_stack(cqt_resp, n_bins)
# Propagate dtype from the last component
return __trim_stack(cqt_resp, n_bins, cqt_resp[-1].dtype)


@cache(level=20)
def pseudo_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
bins_per_octave=12, tuning=0.0, filter_scale=1,
norm=1, sparsity=0.01, window='hann', scale=True,
pad_mode='reflect'):
pad_mode='reflect', dtype=None):
'''Compute the pseudo constant-Q transform of an audio signal.
This uses a single fft size that is the smallest power of 2 that is greater
Expand Down Expand Up @@ -381,6 +393,10 @@ def pseudo_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
See also: `librosa.core.stft` and `np.pad`.
dtype : np.dtype, optional
The complex data type for CQT calculations.
By default, this is inferred to match the precision of the input signal.
Returns
-------
CQT : np.ndarray [shape=(n_bins, t), dtype=np.float]
Expand All @@ -407,6 +423,9 @@ def pseudo_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
if tuning is None:
tuning = estimate_tuning(y=y, sr=sr, bins_per_octave=bins_per_octave)

if dtype is None:
dtype = util.dtype_r2c(y.dtype)

# Apply tuning correction
fmin = fmin * 2.0**(tuning / bins_per_octave)

Expand All @@ -415,12 +434,13 @@ def pseudo_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
filter_scale,
norm, sparsity,
hop_length=hop_length,
window=window)
window=window,
dtype=dtype)

fft_basis = np.abs(fft_basis)

# Compute the magnitude STFT with Hann window
D = np.abs(stft(y, n_fft=n_fft, hop_length=hop_length, pad_mode=pad_mode))
D = np.abs(stft(y, n_fft=n_fft, hop_length=hop_length, pad_mode=pad_mode, dtype=dtype))

# Project onto the pseudo-cqt basis
C = fft_basis.dot(D)
Expand All @@ -442,7 +462,7 @@ def pseudo_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84,
@cache(level=40)
def icqt(C, sr=22050, hop_length=512, fmin=None, bins_per_octave=12,
tuning=0.0, filter_scale=1, norm=1, sparsity=0.01, window='hann',
scale=True, length=None, res_type='fft', dtype=np.float32):
scale=True, length=None, res_type='fft', dtype=None):
'''Compute the inverse constant-Q transform.
Given a constant-Q transform representation `C` of an audio signal `y`,
Expand Down Expand Up @@ -501,7 +521,8 @@ def icqt(C, sr=22050, hop_length=512, fmin=None, bins_per_octave=12,
See `librosa.resample` for supported modes.
dtype : numeric type
Real numeric type for `y`. Default is 32-bit float.
Real numeric type for `y`. Default is inferred to match the numerical
precision of the input CQT.
Returns
-------
Expand Down Expand Up @@ -620,7 +641,7 @@ def icqt(C, sr=22050, hop_length=512, fmin=None, bins_per_octave=12,
def vqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84, gamma=None,
bins_per_octave=12, tuning=0.0, filter_scale=1,
norm=1, sparsity=0.01, window='hann',
scale=True, pad_mode='reflect', res_type=None):
scale=True, pad_mode='reflect', res_type=None, dtype=None):
'''Compute the variable-Q transform of an audio signal.
This implementation is based on the recursive sub-sampling method
Expand Down Expand Up @@ -717,6 +738,10 @@ def vqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84, gamma=None,
but potentially slow FFT-based down-sampling, while `res_type='polyphase'` will
use a fast, but potentially inaccurate down-sampling.
dtype : np.dtype
The dtype of the output array. By default, this is inferred to match the
numerical precision of the input signal.
Returns
-------
VQT : np.ndarray [shape=(n_bins, t), dtype=np.complex or np.float]
Expand Down Expand Up @@ -772,6 +797,9 @@ def vqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84, gamma=None,
if gamma is None:
gamma = 24.7 * alpha / 0.108

if dtype is None:
dtype = util.dtype_r2c(y.dtype)

# Apply tuning correction
fmin = fmin * 2.0**(tuning / bins_per_octave)

Expand Down Expand Up @@ -813,10 +841,11 @@ def vqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84, gamma=None,
norm,
sparsity,
window=window,
gamma=gamma)
gamma=gamma,
dtype=dtype)

# Compute the VQT filter response and append it to the stack
vqt_resp.append(__cqt_response(y, n_fft, hop_length, fft_basis, pad_mode))
vqt_resp.append(__cqt_response(y, n_fft, hop_length, fft_basis, pad_mode, dtype=dtype))

fmin_t /= 2
fmax_t /= 2
Expand All @@ -842,8 +871,7 @@ def vqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84, gamma=None,
if i > 0:
if len(my_y) < 2:
raise ParameterError('Input signal length={} is too short for '
'{:d}-octave CQT/VQT'.format(len_orig,
n_octaves))
'{:d}-octave CQT/VQT'.format(len_orig, n_octaves))

my_y = audio.resample(my_y, 2, 1,
res_type=res_type,
Expand All @@ -859,14 +887,15 @@ def vqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84, gamma=None,
norm,
sparsity,
window=window,
gamma=gamma)
gamma=gamma,
dtype=dtype)
# Re-scale the filters to compensate for downsampling
fft_basis[:] *= np.sqrt(2**i)

# Compute the vqt filter response and append to the stack
vqt_resp.append(__cqt_response(my_y, n_fft, my_hop, fft_basis, pad_mode))
vqt_resp.append(__cqt_response(my_y, n_fft, my_hop, fft_basis, pad_mode, dtype=dtype))

V = __trim_stack(vqt_resp, n_bins)
V = __trim_stack(vqt_resp, n_bins, dtype)

if scale:
lengths = filters.constant_q_lengths(sr, fmin,
Expand All @@ -883,7 +912,7 @@ def vqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84, gamma=None,
@cache(level=10)
def __cqt_filter_fft(sr, fmin, n_bins, bins_per_octave,
filter_scale, norm, sparsity, hop_length=None,
window='hann', gamma=0.):
window='hann', gamma=0., dtype=np.complex):
'''Generate the frequency domain constant-Q filter basis.'''

basis, lengths = filters.constant_q(sr,
Expand Down Expand Up @@ -912,31 +941,42 @@ def __cqt_filter_fft(sr, fmin, n_bins, bins_per_octave,
fft_basis = fft.fft(basis, n=n_fft, axis=1)[:, :(n_fft // 2)+1]

# sparsify the basis
fft_basis = util.sparsify_rows(fft_basis, quantile=sparsity)
fft_basis = util.sparsify_rows(fft_basis, quantile=sparsity, dtype=dtype)

return fft_basis, n_fft, lengths


def __trim_stack(cqt_resp, n_bins):
def __trim_stack(cqt_resp, n_bins, dtype):
'''Helper function to trim and stack a collection of CQT responses'''

# cleanup any framing errors at the boundaries
max_col = min(x.shape[1] for x in cqt_resp)
max_col = min(c_i.shape[-1] for c_i in cqt_resp)
cqt_out = np.empty((n_bins, max_col), dtype=dtype, order='F')

# Copy per-octave data into output array
end = n_bins
for c_i in cqt_resp:
# By default, take the whole octave
n_oct = c_i.shape[0]
# If the whole octave is more than we can fit,
# take the highest bins from c_i
if end < n_oct:
cqt_out[:end] = c_i[-end:, :max_col]
else:
cqt_out[end - n_oct:end] = c_i[:, :max_col]

cqt_resp = np.vstack([x[:, :max_col] for x in cqt_resp][::-1])
end -= n_oct

# Finally, clip out any bottom frequencies that we don't really want
# Transpose magic here to ensure column-contiguity
return np.asfortranarray(cqt_resp[-n_bins:])
return cqt_out


def __cqt_response(y, n_fft, hop_length, fft_basis, mode):
def __cqt_response(y, n_fft, hop_length, fft_basis, mode, dtype=None):
'''Compute the filter response with a target STFT hop.'''

# Compute the STFT matrix
D = stft(y, n_fft=n_fft, hop_length=hop_length,
window='ones',
pad_mode=mode)
pad_mode=mode,
dtype=dtype)

# And filter response energy
return fft_basis.dot(D)
Expand Down Expand Up @@ -1003,7 +1043,7 @@ def __num_two_factors(x):

def griffinlim_cqt(C, n_iter=32, sr=22050, hop_length=512, fmin=None, bins_per_octave=12, tuning=0.0,
filter_scale=1, norm=1, sparsity=0.01, window='hann', scale=True,
pad_mode='reflect', res_type='kaiser_fast', dtype=np.float32,
pad_mode='reflect', res_type='kaiser_fast', dtype=None,
length=None, momentum=0.99, init='random', random_state=None):
'''Approximate constant-Q magnitude spectrogram inversion using the "fast" Griffin-Lim
algorithm [1]_ [2]_.
Expand Down Expand Up @@ -1091,7 +1131,8 @@ def griffinlim_cqt(C, n_iter=32, sr=22050, hop_length=512, fmin=None, bins_per_o
See `librosa.core.resample` for a list of available options.
dtype : numeric type
Real numeric type for `y`. Default is 32-bit float.
Real numeric type for `y`. Default is inferred to match the precision
of the input CQT.
length : int > 0, optional
If provided, the output `y` is zero-padded or clipped to exactly
Expand Down Expand Up @@ -1207,7 +1248,7 @@ def griffinlim_cqt(C, n_iter=32, sr=22050, hop_length=512, fmin=None, bins_per_o
# Rebuild the spectrogram
rebuilt = cqt(inverse, sr=sr, bins_per_octave=bins_per_octave, n_bins=C.shape[0],
hop_length=hop_length, fmin=fmin, tuning=tuning, filter_scale=filter_scale,
window=window, res_type=res_type)
window=window, res_type=res_type, dtype=C.dtype)

# Update our phase estimates
angles[:] = rebuilt - (momentum / (1 + momentum)) * tprev
Expand Down
Loading

0 comments on commit 86cd2c1

Please sign in to comment.