# Compare Candidate Search Approaches

Code authored by: Shaw Talebi <br>

Video link: https://youtu.be/6qCrvlHRhcM <br>


### imports

In [81]:
import polars as pl

from sentence_transformers import SentenceTransformer, util

from sklearn.metrics import DistanceMetric
import numpy as np

import matplotlib.pyplot as plt

### load data

In [82]:
df = pl.read_parquet('data/video-transcripts.parquet')
df_eval = pl.read_csv('data/eval-raw.csv')
df.head()

video_id,datetime,title,transcript
str,datetime[μs],str,str
"""03x2oYg9oME""",2024-04-25 15:16:00,"""Data Science Project Managemen…","""this video is part of a larger…"
"""O5i_mMUM94c""",2024-04-19 14:05:54,"""How I’d learned #datascience (…","""here's how I'd learn data scie…"
"""xm9devSQEqU""",2024-04-18 15:59:02,"""4 Skills You Need to Be a Full…","""although it is common to deleg…"
"""Z6CmuVEi7QY""",2024-04-11 10:00:27,"""How I'd Learn Data Science (if…","""when I was first learning data…"
"""INlCLmWlojY""",2024-04-04 18:45:00,"""I Was Wrong About AI Consultin…","""last year I quit my corporate …"


### embed titles and transcripts

In [83]:
# define "parameters"
column_to_embed_list = ['title', 'transcript']
model_name_list = ["all-MiniLM-L6-v2", "all-distilroberta-v1", "multi-qa-distilbert-cos-v1", "multi-qa-mpnet-base-dot-v1"]

In [84]:
# generate embeddings for each combination of column and model

# initialize dict to keep track of all text embeddings
text_embedding_dict = {}

for model_name in model_name_list:

    #define embedding model
    model = SentenceTransformer(model_name) 

    for column_name in column_to_embed_list:

        # define text embedding identifier
        key_name = model_name + "_" + column_name
        print(key_name)

        # generate embeddings for text under column_name
        %time embedding_arr = model.encode(df[column_name].to_list())
        print('')

        # append embeddings to dict
        text_embedding_dict[key_name] = embedding_arr

all-MiniLM-L6-v2_title
CPU times: total: 1.77 s
Wall time: 225 ms

all-MiniLM-L6-v2_transcript
CPU times: total: 19.9 s
Wall time: 2.38 s



README.md:   0%|          | 0.00/10.1k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/328M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/333 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

all-distilroberta-v1_title
CPU times: total: 4.61 s
Wall time: 590 ms

all-distilroberta-v1_transcript
CPU times: total: 1min 32s
Wall time: 11.5 s

multi-qa-distilbert-cos-v1_title
CPU times: total: 4.5 s
Wall time: 546 ms

multi-qa-distilbert-cos-v1_transcript
CPU times: total: 1min 36s
Wall time: 12 s

multi-qa-mpnet-base-dot-v1_title
CPU times: total: 7.27 s
Wall time: 915 ms

multi-qa-mpnet-base-dot-v1_transcript
CPU times: total: 4min 45s
Wall time: 35.8 s



### embed queries

In [85]:
query_embedding_dict = {}

for model_name in model_name_list:

    #define embedding model
    model = SentenceTransformer(model_name)
    print(model_name)

    # embed query text
    %time embedding_arr = model.encode(df_eval['query'].to_list())
    print('')

    # append embedding to dict
    query_embedding_dict[model_name] = embedding_arr

all-MiniLM-L6-v2
CPU times: total: 2.25 s
Wall time: 282 ms

all-distilroberta-v1
CPU times: total: 6.66 s
Wall time: 835 ms

multi-qa-distilbert-cos-v1
CPU times: total: 6.12 s
Wall time: 760 ms

multi-qa-mpnet-base-dot-v1
CPU times: total: 14.2 s
Wall time: 1.79 s



### Evaluate search methods

In [86]:
def returnVideoID_index(df: pl.dataframe.frame.DataFrame, df_eval: pl.dataframe.frame.DataFrame, query_n: int) -> int:
    """
        Function to return the index of a dataframe corresponding to the nth row in evaluation dataframe
    """

    return [i for i in range(len(df)) if df['video_id'][i]==df_eval['video_id'][query_n]][0]

