In [None]:
!pip install openai

In [None]:
!pip install faiss-cpu

In [None]:
import json
import os
from openai import OpenAI
from getpass import getpass
import re
import faiss
import numpy as np

In [None]:
key = getpass("Enter your OpenAI API key: ")
os.environ["OPENAI_API_KEY"] = key
client = OpenAI()

In [None]:
OPENAI_MODEL = "ft:gpt-4.1-mini-2025-04-14:personal:gutbrain4:BUwx1WDM:ckpt-step-638"
TEMPERATURE = 0.0
TOP_P = 1.0
MAX_TOKENS = 16384
DEV_FILE_PATH = "/content/dev.json"
OUTPUT_FILE_PATH = "/content/llm_ner_predictions.json"
TRAIN_FILE_PATH = "/content/train.json"

In [None]:
UNDERSCORE_LABELS = [
    "anatomical_location", "animal", "biomedical_technique", "bacteria",
    "chemical", "dietary_supplement", "DDF", "drug", "food", "gene",
    "human", "microbiome", "statistical_technique"
]

In [None]:
UNDERSCORE_TO_ORIGINAL_LABEL = {
    "anatomical_location": "anatomical location",
    "animal": "animal",
    "biomedical_technique": "biomedical technique",
    "bacteria": "bacteria",
    "chemical": "chemical",
    "dietary_supplement": "dietary supplement",
    "DDF": "DDF",
    "drug": "drug",
    "food": "food",
    "gene": "gene",
    "human": "human",
    "microbiome": "microbiome",
    "statistical_technique": "statistical technique"
}

In [None]:
ORIGINAL_TO_UNDERSCORE_LABEL = {'anatomical location': 'anatomical_location', 'animal': 'animal', 'biomedical technique': 'biomedical_technique', 'bacteria': 'bacteria', 'chemical': 'chemical', 'dietary supplement': 'dietary_supplement', 'DDF': 'DDF', 'drug': 'drug', 'food': 'food', 'gene': 'gene', 'human': 'human', 'microbiome': 'microbiome', 'statistical technique': 'statistical_technique'}

embeddings and vector db

In [None]:
EMBEDDING_MODEL = "text-embedding-3-large"
FAISS_INDEX_FILE = "/content/gutbrain_train.index"
FAISS_METADATA_FILE = "/content/gutbrain_train_metadata.json"

In [None]:
def load_json_data(filepath):
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"File not found: {filepath}")
        return None
    except json.JSONDecodeError:
        print(f"Error decoding JSON from: {filepath}")
        return None

In [None]:
def get_embedding(text, model=EMBEDDING_MODEL):
    try:
        text = text.replace("\n", " ")
        response = client.embeddings.create(input=[text], model=model)
        return response.data[0].embedding
    except Exception as e:
        print(f"Error getting embedding for text '{text[:50]}...': {e}")
        return None

In [None]:
def annotate_text(original_text, entities):
    if not entities:
        return original_text

    processed_entities = []
    for e in entities:
        if not all(k in e for k in ['start_idx', 'end_idx', 'text_span', 'label']):
             print(f"Skipping invalid entity structure: {e}")
             continue
        try:
            start = int(e['start_idx'])
            end = int(e['end_idx']) + 1
            label = e['label']
            text_span = e['text_span']

            if original_text[start:end] != text_span:
                print(f"Span mismatch! Expected '{original_text[start:end]}', found '{text_span}' at {start}:{end}. Skipping entity: {e}")
                continue

            underscore_label = ORIGINAL_TO_UNDERSCORE_LABEL.get(label)
            if not underscore_label:
                print(f"Label '{label}' not found in mapping. Skipping entity: {e}")
                continue

            processed_entities.append({
                'start': start,
                'end': end,
                'text_span': text_span,
                'underscore_label': underscore_label
            })
        except (ValueError, TypeError) as ve:
             print(f"Error processing entity indices/label {e}: {ve}. Skipping.")
             continue
        except IndexError:
             print(f"Index out of bounds for entity {e} in text of length {len(original_text)}. Skipping.")
             continue

    processed_entities.sort(key=lambda x: x['start'])
    annotated_text = ""
    last_idx = 0

    for entity in processed_entities:
        start = entity['start']
        end = entity['end']
        text_span = entity['text_span']
        underscore_label = entity['underscore_label']

        if start < last_idx:
            print(f"Detected overlapping entity: '{text_span}' at {start} overlaps with previous entity ending at {last_idx-1}. Skipping overlap.")
            continue

        annotated_text += original_text[last_idx:start]
        annotated_text += f"@@{text_span}##{underscore_label}"
        last_idx = end

    annotated_text += original_text[last_idx:]
    return annotated_text

