In [None]:
%autosave 0

In [None]:
from cherche import retrieve
from sentence_transformers import SentenceTransformer, util
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas()
import numpy as np
import os
import json

# evaluation variables

In [None]:
DB_PATH = "../dataset_ready/db_libraries.csv"
QUERIES_PATH = "../dataset_ready/queries_w_labels.csv"
GROUND_TRUTH = "../"
KEYWORDS = "new_keywords.json"
for path in (DB_PATH, QUERIES_PATH):
    print(os.path.exists(path))
    
MODEL_PATH = "./model/TripletLoss_uncased_iter5_sim_augmentation_codebert-2022-08-20_04-30-14"

# prepare index

In [None]:
df = pd.read_csv(DB_PATH)
df = df[['id', 'dirname']].copy()

In [None]:
def generate_index(df):
    df_cp = df.copy()
    index_list = []
    for id_, dirname in df_cp.values:
        index_list.append(
        {
            'id': id_,
            'library': dirname.lower()
        })
    return index_list

In [None]:
index_list = generate_index(df)
index_list[:5]

# load model

In [None]:
codebert = SentenceTransformer(MODEL_PATH)

In [None]:
retriever = retrieve.Encoder(
    key = "id",
    on = "library",
    encoder = codebert.encode,
    k = 10,
    path = f"temp/TripletLoss_uncased_iter5_sim_augmentation_codebert-2022-08-20_04-30-14.pkl"
)

In [None]:
retriever = retriever.add(documents=index_list)

# perform search on the queries

In [None]:
df_queries = pd.read_csv(QUERIES_PATH)
df_queries.fillna("null", inplace=True)
columns = ['truths_family', 'truths_serie']

for column in columns:
    df_queries[column] = df_queries[column].progress_apply(lambda x: x.split("###") if x != "null" else "null")
    df_queries[column] = df_queries[column].progress_apply(lambda x: [int(id_) for id_ in x] if x!= "null" else "null")

In [None]:
with open(KEYWORDS, "r") as f:
    keywords = json.load(f)

df_queries['keywords'] = keywords
df_queries['keywords'] = df_queries['keywords'].progress_apply(lambda x: x.split("###") if x != "null" else "null")

In [None]:
def extract_series(x):
    name = x.replace("-", " ").replace("_", " ")
    name = name.split()
    series = []
    for token in name:
        if token.isalnum() and not(token.isalpha()) and not(token.isdigit()):
            series.append(token)
    if len(series) > 0:
        return series
    else:
        return [x]

def search_on_queries(df, model, k):
    df_cp = df.copy()
    preds = []
    for query_id, query, cat, truths_fam, truths_ser, keywords in df_cp.values:
        temp_preds = []
        for keyword in keywords:
            results = model(keyword)
            assert(len(results)==k)
            for item in results:
                temp_preds.append(item)
        
        temp_preds_sorted = sorted(temp_preds, key=lambda d: d['similarity'], reverse=True)
        temp_preds_sorted = [x.get('id') for x in temp_preds_sorted]
        preds.append(temp_preds_sorted[:k])
    return preds

In [None]:
preds = search_on_queries(df_queries, retriever, 10)

In [None]:
df_queries['preds'] = preds

# evaluate precision

In [None]:
def get_precision_family(x, k):
    preds = x.preds[:k]
    truths_fam = x["truths_family"] if x["truths_family"] != "null" else []
    truths_ser = x["truths_serie"] if x["truths_serie"] != "null" else []
    truths = truths_fam + truths_ser
    return len(set(preds) & set(truths))/k

def get_precision_serie(x, k):
    preds = x.preds[:k]
    truths = x["truths_serie"] if x["truths_serie"] != "null" else []  
    return len(set(preds) & set(truths))/k

In [None]:
k_list = [1, 5, 10]
for k in k_list:
    df_queries[f"precision_{k}"] = df_queries.progress_apply(lambda x: get_precision_family(x, k), axis=1)

In [None]:
df_cp = df_queries[df_queries.truths_family != "null"].copy()
print(len(df_cp))
for k in k_list:
    print(f'precision@{k} family: {df_cp[f"precision_family_{k}"].sum()/len(df_cp)}')