In [None]:
import tensorflow as tf
import torch
import torchaudio.functional as F
import torchaudio.transforms as T
import numpy as np
import librosa

In [None]:
# Load TFLite model
interpreter = tf.lite.Interpreter(model_path="instrument_classifier.tflite")
interpreter.allocate_tensors()

# Get model input and output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("Input Details:", input_details)
print("Output Details:", output_details)

In [None]:
# Constants
SAMPLE_RATE = 44100
DURATION = 0.5  # Full segment size in seconds
STEP_SIZE = 0.25  # Sliding window step size (0.25 sec)
N_MELS = 32
HOP_LENGTH = 512
N_FFT = 2048

In [None]:
def generate_mel_spect_gpu(audio_path, save_path=None):
    y, _ = librosa.load(audio_path, sr=SAMPLE_RATE)
    samples_per_segment = int(DURATION * SAMPLE_RATE)
    step_samples = int(STEP_SIZE * SAMPLE_RATE)  # Now sliding instead of jumping
    mel_spects = []

    # Precompute Hann window and Mel filterbank on GPU
    hann_window = torch.hann_window(N_FFT, device="cuda")
    mel_filters = F.melscale_fbanks(
        n_freqs=N_FFT // 2 + 1,
        f_min=0,
        f_max=16000,
        n_mels=N_MELS,
        sample_rate=SAMPLE_RATE,
        norm=None,
        mel_scale="htk",
    ).T.to("cuda")  # (n_mels, n_freq_bins)

    # Compile spectrograms with overlap
    for start in range(0, len(y) - samples_per_segment + 1, step_samples):  
        end = start + samples_per_segment
        segment = y[start:end]

        segment = segment / np.max(np.abs(segment)) if np.max(np.abs(segment)) > 0 else segment  
        segment_torch = torch.tensor(segment, dtype=torch.float32, device="cuda")

        stft = torch.stft(
            segment_torch,
            n_fft=N_FFT,
            hop_length=HOP_LENGTH,
            win_length=N_FFT,
            window=hann_window,
            return_complex=True,
            center=False
        )  

        power_spectrum = stft.abs().pow(2) 
        mel_spectrogram_manual = torch.matmul(mel_filters, power_spectrum)
        mel_spectrogram_db = T.AmplitudeToDB()(mel_spectrogram_manual)
        mel_spects.append(mel_spectrogram_db.cpu().numpy())

    if save_path:
        np.save(save_path, np.array(mel_spects))

    return mel_spects

In [81]:
# Example usage of generating mel spectrograms and running inference on the TFLite model
audio_path = "slakh2100_flac_redux/reduced_train/Track00018/mix.flac"
mel_spects = generate_mel_spect_gpu(audio_path)

# Instrument labels
instruments = ['Piano', 'Guitar', 'Bass', 'String', 'Drums']
THRESHOLD = 0.6  # Lowered threshold for detection
STEP_SIZE = 0.25  # Sliding window step size (0.25 sec instead of 0.5)

# Loop through all mel spectrograms with sliding window
for idx, mel_spectrogram_db in enumerate(mel_spects):
    # Ensure mel_spectrogram_db has the correct shape for the model: [1, 32, 40, 1]
    mel_spectrogram_db = np.expand_dims(mel_spectrogram_db, axis=0)  # Add batch dimension
    mel_spectrogram_db = np.expand_dims(mel_spectrogram_db, axis=-1)  # Add channel dimension

    # Calculate timestamp using sliding window
    time = idx * STEP_SIZE  # Sliding by 0.25 sec instead of 0.5
    minutes = int(time // 60)
    seconds = int(time % 60)
    time_str = f"{minutes:01d}:{seconds:02d}"

    # Run inference on the TFLite model
    input_data = np.array(mel_spectrogram_db, dtype=np.float32)
    interpreter.set_tensor(input_details[0]['index'], input_data)

    # Run inference
    interpreter.invoke()

    # Get the output tensor
    output_data = interpreter.get_tensor(output_details[0]['index'])
    
    # Convert predictions to binary 
    binary_output = (output_data > THRESHOLD).astype(np.int32)
    
    # Collect predictions for this segment
    predictions = [time_str]  # Start with time string
    for i, pred in enumerate(binary_output[0]):
        instrument_name = instruments[i]
        if pred == 1:
            predictions.append(f"\033[92m{instrument_name}\033[0m")  # Lime green for > 0.4
        else:
            predictions.append(instrument_name)  # Default color for <= 0.4

    # Print predictions for this segment on the same line
    print(" | ".join(predictions))


0:00 | Piano | Guitar | [92mBass[0m | String | Drums
0:00 | Piano | Guitar | [92mBass[0m | String | Drums
0:00 | Piano | Guitar | [92mBass[0m | String | Drums
0:00 | Piano | Guitar | [92mBass[0m | String | Drums
0:01 | Piano | Guitar | [92mBass[0m | String | Drums
0:01 | Piano | [92mGuitar[0m | [92mBass[0m | String | Drums
0:01 | Piano | Guitar | [92mBass[0m | String | Drums
0:01 | Piano | Guitar | [92mBass[0m | String | Drums
0:02 | Piano | Guitar | [92mBass[0m | String | Drums
0:02 | Piano | Guitar | [92mBass[0m | String | Drums
0:02 | Piano | Guitar | [92mBass[0m | String | Drums
0:02 | [92mPiano[0m | Guitar | [92mBass[0m | String | Drums
0:03 | [92mPiano[0m | Guitar | [92mBass[0m | String | Drums
0:03 | [92mPiano[0m | Guitar | Bass | String | Drums
0:03 | Piano | Guitar | [92mBass[0m | String | Drums
0:03 | Piano | Guitar | [92mBass[0m | String | Drums
0:04 | Piano | Guitar | Bass | String | Drums
0:04 | Piano | Guitar | Bass | String | Drums
0