In [13]:
import os
import logging
import sys
import json
from dotenv import load_dotenv

from llama_index.core import (
    SimpleDirectoryReader,
    KnowledgeGraphIndex,
    StorageContext,
    load_index_from_storage,
    Settings,
    Document,
    QueryBundle,
    VectorStoreIndex, # We might need this temporarily if we want node parsing
    PromptTemplate,
)
from llama_index.core.graph_stores import SimpleGraphStore
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TextNode
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.output_parsers import PydanticOutputParser
from llama_index.llms.openai import OpenAI
# from llama_index.llms.ollama import Ollama

# --- Property Graph Imports ---
from llama_index.core.graph_stores.simple_labelled import SimplePropertyGraphStore
from llama_index.core.indices.property_graph import PropertyGraphIndex

# --- Query Engine and Retriever ---
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import BaseRetriever # For type hinting
from llama_index.core import get_response_synthesizer # Factory for response synthesizer

from pyvis.network import Network
from tqdm.notebook import tqdm # Or standard tqdm if not in notebook

# --- Pydantic Models (from step 2) ---
from pydantic import BaseModel, Field
from typing import List, Tuple, Any, Dict

In [3]:
import nest_asyncio
nest_asyncio.apply()


In [4]:
from dotenv import load_dotenv
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
load_dotenv()


endpoint = "https://d-ais-eus-ais-chatbots.openai.azure.com/"
model_name = "o1-mini"
deployment = "o1-mini"
subscription_key = os.getenv("AZURE_OPENAI_API_KEY")
api_version = "2024-12-01-preview" # Use a valid API version

llm = AzureOpenAI(
    azure_endpoint=endpoint,
    api_key=subscription_key,
    api_version=api_version,
    deployment_name=deployment,
    model_name=model_name,
    temperature=1.0
)

embed_model = HuggingFaceEmbedding(
    model_name="BAAI/bge-m3"
    )

Settings.embed_model = embed_model

llm2 = AzureOpenAI(
    azure_endpoint=endpoint,
    api_key=subscription_key,
    api_version="2024-05-01-preview",
    deployment_name="gpt-4o-mini-test",
    model_name="gpt-4o-mini-test",
    temperature=1.0
)

In [None]:
INPUT_DIR = "./kgdata" # Directory containing the text files to be indexed
# OUTPUT_DIR = "./storage_kg_custom_prompt" # Directory to store the index
PERSIST_DIR = "./storage_kg_custom_prompt" # Use a different dir for this version
GRAPH_OUTPUT_HTML = "ems_knowledge_graph_custom_prompt.html"
TRIPLETS_CACHE_FILE = "./extracted_triplets_cache.json"

logging.basicConfig(stream=sys.stdout, level=logging.INFO, force=True)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

In [4]:
from pydantic import BaseModel, Field
from typing import List, Tuple

# Define the structure for a single triplet
class Triplet(BaseModel):
    """Represents a single knowledge graph triplet."""
    subject: str = Field(..., description="Subject entity of the relationship")
    predicate: str = Field(..., description="Predicate or relationship connecting the subject and object")
    object: str = Field(..., description="Object entity of the relationship")

# Define the overall output structure expected from the LLM for a given chunk
class KnowledgeGraphTriplets(BaseModel):
    """Container for multiple extracted triplets."""
    triplets: List[Triplet] = Field(..., description="List of extracted knowledge graph triplets")

In [69]:
from llama_index.core.prompts import PromptTemplate

# --- Custom Chain-of-Thought Prompt Template ---
# Adjust this prompt extensively based on desired output and LLM used.
# Provide clear instructions and examples relevant to EMS protocols.

# --- Modified Custom Chain-of-Thought Prompt Template ---

# KG_TRIPLET_EXTRACT_PROMPT_TMPL = """
# Given the following text from an Emergency Medical Services (EMS) protocol document, extract relevant knowledge graph triplets in the (Subject, Predicate, Object) format.

# **My Goal:** To build a knowledge graph representing EMS procedures, criteria, advice, symptoms, conditions, and responses.

# **Follow these Chain-of-Thought Instructions:**
# 1.  **Read the Text Carefully:** Understand the context ({text}).
# 2.  **Identify Key Entities:** Look for medical conditions, symptoms, criteria levels, actions/advice, patient types, equipment, codes/identifiers.
# 3.  **Determine Relationships:** Find connecting relationships (predicates) like 'has criteria', 'has code', 'requires advice', 'associated with', 'applies to', 'treated with'.
# 4.  **Filter for Relevance:** Only include clear, factual triplets implied in the text. Focus on actionable info, criteria, symptoms, procedures.
# 5.  **Format Output:** Structure the output strictly according to the format instructions below.

# **Format Instructions:**
# {format_instructions}

# **Now, analyze the following text chunk and provide the formatted output:**
# Text: ```{text}```
# **No Additional Text:** Do not include any additional text or explanations outside the specified format.
# Output:
# """ # Note: Removed the explicit JSON example as format_instructions will provide it

KG_TRIPLET_EXTRACT_PROMPT_TMPL = """
Given the following text from an Emergency Medical Services (EMS) protocol document, extract relevant knowledge graph triplets in the (Subject, Predicate, Object) format.
Focus on identifying distinct entities for subjects and objects. Normalize entity names where possible (e.g., 'RTA' to 'Road Traffic Accident').
Predicates should describe the relationship clearly (e.g., 'has criteria', 'requires advice for symptom', 'is symptom of').

**My Goal:** To build a knowledge graph representing EMS procedures, criteria, advice, symptoms, conditions, and responses.

**Follow these Chain-of-Thought Instructions:**
1.  **Read the Text Carefully:** Understand the context ({text}). Translate every language to English.
2.  **Identify Key Entities:** Look for medical conditions (e.g., 'Road traffic accident', 'Hypothermia', 'Anaphylactic shock'), symptoms (e.g., 'Unconscious', 'Shortness of breath', 'Pale, clammy skin'), criteria levels (e.g., 'Critical', 'Urgent'), actions/advice (e.g., 'Stop the bleeding', 'Help person sit up', 'Keep warm'), patient types (e.g., 'Pregnant', 'Diabetic', 'Infant'), equipment (e.g., 'EpiPen', 'AED'), and related codes/identifiers (e.g., '1.2', 'LVI'). Normalize these terms.
3.  **Determine Relationships:** For pairs of identified entities, find the connecting relationship (predicate). Examples:
    *   Condition 'has_criteria' Symptom/Level (e.g., 'Road Traffic Accident' 'has_criteria' 'Unconscious')
    *   Criteria/Symptom 'has_code' Identifier (e.g., 'Unconscious' 'has_code' '1.2')
    *   Symptom 'requires_advice' Action (e.g., 'Difficulty Breathing' 'requires_advice' 'Help person sit up')
    *   Condition 'associated_with_symptom' Symptom (e.g., 'Pre-eclampsia' 'associated_with_symptom' 'Headache')
    *   Action 'applies_to_patient_type' Patient Type (e.g., 'Keep head neutral' 'applies_to_patient_type' 'Infant')
    *   Condition 'treated_with_device' Device (e.g., 'Anaphylactic Shock' 'treated_with_device' 'EpiPen')
4.  **Filter for Relevance:** Only include clear, factual triplets. Avoid overly generic triplets.
5.  **Format Output:** Structure the output strictly according to the format instructions below.

**Format Instructions:**
{format_instructions}

**Now, analyze the following text chunk and provide the formatted output:**
Text: ```{text}```

**Translate every language to English.**
**No Additional Text:** Do not include any additional text or explanations outside the specified format.

Output:
"""

KG_TRIPLET_EXTRACT_PROMPT = PromptTemplate(KG_TRIPLET_EXTRACT_PROMPT_TMPL)

In [6]:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding


Settings.embed_model = HuggingFaceEmbedding(
    model_name="BAAI/bge-small-en-v1.5"
)
Settings.llm = llm
# Settings.embed_model = embed_model
Settings.chunk_size = 512
Settings.chunk_overlap = 100

KG_TRIPLET_EXTRACT_PROMPT = PromptTemplate(KG_TRIPLET_EXTRACT_PROMPT_TMPL)

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: BAAI/bge-small-en-v1.5
Load pretrained SentenceTransformer: BAAI/bge-small-en-v1.5
INFO:sentence_transformers.SentenceTransformer:2 prompts are loaded, with the keys: ['query', 'text']
2 prompts are loaded, with the keys: ['query', 'text']


In [7]:
from llama_index.core import PromptTemplate

# Example: Adjust this significantly based on TreeSummarize's goal
# The idea is to avoid a structure that LlamaIndex interprets as needing a leading system message.
# Put all instructions/context directly into what would become the 'user' message.
CUSTOM_SUMMARY_PROMPT_TMPL = """
Based *only* on the following context information, synthesize a concise answer to the query.
Do not use any prior knowledge. If the context does not provide enough information, state that clearly.

Context Information:
---------------------
{context_str}
---------------------
Query: {query_str}

Answer:
"""
CUSTOM_SUMMARY_PROMPT = PromptTemplate(CUSTOM_SUMMARY_PROMPT_TMPL)

In [70]:
def load_and_parse_documents(input_dir):
    """Loads documents and parses them into nodes (chunks)."""
    logging.info(f"Loading documents from: {input_dir}")
    if not os.path.exists(input_dir):
        raise FileNotFoundError(f"Input directory '{input_dir}' not found.")
        
    reader = SimpleDirectoryReader(input_dir, required_exts=[".md"])
    docs = reader.load_data()
    
    if not docs:
        raise ValueError(f"No markdown documents found in '{input_dir}'.")
        
    logging.info(f"Loaded {len(docs)} documents.")

    # Using SentenceSplitter for parsing into nodes
    # Adjust chunk_size and chunk_overlap in Settings if needed
    parser = SentenceSplitter()
    nodes = parser.get_nodes_from_documents(docs, show_progress=True)
    logging.info(f"Parsed into {len(nodes)} nodes.")
    return nodes

# # Modified 'load_and_parse_documents' to return actual LlamaIndex Document objects as well for PropertyGraph
# def load_and_parse_documents(input_dir: str) -> Tuple[List[Document], List[TextNode]]:
#     """Loads documents and parses them into Document objects and TextNode objects (chunks)."""
#     logging.info(f"Loading documents from: {input_dir}")
#     if not os.path.exists(input_dir):
#         raise FileNotFoundError(f"Input directory '{input_dir}' not found.")
        
#     reader = SimpleDirectoryReader(input_dir, required_exts=[".md"])
#     docs = reader.load_data() # These are LlamaIndex Document objects
    
#     if not docs:
#         raise ValueError(f"No markdown documents found in '{input_dir}'.")
        
#     logging.info(f"Loaded {len(docs)} documents.")

#     parser = SentenceSplitter(chunk_size=Settings.chunk_size, chunk_overlap=Settings.chunk_overlap)
#     nodes = parser.get_nodes_from_documents(docs, show_progress=True) # These are TextNode objects
#     logging.info(f"Parsed into {len(nodes)} TextNode objects.")
#     return docs, nodes

def extract_triplets_from_nodes(nodes):
    """Extracts triplets from each node using PydanticOutputParser."""

    # Create the PydanticOutputParser
    parser = PydanticOutputParser(output_cls=KnowledgeGraphTriplets)

    # Get format instructions once
    format_instructions = parser.get_format_string(escape_json=False)

    all_triplets = []
    logging.info(f"Extracting triplets from {len(nodes)} nodes using PydanticOutputParser...")

    for node in tqdm(nodes):
        node_text = node.get_content()
        if not node_text.strip():
            continue

        try:
            # Format the prompt
            formatted_prompt = KG_TRIPLET_EXTRACT_PROMPT.format(
                text=node_text,
                format_instructions=format_instructions
            )

            # Call the LLM directly
            response = Settings.llm.complete(formatted_prompt)
            raw_output = response.text

            # Parse the raw output string
            parsed_result = parser.parse(raw_output) # type: KnowledgeGraphTriplets

            logging.debug(f"Extracted {len(parsed_result.triplets)} triplets from node {node.node_id[:8]}...")
            for triplet in parsed_result.triplets:
                 if triplet.subject and triplet.predicate and triplet.object:
                    all_triplets.append((triplet.subject, triplet.predicate, triplet.object))
                 else:
                    logging.warning(f"Skipping incomplete triplet from node {node.node_id[:8]}: {triplet}")

        except Exception as e:
            # Handle potential parsing errors or LLM errors
            logging.error(f"Error processing node {node.node_id[:8]}: {e}")
            logging.debug(f"Failed prompt:\n{formatted_prompt}")
            logging.debug(f"Failed raw output:\n{raw_output if 'raw_output' in locals() else 'N/A'}")


    logging.info(f"Extracted a total of {len(all_triplets)} triplets using custom prompt and PydanticOutputParser.")
    return list(set(all_triplets)) # Return unique triplets

