In [32]:
!pip install -q --upgrade pip setuptools wheel
!pip install --upgrade --no-cache-dir datasets



In [33]:
# Put this at the very top of your notebook (before any wandb or transformers imports)
import os
os.environ["WANDB_DISABLED"] = "true"


# Data Preprocessing

In [34]:
import pandas as pd
import re
import numpy as np
from sentence_transformers import SentenceTransformer
!pip install faiss-cpu
import torch
import torch.nn as nn
import math
import faiss
from sentence_transformers import SentenceTransformer
!pip install -U sentence-transformers peft faiss-cpu




In [35]:
!pip install -U gdown
import gdown

# 替换成你自己的文件 ID
file_id = '165wV72OUUHYDO3avmcrVOI2QGkoGTrL-'  # 来自你分享链接中的 d/XXX/
gdown.download(f"https://drive.google.com/uc?id={file_id}", "pubmed_metadata_sample_full.csv", quiet=False)




Downloading...
From (original): https://drive.google.com/uc?id=165wV72OUUHYDO3avmcrVOI2QGkoGTrL-
From (redirected): https://drive.google.com/uc?id=165wV72OUUHYDO3avmcrVOI2QGkoGTrL-&confirm=t&uuid=fc458172-9515-4d8b-8112-37459d26f4b7
To: /content/pubmed_metadata_sample_full.csv
100%|██████████| 294M/294M [00:02<00:00, 137MB/s]


'pubmed_metadata_sample_full.csv'

In [36]:
import json
import math
import numpy as np
import pandas as pd
import re
from sentence_transformers import SentenceTransformer
import faiss

##############################################
# 1. Preprocess the PubMed Metadata (Corpus)
##############################################

# Read the CSV file with columns: pmid, title, abstract, keywords
df = pd.read_csv("pubmed_metadata_sample_full.csv", usecols=[0, 1, 2, 3])
df.columns = ['pmid', 'title', 'abstract', 'keywords']
df = df.dropna(subset=['title', 'abstract'])

def clean_text(text):
    text = re.sub(r'\s+', ' ', text)  # merge whitespace
    text = re.sub(r'[^a-zA-Z0-9., ]', '', text)  # remove special characters
    return text.strip()

df['title'] = df['title'].apply(clean_text)
df['abstract'] = df['abstract'].apply(clean_text)
df['keywords'] = df['keywords'].fillna("").apply(lambda x: clean_text(x.lower()))
# Combine text fields as full text
df['full_text'] = df['title'] + " " + df['abstract'] + " " + df['keywords']
df.to_csv("cleaned_clinical_trials.csv", index=False)
print(f"✅ Cleaned dataset: {df.shape[0]} articles")

##############################################
# 2. Build Corpus from Cleaned CSV and Create Embeddings
##############################################
# We use the cleaned CSV file to create our corpus.
# Make sure 'pmid' is treated as integer.
df['pmid'] = df['pmid'].astype(int)

# Create a dictionary mapping pmid -> full_text
corpus_text = {row['pmid']: row['full_text'] for _, row in df.iterrows()}

# Build a list of PMIDs and texts (order matters for FAISS index)
all_pmids = list(corpus_text.keys())
all_texts = [corpus_text[pid] for pid in all_pmids]


✅ Cleaned dataset: 162360 articles


In [37]:
##############################################
# 3. Load RELISH Labels and Build Ground-Truth Mapping
##############################################
# The RELISH JSON file contains query PMIDs and their candidate relevance information.
# It is assumed that each entry has a 'pmid' and a 'response' field,
# where response contains lists under keys 'relevant', 'partial', and 'irrelevant'.

def load_labeled_data(json_file_path, num_entries=100):
    with open(json_file_path, 'r') as f:
        labeled_data = json.load(f)
    return labeled_data[:num_entries]

def extract_pmid_and_responses(labeled_data):
    queries = []
    for entry in labeled_data:
        pmid = entry['pmid']
        response = entry['response']
        queries.append({
            'pmid': pmid,
            'relevant': response.get('relevant', []),
            'partial': response.get('partial', []),
            'irrelevant': response.get('irrelevant', [])
        })
    return queries

# Update the file path as needed.
json_file_path = '/content/RELISH_v1.json'
labeled_data = load_labeled_data(json_file_path)
queries_list = extract_pmid_and_responses(labeled_data)

# Build ground_truth mapping: for each query pmid (as int), map candidate pmid -> relevance score
# We assign: fully relevant: 2, partial: 1, irrelevant: 0
ground_truth = {}  # {query_pmid: {candidate_pmid: score}}
for entry in queries_list:
    qid = int(entry['pmid'])
    ground_truth[qid] = {}
    for pmid in entry['relevant']:
        ground_truth[qid][int(pmid)] = 2
    for pmid in entry['partial']:
        # If a candidate already exists with score 2, keep it.
        ground_truth[qid][int(pmid)] = max(ground_truth[qid].get(int(pmid), 0), 1)
    for pmid in entry['irrelevant']:
        # irrelevant explicitly scored as 0 (optional, since absence can be treated as 0)
        ground_truth[qid][int(pmid)] = 0
total_pairs = sum(len(cand_dict) for cand_dict in ground_truth.values())
print(f"Total pairs：{total_pairs}")

