In [None]:
%pip install tiktoken

In [None]:
import json
import os

In [None]:
PATH_PLATINUM_TRAIN = "/content/train_platinum.json"
PATH_GOLD_TRAIN = "/content/train_gold.json"
PATH_DEV = "/content/dev.json"
OUTPUT_TRAIN_FILE = "/content/gutbrain_finetune_train.jsonl"
OUTPUT_VALID_FILE = "/content/gutbrain_finetune_valid.jsonl"

In [None]:
ORIGINAL_TO_UNDERSCORE_LABEL = {'anatomical location': 'anatomical_location', 'animal': 'animal', 'biomedical technique': 'biomedical_technique', 'bacteria': 'bacteria', 'chemical': 'chemical', 'dietary supplement': 'dietary_supplement', 'DDF': 'DDF', 'drug': 'drug', 'food': 'food', 'gene': 'gene', 'human': 'human', 'microbiome': 'microbiome', 'statistical technique': 'statistical_technique'}

In [None]:
UNDERSCORE_LABELS = list(ORIGINAL_TO_UNDERSCORE_LABEL.values())

In [None]:
def create_inline_prompt(text_to_annotate):
    """Creates the prompt for the LLM using inline annotation style."""

    prompt = f"""
Task Description:
Your task is to identify and extract entities from the provided text. Perform Named Entity Recognition on the biomedical text provided in the "Text" section using a specific inline annotation style.
Identify all mentions of entities belonging to the predefined categories listed below. An entity can occur multiple times; treat each occurrence as a separate entity and mark them directly within the text.
Mark the beginning of an entity's span with "@@" and the end of the span with "##" followed immediately by the entity's category label. Note - the output text should be exactly identical to the input text i.e all spaces , special characters/ unicode characters,html/markdown tags,etc (or any other character) also should be exactly the same, except for the added "@@, ## and entity label" annotations for each entity detected.
Text span of an entity means the actual words/characters that form the entity in the text. An entity's span can contain single or multiple words but never partial words.
The format to follow while marking an entity with its label is  @@<entity_text_span>##<label>

Predefined Entity Categories (use these exact labels after ##):
[
    "anatomical_location", "animal", "biomedical_technique", "bacteria",
    "chemical", "dietary_supplement", "DDF", "drug", "food", "gene",
    "human", "microbiome", "statistical_technique"
]

Note - DDF stands for Disease, Disorder, or Finding. The remaining categories refer to their conventional or scientific meaning.
Also, If the first word or first set of words in output belong to an entity then ensure to start the output with @@ and follow rest of instructions.

### Text
Input: {text_to_annotate}
Output:"""
    return prompt

In [None]:
def annotate_text(original_text, entities):
    if not entities:
        return original_text

    processed_entities = []
    for e in entities:
        if not all(k in e for k in ['start_idx', 'end_idx', 'text_span', 'label']):
             print(f"Skipping invalid entity structure: {e}")
             continue
        try:
            start = int(e['start_idx'])
            end = int(e['end_idx']) + 1
            label = e['label']
            text_span = e['text_span']

            if original_text[start:end] != text_span:
                print(f"Span mismatch! Expected '{original_text[start:end]}', found '{text_span}' at {start}:{end}. Skipping entity: {e}")
                continue

            underscore_label = ORIGINAL_TO_UNDERSCORE_LABEL.get(label)
            if not underscore_label:
                print(f"Label '{label}' not found in mapping. Skipping entity: {e}")
                continue

            processed_entities.append({
                'start': start,
                'end': end,
                'text_span': text_span,
                'underscore_label': underscore_label
            })
        except (ValueError, TypeError) as ve:
             print(f"Error processing entity indices/label {e}: {ve}. Skipping.")
             continue
        except IndexError:
             print(f"Index out of bounds for entity {e} in text of length {len(original_text)}. Skipping.")
             continue

    processed_entities.sort(key=lambda x: x['start'])
    annotated_text = ""
    last_idx = 0

    for entity in processed_entities:
        start = entity['start']
        end = entity['end']
        text_span = entity['text_span']
        underscore_label = entity['underscore_label']

        if start < last_idx:
            print(f"Detected overlapping entity: '{text_span}' at {start} overlaps with previous entity ending at {last_idx-1}. Skipping overlap.")
            continue

        annotated_text += original_text[last_idx:start]
        annotated_text += f"@@{text_span}##{underscore_label}"
        last_idx = end

    annotated_text += original_text[last_idx:]
    return annotated_text

