In [1]:
import os
import numpy as np
import tensorflow as tf
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
import librosa
from collections import defaultdict

In [2]:
SAMPLE_RATE = 44100
DURATION = 1.0  # seconds per segment
STEP_SIZE = 0.5  # seconds per step (use < DURATION for overlap)
N_MELS = 128
HOP_LENGTH = 1024
N_FFT = 2048

In [9]:
MODEL_PATH = "classifiers/instrument_classifier.tflite"
AUDIO_PATH = "slakh2100_flac_redux/reduced_test/Track01876/mix.flac"
INSTRUMENTS = ['Piano', 'Guitar', 'Bass', 'Strings', 'Drums']
CLASS_THRESHOLDS = np.array([0.5, 0.6, 0.5, 0.3, 0.5])

In [5]:
interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [10]:
def generate_mel_spect_gpu(audio_path):
    """
    Loads audio using torchaudio, resamples and normalizes it,
    and generates mel spectrograms on the GPU.
    """
    # Load with torchaudio (auto handles most formats)
    waveform, sr = torchaudio.load(audio_path)  # shape: [channels, samples]

    # Convert to mono if needed
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    # Resample if sample rate doesn't match
    if sr != SAMPLE_RATE:
        resampler = T.Resample(orig_freq=sr, new_freq=SAMPLE_RATE)
        waveform = resampler(waveform)

    waveform = waveform.squeeze()  # shape: [samples]
    waveform = waveform.to(torch.float32).to("cuda")

    segment_len = int(DURATION * SAMPLE_RATE)
    step_len = int(STEP_SIZE * SAMPLE_RATE)

    # Hann window and mel filterbank (cached)
    hann_window = torch.hann_window(N_FFT, device="cuda")
    mel_filters = F.melscale_fbanks(
        n_freqs=N_FFT // 2 + 1,
        f_min=0,
        f_max=16000,
        n_mels=N_MELS,
        sample_rate=SAMPLE_RATE,
        norm=None,
        mel_scale="htk",
    ).T.to("cuda")

    mel_spects = []

    for start in range(0, waveform.shape[0] - segment_len + 1, step_len):
        segment = waveform[start:start + segment_len]

        # Normalize to [-1, 1]
        if torch.max(torch.abs(segment)) > 0:
            segment = segment / torch.max(torch.abs(segment))

        stft = torch.stft(
            segment,
            n_fft=N_FFT,
            hop_length=HOP_LENGTH,
            win_length=N_FFT,
            window=hann_window,
            return_complex=True,
            center=False
        )

        power_spectrum = stft.abs().pow(2)
        mel = torch.matmul(mel_filters, power_spectrum)
        mel_db = T.AmplitudeToDB()(mel)

        mel_spects.append(mel_db.cpu().numpy())

    return mel_spects

In [11]:
def run_inference(mel_spects):
    """
    Runs model inference on list of mel spectrograms.
    Prints time-stamped predictions.
    """
    results = defaultdict(list)

    for idx, mel in enumerate(mel_spects):
        mel_input = np.expand_dims(mel, axis=(0, -1)).astype(np.float32)

        time_sec = idx * STEP_SIZE
        timestamp = f"{int(time_sec // 60):01d}:{int(time_sec % 60):02d}"

        interpreter.set_tensor(input_details[0]['index'], mel_input)
        interpreter.invoke()
        output = interpreter.get_tensor(output_details[0]['index'])

        binary_output = (output > CLASS_THRESHOLDS).astype(np.int32)[0]

        prediction_strs = [
            f"\033[92m{name}\033[0m" if pred else name
            for name, pred in zip(INSTRUMENTS, binary_output)
        ]

        print(f"{timestamp} | " + " | ".join(prediction_strs))

In [12]:
mel_spects = generate_mel_spect_gpu(AUDIO_PATH)
run_inference(mel_spects)

0:00 | Piano | Guitar | Bass | Strings | Drums
0:00 | Piano | Guitar | Bass | Strings | Drums
0:01 | Piano | Guitar | Bass | Strings | Drums
0:01 | Piano | Guitar | Bass | Strings | Drums
0:02 | Piano | Guitar | Bass | Strings | Drums
0:02 | Piano | Guitar | Bass | Strings | Drums
0:03 | Piano | Guitar | Bass | Strings | Drums
0:03 | Piano | Guitar | Bass | Strings | Drums
0:04 | Piano | Guitar | Bass | Strings | Drums
0:04 | Piano | Guitar | Bass | Strings | Drums
0:05 | Piano | Guitar | Bass | [92mStrings[0m | Drums
0:05 | Piano | Guitar | Bass | [92mStrings[0m | Drums
0:06 | Piano | Guitar | Bass | [92mStrings[0m | Drums
0:06 | Piano | Guitar | Bass | [92mStrings[0m | Drums
0:07 | Piano | Guitar | Bass | [92mStrings[0m | Drums
0:07 | Piano | Guitar | Bass | [92mStrings[0m | Drums
0:08 | Piano | Guitar | Bass | [92mStrings[0m | Drums
0:08 | Piano | [92mGuitar[0m | Bass | [92mStrings[0m | Drums
0:09 | Piano | [92mGuitar[0m | Bass | [92mStrings[0m | Drums
0:09 | P