In [3]:
tauri_onnx_models_directory = "../SonicSearch/src-tauri/onnx_models/"

In [1]:
# Transformers Export to TorchScript
# https://huggingface.co/docs/transformers/v4.27.2/en/model_doc/clap

from datasets import load_dataset
from transformers import AutoProcessor, ClapModel

dataset = load_dataset("ashraq/esc50")
audio_sample = dataset["train"]["audio"][0]["array"]

model = ClapModel.from_pretrained("laion/clap-htsat-unfused", torchscript=True)
model.eval()
processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused", torchscript=True)

input_text = ["The sound of a moderate-length input string"]

inputs = processor(text=input_text, audios=audio_sample, return_tensors="pt", padding=True)

outputs = model(**inputs)

Repo card metadata block was not found. Setting CardData to empty.
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


In [4]:
# Onnx Export

from torch import onnx
import time

onnx_inputs = (inputs["input_ids"], inputs["input_features"], False, inputs["attention_mask"])
onnx_input_names = ["input_ids", "input_features", "is_longer", "attention_mask"]

print("Exporting model to ONNX...")
start = time.time()
onnx.export(
    model,
    onnx_inputs,
    tauri_onnx_models_directory + "laion_clap_htsat_unfused.onnx",
    export_params=True,
    input_names=onnx_input_names,
    output_names=model(**inputs, return_dict=True).keys(),
)
print("Exporting model to ONNX took: ", time.time() - start)

Exporting model to ONNX...


  if time_length > spec_width or freq_length > spec_heigth:
  if time_length < spec_width:
  if freq_length < spec_heigth:
  if height != self.img_size[0] or width != self.img_size[1]:
  batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


Exporting model to ONNX took:  24.787407875061035


In [None]:
# TorchScript export via Jit trace
# Probably don't run thisâ€”if this works at all, it takes an unknown amount of time

from torch import Tensor, jit
import time

text_features_func = lambda input_ids, attention_mask: model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
tokenized_inputs = processor.tokenizer(input_text[0], return_tensors="pt", padding=True)
text_features_dummy_input = (tokenized_inputs["input_ids"], tokenized_inputs["attention_mask"])

audio_features_func = lambda audio_tensor: model.get_audio_features(**processor.feature_extractor(audio_tensor, return_tensors="pt", padding=True))
audio_features_dummy_input = Tensor(audio_sample)

start = time.time()
text_features_dummy_output = text_features_func(*text_features_dummy_input)
print("Text features test shape: ", text_features_dummy_output.shape)
print("Text features test time: ", time.time() - start)

start = time.time()
audio_features_dummy_output = audio_features_func(audio_features_dummy_input)
print("Audio features test shape: ", audio_features_dummy_output.shape)
print("Audio features test time: ", time.time() - start)

print("Tracing text features model")
start = time.time()
jit.trace(text_features_func, text_features_dummy_input).save("laion_clap_htsat_unfused_get_text_features.pt")
print("Tracing text features model took: ", time.time() - start)

print("Tracing audio features model")
start = time.time()
jit.trace(audio_features_func, audio_features_dummy_input).save("laion_clap_htsat_unfused_get_audio_features.pt")
print("Tracing audio features model took: ", time.time() - start)