In [None]:
!pip install datasets faiss-cpu beir



In [None]:
from datasets import load_dataset
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.evaluation import EvaluateRetrieval
from transformers import AutoTokenizer, AutoModelForMaskedLM, BertConfig
from collections import defaultdict
import faiss
import torch
import random
import numpy as np
import pandas as pd

In [None]:
# mount drive
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/cs566

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/cs566


In [None]:
# Load the embeddings from the JSONL file
embeddings_df = pd.read_json('../data/embeddings.jsonl', orient='records', lines=True)

In [None]:
def normalize_embedding(x):
    # Convert to numpy array if it's not already
    arr = np.array(x)

    # If it's 3D with shape (1, 64, 322), take mean across last dimension
    if len(arr.shape) == 3:
        arr = np.mean(arr, axis=2).squeeze()

    # Normalize the vector
    norm = np.linalg.norm(arr)
    if norm > 0:  # Avoid division by zero
        return arr / norm
    return arr

In [None]:
import pandas as pd
import numpy as np

# Convert embedding columns from lists to numpy arrays
embedding_columns = ["audio", "q_audio", "q_audio_eq", "q_audio_pitch", "q_audio_back"]

for col in embedding_columns:
    # Check if the column exists in the DataFrame
    if col in embeddings_df.columns:
        # Convert to numpy array, take mean across last dimension if needed, and normalize
        embeddings_df[col] = embeddings_df[col].apply(
            lambda x: normalize_embedding(x) if x is not None else None
        )

In [None]:
for col in embedding_columns:
    shape = embeddings_df.iloc[0][col].shape
    print(col, shape)

embeddings_df.head()

audio (64,)
q_audio (64,)
q_audio_eq (64,)
q_audio_pitch (64,)
q_audio_back (64,)


Unnamed: 0,audio,q_audio,q_audio_eq,q_audio_pitch,q_audio_back,pid,title,genres,qid,q_audio_back_info
0,"[0.13033324234842722, 0.13417750204708795, 0.0...","[0.11587900333834293, 0.12401565122748352, 0.0...","[0.12361149213402842, 0.08569024377474944, -0....","[-0.022172729114872877, 0.0682367427756499, 0....","[0.08193488012902791, 0.10629642949916245, 0.0...",song_0,Food,[70],qid_0,3-117504-B-16.wav (wind)
1,"[-0.111927407013601, 0.23225223078903065, 0.03...","[-0.13729215783762497, 0.2233969509375424, -0....","[-0.1288448923880034, 0.13283360917081352, -0....","[0.025730517593575374, 0.20571289056285377, -0...","[-0.15110075111340357, 0.26071315684619345, -0...",song_1,This World,[70],qid_1,4-167063-C-11.wav (sea_waves)
2,"[-0.09153404775128864, -0.04951818587497843, 0...","[-0.13892336207469053, -0.0707909017870902, 0....","[-0.14701611537107098, -0.13739532556921727, 0...","[-0.16529951159921763, 0.06048774579416077, -0...","[-0.14310108234538432, -0.06965594175531208, 0...",song_2,Freeway,[116],qid_2,4-163609-B-16.wav (wind)
3,"[0.09102756489695284, -0.28139377965067186, -0...","[-0.015600688615388797, -0.35074546894268466, ...","[0.0010969938118120057, -0.33042127660696685, ...","[0.03251060224491349, -0.2063245191831978, -0....","[-0.008864936627962226, -0.3238854457726845, -...",song_3,Queen Of The Wires,[58],qid_3,2-39443-A-19.wav (thunderstorm)
4,"[0.007645800177216195, -0.25536686929924646, -...","[0.11166974030526089, -0.3819367547948332, 0.0...","[0.13474163327979052, -0.4215778169455545, 0.0...","[0.0462420689957895, -0.2502491246395871, 0.05...","[0.08413391425629083, -0.177270629188624, -0.1...",song_4,Ohio,[58],qid_4,1-137296-A-16.wav (wind)


## Set Dataset Splits

In [None]:
# Determine train/test split indices
train_test_split=0.7
random.seed(42)
dataset_size = len(embeddings_df)
train_size = int(dataset_size * train_test_split)
indices = list(range(dataset_size))
random.shuffle(indices)
train_indices = set(indices[:train_size])
test_indices = set(indices[train_size:])

# extract rows for test_indices from dataframe
test_df = embeddings_df[embeddings_df.index.isin(test_indices)]

## Perform Retrieval

In [None]:
import pandas as pd

