In [7]:
# pip install clip doesn't work, so we need to install from the repo
# pip install git+https://github.com/openai/CLIP.git

from PIL import Image
import torch
import torchaudio
import matplotlib.pyplot as plt
import os
from pathlib import Path
import clip


In [8]:
def audio_to_spectrogram(audio_path, save_path=None):
    # Load audio
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Create spectrogram
    spectrogram = torchaudio.transforms.Spectrogram()(waveform)
    
    # Convert to image
    plt.figure(figsize=(10, 4))
    plt.imshow(spectrogram[0].log2().numpy(), aspect='auto')
    plt.axis('off')
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        plt.close()
        return save_path
    else:
        # Convert plot to image
        plt.tight_layout(pad=0)
        plt.close()
        return plt

def wav_to_jpg_filename(path_with_filename):
    base = os.path.basename(path_with_filename)         # "snare_1.wav"
    name, _ = os.path.splitext(base)                    # ("snare_1", ".wav")
    return f"{name}.jpg"                                # "snare_1.jpg"


def get_wav_files_and_labels(root_dir):
    """
    Recursively fetch all .wav files and assign labels based on the parent directory name.

    Args:
        root_dir (str or Path): Root directory to search for .wav files.

    Returns:
        List of tuples: (filepath, label)
    """
    root = Path(root_dir)
    files_with_labels = []

    for wav_file in root.rglob('*.wav'):
        label = wav_file.parent.name
        files_with_labels.append((str(wav_file), label))

    return files_with_labels

def find_tuple_string(array, target_string):
    """
    Searches for a target string in the first element of tuples in an array.
    If found, returns the second element of the tuple.

    Args:
        array (list): A list of tuples.
        target_string (str): The string to search for.

    Returns:
        any: The second element of the tuple if the string is found, otherwise None.
    """
    for tup in array:
        if tup[1] == target_string:
            return tup[0]
    return None


def read_spectogram_files(folder_name):
    """
    Reads all files in a directory and returns them as an array of tuples.
    The first part of the tuple is the filename (including the path), 
    the second part is a label created from the filename.

    Args:
        folder_name (str): The path to the directory.

    Returns:
        list: A list of tuples containing the filename and label.
    """
    files = []
    for filename in os.listdir(folder_name):
        filepath = os.path.join(folder_name, filename)
        if os.path.isfile(filepath):
            label = "_".join(filename.split("_")[:-1]).split(".")[0]
            files.append((filepath, label))
    return files

In [9]:
# Creating names for CLIP + mapping categories
drum_mapping_categories = [
                    ("snare drum",      "snare"), 
                    ("bass drum",       "bass"),
                    ("hi-hat drum",     "hihat_closed"),
                    ("hi-hat drum",     "hihat_open"),
                    ("tom-tom drum",    "tomtom_low"),
                    ("tom-tom drum",    "tomtom_mid"),
                    ("tom-tom drum",    "tomtom_hi"),
                    ("cymbal drum",     "cymbal")
                ]

drum_categories_for_clip = sorted({item[0] for item in drum_mapping_categories})

In [None]:
# Create output directory if it doesn't exist
os.makedirs("../out/clip_test_openai", exist_ok=True)


# Create spectograms for all drum samples
for drum_sample_wav, label in get_wav_files_and_labels('../data/samples/'):
    drum_sample_jpg = "../out/clip_test_openai/" + wav_to_jpg_filename(drum_sample_wav)
    classification = audio_to_spectrogram(drum_sample_wav, drum_sample_jpg)

correct = 0
incorrect = 0

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

for spectogram_jpg, label in read_spectogram_files('../out/clip_test_openai/'):

    # Load spectrogram image
    image = preprocess(Image.open(spectogram_jpg)).unsqueeze(0).to(device)

    # Correct: 10, Incorrect: 26
    texts = [f"A spectrogram of a {category} sound." for category in drum_categories_for_clip]

    # Correct: 6, Incorrect: 30
    #texts = [f"a {category}" for category in drum_categories_for_clip]

    # Prepare text inputs
    text = clip.tokenize(texts).to(device)

    # Get features and compute similarity
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        logits_per_image = image_features @ text_features.T
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    classification = drum_categories_for_clip[probs.argmax().item()], probs.tolist()[0]
    mapped_value = find_tuple_string(drum_mapping_categories, label)

    if mapped_value == classification[0]:
        correct += 1
        print(f"Correct --> Have this: {label} (path {spectogram_jpg}) gets classified as: {classification[0]}")
    else:
        incorrect += 1
        print(f"Error --> Have this: {label} (path {spectogram_jpg}) gets classified as: {classification[0]} (expected {mapped_value})")

print(f"Correct: {correct}, Incorrect: {incorrect}")


# Vaildation - How about dogs and cats?
dogs_cats = ['../data/pictures/dog/dog_1.png',
              '../data/pictures/dog/dog_2.png', 
              '../data/pictures/dog/dog_3.png', 
              '../data/pictures/cat/cat_1.png', 
              '../data/pictures/cat/cat_2.png',
              '../data/pictures/cat/cat_3.png']

for dog_cat in dogs_cats:
    image = Image.open(dog_cat)
    texts = ["a photo of a dog", "a photo of a cat", "a photo of a cat and dog", "a photo of Novia"]

    # Preprocess the image for CLIP
    image_input = preprocess(image).unsqueeze(0).to(device)

    # Tokenize the text prompts
    text_inputs = clip.tokenize(texts).to(device)

    # Get image and text features, compute similarity
    with torch.no_grad():
        image_features = model.encode_image(image_input)
        text_features = model.encode_text(text_inputs)
        logits_per_image = image_features @ text_features.T
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    # Print the most likely label and probabilities
    predicted_label = texts[probs.argmax()]
    print(f"{dog_cat} is classified as: {predicted_label} (probs: {probs[0]})")



Correct --> Have this: tomtom_low (path ../out/clip_test_openai/tomtom_low_3.jpg) gets classified as: tom-tom drum
Error --> Have this: hihat_closed (path ../out/clip_test_openai/hihat_closed_1.jpg) gets classified as: tom-tom drum (expected hi-hat drum)
Correct --> Have this: tomtom_hi (path ../out/clip_test_openai/tomtom_hi_2.jpg) gets classified as: tom-tom drum
Error --> Have this: hihat_open (path ../out/clip_test_openai/hihat_open_2.jpg) gets classified as: tom-tom drum (expected hi-hat drum)
Error --> Have this: kick (path ../out/clip_test_openai/kick_2.jpg) gets classified as: tom-tom drum (expected None)
Error --> Have this: hihat_open (path ../out/clip_test_openai/hihat_open_3.jpg) gets classified as: tom-tom drum (expected hi-hat drum)
Correct --> Have this: tomtom_mid (path ../out/clip_test_openai/tomtom_mid_3.jpg) gets classified as: tom-tom drum
Correct --> Have this: tomtom_low (path ../out/clip_test_openai/tomtom_low_2.jpg) gets classified as: tom-tom drum
Error --> Hav