positive_pairs = sum(
    sum(1 for score in cand_dict.values() if score >= 1)
    for cand_dict in ground_truth.values()
)
print(f"Total positive pairs（score>=1）：{positive_pairs}")



Total pairs：6000
Total positive pairs（score>=1）：4277


# Finetuning with LoRA

In [38]:
# 1) load SBERT and grab the HF backbone
from sentence_transformers import SentenceTransformer
sbert = SentenceTransformer('all-mpnet-base-v2')
transformer_mod = sbert._first_module()       # the sentence_transformers.models.Transformer
hf_model       = transformer_mod.auto_model   # the actual transformers.MPNetModel

# 2) freeze backbone (optional—get_peft_model does this for you)
hf_model.requires_grad_(False)

# 3) configure LoRA to target q/k/v/o projections
from peft import LoraConfig, TaskType
lora_cfg = LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=["attn.q", "attn.k", "attn.v", "attn.o"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION,
)

# 4) inject and re-insert
from peft import get_peft_model
peft_model = get_peft_model(hf_model, lora_cfg)
transformer_mod.auto_model = peft_model

# 5) verify
peft_model.print_trainable_parameters()


trainable params: 2,359,296 || all params: 111,845,760 || trainable%: 2.1094


In [39]:
# 0. Install required libraries (run once)
!pip install -U peft transformers accelerate sentence-transformers

# 1. Imports
import torch
from peft import LoraConfig, get_peft_model, TaskType
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader

# 2. Load your SBERT model and grab its HuggingFace backbone
sbert_model     = SentenceTransformer('all-mpnet-base-v2')
transformer_mod = sbert_model._first_module()   # the sentence_transformers.models.Transformer
hf_model        = transformer_mod.auto_model    # the underlying transformers.MPNetModel

# 3. Freeze all original MPNet parameters (PEFT would do this automatically, but we enforce it)
hf_model.requires_grad_(False)

# 4. Configure and inject LoRA into the four self-attention projections
lora_config = LoraConfig(
    r=32,                             # LoRA rank
    lora_alpha=16,                   # LoRA scaling
    target_modules=["attn.q",        # matches encoder.layer.*.attention.attn.q
                    "attn.k",        # matches encoder.layer.*.attention.attn.k
                    "attn.v",        # matches encoder.layer.*.attention.attn.v
                    "attn.o"],       # matches encoder.layer.*.attention.attn.o
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION
)
peft_model = get_peft_model(hf_model, lora_config)
peft_model.print_trainable_parameters()  # should report ~589,824 trainable params

# 5. Re-insert the LoRA-wrapped MPNetModel back into your SentenceTransformer
transformer_mod.auto_model = peft_model
sbert_model._modules["0"] = transformer_mod

# 6. Build your contrastive training examples (query vs. relevant document)
train_examples = []
for qid, cand_dict in ground_truth.items():
    if qid not in corpus_text:
        continue
    for pid, score in cand_dict.items():
        if score >= 1 and pid in corpus_text:
            train_examples.append(
                InputExample(texts=[corpus_text[qid], corpus_text[pid]])
            )
from sklearn.model_selection import train_test_split

# ratio： train:val:test = 8:2:1

train_and_val, test_examples = train_test_split(
    train_examples,
    test_size=1/11,
    random_state=42,
    shuffle=True
)

train_examples, val_examples = train_test_split(
    train_and_val,
    test_size=0.2,
    random_state=42,
    shuffle=True
)

print(f"▶️ Train examples: {len(train_examples)}")
print(f"▶️ Validation examples: {len(val_examples)}")
print(f"▶️ Test examples: {len(test_examples)}")

# 7. Create DataLoader and loss
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss       = losses.MultipleNegativesRankingLoss(sbert_model)
print(f"▶️ Total train examples: {len(train_examples)}")
print(f"▶️ Samples in DataLoader: {len(train_dataloader.dataset)}")
print(f"▶️ Number of batches: {len(train_dataloader)}")

# 8. Fine-tune your LoRA-augmented embedder
sbert_model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=3,
    warmup_steps=100,
    output_path="./mpnet-lora-finetuned",
    use_amp=True,              # mixed-precision if supported
    show_progress_bar=True
)




Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


trainable params: 2,359,296 || all params: 111,845,760 || trainable%: 2.1094
▶️ Train examples: 2904
▶️ Validation examples: 726
▶️ Test examples: 363
▶️ Total train examples: 2904
▶️ Samples in DataLoader: 2904
▶️ Number of batches: 182


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.2532


# Calculating MAP, MRR and NDCG

In [40]:
# after training:
corpus_embeddings = sbert_model.encode(all_texts, convert_to_numpy=True, show_progress_bar=True)
faiss.normalize_L2(corpus_embeddings)
index = faiss.IndexFlatIP(corpus_embeddings.shape[1])
index.add(corpus_embeddings)
index_to_pmid = {i: pmid for i, pmid in enumerate(all_pmids)}


Batches:   0%|          | 0/5074 [00:00<?, ?it/s]

