# Named Entity Recognition for Process Extraction with GoLLIE

## Import requirements

In [1]:
import sys
sys.path.append("../")

In [2]:
import rich
import logging
from src.model.load_model import load_model
import black
import inspect
from jinja2 import Template
import tempfile
from src.tasks.utils_typing import AnnotationList
logging.basicConfig(level=logging.INFO)
from typing import Dict, List, Type

### Load Model from HuggingFace

In [4]:
model, tokenizer = load_model(
    inference=True,
    model_weights_name_or_path="HiTZ/GoLLIE-34B",
    quantization=4,
    cache_dir="/work3/s213709",
    use_lora=False,
    force_auto_device_map=True,
    use_flash_attention=True,
    torch_dtype="bfloat16"
)

INFO:root:Loading model model from HiTZ/GoLLIE-34B
INFO:root:We will load the model using the following device map: auto and max_memory: None
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
INFO:root:Bits and Bytes config: {
    "quant_method": "bitsandbytes",
    "load_in_8bit": false,
    "load_in_4bit": true,
    "llm_int8_threshold": 6.0,
    "llm_int8_skip_modules": null,
    "llm_int8_enable_fp32_cpu_offload": false,
    "llm_int8_has_fp16_weight": false,
    "bnb_4bit_quant_type": "nf4",
    "bnb_4bit_use_double_quant": true,
    "bnb_4bit_compute_dtype": "bfloat16"
}
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

INFO:root:Model dtype: torch.bfloat16
INFO:root:Total model memory footprint: 17659.621666 MB


### Define Guidelines for the Extraction of Process Elements from Medical Guidelines 

#### Simple Annotation Schema

In [5]:
from typing import List

from src.tasks.utils_typing import Entity, dataclass

"""
Entity definitions
"""

@dataclass
class Observation(Entity):
    """An Observation refers to any piece of information or data that is noted or recorded about a patient's health status, this includes
    symptoms, diagnoses, test results risk factors like smoking or contextual information like the age or conditions of the patient."""
    
    span: str
    """
    Such as: not possible to administer, distribution of 18-F-FDG, severe hypoxemia, reduced eGFR, Diabetes, hypercapnia, postoperative pain, prostate cancer, elderly
    """


@dataclass
class Activity(Entity):
    """An Activity refers to any action performed by a patient or a healthcare professional, this includes
    tasks, procedures, surgeries, treatments, medication, or other types of interventions. Activities can also be events like start of the scan."""
    
    span: str # Such as: eat, fasting, start, monitored, assessed, reviewed, performed, administered, prescribed, recommended, catheterization, filled, referred, continued, anesthesia

@dataclass
class Input(Entity):
    """An Input Entity refers to any word or phrase that denotes a specific type of clinical measurement, score, or value. 
    These include, but are not limited to, physiological measurements, lab test scores, and specific clinical indices. 
    The difference to the guard is that it not represents specific numerical values or thresholds related but the category of the measurement or value for example
    Blood Pressure, Heart Rate, PEEP, Pmean or PaO2/FiO2. 
    """

    span: str # Such as: pH, Blood Pressure, Heart Rate, PEEP, Pmean or PaO2/FiO2, retention, distribution, eGFR, se-creatinine, sizes

@dataclass
class Output(Entity):
    """An Input Entity refers to any word or phrase that denotes a specific type of clinical measurement, score, or value. 
    These include, but are not limited to, physiological measurements, lab test scores, and specific clinical indices. 
    The difference to the guard is that it not represents specific numerical values or thresholds related but the category of the measurement or value for example
    Blood Pressure, Heart Rate, PEEP, Pmean or PaO2/FiO2. 
    """

    span: str # Such as: pH, Blood Pressure, Heart Rate, PEEP, Pmean or PaO2/FiO2, retention, distribution, eGFR, se-creatinine, sizes

@dataclass
class Actor(Entity):
    """An Actor refers to any person or entity that is involved in an activity, this includes
    patients, doctors, nurses, or other healthcare professionals. Actors can perform activites or be the target of activities."""
    
    span: str # Such as: patients, patient, doctor, pt., anesthesiologist, 

@dataclass
class ActivityData(Entity):
    """An ActivityData entity refers to the data or object directly used by an activity, this includes devices, medications, objects. 
    This could be for example an injection or a scan or insuline."""
    
    span: str # Such as: flow and residual urine, insuline, bladder, an appointment, tablet paracetamol, antidiabetic medication

