In [2]:
# logger
import logging
logging.basicConfig(filename='logs.txt',
                    filemode='a',
                    format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                    datefmt='%H:%M:%S',
                    level=logging.DEBUG)
logger = logging.getLogger()
logger.setLevel(level=logging.DEBUG)

# imports
import ast
from fhirclient.client import FHIRClient
import configparser
import gradio_client
import fhirclient.models.observation as o
import fhirclient.models.annotation as a
from linkml_runtime import SchemaView

# get the config
config = configparser.ConfigParser()
config.read('../config.ini')

# connection settings to FHIR server
settings = {
    'app_id': 'my_web_app',
    'api_base': 'http://localhost:8080/fhir/'
}

# connection to Gradlio server
HOST_URL = "http://localhost:7860"
llm_client = gradio_client.Client(HOST_URL)

# load LinkML template
path_to_template = 'test.yaml'
sv = SchemaView(path_to_template)

# llm call function
def run_llm(prompt):

    # string of dict for input
    kwargs = dict(instruction_nochat=prompt)
    res = llm_client.predict(str(dict(kwargs)), api_name='/submit_nochat_api')

    # string of dict for output
    return ast.literal_eval(res)['response']

# get the prompt for the given template
def get_completion_prompt(cls_name, sv, text):
    """Get the prompt for the given template."""

    # system prompt
    prompt = (
        "Do not provide any explanations and just extract the entities (if available) from the text below in the following format:\n\n"
    )

    # schema prompt is concatinated from prompts for each entity group
    for slot in sv.class_induced_slots(cls_name):
        
        # description of the entity group
        slot_prompt = slot.description

        # custom instruction for categorical entities
        if slot.range in sv.all_enums():
            enum_def = sv.get_enum(slot.range)
            pvs = [str(k) for k in enum_def.permissible_values.keys()]
            slot_prompt += f"Must be one of: {', '.join(pvs)}"

        # entity group prompt
        prompt += f"{slot.name}: <{slot_prompt}>\n"
    
    # concatinate the prompt with doctor's note
    prompt = f"{prompt}\n\nText:\n{text}\n\n===\n\n"
    return prompt

# preprocess annotation
def preprocessing(ann_text):
    
    # strip
    ann_text = ann_text.strip()
    # lower
    ann_text = ann_text.lower()
    # delete intro
    ann_text = ann_text[ann_text.find('age'):]
    # remove stars
    ann_text = ann_text.replace('*','')
    # remove new lines
    ann_text = ann_text.replace('\n',' ')
    # add new lines to sections
    ann_text = ann_text.replace('gender:','\ngender:')
    ann_text = ann_text.replace('conditions:','\nconditions:')
    ann_text = ann_text.replace('observations:','\nobservations:')
    ann_text = ann_text.replace('medications:','\nmedications:')
    ann_text = ann_text.replace('procedures:','\nprocedures:')
    return ann_text

# annotate observation
def annotate_observation(observation):
    
    # generate the prompt
    prompt = get_completion_prompt(
        cls_name='ClinicalNote',
        sv=sv,
        text = observation.code.text
    )
    
    # LLM call to extract entities 
    result = run_llm(prompt)

    # preprocess annotation text
    result = preprocessing(result)
    
    # update observation
    observation.note = [a.Annotation({
        "authorString" : "Raw annotations from LLama2",
        "text": result
    })]
    observation.status = 'registered'
    observation.update(fhir_client.server)

    return result

# annotate observations
job_not_done = True
while job_not_done:

    # get observation batch
    fhir_client = FHIRClient(settings=settings)
    search = o.Observation.where(struct={'status': 'preliminary'})
    observations = search.perform_resources(fhir_client.server)

    # check that it is not empty (job not finished)
    if len(observations) == 0:
        job_not_done = False
    else:

        # annotate the batch of observations
        for observation in observations:
            annotate_observation(observation)

Loaded as API: http://localhost:7860/ ✔


In [16]:
from oaklib import BasicOntologyInterface, get_adapter
from oaklib.datamodels.text_annotator import TextAnnotationConfiguration
from oaklib.implementations import OntoPortalImplementationBase
import importlib
import inflection
import re

from core import ExtractionResult, NamedEntity

mod = importlib.import_module("test")
template_module = mod
ANNOTATION_KEY_EXAMPLES = "prompt.examples"
annotators = [get_adapter('sqlite:obo:ncit')]
mappers = [get_adapter("translator:")]
named_entities = []