In [41]:
##############################################
# 4. Recommendation Function
##############################################
def recommend_articles(query_title, query_abstract, query_keywords, top_n=5):
    """
    Compute query embedding from title, abstract, keywords and search FAISS index.
    Excludes the query itself if present.
    Returns a list of recommended PMIDs.
    """
    query_text = " ".join((query_title + " " + query_abstract + " " + query_keywords).split())
    # query_embedding = sbert_model.encode([query_text], convert_to_numpy=True)
    # query_embedding = encode_texts([query_text])  # returns a numpy array
    query_embedding = sbert_model.encode([query_text], convert_to_numpy=True)
    faiss.normalize_L2(query_embedding)
    # Retrieve more than top_n to allow filtering.
    D, I = index.search(query_embedding, top_n + 5)
    candidate_pmids = [index_to_pmid[int(idx)] for idx in I[0]]

    # Optionally, if the query article's PMID is known, filter it out.
    # Here, we do a simple heuristic: if the query text is very similar to a candidate's text, skip it.
    filtered = []
    for pid in candidate_pmids:
        # If the query is already in the corpus and the candidate text contains similar words, skip.
        # (Alternatively, if you know the query pmid, you can pass it in and filter exactly.)
        if query_title.lower() in corpus_text.get(pid, "").lower():
            continue
        filtered.append(pid)
        if len(filtered) == top_n:
            break
    return filtered


##############################################
# 5. Ranking Metrics Functions
##############################################
def average_precision_at_k(relevant_pmids, recommended_pmids, k):
    """
    Compute Average Precision at k.
    Treat any candidate with a relevance score >= 1 as relevant.
    """
    if not relevant_pmids:
        return 0.0
    relevant_set = set(relevant_pmids)
    num_relevant = 0.0
    ap_sum = 0.0
    for i, pid in enumerate(recommended_pmids[:k], start=1):
        # binary relevance: score >= 1 is relevant
        if pid in relevant_set:
            num_relevant += 1
            ap_sum += num_relevant / i
    return ap_sum / min(len(relevant_set), k)

def mean_average_precision(all_relevant_list, all_recommended_list, k):
    ap_scores = []
    for rels, recs in zip(all_relevant_list, all_recommended_list):
        ap = average_precision_at_k(rels, recs, k)
        ap_scores.append(ap)
    return np.mean(ap_scores) if ap_scores else 0.0

def reciprocal_rank(recommended_pmids, relevant_set):
    for i, pid in enumerate(recommended_pmids, start=1):
        if pid in relevant_set:
            return 1.0 / i
    return 0.0

def dcg_at_k(recommended_pmids, ground_truth_dict, k):
    dcg = 0.0
    for i, pid in enumerate(recommended_pmids[:k], start=1):
        # Use the graded relevance score (if missing, 0)
        score = ground_truth_dict.get(pid, 0)
        dcg += score / math.log2(i + 1)
    return dcg

def ndcg_at_k(recommended_pmids, ground_truth_dict, k):
    dcg = dcg_at_k(recommended_pmids, ground_truth_dict, k)
    # Ideal DCG: sort the relevance scores of the candidates in descending order.
    ideal_scores = sorted(ground_truth_dict.values(), reverse=True)[:k]
    idcg = sum(score / math.log2(i + 1) for i, score in enumerate(ideal_scores, start=1))
    return dcg / idcg if idcg > 0 else 0.0

In [42]:
##############################################
# 6. Evaluation Over Multiple Queries
##############################################
# Evaluate only queries that are in our ground_truth and also appear in our corpus.
query_ids = [qid for qid in ground_truth if qid in corpus_text]
K = 5

all_AP = []
all_RR = []
all_NDCG = []
per_query_results = {}

for qid in query_ids:
    # Get query text from corpus_text
    query_text = corpus_text[qid]
    # Here, we assume that the query's title, abstract and keywords can be recovered
    # by splitting or using the df if available. Otherwise, we use the full text.
    # For simplicity, we use the full_text from corpus.
    # In a real system, you would retrieve the original title, abstract, keywords.
    # Below, we simply split the full_text assuming the title is the first sentence.
    parts = query_text.split(".")
    query_title = parts[0] if parts else query_text
    # Use the remainder for abstract (keywords might be embedded)
    query_abstract = " ".join(parts[1:]) if len(parts) > 1 else ""
    query_keywords = ""  # If not separately available

    recommended_pmids = recommend_articles(query_title, query_abstract, query_keywords, top_n=K)
    per_query_results[qid] = recommended_pmids

    # For binary metrics (AP and RR), consider candidates with score>=1 as relevant.
    true_relevant_set = {pid for pid, score in ground_truth[qid].items() if score >= 1}

    ap = average_precision_at_k(list(true_relevant_set), recommended_pmids, K)
    rr = reciprocal_rank(recommended_pmids, true_relevant_set)
    ndcg = ndcg_at_k(recommended_pmids, ground_truth[qid], K)

    all_AP.append(ap)
    all_RR.append(rr)
    all_NDCG.append(ndcg)

MAP5 = np.mean(all_AP) * 100
MRR = np.mean(all_RR) * 100
NDCG5 = np.mean(all_NDCG) * 100

print(f"Overall MAP@5: {MAP5:.2f}%")
print(f"Overall MRR@5: {MRR:.2f}%")
print(f"Overall NDCG@5: {NDCG5:.2f}%")

# Optionally, print some per-query results.
for i, qid in enumerate(query_ids[:5]):
    print(f"\nQuery PMID: {qid}")
    print(f"Recommended PMIDs: {per_query_results[qid]}")
    binary_truth = [pid for pid, score in ground_truth[qid].items() if score >= 1]
    print(f"Ground truth relevant PMIDs: {binary_truth}")

