Top 10 to 14 

In [None]:
import json
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
import h5py
import faiss
from tqdm import tqdm  # for progress bar

# Load the JSON files with article summaries and references
input_path_queries = '/n/data1/hsph/biostat/celehs/lab/jh537/Retrivial_task/DATA/LONG_CTG_id_text_refs_train.json'
input_path_articles = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Retrievial/v2_/NEW_PM_id_text.json'
embeddings_path = '/n/data1/hsph/biostat/celehs/lab/jh537/Retrivial_task/DATA/CLS_NEW_PM_id_text_W_BGE_L.h5'
output_path = '/n/data1/hsph/biostat/celehs/lab/jh537/Retrivial_task/DATA/top10_17_retrieved_articles.json'

# Set the variables
NUM_QUERIES = 21680
TOP_K = 27  # Retrieve at least the top 17 results to access indices 10 to 17

# Load queries
with open(input_path_queries, 'r') as f:
    queries = json.load(f)
# Limit the number of queries if desired
queries = queries[:NUM_QUERIES]

# Load article metadata (to map index -> article_id)
with open(input_path_articles, 'r') as f:
    articles = json.load(f)
article_ids = [article['article_id'] for article in articles]

# ---------------------------
# Load Model in Half Precision
# ---------------------------
model_name = "BAAI/bge-large-en-v1.5"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.half()       # Convert model weights to fp16
model.to('cuda')   # Move model to GPU
model.eval()       # Set model to evaluation mode

# ---------------------------
# Function: get query CLS embedding
# ---------------------------
def get_query_embedding(text, tokenizer, model):
    # Tokenize and move to GPU
    encoded_input = tokenizer(
        text,
        return_tensors='pt',
        truncation=True,
        padding=True,
        max_length=512
    ).to('cuda')

    with torch.no_grad():
        outputs = model(**encoded_input)
        # CLS pooling: use the [CLS] token (first token) representation
        cls_embedding = outputs.last_hidden_state[:, 0, :]
    # Convert to NumPy (float16 for consistency)
    return cls_embedding.squeeze().cpu().numpy().astype('float16')

# ---------------------------
# 1. Load article embeddings from HDF5
# 2. Build FAISS index
# ---------------------------
print("Loading article embeddings from HDF5...")
embeddings = []
article_id_list = []

with h5py.File(embeddings_path, 'r') as hf:
    for article_id in hf.keys():
        emb = hf[article_id][:]
        embeddings.append(emb)
        article_id_list.append(article_id)

# Convert to NumPy array
embeddings = np.array(embeddings)  # shape: (num_articles, hidden_size)
# By default this might be float16 if you didn't cast. Ensure it's float32 or float16 as needed:
embeddings = embeddings.astype('float32')  # or keep float16 if you prefer

# IMPORTANT: We must maintain consistent ordering between embeddings and article_id_list.
# So after sorting them we can do a quick re-check. But if your h5 creation was random order,
# you either keep them as is or you match them carefully with "article_ids" from `PM_id_text.json`.

# For a quick approach, we'll just treat "article_id_list" as the official ordering.
# If you want to map to your "articles" from JSON, do so carefully. One approach is:
#  - Sort article_id_list if needed,
#  - Build a lookup table (dictionary) from article_id_list to index.

print("Normalizing article embeddings for cosine similarity...")
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / (norms + 1e-9)  # to avoid divide-by-zero

dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)  # inner product index
index.add(embeddings)
print(f"FAISS index created. Number of embeddings: {len(embeddings)}")

# We'll create a dictionary to quickly map from index -> article_id
idx_to_article_id = dict(enumerate(article_id_list))

# ---------------------------
# Retrieval Loop
# ---------------------------
retrieval_results = []

print("Processing queries (CLS pooling for each query, FAISS retrieval)...")
for query_data in tqdm(queries, desc="Processing Queries"):
    query_text = query_data['summary']
    query_ref = query_data['ref']
    ntc_id = query_data.get('ntcId', 'N/A')

    # Get the query embedding
    q_emb = get_query_embedding(query_text, tokenizer, model)
    # Normalize for cosine similarity
    q_emb = q_emb / (np.linalg.norm(q_emb) + 1e-9)

    # FAISS expects shape (1, dim), cast to float32 for IndexFlatIP
    q_emb_2d = np.expand_dims(q_emb, axis=0).astype('float32')

    # Search top K
    D, I = index.search(q_emb_2d, TOP_K)
    # I has shape (1, TOP_K)
    top_indices = I[0]

    # Get article IDs for the 10th through 17th results
    # (Indices [9:17] = 10th through 17th in 0-based indexing)
    top_article_ids = [idx_to_article_id[i] for i in top_indices[10:17]]

    retrieval_results.append({
        'ntcId': ntc_id,
        'ref': query_ref,
        'top_19_25_retrieved_ids': top_article_ids
    })

# ---------------------------
# Save results
# ---------------------------
with open(output_path, 'w') as f:
    json.dump(retrieval_results, f, indent=4)

print(f"Done! Retrieval results saved to {output_path}")