# def extract_triplets_from_text_nodes(text_nodes: List[TextNode]) -> List[Tuple[str, str, str, str]]: # Ensure return type hint is correct
#     """Extracts triplets from each TextNode using PydanticOutputParser, including the source TextNode ID."""
#     parser = PydanticOutputParser(output_cls=KnowledgeGraphTriplets)
#     format_instructions = parser.get_format_string(escape_json=False)
#     all_triplets_extracted_with_source = [] # Changed variable name for clarity
#     logging.info(f"Extracting triplets from {len(text_nodes)} TextNode objects...")

#     for node in tqdm(text_nodes, desc="Extracting Triplets"): # node here is a TextNode
#         node_text = node.get_content()
#         current_text_node_id = node.id_ # Get the ID of the current TextNode

#         if not node_text.strip():
#             continue
#         try:
#             formatted_prompt = KG_TRIPLET_EXTRACT_PROMPT.format(
#                 text=node_text, format_instructions=format_instructions
#             )
#             response = Settings.llm.complete(formatted_prompt)
#             raw_output = response.text
#             parsed_result: KnowledgeGraphTriplets = parser.parse(raw_output)
            
#             for triplet in parsed_result.triplets:
#                 if triplet.subject and triplet.predicate and triplet.object:
#                     s = str(triplet.subject).strip().lower()
#                     p = str(triplet.predicate).strip().lower().replace(" ", "_")
#                     o = str(triplet.object).strip().lower()
#                     if s and p and o:
#                          # --- CRITICAL CHANGE HERE ---
#                          all_triplets_extracted_with_source.append((s, p, o, current_text_node_id))
#                          # --- END CRITICAL CHANGE ---
#                 else:
#                     logging.warning(f"Skipping incomplete triplet from node {current_text_node_id[:8]}: {triplet}") # Use current_text_node_id for logging
#         except Exception as e:
#             logging.error(f"Error processing node {current_text_node_id[:8]}: {e}") # Use current_text_node_id for logging
#             logging.debug(f"Failed content for node {current_text_node_id[:8]}:\n{node_text[:500]}...")
#             if 'raw_output' in locals(): logging.debug(f"Failed raw LLM output:\n{raw_output}")

#     logging.info(f"Extracted a total of {len(all_triplets_extracted_with_source)} triplets (with source IDs) using custom prompt.")
#     # Return unique triplets if necessary, ensuring uniqueness considers all 4 elements
#     # For simplicity, let's assume duplicates are okay or handled later if truly identical.
#     # If you need unique (s,p,o,source_id) tuples:
#     # return list(set(all_triplets_extracted_with_source))
#     return all_triplets_extracted_with_source # Return the list of 4-element tuples

def get_all_triplets_from_simple_store(graph_store: SimpleGraphStore) -> List[Tuple[str, str, str]]:
    """
    Retrieves all triplets by iterating through the internal dictionary
    of a SimpleGraphStore instance.
    """
    all_triplets = []
    if hasattr(graph_store, '_data') and hasattr(graph_store._data, 'graph_dict'):
        for subj, rel_obj_list in graph_store._data.graph_dict.items():
            for rel, obj in rel_obj_list:
                all_triplets.append((subj, rel, obj))
        logging.info(f"Retrieved {len(all_triplets)} triplets via internal iteration.")
    else:
        logging.warning("Could not access internal graph_dict to retrieve all triplets.")
    return all_triplets

def populate_graph_store(triplets):
    """Populates a SimpleGraphStore with the extracted triplets."""
    logging.info("Populating graph store...")
    graph_store = SimpleGraphStore()
    count_added = 0
    for subj, pred, obj in triplets:
        # Ensure components are strings
        subj_str = str(subj).strip()
        pred_str = str(pred).strip()
        obj_str = str(obj).strip()
        # Basic check to avoid adding empty strings or overly short strings (adjust threshold if needed)
        if subj_str and pred_str and obj_str and len(subj_str) > 1 and len(obj_str) > 1:
            # Check if triplet already exists before incrementing count (upsert might not indicate new vs update)
            # This check is a bit inefficient but ensures we count unique triplets added *by this process*
            existing_rels = graph_store._data.graph_dict.get(subj_str, [])
            if [pred_str, obj_str] not in existing_rels:
                 count_added += 1 # Increment count only if it's a truly new addition to this subject
            graph_store.upsert_triplet(subj_str, pred_str, obj_str)

        else:
            logging.warning(f"Skipping triplet with empty or short component(s): ('{subj_str}', '{pred_str}', '{obj_str}')")

    # --- CORRECTED LOGGING LINE ---
    # Calculate final count by iterating the internal structure directly
    final_count = 0
    if hasattr(graph_store, '_data') and hasattr(graph_store._data, 'graph_dict'):
        for subj, rel_obj_list in graph_store._data.graph_dict.items():
            final_count += len(rel_obj_list)
    # Log the final count obtained from the store's internal state
    logging.info(f"Graph store populated. Final triplet count in store: {final_count}")
    # --- END CORRECTION ---

    return graph_store


def build_and_persist_index_from_store(graph_store, persist_dir):
    """Builds the KG index wrapper from the populated store and persists."""

    storage_context = StorageContext.from_defaults(graph_store=graph_store)

    # Create the index object which acts as a wrapper around the graph store
    # We don't need 'from_documents' here as we populated the store manually
    index = KnowledgeGraphIndex(
        nodes=[], # No new nodes to process here
        storage_context=storage_context,
        max_triplets_per_chunk=0, # Extraction already done
        include_embeddings=True, # Or True if you want embeddings later
        embed_model=embed_model, # Optional, but can be set if needed,
    )

    logging.info(f"Persisting index and graph store to {persist_dir}...")
    # Create directory if it doesn't exist
    os.makedirs(persist_dir, exist_ok=True)
    index.storage_context.persist(persist_dir=persist_dir)
    logging.info("Index persisted.")
    return index

def load_index_with_store(persist_dir):
     """Loads the index and its associated graph store."""
     logging.info(f"Loading Knowledge Graph Index from {persist_dir}...")
     graph_store = SimpleGraphStore() # Re-initialize, it will be loaded by StorageContext
     storage_context = StorageContext.from_defaults(persist_dir=persist_dir, graph_store=graph_store)
     # Check if graph store was actually loaded (SimpleGraphStore doesn't save separately easily,
     # persistence relies on the index saving its config which includes the store type)
     # A more robust way is to use a persistent graph store like Neo4j from the start.
     # For SimpleGraphStore, we might need to re-extract if just loading this way.
     # Let's try loading the index normally and see if it reconstructs the SimpleGraphStore implicitly
     try:
        index = load_index_from_storage(
             storage_context=storage_context,
             show_progress=True,
         )
        # Verify store population
        loaded_triplets_count = len(index.graph_store.get_all_triples())
        logging.info(f"Index loaded. Graph store contains {loaded_triplets_count} triplets.")
        if loaded_triplets_count == 0:
             logging.warning("Loaded graph store is empty. SimpleGraphStore might not persist standalone.")
             # In this case, you'd need to re-run the extraction (steps 1-4) if you restart the script
             # Or save/load the extracted triplets list separately (e.g., to a JSON file).
        return index
     except FileNotFoundError:
         logging.warning(f"Persistence directory {persist_dir} not found. Cannot load index.")
         return None
     except Exception as e:
        logging.error(f"Error loading index from {persist_dir}: {e}")
        return None


# # --- Visualization and Querying Functions (Keep from previous example) ---


def visualize_graph(index, output_html):
    """Generates an interactive HTML visualization of the knowledge graph."""
    logging.info("Generating graph visualization...")
    try:
        graph_store = index.graph_store
        triplets_to_visualize = get_all_triplets_from_simple_store(graph_store)

        if not triplets_to_visualize:
            logging.warning("No triplets found in the store for visualization.")
            return False

        net = Network(notebook=False, cdn_resources="in_line", directed=True, height="800px", width="100%") # Added size params
        added_nodes = set()

        for subj, pred, obj in triplets_to_visualize:
            subj_str = str(subj)
            obj_str = str(obj)
            pred_str = str(pred)

            if subj_str not in added_nodes:
                net.add_node(subj_str, label=subj_str)
                added_nodes.add(subj_str)
            if obj_str not in added_nodes:
                net.add_node(obj_str, label=obj_str)
                added_nodes.add(obj_str)
            net.add_edge(subj_str, obj_str, label=pred_str)

        logging.info(f"Saving graph visualization with {len(net.nodes)} nodes and {len(net.edges)} edges to: {output_html}")

        # --- REVISED FIX for UnicodeEncodeError ---
        # Try saving directly but ensure the underlying file operation uses UTF-8
        # While save_graph might not take encoding directly, we can ensure the environment helps.
        # Often, the issue is less about save_graph itself and more about how Python determines
        # the default encoding for file I/O. Let's force it within the 'with open' context.
        # However, pyvis might do its own internal writing.
        # A more direct approach: Write to a BytesIO buffer then write that buffer to file.

        # 1. Generate HTML content to a string (as before)
        html_content = net.generate_html()

        # 2. Write the string to the file using UTF-8 explicitly
        with open(output_html, "w", encoding="utf-8") as f:
             f.write(html_content)
        # --- END REVISED FIX ---


        logging.info("Visualization saved.")
        return True
    except Exception as e:
        logging.error(f"Failed to generate visualization: {e}")
        # Add more detailed error logging if needed
        import traceback
        logging.error(traceback.format_exc())
        return False

# --- Visualization Function (Keep and Adapt if Needed for PropertyGraph Structure) ---
def visualize_property_graph(graph_store: SimplePropertyGraphStore, output_html: str):
    """Generates an interactive HTML visualization of the Property Graph."""
    logging.info("Generating Property Graph visualization...")
    net = Network(notebook=False, cdn_resources="in_line", directed=True, height="800px", width="100%")
    
    # Get all nodes and relationships from the property graph store
    # SimplePropertyGraphStore might not have a direct "get_all_nodes/rels" easily
    # We'll build it from the triplets if that's how we primarily interact for visualization
    # This is a simplified visualization based on the triplets we know we inserted.
    # A full PropertyGraph might have nodes without explicit triplet relations.
    
    # For a simple visualization from triplets:
    added_nodes_viz = set()
    all_graph_triplets = []
    # This assumes the PropertyGraphStore was populated in a way we can retrieve triplets for viz
    # A more direct way would be to use the 'extracted_triplets' list if the store doesn't offer easy iteration

    # The SimplePropertyGraphStore uses `graph_store.get()` which needs a label.
    # This visualization part needs a rethink for SimplePropertyGraphStore if we want to show *all* content.
    # For now, let's visualize based on the `extracted_triplets` list that we have.
    # This requires passing `extracted_triplets` to this function.
    
    # Let's assume 'extracted_triplets' is passed to this function for now for simplicity.
    # In a real scenario, you'd query the graph store for its contents.
    # We'll adapt this if we can't get all triplets from the store easily.

    # We will use the 'extracted_triplets' list that we have globally for this visualization.
    # This is a simplification. Ideally, you'd iterate the store's contents.
    global extracted_triplets # Accessing global for simplicity here
    if not extracted_triplets:
        logging.warning("No extracted_triplets available for visualization.")
        return False

    for subj, pred, obj in extracted_triplets:
        subj_node_name = get_entity_node_name(subj)
        obj_node_name = get_entity_node_name(obj)

        if subj_node_name not in added_nodes_viz:
            net.add_node(subj_node_name, label=subj) # Use original case for label
            added_nodes_viz.add(subj_node_name)
        if obj_node_name not in added_nodes_viz:
            net.add_node(obj_node_name, label=obj)   # Use original case for label
            added_nodes_viz.add(obj_node_name)
        net.add_edge(subj_node_name, obj_node_name, label=pred)

    if not net.nodes:
        logging.warning("No nodes to visualize.")
        return False

    logging.info(f"Saving graph visualization with {len(net.nodes)} nodes and {len(net.edges)} edges to: {output_html}")
    try:
        html_content = net.generate_html()
        with open(output_html, "w", encoding="utf-8") as f:
            f.write(html_content)
        logging.info("Visualization saved.")
        return True
    except Exception as e:
        logging.error(f"Failed to generate visualization: {e}")
        import traceback
        logging.error(traceback.format_exc())
        return False