In [87]:
def evalTrueRankings(dist_arr_isorted: np.ndarray, df: pl.dataframe.frame.DataFrame, df_eval: pl.dataframe.frame.DataFrame) -> np.ndarray:
    """
        Function to return "true" video ID rankings for each evaluation query
    """
    
    # intialize array to store rankings of "correct" search result
    true_rank_arr = np.empty((1, dist_arr_isorted.shape[1]))
    
    # evaluate ranking of correct result for each query
    for query_n in range(dist_arr_isorted.shape[1]):
    
        # return "true" video ID's in df
        video_id_idx = returnVideoID_index(df, df_eval, query_n)
        
        # evaluate the ranking of the "true" video ID
        true_rank = np.argwhere(dist_arr_isorted[:,query_n]==video_id_idx)[0][0]
        
        # store the "true" video ID's ranking in array
        true_rank_arr[0,query_n] = true_rank

    return true_rank_arr

In [88]:
# initialize distance metrics to experiment
dist_name_list = ['euclidean', 'manhattan', 'chebyshev']
sim_name_list = ['cos_sim', 'dot_score']

In [89]:
# evaluate all possible combinations of model, columns to embed, and distance metrics

# initialize list to store results
eval_results = []

# loop through all models
for model_name in model_name_list:

    # generate query embedding
    query_embedding = query_embedding_dict[model_name]
    
    # loop through text columns
    for column_name in column_to_embed_list:

        # generate column embedding
        embedding_arr = text_embedding_dict[model_name+'_'+column_name]

        # loop through distance metrics
        for dist_name in dist_name_list:

            # compute distance between video text and query
            dist = DistanceMetric.get_metric(dist_name)
            dist_arr = dist.pairwise(embedding_arr, query_embedding)

            # sort indexes of distance array
            dist_arr_isorted = np.argsort(dist_arr, axis=0)

            # define label for search method
            method_name = "_".join([model_name, column_name, dist_name])

            # evaluate the ranking of the ground truth
            true_rank_arr = evalTrueRankings(dist_arr_isorted, df, df_eval)

            # store results
            eval_list = [method_name] + true_rank_arr.tolist()[0]
            eval_results.append(eval_list)

        # loop through sbert similarity scores
        for sim_name in sim_name_list:
            # apply similarity score from sbert
            cmd = "dist_arr = -util." + sim_name + "(embedding_arr, query_embedding)"
            exec(cmd)
    
            # sort indexes of distance array (notice minus sign in front of cosine similarity)
            dist_arr_isorted = np.argsort(dist_arr, axis=0)
    
            # define label for search method
            method_name = "_".join([model_name, column_name, sim_name.replace("_","-")])
    
            # evaluate the ranking of the ground truth
            true_rank_arr = evalTrueRankings(dist_arr_isorted, df, df_eval)
    
            # store results
            eval_list = [method_name] + true_rank_arr.tolist()[0]
            eval_results.append(eval_list)

In [90]:
cmd

'dist_arr = -util.dot_score(embedding_arr, query_embedding)'

In [91]:
# compute rankings for title + transcripts embedding
for model_name in model_name_list:
    
    # generate embeddings
    embedding_arr1 = text_embedding_dict[model_name+'_title']
    embedding_arr2 = text_embedding_dict[model_name+'_transcript']
    query_embedding = query_embedding_dict[model_name]

    for dist_name in dist_name_list:

        # compute distance between video text and query
        dist = DistanceMetric.get_metric(dist_name)
        dist_arr = dist.pairwise(embedding_arr1, query_embedding) + dist.pairwise(embedding_arr2, query_embedding)

        # sort indexes of distance array
        dist_arr_isorted = np.argsort(dist_arr, axis=0)

         # define label for search method
        method_name = "_".join([model_name, "title-transcript", dist_name])

        # evaluate the ranking of the ground truth
        true_rank_arr = evalTrueRankings(dist_arr_isorted, df, df_eval)

        # store results
        eval_list = [method_name] + true_rank_arr.tolist()[0]
        eval_results.append(eval_list)

    # loop through sbert similarity scores
    for sim_name in sim_name_list:
        # apply similarity score from sbert
        cmd = "dist_arr = -util." + sim_name + "(embedding_arr1, query_embedding) - util."+ sim_name + "(embedding_arr2, query_embedding)"
        exec(cmd)

        # sort indexes of distance array (notice minus sign in front of cosine similarity)
        dist_arr_isorted = np.argsort(dist_arr, axis=0)

        # define label for search method
        method_name = "_".join([model_name, "title-transcript", sim_name.replace("_","-")])

        # evaluate the ranking of the ground truth
        true_rank_arr = evalTrueRankings(dist_arr_isorted, df, df_eval)

        # store results
        eval_list = [method_name] + true_rank_arr.tolist()[0]
        eval_results.append(eval_list)

