In [5]:
from tqdm.notebook import tqdm
from openai import OpenAI
import re, os, json, pysbd
import prompts

os.environ["OPENAI_API_KEY"] = "YOUR_API"
client = OpenAI()
MODEL_NAME = "gpt-4o-2024-08-06"
seg = pysbd.Segmenter(language="en", clean=False)

In [3]:
def recalculate_spans(sentences, entities, location):

          abstract_data = list()
        
          current_position = 0
          sentence_spans = []
          for sentence in sentences:
               sentence_length = len(sentence)
                 
               sentence_spans.append((current_position, current_position + sentence_length))
               current_position += sentence_length  + 1
        
          for i, (sentence_start, sentence_end) in enumerate(sentence_spans):
               sentence_text = sentences[i]
            
               annotations = []
               for entity in entities:

                         entity_type = entity.get("label")
                         entity_start = int(entity["start_idx"])
                         entity_end = int(entity["end_idx"]) +1
                         entity_mention = entity["text_span"]
               

                         if sentence_start <= entity_start < sentence_end:
                              new_start = entity_start - sentence_start
                              new_end = entity_end - sentence_start
                              if new_end >  sentence_end-sentence_start:
                                   esent = sentence_text + " "+sentences[i+1]
                              else:
                                   esent = sentence_text
                         
                              if esent[new_start:new_end]!=entity_mention:
                              
                                   print(esent)
                                   print(entity_start, entity_end)
                                   print(new_start, new_end, current_position)
                                   print(sentence_start, sentence_end)
                                   print(esent[new_start:new_end], entity_mention)
                                 
                              annotations.append({
                         "label": entity_type,
                         "start": entity_start, 
                         "end": entity_end, 
                         "new_start":new_start,
                         "new_end": new_end,
                         "mention": entity_mention,
                         "location": entity["location"]
                              })
               abstract_data.append(
                 
                 {    
                      "location":location,
                      "start_sent":sentence_start,
                      "text":sentence_text,
                      "annotations": annotations
                 })
          return  abstract_data

def prompt_llm(system, sentence, model=None):
    client = OpenAI
    user = "\nText:\n```\n" + sentence + "\n```\n"
    response = client.chat.completions.create(
        model=model,
        messages=[
          {"role": "system", "content": system},
          {"role": "user", "content": user}
        ],
    )
    try:
        return response.choices[0].message.content
    except TypeError:
        print(response)
        raise TypeError

def tag_entities(text, entities):
    entities_sorted = sorted(entities, key=lambda e: e['start'], reverse=True)
    
    for entity in entities_sorted:
        start = entity['start']
        end = entity['end']
        entity_type = entity['label']
        if entity_type in ["DDF", "ddf"]:
            entity_type = "disease or finding"
        
        text = text[:end] + f"</entity>" + text[end:]
        text = text[:start] + f"<entity type=\"{entity_type}\">" + text[start:]
    
    return text

def extract_and_parse_json(text):
    match = re.search(r'(\{.*\})', text, re.DOTALL)

    if match:
        json_str = match.group(0)
        try:
            json_str = re.sub(";}","}", json_str.strip())
            json_str = re.sub(",}","}", json_str.strip())
            json_dict = json.loads(json_str)
            print(json_dict)
            if isinstance(json_dict.get("Relational triples"), str):
                json_dict = {"Relational triples":[json_dict.get("Relational triples")]}
            return json_dict

        except json.JSONDecodeError:
          print("Error in decoding json string: ", json_str)
          return {"Relational triples": []}
    else:
        return {"Relational triples": []}
    
def extract_spo(triple_str):
    """
    Extracts subject, predicate, and object from a string formatted as "[subject] PREDICATE [object]".
    
    Args:
        triple_str (str): The input string.
    
    Returns:
        tuple: (subject, predicate, object)
    """
    pattern = r"\[(.*?)\]\s+(.+?)\s+\[(.*?)\]"
    match = re.match(pattern, triple_str)
    
    if match:
        subject, predicate, obj = match.groups()
        return subject.strip(), predicate.strip(), obj.strip()
    else:
        return "", "", ""

1. Load GliNER predictions and sentence-segment test documents.

In [6]:
# Split abstracts into sentences

