In [56]:
import os, json, re
import requests
from bs4 import BeautifulSoup
from dotenv import load_dotenv
from neo4j import GraphDatabase
from pydantic import BaseModel
from typing import List

In [90]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_neo4j import Neo4jGraph
from langchain.chains.graph_qa.cypher import GraphCypherQAChain

In [68]:
load_dotenv()
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USERNAME = os.getenv("NEO4J_USER")
NEO4J_PASSWORD = os.getenv("NEO4J_PW")

In [72]:
assert GOOGLE_API_KEY, "Set GOOGLE_API_KEY"
assert NEO4J_URI and NEO4J_USERNAME and NEO4J_PASSWORD, "Set Neo4j credentials"


In [73]:
class Triple(BaseModel):
    subject: str
    relation: str
    object: str
    subject_label: str = "Entity"
    object_label: str = "Entity"

In [74]:
llm = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    google_api_key=GOOGLE_API_KEY,
    temperature=0.0,
)

In [75]:
extraction_prompt = ChatPromptTemplate.from_messages([
    ("system",
     "Extract biomedical knowledge triples from the given text. "
     "Return ONLY a JSON array of triples with keys: subject, relation, object, subject_label, object_label. "
     "Relations must be UPPER_SNAKE_CASE."),
    ("user", "Text:\n{input}\n\nReturn JSON array of triples.")
])

In [76]:
parser = StrOutputParser()

In [None]:
def clean_json_str(s: str) -> str:
    s = s.strip()
    if s.startswith("```"):
        s = re.sub(r"^```[a-zA-Z0-9]*\n?", "", s)
        s = re.sub(r"```$", "", s)
    return s.strip()

def extract_triples(text: str) -> List[Triple]:
    if len(text) > 3000:
        text = text[:3000]
    msgs = extraction_prompt.format_messages(input=text)
    raw = llm.invoke(msgs)
    content = getattr(raw, "content", raw)
    content = clean_json_str(content)
    try:
        data = json.loads(content)
        return [Triple(**t) for t in data]
    except Exception as e:
        print("Parse error:", e)
        print("LLM raw output:", content)
        return []

In [86]:
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

In [87]:
def init_constraints():
    queries = [
        "CREATE CONSTRAINT IF NOT EXISTS FOR (n:Entity) REQUIRE n.name IS UNIQUE",
        "CREATE CONSTRAINT IF NOT EXISTS FOR (n:Drug) REQUIRE n.name IS UNIQUE",
        "CREATE CONSTRAINT IF NOT EXISTS FOR (n:Disease) REQUIRE n.name IS UNIQUE",
    ]
    with driver.session() as s:
        for q in queries:
            s.run(q)

def upsert_triples(triples: List[Triple]):
    with driver.session() as s:
        for t in triples:
            s_label = re.sub(r"[^A-Za-z0-9_]", "", t.subject_label) or "Entity"
            o_label = re.sub(r"[^A-Za-z0-9_]", "", t.object_label) or "Entity"
            relation = re.sub(r"[^A-Za-z0-9_]", "", t.relation) or "RELATED_TO"
            cypher = (
                f"MERGE (a:{s_label} {{name: $sname}}) "
                f"MERGE (b:{o_label} {{name: $oname}}) "
                f"MERGE (a)-[r:{relation}]->(b)"
            )
            s.run(cypher, sname=t.subject, oname=t.object)

In [98]:
def build_graph_qa():
    graph = Neo4jGraph(
        url=NEO4J_URI,
        username=NEO4J_USERNAME,
        password=NEO4J_PASSWORD,
        enhanced_schema=False,
        refresh_schema=False
    )

    schema = """
    (:Drug {name})-[:IS_A]->(:Category {name})
    (:Drug {name})-[:ENCODED_BY]->(:Protein {name})
    (:Drug {name})-[:ELICITS]->(:Response {name})
    (:Response {name})-[:PROTECTS_AGAINST]->(:Disease {name})
    (:Drug {name})-[:CAUSES_ADVERSE_REACTION]->(:SideEffect {name})
    """

    return GraphCypherQAChain.from_llm(
        llm=llm,
        graph=graph,
        validate_cypher=True,
        return_intermediate_steps=True,
        cypher_schema=schema  
    )


In [92]:
#url = "https://go.drugbank.com/drugs/DB15656"
text = """
AZD1222 (AstraZeneca COVID-19 vaccine) is a replication-deficient chimpanzee adenovirus vector
encoding the SARS-CoV-2 spike glycoprotein. It elicits an immune response to protect against COVID-19.
Adverse reactions include injection site tenderness, fatigue, headache, myalgia, malaise, etc.
"""

triples = extract_triples(text)
print("Extracted triples:", triples[:5])

init_constraints()
upsert_triples(triples)

Extracted triples: [Triple(subject='AZD1222', relation='IS_A', object='COVID-19 vaccine', subject_label='AZD1222', object_label='COVID-19 vaccine'), Triple(subject='AZD1222', relation='IS_A', object='replication-deficient chimpanzee adenovirus vector', subject_label='AZD1222', object_label='replication-deficient chimpanzee adenovirus vector'), Triple(subject='AZD1222', relation='ENCODED_BY', object='SARS-CoV-2 spike glycoprotein', subject_label='AZD1222', object_label='SARS-CoV-2 spike glycoprotein'), Triple(subject='AZD1222', relation='ELICITS', object='immune response', subject_label='AZD1222', object_label='immune response'), Triple(subject='immune response', relation='PROTECTS_AGAINST', object='COVID-19', subject_label='immune response', object_label='COVID-19')]
