In [1]:
import os
import sys

path_root = os.path.dirname(os.getcwd())

if path_root not in sys.path:
    sys.path.append(path_root)

## Example of going from Document to Graph

### Text2Graph

In [2]:
import time
import os
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import os
import re
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as patches

In [None]:
t5_model_name = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
kg_tokens_dict = ["<H>", "<R>", "<T>"]
num_added_toks = tokenizer.add_tokens(kg_tokens_dict)
text_prefix = "TEXT: "
graph_prefix = "GRAPH: "

model_path = "../models/webNLG_model.pkl"
device = torch.device("mps")
model = torch.load(model_path, map_location=device)

In [None]:
def parse_triple(content):
    entity_ls = set(
        [
            _e.strip()
            for _e in list(
                set(
                    re.findall(r"\s*<H>([\s\w\.\/\-]+)[<$]*", content)
                    + re.findall(r"\s*<R>([\s\w\.\/\-]+)[<$]*", content)
                )
            )
        ]
    )

    hrt_ls = set(
        [
            (_r[0].strip(), _r[1].strip(), _r[2].strip())
            for _r in re.findall(r"<H>([^<]+)<R>([^<]+)<T>([^<]+)", content)
        ]
    )
    return entity_ls, hrt_ls


def gen_json_response(hrt_ls):
    """
    {"graph": { "nodes": [ { "id": 1, "label": "Bob", "color": "#ffffff" }, { "id": 2, "label": "Alice", "color": "#ff7675" } ],
        "edges": [ { "from": 1, "to": 2, "label": "roommate" }, ] } }
    """
    graph = {"nodes": [], "edges": []}
    node_id = 0
    node_dict = {}
    for _h, _r, _t in hrt_ls:
        if _h not in node_dict:
            node_dict[_h] = node_id
            graph["nodes"].append(
                {"id": node_id, "label": _h, "color": "#ffffff"}
            )
            node_id += 1
        if _t not in node_dict:
            node_dict[_t] = node_id
            graph["nodes"].append(
                {"id": node_id, "label": _t, "color": "#ffffff"}
            )
            node_id += 1
        graph["edges"].append(
            {"from": node_dict[_h], "to": node_dict[_t], "label": _r}
        )
    return {"graph": graph}


def get_graph(text: str):
    start = time.time()
    input_content = text
    prefix = text_prefix

    input_content_tmp = tokenizer(
        prefix + input_content,
        return_tensors="pt",
        padding="max_length",
        max_length=500,
    )
    input_ids = input_content_tmp.input_ids.to("mps")
    am = input_content_tmp.attention_mask.to("mps")

    model_outputs = model.generate(
        input_ids=input_ids,
        attention_mask=am,
        num_beams=4,
        length_penalty=2.0,
        max_length=500,
        temperature=0,
    )

    out_content = tokenizer.decode(model_outputs[0], skip_special_tokens=True)
    if "<H>" in out_content:
        entity_pool, hrt_pool = parse_triple(out_content)
        print("-----Graph-----")
        data = gen_json_response(hrt_pool)
    else:
        print(out_content)
        data = {"graph": {"nodes": [], "edges": []}}

    return {"time": time.time() - start, "data": data}


def get_text_size(text, fontsize=10):
    """Estimate the size of the text in display coordinates."""
    fig = plt.figure()
    text_artist = plt.text(0, 0, text, fontsize=fontsize)
    renderer = fig.canvas.get_renderer()
    bbox = text_artist.get_window_extent(renderer=renderer)
    plt.close(fig)
    return bbox.width, bbox.height


def draw_graph(data):
    # Extract nodes and edges from the data
    nodes = data["data"]["graph"]["nodes"]
    edges = data["data"]["graph"]["edges"]

    # Create a NetworkX graph
    G = (
        nx.DiGraph()
    )  # DiGraph for a directed graph, use Graph() for undirected

    # Add nodes
    for node in nodes:
        G.add_node(
            node["id"], label=node["label"], color=node.get("color", "#0000FF")
        )

    # Add edges
    for edge in edges:
        G.add_edge(edge["from"], edge["to"], label=edge["label"])

    # Draw the graph
    pos = nx.spring_layout(G)  # Positions for all nodes

    # Define a transparent blue color
    transparent_blue = (0, 0, 1, 0.5)  # RGB (0, 0, 1) with 0.5 transparency

    # Extract node colors, default to transparent blue
    # node_colors = [node[1].get('color', transparent_blue) for node in G.nodes(data=True)]
    node_colors = [transparent_blue for node in G.nodes(data=True)]

    # Create a figure and set its size
    plt.figure(figsize=(12, 8))
    ax = plt.gca()

    # Draw nodes as ellipses
    for node, (x, y) in pos.items():
        label = G.nodes[node]["label"]
        width, height = get_text_size(label, fontsize=10)
        width /= 100  # Scale width to fit the graph
        ellipse = patches.Ellipse(
            (x, y),
            width=width + 0.2,
            height=0.1,
            color=node_colors[node],
            alpha=0.5,
        )
        ax.add_patch(ellipse)
        ax.text(
            x,
            y,
            label,
            horizontalalignment="center",
            verticalalignment="center",
        )

    # Draw edges with increased width
    nx.draw_networkx_edges(
        G, pos, arrowstyle="->", arrowsize=10, ax=ax, width=2
    )

    # Draw edge labels
    edge_labels = nx.get_edge_attributes(G, "label")
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)

    # Remove axes
    ax.set_axis_off()

    # Display the graph
    plt.show()

