In [23]:
import dspy
from dspy.predict import Retry
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
import importlib
from tqdm import tqdm
import json
from datasets import load_dataset
from dspy.evaluate import Evaluate
from typing import List
from dspy.teleprompt import BootstrapFewShot

importlib.reload(dspy)

llama = dspy.HFClientVLLM(model="meta-llama/Meta-Llama-3-8B", port=8000, url="http://gpub001.delta.ncsa.illinois.edu", tokens=1024)
dspy.settings.configure(lm=llama)

In [31]:
class ProcessTokens(dspy.Signature):
    """You are an expert in biomedical natural language processing. Your task is to identify disease names in a given list of tokens.
    For each token, assign one of the following tags:
    0 - if the token is not part of a disease name
    1 - if the token is the beginning of a disease name
    2 - if the token is inside (continuation) of a disease name
    Provide your answer as a space-separated list of numbers (0, 1, or 2) corresponding to each input token.
    Make sure your list is enclosed by brackets."""

    tokens = dspy.InputField(desc="A string representation of a list of tokens from a biomedical text")
    ner_tags = dspy.OutputField(desc="A list of NER tags (0, 1, or 2) for each token")

class NCBIDiseaseNER(dspy.Module):
    def __init__(self):
        super().__init__()
        self.process_tokens = dspy.ChainOfThought(ProcessTokens)
    
    def forward(self, tokens):
        tokens_str = json.dumps(tokens)  # Convert list to string representation
        response = self.process_tokens(tokens=tokens_str)
        
        try:
            ner_tags = json.loads(response.ner_tags)
            if not isinstance(ner_tags, list):
                ner_tags = [int(tag) for tag in response.ner_tags.split() if tag in ['0', '1', '2']]
            else:
                ner_tags = [int(tag) for tag in ner_tags if tag in [0, 1, 2]]
        except (json.JSONDecodeError, AttributeError):
            ner_tags = []

        # Handle mismatch in length
        if len(ner_tags) < len(tokens):
            print(f"Warning: Fewer tags ({len(ner_tags)}) than tokens ({len(tokens)}). Padding with 0s.")
            ner_tags.extend([0] * (len(tokens) - len(ner_tags)))
        elif len(ner_tags) > len(tokens):
            print(f"Warning: More tags ({len(ner_tags)}) than tokens ({len(tokens)}). Truncating.")
            ner_tags = ner_tags[:len(tokens)]
        
        # Ensure the output is enclosed in brackets
        if not (ner_tags.startswith('[') and ner_tags.endswith(']')):
            ner_tags = f"[{','.join(map(str, ner_tags))}]"
        
        return dspy.Prediction(ner_tags=ner_tags)
        

def eval_metric(true, prediction, trace=None):
    try:
        return [int(t) for t in true.ner_tags] == prediction.ner_tags
    except:
        return False


In [32]:
dataset = load_dataset("ncbi_disease")

def generate_dspy_examples(dataset_split):
    examples = []
    for item in dataset_split:
        example = dspy.Example({
            "tokens": item["tokens"],
            "ner_tags": item["ner_tags"]
        }).with_inputs("tokens")
        examples.append(example)
    return examples

# Generate DSPy examples for each split
train_dspy_examples = generate_dspy_examples(dataset["train"])
val_dspy_examples = generate_dspy_examples(dataset["validation"])
test_dspy_examples = generate_dspy_examples(dataset["test"]) 


In [33]:
test_tokens = test_dspy_examples[1].tokens
test_tokens = json.dumps(test_tokens)
test_tokens

'["Ataxia", "-", "telangiectasia", "(", "A", "-", "T", ")", "is", "a", "recessive", "multi", "-", "system", "disorder", "caused", "by", "mutations", "in", "the", "ATM", "gene", "at", "11q22", "-", "q23", "(", "ref", ".", "3", ")", "."]'

In [34]:
type(test_tokens)

str

In [35]:
# Define the predictor.
generate_answer = dspy.Predict(ProcessTokens)
pred = generate_answer(tokens = test_tokens)

In [36]:
pred

Prediction(
    ner_tags='[1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,'
)

In [37]:
llama.inspect_history(n=1)




You are an expert in biomedical natural language processing. Your task is to identify disease names in a given list of tokens.
For each token, assign one of the following tags:
0 - if the token is not part of a disease name
1 - if the token is the beginning of a disease name
2 - if the token is inside (continuation) of a disease name
Provide your answer as a space-separated list of numbers (0, 1, or 2) corresponding to each input token.
Make sure your list is enclosed by brackets.

