In [1]:
# @title smulti-threaded sample analysis system

import os
import librosa
import numpy as np
import matplotlib.pyplot as plt
from brian2 import *
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import pickle
from concurrent.futures import ProcessPoolExecutor
from sklearn.utils import shuffle
import random

try:
    import ipywidgets as widgets
except ImportError:
    !pip3 install ipywidgets
    import ipywidgets as widgets

from IPython.display import display, clear_output

fixed_timesteps = 1001
sub_dirs = ['1_4', '2_4', '3_4', '4_4', '5_4', '7_8']
FILES_TO_LOAD = 20

def get_length(file_path):
    y, sr = librosa.load(file_path)
    mfccs = librosa.feature.mfcc(y=y, sr=sr)
    return mfccs.shape[1]

def determine_fixed_length(directory):
    file_paths = []

    for subdir in sub_dirs:
        files_to_load = os.listdir(os.path.join(directory, subdir))[:FILES_TO_LOAD]
        for file in tqdm(files_to_load):
            file_path = os.path.join(directory, subdir, file)
            file_paths.append(file_path)

    # Utilize multiprocessing for faster computation
    with ProcessPoolExecutor() as executor:
        lengths = list(executor.map(get_length, file_paths))

    return min(lengths)

def parallel_data_loader(directories):
    with ThreadPoolExecutor() as executor:
        results = list(tqdm(executor.map(parallel_load_and_preprocess, directories), total=len(directories)))
    return results

def load_and_preprocess_data_subdir(args):
    directory, subdir = args
    data = []
    labels = []
    
    # Only load up to 20 files per subdirectory
    files_to_load = os.listdir(os.path.join(directory, subdir))[:FILES_TO_LOAD]
    
    for file in files_to_load:
        file_path = os.path.join(directory, subdir, file)
        processed_data = load_audio(file_path)
        data.append(processed_data)
        label = sub_dirs.index(subdir)
        labels.append(label)
    
    return data, labels

def parallel_load_and_preprocess(directory):
    # Create a pool of processes
    pool = Pool(cpu_count())

    # Create a list of tasks
    tasks = [(directory, time_sig) for time_sig in sub_dirs]

    # Use imap_unordered to distribute the work among the processes
    results = list(tqdm(pool.imap_unordered(load_and_preprocess_data_subdir, tasks), total=len(tasks), mininterval=0.01))

    # Close the pool and wait for all processes to finish
    pool.close()
    pool.join()

    # Combine results
    combined_data = []
    combined_labels = []
    
    for data, labels in results:
        combined_data.extend(data)
        combined_labels.extend(labels)
    
    return combined_data, combined_labels


def adjust_fixed_length(features, timesteps):
    # If the array is 1-dimensional
    if len(features.shape) == 1:
        if features.shape[0] > timesteps:
            return features[:timesteps]
        elif features.shape[0] < timesteps:
            padding = np.zeros(timesteps - features.shape[0])
            return np.hstack((features, padding))
        return features
    # If the array is 2-dimensional
    else:
        # If the time axis of the 2D array is greater than timesteps, crop it.
        if features.shape[1] > timesteps:
            return features[:, :timesteps]
        # If the time axis of the 2D array is less than timesteps, pad it.
        elif features.shape[1] < timesteps:
            padding = np.zeros((features.shape[0], timesteps - features.shape[1]))
            return np.hstack((features, padding))
        return features

# Convert real-valued features to Poisson spike trains
def poisson_spike_encoding(data, duration=10, dt=1*ms):
    # Assuming data is normalized between 0 and 1
    rates = data * (1.0/dt)
    spikes = (np.random.rand(*data.shape) < rates*dt).astype(float)
    return spikes

def temporal_binning(data, bin_size):
    """
    Bins the data into chunks of bin_size and returns the average of each chunk.
    """
    # Split the data into chunks of bin_size
    binned_data = [np.mean(data[i:i+bin_size]) for i in range(0, len(data), bin_size)]
    return np.array(binned_data)

def rate_based_encoding(data, min_freq, max_freq):
    """
    Convert onset strengths to spike frequencies.
    data: The input data (should be normalized to [0, 1])
    min_freq: The minimum spike frequency (corresponds to data value of 0)
    max_freq: The maximum spike frequency (corresponds to data value of 1)
    Returns: Spike frequencies corresponding to input data
    """
    return min_freq + data * (max_freq - min_freq)

def extract_bpm_and_instrument(file_path):
    # Using \d+ to match one or more digits and [\d.]+ to match a float or integer pattern for noise.
    match = re.search(r"instrument_(\d+)_bpm_(\d+)_rotation_\d+_duration_(\d+)_noise_([\d.]+)", file_path)
    if match:
        instrument = match.group(1)
        bpm = match.group(2)
        duration = match.group(3)
        noise = match.group(4)
        return instrument, bpm, duration, noise
    return None, None, None, None