# Load the TSV file with the qrels data
def load_qrels_from_tsv(file_path):
    """
    Load qrels from a TSV file into a dictionary format required by EvaluateRetrieval.

    Args:
        file_path (str): Path to the TSV file containing qrels data

    Returns:
        dict: A nested dictionary of {query_id: {doc_id: relevance_score}}
    """
    # Read the TSV file
    # Assuming format: query_id, 0, doc_id, relevance_score
    # The second column (0) is typically an iteration which we can ignore
    df = pd.read_csv(file_path, sep='\t', header=None,
                     names=['query_id', 'iteration', 'doc_id', 'relevance'])

    # Convert to the required dictionary format
    qrels_dict = {}
    for _, row in df.iterrows():
        query_id = str(row['query_id'])
        doc_id = str(row['doc_id'])
        relevance = int(row['relevance'])

        # Initialize the inner dictionary if needed
        if query_id not in qrels_dict:
            qrels_dict[query_id] = {}

        # Add the document relevance
        qrels_dict[query_id][doc_id] = relevance

    return qrels_dict

def format_retrievals_faiss(qids, retrieved_pids, scores):
    """
    Format FAISS search results for BEIR evaluation

    Parameters:
    -----------
    qids : list or Series
        List of query IDs
    retrieved_pids : list of lists
        List of lists containing retrieved document IDs for each query
    scores : numpy.ndarray
        Matrix of similarity scores from FAISS search

    Returns:
    --------
    dict
        Dictionary mapping query IDs to {doc_id: score} dictionaries
    """
    retrievals = {}

    # Convert qids to list if it's a pandas Series
    if hasattr(qids, 'tolist'):
        qids = qids.tolist()

    for i, qid in enumerate(qids):
        # Make sure qid is a string
        qid_str = str(qid)
        retrievals[qid_str] = {}

        for j, pid in enumerate(retrieved_pids[i]):
            # Make sure pid is a string
            pid_str = str(pid)
            # Convert numpy float to Python float
            score = float(scores[i][j])
            retrievals[qid_str][pid_str] = score

    # Validate structure
    if len(retrievals) == 0:
        print("Warning: Empty retrievals dictionary")
    else:
        sample_qid = next(iter(retrievals))
        if len(retrievals[sample_qid]) == 0:
            print(f"Warning: No documents for query {sample_qid}")

    return retrievals

In [None]:
# download qrels
qrels_file_path = "data/qrels.tsv"
qrels = load_qrels_from_tsv(qrels_file_path)

# extract query embeddings
doc_embeddings = np.vstack(embeddings_df["audio"])

# get unique result for each query
query_columns = ["q_audio", "q_audio_eq", "q_audio_pitch", "q_audio_back"]
for query in query_columns:
    query_embeddings = np.vstack(test_df[query])

    # stack embeddings
    k = 10
    d = embeddings_df.iloc[0]['audio'].shape[0]
    index = faiss.IndexFlatIP(d)
    index.add(doc_embeddings)
    D, I = index.search(query_embeddings, k)

    # extract qids and pids
    qids = test_df["qid"]
    pids = embeddings_df["pid"]
    retrieved_pids = [[pids[idx] for idx in row] for row in I]
    retrievals = format_retrievals_faiss(qids, retrieved_pids, D)

    # obtain retrievals
    k_values = [1, 3, 5, 10]
    ndcg, map, recall, precision = EvaluateRetrieval.evaluate(qrels, retrievals, k_values)
    print(f"\nResults for {query}:")
    print(f"NDCG: {ndcg}")
    print(f"MAP: {map}")
    print(f"Recall: {recall}")
    print(f"Precision: {precision}")


Results for q_audio:
NDCG: {'NDCG@1': 0.96084, 'NDCG@3': 0.97486, 'NDCG@5': 0.97573, 'NDCG@10': 0.97722}
MAP: {'MAP@1': 0.96084, 'MAP@3': 0.97179, 'MAP@5': 0.97227, 'MAP@10': 0.97289}
Recall: {'Recall@1': 0.96084, 'Recall@3': 0.98358, 'Recall@5': 0.98568, 'Recall@10': 0.99032}
Precision: {'P@1': 0.96084, 'P@3': 0.32786, 'P@5': 0.19714, 'P@10': 0.09903}

Results for q_audio_eq:
NDCG: {'NDCG@1': 0.90316, 'NDCG@3': 0.93326, 'NDCG@5': 0.93698, 'NDCG@10': 0.9405}
MAP: {'MAP@1': 0.90316, 'MAP@3': 0.92604, 'MAP@5': 0.92814, 'MAP@10': 0.92958}
Recall: {'Recall@1': 0.90316, 'Recall@3': 0.95411, 'Recall@5': 0.96295, 'Recall@10': 0.97389}
Precision: {'P@1': 0.90316, 'P@3': 0.31804, 'P@5': 0.19259, 'P@10': 0.09739}

Results for q_audio_pitch:
NDCG: {'NDCG@1': 0.30358, 'NDCG@3': 0.34924, 'NDCG@5': 0.36379, 'NDCG@10': 0.38108}
MAP: {'MAP@1': 0.30358, 'MAP@3': 0.33796, 'MAP@5': 0.34603, 'MAP@10': 0.35316}
Recall: {'Recall@1': 0.30358, 'Recall@3': 0.38189, 'Recall@5': 0.41726, 'Recall@10': 0.47074}
P