In [3]:
from src.utils import load_json

llm_path = "../example_output/example_pipeline_14_05_24/llm.json"
llm_output = load_json(llm_path)

In [None]:
get_graph(text=llm_output[0])

In [None]:
get_graph(text=llm_output[1])

In [None]:
data = get_graph(text=llm_output[2])
data

In [None]:
draw_graph(data)

### REBEL


In [8]:
from transformers import pipeline

triplet_extractor = pipeline(
    "text2text-generation",
    model="Babelscape/rebel-large",
    tokenizer="Babelscape/rebel-large",
)

In [12]:
clinical_text = llm_output[2]

# We need to use the tokenizer manually since we need special tokens.
extracted_text = triplet_extractor.tokenizer.batch_decode(
    [
        triplet_extractor(
            clinical_text, return_tensors=True, return_text=False
        )[0]["generated_token_ids"]
    ]
)


# Function to parse the generated text and extract the triplets
def extract_triplets(text):
    triplets = []
    relation, subject, relation, object_ = "", "", "", ""
    text = text.strip()
    current = "x"
    for token in (
        text.replace("<s>", "")
        .replace("<pad>", "")
        .replace("</s>", "")
        .split()
    ):
        if token == "<triplet>":
            current = "t"
            if relation != "":
                triplets.append(
                    {
                        "head": subject.strip(),
                        "type": relation.strip(),
                        "tail": object_.strip(),
                    }
                )
                relation = ""
            subject = ""
        elif token == "<subj>":
            current = "s"
            if relation != "":
                triplets.append(
                    {
                        "head": subject.strip(),
                        "type": relation.strip(),
                        "tail": object_.strip(),
                    }
                )
            object_ = ""
        elif token == "<obj>":
            current = "o"
            relation = ""
        else:
            if current == "t":
                subject += " " + token
            elif current == "s":
                object_ += " " + token
            elif current == "o":
                relation += " " + token
    if subject != "" and relation != "" and object_ != "":
        triplets.append(
            {
                "head": subject.strip(),
                "type": relation.strip(),
                "tail": object_.strip(),
            }
        )
    return triplets


extracted_triplets = extract_triplets(extracted_text[0])

print("Text", clinical_text)
print("Triplets", extracted_triplets)

Text  Clinical Note for Patient Huey Orn:

Patient presents with viral sinusitis, manifested by persistent nasal congestion, facial pain, and yellow discharge from the nose. History of fever and cough in the past 48 hours. No significant medical history or allergies. Current medications include acetaminophen for fever and ibuprofen for pain.
Triplets [{'head': 'acetaminophen', 'type': 'medical condition treated', 'tail': 'fever'}, {'head': 'fever', 'type': 'drug used for treatment', 'tail': 'acetaminophen'}, {'head': 'ibuprofen', 'type': 'medical condition treated', 'tail': 'pain'}, {'head': 'pain', 'type': 'drug used for treatment', 'tail': 'ibuprofen'}]


## Example of going from Entities Extracted to Graph

In [None]:
import networkx as nx
from pyvis.network import Network
import networkx as nx

In [None]:
from src.utils import load_json

extraction_path = "../example_output/example_pipeline_14_05_24/extraction.json"
extraction_output = load_json(extraction_path)

In [None]:
entity_list = ["person", "diagnosis", "nhs number", "date of birth"]

In [None]:
data = extraction_output[0]["Entities"]
# Initialize a graph
G = nx.Graph()

# Add a starter node (document ID)
doc_id = "doc_1"
G.add_node(doc_id, label="document")

# Add nodes and edges based on the data
for entry in data:
    node_label = entry["label"]
    node_text = entry["text"]
    node_score = entry["score"]

    G.add_node(node_text, label=node_label, score=node_score)
    G.add_edge(doc_id, node_text, label=node_label)

nx.draw(G, with_labels=True)
nt = Network("500px", "500px")
nt.from_nx(G)
nt.show("nx.html", notebook=False)