### Globals

In [None]:
import datetime

tauri_onnx_models_directory = "../SonicSearch/src-tauri/onnx_models/"
file_timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
model_name = "laion/clap-htsat-unfused"

### Utilities

In [None]:
# Utilities
import time

def make_filename_or_dirname(filename, extension=None):
    extension = "" if extension is None else "." + extension.strip('.')
    filename = filename.strip('.').lstrip('/')
    return f'{tauri_onnx_models_directory}{filename}{extension}'

# Inspect inputs and outputs
def get_shapes_in_nests(node, count=0):
    try:
        return str(node.shape)
    except:
        count += 1
        try:
            return ('\n' + '\t'*count).join([f'{key}: {get_shapes_in_nests(value)}' for key, value in node.items()])
        except:
            if isinstance(node, list):
                return ('\n' + '\t'*count).join([get_shapes_in_nests(n) for n in node])
            else:
                return str(node)
        
class QuickTimer():
    """hahaha"""
    _start = 0
    
    def start():
        QuickTimer._start = time.time()
    
    def stop():
        return time.time() - QuickTimer._start

## Embedders: Audio + Text Model with Projection

### Text

In [None]:
from transformers import AutoTokenizer, ClapTextModelWithProjection


text_model = ClapTextModelWithProjection.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenized_inputs = tokenizer(["the longest input one would reasonably use. Truly just one loooooooong input :)."], padding=True, return_tensors="pt")

text_model_outputs = text_model(**tokenized_inputs)
text_embeds = text_model_outputs.text_embeds

In [None]:
# Text inputs and outputs

print("Inputs: ", get_shapes_in_nests(tokenized_inputs))
print("Outputs: ", get_shapes_in_nests(text_model_outputs))

In [None]:
# Onnx Export - Text Model with projection

from torch import onnx

print("Exporting tokenizer config...")
tokenizer.save_pretrained(make_filename_or_dirname("tokenizer"))

print("Exporting model to ONNX...")
QuickTimer.start()
onnx.export(
    text_model,
    (tokenized_inputs["input_ids"], tokenized_inputs["attention_mask"]),
    make_filename_or_dirname(f"{model_name.split('/')[-1]}_text_with_projection", "onnx"),
    export_params=True,
    input_names=["input_ids", "attention_mask"],
    output_names=["text_embeds", "last_hidden_state"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "sequence_length"},
        "text_embeds": {0: "batch_size"},
        "last_hidden_state": {0: "batch_size"}
    }
)
print("Exporting model to ONNX took: ", QuickTimer.stop())

### Audio

In [None]:
from datasets import load_dataset
from transformers import ClapAudioModelWithProjection, ClapProcessor
import numpy as np

audio_model = ClapAudioModelWithProjection.from_pretrained(model_name)
processor = ClapProcessor.from_pretrained(model_name)

dataset = load_dataset("ashraq/esc50")
audio_sample = [datum["array"] for datum in dataset["train"]["audio"][0:31]]
longer_sample = np.concatenate(np.array([datum["array"] for datum in dataset["train"]["audio"][32:34]]))
audio_sample.append(longer_sample)

audio_inputs = processor(audios=audio_sample, return_tensors="pt", sampling_rate=48000)
audio_outputs = audio_model(**audio_inputs)
audio_embeds = audio_outputs.audio_embeds

In [None]:
# Audio inputs and outputs

print("Inputs: ", get_shapes_in_nests(audio_inputs))
print("Outputs: ", get_shapes_in_nests(audio_outputs))

In [None]:
# Onnx Export - Audio Model with projection

from torch import onnx

print("Exporting feature extractor config...")
processor.feature_extractor.save_pretrained(make_filename_or_dirname("feature_extractor"))

