# Querying Audio with CLAP embeddings

## In this walkthrough, we will be using a dataset of audio files and embed them using the CLAP model (https://huggingface.co/docs/transformers/v4.30.0/en/model_doc/clap#transformers.ClapModel)

## Installation Requirements

In [None]:
!pip install librosa
!pip install datasets
!pip install transformers
!pip install torch

In [None]:
from datasets import load_dataset
from transformers import AutoProcessor, ClapModel, AutoTokenizer
import numpy as np
import torch
import vexpresso
from vexpresso.utils import ResourceRequest, DataType

## Load Data

Here we load a dataset of audio files from https://huggingface.co/datasets/ashraq/esc50

In [None]:
dataset = load_dataset("ashraq/esc50")

Convert to dictionary

In [None]:
dictionary = dataset['train'].to_dict()
audios = dataset['train']['audio']
dictionary['audio'] = audios

## Create Collection

Lets create a collection with the audios that we downloaded!

In [None]:
collection = vexpresso.create(data=dictionary, backend="ray")

In [None]:
collection.show(5)

Let's filter out the B takes

In [None]:
collection = collection.filter({"take":{"eq":"A"}}).execute()

In [None]:
collection.show(5)

Lets take a look at the different categories!

In [None]:
np.unique(collection["category"].to_list())

Because this is a demo, let's only get one sound from each category

In [None]:
def unique_filter(category):
    unique_set = set([])
    out = []
    for c in category:
        if c not in unique_set:
            out.append("valid")
            unique_set.add(c)
        else:
            out.append(None)
    return out

In [None]:
collection = collection.apply(unique_filter, collection["category"], to="filter_valid").filter({"filter_valid":{"eq":"valid"}}).execute()

In [None]:
collection.show(5)

## Multimodal CLAP Embedding function

In [None]:
class ClAPEmbeddingsFunction:
    def __init__(self):

        self.model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
        self.processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused")
        self.tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
        self.device = torch.device('cpu')

        if torch.cuda.is_available():
            self.device = torch.device('cuda')
            self.model = self.model.to(self.device)

    def __call__(self, inp, inp_type):
        if inp_type == "audio":
            inputs = self.processor(audios=inp, return_tensors="pt", padding=True)
            print(inputs.keys())
            for k in inputs:
                inputs[k] = inputs[k].to(self.device)
            return self.model.get_audio_features(**inputs).detach().cpu().numpy()
        if inp_type == "text":
            inputs = self.tokenizer(inp, padding=True, return_tensors="pt")
            inputs["input_ids"] = inputs["input_ids"].to(self.device)
            inputs["attention_mask"] = inputs["attention_mask"].to(self.device)
            return self.model.get_text_features(**inputs).detach().cpu().numpy()

## Now lets embed the audio arrays!

This may take a while because we're embedding 2000 audio files

In [None]:
collection = collection.embed(collection["audio.array"], inp_type="audio", embedding_fn=ClAPEmbeddingsFunction, to="audio_embeddings", resource_request=ResourceRequest(num_gpus=1)).execute()

In [None]:
collection.show(5)