@dataclass
class Specification(Entity):
    """A Specification entity refers to any information that further describes an activity, this includes
    the time, the location, the dosage, the quantity, the frequency, the duration, additional information or the type of the activity.
    Specifications are often linked with prepositions like for example at, to, in, into, on, for, with, within, while, as, according to, across, after, by, during, for, over, when, where."""
    
    span: str # Such as: between the first and second tracheal, long-term, acute phase, outpatient clinic, saline, elsewhere in the body, following afternoon

@dataclass
class Guard(Entity):
    """A Guard refers to a specific type of information that sets conditions, limits, or thresholds in the clinical context. 
    These entities often represent critical values or timeframes that impact clinical decisions, such as dosage limits, duration of treatment, or thresholds for test results.
    This can include measurements (like volume or concentration), timeframes (like durations or frequencies), or any other quantifiable condition that affects clinical decisions."""
    
    span: str # after 1 week, >1000 ml, <1000 ml, at least 2 hours, for 6 hours, < 45

@dataclass
class PurposeOutcome(Entity):
    """A PurposeOutcome entity captures the underlying reason, goal, objective, or anticipated result of a clinical action, procedure, or recommendation. 
    It addresses the "why" or the intended effect of a medical intervention or guideline. Examples would be: to reduce the risk of stroke or so that the patient can sleep better.
    The PurposeOutcome entity is often connected with prepositions like for, to, in order to, so that, to ensure, because maybe or because of."""

    span: str # adequate bladder volume, 

@dataclass
class And(Entity):
    """An And entity connects two or more activities that are linked by the conjunction "and." This entity indicates that all linked activities are required or occur in conjunction.
       Primarily used in scenarios where multiple steps or conditions are simultaneously necessary. For instance, in a treatment plan, if multiple treatment activities need to be executed together.
    """

    span: str # Such as: and, &, +, as well as

@dataclass
class Or(Entity):
    """An Or entity links two or more activities or options, using the conjunction "or." 
    It signifies that any one of the linked activities or observations may be chosen or is applicable, but not necessarily all.
    Useful in cases where multiple options are available, and the choice of one excludes the others. Often seen in treatment plans where alternative activities are viable."""

    span: str # Such as: or, /, 

@dataclass
class Xor(Entity):
    """An Xor (exclusive or) entity connects two or more mutually exclusive activities, actions, tasks or observations, using the concept of "xor." 
    It implies that only one of the linked activities can be chosen or applies, and selecting one excludes the others.
    Applied in situations where two or more options are available but are mutually exclusive. It's critical in scenarios where the selection of one option inherently rules out the others.
    """

    span: str 

@dataclass
class RelationResponse(Entity):
    """A RelationResponse entity captures the relationship between two activties or an observation and a activity which need to be executed. So after executing the first activity the second activity must be exectued.
    The RelationResponse entity can therefore be for example action A and requires action B, observation A: administer drug B. 
    """

    span: str # can be repeated, and requires, whether, during this period, in case of, must

@dataclass
class RelationCondition(Entity):
    """A RelationCondition entity captures the relationship between two activties or an observation and a activity which need to be executed in a specific order. So action B can only be executed after action A.
    Activity B could for instance be “Prescribe medicin”. For that to happen a medical examination has to take place, which could be activity A.
    The RelationCondition entity can therefore be for example "in cases where" so "use infusion fluid only in cases where the patient is dehydrated" or "before" so "before prescribing medication, perform a medical examination".
    """

    span: str # Such as: and finally, when, followed by, after, during this period, :, before, until

@dataclass
class RelationExclusion(Entity):
    """A RelationExclusion entity caputres the relationship between two activities or an observation and an activity where one is excluding the other. 
    Examples for the RelationExclusion entity would be for example not routinely recommmended, or should not be, if observation activity is not possible."""

    span: str # Such as: must not, not suitable, should not

@dataclass
class RelationInclusion(Entity):
    """A RelationInclusion entity captures the relationship between two activities or an observation and an activity where one is including the other. For example blood tests are not required unless
    observation A is true. The RelationInclusion entity can therefore be for example "unless" so "blood tests are not required unless observation A is true" 
    or "if" so "if observation A is true, then blood tests are required"."""

    span: str

ENTITY_DEFINITIONS: List[Entity] = [
    Observation,
    Activity,
    Input,
    Output,
    Actor,
    ActivityData,
    Specification,
    Guard,
    PurposeOutcome,
    And,
    Or,
    Xor,
    RelationResponse,
    RelationCondition,
    RelationExclusion,
    RelationInclusion
]