In [None]:
def setup_vector_db(train_data_path, index_path, metadata_path, force_recreate=False):
    """
    Creates or loads a FAISS index and metadata for the training data.
    Stores pmid, location (title/abstract), original text, and its embedding.
    """
    if not force_recreate and os.path.exists(index_path) and os.path.exists(metadata_path):
        print(f"Loading existing FAISS index from {index_path} and metadata from {metadata_path}")
        index = faiss.read_index(index_path)
        with open(metadata_path, 'r', encoding='utf-8') as f:
            metadata = json.load(f)
        return index, metadata

    print("Creating new FAISS index and metadata...")
    train_data = load_json_data(train_data_path)
    if not train_data:
        raise ValueError("Failed to load training data for vector DB setup.")

    embeddings_list = []
    metadata = []  # List of dicts: {'pmid', 'location', 'text'}

    doc_id_counter = 0
    for pmid, data in train_data.items():
        if "metadata" not in data:
            print(f"Skipping PMID {pmid} in DB setup: no metadata.")
            continue

        title = data["metadata"].get("title")
        abstract = data["metadata"].get("abstract")

        # Process Title
        if title:
            title_embedding = get_embedding(title)
            if title_embedding:
                embeddings_list.append(title_embedding)
                metadata.append({
                    "id": doc_id_counter,
                    "pmid": pmid,
                    "location": "title",
                    "text": title
                })
                doc_id_counter += 1

        if abstract:
            abstract_embedding = get_embedding(abstract)
            if abstract_embedding:
                embeddings_list.append(abstract_embedding)
                metadata.append({
                    "id": doc_id_counter,
                    "pmid": pmid,
                    "location": "abstract",
                    "text": abstract
                })
                doc_id_counter += 1


    if not embeddings_list:
        print("No embeddings were generated. Cannot create FAISS index.")
        return None, None

    embeddings_np = np.array(embeddings_list).astype('float32')
    dimension = embeddings_np.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings_np)

    print(f"FAISS index created with {index.ntotal} vectors of dimension {dimension}.")
    faiss.write_index(index, index_path)
    with open(metadata_path, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, ensure_ascii=False, indent=2)
    print(f"FAISS index saved to {index_path}, metadata to {metadata_path}")

    return index, metadata

In [None]:
def base_inline_prompt():
    """Creates the base prompt for the LLM using inline annotation style."""

    prompt = f"""Perform Named Entity Recognition on the biomedical text provided in the "Text" section using a specific inline annotation style.

Task Description:
You are an expert Named Entity Recognition (NER) system specializing in biomedical texts related to the gut-brain axis. Your task is to identify and extract entities from the provided text.
Identify all mentions of entities belonging to the predefined categories listed below. An entity can occur multiple times; treat each occurrence as a separate entity and mark them directly within the text.
Mark the beginning of an entity's span with "@@" and the end of the span with "##" followed immediately by the entity's category label. Note - the output text should be exactly identical to the input text i.e all spaces , special characters/ unicode characters,html/markdown tags,etc (any character) also should be exactly the same, except for the added "@@, ## and entity label" annotations for each entity detected.
Text span of an entity means the actual words/characters that form the entity in the text. An entity's span can contain single or multiple words but never partial words.
The format to follow while marking an entity with its label is  @@<entity_text_span>##<label>

Predefined Entity Categories (use these exact labels after ##):
[
    "anatomical_location", "animal", "biomedical_technique", "bacteria",
    "chemical", "dietary_supplement", "DDF", "drug", "food", "gene",
    "human", "microbiome", "statistical_technique"
]

Note - DDF stands for Disease, Disorder, or Finding. The remaining categories refer to their conventional or scientific meaning.
Also, If the first word or first set of words in output belong to an entity then ensure to start the output with @@ and follow rest of instructions.
"""
    return prompt

