In [None]:
!pip install dspy

In [None]:
import json
import string
from pathlib import Path
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize
import numpy as np

In [None]:
TITLE_MARKER = "|t|"
ABSTRACT_MARKER = "|a|"
CID_MARKER = "CID"


def remove_duplicates(dict_list):
    """
    Remove duplicate dictionaries from a list.
    """
    seen = set()
    unique_dicts = []
    for d in dict_list:
        # If several entities have been extracted - we are just going to keep the first span
        sub_d = {"text": d["text"], "type": d["type"], "identifier": d["identifier"]}
        key = frozenset(sub_d.items())
        if key not in seen:
            seen.add(key)
            unique_dicts.append(d)
    return unique_dicts

def parse_dataset(file_path):
    """
    Parse the BioCreative dataset into structured dictionaries.

    Args:
        file_path (str): Path to the dataset file.

    Returns:
        list[dict]: A list of documents, each with id, title, abstract, and annotations.
    """
    documents = []
    current_doc = None

    with open(file_path, encoding="utf-8") as f:
        for raw_line in f:
            line = raw_line.strip()
            if not line:
                continue  # skip blank lines

            if TITLE_MARKER in line:
                # Save the previous document
                if current_doc:
                    documents.append(current_doc)

                doc_id, _, title = line.partition(TITLE_MARKER)
                current_doc = {"id": doc_id, "text": title, "annotations": []}

            elif ABSTRACT_MARKER in line:
                _, _, abstract = line.partition(ABSTRACT_MARKER)
                if current_doc is not None:
                    current_doc["text"] += "\n" + abstract

            else:
                # Annotation or CID line
                parts = line.split("\t")
                if len(parts) < 6:
                    continue  # malformed line

                if parts[1] == CID_MARKER:
                    continue  # skip chemical-disease relation lines

                _, span_start, span_end, text, entity_type, identifier = parts[:6]
                current_doc["annotations"].append({
                    "text": text,
                    "type": entity_type,
                    "identifier": identifier
                })

        # Add last document
        if current_doc:
            documents.append(current_doc)

    # Deduplicate annotations
    for doc in documents:
        doc["annotations"] = remove_duplicates(doc["annotations"])

    return documents




In [None]:
dataset_path = "/content/CDR_TestSet.PubTator.txt"
dataset = parse_dataset(dataset_path)

In [None]:
def load_kb(file_path):
    """
    Load KB from a JSONL file into alias list and concept mapping.
    """
    aliases = []
    alias_to_concept = []
    concept_meta = {}

    with open(file_path, encoding="utf-8") as f:
        for line in f:
            concept = json.loads(line)
            cid = concept["concept_id"]
            concept_meta[cid] = concept

            for alias in concept["aliases"]:
                aliases.append(alias)
                alias_to_concept.append(cid)

    return aliases, alias_to_concept, concept_meta

def preprocess(text):
    """
    Lowercase, remove punctuation for consistency.
    """
    text = text.lower()
    return text.translate(str.maketrans("", "", string.punctuation))

def build_tfidf_index(aliases):
    vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(3,5))
    alias_vecs = vectorizer.fit_transform(preprocess(a) for a in aliases)
    # Normalize for cosine similarity
    alias_vecs = normalize(alias_vecs)
    return vectorizer, alias_vecs

def batch_link_entities(mentions, vectorizer, alias_vecs, aliases, alias_to_concept, top_k=5):
    mention_vecs = vectorizer.transform(preprocess(m) for m in mentions)
    mention_vecs = normalize(mention_vecs)
    sims = mention_vecs @ alias_vecs.T  # Matrix multiplication, very fast

    results_batch = []
    for i, mention in enumerate(mentions):
        row = sims[i].toarray().ravel()
        top_idx = np.argsort(-row)[:top_k]
        results = []
        for idx in top_idx:
            results.append({
                "mention": mention,
                "alias": aliases[idx],
                "concept_id": alias_to_concept[idx],
                "score": float(row[idx])
            })
        results_batch.append(results)
    return results_batch


