In [None]:
!pip install langchain neo4j transformers langchain_community


Collecting neo4j
  Downloading neo4j-5.27.0-py3-none-any.whl.metadata (5.9 kB)
Collecting langchain_community
  Downloading langchain_community-0.3.16-py3-none-any.whl.metadata (2.9 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain_community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting httpx-sse<0.5.0,>=0.4.0 (from langchain_community)
  Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)
Collecting langchain
  Downloading langchain-0.3.16-py3-none-any.whl.metadata (7.1 kB)
Collecting langchain-core<0.4.0,>=0.3.31 (from langchain)
  Downloading langchain_core-0.3.32-py3-none-any.whl.metadata (6.3 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain_community)
  Downloading pydantic_settings-2.7.1-py3-none-any.whl.metadata (3.5 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain_community)
  Downloading marshmallow-3.26.0-py3-none-any.whl.metadata (7.3 kB)
Collecting typing-inspect<1

In [None]:
from transformers import pipeline
from langchain.graphs import Neo4jGraph
from langchain.chains import GraphCypherQAChain
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from neo4j import GraphDatabase
import re

In [None]:
URI = "neo4j+s://08398cf2.databases.neo4j.io"
USERNAME = "neo4j"
PASSWORD = "fVDb3eipO3R4HrPyPhaZLYSI4jpQRN60OUUBeBP-Eao"

In [None]:
################################################################################
# 1) DATABASE CONNECTION
################################################################################

# Neo4j driver setup
driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))

def run_cypher_query(driver, query):
    """
    Executes a Cypher query using the official Neo4j Python driver
    and returns results as a list of dictionaries.
    """
    with driver.session() as session:
        results = session.run(query)
        return [record.data() for record in results]

################################################################################
# 2) MODEL LOADING (SEQUENCE-TO-SEQUENCE)
################################################################################

# Load the LLM model
# model_name = "google/flan-t5-large"  # Replace with your preferred model
model_name = "google/flan-t5-base"  # Replace with your preferred model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

################################################################################
# 3) QUESTION HANDLING & QUERY GENERATION
################################################################################

def process_patient_profile(profile: str) -> dict:
    """
    Processes a structured patient profile input and extracts key details as a dictionary.
    Infers eligibility based on ECOG status, treatment history, age, and diagnosis.
    """
    profile_lines = profile.split("\n")
    patient_data = {}

    # Parse patient profile
    for line in profile_lines:
        line = line.strip()
        if line.startswith("Age:"):
            age_value = line.split(":", 1)[1].strip()
            patient_data["age"] = int(age_value) if age_value.isdigit() else None
        elif line.startswith("Gender:"):
            patient_data["gender"] = line.split(":", 1)[1].strip()
        elif line.startswith("Diagnosis:"):
            patient_data["diagnosis"] = line.split(":", 1)[1].strip()
        elif line.startswith("ECOGPerformanceStatus:"):
            ecog_value = line.split(":", 1)[1].strip()
            patient_data["ecog"] = int(ecog_value) if ecog_value.isdigit() else None
        elif line.startswith("TreatmentHistory:"):
            treatment_history = line.split(":", 1)[1].strip().lower()
            patient_data["treatment_history"] = "no prior" in treatment_history

    # Ensure required fields are present
    has_required_fields = (
        patient_data.get("ecog") is not None and
        patient_data.get("age") is not None and
        patient_data.get("diagnosis") is not None
    )

    # Infer eligibility
    if has_required_fields:
        patient_data["eligible"] = (
            (patient_data["ecog"] < 2) and
            (18 <= patient_data["age"] <= 70) and
            (patient_data["diagnosis"] == "Head and Neck Squamous Cell Carcinoma")
        )
    else:
        patient_data["eligible"] = False

    # print("Debug  - Parsed Patient Data:", patient_data)
    return patient_data

def generate_recommendation_query(profile_data: dict) -> str:
    """
    Generates a Cypher query for treatment recommendations based on the patient's profile.
    """
    age = profile_data.get("age", None)
    ecog_status = profile_data.get("ecog", 2)  # Default to ineligible if not parsed
    diagnosis = profile_data.get("diagnosis", None)
    is_eligible = profile_data.get("eligible", False)

    if not is_eligible:
        return "RETURN 'Patient does not meet trial inclusion criteria or is not eligible for treatment recommendations' AS message"

    # Add age and ECOG conditions dynamically
    age_condition = ("ANY(inclusion IN pop.inclusionCriteria WHERE inclusion CONTAINS 'Age 18-70 years') AND "
                     if 18 <= age <= 70 else
                     "")
    ecog_condition = ("ANY(inclusion IN pop.inclusionCriteria WHERE inclusion CONTAINS 'ECOG performance status 0 or 1') AND "
                      if ecog_status in [0, 1] else
                      "")

    if not diagnosis:
        return "RETURN 'Diagnosis information is missing' AS message"

    query = (
        "MATCH (pop:Population)-[:HAS_GROUP]->(group:StudyGroup)-[:HAS_REGIMEN]->(regimen:Regimen), "
        "(pop)-[:HAS_DIAGNOSIS]->(diagnosis:Diagnosis) "
        "WHERE "
        f"{age_condition}"
        f"{ecog_condition}"
        f"diagnosis.diagnosisName = '{diagnosis}' "
        "RETURN "
        "regimen.regimenId AS regimenId, "
        "regimen.includesAgents AS recommendedAgents, "
        "regimen.cycleLength AS cycleLength, "
        "regimen.numberOfCycles AS numberOfCycles, "
        "regimen.requiresGCSF AS requiresGCSF, "
        "regimen.maintenanceTherapy AS maintenanceTherapy"
    )
    return query

