In [158]:
import os
import json
import ast
import re
from langchain_neo4j import Neo4jGraph, GraphCypherQAChain
from langchain_ollama import ChatOllama
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

In [15]:
uri = "bolt://localhost:7687"
user = "neo4j"
password = "neo4j_password" 
graph = Neo4jGraph(url=uri, username=user, password=password)

In [135]:
# Set LLM
llm = ChatOllama(model="llama3.2", temperature=0)
relation_llm = llm.with_structured_output(schema=GraphDocument)

In [50]:
# Load data
with open("data/diabetes_subset.json", "r") as f:
    json_data = json.load(f)

questions = [entry.get("QUESTION", "N/A") for entry in json_data.values()]
contexts = [entry.get("CONTEXTS", "N/A") for entry in json_data.values()]
conclusions = [entry.get("LONG_ANSWER", "N/A") for entry in json_data.values()]
ground_truth = [entry.get("final_decision", "N/A") for entry in json_data.values()]


In [51]:
# Convert abstract to a list of Documents
documents = []
for i in range(len(questions)):
    # Builds the abstract text from context and conclusion
    abstract = " ".join(contexts[i]) + conclusions[i]
    
    # Create a LangChain Document
    doc = Document(page_content=abstract)
    # Append to documents list
    documents.append(doc)

### Use LLMGraphTransformer to create a preliminary KG ###

In [None]:
# Use LLMGraphTransformer
llm_transformer = LLMGraphTransformer(llm=llm, additional_instructions="Remove all special characters such as parentheses or apostrophes from entities.")

graph_documents = []
for doc in documents:
    graph_doc = llm_transformer.process_response(doc)
    graph_documents.append(graph_doc)

print(f"Nodes:{graph_documents[0].nodes}")
print(f"Relationships:{graph_documents[0].relationships}")