def get_all_triplets(index):
    """Retrieves all extracted triplets from the index's graph store."""
    try:
        graph_store = index.graph_store
        all_triples = graph_store.get_all_triples()
        logging.info(f"Retrieved {len(all_triples)} triplets from the index's graph store.")
        return all_triples
    except Exception as e:
        logging.error(f"Failed to retrieve triplets from graph store: {e}")
        return []


In [23]:
def save_triplets_to_cache(triplets, cache_file):
    """Saves the extracted triplets list to a JSON file."""
    logging.info(f"Saving {len(triplets)} triplets to cache file: {cache_file}")
    try:
        with open(cache_file, "w", encoding="utf-8") as f:
            # Convert tuples to lists for JSON compatibility
            json.dump([list(t) for t in triplets], f, indent=4)
        logging.info("Triplets saved successfully.")
    except Exception as e:
        logging.error(f"Failed to save triplets to cache: {e}")

def load_triplets_from_cache(cache_file):
    """Loads triplets from a JSON cache file."""
    if os.path.exists(cache_file):
        logging.info(f"Loading triplets from cache file: {cache_file}")
        try:
            with open(cache_file, "r", encoding="utf-8") as f:
                # Convert lists back to tuples
                triplets_list = json.load(f)
                triplets = [tuple(t) for t in triplets_list]
            logging.info(f"Loaded {len(triplets)} triplets from cache.")
            return triplets
        except Exception as e:
            logging.error(f"Failed to load triplets from cache: {e}")
            return None
    else:
        logging.info("Cache file not found.")
        return None
    
def get_entity_node_name(entity_name: str) -> str:
    """Helper to create a consistent node name for entities."""
    return entity_name.strip().lower() # Using lowercase for node names

In [10]:
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core import get_response_synthesizer
from llama_index.core.retrievers import BaseRetriever

In [39]:
import io # Import the io module
from llama_index.core.graph_stores.types import LabelledNode, Relation


In [None]:
# --- Main Execution (Modified for PropertyGraphIndex and Embeddings) ---

if __name__ == "__main__":

    Settings.embed_model = embed_model
    Settings.llm = llm
    # Settings.embed_model = embed_model
    Settings.chunk_size = 512
    Settings.chunk_overlap = 100
    Settings.transformations = []

    include_embeddings_flag = True # Set this based on your config or requirements
    # Settings.chunk_overlap = 64
    # --- End Global Configuration ---

    property_graph_index = None # Initialize
    extracted_triplets = None # Initialize

    # 1. Try loading triplets from cache
    extracted_triplets = load_triplets_from_cache(TRIPLETS_CACHE_FILE)

    # 2. If cache miss, extract triplets from documents
    # We also need the document nodes for PropertyGraphIndex
    original_documents, parsed_text_nodes = load_and_parse_documents(INPUT_DIR)

    if extracted_triplets is None:
        logging.info("Cache miss. Extracting triplets from parsed text nodes...")
        # Pass the parsed TextNode objects to the extraction function
        extracted_triplets = extract_triplets_from_text_nodes(parsed_text_nodes)

        if extracted_triplets:
            save_triplets_to_cache(extracted_triplets, TRIPLETS_CACHE_FILE)
        else:
            logging.error("No triplets extracted. Cannot build PropertyGraphIndex.")
            sys.exit(1) # Exit if extraction is critical and failed

    if extracted_triplets:
        # 3. Initialize PropertyGraphStore
        logging.info("Initializing SimplePropertyGraphStore...")
        property_graph_store = SimplePropertyGraphStore()

        # 4. Build PropertyGraphIndex
        # The PropertyGraphIndex constructor handles populating the store
        # from nodes and relationships (triplets).
        logging.info("Building PropertyGraphIndex with use_async=False and no internal kg_extractors...")
        # --- CHANGE THIS ---
        # property_graph_index = PropertyGraphIndex.from_existing(
        #     property_graph_store=property_graph_store,
        #     nodes=parsed_text_nodes, # This was causing the issue with from_existing
        # )
        # --- TO THIS (using the direct constructor) ---
        # --- TEMPORARY TEST ---
        original_llm = Settings.llm
        Settings.llm = None
        # --- END TEMPORARY TEST ---

        try:
            property_graph_index = PropertyGraphIndex(
                nodes=parsed_text_nodes,
                property_graph_store=property_graph_store,
                llm=None, # Explicitly None for PGI
                # embed_kg_nodes=False,
                embed_model=Settings.embed_model, # Keep your embed_model
                use_async=False,
                kg_extractors=[],
                transformations=[],
                show_progress=True, # Keep show_progress for now
            )
        finally:
            # --- RESTORE SETTINGS.LLM ---
            Settings.llm = original_llm
            # --- END RESTORE ---
        # --- END CHANGE ---
        logging.info(f"PropertyGraphIndex created and processed {len(parsed_text_nodes)} base nodes.")

        # Add relationships (triplets) to the index/store
        # Inside your loop for adding custom extracted relationships:

        from dataclasses import dataclass, field
        from typing import Dict, Any, Optional, List
        import hashlib # For creating a somewhat unique ID
        from llama_index.core.graph_stores.types import TRIPLET_SOURCE_KEY

        # Minimal implementations if the library types are abstract
        @dataclass
        class MyLabelledNode:
            label: str
            properties: Dict[str, Any] = field(default_factory=dict)
            id_str: str = "" # Using a different name to avoid conflict if 'id' is a property

            @property
            def id(self) -> str: # Implement the 'id' property
                return self.id_str

            def __str__(self) -> str: # Implement __str__
                return f"Node(id={self.id}, label='{self.label}', properties={self.properties})"

            # Add other methods/properties if SimplePropertyGraphStore expects them
            # e.g., embedding
            embedding: Optional[List[float]] = None


        @dataclass
        class MyRelation:
            source_id: str
            target_id: str
            label: str
            id_str: Optional[str] = None # Optional ID for the relation itself
            properties: Dict[str, Any] = field(default_factory=dict)
            # embedding: Optional[List[float]] = None # If relations can be embedded

            @property
            def id(self) -> Optional[str]:
                return self.id_str

            def __str__(self) -> str:
                return f"Relation(source_id='{self.source_id}', target_id='{self.target_id}', label='{self.label}')"

        # --- In your triplet adding loop ---

        logging.info(f"Adding {len(extracted_triplets)} custom extracted relationships...")
        for s_str, p_str, o_str, source_node_id_for_this_triplet  in tqdm(extracted_triplets, desc="Adding Custom Relationships"):
            try:
                s_entity_id = get_entity_node_name(s_str) # e.g., "medical charcoal"
                o_entity_id = get_entity_node_name(o_str) # e.g., "seizures"

                # Use your custom minimal node and relation types
                s_node_obj = MyLabelledNode(
                    id_str=s_entity_id,
                    label=s_str,
                    properties={"name": s_str, TRIPLET_SOURCE_KEY: source_node_id_for_this_triplet}
                )
                o_node_obj = MyLabelledNode(
                    id_str=o_entity_id,
                    label=o_str,
                    properties={"name": o_str, TRIPLET_SOURCE_KEY: source_node_id_for_this_triplet}
                )

                property_graph_index.property_graph_store.upsert_nodes([s_node_obj, o_node_obj])

                relation = MyRelation(
                    source_id=s_entity_id, # Relation is FROM "medical charcoal"
                    target_id=o_entity_id, # Relation is TO "seizures"
                    label=p_str,
                    properties={"type": p_str, TRIPLET_SOURCE_KEY: source_node_id_for_this_triplet} # CRITICAL
                )
                property_graph_index.property_graph_store.upsert_relations([relation])

            except Exception as e_rel:
                logging.error(f"Failed to add relationship ({s_str}, {p_str}, {o_str}): {e_rel}")
                import traceback
                logging.error(traceback.format_exc())

    # --- Post-Build Operations ---
    if property_graph_index:
        # Visualization: Needs access to the triplets for the current simple viz
        vis_success = visualize_property_graph(property_graph_index.property_graph_store, GRAPH_OUTPUT_HTML)
        if vis_success:
            print(f"\nGraph visualization saved to ems_knowledge_graph_pg.html")
        else:
            print("\nSkipping visualization for Property Graph.")

        # Querying
        # Inside your main execution block, after property_graph_index is built

        # Querying
        logging.info("Setting up query engine for PropertyGraphIndex...")
        try:
            # 1. Get the retriever from the PropertyGraphIndex
            # The default retriever for PGI is usually quite good.
            retriever_pg = property_graph_index.as_retriever(
                # You can experiment with retriever_mode if default isn't optimal
                # e.g., retriever_mode="hybrid" or "vector" or "keyword"
                # For PGI, often the default handles this well.
                # include_text=True # Usually good to include text for synthesis
                retriever_mode="hybrid",
                include_text=True, # Include text for synthesis

            )
            logging.info(f"Retriever for PGI: {type(retriever_pg)}")

            # 2. Create the ResponseSynthesizer explicitly with the desired mode and custom prompt
            # If you want TreeSummarize behavior specifically:
            response_synthesizer_pg = get_response_synthesizer(
                response_mode="tree_summarize", # Explicitly choose tree_summarize
                llm=Settings.llm,
                summary_template=CUSTOM_SUMMARY_PROMPT, # Your custom summary prompt
                use_async=False,
                # verbose=True # For debugging the synthesizer
            )
            # If the default was CompactAndRefine and you want to customize its prompts:
            # You would need to find the specific prompt parameters for CompactAndRefine.
            # For example (hypothetical, check LlamaIndex docs for actual param names):
            # response_synthesizer_pg = get_response_synthesizer(
            #     response_mode="compact_and_refine", # Or just "refine"
            #     llm=Settings.llm,
            #     # CompactAndRefine might use refine_template and/or text_qa_template
            #     refine_template=YOUR_CUSTOM_REFINE_PROMPT, # If you have one
            #     text_qa_template=YOUR_CUSTOM_TEXT_QA_PROMPT, # If you have one
            #     use_async=False,
            # )
            # For now, let's stick to forcing TreeSummarize as it's usually good for Q&A over KG.

            logging.info(f"Response synthesizer for PGI: {type(response_synthesizer_pg)}")
            # Check the internal builder if needed for debugging
            if hasattr(response_synthesizer_pg, '_response_builder'):
                logging.info(f"  Internal builder: {type(response_synthesizer_pg._response_builder)}")


            # 3. Create the RetrieverQueryEngine
            query_engine_pg = RetrieverQueryEngine(
                retriever=retriever_pg,
                response_synthesizer=response_synthesizer_pg,
            )
            logging.info("Query engine for PGI created.")


            # 4. Execute Queries
            print("\n--- Querying the Property Graph ---")
            queries = [
                "What advice is given for major bleeding in RTA?",
                "List critical criteria for Violence / abuse.",
                "What are the symptoms of Anaphylactic shock?",
                "What is the treatment for Anaphylactic shock?",
                "What is the treatment for seizures?",
                "What are advices for unconscious patients?",
            ]
            for q_text in queries:
                print(f"\nExecuting Query: {q_text}")
                response = query_engine_pg.query(q_text) # Use sync query
                # Or if you made main_async_pipeline and want to use await:
                # response = await query_engine_pg.aquery(q_text)
                print(f"Q: {q_text}\nA: {response}")
            print("-" * 30)

        except Exception as e:
            logging.error(f"An error occurred during PGI query engine setup or execution: {e}")
            import traceback
            logging.error(traceback.format_exc())
    else:
        print("\nPropertyGraphIndex could not be built. Exiting.")

    logging.info("Process finished.")

