In [1]:
!git clone https://github.com/szeighami/nudge.git
%cd nudge

fatal: destination path 'nudge' already exists and is not an empty directory.
/mnt/e/ResiloSync/notebook/embedding_finetuning/finetune/nudge


In [2]:
import sys
sys.path.append('../../')

In [3]:
from eval.dataset import RAGDataset


In [4]:
rag_ds = RAGDataset.from_file('../../data/infgrad_retrieval_data_llm.json')

In [5]:
from nudge import NUDGEM, NUDGEN
from util.knnretriever import kNNRetriever
from util.utils import calc_metrics_batch, load_hf_datasets, embed_data_and_query_sets

  from tqdm.autonotebook import tqdm, trange


In [8]:
import pandas as pd
corpus_keys = list(rag_ds.corpus.keys())
corpus_key_to_id = {}
corpus = []
for id, key in enumerate(corpus_keys[:10000]):
    corpus_key_to_id[key] = id
    corpus.append(dict(
        doc_id=key,
        text=rag_ds.corpus[key],
        passage_id=0,
        record_id=id
    ))

dataset = pd.DataFrame.from_records(corpus)

In [10]:
query_sets = {}
for split in rag_ds.queries_split:
    query_sets[split] = {}
    q_df = []
    q_ans_indx = []
    for id, query_key in enumerate(rag_ds.queries_split[split]):
        rel_docs = rag_ds.relevant_docs[query_key]
        rel_doc_ids = [corpus_key_to_id[doc_id] for doc_id in rel_docs if doc_id in corpus_key_to_id]
        if not rel_doc_ids:
            continue
        q_ans_indx.append(rel_doc_ids)
        q_df.append(dict(
            q_id=id,
            input=rag_ds.queries[query_key],
        ))
    query_sets[split]['q_df'] = pd.DataFrame.from_records(q_df)
    query_sets[split]['q_ans_indx'] = q_ans_indx

query_sets['test'] = query_sets['val']
del query_sets['val']

# split train to train and dev
# 将训练集分割为训练集和开发集
from sklearn.model_selection import train_test_split

train_q_df, dev_q_df, train_ans_indx, dev_ans_indx = train_test_split(
    query_sets['train']['q_df'], 
    query_sets['train']['q_ans_indx'],
    test_size=0.5,  # 50% 作为开发集
    random_state=42  # 设置随机种子以确保可重复性
)


query_sets['train'] = {
    'q_df': train_q_df,
    'q_ans_indx': train_ans_indx
}

query_sets['dev'] = {
    'q_df': dev_q_df,
    'q_ans_indx': dev_ans_indx
}

In [14]:
len(query_sets['dev']['q_ans_indx'])

2480

In [15]:
data_emb, query_sets = embed_data_and_query_sets(dataset, query_sets, "BAAI/bge-small-zh-v1.5")

embedding data


Batches: 100%|██████████| 313/313 [00:02<00:00, 130.61it/s]


embedding qs train


Batches: 100%|██████████| 78/78 [00:00<00:00, 229.66it/s]


embedding qs dev


Batches: 100%|██████████| 78/78 [00:00<00:00, 222.43it/s]


embedding qs test


Batches: 100%|██████████| 2/2 [00:00<00:00, 198.78it/s]


In [16]:
nudgen =  NUDGEN()
new_embs_nudgen = nudgen.finetune_embeddings(data_emb, query_sets['train'], query_sets['dev'])
nudge_nret = kNNRetriever(new_embs_nudgen)
nudge_n_res = nudge_nret.retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])

Calculating G
Finding gamma


In [17]:
nudgem =  NUDGEM()
new_embs_nudgem = nudgem.finetune_embeddings(data_emb, query_sets['train'], query_sets['dev'])
nudge_mret = kNNRetriever(new_embs_nudgem, dist_metric='dot')
nudge_m_res = nudge_mret.retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])

Calculating G
Finding gamma


In [18]:
no_ft_ret = kNNRetriever(data_emb)
no_ft_res = no_ft_ret.retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])

In [23]:
metrics = [('recall',10), ('ndcg',10)]
no_ft_accs = calc_metrics_batch(metrics,no_ft_res, query_sets['test']['q_ans_indx'])
nudgem_accs = calc_metrics_batch(metrics,nudge_m_res, query_sets['test']['q_ans_indx'])
nudgen_accs = calc_metrics_batch(metrics,nudge_n_res, query_sets['test']['q_ans_indx'])
print(f"No Fine-Tuning {metrics[0][0]}@{metrics[0][1]}: {no_ft_accs[0]*100:.1f}, {metrics[1][0]}@{metrics[1][1]}: {no_ft_accs[1]*100:.1f}")
print(f"NUDGE-M {metrics[0][0]}@{metrics[0][1]}: {nudgem_accs[0]*100:.1f}, {metrics[1][0]}@{metrics[1][1]}: {nudgem_accs[1]*100:.1f}")
print(f"NUDGE-N {metrics[0][0]}@{metrics[0][1]}: {nudgen_accs[0]*100:.1f}, {metrics[1][0]}@{metrics[1][1]}: {nudgen_accs[1]*100:.1f}")

No Fine-Tuning recall@10: 95.6, ndcg@10: 90.7
NUDGE-M recall@10: 95.6, ndcg@10: 90.7
NUDGE-N recall@10: 95.6, ndcg@10: 90.7