In [None]:
def create_few_shot_inline_prompt(
    text_to_annotate,
    num_shots,
    faiss_index,
    faiss_metadata,
    full_train_data,
    openai_client,
    current_pmid,
    current_location
):
    """
    Creates a prompt for the LLM, dynamically injecting few-shot examples
    based on vector similarity.
    """
    # Base prompt structure (0-shot)
    base_prompt = base_inline_prompt()
    examples_section = "Follow the format shown in the detailed examples below precisely. \n ### Examples\n"
    final_text_section = f"### Text\nInput: {text_to_annotate}\nOutput:"

    if num_shots == 0:
        return base_prompt + "\n" + final_text_section

    if not faiss_index or not faiss_metadata:
        print("FAISS index or metadata not available for few-shot example selection. Falling back to 0-shot.")
        return base_prompt + "\n" + final_text_section

    current_text_embedding = get_embedding(text_to_annotate, model=EMBEDDING_MODEL)
    if current_text_embedding is None:
        print("Could not generate embedding for current text. Falling back to 0-shot.")
        return base_prompt + "\n" + final_text_section

    query_embedding = np.array([current_text_embedding]).astype('float32')

    num_abstracts_needed, num_titles_needed = 0, 0
    if num_shots == 1:
        num_abstracts_needed = 1
    elif num_shots == 3:
        num_abstracts_needed = 2
        num_titles_needed = 1
    elif num_shots == 5:
        num_abstracts_needed = 3
        num_titles_needed = 2
    else:
        print(f"Invalid num_shots: {num_shots}. Defaulting to 0-shot.")
        return base_prompt + "\n" + final_text_section

    k_to_fetch = 500
    distances, indices = faiss_index.search(query_embedding, k_to_fetch)

    selected_examples_formatted = []
    selected_example_keys = set() # To avoid duplicate examples if they are very similar

    # Collect candidates
    abstract_candidates = []
    title_candidates = []

    for i in range(len(indices[0])):
        retrieved_idx = indices[0][i]
        if retrieved_idx < 0 or retrieved_idx >= len(faiss_metadata):
            continue

        example_meta = faiss_metadata[retrieved_idx]
        example_pmid = example_meta["pmid"]
        example_location = example_meta["location"]
        example_key = (example_pmid, example_location)

        if example_pmid == current_pmid and example_location == current_location:
            continue
        if example_key in selected_example_keys:
            continue

        if example_location == "abstract":
            abstract_candidates.append(example_meta)
        elif example_location == "title":
            title_candidates.append(example_meta)

    final_selected_metadata = []
    for _ in range(num_abstracts_needed):
        if abstract_candidates:
            meta = abstract_candidates.pop(0)
            final_selected_metadata.append(meta)
            selected_example_keys.add((meta["pmid"], meta["location"]))
    for _ in range(num_titles_needed):
        if title_candidates:
            meta = title_candidates.pop(0)
            final_selected_metadata.append(meta)
            selected_example_keys.add((meta["pmid"], meta["location"]))

    # needed_more = num_shots - len(final_selected_metadata)
    # if needed_more > 0:
    #     remaining_candidates = abstract_candidates + title_candidates
    #     for _ in range(needed_more):
    #         if remaining_candidates:
    #             meta = remaining_candidates.pop(0)
    #             if (meta["pmid"], meta["location"]) not in selected_example_keys:
    #                 final_selected_metadata.append(meta)
    #                 selected_example_keys.add((meta["pmid"], meta["location"]))


    # Format the selected examples
    for example_meta in final_selected_metadata:
        example_pmid = example_meta["pmid"]
        example_location = example_meta["location"]
        example_text = example_meta["text"]

        doc_data = full_train_data.get(example_pmid)
        if not doc_data or "entities" not in doc_data:
            print(f"Could not find entity data for example PMID {example_pmid}. Skipping example.")
            continue

        all_doc_entities = doc_data["entities"]
        example_entities = []

        if example_location == "title":
            example_entities = [e for e in all_doc_entities if e.get("location") == "title"]
            # Indices are already relative to title if location is title
        elif example_location == "abstract":
            # Adjust abstract entity indices to be relative to abstract text
            for e in all_doc_entities:
                if e.get("location") == "abstract":
                    try:
                        rel_e = e.copy()
                        rel_e['start_idx'] = int(e['start_idx'])
                        rel_e['end_idx'] = int(e['end_idx'])
                        if rel_e['start_idx'] >= 0 and rel_e['end_idx'] < len(example_text):
                             example_entities.append(rel_e)
                        else:
                             print(f"Adjusted entity out of bounds for example {example_pmid}/{example_location}: {rel_e}")
                    except (KeyError, ValueError, TypeError) as err:
                        print(f"Error adjusting entity for example {example_pmid}/{example_location}: {err}")


        annotated_example_output = annotate_text(example_text, example_entities)
        selected_examples_formatted.append(f"Input: {example_text}\nOutput: {annotated_example_output}")

    if not selected_examples_formatted:
        print("No valid few-shot examples could be generated. Falling back to 0-shot.")
        return base_prompt + "\n" + final_text_section

    examples_str = "\n".join(selected_examples_formatted)
    return base_prompt + "\n" + examples_section + examples_str + "\n\n" + final_text_section