In [None]:
def load_json_data(filepath):
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"File not found: {filepath}")
        return None
    except json.JSONDecodeError:
        print(f"Error decoding JSON from: {filepath}")
        return None

In [None]:
def create_finetuning_file(input_data, output_filepath, system_prompt):

    processed_count = 0
    skipped_docs = 0
    with open(output_filepath, 'w', encoding='utf-8') as outfile:
        for pmid, data in input_data.items():
            try:
                if "metadata" not in data or "title" not in data["metadata"] or "abstract" not in data["metadata"]:
                    print(f"Skipping PMID {pmid} due to missing metadata.")
                    skipped_docs += 1
                    continue

                title = data["metadata"]["title"]
                abstract = data["metadata"]["abstract"]
                all_entities = data.get("entities", [])

                title_entities = [e for e in all_entities if e.get("location") == "title"]
                abstract_entities = [e for e in all_entities if e.get("location") == "abstract"]

                if title:
                    relative_title_entities = []
                    for e in title_entities:
                        try:
                           rel_e = e.copy()
                           relative_title_entities.append(rel_e)
                        except KeyError as ke:
                             print(f"Missing key {ke} in title entity {e} for PMID {pmid}. Skipping entity.")


                    annotated_title = annotate_text(title, relative_title_entities)
                    message = {
                        "messages": [
                            {"role": "system", "content": system_prompt},
                            {"role": "user", "content": create_inline_prompt(title)},
                            {"role": "assistant", "content": annotated_title}
                        ]
                    }
                    outfile.write(json.dumps(message, ensure_ascii=False) + "\n")

                if abstract:
                    relative_abstract_entities = []
                    for e in abstract_entities:
                         try:
                            rel_e = e.copy()
                            rel_e['start_idx'] = int(e['start_idx'])
                            rel_e['end_idx'] = int(e['end_idx'])
                            if rel_e['start_idx'] < 0 or rel_e['end_idx'] >= len(abstract):
                                 print(f"Invalid relative index for abstract entity {e} (relative: {rel_e['start_idx']}:{rel_e['end_idx']}) in PMID {pmid}. Skipping entity.")
                                 continue
                            relative_abstract_entities.append(rel_e)
                         except (KeyError, ValueError, TypeError) as err:
                             print(f"Error adjusting abstract entity {e} for PMID {pmid}: {err}. Skipping entity.")


                    annotated_abstract = annotate_text(abstract, relative_abstract_entities)
                    message = {
                        "messages": [
                            {"role": "system", "content": system_prompt},
                            {"role": "user", "content": create_inline_prompt(abstract)},
                            {"role": "assistant", "content": annotated_abstract}
                        ]
                    }
                    outfile.write(json.dumps(message, ensure_ascii=False) + "\n")

                processed_count += 1
                if processed_count % 100 == 0:
                    print(f"Processed {processed_count} documents...")

            except Exception as e:
                print(f"Unexpected error processing PMID {pmid}: {e}")
                skipped_docs += 1

    print(f"Finished writing to {output_filepath}.")
    print(f"Total documents processed: {processed_count}")
    print(f"Total documents skipped: {skipped_docs}")

In [None]:
print("Loading datasets...")
train_platinum_data = load_json_data(PATH_PLATINUM_TRAIN)
train_gold_data = load_json_data(PATH_GOLD_TRAIN)
valid_data = load_json_data(PATH_DEV)
train_data = {}
train_data.update(train_platinum_data)
train_data.update(train_gold_data)
print(f"Combined training data: {len(train_data)} documents.")
print(f"Validation data: {len(valid_data)} documents.")