Overall MAP@5: 77.85%
Overall MRR: 93.88%
Overall NDCG@5: 74.57%

Query PMID: 22569528
Recommended PMIDs: [18562239, 19282669, 22177953, 21730285, 19242111]
Ground truth relevant PMIDs: [17928366, 18562239, 19052640, 19060905, 19242111, 19244124, 19414607, 19805545, 19816936, 20079430, 20811985, 22028468, 22177953, 23549785, 23712012, 24089523, 25350931, 26235619, 27376062, 28474232, 29454854]

Query PMID: 23613754
Recommended PMIDs: [27924572, 25533345, 29304842, 26224636, 20675210]
Ground truth relevant PMIDs: [18818436, 20022960, 20675210, 22085933, 25533345, 25690936, 29061959, 29304842, 22307056]

Query PMID: 29409062
Recommended PMIDs: [23281855, 26355502, 21103052, 20487513, 16447990]
Ground truth relevant PMIDs: [18443018, 19772615, 22916718, 23281855, 24931993, 26355502, 28570104, 18593717, 19087303, 19237334, 20637083, 21609501, 21846404, 22080466, 22761950, 22927994, 22962469, 23229795, 23514199, 23868775, 24726865, 26455801, 27153661, 27506132, 27571416, 28113697, 28937982]

# Finetuning of paraphrase-mpnet-base-v2

In [43]:
# 1) load SBERT and grab the HF backbone
from sentence_transformers import SentenceTransformer
sbert = SentenceTransformer('paraphrase-mpnet-base-v2')
transformer_mod = sbert._first_module()       # the sentence_transformers.models.Transformer
hf_model       = transformer_mod.auto_model   # the actual transformers.MPNetModel

# 2) freeze backbone (optional—get_peft_model does this for you)
hf_model.requires_grad_(False)

# 3) configure LoRA to target q/k/v/o projections
from peft import LoraConfig, TaskType
lora_cfg = LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=["attn.q", "attn.k", "attn.v", "attn.o"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION,
)

# 4) inject and re-insert
from peft import get_peft_model
peft_model = get_peft_model(hf_model, lora_cfg)
transformer_mod.auto_model = peft_model

# 5) verify
peft_model.print_trainable_parameters()


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.52k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/594 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

trainable params: 2,359,296 || all params: 111,845,760 || trainable%: 2.1094


In [48]:
# 0. Install required libraries (run once)
!pip install -U peft transformers accelerate sentence-transformers

# 1. Imports
import torch
from peft import LoraConfig, get_peft_model, TaskType
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader

# 2. Load your SBERT model and grab its HuggingFace backbone
sbert_model     = SentenceTransformer('paraphrase-mpnet-base-v2')
transformer_mod = sbert_model._first_module()   # the sentence_transformers.models.Transformer
hf_model        = transformer_mod.auto_model    # the underlying transformers.MPNetModel

# 3. Freeze all original MPNet parameters (PEFT would do this automatically, but we enforce it)
hf_model.requires_grad_(False)

# 4. Configure and inject LoRA into the four self-attention projections
lora_config = LoraConfig(
    r=32,                             # LoRA rank
    lora_alpha=16,                   # LoRA scaling
    target_modules=["attn.q",        # matches encoder.layer.*.attention.attn.q
                    "attn.k",        # matches encoder.layer.*.attention.attn.k
                    "attn.v",        # matches encoder.layer.*.attention.attn.v
                    "attn.o"],       # matches encoder.layer.*.attention.attn.o
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION
)
peft_model = get_peft_model(hf_model, lora_config)
peft_model.print_trainable_parameters()  # should report ~589,824 trainable params

# 5. Re-insert the LoRA-wrapped MPNetModel back into your SentenceTransformer
transformer_mod.auto_model = peft_model
sbert_model._modules["0"] = transformer_mod

# 6. Build your contrastive training examples (query vs. relevant document)
train_examples = []
for qid, cand_dict in ground_truth.items():
    if qid not in corpus_text:
        continue
    for pid, score in cand_dict.items():
        if score >= 1 and pid in corpus_text:
            train_examples.append(
                InputExample(texts=[corpus_text[qid], corpus_text[pid]])
            )
from sklearn.model_selection import train_test_split

# ratio： train:val:test = 8:2:1 ⇒ 总共 11 份

train_and_val, test_examples = train_test_split(
    train_examples,
    test_size=1/11,
    random_state=42,
    shuffle=True
)

train_examples, val_examples = train_test_split(
    train_and_val,
    test_size=0.2,
    random_state=42,
    shuffle=True
)

print(f"▶️ Train examples: {len(train_examples)}")
print(f"▶️ Validation examples: {len(val_examples)}")
print(f"▶️ Test examples: {len(test_examples)}")

# 7. Create DataLoader and loss
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss       = losses.MultipleNegativesRankingLoss(sbert_model)
print(f"▶️ Total train examples: {len(train_examples)}")
print(f"▶️ Samples in DataLoader: {len(train_dataloader.dataset)}")
print(f"▶️ Number of batches: {len(train_dataloader)}")

# 8. Fine-tune your LoRA-augmented embedder
sbert_model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=3,
    warmup_steps=100,
    output_path="./mpnet-lora-finetuned",
    use_amp=True,              # mixed-precision if supported
    show_progress_bar=True
)




Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


