Skip to content


Merge pull request #3 from createsafe/revert-1-jv/log_spect
Browse files Browse the repository at this point in the history
Revert "Torchified LOG_SPECT"
  • Loading branch information
yocontra committed Apr 18, 2024
2 parents 19b03d1 + 967ec9b commit 3de955e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 227 deletions.
26 changes: 1 addition & 25 deletions src/BeatNet/
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, model, mode='online', inference_model='PF', plot=[], thread=F
self.log_spec_hop_length = int(20 * 0.001 * self.log_spec_sample_rate)
self.log_spec_win_length = int(64 * 0.001 * self.log_spec_sample_rate)
self.proc = LOG_SPECT(sample_rate=self.log_spec_sample_rate, win_length=self.log_spec_win_length,
hop_size=self.log_spec_hop_length, n_bands=[24])
hop_size=self.log_spec_hop_length, n_bands=[24], mode = self.mode)
if self.inference_model == "PF": # instantiating a Particle Filter decoder - Is Chosen for online inference
self.estimator = particle_filter_cascade(beats_per_bar=[], fps=50, plot=self.plot, mode=self.mode)
elif self.inference_model == "DBN": # instantiating an HMM decoder - Is chosen for offline inference
Expand Down Expand Up @@ -214,30 +214,6 @@ def activation_extractor_online(self, audio_path):
preds = preds.cpu().detach().numpy()
preds = np.transpose(preds[:2, :])
return preds

def process_offline(self, audio: torch.Tensor, sample_rate: int):
audio (torch.Tensor): audio signal where audio.shape = (1, N)
sample_rate (int): sampling frequency (32000, 44100, 48000, etc)

with torch.no_grad():
if sample_rate != self.sample_rate and isinstance(audio, np.ndarray):
audio = librosa.resample(y=audio, orig_sr=sample_rate, target_sr=self.sample_rate)
elif sample_rate != self.sample_rate and isinstance(audio, torch.Tensor):
audio = torchaudio.functional.resample(waveform=audio, orig_freq=sample_rate, new_freq=self.sample_rate)

feats = self.proc.process_audio(audio).T
feats = torch.permute(feats, (2, 0, 1))
# feats = torch.from_numpy(feats)
# feats = feats.unsqueeze(0).to(self.device)
feats =
preds = self.model(feats)[0] # extracting the activations by passing the feature through the NN
preds = self.model.final_pred(preds)
preds = preds.cpu().detach().numpy()
preds = np.transpose(preds[:2, :])
return self.estimator(preds)

def process_offline(self, audio: Iterable, sample_rate: int) -> np.ndarray:
with torch.no_grad():
Expand Down
238 changes: 36 additions & 202 deletions src/BeatNet/
Original file line number Diff line number Diff line change
@@ -1,209 +1,43 @@
# feature extractor that extracts magnitude spectrogoram and its differences
from typing import Iterable
import pprint

import librosa
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt

# torch.set_printoptions(profile="full")