file ="data/predictions_GLINER.json"
with open(file) as fin:
    data = json.load(fin)
    for pid in data:
        title = data[pid]["metadata"]["title"]
        abstract = data[pid]["metadata"]["abstract"]
        entities = data[pid]["entities"]
        sent_data = []
        for (t, loc) in [(title, "title"), (abstract, "abstract")]:
            sentences = re.split(r"(?<=\.) (?=[A-Z])", t)
            new_sents = []
            for s in sentences: 
                if len(s) > 400:
                    new_sents.extend(seg.segment(s))
                else:
                    new_sents.append(s)
                ents = [e|{"location":loc} for e in entities if e["location"]==loc]
                sent_data.extend(recalculate_spans(sentences, ents, loc))
            data[pid] = {"sent_data":sent_data}

2. Tag sentences with extracted entities.

In [7]:
for pid in data:
    for sid, s in enumerate(data[pid]["sent_data"]):
        text = s["text"]
        entities = s["annotations"]
        all_entities = []

        for e in entities:
            if text[e["new_start"]:e["new_end"]] == e["mention"]:
                all_entities.append({
                        "start":e["new_start"], 
                        "end": e["new_end"],
                        "location":e["location"], 
                        "mention":e["mention"],
                        "label": e["label"]
                        })
            else:
                print("Wrong spans: ", text[e["new_start"]:e["new_end"]], e["mention"])
        data[pid]["sent_data"][sid]["tagged_text"] = tag_entities(text, all_entities)

3. Prompt OpenAI model to extract specified relation types.

In [None]:
relation_types = ["LOCATED_IN", "INTERACT", "IMPACT", "LINKED_TO", "AFFECT", "USED_BY"]

for sem_type in relation_types:

    prompt_name = getattr(prompts, sem_type)

    for pid in tqdm(data):

        for sid, sent in enumerate(data[pid]["sent_data"]):

            entities = sent["annotations"]
            sentence = sent["tagged_text"]
            
            if "pred_relations" not in data[pid]["sent_data"][sid]:
                data[pid]["sent_data"][sid]["pred_relations"] = []

            if entities and entities is not None:
                            
                response = prompt_llm(prompt_name, sentence, model=MODEL_NAME)
                data[pid]["sent_data"][sid]["pred_relations"].append(response)

4. Filter extracted relations based on legal relational patterns.

In [None]:
legal_relations = [
 ("anatomical location", "human", "located in"),
 ("anatomical location", "animal", "located in"),
 ("bacteria", "bacteria", "interact"),
 ("bacteria", "chemical", "interact"),
 ("bacteria", "ddf", "influence"),
 ("bacteria", "gene", "change expression"),
 ("bacteria", "human", "located in"),
 ("bacteria", "animal", "located in"),
 ("bacteria", "microbiome", "part of"),
 ("chemical", "animal location", "located in"),
 ("chemical", "human", "located in"),
 ("chemical", "animal", "located in"),
 ("chemical", "chemical", "interact"),
 ("chemical", "chemical", "part of"),
 ("chemical", "microbiome", "impact"),
 ("chemical", "microbiome", "produced by"),
 ("chemical", "bacteria", "impact"),
 ("dietary supplement", "bacteria", "impact"),
 ("drug", "bacteria", "impact"),
 ("food", "bacteria", "impact"),
 ("chemical", "microbiome", "impact"),
 ("dietary supplement", "microbiome", "impact"),
 ("drug", "microbiome", "impact"),
 ("food", "microbiome", "impact"),
 ("chemical", "ddf", "influence"),
 ("dietary supplement", "ddf", "influence"),
 ("food", "ddf", "influence"),
 ("chemical", "gene", "change expression"),
 ("dietary supplement", "gene", "change expression"),
 ("drug", "gene", "change expression"),
 ("food", "gene", "change expression"),
 ("chemical", "human", "administered"),
 ("dietary supplement", "human", "administered"),
 ("drug", "human", "administered"),
 ("food", "human", "administered"),
 ("chemical", "animal", "administered"),
 ("dietary supplement", "animal", "administered"),
 ("drug", "animal", "administered"),
 ("food", "animal", "administered"),
 ("ddf", "anatomical location", "strike"),
 ("ddf", "bacteria", "change abundance"),
 ("ddf", "microbiome", "change abundance"),
 ("ddf", "chemical", "interact"),
 ("ddf", "ddf", "affect"),
 ("ddf", "ddf", "is a"),
 ("ddf", "human", "target"),
 ("ddf", "animal", "target"),
 ("drug", "chemical", "interact"),
 ("drug", "drug", "interact"),
 ("drug", "ddf", "change effect"),
 ("human", "biomedical technique", "used by"),
 ("animal", "biomedical technique", "used by"),
 ("microbiome", "biomedical technique", "used by"),
 ("microbiome", "anatomical location", "located in"),
 ("microbiome", "human", "located in"),
 ("microbiome", "animal", "located in"),
 ("microbiome", "gene", "change expression"),
 ("microbiome", "ddf", "is linked to"),
 ("microbiome", "microbiome", "compared to")
 ]