trainable params: 2,359,296 || all params: 111,845,760 || trainable%: 2.1094
▶️ Train examples: 2904
▶️ Validation examples: 726
▶️ Test examples: 363
▶️ Total train examples: 2904
▶️ Samples in DataLoader: 2904
▶️ Number of batches: 182


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.2887


In [49]:
# after training:
corpus_embeddings = sbert_model.encode(all_texts, convert_to_numpy=True, show_progress_bar=True)
faiss.normalize_L2(corpus_embeddings)
index = faiss.IndexFlatIP(corpus_embeddings.shape[1])
index.add(corpus_embeddings)
index_to_pmid = {i: pmid for i, pmid in enumerate(all_pmids)}


Batches:   0%|          | 0/5074 [00:00<?, ?it/s]

In [50]:
##############################################
# 4. Recommendation Function
##############################################
def recommend_articles(query_title, query_abstract, query_keywords, top_n=5):
    """
    Compute query embedding from title, abstract, keywords and search FAISS index.
    Excludes the query itself if present.
    Returns a list of recommended PMIDs.
    """
    query_text = " ".join((query_title + " " + query_abstract + " " + query_keywords).split())
    # query_embedding = sbert_model.encode([query_text], convert_to_numpy=True)
    # query_embedding = encode_texts([query_text])  # returns a numpy array
    query_embedding = sbert_model.encode([query_text], convert_to_numpy=True)
    faiss.normalize_L2(query_embedding)
    # Retrieve more than top_n to allow filtering.
    D, I = index.search(query_embedding, top_n + 5)
    candidate_pmids = [index_to_pmid[int(idx)] for idx in I[0]]

    # Optionally, if the query article's PMID is known, filter it out.
    # Here, we do a simple heuristic: if the query text is very similar to a candidate's text, skip it.
    filtered = []
    for pid in candidate_pmids:
        # If the query is already in the corpus and the candidate text contains similar words, skip.
        # (Alternatively, if you know the query pmid, you can pass it in and filter exactly.)
        if query_title.lower() in corpus_text.get(pid, "").lower():
            continue
        filtered.append(pid)
        if len(filtered) == top_n:
            break
    return filtered


##############################################
# 5. Ranking Metrics Functions
##############################################
def average_precision_at_k(relevant_pmids, recommended_pmids, k):
    """
    Compute Average Precision at k.
    Treat any candidate with a relevance score >= 1 as relevant.
    """
    if not relevant_pmids:
        return 0.0
    relevant_set = set(relevant_pmids)
    num_relevant = 0.0
    ap_sum = 0.0
    for i, pid in enumerate(recommended_pmids[:k], start=1):
        # binary relevance: score >= 1 is relevant
        if pid in relevant_set:
            num_relevant += 1
            ap_sum += num_relevant / i
    return ap_sum / min(len(relevant_set), k)

def mean_average_precision(all_relevant_list, all_recommended_list, k):
    ap_scores = []
    for rels, recs in zip(all_relevant_list, all_recommended_list):
        ap = average_precision_at_k(rels, recs, k)
        ap_scores.append(ap)
    return np.mean(ap_scores) if ap_scores else 0.0

def reciprocal_rank(recommended_pmids, relevant_set):
    for i, pid in enumerate(recommended_pmids, start=1):
        if pid in relevant_set:
            return 1.0 / i
    return 0.0

def dcg_at_k(recommended_pmids, ground_truth_dict, k):
    dcg = 0.0
    for i, pid in enumerate(recommended_pmids[:k], start=1):
        # Use the graded relevance score (if missing, 0)
        score = ground_truth_dict.get(pid, 0)
        dcg += score / math.log2(i + 1)
    return dcg

def ndcg_at_k(recommended_pmids, ground_truth_dict, k):
    dcg = dcg_at_k(recommended_pmids, ground_truth_dict, k)
    # Ideal DCG: sort the relevance scores of the candidates in descending order.
    ideal_scores = sorted(ground_truth_dict.values(), reverse=True)[:k]
    idcg = sum(score / math.log2(i + 1) for i, score in enumerate(ideal_scores, start=1))
    return dcg / idcg if idcg > 0 else 0.0

In [51]:
##############################################
# 6. Evaluation Over Multiple Queries
##############################################
# Evaluate only queries that are in our ground_truth and also appear in our corpus.
query_ids = [qid for qid in ground_truth if qid in corpus_text]
K = 5

all_AP = []
all_RR = []
all_NDCG = []
per_query_results = {}

for qid in query_ids:
    # Get query text from corpus_text
    query_text = corpus_text[qid]
    # Here, we assume that the query's title, abstract and keywords can be recovered
    # by splitting or using the df if available. Otherwise, we use the full text.
    # For simplicity, we use the full_text from corpus.
    # In a real system, you would retrieve the original title, abstract, keywords.
    # Below, we simply split the full_text assuming the title is the first sentence.
    parts = query_text.split(".")
    query_title = parts[0] if parts else query_text
    # Use the remainder for abstract (keywords might be embedded)
    query_abstract = " ".join(parts[1:]) if len(parts) > 1 else ""
    query_keywords = ""  # If not separately available

    recommended_pmids = recommend_articles(query_title, query_abstract, query_keywords, top_n=K)
    per_query_results[qid] = recommended_pmids

    # For binary metrics (AP and RR), consider candidates with score>=1 as relevant.
    true_relevant_set = {pid for pid, score in ground_truth[qid].items() if score >= 1}

    ap = average_precision_at_k(list(true_relevant_set), recommended_pmids, K)
    rr = reciprocal_rank(recommended_pmids, true_relevant_set)
    ndcg = ndcg_at_k(recommended_pmids, ground_truth[qid], K)

    all_AP.append(ap)
    all_RR.append(rr)
    all_NDCG.append(ndcg)

