In [2]:
from PIL import Image
import torchaudio
import matplotlib.pyplot as plt
from transformers import CLIPProcessor, CLIPModel
import os
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm
2025-05-18 18:51:40.188640: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-18 18:51:40.329880: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
# Load CLIP model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def audio_to_spectrogram(audio_path, save_path=None):
    # Load audio
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Create spectrogram
    spectrogram = torchaudio.transforms.Spectrogram()(waveform)
    
    # Convert to image
    plt.figure(figsize=(10, 4))
    plt.imshow(spectrogram[0].log2().numpy(), aspect='auto')
    plt.axis('off')
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        plt.close()
        return save_path
    else:
        # Convert plot to image
        plt.tight_layout(pad=0)
        plt.close()
        return plt

def classify_drum_sound(audio_path, drum_categories, spectrogram_file="temp_spectrogram.jpg"):
    # Convert audio to spectrogram image
    spectrogram_path = audio_to_spectrogram(audio_path, spectrogram_file)
    
    # Load image
    image = Image.open(spectrogram_path)
    
    # Prepare text inputs
    texts = [f"A spectrogram of a {category} sound." for category in drum_categories]
    
    # Process inputs
    inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
    
    # Get similarity scores
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)
    
    # Return prediction
    return drum_categories[probs.argmax().item()], probs.tolist()[0]

def wav_to_jpg_filename(path_with_filename):
    base = os.path.basename(path_with_filename)         # "snare_1.wav"
    name, _ = os.path.splitext(base)                    # ("snare_1", ".wav")
    return f"{name}.jpg"                                # "snare_1.jpg"


def get_wav_files_and_labels(root_dir):
    """
    Recursively fetch all .wav files and assign labels based on the parent directory name.

    Args:
        root_dir (str or Path): Root directory to search for .wav files.

    Returns:
        List of tuples: (filepath, label)
    """
    root = Path(root_dir)
    files_with_labels = []

    for wav_file in root.rglob('*.wav'):
        label = wav_file.parent.name
        files_with_labels.append((str(wav_file), label))

    return files_with_labels

def find_tuple_string(array, target_string):
    """
    Searches for a target string in the first element of tuples in an array.
    If found, returns the second element of the tuple.

    Args:
        array (list): A list of tuples.
        target_string (str): The string to search for.

    Returns:
        any: The second element of the tuple if the string is found, otherwise None.
    """
    for tup in array:
        if tup[0] == target_string:
            return tup[1]
    return None

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
# Creating names for CLIP + mapping categories
drum_mapping_categories = [
                    ("snare drum",      "snare"), 
                    ("bass drum",       "bass"),
                    ("hi-hat cymbal",   "hihat_closed"),
                    ("hi-hat cymbal",   "hihat_open"),
                    ("tom-tom drum",    "tomtom_lo"),
                    ("tom-tom drum",    "tomtom_mid"),
                    ("tom-tom drum",    "tomtom_hi"),
                    ("cymbal",          "cymbal")
                ]

drum_categories_for_clip = sorted({item[0] for item in drum_mapping_categories})

In [None]:
correct = 0
incorrect = 0

# Create output directory if it doesn't exist
os.makedirs("../out/clip_test_hug", exist_ok=True)

files_and_labels = get_wav_files_and_labels('../data/samples/')

for drum_sample_wav, label in files_and_labels:
    drum_sample_jpg = "../out/clip_test_hug/" + wav_to_jpg_filename(drum_sample_wav)
    classification = classify_drum_sound(drum_sample_wav, drum_categories_for_clip, drum_sample_jpg)

    mapped_value = find_tuple_string(drum_mapping_categories, classification[0])

    if mapped_value == label:
        correct += 1
        print(f"Correct --> {label}: {drum_sample_wav} -> {classification[0]}")
    else:
        incorrect += 1
        print(f"Error --> {label}: {drum_sample_wav} -> {classification[0]} (expected {mapped_value})")

print(f"Correct: {correct}, Incorrect: {incorrect}")

# Validation - How about dogs and cats?
dogs_cats = ['../data/pictures/dog/dog_1.png',
              '../data/pictures/dog/dog_2.png', 
              '../data/pictures/dog/dog_3.png', 
              '../data/pictures/cat/cat_1.png', 
              '../data/pictures/cat/cat_2.png',
              '../data/pictures/cat/cat_3.png']

for dog_cat in dogs_cats:
    image = Image.open(dog_cat)
    texts = ["a photo of a dog", "a photo of a cat", "a photo of a cat and dog", "a photo of Novia"]
    
    inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
    
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)
    
    print(f"Dog/Cat: {dog_cat} -> {probs.tolist()[0]}")

    predicted_index = probs.argmax().item()
    predicted_text = texts[predicted_index]
    print(f"Predicted: {predicted_text}")


Error --> tomtom: ../data/samples/tomtom/tomtom_hi_3.wav -> snare drum (expected snare)
Error --> tomtom: ../data/samples/tomtom/tomtom_low_1.wav -> bass drum (expected bass)
Error --> tomtom: ../data/samples/tomtom/tomtom_low_3.wav -> snare drum (expected snare)
Error --> tomtom: ../data/samples/tomtom/tomtom_hi_1.wav -> bass drum (expected bass)
Error --> tomtom: ../data/samples/tomtom/tomtom_low_2.wav -> bass drum (expected bass)
Error --> tomtom: ../data/samples/tomtom/tomtom_mid_2.wav -> snare drum (expected snare)
Error --> tomtom: ../data/samples/tomtom/tomtom_mid_3.wav -> tom-tom drum (expected tomtom_lo)
Error --> tomtom: ../data/samples/tomtom/tomtom_mid_1.wav -> bass drum (expected bass)
Error --> tomtom: ../data/samples/tomtom/tomtom_hi_2.wav -> bass drum (expected bass)
Error --> kick: ../data/samples/kick/kick_2.wav -> tom-tom drum (expected tomtom_lo)
Error --> kick: ../data/samples/kick/kick_3.wav -> bass drum (expected bass)
Error --> kick: ../data/samples/kick/kick_5.