In [None]:
# Querying
logging.info("Setting up query engine for PropertyGraphIndex...")
try:
    # 1. Get the retriever from the PropertyGraphIndex
    # The default retriever for PGI is usually quite good.
    retriever_pg = property_graph_index.as_retriever(
        # You can experiment with retriever_mode if default isn't optimal
        # e.g., retriever_mode="hybrid" or "vector" or "keyword"
        # For PGI, often the default handles this well.
        # include_text=True # Usually good to include text for synthesis
        retriever_mode="hybrid",
        include_text=True, # Include text for synthesis

    )
    logging.info(f"Retriever for PGI: {type(retriever_pg)}")

    # 2. Create the ResponseSynthesizer explicitly with the desired mode and custom prompt
    # If you want TreeSummarize behavior specifically:
    response_synthesizer_pg = get_response_synthesizer(
        response_mode="tree_summarize", # Explicitly choose tree_summarize
        llm=Settings.llm,
        summary_template=CUSTOM_SUMMARY_PROMPT, # Your custom summary prompt
        use_async=False,
        # verbose=True # For debugging the synthesizer
    )
    # If the default was CompactAndRefine and you want to customize its prompts:
    # You would need to find the specific prompt parameters for CompactAndRefine.
    # For example (hypothetical, check LlamaIndex docs for actual param names):
    # response_synthesizer_pg = get_response_synthesizer(
    #     response_mode="compact_and_refine", # Or just "refine"
    #     llm=Settings.llm,
    #     # CompactAndRefine might use refine_template and/or text_qa_template
    #     refine_template=YOUR_CUSTOM_REFINE_PROMPT, # If you have one
    #     text_qa_template=YOUR_CUSTOM_TEXT_QA_PROMPT, # If you have one
    #     use_async=False,
    # )
    # For now, let's stick to forcing TreeSummarize as it's usually good for Q&A over KG.

    logging.info(f"Response synthesizer for PGI: {type(response_synthesizer_pg)}")
    # Check the internal builder if needed for debugging
    if hasattr(response_synthesizer_pg, '_response_builder'):
        logging.info(f"  Internal builder: {type(response_synthesizer_pg._response_builder)}")


    # 3. Create the RetrieverQueryEngine
    query_engine_pg = RetrieverQueryEngine(
        retriever=retriever_pg,
        response_synthesizer=response_synthesizer_pg,
    )
    logging.info("Query engine for PGI created.")


    # 4. Execute Queries
    print("\n--- Querying the Property Graph ---")
    queries = [
        "What are advices for unconscious patients?",
        "List critical criteria for Violence / abuse.",
        "What are the symptoms of Anaphylactic shock?",
        "What is the treatment for Anaphylactic shock?",
        "What is the treatment for seizures?",
        
    ]
    for q_text in queries:
        print(f"\nExecuting Query: {q_text}")
        response = query_engine_pg.query(q_text) # Use sync query
        # Or if you made main_async_pipeline and want to use await:
        # response = await query_engine_pg.aquery(q_text)
        print(f"Q: {q_text}\nA: {response}")
    print("-" * 30)

except Exception as e:
    logging.error(f"An error occurred during PGI query engine setup or execution: {e}")
    import traceback
    logging.error(traceback.format_exc())

In [None]:
# # --- Main Execution (Modified with Caching Logic) ---
# if __name__ == "__main__":

#     Settings.embed_model = HuggingFaceEmbedding(
#     model_name="BAAI/bge-m3"
#     )
#     Settings.llm = llm
#     # Settings.embed_model = embed_model
#     Settings.chunk_size = 512
#     Settings.chunk_overlap = 100

#     include_embeddings_flag = True # Set this based on your config or requirements

#     # 1. Try loading triplets from cache
#     extracted_triplets = load_triplets_from_cache(TRIPLETS_CACHE_FILE)

#     # 2. If cache miss or loading failed, extract triplets
#     if extracted_triplets is None:
#         logging.info("Cache miss or error. Extracting triplets from documents...")
#         doc_nodes = load_and_parse_documents(INPUT_DIR)
#         extracted_triplets = extract_triplets_from_nodes(doc_nodes) # Uses custom prompt + PydanticOutputParser

#         # 3. Save newly extracted triplets to cache
#         if extracted_triplets:
#             save_triplets_to_cache(extracted_triplets, TRIPLETS_CACHE_FILE)
#         else:
#             logging.error("No triplets extracted. Cannot build index.")
#             # Optionally exit if extraction is critical and failed
#             # sys.exit(1)

#     # Proceed only if we have triplets (either from cache or fresh extraction)
#     if extracted_triplets:
#         # 4. Populate Graph Store (Always do this step if triplets exist)
#         logging.info("Populating graph store from loaded/extracted triplets...")
#         graph_store_populated = populate_graph_store(extracted_triplets) # Uses SimpleGraphStore

#         # 5. Build Index Wrapper (Always do this step if triplets exist)
#         logging.info("Building KnowledgeGraphIndex wrapper...")
#         storage_context = StorageContext.from_defaults(graph_store=graph_store_populated)
#         # Note the DeprecationWarning when this line runs
#         kg_index = KnowledgeGraphIndex(
#             nodes=[], # Pass empty list as nodes are implicitly in the graph store
#             storage_context=storage_context,
#             max_triplets_per_chunk=0, # Extraction already done
#             include_embeddings=include_embeddings_flag, # Use flag set during config
#             # service_context=service_context, # Deprecated, use Settings
#         )
#         logging.info("KnowledgeGraphIndex wrapper built.")

#     # --- Post-Build Operations (Visualization, Querying) ---
#     if kg_index:
#         # Get triplets again *from the store* for consistency check / use
#         triplets_from_store = get_all_triplets_from_simple_store(kg_index.graph_store)
#         print(f"\n--- Store contains {len(triplets_from_store)} triplets ---")
#         print("--- Sample Extracted Triplets (Custom Prompt) ---")
#         for i, (s, p, o) in enumerate(triplets_from_store[:20]): # Print first 20
#             print(f"{s} --[{p}]--> {o}")
#         if not triplets_from_store:
#             print("No triplets found in the graph store.")
#         print("-" * 30)

#         # Visualize the graph
#         vis_success = visualize_graph(kg_index, GRAPH_OUTPUT_HTML) # Uses helper, saves with UTF-8
#         if vis_success:
#             print(f"\nGraph visualization saved to {GRAPH_OUTPUT_HTML}")
#             print("Open this HTML file in your browser to view the graph.")
#         else:
#             print("\nSkipping visualization due to errors or empty graph.")

#         # Query the graph (using manually constructed engine)
#         if triplets_from_store:
#             logging.info("Setting up query engine...")
#             try:
#                 # 1. Get the retriever from the index
#                 # Specify retriever type hint for clarity
#                 from llama_index.core.retrievers import BaseRetriever
#                 retriever: BaseRetriever = kg_index.as_retriever(
#                      # Use 'keyword' if no embeddings, 'hybrid' if embeddings included
#                     retriever_mode="keyword" if not include_embeddings_flag else "hybrid",
#                     graph_traversal_depth=2,
#                     # include_embeddings=include_embeddings_flag # Pass flag if retriever needs it
#                 )
#                 logging.info(f"Retriever created with mode: {'keyword' if not include_embeddings_flag else 'hybrid'}")


#                 # 2. Create the ResponseSynthesizer using the factory function
#                 response_synthesizer = get_response_synthesizer(
#                     response_mode="tree_summarize",
#                     llm=Settings.llm,
#                     summary_template=CUSTOM_SUMMARY_PROMPT, # Use the refined custom prompt
#                     use_async=False,
#                     # verbose=True # Enable for debugging
#                 )
#                 logging.info("Response synthesizer created.")

#                 # 3. Create the RetrieverQueryEngine manually
#                 query_engine = RetrieverQueryEngine(
#                     retriever=retriever,
#                     response_synthesizer=response_synthesizer,
#                 )
#                 logging.info("Query engine created.")

#                 # 4. Execute Queries
#                 print("\n--- Querying the Knowledge Graph (Custom Prompt, Factory Synthesizer) ---")

#                 query1 = "What advice is given for major bleeding in RTA?"
#                 print(f"\nExecuting Query 1: {query1}")
#                 response1 = query_engine.query(query1)
#                 print(f"\nQ: {query1}\nA: {response1}")

#                 query2 = "List critical criteria for Violence / abuse."
#                 print(f"\nExecuting Query 2: {query2}")
#                 response2 = query_engine.query(query2)
#                 print(f"\nQ: {query2}\nA: {response2}")

#                 query3 = "What are the symptoms of Anaphylactic shock?"
#                 print(f"\nExecuting Query 3: {query3}")
#                 response3 = query_engine.query(query3)
#                 print(f"\nQ: {query3}\nA: {response3}")

#                 print("-" * 30)

#             except Exception as e:
#                 logging.error(f"An error occurred during query engine setup or execution: {e}")
#                 import traceback
#                 logging.error(traceback.format_exc())

#         else:
#             print("\nSkipping querying as no triplets were found in the graph store.")
#     else:
#         print("\nIndex could not be built or loaded. Exiting.")

#     logging.info("Process finished.")

In [None]:
###### without caching triplets ######

if __name__ == "__main__":

    Settings.embed_model = HuggingFaceEmbedding(
    model_name="BAAI/bge-m3"
    )
    Settings.llm = llm
    # Settings.embed_model = embed_model
    Settings.chunk_size = 512
    Settings.chunk_overlap = 100

    # Check if index needs to be built or can be loaded
    if not os.path.exists(PERSIST_DIR):
         logging.info("Building new index as persistence directory not found.")
         # 1. Load and Parse Documents into Nodes
         doc_nodes = load_and_parse_documents(INPUT_DIR)

         # 2. Extract Triplets using Custom Prompt & PydanticProgram
         extracted_triplets = extract_triplets_from_nodes(doc_nodes)

         # 3. Populate Graph Store
         if extracted_triplets:
            graph_store_populated = populate_graph_store(extracted_triplets)

            # 4. Build Index Wrapper and Persist
            kg_index = build_and_persist_index_from_store(graph_store_populated, PERSIST_DIR)
         else:
            logging.error("No triplets extracted. Cannot build or save index.")
            kg_index = None
    else:
         logging.info("Attempting to load index from persistence directory.")
         # Try loading the index (which should include the graph store configuration)
         # Note: SimpleGraphStore persistence relies on index saving it. May need re-extraction.
         kg_index = load_index_with_store(PERSIST_DIR)
         if kg_index is None or len(kg_index.graph_store.get_all_triples()) == 0:
             logging.warning("Failed to load a populated index. Re-building...")
             # Fallback to re-building if loading fails or store is empty
             doc_nodes = load_and_parse_documents(INPUT_DIR)
             extracted_triplets = extract_triplets_from_nodes(doc_nodes)
             if extracted_triplets:
                 graph_store_populated = populate_graph_store(extracted_triplets)
                 kg_index = build_and_persist_index_from_store(graph_store_populated, PERSIST_DIR)
             else:
                 logging.error("No triplets extracted during re-build.")
                 kg_index = None


    if kg_index:
        # 5. (Optional) Retrieve and print triplets
        triplets = get_all_triplets_from_simple_store(kg_index.graph_store)
        print("\n--- Sample Extracted Triplets (Custom Prompt) ---")
        for i, (s, p, o) in enumerate(triplets[:20]): # Print first 20
            print(f"{s} --[{p}]--> {o}")
        if not triplets:
            print("No triplets found in the graph store.")
        print("-" * 30)

        # 6. (Optional) Visualize the graph
        vis_success = visualize_graph(kg_index, GRAPH_OUTPUT_HTML)
        if vis_success:
            print(f"\nGraph visualization saved to {GRAPH_OUTPUT_HTML}")
            print("Open this HTML file in your browser to view the graph.")
        else:
            print("\nSkipping visualization due to errors or empty graph.")

        # 7. (Optional) Query the graph
        if kg_index and triplets: # Ensure index and triplets exist
            # 1. Get the retriever from the index
            retriever: BaseRetriever = kg_index.as_retriever(
                retriever_mode="keyword", # Or other modes as needed
                # include_embeddings=False # Match how index was built
            )

            # 2. Create the TreeSummarize synthesizer with the custom prompt
            summarizer = TreeSummarize(
                summary_template=CUSTOM_SUMMARY_PROMPT, # Use the custom prompt
                llm=Settings.llm,
                # verbose=True # Enable for debugging synthesizer prompts
            )

            # 3. Create the ResponseSynthesizer
            response_synthesizer = get_response_synthesizer(
             response_mode="tree_summarize", # Specify the mode
             llm=Settings.llm, # Pass the LLM
             summary_template=CUSTOM_SUMMARY_PROMPT, # Pass the custom prompt here as well
             # Or potentially pass the summarizer instance if the factory supports it:
             # response_builder=summarizer, # Check documentation for this exact usage
             use_async=False, # Set based on your needs
            )

            # 4. Create the RetrieverQueryEngine manually
            query_engine = RetrieverQueryEngine(
                retriever=retriever,
                response_synthesizer=response_synthesizer,
            )

            # Now use this query_engine for querying
            print("\n--- Querying the Knowledge Graph (Custom Prompt, Custom Synthesizer) ---")
            query1 = "What advice is given for major bleeding in RTA?"
            try:
                response1 = query_engine.query(query1)
                print(f"\nQ: {query1}\nA: {response1}")
            except Exception as e:
                print(f"Error during query 1: {e}")

            query2 = "List critical criteria for Violence / abuse."
            try:
                response2 = query_engine.query(query2)
                print(f"\nQ: {query2}\nA: {response2}")
            except Exception as e:
                print(f"Error during query 2: {e}")

            query3 = "What are the symptoms of Anaphylactic shock?"
            try:
                response3 = query_engine.query(query3)
                print(f"\nQ: {query3}\nA: {response3}")
            except Exception as e:
                print(f"Error during query 3: {e}")

            print("-" * 30)

            
        else:
            print("\nSkipping querying as no triplets were extracted or found.")
    else:
        print("\nIndex could not be built or loaded. Exiting.")

    logging.info("Process finished.")

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: BAAI/bge-m3
Load pretrained SentenceTransformer: BAAI/bge-m3
INFO:sentence_transformers.SentenceTransformer:2 prompts are loaded, with the keys: ['query', 'text']
2 prompts are loaded, with the keys: ['query', 'text']
INFO:root:Building new index as persistence directory not found.
Building new index as persistence directory not found.
INFO:root:Loading documents from: ./kgdata
Loading documents from: ./kgdata
INFO:root:Loaded 37 documents.
Loaded 37 documents.