unspecified_entities = [
    "bacteria", 
    "metabolites",
    "metabolite",
    "genes", 
    "gene",
    "disease", 
    "diseases", 
    "disorder",
    "disorders",
    "microbiome", 
    "microbiota",
    "human", 
    "animal", 
    "food", 
    "chemical", 
    "chemicals", 
    "drug", "drugs",
    "neurotransmitters", 
    "neurotransmitter"
    ]

In [None]:
for pid in tqdm(list(data)):
    
    relations = []
    for sid, sent in enumerate(data[pid]["sent_data"]):
        entities = sent["annotations"]
        sentence = sent["tagged_text"]
        sent_spos = []
        
        for raw_rel in sent.get("pred_relations", []):
            if raw_rel.strip():
                
                triples = extract_and_parse_json(raw_rel.strip())
                triples = triples.get("Relational triples", [])
                ss, so = [], []
                for triple in triples:
                    subj, predicate, obj = extract_spo(triple)

                    for e in list(entities):
                        if (e["mention"].lower() == subj.lower()) and \
                            (subj.lower() not in unspecified_entities):
                            ss.append(e)
                            break
                    
                    for e in entities:
                        if (e["mention"].lower() == obj.lower()) and \
                            (obj.lower() not in unspecified_entities):
                            so.append(e)
                            break
                    
                    spos = []
                    if ss and so:
                        for s in ss:
                            for o in so:
                                spos.append([s, o, predicate])

                    for spo in spos:
                        lab1, lab2 = spo[0]["label"].lower(), spo[1]["label"].lower()
                        rel_type = spo[-1].lower()

                        # housekeeping
                        if rel_type == "associated with":
                            rel_type = "affect"

                        lab1 = "DDF" if lab1.lower() in ["ddf", "disease or finding"] else lab1
                        lab2 = "DDF" if lab2.lower() in ["ddf", "disease or finding"] else lab2
                        
                        if (lab1.lower(), lab2.lower(), rel_type.lower()) in legal_relations:
                            key = (spo[0]["mention"],lab1,rel_type,spo[1]["mention"],lab2)
                            if key not in sent_spos:
                                sent_spos.append(key)

        relations.extend(sent_spos)

    ternary_mention_based_relations = []
    ternary_tag_based_relations = []
    binary_tag_based_relations = []

    for item in relations:
        ternary_mention_based_relations.append(
            {
                'subject_text_span': item[0],
                'subject_label': item[1],
                'predicate':item[2],
                'object_text_span':item[3],
                'object_label':  item[4]
            }
        )
        ternary_tag_based_relations.append(
            {
                'subject_label': item[1],
                'predicate':item[2],
                'object_label':  item[4] 
            }
        )
        binary_tag_based_relations.append(
            {
                'subject_label': item[1],
                'object_label':  item[4]
            }
        )
    data[pid]["ternary_mention_based_relations"] = ternary_mention_based_relations
    data[pid]["ternary_tag_based_relations"] = ternary_tag_based_relations
    data[pid]["binary_tag_based_relations"] = binary_tag_based_relations

In [None]:
tag_re = dict()
binary_re = dict()
mention_re = dict()

for pid in data:
    tag_re[pid] = {"ternary_tag_based_relations":data[pid]["ternary_tag_based_relations"]}
    binary_re[pid] = {"binary_tag_based_relations":data[pid]["binary_tag_based_relations"]}
    mention_re[pid] = {"ternary_mention_based_relations":data[pid]["ternary_mention_based_relations"]}

In [None]:
with open("outputs/Unsupervised_T622_R1_gpt4re/Unsepervised_T622_R1_gpt4re.json", "w") as file:
    json.dump(binary_re, file, indent=4)

with open("outputs/Unsupervised_T623_R1_gpt4re/Unsepervised_T623_R1_gpt4re.json", "w") as file:
    json.dump(tag_re, file, indent=4)

with open("outputs/Unsupervised_T624_R1_gpt4re/Unsepervised_T624_R1_gpt4re.json", "w") as file:
    json.dump(mention_re, file, indent=4)

In [None]:
with open("outputs/llm_results.json", "w") as file:
    json.dump(data, file, indent=4)