# Demo for Toward Universal Text-to-Music Retrieval

- arXiv: https://arxiv.org/abs/2211.14558
- pretrained model: https://zenodo.org/record/7322135
- github repo: https://github.com/seungheondoh/music-text-representation
- demo site: https://seungheondoh.github.io/text-music-representation-demo

In [1]:
import os
import json
import pickle
import torch
from torch import nn
import numpy as np
import pandas as pd
import IPython.display as ipd
from IPython.display import Audio, HTML

import argparse
from mtr.utils.demo_utils import get_model
from mtr.utils.eval_utils import _text_representation
import warnings
warnings.filterwarnings(action='ignore')

## Load Pretrained Model & Metadata

In [2]:
# check https://github.com/seungheondoh/msd-subsets
# your_msd_path = "dataset"
# msd_path = os.path.join(your_msd_path, "msd-subsets/dataset")
msd_path = '/home/minhee/userdata/music-text-representation-minigb/dataset'

In [3]:
global msd_to_id
global id_to_path

In [4]:
msd_to_id = pickle.load(open(os.path.join(msd_path, "lastfm_annotation", "MSD_id_to_7D_id.pkl"), 'rb'))
id_to_path = pickle.load(open(os.path.join(msd_path, "lastfm_annotation", "7D_id_to_path.pkl"), 'rb'))
annotation = json.load(open(os.path.join(msd_path, "ecals_annotation/annotation.json"), 'r'))

In [5]:
def pre_extract_audio_embedding(framework, text_type, text_rep):
    ecals_test = torch.load(f"../mtr/{framework}/exp/transformer_cnn_cf_mel/{text_type}_{text_rep}/audio_embs.pt")
    msdid = [k for k in ecals_test.keys()]
    audio_embs = [ecals_test[k] for k in msdid]
    audio_embs = torch.stack(audio_embs)
    return audio_embs, msdid

def model_load(framework, text_type, text_rep):
    audio_embs, msdid = pre_extract_audio_embedding(framework, text_type, text_rep)
    model, tokenizer, config = get_model(framework=framework, text_type=text_type, text_rep=text_rep)
    return model, audio_embs, tokenizer, msdid

## Query

In [6]:
def retrieval_fn(query, tokenizer, model, audio_embs, msdid, annotation):
    text_input = tokenizer(query, return_tensors="pt")['input_ids']
    with torch.no_grad():
        text_embs = model.encode_bert_text(text_input, None)
    audio_embs = nn.functional.normalize(audio_embs, dim=1)
    text_embs = nn.functional.normalize(text_embs, dim=1)
    logits = text_embs @ audio_embs.T
    ret_item = pd.Series(logits.squeeze(0).numpy(), index=msdid)
    metadata = {}
    for idx, _id in enumerate(ret_item.sort_values(ascending=False).head(3).index):
        meta_obj = annotation[_id]
        metadata[f'top{idx+1} music'] = meta_obj['tag']
    return metadata

def retrieval_show(framework, text_type, text_rep, annotation, query, is_audio=False):    
    model, audio_embs, tokenizer, msdid = model_load(framework, text_type, text_rep)
    meta_results, retrieval_results = [], []
    for i in query:
        metadata = retrieval_fn(i, tokenizer, model, audio_embs, msdid, annotation)
        meta_results.append(metadata)
    if is_audio:
        inference = pd.DataFrame(retrieval_results, index=query)
        html = inference.to_html(escape=False)
    else:
        inference = pd.DataFrame(meta_results, index=query)
        html = inference.to_html(escape=False)
    ipd.display(HTML(html))


In [7]:
tag_query = "banjo"
# caption_query = "fusion jazz with synth, bass, drums, saxophone"
# unseen_query = "music for meditation or listen to in the forest"
query = [tag_query]

In [8]:
framework='contrastive' # triplet
text_type='bert' # tag, caption
text_rep="stochastic"


In [9]:
retrieval_show(framework, text_type, text_rep, annotation, query, is_audio=False)

Unnamed: 0,top1 music,top2 music,top3 music
banjo,"[weary, alternative indie rock, long walk, maverick, searching, pop rock, passionate, reserved, reflection, contemporary folk, delicate, gentle, drinking, light, cathartic, banjo, somber, earthy, the great outdoors, soothing, comfort, alternative singer songwriter, autumnal, rustic, intimate, folk, autumn, bittersweet, acoustic]","[weary, alternative indie rock, long walk, maverick, searching, pop rock, california, passionate, reserved, reflection, contemporary folk, delicate, gentle, drinking, light, cathartic, banjo, somber, earthy, the great outdoors, soothing, comfort, alternative singer songwriter, autumnal, rustic, intimate, folk, autumn, bittersweet, acoustic]","[weary, alternative indie rock, long walk, maverick, searching, pop rock, passionate, reserved, reflection, contemporary folk, delicate, gentle, drinking, light, cathartic, banjo, somber, earthy, the great outdoors, soothing, comfort, alternative singer songwriter, autumnal, rustic, intimate, folk, autumn, bittersweet, acoustic]"