if __name__ == "__main__":
    cell_text = In[-1]


Due to IPython limitations, we must write the content of the previous cell to a file and then import the content from that file.

In [6]:
with open("guidelines.py","w",encoding="utf8") as python_guidelines:
    print(cell_text,file=python_guidelines)

from guidelines import *

We use inspect.getsource to get the guidelines as a string

In [7]:
guidelines = [inspect.getsource(definition) for definition in ENTITY_DEFINITIONS]

### Load input sentences

In [8]:
text = "During this period, the patient should only drink tap water Non-insulin-dependent diabetes: * The patient must not take their antidiabetic medication in the morning and follow the usual guidelines, i.e. fasting for 6 hours before the start of the scan, after the scan the patient can eat and take their antidiabetic medication as usual"
gold = [
    RelationResponse(span= "During this period"),
    Actor(span= "patient"),
    Activity(span= "drink"),
    ActivityData(span= "tap water"),
    Observation(span= "Non-insulin-dependent diabetes"),
    Actor(span= "patient"),
    RelationExclusion(span= "must not"),
    Activity(span= "take"),
    ActivityData(span= "antidiabetic medication"),
    Specification(span= "in the morning"),
    Activity(span= "follow"),
    ActivityData(span= "usual guidelines"),
    Activity(span="fasting"),
    Guard(span= "for 6 hours"),
    Activity(span= "start"),
    ActivityData(span= "scan"),
    RelationCondition(span= "after"),
    Actor(span= "patient"),
    Activity(span= "scan"),
    Activity(span= "eat"),
    Activity(span= "take"),
    ActivityData(span= "antidiabetic medication"),
    Specification(span= "as usual"),
]


In [9]:
import os
print(os.getcwd())

/zhome/06/4/166098/GoLLIEProcessExtraction


#### Filling a template

In [27]:
# Read template
with open("./templates/prompt.txt", "rt") as f:
    template = Template(f.read())
# Fill the template
formated_text = template.render(guidelines=guidelines, text=text, annotations=gold, gold=gold)


### Black Code Formatter

In [28]:
black_mode = black.Mode()
formated_text = black.format_str(formated_text, mode=black_mode)

#### Print the filled and formatted template

In [29]:
rich.print(formated_text)

### Prepare model inputs

In [30]:
prompt, _ = formated_text.split("result =")
prompt = prompt + "result ="

Tokenize the input sentences

In [31]:
model_input = tokenizer(prompt, add_special_tokens=True, return_tensors="pt")

Remove the eos token from the input

In [32]:
model_input["input_ids"] = model_input["input_ids"][:, :-1]
model_input["attention_mask"] = model_input["attention_mask"][:, :-1]

## Run GoLLIE

Now we generate the predictions with GoLLIE
We use num_beams=1 and do_sample=False in our exmperiments.

In [40]:
%%time

model_ouput = model.generate(
    **model_input.to(model.device),
    max_new_tokens=128,
    do_sample=True,
    min_new_tokens=0,
    num_beams=2,
    num_return_sequences=2,
)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


OutOfMemoryError: CUDA out of memory. Tried to allocate 142.00 MiB. GPU 0 has a total capacty of 39.39 GiB of which 94.38 MiB is free. Including non-PyTorch memory, this process has 39.29 GiB memory in use. Of the allocated memory 34.38 GiB is allocated by PyTorch, and 4.41 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

### Print the results

In [34]:
for y, x in enumerate(model_ouput):
    print(f"Answer {y}")
    rich.print(tokenizer.decode(x,skip_special_tokens=True).split("result = ")[-1])

Answer 0


### Parse the output

In [35]:
result = AnnotationList.from_output(
    tokenizer.decode(model_ouput[0],skip_special_tokens=True).split("result = ")[-1],
    task_module="guidelines"
    )
rich.print(result)

## Evaluate the results
First, we define an Scorer, for Named Entity Recognition, we will use the SpanScorer class.

We need to define the valid_types for the scorer, which will be the labels that we have defined

In [36]:
from src.tasks.utils_scorer import SpanScorer

class MyEntityScorer(SpanScorer):
    """Compute the F1 score for Named Entity Recogtion Tasks"""

    valid_types: List[Type] = ENTITY_DEFINITIONS

    def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]:
        output = super().__call__(reference, predictions)
        return {"entities": output["spans"]}

#### Initialize the scorer

In [37]:
scorer = MyEntityScorer()

#### Compute F1

In [38]:

scorer_results = scorer(reference=[gold],predictions=[result])
rich.print(scorer_results)