---

Follow the following format.

Tokens: A string representation of a list of tokens from a biomedical text
Ner Tags: A list of NER tags (0, 1, or 2) for each token

---

Tokens: ["Ataxia", "-", "telangiectasia", "(", "A", "-", "T", ")", "is", "a", "recessive", "multi", "-", "system", "disorder", "caused", "by", "mutations", "in", "the", "ATM", "gene", "at", "11q22", "-", "q23", "(", "ref", ".", "3", ")", "."]
Ner Tags:[32m [1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

'\n\n\nYou are an expert in biomedical natural language processing. Your task is to identify disease names in a given list of tokens.\nFor each token, assign one of the following tags:\n0 - if the token is not part of a disease name\n1 - if the token is the beginning of a disease name\n2 - if the token is inside (continuation) of a disease name\nProvide your answer as a space-separated list of numbers (0, 1, or 2) corresponding to each input token.\nMake sure your list is enclosed by brackets.\n\n---\n\nFollow the following format.\n\nTokens: A string representation of a list of tokens from a biomedical text\nNer Tags: A list of NER tags (0, 1, or 2) for each token\n\n---\n\nTokens: ["Ataxia", "-", "telangiectasia", "(", "A", "-", "T", ")", "is", "a", "recessive", "multi", "-", "system", "disorder", "caused", "by", "mutations", "in", "the", "ATM", "gene", "at", "11q22", "-", "q23", "(", "ref", ".", "3", ")", "."]\nNer Tags:\x1b[32m [1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [38]:

# Set up the evaluation
evaluate_test = Evaluate(
    devset=test_dspy_examples[:20],  # Using a subset for faster testing
    metric=eval_metric, 
    num_threads=3, 
    display_progress=True, 
    display_table=True
)

In [39]:
initial_model = NCBIDiseaseNER()
initial_performance = evaluate_test(initial_model)
print(f"Initial Performance: {initial_performance}")

  0%|          | 0/20 [00:00<?, ?it/s]

[2m2024-08-31T11:26:46.406204Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 'list' object has no attribute 'startswith'[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m183[0m
[2m2024-08-31T11:26:46.407048Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 'list' object has no attribute 'startswith'[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m183[0m
Average Metric: 0.0 / 1  (0.0):   0%|          | 0/20 [00:04<?, ?it/s][2m2024-08-31T11:26:46.410223Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 'list' object has no attribute 'startswith'[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m183[0m
Average Metric: 0.0 / 3  (0.0):  10%|█         | 2/20 [00:04<01:26,  4.82s/it]



[2m2024-08-31T11:26:51.236560Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 'list' object has no attribute 'startswith'[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m183[0m






AttributeError: 'list' object has no attribute 'startswith'

In [None]:
# 3. Performance after prompt optimization with OSS model
teleprompter = BootstrapFewShot(metric=eval_metric)
compiled_model = teleprompter.compile(NCBIDiseaseNER(), trainset=train_dspy_examples[:100])  # Using a subset for faster compilation
optimized_performance = evaluate_test(compiled_model)
print(f"Optimized Performance: {optimized_performance}")

  0%|          | 0/100 [00:00<?, ?it/s][2m2024-08-31T11:16:52.070520Z[0m [[31m[1merror    [0m] [1mFailed to run or to evaluate example Example({'tokens': ['Identification', 'of', 'APC2', ',', 'a', 'homologue', 'of', 'the', 'adenomatous', 'polyposis', 'coli', 'tumour', 'suppressor', '.'], 'ner_tags': [0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 0, 0]}) (input_keys={'tokens'}) with <function eval_metric at 0x7ff0076d8f40> due to Need format_handler for tokens of type <class 'list'>.[0m [[0m[1m[34mdspy.teleprompt.bootstrap[0m][0m [36mfilename[0m=[35mbootstrap.py[0m [36mlineno[0m=[35m211[0m
[2m2024-08-31T11:16:52.071760Z[0m [[31m[1merror    [0m] [1mFailed to run or to evaluate example Example({'tokens': ['The', 'adenomatous', 'polyposis', 'coli', '(', 'APC', ')', 'tumour', '-', 'suppressor', 'protein', 'controls', 'the', 'Wnt', 'signalling', 'pathway', 'by', 'forming', 'a', 'complex', 'with', 'glycogen', 'synthase', 'kinase', '3beta', '(', 'GSK', '-', '3beta', ')', ',', '

AssertionError: Need format_handler for tokens of type <class 'list'>

In [None]:
# evaluate_test(NCBIDiseaseNER())