print("Exporting model to ONNX...")
QuickTimer.start()
onnx.export(
    audio_model,
    (audio_inputs["input_features"], audio_inputs["is_longer"]),
    make_filename_or_dirname(f"{model_name.split('/')[-1]}_audio_with_projection", "onnx"),
    export_params=True,
    input_names=["input_features", "is_longer"],
    output_names=["audio_embeds", "last_hidden_state"],
    dynamic_axes={'input_features': {0: 'batch_size'},
                  'is_longer': {0: 'batch_size'},
                  'audio_embeds': {0: 'batch_size'},
                  'last_hidden_state': {0: 'batch_size'},
                  }
)
print("Exporting model to ONNX took: ", QuickTimer.stop())

# SCRATCH

## Audio Processor Understanding
My hypothesis: when audio is <10 seconds, and we aren't doing fusion, the preprocessing is simply repeat-padding the values and transforming them to a Mel Spectrogram.

In [None]:
print(get_shapes_in_nests(audio_sample[-5:]))
print(get_shapes_in_nests(audio_inputs["input_features"][-5:]))
print(get_shapes_in_nests(audio_inputs["is_longer"][-5:]))

In [None]:
# An attempt to manually reverse-engineer the preprocessing

import librosa
from IPython.display import Audio, display
import torch
import numpy as np

sample_num = 1

librosa.display.plt.figure(figsize=(20, 20))
librosa.display.plt.subplot(2, 2, 1)
librosa.display.plt.title('Preprocessed audio_sample')
print("Preprocessed audio_sample shape: ", audio_sample[sample_num].shape)
display(Audio(data=audio_sample[sample_num], rate=44100))
librosa.display.waveshow(audio_sample[sample_num], sr=44100, color='blue')

librosa.display.plt.subplot(2, 2, 3)
librosa.display.plt.title('Processed input_features')
print("Processed input_features shape: ", audio_inputs["input_features"][sample_num].shape)
specshow_tensor = torch.transpose(audio_inputs["input_features"][sample_num,0], 0, 1).numpy()
librosa.display.specshow(specshow_tensor, sr=44100, x_axis='time', y_axis='mel', hop_length=480, cmap='coolwarm', fmax=14000, fmin=50, n_fft=1024, win_length=1024)

librosa.display.plt.subplot(2, 2, 2)
librosa.display.plt.title('Slice')
librosa.display.plt.plot(specshow_tensor[:, 100], color="blue")


def pad_mel_dbifier(manual_processing_inputs):
    manual_processed_input_features = []
    for manual_processing_input in manual_processing_inputs:
        max_length = 48000 * 10
        assert manual_processing_input.shape[0] <= max_length, "Input is too long"
        if manual_processing_input.shape[0] < max_length:
                    n_repeat = int(max_length / len(manual_processing_input))
                    stacked_manual_processing_input = np.stack(np.tile(manual_processing_input, n_repeat))
                    padded_manual_processing_input = np.pad(stacked_manual_processing_input, (0, max_length - manual_processing_input.shape[0]))
        else:
            padded_manual_processing_input = manual_processing_input
        sample_melled = librosa.feature.melspectrogram(y=padded_manual_processing_input, sr=48000, n_fft=1024, hop_length=480, win_length=1024, window='hann', norm='slaney', n_mels=64, power=2.0) # "mel"
        sample_melled_and_dbed = librosa.power_to_db(sample_melled) # TODO: resulting values look lower; inspect
        manual_processed_input_feature = np.expand_dims(sample_melled_and_dbed.transpose(), 0)
        features_np_array = manual_processed_input_feature[:, :1001, :] # HACkHACkHACkHACk put me in jail
        manual_processed_input_features.append(features_np_array)
        # TODO: trunkation is weird; inspect

    return {"input_features": torch.from_numpy(np.stack(manual_processed_input_features, axis=0)).to(torch.float32), "is_longer": torch.Tensor([False] * len(manual_processing_inputs)).unsqueeze(1).to(torch.bool)}

manual_processed_input_features = pad_mel_dbifier(audio_sample[sample_num-1:sample_num+1]) # slicing for speed :/