# parse completion payload
def parse_completion_payload(results, cls_name, sv):
    raw = _parse_response_to_dict(results, cls_name, sv)
    logging.debug(f"RAW: {raw}")
    return ground_annotation_object(raw, cls_name, sv)

In [17]:
# test on 1 record
result = """
age: not provided \ngender: male \nconditions: diabetes \nobservations:  blood glucose level: 175 mg/dl \nmedications:  medication name: not provided  dose: not provided  frequency: not provided \nprocedures:  none mentioned scores:  none mentioned
"""

# parse the result
ann_dict = parse_completion_payload(
    results = result,
    cls_name = 'ClinicalNote', 
    sv = sv
)

# print the result
ann_dict

ClinicalNote(age=None, gender='NCIT:C20197', conditions=['NCIT:C2985'], observations=['blood glucose level: 175 mg/dl'], medications=[], procedures=[])

In [19]:
sv.class_induced_slots('ClinicalNote')

[SlotDefinition(name='age', id_prefixes=[], definition_uri=None, local_names=JsonObj(), conforms_to=None, implements=[], extensions=JsonObj(), annotations=JsonObj(), description='age of the patient', alt_descriptions=JsonObj(), title=None, deprecated=None, todos=[], notes=[], comments=[], examples=[], in_subset=[], from_schema='test', imported_from=None, source=None, in_language=None, see_also=[], deprecated_element_has_exact_replacement=None, deprecated_element_has_possible_replacement=None, aliases=[], structured_aliases=JsonObj(), mappings=[], exact_mappings=[], close_mappings=[], related_mappings=[], narrow_mappings=[], broad_mappings=[], created_by=None, created_on=None, last_updated_on=None, modified_by=None, status=None, rank=None, is_a=None, abstract=None, mixin=None, mixins=[], apply_to=[], values_from=[], string_serialization=None, singular_name=None, domain=None, slot_uri=None, multivalued=None, inherited=None, readonly=None, ifabsent=None, list_elements_unique=None, list_el

In [None]:
def groundings(text, cls, annotators = annotators):
    logger.info(f"GROUNDING {text} using {cls.name}")
    id_matches = re.match(r"^(\S+):(\d+)$", text)
    if id_matches:
        obj_prefix = id_matches.group(1)
        matching_prefixes = [x for x in cls.id_prefixes if x.upper() == obj_prefix.upper()]
        if matching_prefixes:
            yield matching_prefixes[0] + ":" + id_matches.group(2)
    text_lower = text.lower()
    text_singularized = inflection.singularize(text_lower)
    if text_singularized != text_lower:
        logger.info(f"Singularized {text} to {text_singularized}")
        yield from groundings(text_singularized, cls)
    paren_char = "["
    parenthetical_components = re.findall(r"\[(.*?)\]", text_lower)
    if not parenthetical_components:
        paren_char = "("
        parenthetical_components = re.findall(r"\((.*?)\)", text_lower)
    if parenthetical_components:
        logger.info(f"{text_lower} =>paren=> {parenthetical_components}")
        trimmed_text = text_lower
        for component in parenthetical_components:
            if component:
                logger.debug(
                    f"RECURSIVE GROUNDING OF {component} from {parenthetical_components}"
                )
                yield from groundings(component, cls)
            if paren_char == "(":
                trimmed_text = trimmed_text.replace(f"({component})", "")
            elif paren_char == "[":
                trimmed_text = trimmed_text.replace(f"[{component}]", "")
            else:
                raise AssertionError(f"Unknown paren char {paren_char}")
        trimmed_text = trimmed_text.strip().replace("  ", " ")
        if trimmed_text:
            if len(trimmed_text) >= len(text_lower):
                raise AssertionError(
                    f"Trimmed text {trimmed_text} is not shorter than {text_lower}"
                )
            logger.debug(
                f"{text_lower} =>trimmed=> {trimmed_text}; in {parenthetical_components}"
            )
            yield from groundings(trimmed_text, cls)
    if annotators and cls.name in annotators:
        annotators = annotators[cls.name]
    # prioritize whole matches by running these first
    for matches_whole_text in [True, False]:
        config = TextAnnotationConfiguration(matches_whole_text=matches_whole_text)
        for annotator in annotators:
            if isinstance(annotator, str):
                if annotator not in annotators:
                    logger.info(f"Loading annotator {annotator}")
                    annotators[annotator] = get_adapter(annotator)
                annotator = annotators[annotator]
            if not matches_whole_text and not isinstance(
                annotator, OntoPortalImplementationBase
            ):
                # TODO: allow more fine-grained control
                logger.info(
                    f"Skipping {type(annotator)} as it does not support partial matches"
                )
                continue
            try:
                results = annotator.annotate_text(text, config)
                for result in results:
                    yield result.object_id
            except Exception as e:
                logger.error(f"Error with {annotator} for {text}: {e}")

def parse_completion_payload(results, cls_name, sv):
    raw = _parse_response_to_dict(results, cls_name, sv)
    logging.debug(f"RAW: {raw}")
    return ground_annotation_object(raw, cls_name, sv)

def ground_annotation_object(ann, cls_name, sv):
    logging.debug(f"Grounding annotation object {ann}")
    new_ann = {}
    for field, vals in ann.items():
        if vals is None:
            new_ann[field] = None
        else:
            if isinstance(vals, list):
                multivalued = True
            else:
                multivalued = False
                vals = [vals]
            slot = sv.induced_slot(field, cls_name)
            rng_cls = sv.get_class(slot.range)
            enum_def = None
            if slot.range:
                if slot.range in sv.all_enums():
                    enum_def = sv.get_enum(slot.range)
            new_ann[field] = []
            for val in vals:
                if not val:
                    continue
                if isinstance(val, tuple):
                    # special case for pairs
                    sub_slots = sv.class_induced_slots(rng_cls.name)
                    obj = {}
                    for i in range(0, len(val)):
                        sub_slot = sub_slots[i]
                        sub_rng = sv.get_class(sub_slot.range)
                        if not sub_rng:
                            logging.error(f"Cannot find range for {sub_slot.name}")
                        result = normalize_named_entity(val[i], sub_slot.range)
                        obj[sub_slot.name] = result
                elif isinstance(val, dict):
                    # recurse
                    obj = ground_annotation_object(val, rng_cls, sv)
                else:
                    obj = normalize_named_entity(val, slot.range, sv)
                if enum_def:
                    found = False
                    logging.info(f"Looking for {obj} in {enum_def.name}")
                    for k, _pv in enum_def.permissible_values.items():
                        if obj.lower() == k.lower():
                            obj = k
                            found = True
                            break
                    if not found:
                        logging.info(f"Cannot find enum value for {obj} in {enum_def.name}")
                        obj = None
                if multivalued:
                    new_ann[field].append(obj)
                else:
                    new_ann[field] = obj
    logging.debug(f"Creating object from dict {new_ann}")
    logging.info(new_ann)
    py_cls = template_module.__dict__[cls_name]
    return py_cls(**new_ann)

def is_valid_identifier(input_id, cls, sv):
    if cls.id_prefixes:
        if ":" not in input_id:
            return False
        prefix, rest = input_id.split(":", 1)
        if prefix not in cls.id_prefixes:
            logger.debug(f"ID {input_id} not in prefixes {cls.id_prefixes}")
            return False
    id_slot = sv.get_identifier_slot(cls.name)
    if id_slot and id_slot.pattern:
        id_regex = re.compile(id_slot.pattern)
        m = re.match(id_regex, input_id)
        if not m:
            logger.debug(f"ID {input_id} does not match pattern {id_slot.pattern}")
            return False
    return True

def normalize_identifier(input_id, cls, sv):
    if is_valid_identifier(input_id, cls, sv):
        yield input_id
    for obj_id in map_identifier(input_id, cls):
        if obj_id == input_id:
            continue
        if is_valid_identifier(obj_id, cls, sv):
            yield obj_id

def map_identifier(input_id, cls):
    if input_id.startswith("http://purl.bioontology.org/ontology"):
        logging.info(f"Normalizing BioPortal id {input_id}")
        input_id = input_id.replace("http://purl.bioontology.org/ontology/", "").replace(
            "/", ":"
        )
    if input_id.startswith("http://id.nlm.nih.gov/mesh/"):
        logging.info(f"Normalizing MESH id {input_id}")
        input_id = input_id.replace("http://id.nlm.nih.gov/mesh/", "").replace("/", ":")
    if input_id.startswith("drugbank:"):
        input_id = input_id.replace("drugbank:", "DRUGBANK:")
    yield input_id
    if not cls.id_prefixes:
        return
    if not mappers:
        return
    for mapper in mappers:
        if isinstance(mapper, MappingProviderInterface):
            for mapping in mapper.sssom_mappings([input_id]):
                yield str(mapping.object_id)
        else:
            raise ValueError(f"Unknown mapper type {mapper}")

def normalize_named_entity(text, range, sv):
    cls = sv.get_class(range)
    if cls is None:
        return text
    if ANNOTATION_KEY_EXAMPLES in cls.annotations:
        examples = cls.annotations[ANNOTATION_KEY_EXAMPLES].value.split(", ")
        examples = [x.lower() for x in examples]
        logger.debug(f"Will exclude if in list of examples: {examples}")
        if text.lower() in examples:
            logger.warning(f"Likely a hallucination as it is the example set: {text}")
            return f"LIKELY HALLUCINATION: {text}"
    for obj_id in groundings(text, cls):
        logger.info(f"Grounding {text} to {obj_id}; next step is to normalize")
        for normalized_id in normalize_identifier(obj_id, cls, sv):
            if not any(e for e in named_entities if e.id == normalized_id):
                named_entities.append(NamedEntity(id=normalized_id, label=text))
            logger.info(f"Normalized {text} with {obj_id} to {normalized_id}")
            return normalized_id
    logger.info(f"Could not ground and normalize {text} to {cls.name}")
    obj_id = text
    return obj_id

# parse a parameter annotation (1 line)
def _parse_line_to_dict(line, sv, cls_name):
    
    # log
    logging.info(f"PARSING LINE: {line}")
    field, val = line.split(":", 1)
    
    # test if value is valid
    value_valid = True
    val = val.strip()

    # phrases to ignore
    ignore_phrases = [
        'not available',
        'not provided',
        'not applicable',
        'not done',
        'not known',
        'not performed',
        'not reported',
        'not specified',
        'not tested',
        'not treated',
        'none provided',
        'none reported',
        'none specified',
        'none',
        'unknown',
        'unspecified',
    ]

    for ignore_phrase in ignore_phrases:
        
        # ignore if phrase is found
        if val.find(ignore_phrase) != -1:
            value_valid = False
            logging.info(f"LINE {line} has value {ignore_phrase} and will be ignored")

    # process value only if valid
    if value_valid:

        # Field nornalization:
        # The LLML may mutate the output format somewhat,
        # randomly pluralizing or replacing spaces with underscores
        field = field.lower().replace(" ", "_")
        cls_slots = sv.class_slots(cls_name)
        slot = None
        if field in cls_slots:
            slot = sv.induced_slot(field, cls_name)
        else:
            if field.endswith("s"):
                field = field[:-1]
            if field in cls_slots:
                slot = sv.induced_slot(field, cls_name)
        if not slot:
            logging.error(f"Cannot find slot for {field} in {line}")
            # raise ValueError(f"Cannot find slot for {field} in {line}")
            return
        if not val:
            msg = f"Empty value in key-value line: {line}"
            if slot.required:
                raise ValueError(msg)
            if slot.recommended:
                logging.warning(msg)
            return
        inlined = slot.inlined
        slot_range = sv.get_class(slot.range)
        if not inlined:
            if slot.range in sv.all_classes():
                inlined = sv.get_identifier_slot(slot_range.name) is None
        val = val.strip()
        if slot.multivalued:
            vals = [v.strip() for v in val.split(";")]
        else:
            vals = [val]
        vals = [val for val in vals if val]
        logging.debug(f"SLOT: {slot.name} INL: {inlined} VALS: {vals}")
        # transform back from list to single value if not multivalued
        if slot.multivalued:
            final_val = vals
        else:
            if len(vals) != 1:
                logging.error(f"Expected 1 value for {slot.name} in '{line}' but got {vals}")
            final_val = vals[0]
        return field, final_val
    return None

# parse a record LLM annotation result line by line
def _parse_response_to_dict(result, cls_name, sv):

    # process line by line
    lines = result.splitlines()
    ann = {}
    promptable_slots = sv.class_induced_slots(cls_name)
    for line in lines:
        line = line.strip()
        if not line:
            continue
        if ":" not in line:
            if len(promptable_slots) == 1:
                slot = promptable_slots[0]
                logging.warning(
                    f"Coercing to YAML-like with key {slot.name}: Original line: {line}"
                )
                line = f"{slot.name}: {line}"
            else:
                logging.error(f"Line '{line}' does not contain a colon; ignoring")
                return
        r = _parse_line_to_dict(line, sv, cls_name)
        if r is not None:
            field, val = r
            ann[field] = val
            logging.info(f"FIELD: {field} VAL: {val}")
    return ann