In [None]:
full_train_data = load_json_data(TRAIN_FILE_PATH)
faiss_index_global, faiss_metadata_global = setup_vector_db(
            TRAIN_FILE_PATH,
            FAISS_INDEX_FILE,
            FAISS_METADATA_FILE,
            force_recreate=False
        )
if faiss_index_global and faiss_metadata_global:
  print("FAISS Index and metadata are ready.")

eval

In [None]:
def parse_llm_output(llm_generated_text, original_text, location, threshold = 10):
    label_pattern = "|".join(re.escape(lbl) for lbl in UNDERSCORE_TO_ORIGINAL_LABEL.keys())

    # 1. Clean stray markers not matching '@@...##<label>'
    # Remove '@@' not followed by valid content and ##label
    stray_at_pattern = r"@@(?!(?:[^#]+)##(?:" + label_pattern + r"))"
    stray_hash_pattern = r"##(?!((?:" + label_pattern + r")))"
    cleaned_text = re.sub(stray_at_pattern, "", llm_generated_text)
    cleaned_text = re.sub(stray_hash_pattern, "", cleaned_text)

    # 2. Find all valid annotations
    pattern = re.compile(r"@@([^#]+?)##(" + label_pattern + r")")
    entities = []
    cumulative_offset = 0

    for match in pattern.finditer(cleaned_text):
        span_text = match.group(1)
        underscore_label = match.group(2)
        original_label = UNDERSCORE_TO_ORIGINAL_LABEL.get(underscore_label)
        if not original_label:
            print(f"Unknown label '{underscore_label}' in LLM output for {location}. Skipping.")
            cumulative_offset += len(match.group(0)) - len(span_text)
            continue

        # Estimate start/end in original text
        raw_start = match.start() - cumulative_offset
        raw_end = raw_start + len(span_text) - 1

        def slice_matches(s, e):
            return 0 <= s <= e < len(original_text) and original_text[s:e+1] == span_text

        if slice_matches(raw_start, raw_end):
            start_idx, end_idx = raw_start, raw_end
        else:
            found = False
            for shift in range(1, threshold + 1):
                s_r = raw_start + shift
                e_r = s_r + len(span_text) - 1
                if slice_matches(s_r, e_r):
                    start_idx, end_idx = s_r, e_r
                    found = True
                    break
                s_l = raw_start - shift
                e_l = s_l + len(span_text) - 1
                if slice_matches(s_l, e_l):
                    start_idx, end_idx = s_l, e_l
                    found = True
                    break
            if not found:
                print(
                    f"Could not align span '{span_text}' at raw indices [{raw_start}:{raw_end+1}] "
                    f"within ±{threshold} chars in original {location}. Skipping."
                )
                cumulative_offset += len(match.group(0)) - len(span_text)
                continue

        entities.append({
            "start_idx": start_idx,
            "end_idx": end_idx,
            "location": location,
            "text_span": span_text,
            "label": original_label,
        })

        marker_length = len(match.group(0)) - len(span_text)
        cumulative_offset += marker_length

    return entities