librosa.display.plt.subplot(2, 2, 4)
librosa.display.plt.title('Manually Processed audio_sample')
print("Manual processed input_features shape: ", manual_processed_input_features["input_features"][sample_num].shape)
manually_processed_specshow_tensor = torch.transpose(manual_processed_input_features["input_features"][sample_num,0], 0, 1).numpy()
librosa.display.specshow(manually_processed_specshow_tensor, sr=44100, x_axis='time', y_axis='mel', hop_length=480, cmap='coolwarm', fmax=14000, fmin=50, n_fft=1024, win_length=1024)

librosa.display.plt.subplot(2, 2, 2)
librosa.display.plt.plot(manually_processed_specshow_tensor[:, 100], color="red")


librosa.display.plt.show()

In [None]:
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
from transformers.audio_utils import mel_filter_bank

# Configuration parameters
n_fft = 1024
n_mels = 64
sr = 48000
fmin = 50
fmax = 14000

# Your mel_filter_bank function needs to be defined as you have it

# Generate the Mel filter banks using both librosa and your custom function
librosa_filters = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
your_mel_filters = mel_filter_bank(
    num_frequency_bins=513,
    num_mel_filters=64,
    max_frequency=14000,
    min_frequency=50,
    sampling_rate=48000,
    mel_scale='slaney',
    norm='slaney',
).transpose() # This should be the output of your custom function

# Plotting both filter banks for comparison
plt.figure(figsize=(20, 10))

plt.subplot(1, 2, 1)
librosa.display.specshow(librosa_filters, sr=sr, hop_length=512, x_axis='log')
plt.title('Librosa Mel Filter Bank')
plt.colorbar(format='%+2.0f dB')

plt.subplot(1, 2, 2)
librosa.display.specshow(your_mel_filters, sr=sr, hop_length=512, x_axis='log')
plt.title('Your Mel Filter Bank')
plt.colorbar(format='%+2.0f dB')

plt.tight_layout()
plt.show()

librosa.display.specshow(librosa_filters - your_mel_filters, sr=sr, hop_length=512, x_axis='log')
plt.title('Difference between filter banks')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
plt.show()


...oh well

In [None]:
print(get_shapes_in_nests(audio_inputs))
print(audio_inputs["input_features"].dtype)
print(audio_inputs["is_longer"].dtype)
print(get_shapes_in_nests(manual_processed_input_features))
print(manual_processed_input_features["input_features"].dtype)
print(manual_processed_input_features["is_longer"].dtype)

### Actual Model Assessment
I went to all this trouble... does this model even work well?

In [None]:
import matplotlib.pyplot as plt
import torch.nn.functional as F

def assess_model(text_input, text_preprocessor, text_model, audio_input, audio_preprocessor, audio_model):
    text_embeds = text_model(text_preprocessor(text_input))
    audio_embeds = audio_model(audio_preprocessor(audio_input))

    for i in range(len(text_input)):
        plt.subplot(len(text_input), 1,i+1)
        plt.title(f"{i}: {text_input[i]}")
        plt.imshow(text_embeds[i].reshape(8,-1).detach().numpy())
    plt.show()
    
    for i in range(len(audio_input[0:5])):
        plt.subplot(len(audio_input[0:5]), 1,i+1)
        plt.title(f"Audio {i}")
        plt.imshow(audio_embeds[i].reshape(8,-1).detach().numpy())
    plt.show()

    # Cosine Similarities

    norm_text_embeds = F.normalize(text_embeds, p=2, dim=1)
    norm_audio_embeds = F.normalize(audio_embeds, p=2, dim=1)

    cosine_similarities=F.cosine_similarity(text_embeds.unsqueeze(1), audio_embeds.unsqueeze(0), dim=2)
    plt.title("Cosine Similarities")
    plt.imshow(cosine_similarities.detach().numpy())
    plt.xlabel("Audio")
    plt.ylabel("Text")
    plt.show()

    # Top-3 and Bottom-3 Cosine Similarities for each text input
    for i in range(len(text_input)):
        print(text_input[i])
        print("Top 3")
        top_3_indices = cosine_similarities[i].argsort(descending=True)[0:3]
        print(top_3_indices.tolist())
        for j in top_3_indices:
            display(Audio(data=audio_sample[j], rate=44100))
        print("Bottom 3")
        bottom_3_indices = cosine_similarities[i].argsort(descending=False)[0:3]
        print(bottom_3_indices.tolist())
        for j in bottom_3_indices:
            display(Audio(data=audio_sample[j], rate=44100))
        print()
        
    return text_embeds, audio_embeds