def moving_average(data, window_size):
    """Compute moving average"""
    return np.convolve(data, np.ones(window_size)/window_size, mode='valid')


def load_audio(file_path):
    y, sr = librosa.load(file_path, sr=22050)  # setting sr ensures all files are resampled to this rate
    return [y, sr, file_path]

# Process the audio file into desired features
# Process the audio file into desired features
def preprocess_audio(file_path):
    y, sr = librosa.load(file_path, sr=22050)  # setting sr ensures all files are resampled to this rate
    time_signature = file_path.split('/')[-2].replace('_', '/')
    instrument, bpm = extract_bpm_and_instrument(file_path)

    # Extracting onset strength
    onset_strength = librosa.onset.onset_strength(y=y, sr=sr)
    
    # Extracting tempogram
    tempogram = librosa.feature.tempogram(onset_envelope=onset_strength, sr=sr)
    
    # Extracting tempogram
    tempogram_cropped = librosa.feature.tempogram(onset_envelope=onset_strength[20:], sr=sr)
    
    # Adjust the time axis of each feature to fixed_timesteps
    onset_strength_fixed = adjust_fixed_length(onset_strength, fixed_timesteps)
    tempogram_fixed = adjust_fixed_length(tempogram, fixed_timesteps)

    # Stacking features horizontally
    combined_features = np.vstack(onset_strength)
    
    # Normalize to range [0, 1]
    encoded_features = (combined_features - np.min(combined_features)) / (np.max(combined_features) - np.min(combined_features))
    
        # Plotting
    plt.figure(figsize=(12, 14))
    plt.title('audio  with {time_signature} time signature, {bpm} bpm, and instrument {instrument}')

    rows = 6
    # 1. Raw audio
    plt.subplot(rows, 1, 1)
    librosa.display.waveshow(y, sr=sr)
    plt.title('Raw Audio')

    # 2. Onset strength
    plt.subplot(rows, 1, 2)
    plt.plot(onset_strength_fixed)
    plt.title('Onset Strength fixed size')
    
    # 2. Onset strength
    plt.subplot(rows, 1, 3)
    onset_strength_normalized = (onset_strength[20:] - np.min(onset_strength[20:])) / (np.max(onset_strength[20:]) - np.min(onset_strength[20:]))
    plt.plot(onset_strength_normalized)
    plt.title('Onset Strength normalized and cropped')
    
    # Add a plot for averaged onset strength
    plt.subplot(rows, 1, 4)
    averaged_onset = moving_average(onset_strength_normalized, window_size=5)  # using a window size of 10, adjust as needed
    plt.plot(averaged_onset)
    plt.title('Averaged Onset Strength')
    
    # 3. Tempogram
    plt.subplot(rows, 1, 5)
    librosa.display.specshow(tempogram_fixed, sr=sr, x_axis='time', y_axis='tempo')
    plt.title('Tempogram fixed')
    
        # 3. Tempogram
    plt.subplot(rows, 1, 6)
    librosa.display.specshow(tempogram_cropped, sr=sr, x_axis='time', y_axis='tempo')
    plt.title('Tempogram cropped')
    
    
    
    plt.tight_layout()
    plt.savefig(f'output_processing_noise_avg/{time_signature.replace("/", "_")}_BPM{bpm}_noise.png')
    
    return encoded_features[20:]


def count_files(directory):
    return sum([len(files) for _, _, files in os.walk(directory)])

# Current directory
directory = '.'

# Loop through all files in the current directory
for filename in os.listdir(directory):
    # Check if the filename ends with '.png' and contains 'spike_train'
    if filename.endswith('.png') and 'spike_train' in filename:
        # Construct the full file path
        filepath = os.path.join(directory, filename)
        
        # Remove the file
        os.remove(filepath)
        print(f"Deleted: {filename}", end='\r')
        

# checking shapes
print("Checking shapes...")
fixed_timesteps = determine_fixed_length('training_data_dirty_bpm')
print(fixed_timesteps)
# fixed_timesteps2 = determine_fixed_length('validation_data_dirty_bpm')
# print(fixed_timesteps2)
# fixed_timesteps = max(fixed_timesteps, fixed_timesteps2)


# 1. Load and preprocess data
print("Loading and preprocessing training data...")
directories = ['training_data_dirty_bpm', 'validation_data_dirty_bpm']
training_data_results, validation_data_results = parallel_data_loader(directories)