def generate_llm_response(query_results: list) -> str:
    """
    Generates a natural language response based on query results using an LLM.
    """
    if not query_results:
        return "No results found or the patient is ineligible for treatment recommendations."

    result_str = "\n".join(
        f"- Regimen ID: {result['regimenId']}, Agents: {result['recommendedAgents']}, "
        f"Cycle Length: {result['cycleLength']} days, Cycles: {result['numberOfCycles']}, "
        f"GCSF Required: {'Yes' if result['requiresGCSF'] else 'No'}, "
        f"Maintenance Therapy: {'Yes' if result['maintenanceTherapy'] else 'No'}"
        for result in query_results
    )

    prompt = f"""
    You are a clinical data expert. Based on the following structured query results, write a natural language response
    summarizing the recommended treatment regimens for a patient:

    Results:
    {result_str}

    Response:
    """
    inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
    outputs = model.generate(inputs, max_length=200, num_beams=5, early_stopping=True)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    return response

################################################################################
# 4) QUERY EXECUTION
################################################################################

def fix_cypher_query_syntax(query: str) -> str:
    """
    Fixes minor property syntax issues in Cypher queries.
    """
    query_fixed = re.sub(r'(\w+):\s*"([^"]+)"', r'{\1: "\2"}', query.strip())
    if not query_fixed.upper().startswith(("MATCH", "CREATE", "MERGE", "RETURN", "WITH")):
        query_fixed = "MATCH " + query_fixed
    return query_fixed

def validate_query(query: str) -> str:
    """
    Validates that the query starts with a valid Cypher keyword.
    """
    valid_start_keywords = ("MATCH", "CREATE", "MERGE", "RETURN", "WITH")
    if not query.upper().startswith(valid_start_keywords):
        raise ValueError(f"Invalid Cypher query: {query}")
    return query

def main():
    """
    Main function to handle the process pipeline.
    """
    patient_profile = """
    Age: 58
    Gender: Female
    Diagnosis: Head and Neck Squamous Cell Carcinoma
    ECOGPerformanceStatus: 1
    TreatmentHistory: Oral Tongue Tumor (Stage II): Partial glossectomy; Adjuvant radiation (54 Gy) ended 14 months ago; No chemotherapy; Never received Cetuximab or IO therapy
    """

    # Step 1: Process the patient profile
    profile_data = process_patient_profile(patient_profile)
    print("Processed Patient Profile:", profile_data)

    # Step 2: Generate a personalized query for recommendations
    personalized_query = generate_recommendation_query(profile_data)
    print("Generated Query:", personalized_query)

    # Step 3: Fix and validate the query
    try:
        fixed_query = fix_cypher_query_syntax(personalized_query)
        validated_query = validate_query(fixed_query)
    except ValueError as e:
        print("Query validation error:", e)
        return

    # Step 4: Execute the query in Neo4j
    results = run_cypher_query(driver, validated_query)

    # Step 5: Generate an LLM-enhanced response
    response = generate_llm_response(results)
    print("Here is the treatment recommendation:", response)

if __name__ == "__main__":
    main()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Processed Patient Profile: {'age': 58, 'gender': 'Female', 'diagnosis': 'Head and Neck Squamous Cell Carcinoma', 'ecog': 1, 'treatment_history': False, 'eligible': True}
Generated Query: MATCH (pop:Population)-[:HAS_GROUP]->(group:StudyGroup)-[:HAS_REGIMEN]->(regimen:Regimen), (pop)-[:HAS_DIAGNOSIS]->(diagnosis:Diagnosis) WHERE ANY(inclusion IN pop.inclusionCriteria WHERE inclusion CONTAINS 'Age 18-70 years') AND ANY(inclusion IN pop.inclusionCriteria WHERE inclusion CONTAINS 'ECOG performance status 0 or 1') AND diagnosis.diagnosisName = 'Head and Neck Squamous Cell Carcinoma' RETURN regimen.regimenId AS regimenId, regimen.includesAgents AS recommendedAgents, regimen.cycleLength AS cycleLength, regimen.numberOfCycles AS numberOfCycles, regimen.requiresGCSF AS requiresGCSF, regimen.maintenanceTherapy AS maintenanceTherapy
Here is the treatment recommendation: Regimen ID: TPEx, Agents: ['Docetaxel 75 mg/m2 (day 1)', 'Cisplatin 75 mg/m2 (day 1)', 'Cetuximab 400 mg/m2 (day 1, day 1, then 