kb_file = "/content/mesh_2020.jsonl"

print("Loading KB...")
aliases, alias_to_concept, concept_meta = load_kb(kb_file)
print(f"Loaded {len(aliases)} aliases from KB.")

print("Building TF-IDF index...")
vectorizer, alias_vecs = build_tfidf_index(aliases)

# here an example
mentions = ["delirium", "hypertension"]
mention_results = batch_link_entities(mentions, vectorizer, alias_vecs, aliases, alias_to_concept, top_k=5)

for mention_results in mention_results:
    print(f"\nMention: {mention_results[0]['mention']}")
    for r in mention_results:
        meta = concept_meta[r["concept_id"]]
        print(f"  {r['alias']} → {r['concept_id']} ({meta['canonical_name']}) score={r['score']:.3f}")



In [None]:
TOP_K=5

In [None]:
from tqdm import tqdm

display_error = True
TOTAL_FOUND = 0
TOTAL = 0
# We only process the first 100 articles.

for item in tqdm(dataset[:100], desc="process all the dataset"):

    mentions = [annotation["text"] for annotation in item["annotations"]]
    ground_truths = [annotation["identifier"] for annotation in item["annotations"]]
    results = batch_link_entities(mentions, vectorizer, alias_vecs, aliases, alias_to_concept, top_k=TOP_K)
    for i, (result, ground_truth) in enumerate(zip(results, ground_truths)):
        if ground_truth in [entity["concept_id"] for entity in result]:
            TOTAL_FOUND += 1
        else:
            if display_error:
                matched = ', '.join([f"{entity['alias']} ({entity['concept_id']})" for entity in result])
                print(f"{mentions[i]}: matched {matched}).\nGround truth was {ground_truth}\n")
        TOTAL += 1
print(f"Recall @{TOP_K} {(TOTAL_FOUND/TOTAL) * 100}")

In [None]:
print(f"Recall @{TOP_K} {(TOTAL_FOUND/TOTAL) * 100}")

Recall @5 72.00424178154825


## Using LLM

In [None]:
import dspy
from typing import List

from typing import List, Optional
from pydantic import BaseModel

class LLMConfig:
    @staticmethod
    def setup_openai(api_key: str, model: str = "gpt-4o-mini", max_tokens=2048):
        """Setup OpenAI API"""
        import openai
        openai.api_key = api_key
        lm = dspy.LM(model, max_tokens=max_tokens, api_key=api_key)
        dspy.configure(lm=lm, adapter=dspy.JSONAdapter())
        dspy.settings.lm.kwargs["temperature"] = 0.0
        return lm

    @staticmethod
    def setup_ollama(api_base: str = "http://127.0.0.1:11434", model: str = "llama3.1:8b", max_tokens=2048):
        """Setup Ollama local model"""
        lm = dspy.LM(model, api_base=api_base, max_tokens=max_tokens)
        dspy.configure(lm=lm, adapter=dspy.JSONAdapter())
        dspy.settings.lm.kwargs["temperature"] = 0.0
        return lm

class NamedEntity(BaseModel):
    text: str
    type: str

class EntityLinkingInput(BaseModel):
    text: str
    entities: List[NamedEntity]

class LinkedEntity(BaseModel):
    text: str
    mesh_id: str


class EntityLinkingSignature(dspy.Signature):
    """
    Given a scientific text and a list of extracted named entities,
    Link each entity to the MeSH ID from the Medical Subject Headings (MeSH) thesaurus.
    """
    text: str = dspy.InputField(desc="The full document text")
    entities: EntityLinkingInput = dspy.InputField(desc="List of extracted entities, each with text and type")
    linked_entities: List[LinkedEntity] = dspy.OutputField(desc="List of the prediceted MeSH ID")

class EntityLinkerModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.linker_llm = dspy.Predict(EntityLinkingSignature)

    def forward(self, item):
        text = item["text"]
        entities = [{"text": annot.get('text', ''), "type": annot.get('type', '')} for annot in item['annotations']]
        return self.linker_llm(text=text, entities=entities)