In [None]:
def process_documents(input_path, output_path, shot):
    """Loads data, calls LLM for inline annotation, parses output, and saves."""
    dev_data = load_json_data(input_path)
    all_predictions = {}
    processed_count = 0

    for pmid, data in dev_data.items():
        # if processed_count <=0 or processed_count>=2:
        #   processed_count+=1
        #   continue
        print(f"Processing PMID: {pmid} ({processed_count + 1}/{len(dev_data)})...")
        combined_entities = []
        try:
            # if "metadata" not in data or "title" not in data["metadata"] or "abstract" not in data["metadata"]:
            #      print(f"Skipping PMID {pmid} due to missing metadata, title, or abstract.")
            #      continue

            # title = data["metadata"]["title"]
            # abstract = data["metadata"]["abstract"]
            title = data["title"]
            abstract = data["abstract"]
            title_prompt = create_few_shot_inline_prompt(title,shot,faiss_index_global,faiss_metadata_global,full_train_data,client,pmid,"title")
            llm_title_output = None
            try:
              response = client.chat.completions.create(
                  model=OPENAI_MODEL,
                  messages=[{"role": "user", "content": title_prompt}],
                  temperature=TEMPERATURE,
                  top_p=TOP_P,
                  max_tokens=MAX_TOKENS
              )
              llm_title_output = response.choices[0].message.content.strip()
              print(title)
              print(llm_title_output)
            except Exception as e:
                print(f"API call failed for title PMID {pmid} title prompt")

            if llm_title_output:
                 title_entities = parse_llm_output(llm_title_output, title, "title")
                 combined_entities.extend(title_entities)

            abstract_prompt = create_few_shot_inline_prompt(abstract,shot,faiss_index_global,faiss_metadata_global,full_train_data,client,pmid,"abstract")
            llm_abstract_output = None
            try:
               response = client.chat.completions.create(
                   model=OPENAI_MODEL,
                   messages=[{"role": "user", "content": abstract_prompt}],
                   temperature=TEMPERATURE,
                   top_p=TOP_P,
                   max_tokens=MAX_TOKENS
               )
               llm_abstract_output = response.choices[0].message.content.strip()
               print (abstract)
               print (llm_abstract_output)
            except Exception as e:
               print(f"API call failed for abstract PMID {pmid} abstract")

            if llm_abstract_output:
                abstract_entities = parse_llm_output(llm_abstract_output, abstract, "abstract")
                combined_entities.extend(abstract_entities)

            all_predictions[pmid] = {"entities": combined_entities}
            processed_count += 1

        except Exception as e:
            print(f"An unexpected error occurred outside API calls while processing PMID {pmid}: {e}")


    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(all_predictions, f, ensure_ascii=False, indent=4)
        print(f"\nSuccessfully processed {processed_count} documents.")
        print(f"Predictions saved to {output_path}")
    except IOError as e:
        print(f"Error writing predictions to {output_path}: {e}")

In [None]:
process_documents("/content/articles_test.json",OUTPUT_FILE_PATH,0)