MAP5 = np.mean(all_AP) * 100
MRR = np.mean(all_RR) * 100
NDCG5 = np.mean(all_NDCG) * 100

print(f"Overall MAP@5 of paraphrase-mpnet-base-v2: {MAP5:.2f}%")
print(f"Overall MRR@5 of paraphrase-mpnet-base-v2: {MRR:.2f}%")
print(f"Overall NDCG@5 of paraphrase-mpnet-base-v2: {NDCG5:.2f}%")

# Optionally, print some per-query results.
for i, qid in enumerate(query_ids[:5]):
    print(f"\nQuery PMID: {qid}")
    print(f"Recommended PMIDs: {per_query_results[qid]}")
    binary_truth = [pid for pid, score in ground_truth[qid].items() if score >= 1]
    print(f"Ground truth relevant PMIDs: {binary_truth}")

Overall MAP@5: 79.10%
Overall MRR@5: 93.09%
Overall NDCG@5: 74.75%

Query PMID: 22569528
Recommended PMIDs: [22177953, 19805545, 22028468, 18562239, 24434059]
Ground truth relevant PMIDs: [17928366, 18562239, 19052640, 19060905, 19242111, 19244124, 19414607, 19805545, 19816936, 20079430, 20811985, 22028468, 22177953, 23549785, 23712012, 24089523, 25350931, 26235619, 27376062, 28474232, 29454854]

Query PMID: 23613754
Recommended PMIDs: [25533345, 22307056, 20675210, 27924572, 29304842]
Ground truth relevant PMIDs: [18818436, 20022960, 20675210, 22085933, 25533345, 25690936, 29061959, 29304842, 22307056]

Query PMID: 29409062
Recommended PMIDs: [23281855, 25045691, 26455801, 25705652, 18947876]
Ground truth relevant PMIDs: [18443018, 19772615, 22916718, 23281855, 24931993, 26355502, 28570104, 18593717, 19087303, 19237334, 20637083, 21609501, 21846404, 22080466, 22761950, 22927994, 22962469, 23229795, 23514199, 23868775, 24726865, 26455801, 27153661, 27506132, 27571416, 28113697, 2893798

In [52]:
print(f"Overall MAP@5 of paraphrase-mpnet-base-v2: {MAP5:.2f}%")
print(f"Overall MRR@5 of paraphrase-mpnet-base-v2: {MRR:.2f}%")
print(f"Overall NDCG@5 of paraphrase-mpnet-base-v2: {NDCG5:.2f}%")

Overall MAP@5 of paraphrase-mpnet-base-v2: 79.10%
Overall MRR@5 of paraphrase-mpnet-base-v2: 93.09%
Overall NDCG@5 of paraphrase-mpnet-base-v2: 74.75%


# Finetuning with all-MiniLM-L6-v2

In [55]:
# 1) load SBERT and grab the HF backbone
from sentence_transformers import SentenceTransformer
sbert = SentenceTransformer('all-MiniLM-L6-v2')
transformer_mod = sbert._first_module()       # the sentence_transformers.models.Transformer
hf_model       = transformer_mod.auto_model   # the actual transformers.MPNetModel

# 2) freeze backbone (optional—get_peft_model does this for you)
hf_model.requires_grad_(False)

from peft import LoraConfig, TaskType

lora_cfg = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["self.query", "self.key", "self.value", "output.dense"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION,
)


# 4) inject and re-insert
from peft import get_peft_model
peft_model = get_peft_model(hf_model, lora_cfg)
transformer_mod.auto_model = peft_model

# 5) verify
peft_model.print_trainable_parameters()


trainable params: 958,464 || all params: 23,671,680 || trainable%: 4.0490


In [54]:
for name, module in hf_model.named_modules():
    if isinstance(module, torch.nn.Linear) and "attention" in name:
        print(name)


encoder.layer.0.attention.self.query
encoder.layer.0.attention.self.key
encoder.layer.0.attention.self.value
encoder.layer.0.attention.output.dense
encoder.layer.1.attention.self.query
encoder.layer.1.attention.self.key
encoder.layer.1.attention.self.value
encoder.layer.1.attention.output.dense
encoder.layer.2.attention.self.query
encoder.layer.2.attention.self.key
encoder.layer.2.attention.self.value
encoder.layer.2.attention.output.dense
encoder.layer.3.attention.self.query
encoder.layer.3.attention.self.key
encoder.layer.3.attention.self.value
encoder.layer.3.attention.output.dense
encoder.layer.4.attention.self.query
encoder.layer.4.attention.self.key
encoder.layer.4.attention.self.value
encoder.layer.4.attention.output.dense
encoder.layer.5.attention.self.query
encoder.layer.5.attention.self.key
encoder.layer.5.attention.self.value
encoder.layer.5.attention.output.dense


In [59]:
# 0. Install required libraries (run once)
!pip install -U peft transformers accelerate sentence-transformers
# 1. Imports
import torch
from peft import LoraConfig, get_peft_model, TaskType
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# 2. Load SBERT + HF backbone
sbert_model     = SentenceTransformer('all-MiniLM-L6-v2')
transformer_mod = sbert_model._first_module()   # sentence_transformers.models.Transformer
hf_model        = transformer_mod.auto_model    # transformers.MPNetModel

# 3. Freeze original MPNet parameters
hf_model.requires_grad_(False)

# 4. Configure LoRA and inject
lora_cfg = LoraConfig(
    r=8,
    lora_alpha=16,                     # scaling
    target_modules=[
        "attention.self.query",        # matches .attention.self.query
        "attention.self.key",          # matches .attention.self.key
        "attention.self.value",        # matches .attention.self.value
        "attention.output.dense"       # matches .attention.output.dense
    ],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION,
)

# ⚠️ Pass the correct config object here:
peft_model = get_peft_model(hf_model, lora_cfg)
peft_model.print_trainable_parameters()  # ≈589,824 parameters

# 5. Put it back into your SentenceTransformer
transformer_mod.auto_model      = peft_model
sbert_model._modules["0"]       = transformer_mod

# 6. Build contrastive examples
train_examples = []
for qid, cand_dict in ground_truth.items():
    if qid not in corpus_text:
        continue
    for pid, score in cand_dict.items():
        if score >= 1 and pid in corpus_text:
            train_examples.append(InputExample(
                texts=[corpus_text[qid], corpus_text[pid]]
            ))

# 8:2:1 split
train_and_val, test_examples = train_test_split(train_examples, test_size=1/11, random_state=42)
train_examples, val_examples = train_test_split(train_and_val, test_size=0.2, random_state=42)

print(f"▶️ Train: {len(train_examples)}, Val: {len(val_examples)}, Test: {len(test_examples)}")

# 7. DataLoader + loss
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.MultipleNegativesRankingLoss(sbert_model)

print(f"▶️ Total train examples: {len(train_examples)}")
print(f"▶️ Batches: {len(train_dataloader)}")

# 8. Fine‐tune
sbert_model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=3,
    warmup_steps=100,
    output_path="./mpnet-lora-finetuned",
    use_amp=True,
    show_progress_bar=True
)




Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


trainable params: 147,456 || all params: 22,860,672 || trainable%: 0.6450
▶️ Train: 2904, Val: 726, Test: 363
▶️ Total train examples: 2904
▶️ Batches: 182


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.2587


In [60]:
# after training:
corpus_embeddings = sbert_model.encode(all_texts, convert_to_numpy=True, show_progress_bar=True)
faiss.normalize_L2(corpus_embeddings)
index = faiss.IndexFlatIP(corpus_embeddings.shape[1])
index.add(corpus_embeddings)
index_to_pmid = {i: pmid for i, pmid in enumerate(all_pmids)}


Batches:   0%|          | 0/5074 [00:00<?, ?it/s]

In [61]:
##############################################
# 4. Recommendation Function
##############################################
def recommend_articles(query_title, query_abstract, query_keywords, top_n=5):
    """
    Compute query embedding from title, abstract, keywords and search FAISS index.
    Excludes the query itself if present.
    Returns a list of recommended PMIDs.
    """
    query_text = " ".join((query_title + " " + query_abstract + " " + query_keywords).split())
    # query_embedding = sbert_model.encode([query_text], convert_to_numpy=True)
    # query_embedding = encode_texts([query_text])  # returns a numpy array
    query_embedding = sbert_model.encode([query_text], convert_to_numpy=True)
    faiss.normalize_L2(query_embedding)
    # Retrieve more than top_n to allow filtering.
    D, I = index.search(query_embedding, top_n + 5)
    candidate_pmids = [index_to_pmid[int(idx)] for idx in I[0]]

    # Optionally, if the query article's PMID is known, filter it out.
    # Here, we do a simple heuristic: if the query text is very similar to a candidate's text, skip it.
    filtered = []
    for pid in candidate_pmids:
        # If the query is already in the corpus and the candidate text contains similar words, skip.
        # (Alternatively, if you know the query pmid, you can pass it in and filter exactly.)
        if query_title.lower() in corpus_text.get(pid, "").lower():
            continue
        filtered.append(pid)
        if len(filtered) == top_n:
            break
    return filtered


##############################################
# 5. Ranking Metrics Functions
##############################################
def average_precision_at_k(relevant_pmids, recommended_pmids, k):
    """
    Compute Average Precision at k.
    Treat any candidate with a relevance score >= 1 as relevant.
    """
    if not relevant_pmids:
        return 0.0
    relevant_set = set(relevant_pmids)
    num_relevant = 0.0
    ap_sum = 0.0
    for i, pid in enumerate(recommended_pmids[:k], start=1):
        # binary relevance: score >= 1 is relevant
        if pid in relevant_set:
            num_relevant += 1
            ap_sum += num_relevant / i
    return ap_sum / min(len(relevant_set), k)

def mean_average_precision(all_relevant_list, all_recommended_list, k):
    ap_scores = []
    for rels, recs in zip(all_relevant_list, all_recommended_list):
        ap = average_precision_at_k(rels, recs, k)
        ap_scores.append(ap)
    return np.mean(ap_scores) if ap_scores else 0.0