In [None]:
! sudo apt update && sudo apt install pciutils lshw
!curl -fsSL https://ollama.com/install.sh | sh

In [None]:
!nohup ollama serve > ollama.log 2>&1 &

In [None]:
!ollama run vicuna:13b “What is the capital of the Netherlands?”

In [None]:
import os
from dotenv import load_dotenv, find_dotenv
# load .env vars
_ = load_dotenv(find_dotenv())
api_key = os.getenv("OPENAI_API_KEY")

llm = LLMConfig.setup_openai(model="gpt-4.1-mini", api_key=api_key)
# llm = LLMConfig.setup_ollama(model="ollama_chat/llama3.1:8b", api_base="http://127.0.0.1:11434")

In [None]:
# This is just a test !
llm("Hello !")

In [None]:
LLMLinker = EntityLinkerModule()

In [None]:
from tqdm import tqdm

TOTAL_FOUND = 0
TOTAL = 0
# We only process the first 100 articles.
all_responses = {}

for item in tqdm(dataset[:100], desc="process all the dataset"):
    ground_truth = {}
    for annot in item["annotations"]:
        ground_truth[annot["text"]] = annot["identifier"]

    # Compute response
    response = LLMLinker(item)
    predicted = {}
    all_responses[item['id']] = ([le.model_dump() for le in response.linked_entities])
    for linked_ent in response.linked_entities:
        predicted[linked_ent.text] = linked_ent.mesh_id

    # Build a dict for the groudn truth and the response outputs
    for text, gt_mesh_id in ground_truth.items():
        if predicted.get(text, '') == gt_mesh_id:
          TOTAL_FOUND += 1
        TOTAL += 1


In [None]:
print(f"LLM - Recall @1 {(TOTAL_FOUND/TOTAL) * 100}")

In [None]:
# Save the response in a cache file
with open("all_gpt-4.1-mini_responses.json", "w") as f:
    json.dump(all_responses, f, indent=4)

In [None]:
# Using Cache
cached_response = {}
with open("all_gpt-4.1-mini_responses.json", "r") as f:
    cached_response = json.load(f)


TOTAL_FOUND = 0
TOTAL = 0

for item in tqdm(dataset[:100], desc="process all the dataset"):
    ground_truth = {}
    for annot in item["annotations"]:
        ground_truth[annot["text"]] = annot["identifier"]

    # Compute response
    response = [LinkedEntity(**entity) for entity in cached_response[item['id']]]
    predicted = {}
    for linked_ent in response:
        predicted[linked_ent.text] = linked_ent.mesh_id

    # Build a dict for the groudn truth and the response outputs
    for text, gt_mesh_id in ground_truth.items():
        if predicted.get(text, '') == gt_mesh_id:
          TOTAL_FOUND += 1
        TOTAL += 1


In [None]:
print(f"LLM (from cache) - Recall @1 {(TOTAL_FOUND/TOTAL) * 100}")

This is way slower, and obviously: way worse !

Maybe not for linking ... but how could they help ?

[*LLM as Entity Disambiguator for Biomedical Entity-Linking*](https://doi.org/10.18653/v1/2025.acl-short.25)

[*Retrieve Then Rerank: An End-to-End Learning Paradigm for Biomedical Entity Linking*](https://doi.org/10.1111/jebm.70053)

[*BioLinkerAI: Leveraging LLMs to Improve Biomedical Entity Linking and Knowledge Capture*](https://doi.org/10.1145/3701551.3708812)

[*Improving biomedical entity linking for complex entity mentions with LLM-based text simplification*](https://doi.org/10.1093/database/baae067)

[*Accelerating Cross-Encoders in Biomedical Entity Linking*](10.18653/v1/2025.bionlp-1.13)

[*Contextual Augmentation for Entity Linking using Large Language Models*](https://aclanthology.org/2025.coling-main.570/)

Can an LLM help to filter the initial candidates list ?