# Standard data

# Preprocess

In [None]:
import pandas as pd
import re
import numpy as np
from sentence_transformers import SentenceTransformer


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install -U gdown
import gdown


file_id = '165wV72OUUHYDO3avmcrVOI2QGkoGTrL-'
gdown.download(f"https://drive.google.com/uc?id={file_id}", "pubmed_metadata_sample_full.csv", quiet=False)




Downloading...
From (original): https://drive.google.com/uc?id=165wV72OUUHYDO3avmcrVOI2QGkoGTrL-
From (redirected): https://drive.google.com/uc?id=165wV72OUUHYDO3avmcrVOI2QGkoGTrL-&confirm=t&uuid=50312bcb-02e4-45a2-804a-2e6864f98b4a
To: /content/pubmed_metadata_sample_full.csv
100%|██████████| 294M/294M [00:09<00:00, 32.1MB/s]


'pubmed_metadata_sample_full.csv'

In [None]:
import pandas as pd
import re

df = pd.read_csv("pubmed_metadata_sample_full.csv", usecols=[0, 1, 2, 3])

df.columns = ['pmid', 'title', 'abstract', 'keywords']

df = df.dropna(subset=['title', 'abstract'])

def clean_text(text):
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^a-zA-Z0-9., ]', '', text)
    return text.strip()


df['title'] = df['title'].apply(clean_text)
df['abstract'] = df['abstract'].apply(clean_text)
df['keywords'] = df['keywords'].fillna("").apply(lambda x: clean_text(x.lower()))  # 小写关键词并清洗


df['full_text'] = df['title'] + " " + df['abstract'] + " " + df['keywords']


df.to_csv("cleaned_clinical_trials.csv", index=False)
print(f"✅ Cleaned dataset: {df.shape[0]} articles")


✅ Cleaned dataset: 162360 articles


In [None]:
df.head()

Unnamed: 0,pmid,title,abstract,keywords,full_text
0,19082600,"The ornamental variety, Japanese striped corn,...","Phenylalanine ammonialyase PAL, EC 4.3.1.24 fo...",anthocyanins enzyme stability freeze drying ge...,"The ornamental variety, Japanese striped corn,..."
2,23790829,Toxicological characterization of the landfill...,"In this research, toxicological safety of two ...",ao apdc allium cepa bod cbmn cod chemical trea...,Toxicological characterization of the landfill...
3,25174527,Geographic differences in the distribution of ...,To compare the distribution of the intrinsic m...,"adolescent adult age factors aged aged, 80 and...",Geographic differences in the distribution of ...
4,18493761,Phase I dose escalation study of docetaxel wit...,The primary objectives of this study were to e...,adult aged antineoplastic combined chemotherap...,Phase I dose escalation study of docetaxel wit...
5,29643479,Use of statins and the risk of dementia and mi...,We conducted a systematic review and metaanaly...,cognitive dysfunction dementia humans hydroxym...,Use of statins and the risk of dementia and mi...


# SBERT for embeddings

# FAISS (fast similarity search)

In [None]:
pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl (30.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m57.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.10.0


# Finetune

In [None]:
# Load RELISH Labels and Build Ground-Truth Mapping

def load_labeled_data(json_file_path):
    with open(json_file_path, 'r') as f:
        return json.load(f)

def extract_pmid_and_responses(labeled_data):
    queries = []
    for entry in labeled_data:
        pmid = entry['pmid']
        response = entry['response']
        queries.append({
            'pmid': pmid,
            'relevant': response.get('relevant', []),
            'partial': response.get('partial', []),
            'irrelevant': response.get('irrelevant', [])
        })
    return queries

# Update the file path as needed.
json_file_path = '/content/drive/MyDrive/CPSC577/RELISH_v1.json'
labeled_data = load_labeled_data(json_file_path)
queries_list = extract_pmid_and_responses(labeled_data)


random.seed(42)
random.shuffle(labeled_data)
n = len(labeled_data)
train_data = labeled_data[:int(0.8 * n)]
val_data = labeled_data[int(0.8 * n):int(0.9 * n)]
test_data = labeled_data[int(0.9 * n):]

In [None]:
from sentence_transformers import SentenceTransformer

model_path = '/content/drive/MyDrive/CPSC577/Finetune'
model = SentenceTransformer(model_path)

import faiss
all_pmids  = df['pmid'].astype(int).tolist()
all_texts  = df['full_text'].tolist()
embeddings = model.encode(all_texts, show_progress_bar=True, convert_to_numpy=True)
faiss.normalize_L2(embeddings)

# create faiss
dim   = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)

import math

pmid_to_idx = {pmid: i for i, pmid in enumerate(all_pmids)}


def recommend(pmid, top_k):
    pmid = int(pmid)
    if pmid not in pmid_to_idx:
        return []
    idx = pmid_to_idx[pmid]
    D, I = index.search(embeddings[idx:idx+1], top_k+1)
    recs = [all_pmids[i] for i in I[0] if all_pmids[i] != pmid]
    return recs[:top_k]