def reciprocal_rank(recommended_pmids, relevant_set):
    for i, pid in enumerate(recommended_pmids, start=1):
        if pid in relevant_set:
            return 1.0 / i
    return 0.0

def dcg_at_k(recommended_pmids, ground_truth_dict, k):
    dcg = 0.0
    for i, pid in enumerate(recommended_pmids[:k], start=1):
        # Use the graded relevance score (if missing, 0)
        score = ground_truth_dict.get(pid, 0)
        dcg += score / math.log2(i + 1)
    return dcg

def ndcg_at_k(recommended_pmids, ground_truth_dict, k):
    dcg = dcg_at_k(recommended_pmids, ground_truth_dict, k)
    # Ideal DCG: sort the relevance scores of the candidates in descending order.
    ideal_scores = sorted(ground_truth_dict.values(), reverse=True)[:k]
    idcg = sum(score / math.log2(i + 1) for i, score in enumerate(ideal_scores, start=1))
    return dcg / idcg if idcg > 0 else 0.0

In [62]:
##############################################
# 6. Evaluation Over Multiple Queries
##############################################
# Evaluate only queries that are in our ground_truth and also appear in our corpus.
query_ids = [qid for qid in ground_truth if qid in corpus_text]
K = 5

all_AP = []
all_RR = []
all_NDCG = []
per_query_results = {}

for qid in query_ids:
    # Get query text from corpus_text
    query_text = corpus_text[qid]
    # Here, we assume that the query's title, abstract and keywords can be recovered
    # by splitting or using the df if available. Otherwise, we use the full text.
    # For simplicity, we use the full_text from corpus.
    # In a real system, you would retrieve the original title, abstract, keywords.
    # Below, we simply split the full_text assuming the title is the first sentence.
    parts = query_text.split(".")
    query_title = parts[0] if parts else query_text
    # Use the remainder for abstract (keywords might be embedded)
    query_abstract = " ".join(parts[1:]) if len(parts) > 1 else ""
    query_keywords = ""  # If not separately available

    recommended_pmids = recommend_articles(query_title, query_abstract, query_keywords, top_n=K)
    per_query_results[qid] = recommended_pmids

    # For binary metrics (AP and RR), consider candidates with score>=1 as relevant.
    true_relevant_set = {pid for pid, score in ground_truth[qid].items() if score >= 1}

    ap = average_precision_at_k(list(true_relevant_set), recommended_pmids, K)
    rr = reciprocal_rank(recommended_pmids, true_relevant_set)
    ndcg = ndcg_at_k(recommended_pmids, ground_truth[qid], K)

    all_AP.append(ap)
    all_RR.append(rr)
    all_NDCG.append(ndcg)

MAP5 = np.mean(all_AP) * 100
MRR = np.mean(all_RR) * 100
NDCG5 = np.mean(all_NDCG) * 100

print(f"Overall MAP@5 of all-MiniLM-L6-v2: {MAP5:.2f}%")
print(f"Overall MRR@5 of all-MiniLM-L6-v2: {MRR:.2f}%")
print(f"Overall NDCG@5 of all-MiniLM-L6-v2: {NDCG5:.2f}%")

# Optionally, print some per-query results.
for i, qid in enumerate(query_ids[:5]):
    print(f"\nQuery PMID: {qid}")
    print(f"Recommended PMIDs: {per_query_results[qid]}")
    binary_truth = [pid for pid, score in ground_truth[qid].items() if score >= 1]
    print(f"Ground truth relevant PMIDs: {binary_truth}")

Overall MAP@5 of paraphrase-mpnet-base-v2: 79.94%
Overall MRR@5 of paraphrase-mpnet-base-v2: 96.10%
Overall NDCG@5 of paraphrase-mpnet-base-v2: 75.11%

Query PMID: 22569528
Recommended PMIDs: [18562239, 23712012, 19805545, 24434059, 22065579]
Ground truth relevant PMIDs: [17928366, 18562239, 19052640, 19060905, 19242111, 19244124, 19414607, 19805545, 19816936, 20079430, 20811985, 22028468, 22177953, 23549785, 23712012, 24089523, 25350931, 26235619, 27376062, 28474232, 29454854]

Query PMID: 23613754
Recommended PMIDs: [25533345, 27924572, 29304842, 20675210, 18818436]
Ground truth relevant PMIDs: [18818436, 20022960, 20675210, 22085933, 25533345, 25690936, 29061959, 29304842, 22307056]

Query PMID: 29409062
Recommended PMIDs: [23281855, 26455801, 20487513, 25045691, 26342231]
Ground truth relevant PMIDs: [18443018, 19772615, 22916718, 23281855, 24931993, 26355502, 28570104, 18593717, 19087303, 19237334, 20637083, 21609501, 21846404, 22080466, 22761950, 22927994, 22962469, 23229795, 235

In [63]:
print(f"Overall MAP@5 of all-MiniLM-L6-v2: {MAP5:.2f}%")
print(f"Overall MRR@5 of all-MiniLM-L6-v2: {MRR:.2f}%")
print(f"Overall NDCG@5 of all-MiniLM-L6-v2: {NDCG5:.2f}%")

Overall MAP@5 of all-MiniLM-L6-v2: 79.94%
Overall MRR@5 of all-MiniLM-L6-v2: 96.10%
Overall NDCG@5 of all-MiniLM-L6-v2: 75.11%