In [None]:
SYSTEM_PROMPT = "You are an expert Named Entity Recognition (NER) system specializing in biomedical texts related to the gut-brain axis."
create_finetuning_file(train_data, OUTPUT_TRAIN_FILE, SYSTEM_PROMPT)
create_finetuning_file(valid_data, OUTPUT_VALID_FILE, SYSTEM_PROMPT)

Validate files

In [None]:
import tiktoken
import numpy as np
from collections import defaultdict

In [None]:
def get_stats(path):
  with open(path, 'r', encoding='utf-8') as f:
      dataset = [json.loads(line) for line in f]

  # Initial dataset stats
  print("Num examples:", len(dataset))
  print("First example:")
  for message in dataset[0]["messages"]:
      print(message)
  return dataset

train_set = get_stats(OUTPUT_TRAIN_FILE)
val_set = get_stats(OUTPUT_VALID_FILE)

In [None]:
def check_for_errors(dataset):
  format_errors = defaultdict(int)
  for ex in dataset:
      if not isinstance(ex, dict):
          format_errors["data_type"] += 1
          continue

      messages = ex.get("messages", None)
      if not messages:
          format_errors["missing_messages_list"] += 1
          continue

      for message in messages:
          if "role" not in message or "content" not in message:
              format_errors["message_missing_key"] += 1

          if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
              format_errors["message_unrecognized_key"] += 1

          if message.get("role", None) not in ("system", "user", "assistant", "function"):
              format_errors["unrecognized_role"] += 1

          content = message.get("content", None)
          function_call = message.get("function_call", None)

          if (not content and not function_call) or not isinstance(content, str):
              format_errors["missing_content"] += 1

      if not any(message.get("role", None) == "assistant" for message in messages):
          format_errors["example_missing_assistant_message"] += 1

  if format_errors:
      print("Found errors:")
      for k, v in format_errors.items():
          print(f"{k}: {v}")
  else:
      print("No errors found")

check_for_errors(train_set)
check_for_errors(val_set)

In [None]:
encoding = tiktoken.get_encoding("cl100k_base")
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3
    return num_tokens

def num_assistant_tokens_from_messages(messages):
    num_tokens = 0
    for message in messages:
        if message["role"] == "assistant":
            num_tokens += len(encoding.encode(message["content"]))
    return num_tokens

def print_distribution(values, name):
    print(f"\n#### Distribution of {name}:")
    print(f"min / max: {min(values)}, {max(values)}")
    print(f"mean / median: {np.mean(values)}, {np.median(values)}")
    print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")

In [None]:
def get_counts(dataset):
  n_missing_system = 0
  n_missing_user = 0
  n_messages = []
  convo_lens = []
  assistant_message_lens = []

  for ex in dataset:
      messages = ex["messages"]
      if not any(message["role"] == "system" for message in messages):
          n_missing_system += 1
      if not any(message["role"] == "user" for message in messages):
          n_missing_user += 1
      n_messages.append(len(messages))
      convo_lens.append(num_tokens_from_messages(messages))
      assistant_message_lens.append(num_assistant_tokens_from_messages(messages))

  print("Num examples missing system message:", n_missing_system)
  print("Num examples missing user message:", n_missing_user)
  print_distribution(n_messages, "num_messages_per_example")
  print_distribution(convo_lens, "num_total_tokens_per_example")
  print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
  n_too_long = sum(l > 10000 for l in convo_lens)
  print(f"\n{n_too_long} examples may be over the 10,000 token limit, they will be truncated during fine-tuning")
  return convo_lens

train_convo_lens = get_counts(train_set)
val_convo_lens = get_counts(val_set)

In [None]:
# Pricing and default n_epochs estimate
MAX_TOKENS_PER_EXAMPLE = 16385
TARGET_EPOCHS = 5
MIN_TARGET_EXAMPLES = 100
MAX_TARGET_EXAMPLES = 400
MIN_DEFAULT_EPOCHS = 5
MAX_DEFAULT_EPOCHS = 5

n_epochs = TARGET_EPOCHS
n_train_examples = len(train_set)
if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
    n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
    n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)

n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in train_convo_lens)
print(f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training")
print(f"By default, you'll train for {n_epochs} epochs on this dataset")
print(f"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens")