In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import networkx as nx

In [5]:
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Set
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer
import json

  from .autonotebook import tqdm as notebook_tqdm


vLLM Client

In [6]:
from pydantic import BaseModel, Field
from typing import List

class ReasoningStep(BaseModel):
    step: str = Field(..., description="A reasoning step in the planning process")
    required_info: List[str] = Field(..., description="Types of information needed for this step")

class QueryPlan(BaseModel):
    reasoning_steps: List[ReasoningStep] = Field(..., description="Step-by-step plan to answer the query")
    key_concepts: List[str] = Field(..., description="Key concepts that need to be found in the knowledge graph")
    search_strategy: str = Field(..., description="Strategy for searching the knowledge graph")
    expected_answer_type: str = Field(..., description="What type of answer is expected (causal, descriptive, comparative, etc.)")

class NotebookEntry(BaseModel):
    source_node_id: str = Field(..., description="ID of the node this information came from")
    information: str = Field(..., description="Key information extracted from the node")
    relevance_score: float = Field(..., description="How relevant this information is (0-1)")
    information_type: str = Field(..., description="Type of information (causal, descriptive, statistical, etc.)")

class ExplorationDecision(BaseModel):
    should_continue: bool = Field(..., description="Whether to continue exploring")
    reasoning: str = Field(..., description="Reasoning for the decision")
    next_nodes_to_explore: List[str] = Field(default=[], description="Specific node IDs to explore next")
    exploration_strategy: str = Field(..., description="How to explore next (neighbors, keywords, specific_nodes)")
    information_gaps: List[str] = Field(default=[], description="What information is still needed")

class FinalAnswer(BaseModel):
    reasoning_steps: List[str] = Field(..., description="Step-by-step reasoning using gathered information")
    answer: str = Field(..., description="Final comprehensive answer to the question")
    confidence: float = Field(..., description="Confidence score (0-1) in the answer")
    sources: List[str] = Field(..., description="Node IDs used as sources for the answer")
    information_completeness: float = Field(..., description="How complete the gathered information is (0-1)")

answerString = \
"""
class Reasoning_Step(BaseModel):
    reasoning_step: str = Field(..., description="An intermediate reasoning step for breaking down the given context and query")

class Answer(BaseModel):
    reasoning: List[Reasoning_Step] = Field(..., description="List of reasoning steps")
    conclusion: bool = Field(..., description="The culminating final conclusion or answer to the question")
"""    

In [7]:
# from sentence_transformers import SentenceTransformer

