## **shira**
#### a neural audio search engine
(for local usage)

In [None]:
! pip install -q faiss-cpu faiss-gpu

In [None]:
"""
testing ground :)
"""
import torch, gc
import faiss
import numpy as np
import librosa, pydub, os, time, glob
from datasets import load_dataset, Dataset
from transformers import ClapModel, ClapProcessor
from typing import Union
from IPython.display import Audio as idp_audio
from functools import wraps
from tqdm.auto import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "laion/larger_clap_music_and_speech"
sample_rate = 22400
max_duration = 10
batch_size = 16 # for batched mapping
input_path = "."

In [None]:
def latency(func): # decorator to measure execution time
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"latency => {func.__name__}: {end_time - start_time:.4f} seconds")
        return result

    return wrapper


# crawl all the local audio files and retrun a single list
@latency
def audiofile_crawler(root_dir: str, extensions: list =["*.wav", "*.mp3"]) -> list:
    audio_files = []

    for ext in tqdm(extensions):
        for directory, _, _ in os.walk(root_dir):
            audio_files.append(glob.glob(os.path.join(directory, ext)))

    print(f"found {len(audio_files)} images in {root_dir}")

    return audio_files


def read_audio(audio_file: str) -> np.ndarray: #read audio file into numpy array/torch tensor from file path
    if not audio_file.endswith(".wav"):
        audio_file = mp3_to_wav(audio_file)
    waveform, _ = librosa.load(audio_file, sr=sample_rate)
    waveform = trimpad_audio(waveform)

    return waveform


# converting mp3 files to .wav for loading
def mp3_to_wav(file: str) -> str:
    outpath = os.path.basename(file).split(".")[0]
    outpath = f"{outpath}.wav" # full fileame derived from original
    sound = pydub.AudioSegment.from_mp3(file)
    sound.export(outpath)

    return outpath

# trimming audio to a fixed length for all tasks
def trimpad_audio(audio: np.ndarray) -> np.ndarray:
    samples = int(sample_rate * max_duration) # calculate total number of samples

    # cut off excess samples if beyong length, or pad to req. length
    if len(audio) > samples:
        audio = audio[:samples]
    else:
        pad_width = samples - len(audio)
        audio = np.pad(audio, (0, pad_width), mode="reflect")

    return audio

# displays platable audio widget, for notebooks 
def display_audio(audio: Union[np.ndarray, str], srate: int = 22400):
    if isinstance(audio, np.ndarray):
        idp_audio(data=audio, rate=srate)
        
    else:
        idp_audio(filename=audio, rate=srate)


In [None]:
audiofiles = audiofile_crawler(input_path)

In [None]:
music_data = load_dataset('audiofolder', data_files=audiofiles, split="train")
# music_data = Dataset.from_dict({'audio': [audiofiles]})

music_data

In [None]:
clap_model = ClapModel.from_pretrained(model_id).to(device)
clap_processor = ClapProcessor.from_pretrained(model_id)

In [None]:
gc.collect()

In [None]:
%%time

def embed_audio_batch(batch):
    sample = batch["audio"]['array']
    coded_audio = clap_processor(
        audios=sample, 
        return_tensors="pt", 
        sampling_rate=48000
    )["input_features"]

    audio_embed = clap_model.get_audio_features(coded_audio)

    batch["audio_embeddings"] = audio_embed[0]

    return batch


embedded_data = music_data.map(embed_audio_batch)#, num_proc=4)#, batched=True, batch_size=batch_size)

In [None]:
%%time

embedded_data.add_faiss_index(column="audio_embeddings")

#### for audio-audio retrieval
##### like shazam, but slower

In [None]:
@latency
def audio_search(input_audio, embedded_data, k_count: int=2, device: torch.device=device):
    if not isinstance(input_audio, np.ndarray):  
        input_audio = read_audio(input_audio)  # loads audio file from wav to ndarray

    audio_values = clap_processor(audios=input_audio, return_tensors="pt", sampling_rate=sample_rate)["input_features"] # type: ignore
    audio_values = audio_values.to(device)
    
    wav_embed = clap_model.get_audio_features(audio_values)[0]
    wav_embed = wav_embed.detach().cpu().numpy()

    scores, retrieved_audio = embedded_data.get_nearest_examples(
        "audio_embeddings", wav_embed, k=k_count
    )
    
    return retrieved_audio, scores


audiofile = "/kaggle/input/sample-music/beethoven_sonata.mp3"
similar_audio, scores = audio_search(audiofile, embedded_data)  # search for similar audio files

# similar_audio[0]

In [None]:
similar_audio["audio"], scores

In [None]:
top_file = similar_audio["audio"]["path"][0]

display_audio(top_file)

#### for text retrieval

In [None]:
@latency
def text_search(
    text_query: str, embedded_data: Dataset, k_count: int = 4, device: torch.device = device
):

    encoded_text = clap_processor(text=text_query, return_tensors="pt")["input_ids"]  # type: ignore
    encoded_text = encoded_text.to(device)

    text_embed = clap_model.get_text_features(encoded_text)[0]
    text_embed = text_embed.detach().cpu().numpy()

    scores, retrieved_audio = embedded_data.get_nearest_examples("audio_embeddings", text_embed, k=k_count)

    return retrieved_audio, scores


text_q = "classical music"
similar_samples, t_scores = text_search(text_q, embedded_data)  # search for similar audio files

In [None]:
# displays platable audio widget, for notebooks
# def display_audio(audio: Union[np.ndarray, str], srate: int = 22400):
#     if isinstance(audio, np.ndarray):
#         idp_audio(data=audio, rate=srate)
#     else:
#         idp_audio(filename=audio, rate=srate)

### library code sample

In [None]:
from shira import AudioSearch, AudioEmbedding

In [None]:
embedder = AudioEmbedding(data_path='.') # init embedder class

audio_data_embeds = embedder.index_files() # create embeddings and index audio files

In [None]:
neural_search = AudioSearch() # init semantic search class

text_query = 'classical music' # text description for search

# get k similar audio w/probability scores pairs 
matching_samples, scores = neural_search.text_search(text_query, audio_data_embeds, k_count=5)

matching_samples[0]['audio']['path'] # get file path for top sample

In [None]:
# cli/terminal usage

