Resources:

[https://pytorch.org/audio/stable/tutorials/audio_io_tutorial.html]

[https://www.kaggle.com/code/enrcdamn/tempo-estimation-and-beat-tracking-pipeline]

[https://lo.calho.st/posts/numpy-spectrogram/]


Install the necessary packages with:

In [None]:
!pip install -r requirements.txt

In [None]:
import torch
import torchaudio
import torchaudio.transforms as T
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

print(torch.__version__)
print(torchaudio.__version__)

In [None]:
fft_window_size = 256 # Should be a a value 2^x

In [None]:
waveform, sample_rate = torchaudio.load('superhero_64kbps.mp3')
print("Sample Rate = ", sample_rate)

In [None]:
audio_length = waveform.shape[1] / sample_rate
print("audio length (seconds) = ", audio_length)
print("audio length (mins and secs) = ", f"{audio_length//60:.0f}", "m", f"{audio_length%60:.2f}", "s")

In [None]:
def plot_waveform(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")

In [None]:
plot_waveform(waveform, sample_rate)

In [None]:
def plot_specgram(waveform, sample_rate, title="Spectrogram", start_time=0, end_time=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    print("Num Channels = ", num_channels)
    print("Num Frames = ", num_frames)

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        num_samples_start = sample_rate*start_time
        if end_time:
            num_samples_end = sample_rate*end_time
        else:
            num_samples_end = len(waveform[c])
        axes[c].specgram(waveform[c][num_samples_start:num_samples_end], Fs=sample_rate)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle(title)

In [None]:
plot_specgram(waveform, sample_rate, start_time=15, end_time=20)

Thoughts:

Idea 1: Compute the average frequency of the whole song.  Times frames at which the average frequency is significantly lower than average are considered 'drum beats'.

Idea 2: Compute the average power of the whole song.  Time frames at higher power are 'drum beats'.

Idea 3: Simple power threshold as a % of the max power.  (This is done below 17/07/2024)

In [None]:
def compute_power(waveform, sample_rate, fft_window_size):
    #waveform = waveform.numpy()

    spectrogram_machine = T.Spectrogram(n_fft=fft_window_size)
    spec = spectrogram_machine(waveform)
    spec = spec.numpy()
    print("(Channel, freq, time) = ", spec.shape)

    power = []
    for time_index in range(len(spec[0][0][:])):
        sum = 0
        for freq_index in range(len(spec[0][0:-1])):
            sum += spec[0][freq_index][time_index]
        power.append(sum)
    
    plt.plot(power)
    print(max(power))
    return power

In [None]:
power = compute_power(waveform, sample_rate, fft_window_size)

In [None]:
def plot_power(power, audio_length, start_time=None, end_time=None):
    if start_time:
        start_index = int((start_time/audio_length) * len(power))
    else:
        start_index = 0

    if end_time:
        end_index = int((end_time/audio_length) * len(power))
    else:
        end_index = len(power)
    
    print(start_index, end_index)

    t = np.arange(start_index, end_index) * audio_length/len(power)
    plt.plot(t, power[start_index:end_index])
    
    
    plt.xlabel("Time (sec)")
    plt.ylabel("Integrated Power")

In [None]:
plot_power(power, audio_length, 0, 10)

A very simple beat detector that just looks for power peaks within a certain % value of the maximum peak power.  It's not very good if the song has different phases/sections that differ significantly from each other.

In [None]:
def detect_beats_from_power(power, audio_length, power_threshold_factor):
    max_power = max(power)
    power_threshold = power_threshold_factor*max_power
    time_per_fft_frame = audio_length/len(power)
    cooldown_time = 0.6
    cooldown_frames = int(cooldown_time/time_per_fft_frame)
    print("cooldown frames = ", cooldown_frames)

    beat_times = []
    cooldown_ticker = 0

    for i in range(len(power)):
        current_time = time_per_fft_frame * i
        #print("current_time = ", current_time)

        if cooldown_ticker > 0:
            cooldown_ticker -= 1
            continue

        if power[i]>power_threshold:
            beat_times.append(current_time)
            cooldown_ticker = cooldown_frames

    return beat_times

In [None]:
beat_times = detect_beats_from_power(power, audio_length, 0.80)

print("Number of beats detected = ", len(beat_times))
print(beat_times)

Try working with the frequencies in the spectrogram instead, maybe a better beat detector can be made.

In [None]:
def compute_avg_freq(waveform, sample_rate, fft_window_size):
    spectrogram_machine = T.Spectrogram(n_fft=fft_window_size)
    spec = spectrogram_machine(waveform)
    spec = spec.numpy()
    print("(Channel, freq, time) = ", spec.shape)

    avg_freq = []
    for time_index in range(len(spec[0][0][0:-1])):
        weighted_sum = 0.0
        norm_factor = 0.0
        for freq_index in range(len(spec[0][0:-1])):
            current_freq = float(freq_index) * sample_rate/fft_window_size
            weighted_sum += spec[0][freq_index][time_index]*current_freq
            norm_factor += spec[0][freq_index][time_index]
            #print(freq_index, " ", weighted_sum, " ", norm_factor)
        avg_freq.append(weighted_sum / norm_factor)
    
    return avg_freq

    

In [None]:
def plot_avg_freq(avg_freq, audio_length, start_time=None, end_time=None):
    if start_time:
        start_index = int((start_time/audio_length) * len(avg_freq))
    else:
        start_index = 0

    if end_time:
        end_index = int((end_time/audio_length) * len(avg_freq))
    else:
        end_index = len(avg_freq)
    
    print(start_index, end_index)

    t = np.arange(start_index, end_index) * audio_length/len(avg_freq)
    plt.plot(t, avg_freq[start_index:end_index])
    
    
    plt.xlabel("Time (sec)")
    plt.ylabel("Average Frequency (Hz)")

In [None]:
avg_freq = compute_avg_freq(waveform, sample_rate, fft_window_size)

In [None]:
plot_avg_freq(avg_freq, audio_length)

In [None]:
plot_avg_freq(avg_freq, audio_length, start_time=0, end_time=10)