In [92]:
len(eval_results)

60

In [93]:
# define schema for results dataframe
schema_dict = {'method_name':str}
for i in range(len(eval_results[0])-1):
    schema_dict['rank_query-'+str(i)] = float

# store results in dataframe
df_results = pl.DataFrame(eval_results, schema=schema_dict)
df_results.head()

  df_results = pl.DataFrame(eval_results, schema=schema_dict)


method_name,rank_query-0,rank_query-1,rank_query-2,rank_query-3,rank_query-4,rank_query-5,rank_query-6,rank_query-7,rank_query-8,rank_query-9,rank_query-10,rank_query-11,rank_query-12,rank_query-13,rank_query-14,rank_query-15,rank_query-16,rank_query-17,rank_query-18,rank_query-19,rank_query-20,rank_query-21,rank_query-22,rank_query-23,rank_query-24,rank_query-25,rank_query-26,rank_query-27,rank_query-28,rank_query-29,rank_query-30,rank_query-31,rank_query-32,rank_query-33,rank_query-34,rank_query-35,rank_query-36,rank_query-37,rank_query-38,rank_query-39,rank_query-40,rank_query-41,rank_query-42,rank_query-43,rank_query-44,rank_query-45,rank_query-46,rank_query-47,rank_query-48,rank_query-49,rank_query-50,rank_query-51,rank_query-52,rank_query-53,rank_query-54,rank_query-55,rank_query-56,rank_query-57,rank_query-58,rank_query-59,rank_query-60,rank_query-61,rank_query-62,rank_query-63
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""all-MiniLM-L6-v2_title_euclide…",0.0,0.0,16.0,0.0,7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,1.0,0.0,2.0,0.0,0.0,0.0,1.0,3.0,1.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,8.0,1.0,0.0,0.0,1.0,0.0,6.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0
"""all-MiniLM-L6-v2_title_manhatt…",0.0,0.0,9.0,0.0,7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,3.0,1.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,7.0,1.0,0.0,0.0,1.0,0.0,3.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,10.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0
"""all-MiniLM-L6-v2_title_chebysh…",0.0,2.0,46.0,0.0,60.0,0.0,0.0,0.0,0.0,0.0,1.0,3.0,0.0,30.0,0.0,0.0,4.0,57.0,0.0,3.0,0.0,24.0,0.0,0.0,0.0,8.0,6.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,43.0,1.0,0.0,0.0,1.0,0.0,6.0,8.0,0.0,1.0,1.0,0.0,3.0,0.0,0.0,0.0,0.0,5.0,5.0,1.0,70.0,11.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0
"""all-MiniLM-L6-v2_title_cos-sim""",0.0,0.0,16.0,0.0,7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,1.0,0.0,2.0,0.0,0.0,0.0,1.0,3.0,1.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,8.0,1.0,0.0,0.0,1.0,0.0,6.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0
"""all-MiniLM-L6-v2_title_dot-sco…",0.0,0.0,16.0,0.0,7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,1.0,0.0,2.0,0.0,0.0,0.0,1.0,3.0,1.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,8.0,1.0,0.0,0.0,1.0,0.0,6.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0


In [94]:
# compute mean rankings of ground truth search result
df_results = df_results.with_columns(new_col=pl.mean_horizontal(df_results.columns[1:])).rename({"new_col": "rank_query-mean"})

In [95]:
# compute number of ground truth results which appear in top 3
for i in [1,3]:
    df_results = df_results.with_columns(new_col=pl.sum_horizontal(df_results[:,1:-1]<i)).rename({"new_col": "num_in_top-"+str(i)})

### Look at top results

In [96]:
df_summary = df_results[['method_name', "rank_query-mean", "num_in_top-1", "num_in_top-3"]]

In [97]:
print(df_summary.sort('rank_query-mean').head())

shape: (5, 4)
┌─────────────────────────────────┬─────────────────┬──────────────┬──────────────┐
│ method_name                     ┆ rank_query-mean ┆ num_in_top-1 ┆ num_in_top-3 │
│ ---                             ┆ ---             ┆ ---          ┆ ---          │
│ str                             ┆ f64             ┆ u32          ┆ u32          │
╞═════════════════════════════════╪═════════════════╪══════════════╪══════════════╡
│ all-MiniLM-L6-v2_title-transcr… ┆ 0.875           ┆ 41           ┆ 60           │
│ all-MiniLM-L6-v2_title_manhatt… ┆ 0.921875        ┆ 44           ┆ 58           │
│ all-MiniLM-L6-v2_title-transcr… ┆ 0.96875         ┆ 41           ┆ 61           │
│ all-MiniLM-L6-v2_title-transcr… ┆ 0.984375        ┆ 41           ┆ 60           │
│ all-MiniLM-L6-v2_title-transcr… ┆ 0.984375        ┆ 41           ┆ 60           │
└─────────────────────────────────┴─────────────────┴──────────────┴──────────────┘


In [98]:
df_summary.sort('rank_query-mean').head()[0,0]

'all-MiniLM-L6-v2_title-transcript_manhattan'

In [99]:
print(df_summary.sort("num_in_top-1", descending=True).head())

shape: (5, 4)
┌─────────────────────────────────┬─────────────────┬──────────────┬──────────────┐
│ method_name                     ┆ rank_query-mean ┆ num_in_top-1 ┆ num_in_top-3 │
│ ---                             ┆ ---             ┆ ---          ┆ ---          │
│ str                             ┆ f64             ┆ u32          ┆ u32          │
╞═════════════════════════════════╪═════════════════╪══════════════╪══════════════╡
│ all-MiniLM-L6-v2_title_euclide… ┆ 1.09375         ┆ 45           ┆ 57           │
│ all-MiniLM-L6-v2_title_cos-sim  ┆ 1.09375         ┆ 45           ┆ 57           │
│ all-MiniLM-L6-v2_title_dot-sco… ┆ 1.09375         ┆ 45           ┆ 57           │
│ multi-qa-mpnet-base-dot-v1_tit… ┆ 1.8125          ┆ 45           ┆ 57           │
│ all-MiniLM-L6-v2_title_manhatt… ┆ 0.921875        ┆ 44           ┆ 58           │
└─────────────────────────────────┴─────────────────┴──────────────┴──────────────┘


In [100]:
df_summary.sort("num_in_top-1", descending=True).head()[0,0]

'all-MiniLM-L6-v2_title_euclidean'

In [101]:
print(df_summary.sort("num_in_top-3", descending=True).head())

shape: (5, 4)
┌─────────────────────────────────┬─────────────────┬──────────────┬──────────────┐
│ method_name                     ┆ rank_query-mean ┆ num_in_top-1 ┆ num_in_top-3 │
│ ---                             ┆ ---             ┆ ---          ┆ ---          │
│ str                             ┆ f64             ┆ u32          ┆ u32          │
╞═════════════════════════════════╪═════════════════╪══════════════╪══════════════╡
│ all-MiniLM-L6-v2_title-transcr… ┆ 0.96875         ┆ 41           ┆ 61           │
│ multi-qa-distilbert-cos-v1_tit… ┆ 1.59375         ┆ 43           ┆ 61           │
│ multi-qa-distilbert-cos-v1_tit… ┆ 1.625           ┆ 42           ┆ 61           │
│ multi-qa-distilbert-cos-v1_tit… ┆ 1.625           ┆ 42           ┆ 61           │
│ all-distilroberta-v1_title_man… ┆ 1.09375         ┆ 40           ┆ 60           │
└─────────────────────────────────┴─────────────────┴──────────────┴──────────────┘


In [102]:
df_summary.sort("num_in_top-3", descending=True).head()[0,0]

'all-MiniLM-L6-v2_title-transcript_euclidean'

In [103]:
for i in range(4):
    print(df_summary.sort("num_in_top-3", descending=True)['method_name'][i])

all-MiniLM-L6-v2_title-transcript_euclidean
multi-qa-distilbert-cos-v1_title-transcript_euclidean
multi-qa-distilbert-cos-v1_title-transcript_cos-sim
multi-qa-distilbert-cos-v1_title-transcript_dot-score