training_data, training_labels = training_data_results
validation_data, validation_labels = validation_data_results
print("\nDone with preprocessing!")


Checking shapes...


100%|██████████| 20/20 [00:00<00:00, 235635.06it/s]
100%|██████████| 20/20 [00:00<00:00, 203606.99it/s]
100%|██████████| 20/20 [00:00<00:00, 307275.02it/s]
100%|██████████| 20/20 [00:00<00:00, 109655.01it/s]
100%|██████████| 20/20 [00:00<00:00, 294337.12it/s]
100%|██████████| 20/20 [00:00<00:00, 264624.86it/s]


497
Loading and preprocessing training data...


  0%|          | 0/2 [00:00<?, ?it/s]
[A
[A
[A
[A
[A
[A
100%|██████████| 6/6 [00:02<00:00,  2.59it/s]
100%|██████████| 6/6 [00:02<00:00,  2.52it/s]
100%|██████████| 2/2 [00:02<00:00,  1.33s/it]


Done with preprocessing!





In [2]:
class LIFNeuron:
    def __init__(self, 
                 tau_m=6.0,  # Membrane Time Constant: Determines the rate at which the membrane potential decays towards its resting value.
                 v_rest=0.0,  # Resting Potential: The stable value of the membrane potential when no external input is present.
                 v_threshold=0.7,  # Firing Threshold: The value of the membrane potential at which the neuron generates a spike.
                 v_reset=0.2,  # Reset Potential: The value to which the membrane potential is reset after a spike is generated.
                 r_m=0.9,  # Membrane Resistance: The effective resistance of the neuron's membrane, modulating the influence of incoming spikes.
                 dt=10.0,  # Time Step: Determines the granularity of the simulation time, influencing the speed of all dynamical variables.
                 adaptive_increase=0.0,  # Adaptive Increase: The value by which the firing threshold increases after each spike.
                 refractory_period=1  # Refractory Period: The number of time steps for which the neuron cannot fire after generating a spike.
                ):        
        self.tau_m = tau_m
        self.v_rest = v_rest
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.r_m = r_m
        self.dt = dt
        self.v = v_rest
        self.adaptive_increase = adaptive_increase
        self.refractory_period = refractory_period
        self.refractory_counter = 0

    def constrain_value(self, value):
        return min(max(value, 0.0), 1.0)
    
    def update_voltage(self, i):
        dv = (-self.v + self.v_rest + self.r_m * i) / self.tau_m * self.dt
        self.v += dv
        self.v = self.constrain_value(self.v)
        
    def check_for_spike(self):
        spike = 0
        if self.v >= self.v_threshold:
            spike = 1
            self.v = self.v_reset
            
            if self.refractory_period > 0:
                self.refractory_counter = self.refractory_period
                
            if self.adaptive_increase > 0.0:
                self.v_threshold += self.adaptive_increase
                self.v_threshold = self.constrain_value(self.v_threshold)
        else:
            if self.adaptive_increase > 0.0:
                self.v_threshold = max(self.v_threshold - self.adaptive_increase, 1.0)
                self.v_threshold = self.constrain_value(self.v_threshold)
                
        return spike
    
    def step(self, i):
        if self.refractory_counter > 0:
            self.refractory_counter -= 1
            return 0

        self.update_voltage(i)
        
        # Constrain all relevant variables
        self.v_threshold = self.constrain_value(self.v_threshold)
        self.v_reset = self.constrain_value(self.v_reset)
        self.r_m = self.constrain_value(self.r_m)
        self.v_rest = self.constrain_value(self.v_rest)
        self.adaptive_increase = self.constrain_value(self.adaptive_increase)
        
        return self.check_for_spike()
        
    
def generate_lif_spikes(data, neuron):
    spikes = []
    potentials = []  # To store membrane potentials
    for i in data:
        spike = neuron.step(i)
        spikes.append(spike)
        potentials.append(neuron.v)  # Store the membrane potential after each step
    return np.array(spikes), np.array(potentials)

# Pre-process training data to filter into four different lists based on time_signature
training_data_1_4 = [item for item in training_data if item[2].split('/')[-2].replace('_', '/') == '1/4']
training_data_2_4 = [item for item in training_data if item[2].split('/')[-2].replace('_', '/') == '2/4']
training_data_3_4 = [item for item in training_data if item[2].split('/')[-2].replace('_', '/') == '3/4']
training_data_4_4 = [item for item in training_data if item[2].split('/')[-2].replace('_', '/') == '4/4']
training_data_5_4 = [item for item in training_data if item[2].split('/')[-2].replace('_', '/') == '5/4']
training_data_7_8 = [item for item in training_data if item[2].split('/')[-2].replace('_', '/') == '7/8']

print(f"1/4: {len(training_data_1_4)}")
print(f"2/4: {len(training_data_2_4)}")
print(f"3/4: {len(training_data_3_4)}")
print(f"4/4: {len(training_data_4_4)}")
print(f"5/4: {len(training_data_5_4)}")
print(f"7/8: {len(training_data_7_8)}")

1/4: 20
2/4: 20
3/4: 20
4/4: 20
5/4: 20
7/8: 20


In [3]:
import scipy.signal
from scipy.signal import savgol_filter


def smooth_using_savgol(data, window_size, polynomial_order=3):
    return savgol_filter(data, window_size, polynomial_order)

# Convert real-valued features to Poisson spike trains
def poisson_spike_encoding(data, duration=10, dt=1*ms):
    # Assuming data is normalized between 0 and 1
    rates = data * (1.0/dt)
    spikes = (np.random.rand(*data.shape) < rates*dt).astype(float)
    return spikes

def low_pass_filter(y, sr, cutoff_freq):
    nyq = 0.5 * sr  # Nyquist frequency
    normal_cutoff = cutoff_freq / nyq
    b, a = scipy.signal.butter(6, normal_cutoff, btype='low', analog=False)
    return scipy.signal.filtfilt(b, a, y)

def high_pass_filter(y, sr, cutoff_freq):
    nyq = 0.5 * sr  # Nyquist frequency
    normal_cutoff = cutoff_freq / nyq
    b, a = scipy.signal.butter(6, normal_cutoff, btype='high', analog=False)
    return scipy.signal.filtfilt(b, a, y)

def normalize_data(data):
    """Normalisiert eine Liste von Werten zwischen 0 und 1."""
    min_val = min(data)
    max_val = max(data)
    return [(val - min_val) / (max_val - min_val) for val in data]

def plot_pixel_spectra_norm(item_no, window_size, tau_m, v_rest, v_threshold, v_reset, r_m, dt, high_pass_cutoff):
    y = training_data[item_no][0]
    sr = training_data[item_no][1]
    file_path = training_data[item_no][2]
    
                # Apply a short fade-in
    # fade_in_time = 0.000  # in seconds
    # fade_in_samples = int(fade_in_time * sr)
    # fade_curve = np.linspace(0, 1, fade_in_samples)
    # y[:fade_in_samples] = y[:fade_in_samples] * fade_curve
    # Extract the MFCCs
    mfccs = librosa.feature.mfcc(y=y, sr=sr)
    # y = low_pass_filter(y, sr, low_pass_cutoff)
    # y = high_pass_filter(y, sr, high_pass_cutoff)
    
    time_signature = file_path.split('/')[-2].replace('_', '/')
    instrument, bpm, duration, noise = extract_bpm_and_instrument(file_path)
    print(f"time signature: {time_signature} BPM:{bpm} Noise:{noise} sampling rate: {sr} ", end='\r')

    # Extracting onset strength
    onset_strength = librosa.onset.onset_strength(y=y, sr=sr)
    
    # Extracting tempogram
    tempogram = librosa.feature.tempogram(onset_envelope=onset_strength, sr=sr)
    
    # Extracting Mel spectrogram
    mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr)
    mel_spectrogram_8 = librosa.feature.melspectrogram(y=y, sr=8000)

    
    # Extracting tempogram
    tempogram_cropped = librosa.feature.tempogram(onset_envelope=onset_strength, sr=sr)
    
    # Adjust the time axis of each feature to fixed_timesteps
    onset_strength_fixed = adjust_fixed_length(onset_strength, fixed_timesteps)
    tempogram_fixed = adjust_fixed_length(tempogram, fixed_timesteps)

    # Stacking features horizontally
    combined_features = np.vstack(poisson_spike_encoding(onset_strength))
    
    # Normalize to range [0, 1]
    encoded_features = (combined_features - np.min(combined_features)) / (np.max(combined_features) - np.min(combined_features))
    
    onset_strength_normalized = (onset_strength - np.min(onset_strength)) / (np.max(onset_strength) - np.min(onset_strength))

    averaged_onset = moving_average(onset_strength_normalized, window_size=window_size)  # using a window size of 10, adjust as needed
    normalized_averaged_onset = normalize_data(averaged_onset)
    
    global_tempo = librosa.feature.rhythm.tempo(onset_envelope=onset_strength, sr=sr)[0]
    dtempo = librosa.feature.rhythm.tempo(onset_envelope=onset_strength, sr=sr, aggregate=None)
    
        # Plotting
    plt.figure(figsize=(12, 14))
    plt.title('audio  with {time_signature} time signature, {bpm} bpm, and instrument {instrument}')

    rows = 7
    # 1. Raw audio
    plt.subplot(rows, 1, 1)
    # Prepare time axes for raw audio and onset strength
    time_audio = np.linspace(0, len(y) / sr, len(y))
    time_onset = np.linspace(0, len(y) / sr, len(normalize_data(onset_strength)))

    # Plot raw audio
    plt.plot(time_audio, normalize_data(y), label='Raw Audio', alpha=0.7)

    # Plot onset strength (scaled)
    plt.plot(time_onset, normalize_data(onset_strength), label='spectral flux (scaled)', alpha=0.7, color='r')

    plt.legend()
    plt.xlabel('Time (s)')

    plt.title(f'Raw Audio and Spectral Flux, predicted BPM = {global_tempo}, actual BPM = {bpm}, instrument {instrument} and time signature {time_signature}')
    plt.legend()

    # Frequency spectrum
    plt.subplot(rows, 1, 2)
    librosa.display.specshow(mfccs, x_axis='time')
    plt.title('MFCC')
    plt.xlabel('Time (s)')
    plt.ylabel('MFCC Coefficients')
    
    
    # 3. Frequency spectrum after high-pass filtering
    plt.subplot(rows, 1, 5)
    y_high_pass = high_pass_filter(y, sr, high_pass_cutoff)
    fourier_high_pass = np.fft.fft(y_high_pass)
    n_high_pass = len(fourier_high_pass)
    frequencies_high_pass = np.fft.fftfreq(n_high_pass, 1/sr) 
    plt.plot(frequencies_high_pass[:n_high_pass//2], np.abs(fourier_high_pass)[:n_high_pass//2])
    plt.title(f'Frequency Spectrum after high-pass filter at {int(high_pass_cutoff)} Hz')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Magnitude')

    
    # 2. Onset strength
    plt.subplot(rows, 1, 6)
    plt.plot(onset_strength_normalized, label='Onset Strength', alpha=0.7)
    poisson_encoded = poisson_spike_encoding(onset_strength_normalized.reshape(1,-1), dt=dt*ms)
    # Calculate the spike count
    spike_count = np.sum(poisson_encoded)
    plt.plot(poisson_encoded[0], label='Poisson Spike Train', linestyle=':', color='g')
    plt.title(f'Onset Strength fixed size with Poisson Spike Train - Spike Count: {int(spike_count)}')    
    plt.legend()
    # 2. Onset strength
    plt.subplot(rows, 1, 7)
    # plt.plot(onset_strength_normalized)
   
    lif_neuron = LIFNeuron(tau_m=tau_m, v_rest=v_rest, v_threshold=v_threshold, v_reset=v_reset, r_m=r_m, dt=dt)
    lif_spikes, _ = generate_lif_spikes(onset_strength_normalized, lif_neuron)
    spike_count = np.sum(lif_spikes)
        # 3. Onset strength with LIF spike train
    plt.plot(onset_strength_normalized, label='Onset Strength', alpha=0.7)
    plt.plot(lif_spikes, label='LIF Spike Train', linestyle='--', color='r')
    plt.title(f'Onset Strength with LIF Spike Train, Spike Count: {int(spike_count)}')
    #plt.legend()
    

    # 3. Tempogram
    # plt.subplot(rows, 1, 6)
    # librosa.display.specshow(tempogram_fixed, sr=sr, x_axis='time', y_axis='tempo')
    # plt.title('Tempogram fixed')
    
    # 7. Mel Spectrogram
    plt.subplot(rows, 1, 3)
    librosa.display.specshow(librosa.power_to_db(mel_spectrogram, ref=np.max), y_axis='mel', x_axis='time')
    plt.colorbar(format='%+2.0f dB')
    plt.title('Mel spectrogram')
    
    n_bands_to_keep = 80
    mel_spectrogram_reduced = mel_spectrogram[:n_bands_to_keep]
    
    plt.subplot(rows, 1, 4)
    librosa.display.specshow(librosa.power_to_db(mel_spectrogram_reduced, ref=np.max), y_axis='mel', x_axis='time')
    plt.colorbar(format='%+2.0f dB')
    plt.title('Mel spectrogram reduced')
    
    print("Dimensions of mel_spectrogram:", mel_spectrogram.shape)
    print("Dimensions of mel_spectrogram_reduced:", mel_spectrogram_reduced.shape)
    print("Dimensions of spec 8 khz:", mel_spectrogram_8.shape)


    

    plt.figtext(0.15, 0.14, f"Item No: {item_no}, Window Size: {window_size}, tau_m: {tau_m}, v_rest: {v_rest}, "
                        f"v_threshold: {v_threshold}, v_reset: {v_reset}, r_m: {r_m}, dt: {dt}, "
                        f"high_pass_cutoff: {high_pass_cutoff}", ha="left", fontsize=10)
    plt.tight_layout()
    plt.savefig('plot_with_params.png')
    plt.show()
    

def interactive_plot_spec_norm(item_no, window_size, tau_m, v_rest, v_threshold, v_reset, r_m, dt, high_pass_cutoff,):
    plot_pixel_spectra_norm(item_no, window_size, tau_m, v_rest, v_threshold, v_reset, r_m, dt, high_pass_cutoff)


if widgets is not None:
    widgets.interact(
        interactive_plot_spec_norm,
        item_no=widgets.IntSlider(min=0, max=len(training_data)-1, value=0, step=1, continuous_update=False, description="Item No."),
        window_size=widgets.IntSlider(min=1, max=400, value=10, step=1, continuous_update=False, description="avg window size"),
        tau_m=widgets.FloatSlider(min=1.0, max=100.0, value=6.0, step=0.5, continuous_update=False, description="tau_m"),
        v_rest=widgets.FloatSlider(min=0.0, max=1.0, value=0.0, step=0.1, continuous_update=False, description="v_rest"),
        v_threshold=widgets.FloatSlider(min=0.0, max=1.0, value=0.7, step=0.1, continuous_update=False, description="v_threshold"),
        v_reset=widgets.FloatSlider(min=0.0, max=1.0, value=0.2, step=0.1, continuous_update=False, description="v_reset"),
        r_m=widgets.FloatSlider(min=0.1, max=1.0, value=0.9, step=0.1, continuous_update=False, description="r_m"),
        dt=widgets.FloatSlider(min=0.1, max=100.0, value=10.0, step=1.0, continuous_update=False, description="dt"),
        high_pass_cutoff=widgets.FloatSlider(min=10.0, max=1000.0, value=20.0, step=10.0, continuous_update=False, description="high_pass_cutoff [Hz]"),
    );

interactive(children=(IntSlider(value=0, continuous_update=False, description='Item No.', max=119), IntSlider(…

In [4]:
import librosa.display
import ipywidgets as widgets
import numpy as np
import matplotlib.pyplot as plt
import librosa

def autocorrelate_tempogram(onset_strength, max_lag=None):
    if max_lag is None:
        max_lag = len(onset_strength)
    autocorr = np.correlate(onset_strength, onset_strength, mode='full')[len(onset_strength)-1:len(onset_strength)+max_lag-1]
    return autocorr

def compute_onset_strength(y, sr):
    # Compute the STFT magnitude
    hop_length = int(sr * 0.01)  # roughly 1% of the sample rate. For 22kHz, this would be around 220.
    hop_length = min(512, hop_length)  # put an upper limit if you want

    # Set the window length as a multiple of hop length
    win_length = int(2 * hop_length)  # or 4 * hop_length

    # Using a Hann window explicitly
    D = np.abs(librosa.stft(y, window='hann'))
    
    # Compute the first-order difference
    diff = np.diff(D, axis=1)
    
    # Keep only the positive part of the difference (similar to half-wave rectification)
    diff = np.maximum(0, diff)
    
    # Sum across frequencies
    onset_strength = np.sum(diff, axis=0)
    
    return onset_strength



def plot_tempo_extraction(y, sr, file_path):
    fig, axs = plt.subplots(6, 1, figsize=(15, 30))
    
    org_sr = sr
    tsr = 1024*12
    y = librosa.resample(y=y, orig_sr=sr, target_sr=tsr)
    sr = tsr
    
    time_signature = file_path.split('/')[-2].replace('_', '/')
    instrument, bpm_true, duration, noise = extract_bpm_and_instrument(file_path)  # Assuming you have this function
    print(f"time signature: {time_signature} BPM:{bpm_true} Noise:{noise} sampling rate: {sr} ", end='\r')
    plt.title(f'audio with {time_signature} time signature, {bpm_true} bpm, and instrument {instrument}, sampling rate: {sr}')
    
    bpm_true = int(bpm_true)
    
    # 1. Raw Audio
    librosa.display.waveshow(y, sr=sr, ax=axs[0])
    axs[0].set_title(f'raw audio with {time_signature} time signature, {bpm_true} bpm, and instrument {instrument}')
    axs[0].set_xlim([0, len(y) / sr])

    # 2. STFT Magnitude
    D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max)
    librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='log', ax=axs[1])
    axs[1].set_title("STFT Magnitude Spectrum")
    axs[1].set_xlim([0, len(y) / sr])


    # 3. Mel Spectrogram
    mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr)
    t = librosa.frames_to_time(np.arange(mel_spectrogram.shape[1]), sr=sr)  # based on the resampled sr
    librosa.display.specshow(librosa.power_to_db(mel_spectrogram, ref=np.max), y_axis='mel', x_axis='time', x_coords=t, ax=axs[2])
    axs[2].set_title('Mel spectrogram')
    axs[2].set_xlim([0, len(y) / sr])


    # 4. Spectral Flux (onset strength)
    skip = 1
    onset_strength = compute_onset_strength(y=y, sr=sr)
    onset_times = librosa.frames_to_time(librosa.onset.onset_detect(onset_envelope=onset_strength, sr=sr))
    onset_times_adjusted = onset_times * (org_sr / sr)
    axs[3].vlines(onset_times_adjusted, 0, np.max(onset_strength), color='r', alpha=0.9, linestyle='--')

    t_onset = librosa.frames_to_time(np.arange(len(onset_strength)), sr=sr)  # based on the resampled sr
    axs[3].plot(t_onset, onset_strength, label="Onset Strength")
    axs[3].set_title("Spectral Flux with Onset Times")
    axs[3].set_xlim([0, len(y) / sr])

    # 5. Autocorrelation Tempogram
    hop_length = 512
    autocorr_tempogram = autocorrelate_tempogram(onset_strength)
    # normalize to range [0, 1]
    autocorr_tempogram = (autocorr_tempogram - np.min(autocorr_tempogram)) / (np.max(autocorr_tempogram) - np.min(autocorr_tempogram))
    lag_times = np.arange(len(autocorr_tempogram)) * hop_length / sr
    bpm_values = 60.0 / lag_times[skip:]
    valid_bpm_range = (bpm_values > 30) & (bpm_values < 150)
    valid_bpm_values = bpm_values[valid_bpm_range]
    valid_tempogram = autocorr_tempogram[skip:][valid_bpm_range]
    estimated_bpm = valid_bpm_values[np.argmax(valid_tempogram)]
    axs[4].plot(valid_bpm_values, valid_tempogram, label="Tempogram", color='g')
    axs[4].axvline(estimated_bpm, color='r', linestyle='--', label=f'Estimated BPM: {estimated_bpm:.2f}')
    axs[4].set_title("Autocorrelation Tempogram" + f'Estimated BPM: {estimated_bpm:.2f}')
    

    # 6. Librosa's Tempogram
    tempogram = librosa.feature.tempogram(onset_envelope=onset_strength, sr=sr)
    librosa.display.specshow(tempogram, sr=sr, x_axis='time', y_axis='tempo', cmap='magma', ax=axs[5])
    bpm = librosa.core.tempo_frequencies(tempogram.shape[0], hop_length=512, sr=sr)
    #bpm = bpm[skip:]
    #mean_tempogram = np.mean(tempogram, axis=1)
    mean_tempogram = np.mean(tempogram, axis=1)
    estimated_bpm_tempogram = bpm[np.argmax(mean_tempogram)]
    tempo, beat_times = librosa.beat.beat_track(y=y, sr=sr)
    axs[5].axhline(estimated_bpm_tempogram, color='r', linestyle='--', label=f'Estimated BPM: {estimated_bpm_tempogram:.2f}, librosa={tempo}')
    axs[5].axhline(bpm_true, color='g', linestyle='--', label=f'Correct BPM: {bpm_true:.2f}')
    axs[5].set_title(f'Tempogram Estimated BPM: {estimated_bpm_tempogram:.2f}, librosa={tempo}, true={bpm_true:.2f}')
    axs[5].set_xlim([0, len(y) / sr])
    

    
    plt.tight_layout()
    plt.show()
    
    
item_no_widget = widgets.IntSlider(min=0, max=len(training_data)-1, value=0, step=1, continuous_update=False, description="Item No.")

def wrapper(item_no):
    y = training_data[item_no][0]
    sr = training_data[item_no][1]
    file_path = training_data[item_no][2]
    plot_tempo_extraction(y, sr, file_path)

widgets.interactive(wrapper, item_no=item_no_widget)

interactive(children=(IntSlider(value=0, continuous_update=False, description='Item No.', max=119), Output()),…

In [11]:
#Initialize some global variables to remember state
last_item_index = -1
last_onset_strengths = {'1/4': None, '2/4': None, '3/4': None, '4/4': None, '5/4': None, '7/8': None}

def interactive_LIF(item_index, tau_m=6.0, v_rest=0.0, v_threshold=0.7, v_reset=0.2, r_m=0.9, dt=10.0, adaptive_increase=0.0, refractory_period=1):
    global last_item_index, last_onset_strengths
    
    print(f"item index={item_index}, last_item_index={last_item_index}", end='\r')
        
    fig, axs = plt.subplots(3, 2, figsize=(16, 12))
    
    # Create a title string that includes LIF parameters
    title_str = (f"LIF Parameters:\n"
                 f"tau_m={tau_m}, v_rest={v_rest}, v_threshold={v_threshold}, v_reset={v_reset},\n"
                 f"r_m={r_m}, dt={dt}, adaptive_increase={adaptive_increase}, refractory_period={refractory_period}")
    fig.suptitle(title_str, fontsize=12)
    
    
    for ax, (time_sig, data) in zip(axs.flatten(), [('1/4', training_data_1_4), ('2/4', training_data_2_4), ('3/4', training_data_3_4), ('4/4', training_data_4_4), ('5/4', training_data_5_4), ('7/8', training_data_7_8)]):
        
        if  last_item_index != item_index or last_onset_strengths[time_sig] is None:
            y, sr, file_path = data[item_index]
            # Apply a short fade-in
            # fade_in_time = 0.3  # in seconds
            # fade_in_samples = int(fade_in_time * sr)
            # fade_curve = np.linspace(0, 1, fade_in_samples)
            # y[:fade_in_samples] = y[:fade_in_samples] * fade_curve
            
            print("calculating onset strength", end='\r')
            onset_strength = librosa.onset.onset_strength(y=y, sr=sr)
            last_onset_strengths[time_sig] = onset_strength
        else:
            onset_strength = last_onset_strengths[time_sig]
        
        onset_strength = onset_strength[20:]
        onset_strength_normalized = (onset_strength - np.min(onset_strength)) / (np.max(onset_strength) - np.min(onset_strength))

        # Your plotting logic here using ax for plotting
        lif_neuron = LIFNeuron(tau_m=tau_m, v_rest=v_rest, v_threshold=v_threshold, v_reset=v_reset, r_m=r_m, dt=dt, adaptive_increase=adaptive_increase, refractory_period=refractory_period)
        lif_spikes, lif_potentials = generate_lif_spikes(onset_strength_normalized, lif_neuron)  # Note the second returned value
        spike_count = np.sum(lif_spikes)
        # Find indices where spikes occur
        spike_indices = np.where(lif_spikes > 0)[0]
        ax.vlines(spike_indices, ymin=1.05, ymax=1.1, color='r', label='LIF Spike Train')
        ax.plot(onset_strength_normalized, label='Onset Strength', alpha=0.7)
        ax.plot(lif_spikes, label='Onset spikes', alpha=0.7)
        ax.plot(lif_potentials, label='Membrane Potential', linestyle='--', color='g')  # Plot the potentials
        ax.set_title(f'Onset Strength {time_sig} with LIF Spike Train, Spike Count: {int(spike_count)}')
        ax.legend(loc='center right')
    
    if item_index != last_item_index:
        last_item_index = item_index
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to make room for the suptitle    
    # plt.show()
    # plot the parameters of the lif neuron
    plt.savefig('LIF_spiketrain_compare_adaptive.png')


max_index = min(len(training_data_1_4), len(training_data_2_4), len(training_data_3_4), len(training_data_4_4)) - 1

if widgets is not None:
    widgets.interact(
            interactive_LIF,
            item_index=widgets.IntSlider(min=0, max=max_index, value=0, step=1, continuous_update=False, description="Item Index"),
            tau_m=widgets.FloatSlider(min=1.0, max=10.0, value=6.0, step=0.1, continuous_update=True, description="tau_m"),
            v_rest=widgets.FloatSlider(min=0.0, max=1.0, value=0.0, step=0.1, continuous_update=True, description="v_rest"),
            v_threshold=widgets.FloatSlider(min=0.0, max=1.0, value=0.7, step=0.1, continuous_update=True, description="v_threshold"),
            v_reset=widgets.FloatSlider(min=0.0, max=1.0, value=0.2, step=0.1, continuous_update=True, description="v_reset"),
            r_m=widgets.FloatSlider(min=0.1, max=1.0, value=0.9, step=0.1, continuous_update=True, description="r_m"),
            dt=widgets.FloatSlider(min=0.1, max=100.0, value=10.0, step=0.1, continuous_update=True, description="dt"),
            adaptive_increase=widgets.FloatSlider(min=0.0, max=1.0, value=0.0, step=0.1, continuous_update=True, description="adaptive threshold increase"),
            refractory_period=widgets.IntSlider(min=0, max=10, value=1, step=1, continuous_update=True, description="refractory period"),
        );

interactive(children=(IntSlider(value=0, continuous_update=False, description='Item Index', max=19), FloatSlid…