Parsing nodes:   0%|          | 0/37 [00:00<?, ?it/s]

INFO:root:Parsed into 145 nodes.
Parsed into 145 nodes.
INFO:root:Extracting triplets from 145 nodes using PydanticOutputParser...
Extracting triplets from 145 nodes using PydanticOutputParser...


  0%|          | 0/145 [00:00<?, ?it/s]

INFO:httpx:HTTP Request: POST https://d-ais-eus-ais-chatbots.openai.azure.com/openai/deployments/o1-mini/chat/completions?api-version=2024-12-01-preview "HTTP/1.1 200 OK"
HTTP Request: POST https://d-ais-eus-ais-chatbots.openai.azure.com/openai/deployments/o1-mini/chat/completions?api-version=2024-12-01-preview "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://d-ais-eus-ais-chatbots.openai.azure.com/openai/deployments/o1-mini/chat/completions?api-version=2024-12-01-preview "HTTP/1.1 200 OK"
HTTP Request: POST https://d-ais-eus-ais-chatbots.openai.azure.com/openai/deployments/o1-mini/chat/completions?api-version=2024-12-01-preview "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://d-ais-eus-ais-chatbots.openai.azure.com/openai/deployments/o1-mini/chat/completions?api-version=2024-12-01-preview "HTTP/1.1 200 OK"
HTTP Request: POST https://d-ais-eus-ais-chatbots.openai.azure.com/openai/deployments/o1-mini/chat/completions?api-version=2024-12-01-preview "HTTP/1.1 200 OK"
INFO:ht

  index = KnowledgeGraphIndex(


INFO:root:Index persisted.
Index persisted.
INFO:root:Retrieved 6496 triplets via internal iteration.
Retrieved 6496 triplets via internal iteration.

--- Sample Extracted Triplets (Custom Prompt) ---
Important Information to the Caller --[requires_advice]--> Tell me immediately if anything changes
Important Information to the Caller --[requires_advice]--> Watch the person all the time
Important Information to the Caller --[requires_advice]--> Help is on the way
Important Information to the Caller --[includes_advice]--> Inform immediately if anything changes
Important Information to the Caller --[includes_advice]--> Help is on the way
Important Information to the Caller --[includes_advice]--> Watch the casualty at all times
Important Information to the Caller --[requires_advice]--> Keep this phone free until the medics arrive
Important Information to the Caller --[includes_advice]--> Keep the phone free until medics arrive
Important Information to the Caller --[provides_instruction]-->

In [14]:
response2 = query_engine.query("unconcious")
print(f"\nQ: {query2}\nA: {response2}")

INFO:httpx:HTTP Request: POST https://d-ais-eus-ais-chatbots.openai.azure.com/openai/deployments/o1-mini/chat/completions?api-version=2024-12-01-preview "HTTP/1.1 200 OK"
HTTP Request: POST https://d-ais-eus-ais-chatbots.openai.azure.com/openai/deployments/o1-mini/chat/completions?api-version=2024-12-01-preview "HTTP/1.1 200 OK"
Index was not constructed with embeddings, skipping embedding usage...
INFO:llama_index.core.indices.knowledge_graph.retrievers:> No relationships found, returning nodes found by keywords.
> No relationships found, returning nodes found by keywords.
INFO:llama_index.core.indices.knowledge_graph.retrievers:> No nodes found by keywords, returning empty response.
> No nodes found by keywords, returning empty response.
INFO:httpx:HTTP Request: POST https://d-ais-eus-ais-chatbots.openai.azure.com/openai/deployments/o1-mini/chat/completions?api-version=2024-12-01-preview "HTTP/1.1 200 OK"
HTTP Request: POST https://d-ais-eus-ais-chatbots.openai.azure.com/openai/deplo

In [None]:
#### to do: how to build the knowledge graph from graph_store.json and query it? either using llama or langchain
#### also,  probably run previous experiment again to store embeddings as well.

In [None]:


# 2. Define paths to your persisted files
GRAPH_STORE_PATH = ".\storage_kg_custom_prompt\graph_store.json"
# DOC_STORE_PATH = "docstore.json" # Empty in your case
# VECTOR_STORE_PATH = "default__vector_store.json" # Empty in your case


# --- Load the Graph Store ---
print(f"Loading graph data from {GRAPH_STORE_PATH}...")
try:
    with open(GRAPH_STORE_PATH, "r") as f:
        graph_data = json.load(f)
except FileNotFoundError:
    print(f"Error: {GRAPH_STORE_PATH} not found.")
    exit()
except json.JSONDecodeError:
    print(f"Error: Could not decode JSON from {GRAPH_STORE_PATH}.")
    exit()

# Create a SimpleGraphStore and populate it
graph_store = SimpleGraphStore()

if 'graph_dict' in graph_data and isinstance(graph_data['graph_dict'], dict):
    graph_store.graph_dict = graph_data['graph_dict']
    print(f"Successfully loaded graph_dict with {len(graph_store.graph_dict)} subjects.")
    # Optional: Print a few subjects to verify
    # count = 0
    # for subj in graph_store.graph_dict:
    #     print(f"- {subj}")
    #     count += 1
    #     if count >= 3:
    #         break
else:
    print("Error: 'graph_dict' not found or not a dictionary in graph_store.json")
    exit()


# --- Create the KnowledgeGraphIndex ---
# Since we are loading an existing graph_store, we don't need to provide documents (nodes=[])
# We are also not loading from a persisted index directory, but rebuilding from the graph_store
index = KnowledgeGraphIndex(
    nodes=[],  # No new documents to process
    graph_store=graph_store,
    # You can optionally set index_id if you plan to persist it again
    # index_id="my_restored_kg_index"
)
print("KnowledgeGraphIndex created from loaded graph store.")

# --- Querying the Graph ---
# When querying, include_text=False is important if docstore is empty,
# as it forces the query to rely on the graph structure.
# If you had text in docstore and wanted to retrieve it, you'd set include_text=True.
query_engine = index.as_query_engine(
    include_text=False, # Set to True if you had a populated docstore and wanted text summaries
    response_mode="tree_summarize", # Good for KG, consolidates info from multiple paths
    # You can also try other modes like "compact"
    # For more complex KG queries, you might explore sub_graph_query_engine_configs
)
print("Query engine created.")

# --- Example Queries ---
queries = [
    "What are the symptoms of Subarachnoidalblødning (SAB)?",
    "What actions are included in 'Advice 3. DIFFICULTY BREATHING'?",
    "What does 'Sentralstimulerende midler' omfatter?",
    "What are the criteria for 'Urgent' severity?",
    "What should a Responder do for an infant under 1 year?",
    "What are the causes of Eclampsia?",
    "What is Medisinsk kull NOT given for?"
]

for query_text in queries:
    print(f"\n📝 Querying: {query_text}")
    try:
        response = query_engine.query(query_text)
        print(f"💡 Response: {response}")
        # print("\n--- Source Nodes ---")
        # for node in response.source_nodes:
        #     print(f"Node ID: {node.node_id}, Score: {node.score}")
        #     # print(f"Metadata: {node.metadata}") # Might be empty if not set
        #     # print(f"Text: {node.get_content(metadata_mode='llm')}") # If include_text=True
    except Exception as e:
        print(f"Error during query: {e}")

print("\nDone.")

In [74]:
import json
import os

from llama_index.core import (
    KnowledgeGraphIndex,
    SimpleDirectoryReader,
    StorageContext,
    Settings,
    VectorStoreIndex, # We'll use its vector store component
)
from llama_index.core.graph_stores import SimpleGraphStore
from llama_index.core.vector_stores import SimpleVectorStore # Or other vector stores
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding # For embeddings
from llama_index.core.schema import TextNode
from llama_index.core.node_parser import SimpleNodeParser

In [None]:
# GRAPH_STORE_PATH = "graph_store.json"
PERSIST_DIR = "./storage_with_embeddings"
GRAPH_STORE_PATH = ".\storage_kg_custom_prompt\graph_store.json"

Settings.llm = llm
Settings.embed_model = embed_model
Settings.chunk_size = 512
Settings.chunk_overlap = 100

Settings.node_parser = SimpleNodeParser.from_defaults(chunk_size=Settings.chunk_size)
# --- Load the Graph Store ---
print(f"Loading graph data from {GRAPH_STORE_PATH}...")
try:
    with open(GRAPH_STORE_PATH, "r") as f:
        graph_data_loaded = json.load(f)
except FileNotFoundError:
    print(f"Error: {GRAPH_STORE_PATH} not found.")
    exit()
except json.JSONDecodeError:
    print(f"Error: Could not decode JSON from {GRAPH_STORE_PATH}.")
    exit()

graph_store = SimpleGraphStore()
if 'graph_dict' in graph_data_loaded and isinstance(graph_data_loaded['graph_dict'], dict):
    graph_store.graph_dict = graph_data_loaded['graph_dict']
    print(f"Successfully loaded graph_dict with {len(graph_store.graph_dict)} subjects.")
else:
    print("Error: 'graph_dict' not found or not a dictionary in graph_store.json")
    exit()

# --- 2. Extract Entities and Create Nodes for Embedding ---
# We'll treat the subjects of your graph as the text to be embedded.
entities_to_embed = list(graph_store.graph_dict.keys())
print(f"Extracted {len(entities_to_embed)} entities (subjects) for embedding.")

entity_nodes = []
for i, entity_text in enumerate(entities_to_embed):
    # We use the entity text as the node_id to potentially link it back
    # if the KG index's internal mechanisms support it during retrieval.
    # For simple embedding and vector store population, a unique ID is sufficient.
    node = TextNode(text=entity_text, id_=f"entity_node_{i}")
    entity_nodes.append(node)

print(f"Created {len(entity_nodes)} TextNodes for embedding.")

# --- 3. Setup Vector Store and Storage Context ---
vector_store = SimpleVectorStore()
storage_context = StorageContext.from_defaults(
    graph_store=graph_store,
    vector_store=vector_store
)
print("StorageContext created with graph_store and vector_store.")


# --- 4. Build/Rebuild KnowledgeGraphIndex (This will generate and store embeddings) ---
# By providing 'nodes' (our entity_nodes), the index will process them,
# generate embeddings, and store them in the vector_store provided in storage_context.
# The existing graph_store will be used for the graph structure.
index = KnowledgeGraphIndex(
    nodes=entity_nodes, # These nodes will be embedded
    storage_context=storage_context,
    max_triplets_per_chunk=2, # This is more for building from docs, less critical here
    include_embeddings=True, # Explicitly enable embedding storage within the KG index logic
)
print("KnowledgeGraphIndex built. Embeddings for entities should now be in the vector_store.")

# --- 5. Persist the Index (including Graph Store and Vector Store with Embeddings) ---
if not os.path.exists(PERSIST_DIR):
    os.makedirs(PERSIST_DIR)
index.storage_context.persist(persist_dir=PERSIST_DIR)
print(f"Index, graph store, and vector store (with embeddings) persisted to {PERSIST_DIR}")

# --- 6. Querying with Embeddings ---
# The query engine will now use embeddings for semantic similarity if relevant,
# and then traverse the graph.
query_engine = index.as_query_engine(
    include_text=True, # Set to True to see the text of the matched entity nodes
    response_mode="tree_summarize",
    embedding_mode="hybrid", # Explicitly try hybrid (semantic + keyword)
    # similarity_top_k=3 # How many semantically similar nodes to consider as entry points
)
print("Query engine created.")

# --- Example Queries ---
queries = [
    "What are the symptoms of Subarachnoidalblødning (SAB)?", # Graph traversal
    "Tell me about conditions similar to 'headache'.", #graph_store = SimpleGraphStore()
]
if 'graph_dict' in graph_data and isinstance(graph_data['graph_dict'], dict):
    graph_store.graph_dict = graph_data['graph_dict']
    print(f"Successfully loaded graph_dict with {len(graph_store.graph_dict)} subjects.")
else:
    print("Error: 'graph_dict' not found or not a dictionary in graph_store.json")
    exit()

# --- Extract Unique Node Labels from the Graph to Embed ---
print("Extracting unique node labels for embedding...")
unique_node_labels = set()
for subj, rel_obj_list in graph_store.graph_dict.items():
    unique_node_labels.add(str(subj)) # Ensure subject is string
    for rel, obj in rel_obj_list:
        unique_node_labels.add(str(obj)) # Ensure object is string

print(f"Found {len(unique_node_labels)} unique node labels to embed.")

# Create TextNode objects for each unique label
# These nodes will be embedded and stored in the vector store.
# The KG index will then link its graph structure to these embedded nodes.
text_nodes_for_embedding = [TextNode(text=label, id_=label) for label in unique_node_labels]
print(f"Created {len(text_nodes_for_embedding)} TextNodes for embedding.")


# --- Set up Storage Context with Graph Store and Vector Store ---
# We will use a SimpleVectorStore here. For production, consider FAISS, Pinecone, etc.
vector_store = SimpleVectorStore()

storage_context = StorageContext.from_defaults(
    graph_store=graph_store,
    vector_store=vector_store
)
print("StorageContext created.")

# --- Create/Build the KnowledgeGraphIndex ---
# This step will embed the 'text_nodes_for_embedding' and store them in the vector_store
# It will also use the provided graph_store.
index = KnowledgeGraphIndex(
    nodes=text_nodes_for_embedding, # Provide the nodes to be embedded
    storage_context=storage_context,
    # kg_triplet_extract_fn can be used if building from text, not needed here
    # index_id="my_embedded_kg_index" # Optional
    # embed_kg_nodes=True # This is often on by default if an embed_model is in Settings
                        # and nodes are provided or built from documents.
)
print("KnowledgeGraphIndex built. Node labels should now be embedded.")

# --- Persist the Index (including embeddings in the vector store) ---
print(f"Persisting index to {PERSIST_DIR}...")
os.makedirs(PERSIST_DIR, exist_ok=True)
index.storage_context.persist(persist_dir=PERSIST_DIR)
# Alternatively: index.persist(persist_dir=PERSIST_DIR)
print(f"Index persisted. Graph store and vector store (with embeddings) are saved.")


# --- Querying the Graph with Embeddings ---
# similarity_top_k will retrieve nodes based on embedding similarity first,
# then the KG query engine will use these as starting points or context.
query_engine = index.as_query_engine(
    include_text=True, # Now we want the text of the nodes (their labels)
    response_mode="tree_summarize",
    similarity_top_k=5, # Retrieve top 5 similar nodes using embeddings
    # You can also explore graph_query_synthesis_prompt for more control
    # and explore different retriever_modes for how KG and vector search combine.
    # For example, you might look into KnowledgeGraphQueryEngine and its configurations.
)
print("Query engine created (embedding-aware).")

# --- Example Queries ---
queries = [
    "Tell me about conditions related to breathing difficulties.", # More semantic query
    "What medical advice is given for difficulty breathing?",
    "What are the symptoms of SAB?", # SAB is Subarachnoidalblødning
    "What are treatments or actions for medicinal poisoning?",
    "Find information about substances that stimulate the central nervous system.",
    "What are the urgent criteria related to health issues?",
    ]

for query_text in queries:
    print(f"\n📝 Querying: {query_text}")
    try:
        response = query_engine.query(query_text)
        print(f"💡 Response: {response}")
        print("\n--- Source Nodes (from vector search + KG traversal) ---")
        for node_with_score in response.source_nodes:
            node = node_with_score.node
            print(f"Node ID: {node.id_}, Score: {node_with_score.score:.4f}")
            # print(f"Text: {node.get_content(metadata_mode='llm')}")
    except Exception as e:
        print(f"Error during query: {e}")

print("\nDone.")

In [None]:
# --- To load the persisted index later ---
# from llama_index.core import load_index_from_storage
# print("\n--- Example: Loading persisted index ---")
# # Re-initialize Settings if in a new session
# Settings.llm = llm
# Settings.embed_model = embed_model

# loaded_storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
# loaded_index = load_index_from_storage(
#     storage_context=loaded_storage_context,
#     # If you used SimpleGraphStore, LlamaIndex should auto-detect it.
#     # If you used a custom graph_store, you might need to pass it:
#     # graph_store=SimpleGraphStore.from_persist_dir(persist_dir=PERSIST_DIR),
# )
# print("Index loaded from persistence.")
# loaded_query_engine = loaded_index.as_query_engine(
#     include_text=True,
#     response_mode="tree_summarize",
#     similarity_top_k=5
# )
# test_query_response = loaded_query_engine.query("What is Epiglottitis?")
# print(f"Test query on loaded index: {test_query_response}")

In [None]:


# --- Configuration ---
GRAPH_STORE_FILE = "graph_store.json"
DOCSTORE_FILE = "docstore.json"
INDEX_STORE_FILE = "index_store.json"
VECTOR_STORE_FILE = "default__vector_store.json" # LlamaIndex default naming convention
IMAGE_VECTOR_STORE_FILE = "image__vector_store.json" # Not used here, but good to have defined

PERSIST_DIR = "./persisted_kg" # Persist files in the current directory

# --- Main Script ---
def main():

    # 1. Load Knowledge Graph Data from graph_store.json
    print(f"\n--- Loading Knowledge Graph from {GRAPH_STORE_PATH} ---")
    # --- Load the Graph Store ---
    print(f"Loading graph data from {GRAPH_STORE_PATH}...")
    try:
        with open(GRAPH_STORE_PATH, "r") as f:
            graph_data_loaded = json.load(f)
    except FileNotFoundError:
        print(f"Error: {GRAPH_STORE_PATH} not found.")
        exit()
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {GRAPH_STORE_PATH}.")
        exit()


    # Convert to a list of (subject, relation, object) triples
    triples = []
    all_nodes = set()
    for subj, rel_obj_list in graph_triples_dict.items():
        all_nodes.add(subj)
        for rel, obj in rel_obj_list:
            triples.append((subj, rel, obj))
            all_nodes.add(obj)
    
    print(f"Loaded {len(triples)} triples from the graph store.")
    print(f"Found {len(all_nodes)} unique nodes/entities.")

    # 2. Initialize Stores and Storage Context
    # We will try to load existing stores if they have data, otherwise, they'll be new.
    
    # Graph Store
    graph_store = SimpleGraphStore()
    graph_store.upsert_triples(triples) # Populate with our loaded triples
    print("Graph store populated with triples.")

    # Vector Store (will be populated with embeddings)
    vector_store = SimpleVectorStore.from_persist_path(
        os.path.join(PERSIST_DIR, VECTOR_STORE_FILE)
    ) if os.path.exists(os.path.join(PERSIST_DIR, VECTOR_STORE_FILE)) and os.path.getsize(os.path.join(PERSIST_DIR, VECTOR_STORE_FILE)) > 2 else SimpleVectorStore()
    
    # Document Store (will store text of nodes)
    docstore = SimpleDocumentStore.from_persist_path(
        os.path.join(PERSIST_DIR, DOCSTORE_FILE)
    ) if os.path.exists(os.path.join(PERSIST_DIR, DOCSTORE_FILE)) and os.path.getsize(os.path.join(PERSIST_DIR, DOCSTORE_FILE)) > 2 else SimpleDocumentStore()

    storage_context = StorageContext.from_defaults(
        graph_store=graph_store,
        vector_store=vector_store,
        docstore=docstore,
        # index_store will be handled by persist/load_index_from_storage
    )
    print("Initialized stores and storage context.")

    # 3. Create LlamaIndex Documents from KG Nodes for Embedding
    # The KnowledgeGraphIndex itself doesn't directly embed the nodes from the graph_store
    # for vector search. It relies on Documents being embedded.
    # We create one Document per unique node in our KG.
    
    # Check if docs for nodes already exist to avoid re-embedding (simple check)
    existing_doc_texts = set()
    if not docstore.is_empty():
        for doc_id in docstore.get_all_document_hashes().values():
             try:
                existing_doc_texts.add(docstore.get_document(doc_id).text)
             except: # Handle potential errors if doc_id is somehow invalid
                pass
    
    documents_to_index = []
    for node_name in all_nodes:
        if node_name not in existing_doc_texts:
            documents_to_index.append(Document(text=node_name, metadata={"is_kg_node": True}))

    print(f"Created {len(documents_to_index)} new LlamaIndex Documents for KG nodes to be embedded.")
    if not documents_to_index and not vector_store.is_empty():
        print("All nodes seem to be already embedded and present in docstore/vector_store.")


    # 4. Build or Load the Knowledge Graph Index (and create embeddings)
    print("\n--- Building/Loading Knowledge Graph Index ---")
    
    # Try to load the index if an index_store.json exists and is valid
    index = None
    if os.path.exists(INDEX_STORE_FILE):
        try:
            # We need to load the specific index_id from the index_store.json
            with open(INDEX_STORE_FILE, "r") as f:
                index_store_data = json.load(f)
            index_ids = list(index_store_data.get("index_store/data", {}).keys())
            
            if index_ids:
                index_id = index_ids[0] # Assuming one main KG index
                print(f"Attempting to load index with ID: {index_id}")
                index = load_index_from_storage(
                    storage_context,
                    index_id=index_id
                )
                print("Successfully loaded existing KG Index from storage.")
                # If we loaded the index, and there are new documents, we might need to insert them
                if documents_to_index:
                    print(f"Inserting {len(documents_to_index)} new documents into existing index.")
                    for doc in documents_to_index:
                        index.insert(doc) # This will also handle embedding
            else:
                print("index_store.json found but no index IDs within. Will build a new index.")
        except Exception as e:
            print(f"Failed to load index from storage: {e}. Will build a new index.")
            # If loading fails, reset storage context for a clean build for certain components
            storage_context.docstore.delete_docs(list(storage_context.docstore.docs.keys()), raise_error=False) # Clear docstore if index load failed
            storage_context.vector_store = SimpleVectorStore() # Reset vector store

    if index is None:
        print("Building a new Knowledge Graph Index.")
        # When building from scratch, use all nodes as documents
        # This ensures embeddings are created for all entities in the KG
        # Note: `kg_triples_extract_fn` is set to lambda x: [] because we are providing triples manually
        # to the graph_store. We don't want the index to try and extract them from node names.
        # `include_embeddings=True` ensures embeddings are generated for the provided documents (nodes)
        # and stored in the vector_store.
        all_node_documents = [Document(text=node_name, metadata={"is_kg_node": True}) for node_name in all_nodes]
        index = KnowledgeGraphIndex(
            nodes=all_node_documents, # Documents to be embedded
            storage_context=storage_context,
            max_triplets_per_chunk=10, # Not strictly needed as we provide triples
            include_embeddings=True,   # This is key for creating embeddings for nodes
            kg_triple_extract_fn=lambda x: [], # Don't extract, use graph_store
            show_progress=True
        )
        print("New KG Index built. Embeddings for nodes should now be in the vector store.")

    # 5. Persist the updated stores
    # The storage_context.persist() method saves the docstore, vector_store, and index_store.
    # The graph_store (SimpleGraphStore) persists its data if a persist_path is given at init,
    # or you can call graph_store.persist() explicitly.
    # Here, we are primarily concerned with vector_store and index_store persistence.
    print("\n--- Persisting Stores ---")
    # index.storage_context.persist(persist_dir=PERSIST_DIR)
    # More robust: persist components individually if they have changes or need specific paths
    
    if hasattr(storage_context.docstore, 'persist'):
        storage_context.docstore.persist(persist_path=os.path.join(PERSIST_DIR, DOCSTORE_FILE))
        print(f"Persisted docstore to {DOCSTORE_FILE}")

    if hasattr(storage_context.vector_store, 'persist'):
        storage_context.vector_store.persist(persist_path=os.path.join(PERSIST_DIR, VECTOR_STORE_FILE))
        print(f"Persisted vector_store to {VECTOR_STORE_FILE}")
    
    if hasattr(storage_context.index_store, 'persist'):
         storage_context.index_store.persist(persist_path=os.path.join(PERSIST_DIR, INDEX_STORE_FILE))
         print(f"Persisted index_store to {INDEX_STORE_FILE}")
    else: # For KGIndex, the index_store is part of the index object itself often
        index.storage_context.persist(persist_dir=PERSIST_DIR) # This will save index_store.json
        print(f"Persisted index_store via index.storage_context.persist() to {INDEX_STORE_FILE}")

    # Verify vector store content
    if os.path.exists(VECTOR_STORE_FILE):
        with open(VECTOR_STORE_FILE, 'r') as f:
            vs_data = json.load(f)
        if vs_data.get("embedding_dict"):
            print(f"Vector store ({VECTOR_STORE_FILE}) contains {len(vs_data['embedding_dict'])} embeddings.")
        else:
            print(f"Vector store ({VECTOR_STORE_FILE}) appears to be empty or has no embeddings.")
    else:
        print(f"Vector store file ({VECTOR_STORE_FILE}) not found after persist.")


    # 6. Query the Knowledge Graph
    print("\n--- Querying the Knowledge Graph ---")
    
    # We can use a retriever that fetches keyword matches first, then expands via KG
    # For embedding-based retrieval of starting nodes, the default retriever behavior is often sufficient
    # if include_embeddings was True during indexing.
    
    # Example using the more explicit KnowledgeGraphQueryEngine
    # This engine can use embeddings if the retriever is configured to do so,
    # or use keyword matching.
    
    # Define a custom prompt for better KG interaction if needed
    # (using default for simplicity now)
    # graph_rag_retrieval_query_pt = PromptTemplate(
    #     "Use the knowledge graph to answer the query. \n"
    #     "Query: {query_str}\n"
    #     "Knowledge Graph Triples: \n {kg_feed_str}"
    # )

    query_engine = index.as_query_engine(
        include_text=False, # We want to see the graph traversal, not just node text
        response_mode="tree_summarize", # Good for KG
        embedding_mode="hybrid", # Uses both embeddings and keywords if available
        similarity_top_k=3, # How many embedding-similar nodes to start with
        # query_template=graph_rag_retrieval_query_pt # If using custom prompt
    )
    
    # For more direct control over KG part of retrieval:
    # from llama_index.core.retrievers import KnowledgeGraphRAGRetriever
    # kg_retriever = KnowledgeGraphRAGRetriever(
    #     storage_context=storage_context,
    #     verbose=True,
    #     graph_traversal_depth=2,
    #     max_knowledge_sequence=500,
    #     # You can also specify llm and embed_model here if not using Settings
    # )
    # You could then use this retriever within a RetrieverQueryEngine

    queries = [
        "What are the symptoms of a Seizure?",
        "What can cause Unconsciousness?",
        "Tell me about Low blood sugar.",
        "What is Headache a symptom of?",
    ]

    for query_str in queries:
        print(f"\nQuery: {query_str}")
        try:
            response = query_engine.query(query_str)
            print("Response:")
            # The response object can be complex.
            # If it's a string, print directly. If it has source_nodes, explore them.
            if hasattr(response, 'response'):
                print(response.response)
            else:
                print(str(response))

            if hasattr(response, 'source_nodes') and response.source_nodes:
                print("\nSource Triples/Nodes (if any):")
                for sn in response.source_nodes:
                    # KG source nodes might have metadata containing the triples
                    if "kg_rel_map" in sn.node.metadata: # For older LlamaIndex versions
                         print(f"  Node: {sn.node.text}, Score: {sn.score}, Relations: {sn.node.metadata['kg_rel_map']}")
                    elif "triplets" in sn.node.metadata: # More common for KG specific retrievers
                         print(f"  Node: {sn.node.text}, Score: {sn.score}, Triplets: {sn.node.metadata['triplets']}")
                    else:
                         print(f"  Node: {sn.node.text}, Score: {sn.score}, Metadata: {sn.node.metadata}")
            print("-" * 30)
        except Exception as e:
            print(f"Error querying '{query_str}': {e}")
            import traceback
            traceback.print_exc()

if __name__ == "__main__":
    main()

In [None]:
import os
import json
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader

from llama_index.core.graph_stores.simple import SimpleGraphStore
import networkx as nx
import matplotlib.pyplot as plt

# Load your existing knowledge graph data
def load_graph_store():
    print("Loading existing knowledge graph...")
    try:
        with open(r'./storage_kg_custom_prompt/graph_store.json', 'r') as f:
            graph_data = json.load(f)
        
        # Create a SimpleGraphStore from the loaded data
        graph_store = SimpleGraphStore()
        graph_store.graph_dict = graph_data.get("graph_dict", {})
        print(f"Loaded graph with {len(graph_store.graph_dict)} primary nodes.")
        return graph_store
    except Exception as e:
        print(f"Error loading graph: {e}")
        return None

# Extract all nodes from the graph
def extract_all_nodes(graph_store):
    """Extract all unique nodes from the graph."""
    nodes = set()
    
    # Iterate through the graph and extract all nodes
    for source, rel_obj_list in graph_store.graph_dict.items(): # rel_obj_list is a list of [predicate, object]
        nodes.add(source)
        
        for predicate, obj in rel_obj_list: # Iterate through each [predicate, object] pair
            nodes.add(obj) # Add the object to the set of nodes
    
    print(f"Found {len(nodes)} unique nodes in the graph.")
    return nodes

# Generate embeddings for all nodes in the knowledge graph
def generate_embeddings(graph_store):
    print("Generating embeddings for knowledge graph nodes...")
    
    # Create an embedding model (you'll need to set OPENAI_API_KEY env variable)
    embed_model = embed_model = HuggingFaceEmbedding(
    model_name="BAAI/bge-m3"
    )
    
    # Extract the nodes from the graph
    nodes = extract_all_nodes(graph_store)
    
    # Generate embeddings for each node
    embedding_dict = {}
    for node in nodes:
        try:
            embedding = embed_model.get_text_embedding(node)
            embedding_dict[node] = embedding
            if len(embedding_dict) % 50 == 0:
                print(f"Generated embeddings for {len(embedding_dict)} nodes.")
        except Exception as e:
            print(f"Error generating embedding for '{node}': {e}")
    
    print(f"Successfully generated embeddings for {len(embedding_dict)} nodes.")
    return embedding_dict

# Save the embeddings to the vector store file
def save_embeddings(embedding_dict):
    print("Saving embeddings to vector_store.json...")
    try:
        # Initialize the vector store dictionary structure
        vector_store = {
            "embedding_dict": embedding_dict,
            "text_id_to_ref_doc_id": {},
            "metadata_dict": {}
        }
        
        # Save the vector store to the output file
        with open('vector_store.json', 'w') as f:
            json.dump(vector_store, f)
        
        print(f"Successfully saved embeddings for {len(embedding_dict)} nodes.")
        return True
    except Exception as e:
        print(f"Error saving embeddings: {e}")
        return False

# Create a knowledge graph index with embeddings for querying
def create_kg_index(graph_store, embedding_dict):
    
    print("Creating queryable knowledge graph index...")
    embed_model = HuggingFaceEmbedding(
        model_name="BAAI/bge-m3"
    )
    try:
        storage_context = StorageContext.from_defaults(graph_store=graph_store)
        
        # Create a KnowledgeGraphIndex
        kg_index = KnowledgeGraphIndex(
            nodes=[],  # Nodes are already in the graph_store
            storage_context=storage_context,
            embed_model=embed_model,
            include_embeddings=True,
        )
        
        print("Knowledge graph index created successfully.")
        return kg_index
    except Exception as e:
        print(f"Error creating knowledge graph index: {e}")
        return None

# Query the knowledge graph
def query_knowledge_graph(kg_index, query_text):
    print(f"Querying knowledge graph with: '{query_text}'")
    try:
        # Create a query engine
        query_engine = kg_index.as_query_engine(
            include_text=True,
            response_mode="tree_summarize"
        )
        
        # Execute the query
        response = query_engine.query(query_text)
        
        return response
    except Exception as e:
        print(f"Error querying knowledge graph: {e}")
        return None


# Main function
def main():
    # Set your OpenAI API key (or set it as an environment variable)
    # os.environ["OPENAI_API_KEY"] = "your-api-key"
    
    # Load the existing graph store
    graph_store = load_graph_store()
    if not graph_store:
        print("Failed to load graph store. Exiting.")
        return
    
    # Generate embeddings for the nodes
    print("Generating embeddings...")
    embedding_dict = generate_embeddings(graph_store)
    
    # Save the embeddings
    print("Saving embeddings...")
    success = save_embeddings(embedding_dict)
    if not success:
        print("Failed to save embeddings. Exiting.")
        return
    
    # Create a knowledge graph index for querying
    print("Creating knowledge graph index...")
    kg_index = create_kg_index(graph_store, embedding_dict)
    if not kg_index:
        print("Failed to create knowledge graph index. Exiting.")
        return
    print("Knowledge graph index created successfully.")
    
    # Example: Query the knowledge graph
    example_queries = [
        "What are the symptoms of a seizure?",
        "What requires immediate medical attention?",
        "What advice should be given for difficulty breathing?"
    ]
    
    for query in example_queries:
        response = query_knowledge_graph(kg_index, query)
        if response:
            print(f"\nQuery: {query}")
            print(f"Response: {response}")
            print("-" * 50)

if __name__ == "__main__":
    main()

Loading existing knowledge graph...
Loaded graph with 2462 primary nodes.
Generating embeddings...
Generating embeddings for knowledge graph nodes...
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: BAAI/bge-m3
Load pretrained SentenceTransformer: BAAI/bge-m3
INFO:sentence_transformers.SentenceTransformer:2 prompts are loaded, with the keys: ['query', 'text']
2 prompts are loaded, with the keys: ['query', 'text']
Found 6007 unique nodes in the graph.
Generated embeddings for 50 nodes.
Generated embeddings for 100 nodes.
Generated embeddings for 150 nodes.
Generated embeddings for 200 nodes.
Generated embeddings for 250 nodes.
Generated embeddings for 300 nodes.
Generated embeddings for 350 nodes.
Generated embeddings for 400 nodes.
Generated embeddings for 450 nodes.
Generated embeddings for 500 nodes.
Generated embeddings for 550 nodes.
Generated embeddings for 600 nodes.
Generated embeddings for 650 nodes.
Generated embeddings for 700 nodes.
Generat

In [None]:
#TODO:
"""
search node vector embeddings->convert to text->search graph store for text->return graph nodes
"""

In [84]:
import json
from collections import defaultdict
from llama_index.core.graph_stores.simple_labelled import SimplePropertyGraphStore
from llama_index.core.indices.property_graph import PropertyGraphIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.simple import SimpleVectorStore
from llama_index.core import VectorStoreIndex
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.graph_stores.types import EntityNode, Relation

embed_model = HuggingFaceEmbedding(
    model_name="BAAI/bge-m3"
    )
# === Step 1: Load and Parse Graph ===
with open("storage_kg_custom_prompt\graph_store.json") as f:
    graph_data = json.load(f)["graph_dict"]

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: BAAI/bge-m3
Load pretrained SentenceTransformer: BAAI/bge-m3


  with open("storage_kg_custom_prompt\graph_store.json") as f:


INFO:sentence_transformers.SentenceTransformer:2 prompts are loaded, with the keys: ['query', 'text']
2 prompts are loaded, with the keys: ['query', 'text']


In [None]:


graph = SimplePropertyGraphStore()
node_set = set()

print(node_set)
# === Step 2: Build Graph ===
for subject, triples in graph_data.items():
    for predicate, obj in triples:
        print(f" {subject} -> {predicate} -> {obj}")
        if subject not in node_set:
            graph.upsert_nodes([EntityNode(name=subject, label="subject")])
            node_set.add(subject)
        if obj not in node_set:
            graph.upsert_nodes([EntityNode(name=obj, label="object")])
            node_set.add(obj)
        graph.upsert_relations([Relation(label=predicate, source_id=subject, target_id=obj)])
        
        
print(len(node_set))


In [None]:
for subject, triples in graph_data.items():
    for predicate, obj in triples:
        object = str(obj)
        subject = str(subject)
        predicate = str(predicate)
        
        graph.upsert_nodes([EntityNode(name=subject, label="subject")])
        graph.upsert_nodes([EntityNode(name=object, label="object")])

        relation = Relation(label=predicate, source_id=subject, target_id=object)
        print(f"Adding relation: {relation}")
            
        graph.upsert_relations([Relation(label=predicate, source_id=subject, target_id=object)])

In [112]:
graph = SimplePropertyGraphStore()
for subject, triples in graph_data.items():
    # Create subject node first
    subject_str = str(subject)
    # Generate the embeddings for the subject node
    subject_embedding = embed_model.get_text_embedding(subject_str)
    # Create the subject node with the embedding
    graph.upsert_nodes([EntityNode(id=subject_str, name=subject_str, label="subject", embedding=subject_embedding)])   
    for predicate, obj in triples:
        object_str = str(obj)
        predicate_str = str(predicate)
        
        # Create object node
        object_embedding = embed_model.get_text_embedding(object_str)
        graph.upsert_nodes([EntityNode(id=object_str, name=object_str, label="object",embedding=object_embedding)])
        
        # Create relation using the IDs
        relation = Relation(label=predicate_str, source_id=subject_str, target_id=object_str )
        graph.upsert_relations([relation])

# When querying, pass parameters to get_triplets
# For example, to get all triplets:
# triplets = graph.get_triplets(entity_names=[str(s) for s in graph_data.keys()])
# triplets


In [None]:
# Persist the storage context
graph.persist(persist_path="./storage_with_embeddings_in_kg.json")


In [121]:
# Create a storage context
storage_context = StorageContext.from_defaults(graph_store=graph)
# Persist the storage context
storage_context.persist(persist_dir="./storage")

In [None]:
# Load the storage context
storage_context = StorageContext.from_defaults(persist_dir="./storage")

# Access the graph store
graph_store = storage_context.graph_store


In [None]:
#### TODO: how to match query embeddings to graph nodes (embeddings) and return the graph nodes?


# --- Example Querying ---
# Assuming you have a query embedding
# You can use the vector store to find similar nodes based on embeddings
# For example, if you have a query embedding:
query_embedding = embed_model.get_text_embedding("What are the symptoms of a seizure?")
# You can use the vector store to find similar nodes based on embeddings
vector_store = SimpleVectorStore()
# Assuming you have a vector store with the embeddings of your nodes
similar_nodes = vector_store.similarity_search(query_embedding, k=5)  # Get top 5 similar nodes
# This will return the nodes that are most similar to the query embedding
# You can then use these nodes to find related information in your graph
# For example, if you have a function to get related nodes:
def get_related_nodes(node_id):
    # This function should return the related nodes from your graph
    return graph.get_related_nodes(node_id)
# You can then use the similar nodes to get related information


In [None]:
### check this: https://docs.llamaindex.ai/en/stable/examples/property_graph/property_graph_custom_retriever/

In [114]:
# First, get the nodes you want to explore relationships for
subject_nodes = graph.get(ids=["Important Information to the Caller"])  # or use properties to filter


# Then pass these nodes to get_rel_map
if subject_nodes:  # Make sure you have nodes before calling get_rel_map
    relations_map = graph.get_rel_map(
        graph_nodes=subject_nodes,
        depth=2,  # How many hops to explore
        limit=30  # Maximum number of triplets to return
    )
    print(f"Found {len(relations_map)} relationships")
    print("Relationships:")
    for rel in relations_map:
        print(rel)
else:
    print("No nodes found to map relationships")

Found 30 relationships
Relationships:
(EntityNode(label='subject', embedding=[0.016924625262618065, 0.03575928136706352, -0.05892634019255638, 0.004378217272460461, -0.018147598952054977, -0.04115688428282738, 0.006678886711597443, -0.025580881163477898, -0.0076911235228180885, 0.01850573904812336, 0.014300176873803139, 0.0041771866381168365, -0.010296151041984558, -0.0063529484905302525, -0.00010613004269544035, -0.022726163268089294, -0.005216141231358051, -0.02079581469297409, -0.01584470644593239, 0.005753336939960718, -0.01436888612806797, 0.03183547034859657, -0.025949664413928986, 0.003342696465551853, -0.017934715375304222, 0.021911881864070892, -0.012286962009966373, -0.007130986545234919, -0.019872350618243217, -0.008588476106524467, -0.0066472003236413, 0.00037115486338734627, 0.013172182254493237, -0.009251204319298267, -0.009280216880142689, -0.04178086668252945, 0.003872290253639221, 0.0049677323549985886, -0.06649777293205261, 0.006946249399334192, 0.00397747615352273, -

In [12]:
import json

json_file_path ="storage_with_embeddings_in_kg.json"
def print_json_without_embeddings(json_file_path):
    # Read the JSON file
    with open(json_file_path, 'r') as file:
        data = json.load(file)
    
    # Function to process the JSON and hide embeddings
    def process_object(obj):
        if isinstance(obj, dict):
            result = {}
            for key, value in obj.items():
                if key == "embedding":
                    result[key] = "[... embedding array hidden ...]"
                else:
                    result[key] = process_object(value)
            return result
        elif isinstance(obj, list):
            return [process_object(item) for item in obj]
        else:
            return obj
    
    # Process and print the data
    processed_data = process_object(data)
    print(json.dumps(processed_data, indent=2))

# Example usage
print_json_without_embeddings(json_file_path)

{
  "nodes": {
    "Important Information to the Caller": {
      "label": "subject",
      "embedding": "[... embedding array hidden ...]",
      "properties": {},
      "name": "Important Information to the Caller"
    },
    "Tell me immediately if anything changes": {
      "label": "subject",
      "embedding": "[... embedding array hidden ...]",
      "properties": {},
      "name": "Tell me immediately if anything changes"
    },
    "Watch the person all the time": {
      "label": "object",
      "embedding": "[... embedding array hidden ...]",
      "properties": {},
      "name": "Watch the person all the time"
    },
    "Help is on the way": {
      "label": "object",
      "embedding": "[... embedding array hidden ...]",
      "properties": {},
      "name": "Help is on the way"
    },
    "Inform immediately if anything changes": {
      "label": "object",
      "embedding": "[... embedding array hidden ...]",
      "properties": {},
      "name": "Inform immediately if 

In [14]:
from dotenv import load_dotenv
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
load_dotenv()


endpoint = "https://d-ais-eus-ais-chatbots.openai.azure.com/"
model_name = "o1-mini"
deployment = "o1-mini"
subscription_key = os.getenv("AZURE_OPENAI_API_KEY")
api_version = "2024-12-01-preview" # Use a valid API version

llm = AzureOpenAI(
    azure_endpoint=endpoint,
    api_key=subscription_key,
    api_version=api_version,
    deployment_name=deployment,
    model_name=model_name,
    temperature=1.0
)

embed_model = HuggingFaceEmbedding(
    model_name="BAAI/bge-m3"
    )

Settings.embed_model = embed_model
Settings.llm = None
Settings.chunk_size = 512
Settings.chunk_overlap = 100

llm2 = AzureOpenAI(
    azure_endpoint=endpoint,
    api_key=subscription_key,
    api_version="2024-05-01-preview",
    deployment_name="gpt-4o-mini-test",
    model_name="gpt-4o-mini-test",
    temperature=1.0
)

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: BAAI/bge-m3
Load pretrained SentenceTransformer: BAAI/bge-m3


INFO:sentence_transformers.SentenceTransformer:2 prompts are loaded, with the keys: ['query', 'text']
2 prompts are loaded, with the keys: ['query', 'text']
LLM is explicitly disabled. Using MockLLM.


In [104]:
# === Step 3: Create PropertyGraphIndex ===
pg_index = PropertyGraphIndex.from_existing(
    property_graph_store=graph,
    llm=None,           # Optional: Specify an LLM for path extraction
    vector_store=None,  # Optional: Add a vector store if needed
    embed_model=embed_model,   # Optional: Specify an embedding model
    show_progress=True, # Optional: Enable progress bars
)


In [105]:
# use
retriever = pg_index.as_retriever(
    include_text=True,  # include source chunk with matching paths
    # similarity_top_k=2,  # top k for vector kg node retrieval
)
nodes = retriever.retrieve("mental issue")
for node in nodes:
    print(node.text)

Given birth previously, more than 5 minutes between contractions -> is_sub_criteria_of -> Full-term (week 37-40)
Acute Coronary Disease -> has_less_characteristic_symptoms -> Vague Symptoms
Acute Functional Impairment -> is characterized by -> Sequelae after diseases and injuries
Strong contractions -> has_code -> 1.2
AMK -> can assist -> Informant in notifying police
Low threshold for red response -> applies_to -> Severe conditions
Epileptics -> develop -> Cerebral Failure over Years
Assistance request -> must include -> Patient's condition
BRØSET Violence Checklist -> has_symptom -> Physical threats
Suspect criminal act -> has_code -> LVI
Emergency Medical Response -> includes_action -> Triple alerting / SAR alerting
Sudden Vision Impairment / Loss of Vision -> requires_advice -> This must be checked immediately by a doctor
Hernia -> can lead to -> Intestine trapped and compromised blood supply
Continuing Fever -> requires_advice -> Avoid overheating the person. Remove clothes if nec

In [51]:
# === Step 4: Convert nodes and edges to text for embedding ===
# Build a list of Document objects for embedding
documents = []
# for graph nodes
for node in graph.graph.nodes.values():
    doc_text = f" {node.name}"
    documents.append(Document(text=doc_text, doc_id=node.id))
# for graph edges
for rel in graph.graph.relations.values():
    edge_text = f"Edge from {rel.source_id} to {rel.target_id} with type: {rel.label}"
    edge_id = f"{rel.source_id}_{rel.label}_{rel.target_id}"
    documents.append(Document(text=edge_text, doc_id=edge_id))


In [55]:
documents[7550].text

'Edge from Normal blood sugar level to 4–10 mmol/L with type: is'

In [56]:
# === Step 5: Embed using OpenAI ===
vector_index = VectorStoreIndex.from_documents(documents, embed_model=embed_model)

# === Step 7: Save the Vector Store ===
vector_index.storage_context.persist(persist_dir="vector_store_all.json")

In [59]:


# === Step 6: Query Example ===
query_engine = vector_index.as_query_engine()
response = query_engine.query("fall from ladder")
print(response)


Context information is below.
---------------------
fall from standing height

falls from own height
---------------------
Given the context information and not prior knowledge, answer the query.
Query: fall from ladder
Answer: 