def average_precision_at_k(true_list, pred_list, k):
    if not true_list: return 0.0
    true_set = set(int(x) for x in true_list)
    num_rel, score = 0, 0.0
    for i, p in enumerate(pred_list[:k], start=1):
        if int(p) in true_set:
            num_rel += 1
            score += num_rel / i
    return score / min(len(true_set), k)

def dcg_at_k(pred_list, truth_scores, k):
    dcg = 0.0
    for i, p in enumerate(pred_list[:k], start=1):
        rel = truth_scores.get(int(p), 0)
        dcg += rel / math.log2(i+1)
    return dcg

def ndcg_at_k(pred_list, truth_scores, k):
    dcg = dcg_at_k(pred_list, truth_scores, k)
    ideal = sorted(truth_scores.values(), reverse=True)[:k]
    idcg  = sum(v/math.log2(i+1) for i, v in enumerate(ideal, start=1))
    return dcg/idcg if idcg>0 else 0.0







Batches:   0%|          | 0/5074 [00:00<?, ?it/s]

In [None]:
import math
from tqdm import tqdm
import numpy as np

def evaluate_split_with_mrr(split, k_list=[5,10,15]):
    scores = {f"MAP@{k}": [] for k in k_list}
    scores.update({f"NDCG@{k}": [] for k in k_list})
    scores.update({f"MRR@{k}": [] for k in k_list})
    skipped = 0

    for entry in tqdm(split, desc="Eval split"):
        q = int(entry['pmid'])
        if q not in pmid_to_idx:
            skipped += 1
            continue

        resp = entry['response']
        gt = {}
        for p in resp.get('relevant',   []): gt[int(p)] = 2
        for p in resp.get('partial',    []): gt[int(p)] = max(gt.get(int(p),0), 1)
        for p in resp.get('irrelevant',[]): gt[int(p)] = 0

        rel_bin = {pid for pid, sc in gt.items() if sc >= 1}

        for k in k_list:
            preds = recommend(q, k)
            # MAP@k
            ap = average_precision_at_k(rel_bin, preds, k)
            scores[f"MAP@{k}"].append(ap)

            # NDCG@k
            ndcg = ndcg_at_k(preds, gt, k)
            scores[f"NDCG@{k}"].append(ndcg)

            # MRR@k
            rr = 0.0
            for rank, pid in enumerate(preds[:k], start=1):
                if int(pid) in rel_bin:
                    rr = 1.0 / rank
                    break
            scores[f"MRR@{k}"].append(rr)

    print(f"Skipped {skipped} queries not in corpus.")
    return {metric: np.mean(vals) * (100 if metric.startswith(("MAP","NDCG")) else 1)
            for metric, vals in scores.items()}

for name, split in [('Full', labeled_data),
                    ('Train', train_data),
                    ('Val',   val_data),
                    ('Test',  test_data)]:
    res = evaluate_split_with_mrr(split, k_list=[5,10,15])
    print(f"\n==== {name} Set Results ====")
    for metric, val in res.items():
        suf = "%" if metric.startswith(("MAP","NDCG")) else ""
        print(f"{metric}: {val:.4f}{suf}")


Eval split: 100%|██████████| 3278/3278 [07:47<00:00,  7.01it/s]


Skipped 54 queries not in corpus.

==== Full Set Results ====
MAP@5: 75.4695%
MAP@10: 66.2318%
MAP@15: 59.5028%
NDCG@5: 78.3727%
NDCG@10: 73.0925%
NDCG@15: 69.4876%
MRR@5: 0.9097
MRR@10: 0.9112
MRR@15: 0.9114


Eval split: 100%|██████████| 2622/2622 [06:11<00:00,  7.05it/s]


Skipped 45 queries not in corpus.

==== Train Set Results ====
MAP@5: 78.3498%
MAP@10: 69.2460%
MAP@15: 62.4707%
NDCG@5: 81.8676%
NDCG@10: 76.5874%
NDCG@15: 72.9323%
MRR@5: 0.9233
MRR@10: 0.9245
MRR@15: 0.9247


Eval split: 100%|██████████| 328/328 [00:46<00:00,  7.10it/s]


Skipped 5 queries not in corpus.

==== Val Set Results ====
MAP@5: 65.8844%
MAP@10: 56.4892%
MAP@15: 49.8314%
NDCG@5: 66.5361%
NDCG@10: 61.5131%
NDCG@15: 57.9866%
MRR@5: 0.8523
MRR@10: 0.8557
MRR@15: 0.8562


Eval split: 100%|██████████| 328/328 [00:45<00:00,  7.17it/s]

Skipped 4 queries not in corpus.

==== Test Set Results ====
MAP@5: 62.1163%
MAP@10: 51.9706%
MAP@15: 45.5386%
NDCG@5: 62.3758%
NDCG@10: 56.8386%
NDCG@15: 53.5548%
MRR@5: 0.8588
MRR@10: 0.8609
MRR@15: 0.8611