## Assess ONNX Models

In [None]:
# Reverse-engineered processor + Onnx model

import onnxruntime

text_ort_session = onnxruntime.InferenceSession(make_filename_or_dirname(f"{model_name.split('/')[-1]}_text_with_projection", "onnx"))
audio_ort_session = onnxruntime.InferenceSession(make_filename_or_dirname(f"{model_name.split('/')[-1]}_audio_with_projection", "onnx"))

print([input.name for input in text_ort_session.get_inputs()])
print([output.name for output in text_ort_session.get_outputs()])
print([input.name for input in audio_ort_session.get_inputs()])
print([output.name for output in audio_ort_session.get_outputs()])

In [None]:
def np_dict_to_ortvalue(dict):
   return {key: onnxruntime.OrtValue.ortvalue_from_numpy(value.numpy() if isinstance(value, torch.Tensor) else value) for key, value in dict.items()} 

In [None]:
# Compare preprocessors
text_input = ["the sound of a gunshot", "the sound of a crowd", "the sound of a dog", "the sound of a bird", "the sound of applause", "the sound of a car"]
audio_input = [datum["array"] for datum in dataset["train"]["audio"][0:31]]

In [None]:
# from_pretrained tokenizer and processor
QuickTimer.start()
print("from_pretrained tokenizer and audio processor")
hf_text_embeds, hf_audio_embeds = assess_model(
    text_input=text_input,
    text_preprocessor=lambda input: tokenizer(input, padding=True, return_tensors="pt"),
    text_model=lambda processed_input: text_model(**processed_input).text_embeds,
    audio_input=audio_sample,
    audio_preprocessor=lambda input: processor(audios=input, return_tensors="pt", sampling_rate=48000),
    audio_model=lambda processed_input: audio_model(**processed_input).audio_embeds
)
print("from_pretrained tokenizer and audio processor took: ", QuickTimer.stop())

# Reverse-engineered processor
QuickTimer.start()
print("Reverse-engineered audio processor")
re_text_embeds, re_audio_embeds = assess_model(
    text_input=text_input,
    text_preprocessor=lambda input: tokenizer(input, padding=True, return_tensors="pt"),
    text_model=lambda processed_input: text_model(**processed_input).text_embeds,
    audio_input=audio_sample,
    audio_preprocessor=lambda input: pad_mel_dbifier(input),
    audio_model=lambda processed_input: audio_model(**processed_input).audio_embeds
)
print("Reverse-engineered audio processor took: ", QuickTimer.stop())

# Onnx processor
QuickTimer.start()
print("Onnx audio processor")
ort_text_embeds, ort_audio_embeds = assess_model(
    text_input=text_input,
    text_preprocessor=lambda input: np_dict_to_ortvalue(tokenizer(input, padding=True, return_tensors="np", pad_to_multiple_of=20)),
    text_model=lambda processed_input: torch.from_numpy(text_ort_session.run(['text_embeds', 'last_hidden_state'], processed_input)[0]),
    audio_input=audio_sample,
    audio_preprocessor=lambda input: np_dict_to_ortvalue({'input_features': pad_mel_dbifier(input)['input_features']}),
    audio_model=lambda processed_input: torch.from_numpy(audio_ort_session.run(['audio_embeds', 'last_hidden_state'], processed_input)[0])
)
print("Onnx audio processor took: ", QuickTimer.stop())


In [None]:
# Check results

np.testing.assert_allclose(re_text_embeds.detach().numpy(), ort_text_embeds.detach().numpy(), rtol=1e-03, atol=1e-05)
np.testing.assert_allclose(re_audio_embeds.detach().numpy(), ort_audio_embeds.detach().numpy(), rtol=1e-03, atol=1e-05)

# All good :)