class IterativeKnowledgeGraphAgent:
    def __init__(self, gml_file_path: str, vllm_client, tokenizer_name: str = "Qwen/Qwen2.5-7B-Instruct-AWQ", max_iterations: int = 5):
        """
        Initialize the iterative Knowledge Graph QA agent
        
            tokenizer_name: Name of the tokenizer to use
            max_iterations: Maximum number of exploration iterations
        """
        self.graph = nx.read_gml(gml_file_path)
        self.vllm_client = vllm_client
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.max_iterations = max_iterations
        
        self.notebook: List[NotebookEntry] = []
        self.explored_nodes: Set[str] = set()
        self.current_iteration = 0
        
        self.entities = {}
        self.documents = {}
        # self.communities = {}
        self._index_nodes()
        self.query_embedding = None
        self.community = self._build_community_index
        # self.embedder = SentenceTransformer('all-MiniLM-L6-v2')  

    
    # def embed(self, text: str):
    #     return self.embedder.encode(text)
    
    def _index_nodes(self):
        """Index nodes by their types for efficient retrieval"""
        for node_id, node_data in self.graph.nodes(data=True):
            labels = node_data.get('labels', [])
            
            if '__Entity__' in labels or 'Person' in labels:
                self.entities[node_id] = node_data
            elif 'Document' in labels:
                self.documents[node_id] = node_data
            # elif '__Community__' in labels:
            #     self.communities[node_id] = node_data

    def _build_community_index(self) -> Dict[str, List[str]]:
        community_map = {}
        for source, target, data in self.graph.edges(data=True):
            if data.get("type") == "IN_COMMUNITY":
                community_map.setdefault(target, []).append(source)
        return community_map
    
    def _create_prompt(self, system_message: str, user_message: str, schema: str) -> str:
        """Create a formatted prompt for the LLM"""
        return self.tokenizer.apply_chat_template(
            [
                {"role": "system", "content": f"{system_message}\n\nYou MUST adhere to this schema:\n{schema}"},
                {"role": "user", "content": user_message},
            ],
            tokenize=False,
            add_bos=True,
            add_generation_prompt=True,
        )
    
    def reset_agent_state(self):
        """Reset the agent's state for a new question"""
        self.notebook = []
        self.explored_nodes = set()
        self.current_iteration = 0
        # self.query_embedding = None

    
    def create_query_plan(self, question: str) -> QueryPlan:
        """Create a strategic plan for answering the given question"""
        system_message = """You are an expert knowledge graph exploration agent. Create a systematic plan 
        for answering questions using iterative graph exploration. Focus on what information you need to find 
        and how to search for it effectively."""
        
        user_message = f"""
        Question: {question}
        
        Create a detailed exploration plan for this question. The knowledge graph contains:
        - Entities: Specific concepts, people, conditions, treatments, etc.
        - Documents: Research papers and larger text chunks  
        - Relationships: CAUSES, IS_ASSOCIATED_WITH, MENTIONS, IN_COMMUNITY
        
        Your plan should guide iterative exploration to gather comprehensive information.
        """
        
        schema = """
        class ReasoningStep(BaseModel):
            step: str = Field(..., description="A reasoning step in the planning process")
            required_info: List[str] = Field(..., description="Types of information needed for this step")

        class QueryPlan(BaseModel):
            reasoning_steps: List[ReasoningStep] = Field(..., description="Step-by-step plan to answer the query")
            key_concepts: List[str] = Field(..., description="Key concepts that need to be found in the knowledge graph")
            search_strategy: str = Field(..., description="Strategy for searching the knowledge graph")
            expected_answer_type: str = Field(..., description="What type of answer is expected (causal, descriptive, comparative, etc.)")
        """
        
        prompt = self._create_prompt(system_message, user_message, schema)
        
        original_schema = self.vllm_client.schema
        self.vllm_client.schema = QueryPlan
        
        result = self.vllm_client(prompt, sampling_params={
            "n": 1, "min_tokens": 100, "max_tokens": 800, "temperature": 0.1
        })
        
        self.vllm_client.schema = original_schema
        return result
    
    def find_initial_nodes(self, plan: QueryPlan, top_k: int = 10) -> List[str]:
        """Find initial nodes to start exploration based on the query plan"""
        keywords = plan.key_concepts.copy()
        for step in plan.reasoning_steps:
            keywords.extend(step.required_info)
        
        relevant_nodes = []
        keywords_lower = [kw.lower() for kw in keywords]
        # if self.query_embedding is None:
        #     query_text = " ".join(plan.key_concepts + sum((step.required_info for step in plan.reasoning_steps), []))
        #     self.query_embedding = np.array(self.embed(query_text)).reshape(1, -1)
        
        for node_id, node_data in self.graph.nodes(data=True):
            score = 0
            searchable_text = ""
            
            for field in ['description', 'text', 'summary', 'full_content']:
                if field in node_data:
                    searchable_text += " " + str(node_data[field])
            searchable_text = searchable_text.lower()
            
            for keyword in keywords_lower:
                if keyword in searchable_text:
                    score += searchable_text.count(keyword)

            # # add cosine similarity if embeddings are available
            # if 'embedding' in node_data:
            #     embedding = np.array(node_data['embedding']).reshape(1, -1)
            #     sim = cosine_similarity(self.query_embedding, embedding)[0][0]
            #     score += sim * 5
            
            if score > 0:
                relevant_nodes.append((node_id, score))
        
        relevant_nodes.sort(key=lambda x: x[1], reverse=True)
        return [node_id for node_id, _ in relevant_nodes[:top_k]]
    
    def extract_information_from_node(self, node_id: str, question: str, plan: QueryPlan) -> Optional[NotebookEntry]:
        """Extract relevant information from a specific node"""
        if node_id not in self.graph:
            return None
        
        node_data = self.graph.nodes[node_id]
        
        node_info = f"Node ID: {node_id}\n"
        node_info += f"Labels: {node_data.get('labels', [])}\n"
        
        if 'description' in node_data:
            node_info += f"Description: {node_data['description']}\n"
        if 'text' in node_data:
            node_info += f"Text: {node_data['text'][:1000]}{'...' if len(str(node_data['text'])) > 1000 else ''}\n"
        if 'summary' in node_data:
            node_info += f"Summary: {node_data['summary']}\n"
        if 'full_content' in node_data:
            node_info += f"Full Content: {str(node_data['full_content'])[:500]}{'...' if len(str(node_data.get('full_content', ''))) > 500 else ''}\n"
        
        
        neighbors = list(self.graph.neighbors(node_id))
        if neighbors:
            node_info += f"Connected to {len(neighbors)} other nodes\n"
        
        system_message = """You are an expert information extractor. Extract the most relevant and useful 
        information from the given node that helps answer the question. Focus on key facts, relationships, 
        and insights."""
        
        user_message = f"""
        Question: {question}
        Query Plan: {plan.model_dump_json(indent=2)}
        
        Node Information:
        {node_info}
        
        Extract the most relevant information from this node. Determine its relevance score and information type.
        """
        
        schema = """
        class NotebookEntry(BaseModel):
            source_node_id: str = Field(..., description="ID of the node this information came from")
            information: str = Field(..., description="Key information extracted from the node")
            relevance_score: float = Field(..., description="How relevant this information is (0-1)")
            information_type: str = Field(..., description="Type of information (causal, descriptive, statistical, etc.)")
        """
        
        prompt = self._create_prompt(system_message, user_message, schema)
        
        original_schema = self.vllm_client.schema
        self.vllm_client.schema = NotebookEntry
        
        try:
            result = self.vllm_client(prompt, sampling_params={
                "n": 1, "min_tokens": 50, "max_tokens": 400, "temperature": 0.1
            })
            self.vllm_client.schema = original_schema
            return result
        except Exception as e:
            print(f"Error extracting from node {node_id}: {e}")
            self.vllm_client.schema = original_schema
            return None
    
    def decide_next_exploration(self, question: str, plan: QueryPlan) -> ExplorationDecision:
        """Decide whether to continue exploring and what to explore next"""
        
        notebook_summary = "\n".join([
            f"- {entry.information} (relevance: {entry.relevance_score:.2f}, type: {entry.information_type})"
            for entry in self.notebook
        ])
        
        system_message = """You are an expert research agent. Based on the information gathered so far, 
        decide whether you have enough information to answer the question or if you need to explore more. 
        If exploring more, specify what nodes or areas to focus on next."""
        
        user_message = f"""
        Question: {question}
        Query Plan: {plan.model_dump_json(indent=2)}
        Current Iteration: {self.current_iteration + 1}/{self.max_iterations}
        
        Information Gathered So Far:
        {notebook_summary if notebook_summary else "No information gathered yet"}
        
        Explored Nodes: {list(self.explored_nodes)}
        
        Should you continue exploring? If yes, what should you explore next?
        """
        
        schema = """
        class ExplorationDecision(BaseModel):
            should_continue: bool = Field(..., description="Whether to continue exploring")
            reasoning: str = Field(..., description="Reasoning for the decision")
            next_nodes_to_explore: List[str] = Field(default=[], description="Specific node IDs to explore next")
            exploration_strategy: str = Field(..., description="How to explore next (neighbors, keywords, specific_nodes)")
            information_gaps: List[str] = Field(default=[], description="What information is still needed")
        """
        
        prompt = self._create_prompt(system_message, user_message, schema)
        
        original_schema = self.vllm_client.schema
        self.vllm_client.schema = ExplorationDecision
        
        result = self.vllm_client(prompt, sampling_params={
            "n": 1, "min_tokens": 100, "max_tokens": 600, "temperature": 0.2
        })
        
        self.vllm_client.schema = original_schema
        return result
    
    #edges arent directional right?
    def get_neighbor_nodes(self, node_ids: List[str], max_neighbors: int = 15) -> List[str]:
        """Get neighboring nodes for further exploration"""
        neighbors = set()
        
        for node_id in node_ids:
            if node_id in self.graph:
                node_neighbors = list(self.graph.neighbors(node_id))
                neighbors.update(node_neighbors)
        neighbors -= self.explored_nodes        
        return list(neighbors)[:max_neighbors]
    
    def generate_final_answer(self, question: str, plan: QueryPlan) -> FinalAnswer:
        """Generate the final answer using all gathered information"""
        
        sorted_entries = sorted(self.notebook, key=lambda x: x.relevance_score, reverse=True)
        
        notebook_content = ""
        for i, entry in enumerate(sorted_entries, 1):
            notebook_content += f"{i}. Source: Node {entry.source_node_id}\n"
            notebook_content += f"   Information: {entry.information}\n"
            notebook_content += f"   Type: {entry.information_type}, Relevance: {entry.relevance_score:.2f}\n\n"
        
        system_message = """You are an expert researcher synthesizing information to provide a comprehensive answer. 
        Use all the gathered information from your notebook to construct a well-reasoned, complete response."""
        
        user_message = f"""
        Question: {question}
        Query Plan: {plan.model_dump_json(indent=2)}
        
        Information from Knowledge Graph Exploration:
        {notebook_content}
        
        Total iterations completed: {self.current_iteration}
        Total nodes explored: {len(self.explored_nodes)}
        
        Provide a comprehensive answer with clear reasoning steps, confidence assessment, and completeness evaluation.
        """
        
        schema = """
        class FinalAnswer(BaseModel):
            reasoning_steps: List[str] = Field(..., description="Step-by-step reasoning using gathered information")
            answer: str = Field(..., description="Final comprehensive answer to the question")
            confidence: float = Field(..., description="Confidence score (0-1) in the answer")
            sources: List[str] = Field(..., description="Node IDs used as sources for the answer")
            information_completeness: float = Field(..., description="How complete the gathered information is (0-1)")
        """
        
        prompt = self._create_prompt(system_message, user_message, schema)
        
        original_schema = self.vllm_client.schema
        self.vllm_client.schema = FinalAnswer
        
        result = self.vllm_client(prompt, sampling_params={
            "n": 1, "min_tokens": 300, "max_tokens": 2000, "temperature": 0.1
        })
        
        self.vllm_client.schema = original_schema
        return result
    
    def answer_question(self, question: str) -> Dict[str, Any]:
        """
        Complete iterative pipeline to answer a question using the knowledge graph
        """
        print(f"Starting iterative exploration for: {question}")
        
        self.reset_agent_state()
        
        # Step 1: Create query plan
        print("Creating query plan...")
        plan = self.create_query_plan(question)
        print(f"Plan created with {len(plan.reasoning_steps)} steps")
        
        # Step 2: Find initial nodes
        print("  Finding initial nodes...")
        initial_nodes = self.find_initial_nodes(plan, top_k=8)
        print(f"   Found {len(initial_nodes)} initial nodes to explore")
        
        exploration_log = []
        
        # Step 3: Iterative exploration
        while self.current_iteration < self.max_iterations:
            print(f"\nIteration {self.current_iteration + 1}/{self.max_iterations}")
            
            # Determine nodes to explore this iteration
            if self.current_iteration == 0:
                nodes_to_explore = initial_nodes
            else:
                # Make exploration decision
                decision = self.decide_next_exploration(question, plan)
                exploration_log.append(decision)
                
                if not decision.should_continue:
                    print(f"Agent decided to stop: {decision.reasoning}")
                    break
                
                if decision.next_nodes_to_explore:
                    nodes_to_explore = decision.next_nodes_to_explore
                elif decision.exploration_strategy == "neighbors":
                    # Explore neighbors of high-relevance nodes
                    high_relevance_nodes = [entry.source_node_id for entry in self.notebook 
                                          if entry.relevance_score > 0.7]
                    nodes_to_explore = self.get_neighbor_nodes(high_relevance_nodes or [entry.source_node_id for entry in self.notebook[-3:]])
                else:
                    # Find new nodes based on information gaps
                    nodes_to_explore = self.find_initial_nodes(plan, top_k=5)
            
            nodes_to_explore = [n for n in nodes_to_explore if n not in self.explored_nodes]
            
            if not nodes_to_explore:
                print("No new nodes to explore")
                break
            
            print(f"Exploring {len(nodes_to_explore)} nodes...")
            
            # Extract information from nodes
            for node_id in nodes_to_explore[:5]:  # Limit to 5 nodes per iter
                if node_id not in self.explored_nodes:
                    entry = self.extract_information_from_node(node_id, question, plan)
                    if entry and entry.relevance_score > 0.3: 
                        self.notebook.append(entry)
                        print(f"     Added info from node {node_id} (relevance: {entry.relevance_score:.2f})")
                    
                    self.explored_nodes.add(node_id)
            
            self.current_iteration += 1
        
        print("\n Generating final answer...")
        final_answer = self.generate_final_answer(question, plan)
        
        print(f" Exploration Summary:")
        print(f"   - Total iterations: {self.current_iteration}")
        print(f"   - Nodes explored: {len(self.explored_nodes)}")
        print(f"   - Information gathered: {len(self.notebook)} entries")
        print(f"   - Final confidence: {final_answer.confidence:.2f}")
        
        return {
            "question": question,
            "plan": plan,
            "exploration_log": exploration_log,
            "notebook": self.notebook,
            "explored_nodes": list(self.explored_nodes),
            "iterations_completed": self.current_iteration,
            "final_answer": final_answer
        }