Remove the same ref 

In [None]:
import json
import random

# Load the JSON file with retrieval results
input_path = '/n/data1/hsph/biostat/celehs/lab/jh537/Retrivial_task/DATA/top10_17_retrieved_articles.json'

with open(input_path, 'r') as f:
    retrieval_results = json.load(f)

# Iterate over each result and modify the top 5 retrieved IDs list
for result in retrieval_results:
    ref = str(result['ref'])
    top_5_ids = result['top_10_17_retrieved_ids']

    # If the reference is in the top 5 retrieved IDs, remove it
    if ref in top_5_ids:
        top_5_ids.remove(ref)
    else:
        # If the reference is not in the top 5, remove a random element
        if len(top_5_ids) > 0:
            top_5_ids.pop(random.randrange(len(top_5_ids)))

    # Update the result with the modified list
    result['top_10_17_retrieved_ids'] = top_5_ids

# Save the modified retrieval results to the same JSON file
with open(input_path, 'w') as f:
    json.dump(retrieval_results, f, indent=5)

print(f"Modified results saved to {input_path}")


Create training file with HN

In [None]:
import json

# Define file paths
long_ctg_file = '/n/data1/hsph/biostat/celehs/lab/jh537/Retrivial_task/DATA/LONG_CTG_id_text_refs_train.json'
top5_articles_file = '/n/data1/hsph/biostat/celehs/lab/jh537/Retrivial_task/DATA/top10_17_retrieved_articles.json'
pm_id_text_file = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Retrievial/v2_/NEW_PM_id_text.json'
output_file = '/n/data1/hsph/biostat/celehs/lab/jh537/Retrivial_task/DATA/rerank_training_HN_10_17.jsonl'

# Load data from files
with open(long_ctg_file, 'r') as f:
    long_ctg_data = json.load(f)

with open(top5_articles_file, 'r') as f:
    top5_articles_data = json.load(f)

with open(pm_id_text_file, 'r') as f:
    pm_id_text_data = json.load(f)

# Create dictionaries for quick lookups
ref_to_summary = {entry['ref']: entry['summary'] for entry in long_ctg_data}
article_id_to_text = {entry['article_id']: entry['text'] for entry in pm_id_text_data}

# Process each entry in top5_retrieved_articles.json and create the output
output_data = []
for entry in top5_articles_data:
    ref = entry['ref']
    top_5_ids = entry['top_10_17_retrieved_ids']

    # Get query (summary)
    query = ref_to_summary.get(ref, None)
    if query is None:
        continue

    # Get positive text
    pos = article_id_to_text.get(str(ref), None)
    if pos is None:
        continue

    # Get negative texts
    neg = [article_id_to_text.get(str(article_id), None) for article_id in top_5_ids]
    neg = [text for text in neg if text is not None]

    # Append to output data
    output_data.append({"query": query, "pos": [pos], "neg": neg})

# Write output to a JSONL file
with open(output_file, 'w') as f:
    for entry in output_data:
        json.dump(entry, f)
        f.write('\n')

print(f"Output written to {output_file}")


Create training without HN

In [None]:
import json

# Define file paths
long_ctg_file = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Retrievial/Complete_PP/LONG_CTG_id_text_refs_train.json'
top5_articles_file = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Retrievial/Complete_PP/top5_retrieved_articles.json'
pm_id_text_file = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Retrievial/v2_/PM_id_text.json'
output_file = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Retrievial/Complete_PP/rerank_training_no_HN.jsonl'

# Load data from files
with open(long_ctg_file, 'r') as f:
    long_ctg_data = json.load(f)

with open(top5_articles_file, 'r') as f:
    top5_articles_data = json.load(f)

with open(pm_id_text_file, 'r') as f:
    pm_id_text_data = json.load(f)

# Create dictionaries for quick lookups
ref_to_summary = {entry['ref']: entry['summary'] for entry in long_ctg_data}
article_id_to_text = {entry['article_id']: entry['text'] for entry in pm_id_text_data}

# Process each entry in top5_retrieved_articles.json and create the output
output_data = []
for entry in top5_articles_data:
    ref = entry['ref']
    top_5_ids = entry['top_5_retrieved_ids']

    # Get query (summary)
    query = ref_to_summary.get(ref, None)
    if query is None:
        continue

    # Get positive text
    pos = article_id_to_text.get(str(ref), None)
    if pos is None:
        continue

    # Get negative texts
    neg = [""]

    # Append to output data
    output_data.append({"query": query, "pos": [pos], "neg": neg})

# Write output to a JSONL file
with open(output_file, 'w') as f:
    for entry in output_data:
        json.dump(entry, f)
        f.write('\n')

print(f"Output written to {output_file}")


In [None]:
##########################################################

import json

output_path = '/n/data1/hsph/biostat/celehs/lab/jh537/Retrivial_task/DATA/unique_ct_terms_weight.json'

with open(output_path, 'r') as json_file:
    data = json.load(json_file)

# Print the length of the file
print(f"Number of entries in the file: {len(data)}")

# Print the first 2 entries (head)
print(data[:5])