def log_frequencies(bands_per_octave: int, fmin: float, fmax: float, fref: float=440):
Returns frequencies aligned on a logarithmic frequency scale.
bands_per_octave : int
Number of filter bands per octave.
fmin : float
Minimum frequency [Hz].
fmax : float
Maximum frequency [Hz].
fref : float, optional
Tuning frequency [Hz].
log_frequencies : numpy array
Logarithmically spaced frequencies [Hz].
If `bands_per_octave` = 12 and `fref` = 440 are used, the frequencies are
equivalent to MIDI notes.
# get the range
left = np.floor(np.log2(float(fmin) / fref) * bands_per_octave)
right = np.ceil(np.log2(float(fmax) / fref) * bands_per_octave)
# generate frequencies
frequencies = fref * 2. ** (torch.arange(left, right) /
# filter frequencies
# needed, because range might be bigger because of the use of floor/ceil
frequencies = frequencies[torch.searchsorted(frequencies, fmin):]
frequencies = frequencies[:torch.searchsorted(frequencies, fmax, right=True)]
# return
return frequencies

def frequencies2bins(frequencies, bin_frequencies, unique_bins=False):
Map frequencies to the closest corresponding bins.
frequencies : numpy array
Input frequencies [Hz].
bin_frequencies : numpy array
Frequencies of the (FFT) bins [Hz].
unique_bins : bool, optional
Return only unique bins, i.e. remove all duplicate bins resulting from
insufficient resolution at low frequencies.
bins : numpy array
Corresponding (unique) bins.
It can be important to return only unique bins, otherwise the lower
frequency bins can be given too much weight if all bins are simply summed
up (as in the spectral flux onset detection).
# cast as numpy arrays
frequencies = np.asarray(frequencies)
bin_frequencies = np.asarray(bin_frequencies)
# map the frequencies to the closest bins
# solution found at:
indices = bin_frequencies.searchsorted(frequencies)
indices = np.clip(indices, 1, len(bin_frequencies) - 1)
left = bin_frequencies[indices - 1]
right = bin_frequencies[indices]
indices -= frequencies - left < right - frequencies
# only keep unique bins if requested
if unique_bins:
indices = np.unique(indices)
# return the (unique) bin indices of the closest matches
return indices
# Author: Mojtaba Heydari <>

def triangular_filter(channels, bins, fft_size, overlap=True, normalize=True):

num_filters = len(bins) - 2
filters = torch.zeros(size=[num_filters, fft_size])

for n in range(num_filters):
# get start, center and stop bins
start, center, stop = bins[n:n+3]
if not overlap:
start = int(np.floor((center + start)) / 2)
stop = int(np.ceil((center + stop)) / 2)
from import SignalProcessor, FramedSignalProcessor
from import ShortTimeFourierTransformProcessor
from import (
FilteredSpectrogramProcessor, LogarithmicSpectrogramProcessor,
from madmom.processors import ParallelProcessor, SequentialProcessor
from BeatNet.common import *

if stop - start < 2:
center = start
stop = start + 1

filters[n, start:center] = torch.linspace(start=0, end=(1 - (1 / (center-start))), steps=center-start)
filters[n, center:stop] = torch.linspace(start=1, end=(0 + (1 / (center-start))), steps=stop-center)

if normalize:
filters = torch.div(filters.T, filters.sum(dim=1)).T

filters = filters.repeat(channels, 1, 1)

return filters

def log_magnitude(spectrogram: torch.Tensor,
mul: float,
addend: float):
return torch.log10((spectrogram * mul) + addend)
# feature extractor that extracts magnitude spectrogoram and its differences

class LOG_SPECT():
def __init__(self, *,
sample_rate: int=48000,
win_length: int=2048,
hop_size: int=512,
n_bands: Iterable[int]=12,
fmin: float=30,
fmax: float=17000,
channels: int=1,
unique_bins: bool=True):

class LOG_SPECT(FeatureModule):
def __init__(self, num_channels=1, sample_rate=22050, win_length=2048, hop_size=512, n_bands=[12], mode='online'):
sig = SignalProcessor(num_channels=num_channels, win_length=win_length, sample_rate=sample_rate)
self.sample_rate = sample_rate
self.fft_size = win_length
self.hop_size = hop_size
self.fmin = fmin
self.fmax = fmax
self.channels = channels
if isinstance(n_bands, Iterable):
self.num_bands_per_octave = n_bands[0]
self.num_bands_per_octave = n_bands

# get log spaced frequencies
self.freqs = log_frequencies(bands_per_octave=self.num_bands_per_octave,

# use double fft_size so that dims match when negative
self._spectrogram_processor = lambda signal : torch.stft(signal,
self._fft_freqs = np.linspace(0, self.sample_rate/2, self.fft_size//2)
self._bins = frequencies2bins(self.freqs, self._fft_freqs, unique_bins)
self._filters = triangular_filter(self.channels, self._bins, self.fft_size//2)

def process_audio(self, signal: torch.Tensor):
assert len(signal.shape) == 2, "signal must have dimensions [num_channels, num_samples]"
assert signal.shape[0] == self.channels, f"signal has {signal.shape[0]} channels but this object has {self.channels}"
spectrogram = self._spectrogram_processor(signal).abs()
spectrogram = spectrogram[:, :self.fft_size//2, :]
filtered = torch.matmul(self._filters, spectrogram)
result = log_magnitude(filtered, 1, 1)
diff = torch.diff(result, dim=2, prepend=torch.zeros((result.shape[0], result.shape[1], 1)))
diff *= (diff > 0).to(diff.dtype)
result =, diff), dim=1)
return result

if __name__ == '__main__':
# test
import matplotlib.pyplot as plt

def square(t: torch.Tensor,
period_ms: float) -> torch.Tensor:
sample_rate = int(1.0 / t[1] - t[0])
sample_period = int((period_ms / 1000) * sample_rate)
result = torch.zeros_like(t)

start = 0
end = sample_period
while end < len(t):
result[start:end] = 1
start += 2*sample_period
end += 2*sample_period
return result

sample_rate = 22050
t = torch.linspace(0, 3, sample_rate*3)
signal = torch.cos(t * 440 * 2 * np.pi)
audio = signal * square(t, 500)

# plt.plot(t, audio)

audio = audio.unsqueeze(dim=0)
spec = LOG_SPECT(channels=1, win_length=4096, hop_size=256)
spectrogram = spec.process_audio(audio)
self.hop_length = hop_size
self.num_channels = num_channels
multi = ParallelProcessor([])
frame_sizes = [win_length]
num_bands = n_bands
for frame_size, num_bands in zip(frame_sizes, num_bands):
if mode == 'online' or mode == 'offline':
frames = FramedSignalProcessor(frame_size=frame_size, hop_size=hop_size)
else: # for real-time and streaming modes
frames = FramedSignalProcessor(frame_size=frame_size, hop_size=hop_size, num_frames=4)
stft = ShortTimeFourierTransformProcessor() # caching FFT window
filt = FilteredSpectrogramProcessor(
num_bands=num_bands, fmin=30, fmax=17000, norm_filters=True)
spec = LogarithmicSpectrogramProcessor(mul=1, add=1)
diff = SpectrogramDifferenceProcessor(
diff_ratio=0.5, positive_diffs=True, stack_diffs=np.hstack)
# process each frame size with spec and diff sequentially
multi.append(SequentialProcessor((frames, stft, filt, spec, diff)))
# stack the features and processes everything sequentially
self.pipe = SequentialProcessor((sig, multi, np.hstack))

def process_audio(self, audio):
feats = self.pipe(audio)
return feats.T

plt.pcolormesh(spectrogram[0, :])

0 comments on commit 3de955e

Please sign in to comment.