[Relationship(source=Node(id='Positron Emission Tomography', type='Concept', properties={}), target=Node(id='Cervical Cancer', type='Concept', properties={}), type='ASSESSES', properties={}), Relationship(source=Node(id='Computed Tomography', type='Concept', properties={}), target=Node(id='Cervical Cancer', type='Concept', properties={}), type='COMPARES_TO', properties={}), Relationship(source=Node(id='Magnetic Resonance Imaging', type='Concept', properties={}), target=Node(id='Cervical Cancer', type='Concept', properties={}), type='COMPARES_TO', properties={}), Relationship(source=Node(id='Diabetes Mellitus', type='Concept', properties={}), target=Node(id='Fdg-Pet', type='Concept', properties={}), type='INFLUENCES', properties={}), Relationship(source=Node(id='Fdg-Pet', type='Concept', properties={}), target=Node(id='Hyperglycaemic Dm', type='Concept', properties={}), type='USED_FOR', properties={}), Relationship(source=Node(id='Mri/Ct Scans', type='Concept', properties={}), target=No

In [104]:
# Print all the unique relations
unique_relations = set()

# Iterate over all GraphDocuments and collect unique relationships
for graph_doc in graph_documents:
    for relation in graph_doc.relationships:
        relation.type = relation.type.replace("-", "_")
        # Add the relation (the relation is typically a tuple like (subject, relation, object))
        unique_relations.add(relation.type)  # We are interested in the relationship (second element)

# Print unique relations
print("Unique Relations:")
for relation in unique_relations:
    print(relation)
print("Total: ", len(unique_relations))

Unique Relations:
AFFECTED_BY
POSITIVE
ASSOCIATION
CAUSE
CONFER_EFFECTS
INCREASED_BY
EXPRESSED_IN
COMPARES_TO
IS_A_TYPE_OF
EQUIVALENT_TO
DECREASED_BY
USED_FOR
REDUCE
PRESENT
MEASURED
STIMULATE_SECRETION
HAS_CONDITION
STRONGLY_RELATED_TO
RELATION
WITH
TREATED_WITH
USED_TOGETHER
POTENTIAL_TARGET
IMPACT_ON
INCREASES_RISK_OF
IS_SIMILAR_TO
FOCUS
TREATMENT_OPTION
ASSESSES
WITHOUT
RISK_FACTOR_FOR
INFLUENCES
NOT_EXPRESSED_IN
GOAL_OF
HAS
ASSOCIATED_WITH
BLOCKING
PREVENTS
COMPARED_TO
IS_OUTCOME_OF
RISK_FACTOR
INCREASE_SECRETION
REGULATION
CO_EXPOSURE
PREVENTION
EXPOSURE
IMPROVED_BY
COMPLICATION
INDUCTION
DEVELOPED
NON_ASSOCIATION
PART_OF
HAS_BEHAVIOR
AFFECTS
GUIDELINE
TREATMENT
RECOMMENDATION
PERFORMED
AFFILIATION
MEASURED_FOR
USED_IN
REDUCE_SECRETION
MEASURED_BY
CHARACTERISTIC_OF
RESULT
CAUSES
CONSUMPTION
USED
DIAGNOSES
ASSOCIATED
RELATED_TO
BASED_ON
RISK
Total:  73


In [146]:
# Simplify relations
simplify_prompt = PromptTemplate(
    input_variables=["relations"],
    template="""
You are cleaning up a knowledge graph. Given this list of relationship types:

{relations}

Group similar or redundant ones and propose standardized names for them. 
Output only a mapping in JSON format where each original relation is mapped to a simplified version of that relation.
IMPORTANT NOTES:\n
        - Don't add any explanation and text.
""",
)

simplify_chain = simplify_prompt | llm

In [160]:
simplified_relations = simplify_chain.invoke({"relations": unique_relations})

In [161]:
# Update the relations in the graph_documents
relation_mapping = ast.literal_eval(simplified_relations.content)
updated_graph_documents = []

for doc in graph_documents:
    updated_relationships = []
    for rel in doc.relationships:
        original_type = rel.type
        new_type = relation_mapping.get(original_type, original_type)  # fallback to original
        if new_type != original_type:
            # Create a new relationship object with updated type
            rel.type = new_type
        updated_relationships.append(rel)

    doc.relationships = updated_relationships
    updated_graph_documents.append(doc)

In [162]:
# Add to graph
graph.add_graph_documents(updated_graph_documents, baseEntityLabel=True)

### Extract as many relations as possible ###
Extract more relations based on the existing nodes in the KG to improve density.

In [None]:
# Prompts model to extract as many entities as possible
entity_prompt = PromptTemplate(
    input_variables=["question"],
    template="""
        Attempt to extract as many entities as you can. Maintain
        Entity Consistency: When extracting entities, it's vital to ensure
        consistency. If an entity, such as "John Doe", is mentioned multiple
        times in the text but is referred to by different names or pronouns
        (e.g., "Joe", "he"), always use the most complete identifier for
        that entity. The knowledge graph should be coherent and easily
        understandable, so maintaining consistency in entity references is
        crucial.
        IMPORTANT NOTES:\n
        - Don't add any explanation and text.
        - Exclude special characters like parentheses or apostrophes.
Question: "{question}"
Entities (comma-separated):"""
)

entity_chain = entity_prompt | llm

In [None]:
# Prompts model to extract as many relations as possible between given entities
relation_prompt = PromptTemplate(
    input_variables = ["abstract", "entities"],
    template="""
        You are a top-tier algorithm designed for extracting information in 
        structured formats to build a knowledge graph. You must generate the 
        output as GraphDocument objects. Each object contains Relationships, which
        should have a source, target, and type key. The source and target keys are nodes from the following list: {entities}.
        The type key must contain the directional relationship between the source
        and target nodes. The type key is a relationship that can be found in the 
        the following text: {abstract}. Keep relationship types as succinct and simple as possible.
        IMPORTANT NOTES:\n- Don't add any explanation and text."""
)

relation_chain = relation_prompt | relation_llm

In [None]:
# Extract additional relations for the nodes already in the graph by prompting with those nodes and the abstracts again
# The graph needs to be more dense
abstract = documents[0]
#entity_response = entity_chain.invoke(abstract.page_content)
#entities = [e.strip() for e in entity_response.content.split(",") if e.strip()]
entities = [node.id for node in graph_documents[0].nodes]
print(entities)
response = relation_chain.invoke({"abstract": abstract.page_content, "entities": entities}) # BUG HERE: extracts the same relation multiple times?

['Positron Emission Tomography', 'Computed Tomography', 'Magnetic Resonance Imaging', 'Cervical Cancer', 'Diabetes Mellitus', 'Fdg-Pet', 'Mri/Ct Scans', 'Hyperglycaemic Dm', 'Euglycaemic Dm', 'Non-Dm']


In [175]:
for relation in response.relationships:
    print(relation.type)

may have additional value in the assessment of primary and recurrent cervical cancer
degree of tumour uptake is sometimes influenced by
in the assessment of primary and recurrent cervical cancer
were performed within 2 weeks
diagnostic ability in patients with cervical cancer complicated by DM
influence on the degree of tumour uptake
were performed within 2 weeks
diagnostic ability in patients with cervical cancer complicated by DM
influence on the degree of tumour uptake
were performed within 2 weeks
diagnostic ability in patients with cervical cancer complicated by DM
influence on the degree of tumour uptake
were performed within 2 weeks


### Querying ###
Now that we have a dense graph, we can query it by passing the possible relations to our querying model, llama-3.3.