Example

In [10]:
from langchain_community.graphs.graph_document import GraphDocument
from langchain_core.documents import Document
from retry import retry
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from prompts import *
from vllm_client import VLLMClient

# llm_transformer = llm.LLMGraphTransformer(
#     model=None,
#     prompt=ontology_prompt, # declared up top
#     allowed_relationships=["CAUSES", "IS_ASSOCIATED_WITH"],
#     node_properties=["description"],
#     relationship_properties=["description"]
# )


In [11]:
client = VLLMClient(schema=None)

In [12]:
agent = IterativeKnowledgeGraphAgent(
    gml_file_path="graph_dump.gml",
    vllm_client=client,
    max_iterations=5
)

question = "What is the causal relationship between insomnia and chronic pain?"
result = agent.answer_question(question)

print("=== FINAL ANSWER ===")
print(result["final_answer"].answer)
print(f"Confidence: {result['final_answer'].confidence}")

print("\n=== EXPLORATION NOTEBOOK ===")
for entry in result["notebook"]:
    print(f"Node {entry.source_node_id}: {entry.information} (score: {entry.relevance_score})")


Starting iterative exploration for: What is the causal relationship between insomnia and chronic pain?
Creating query plan...
Plan created with 4 steps
  Finding initial nodes...
   Found 8 initial nodes to explore

Iteration 1/5
Exploring 8 nodes...
     Added info from node 93 (relevance: 0.85)
     Added info from node 335 (relevance: 0.85)
     Added info from node 336 (relevance: 0.85)
     Added info from node 337 (relevance: 0.85)
     Added info from node 338 (relevance: 0.85)

Iteration 2/5
Exploring 3 nodes...

Iteration 3/5
No new nodes to explore

 Generating final answer...
 Exploration Summary:
   - Total iterations: 2
   - Nodes explored: 8
   - Information gathered: 5 entries
   - Final confidence: 0.85
=== FINAL ANSWER ===
The causal relationship between insomnia and chronic pain is well-established through multiple research studies and clinical trials. Insomnia is a significant risk factor for various pain conditions, including back pain, headaches, and stomach pain. 