Basic Graph Retrieval

In [None]:
from typing import List, Set, Dict, Any, Optional, Tuple
from pydantic import BaseModel, Field
from collections import defaultdict
import numpy as np
from dataclasses import dataclass

class Neo4jGraphRetriever:
    def __init__(self, neo4j_graph):
        """
        Initialize the Neo4j graph retriever

        Args:
            neo4j_graph: Your Neo4j graph object
        """
        self.graph = neo4j_graph
        # Create a lowercase mapping for better entity matching
        self.node_mapping = self._create_node_mapping()

    def _create_node_mapping(self) -> Dict[str, List[str]]:
        """Create a mapping from lowercase text to actual node names for better matching"""
        mapping = defaultdict(list)

        # Query to get all nodes and their properties
        query = """
        MATCH (n)
        RETURN n.name as name, labels(n) as labels, id(n) as id
        """

        try:
            result = self.graph.query(query)
            for record in result:
                node_name = record.get('name', '')
                node_labels = record.get('labels', [])
                node_id = record.get('id', '')

                # Create searchable text from name and labels
                searchable_texts = []
                if node_name:
                    searchable_texts.append(str(node_name))
                searchable_texts.extend([str(label) for label in node_labels])

                for text in searchable_texts:
                    if text:
                        text_lower = text.lower()
                        mapping[text_lower].append(node_name or f"node_{node_id}")

                        # Also add partial matches (split by spaces, hyphens, etc.)
                        words = text_lower.replace('-', ' ').replace('_', ' ').split()
                        for word in words:
                            if len(word) > 2:  # Avoid very short words
                                mapping[word].append(node_name or f"node_{node_id}")

        except Exception as e:
            print(f"Error creating node mapping: {e}")
            # Fallback: try a simpler query
            try:
                simple_query = "MATCH (n) RETURN n LIMIT 10"
                result = self.graph.query(simple_query)
                print("Sample nodes:", result)
            except:
                pass

        return mapping

    def find_matching_nodes(self, entity: str) -> List[str]:
        """
        Find nodes in Neo4j graph that match the entity using Cypher queries

        Args:
            entity: Entity name to search for

        Returns:
            List of matching node names/identifiers
        """
        entity_lower = entity.lower()
        matching_nodes = set()

        # Method 1: Use pre-built mapping
        if entity_lower in self.node_mapping:
            matching_nodes.update(self.node_mapping[entity_lower])

        for node_key, node_list in self.node_mapping.items():
            if entity_lower in node_key or node_key in entity_lower:
                matching_nodes.update(node_list)

        # Method 2: Direct Neo4j search (more flexible)
        try:
            # Search by name property
            query1 = f"""
            MATCH (n)
            WHERE toLower(n.name) CONTAINS toLower($entity)
            RETURN n.name as name LIMIT 10
            """
            result1 = self.graph.query(query1, params={"entity": entity})
            for record in result1:
                if record.get('name'):
                    matching_nodes.add(record['name'])

            # Search by labels
            query2 = f"""
            MATCH (n)
            WHERE any(label IN labels(n) WHERE toLower(label) CONTAINS toLower($entity))
            RETURN n.name as name LIMIT 10
            """
            result2 = self.graph.query(query2, params={"entity": entity})
            for record in result2:
                if record.get('name'):
                    matching_nodes.add(record['name'])

        except Exception as e:
            print(f"Error in Neo4j search: {e}")

        return list(matching_nodes)

    def k_hop_expansion(self, seed_nodes: List[str], k_hops: int = 2) -> Set[str]:
        """
        Expand k-hops from seed nodes using Neo4j

        Args:
            seed_nodes: Starting nodes
            k_hops: Number of hops to expand

        Returns:
            Set of all nodes within k hops
        """
        if not seed_nodes:
            return set()

        try:
            # Use parameterized query to avoid string formatting issues
            query = f"""
            MATCH (start)
            WHERE start.name IN $seed_names
            CALL {{
                WITH start
                MATCH (start)-[*1..{k_hops}]-(connected)
                RETURN connected.name as name
            }}
            RETURN DISTINCT name
            """

            result = self.graph.query(query, params={"seed_names": seed_nodes})
            expanded_nodes = set()

            for record in result:
                node_name = record.get('name')
                if node_name:
                    expanded_nodes.add(node_name)

            # Add original seed nodes
            expanded_nodes.update(seed_nodes)

            return expanded_nodes

        except Exception as e:
            print(f"Error in k-hop expansion: {e}")
            return set(seed_nodes)

    def retrieve_subgraph_info(self, entities: List[str], k_hops: int = 2) -> Dict[str, Any]:
        """
        Main retrieval function: get relevant subgraph information from Neo4j

        Args:
            entities: List of extracted entities from query
            k_hops: Number of hops to expand

        Returns:
            Dictionary containing subgraph info and metadata
        """
        # Step 1: Find matching nodes for each entity
        all_seed_nodes = []
        entity_matches = {}

        print("Searching for entity matches...")
        for entity in entities:
            matching_nodes = self.find_matching_nodes(entity)
            entity_matches[entity] = matching_nodes
            all_seed_nodes.extend(matching_nodes)

        print(f"Entity matches found: {entity_matches}")

        if not all_seed_nodes:
            return {
                "nodes": [],
                "relationships": [],
                "entity_matches": entity_matches,
                "message": "No matching nodes found in graph"
            }

        # Step 2: K-hop expansion
        expanded_nodes = self.k_hop_expansion(all_seed_nodes, k_hops)
        print(f"Expanded to {len(expanded_nodes)} nodes after {k_hops} hops")

        # Step 3: Get subgraph with relationships
        try:
            # Use parameterized query to avoid string formatting issues
            subgraph_query = """
            MATCH (n)-[r]-(m)
            WHERE n.name IN $node_names AND m.name IN $node_names
            RETURN n.name as source, type(r) as relationship, m.name as target,
                   n as source_node, m as target_node
            """

            result = self.graph.query(subgraph_query, params={"node_names": list(expanded_nodes)})

            relationships = []
            nodes_info = {}

            for record in result:
                source = record.get('source')
                target = record.get('target')
                rel_type = record.get('relationship')

                if source and target and rel_type:
                    relationships.append({
                        'source': source,
                        'target': target,
                        'relationship': rel_type
                    })

                # Store node information
                if source:
                    nodes_info[source] = record.get('source_node', {})
                if target:
                    nodes_info[target] = record.get('target_node', {})

            return {
                "nodes": list(expanded_nodes),
                "relationships": relationships,
                "nodes_info": nodes_info,
                "entity_matches": entity_matches,
                "seed_nodes": all_seed_nodes,
                "total_nodes": len(expanded_nodes),
                "total_relationships": len(relationships)
            }

        except Exception as e:
            print(f"Error retrieving subgraph: {e}")
            return {
                "nodes": list(expanded_nodes),
                "relationships": [],
                "entity_matches": entity_matches,
                "seed_nodes": all_seed_nodes,
                "total_nodes": len(expanded_nodes),
                "total_relationships": 0,
                "error": str(e)
            }

Answer generation

In [None]:
# Required imports
from typing import List, Dict, Any
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from dataclasses import dataclass

@dataclass
class AnswerResult:
    answer: str
    confidence: float
    knowledge_used: Dict[str, Any]
    reasoning_steps: List[str]

class GraphAnswerGenerator:
    def __init__(self):
        """Initialize the answer generator with Llama 3.3 via ChatGroq"""
        self.llm = ChatGroq(
            model="qwen/qwen3-32b",
            temperature=0,  # Low temperature for factual accuracy
        )

    def format_graph_knowledge(self, retrieval_result) -> str:
        """
        Convert graph retrieval results into structured text format

        Args:
            retrieval_result: Output from your Neo4jGraphRetriever (dict format)

        Returns:
            Formatted knowledge string
        """
        knowledge_sections = []

        # Section 1: Key Medical Entities Found
        if retrieval_result.get('nodes'):
            entities_text = "MEDICAL ENTITIES:\n"
            # Group nodes by likely type (you could enhance this with node properties)
            for i, node in enumerate(retrieval_result['nodes'][:50], 1):  # Limit for context
                entities_text += f"- {node}\n"
            knowledge_sections.append(entities_text)

        # Section 2: Medical Relationships
        if retrieval_result.get('relationships'):
            relationships_text = "MEDICAL RELATIONSHIPS:\n"
            # Sort by importance if available
            sorted_rels = sorted(
                retrieval_result['relationships'],
                key=lambda x: x.get('importance', 0.5),
                reverse=True
            )

            for rel in sorted_rels[:20]:  # Top 20 relationships
                source = rel.get('source', 'Unknown')
                target = rel.get('target', 'Unknown')
                rel_type = rel.get('relationship', 'RELATED_TO')

                # Convert relationship types to natural language
                rel_description = self._relationship_to_text(rel_type)
                relationships_text += f"- {source} {rel_description} {target}\n"

            knowledge_sections.append(relationships_text)

        # Section 3: Entity Matching Context
        if retrieval_result.get('entity_matches'):
            matching_text = "ENTITY CONTEXT:\n"
            for entity, matched_nodes in retrieval_result['entity_matches'].items():
                if matched_nodes:
                    matching_text += f"- Query term '{entity}' relates to: {', '.join(matched_nodes[:3])}\n"
            knowledge_sections.append(matching_text)

        # Section 4: Connection Paths (if available)
        if retrieval_result.get('path_info') and retrieval_result['path_info'].get('paths'):
            paths_text = "KEY CONNECTIONS:\n"
            for path in retrieval_result['path_info']['paths'][:5]:
                start = path.get('start', '')
                end = path.get('end', '')
                relationships = path.get('relationships', [])
                path_desc = ' → '.join(relationships) if relationships else 'connected to'
                paths_text += f"- {start} is {path_desc} {end}\n"
            knowledge_sections.append(paths_text)

        return "\n".join(knowledge_sections)

    def _relationship_to_text(self, rel_type: str) -> str:
        """
        Convertit un type de relation brut (Neo4j) en texte naturel :
        - Met tout en minuscules
        - Remplace les underscores '_' par des espaces
        """

        return rel_type.lower().replace('_', ' ')

    def create_answer_prompt(self) -> ChatPromptTemplate:
        """Create the prompt template for answer generation"""

        prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a medical expert specializing in Type 2 Diabetes. Your role is to provide accurate, helpful answers based on the provided medical knowledge graph information.

INSTRUCTIONS:
1. Answer the question using ONLY the provided medical knowledge
2. Be precise , factual and explicative - avoid speculation
3. Organize your answer logically (causes → mechanisms → treatments → outcomes)
4. If the knowledge is insufficient, clearly state what information is missing
5. Use medical terminology appropriately but explain complex concepts
6. Focus on practical, clinically relevant information

RESPONSE FORMAT:
- Direct answer first
- Supporting details with mechanisms/relationships
- Practical implications or recommendations
- Note any limitations in the available knowledge"""),

            ("human", """Based on the following medical knowledge from a Type 2 Diabetes knowledge graph, please answer this question:

QUESTION: {question}

AVAILABLE MEDICAL KNOWLEDGE:
{formatted_knowledge}

RETRIEVAL QUALITY: {confidence_score}/1.0
REASONING CONTEXT: {retrieval_reasoning}

Please provide a comprehensive answer based on this knowledge.""")
        ])

        return prompt

    def generate_answer(self, question: str, retrieval_result) -> AnswerResult:
        """
        Generate answer using structured knowledge from graph retrieval

        Args:
            question: Original user question
            retrieval_result: Output from Neo4jGraphRetriever (dict format)

        Returns:
            AnswerResult with generated answer and metadata
        """

        # Step 1: Format the knowledge
        formatted_knowledge = self.format_graph_knowledge(retrieval_result)

        # Step 2: Create the prompt
        prompt = self.create_answer_prompt()

        # Step 3: Prepare prompt variables
        prompt_vars = {
            "question": question,
            "formatted_knowledge": formatted_knowledge,
            "confidence_score": retrieval_result.get('subgraph_score', 0.5),
            "retrieval_reasoning": retrieval_result.get('reasoning', 'Knowledge retrieved from graph database')
        }

        # Step 4: Generate answer
        try:
            chain = prompt | self.llm
            response = chain.invoke(prompt_vars)
            answer_text = response.content

            # Step 5: Estimate answer confidence based on retrieval quality
            answer_confidence = self._calculate_answer_confidence(
                retrieval_result, formatted_knowledge, answer_text
            )

            # Step 6: Extract reasoning steps (basic implementation)
            reasoning_steps = self._extract_reasoning_steps(answer_text)

            return AnswerResult(
                answer=answer_text,
                confidence=answer_confidence,
                knowledge_used={
                    "entities_count": len(retrieval_result.get('nodes', [])),
                    "relationships_count": len(retrieval_result.get('relationships', [])),
                    "retrieval_score": retrieval_result.get('subgraph_score', 0.5),
                    "entity_matches": len(retrieval_result.get('entity_matches', {}))
                },
                reasoning_steps=reasoning_steps
            )

        except Exception as e:
            return AnswerResult(
                answer=f"Error generating answer: {str(e)}",
                confidence=0.0,
                knowledge_used={},
                reasoning_steps=["Error occurred during generation"]
            )

    def _calculate_answer_confidence(self, retrieval_result, formatted_knowledge: str, answer: str) -> float:
        """Calculate confidence score for the generated answer"""

        # Factor 1: Retrieval quality
        retrieval_confidence = retrieval_result.get('subgraph_score', 0.5)

        # Factor 2: Knowledge richness
        knowledge_richness = min(len(formatted_knowledge) / 1000, 1.0)  # Normalize to 0-1

        # Factor 3: Answer completeness (basic heuristic)
        answer_completeness = min(len(answer) / 500, 1.0) if len(answer) > 50 else 0.3

        # Factor 4: Entity coverage
        entity_matches = retrieval_result.get('entity_matches', {})
        entities_covered = len(entity_matches) / max(len(entity_matches), 1)

        # Weighted average
        total_confidence = (
            0.4 * retrieval_confidence +
            0.3 * knowledge_richness +
            0.2 * answer_completeness +
            0.1 * entities_covered
        )

        return min(total_confidence, 1.0)

    def _extract_reasoning_steps(self, answer: str) -> List[str]:
        """Extract basic reasoning steps from the answer"""
        # Simple implementation - you could make this more sophisticated
        sentences = answer.split('. ')
        reasoning_steps = []

        for sentence in sentences[:5]:  # Take first 5 sentences as reasoning steps
            if len(sentence.strip()) > 20:
                reasoning_steps.append(sentence.strip())

        return reasoning_steps

# ==================== COMPLETE PIPELINE TEST ====================

def test_complete_pipeline(neo4j_graph):
    """Test the complete end-to-end pipeline"""

    retriever = Neo4jGraphRetriever(graph)  
    answer_generator = GraphAnswerGenerator()

    # Test question
    test_question = "What are the most effective medications for controlling HbA1c levels in newly diagnosed Type 2 diabetes patients?"

    print(f"QUESTION: {test_question}")
    print("="*80)

    # Step 1: Extract entities (your existing code)
    entity_result = entity_chain.invoke({"question": test_question})
    entities = entity_result.names
    print(f"ENTITIES: {entities}")

    # Step 2: Retrieve relevant subgraph info
    retrieval_result = retriever.retrieve_subgraph_info(entities, k_hops=2)

    # Step 3: Generate answer
    answer_result = answer_generator.generate_answer(test_question, retrieval_result)

    print(f"\n GENERATED ANSWER:")
    print("-"*50)
    print(answer_result.answer)
    print("-"*50)

    return answer_result

Tests

In [None]:
# Initialize
retriever = Neo4jGraphRetriever(graph)
answer_generator = GraphAnswerGenerator()

question = "What are the first warning signs and symptoms of type 2 diabetes?"

# 1. Extract entities
entities = entity_chain.invoke({"question": question}).names

# 2. Retrieve graph knowledge
retrieval_result = retriever.retrieve_subgraph_info(entities, k_hops=2)

# 3. Generate answer
answer_result = answer_generator.generate_answer(question, retrieval_result)

print(f"Answer: {answer_result.answer}")
print(f"Confidence: {answer_result.confidence:.2f}")

In [None]:
retriever = Neo4jGraphRetriever(graph)
answer_generator = GraphAnswerGenerator()

question = "What are the first warning signs and symptoms of type 2 diabetes?"

# 1. Extract entities
entities = entity_chain.invoke({"question": question}).names

# 2. Retrieve graph knowledge
retrieval_result = retriever.retrieve_subgraph_info(entities, k_hops=2)

# 3. Generate answer
answer_result = answer_generator.generate_answer(question, retrieval_result)

print(f"Answer: {answer_result.answer}")
print(f"Confidence: {answer_result.confidence:.2f}")

In [None]:
retriever = Neo4jGraphRetriever(graph)
answer_generator = GraphAnswerGenerator()

question = "What are three ways to prevent type 2 diabetes? "

# 1. Extract entities
entities = entity_chain.invoke({"question": question}).names

# 2. Retrieve graph knowledge
retrieval_result = retriever.retrieve_subgraph_info(entities, k_hops=2)

# 3. Generate answer
answer_result = answer_generator.generate_answer(question, retrieval_result)

print(f"Answer: {answer_result.answer}")
print(f"Confidence: {answer_result.confidence:.2f}")

<h4> Final Graph retriever

Find relevent relations to explore for every question

In [None]:
from typing import List, Dict
import json
from groq import Groq 

RELATIONS = [
    "ACTIVATED_BY", "ACTIVATES", "ACTS_ON", "ACTS_THROUGH", "ADVERSELY_AFFECTS",
    "AFFECTED_BY", "AFFECTS", "ALLEVIATES", "ALTERNATIVE_TO", "ASSESSED_BY",
    "ASSESSES", "ASSOCIATED_WITH", "ASSOCIATED_WITH_IMPROVED_MANAGEMENT",
    "BARRIER_TO", "BARRIERS_TO", "BASED_ON", "BENEFICIAL_FOR", "BENEFITS",
    "BENEFITS_FROM", "BINDS_TO", "BIOMARKER_FOR", "BIOMARKER_OF", "CAN_LEAD_TO",
    "CAN_PROGRESS_TO", "CAUSED_BY", "CAUSES", "CHARACTERISTIC_OF", "CHARACTERIZED_BY",
    "CHARACTERIZES", "CO_LOCALIZED_WITH", "CO_OCCURS_WITH", "CO-OCCURRING_WITH",
    "COEXISTS_WITH", "COMBINATION_THERAPY_WITH", "COMORBID_WITH", "COMPARABLE_TO",
    "COMPARED_TO", "COMPARED_WITH", "COMPARES_TO", "COMPLEMENTS", "COMPLICATED_BY",
    "COMPLICATED_WITH", "COMPLICATES", "COMPLICATES_TREATMENT_OF", "COMPLICATION_OF",
    "COMPONENT_OF", "COMPOSED_OF", "COMPRISES", "COMPROMISES", "CONSIDER", "CONTAINS",
    "CONTRAINDICATED_WITH", "CONTRIBUTES_TO", "CONTROLS", "CORRELATED_WITH",
    "COUNTERACTS", "DAMAGES", "DECREASES", "DECREASES_EXPRESSION_OF", "DECREASES_IN",
    "DECREASES_RISK_OF", "DECREASES_WITH", "DEFINES", "DEGRADED_BY", "DEGRADES",
    "DELIVERS", "DIAGNOSED_BY", "DIAGNOSES", "DIFFERENTIATED_FROM", "DISRUPTS",
    "DISTINGUISHED_FROM", "DOES_NOT_AFFECT", "DOES_NOT_REDUCE_RISK_OF",
    "DOES_NOT_SLOW_PROGRESSION_OF", "DOES_NOT_TREAT", "ENABLES", "ENCAPSULATES",
    "ENHANCES", "ENHANCES_EFFECTIVENESS", "EQUIVALENT_TO", "EVALUATES", "EXPLAINS",
    "EXPRESSED", "EXPRESSED_IN", "EXPRESSES", "FACILITATES", "FOCUSES_ON", "FUNDED",
    "FUNDED_BY", "FUNDS", "GUIDES", "HAS", "HAS_ANABOLIC_EFFECT_ON", "HAS_BIOMARKER",
    "HAS_COMORBIDITY", "HAS_COMPLICATION", "HAS_COMPONENT", "HAS_EFFECT_ON",
    "HAS_IMMUNOMODULATORY_EFFECTS_ON", "HAS_IMPACT_ON", "HAS_NEGATIVE_EFFECT_ON",
    "HAS_NO_EFFECT", "HAS_NO_EFFECT_ON", "HAS_NO_INFERIOR_RISK_OF",
    "HAS_POSITIVE_EFFECT_ON", "HAS_SIDE_EFFECT", "HAS_SUBTYPE", "HAS_SYMPTOM",
    "HINDERS", "IMPAIRS", "IMPLEMENTS", "IMPORTANT_FOR", "IMPROVED_BY", "IMPROVES",
    "IMPROVES_OUTCOME", "IMPROVES_OUTCOME_OF", "IMPROVES_OUTCOMES", "INCLUDES",
    "INCREASED_BY", "INCREASES", "INCREASES_ABUNDANCE_OF", "INCREASES_ACTIVITY_OF",
    "INCREASES_ADHERENCE_TO", "INCREASES_EXPRESSION_OF", "INCREASES_LEVEL_OF",
    "INCREASES_PRODUCTION_OF", "INCREASES_RISK_OF", "INCREASES_SECRETION_OF",
    "INCREASES_UTILIZATION_OF", "INCREASES_WITH", "INDICATES", "INDICATIVE_OF",
    "INDICATOR_OF", "INDUCES", "INFECTS", "INFLUENCED_BY", "INFLUENCES", "INFORMS",
    "INHIBITED_BY", "INHIBITS", "INTERACTS_WITH", "INVERSELY_ASSOCIATED_WITH",
    "INVERSELY_CORRELATED_WITH", "INVERSELY_RELATED_TO", "INVOLVED_IN", "IS", "IS_A",
    "IS_COMPLICATION_OF", "IS_TYPE_OF", "LEADS_TO", "LINKED_TO", "LOCATED_IN",
    "MAINTAINS", "MANAGED_BY", "MANAGED_WITH", "MANAGES", "MANUFACTURES",
    "MARKER_OF", "MAY_AFFECT", "MAY_CAUSE", "MAY_REDUCE_RISK_OF", "MEASURED_AS",
    "MEASURED_BY", "MEASURES", "MEDIATES", "MENTIONED_IN", "METABOLIC_PRODUCT_OF",
    "MIMICS", "MODEL_OF", "MODELS", "MODERATES", "MODIFIES", "MODULATES", "MONITORS",
    "NEGATIVELY_ASSOCIATED_WITH", "NEGATIVELY_CORRELATED_WITH", "NEGATIVELY_IMPACT",
    "NO_EFFECT_ON", "NOT_ASSOCIATED_WITH", "PART_OF", "PHOSPHORYLATES",
    "POSITIVELY_ASSOCIATED_WITH", "POSITIVELY_CORRELATED_WITH", "PREDICTIVE_OF",
    "PREDICTS", "PRESERVES", "PREVENTS", "PREVENTS_COMPLICATIONS", "PRODUCED_BY",
    "PRODUCES", "PROGRESSES_TO", "PROMOTES", "PROTECTIVE_AGAINST", "PROTECTS",
    "PROTECTS_AGAINST", "PROVIDES", "RECOMMENDED_FOR", "RECOMMENDS", "REDUCES",
    "REDUCES_EXPRESSION_OF", "REDUCES_RISK_OF", "REDUCES_SEVERITY_OF",
    "REDUCES_SYMPTOMS_OF", "REGULATED_BY", "REGULATES", "RELATED_TO", "REPLACES",
    "REQUIRES", "RESEARCHES", "RESPONDS_TO", "RISK_FACTOR_FOR", "SIDE_EFFECT_OF",
    "SIMILAR_TO", "SLOWS_PROGRESSION_OF", "STIMULATES", "SUBTYPE_OF", "SUPPORT",
    "SUPPORTED", "SUPPORTED_BY", "SUPPORTS", "SUPPRESSES", "TARGETS", "TESTED_IN",
    "TREATED_WITH", "TREATS", "TREATS_WITH", "TRIGGERS", "USED_FOR",
    "USED_FOR_DIAGNOSIS_OF", "USED_IN", "USED_IN_COMBINATION_WITH", "USED_TO_ASSESS",
    "USED_TO_DERIVE", "USED_TO_DIAGNOSE", "USED_TO_MONITOR", "USED_TO_PREDICT",
    "USED_TO_UNDERSTAND", "USES", "WORSENS"
]

def classify_relations(question: str, relations: List[str]) -> Dict[str, List[str]]:
    client = Groq(api_key="api_key")

    system_prompt = (
        "You are an expert in medical knowledge graphs for type 2 diabetes. "
        "Given a question and a large list of biomedical relations, "
        "identify:\n"
        "- PRIMARY_RELATIONS: Essential relations to answer the question directly.\n"
        "- SECONDARY_RELATIONS: Useful but non-essential.\n"
        "Output strictly in JSON with the format:\n"
        "{'primary_relations': [...], 'secondary_relations': [...]}."
    )

    examples = [
        {
            "question": "How does obesity lead to insulin resistance?",
            "answer": {
                "primary_relations": ["CAUSES","CAUSED_BY","LEADS_TO","CONTRIBUTES_TO","PROMOTES","INDUCES","INCREASES_RISK_OF","INVOLVED_IN"],
                "secondary_relations": ["ASSOCIATED_WITH","LINKED_TO","CORRELATED_WITH","INFLUENCES","AFFECTS","RELATED_TO","RISK_FACTOR_FOR","INCREASES","MEDIATES","MODULATES","TRIGGERS"]
            }
        },
        {
            "question": "What are the mechanisms by which GLP-1 receptor agonists improve cardiovascular outcomes in type 2 diabetes?",
            "answer": {
                "primary_relations": ["IMPROVES_OUTCOME","IMPROVES_OUTCOMES","IMPROVES_OUTCOME_OF","BENEFICIAL_FOR","REDUCES_RISK_OF","PROTECTS","PROTECTS_AGAINST","ENHANCES_EFFECTIVENESS","MECHANISM_OF_ACTION","MEDIATES","MODULATES"],
                "secondary_relations": ["ASSOCIATED_WITH","LINKED_TO","AFFECTS","HAS_POSITIVE_EFFECT_ON","INFLUENCES","INVOLVED_IN","SUPPORTS","PROMOTES","CONTRIBUTES_TO","ENHANCES"]
            }
        },
        {
            "question": "what are the symptoms of type 2 diabetes?",
            "answer": {
                "primary_relations": ["HAS_SYMPTOM","CHARACTERIZED_BY","INDICATES","INDICATIVE_OF","MARKER_OF"],
                "secondary_relations": ["ASSOCIATED_WITH","LINKED_TO","RELATED_TO","PREDICTIVE_OF","SUGGESTS"]
            }
        }
    ]

    few_shot_prompt = "Examples:\n"
    for ex in examples:
        few_shot_prompt += f"Q: {ex['question']}\nA: {json.dumps(ex['answer'], indent=2)}\n\n"

    user_prompt = f"""
    {few_shot_prompt}
    Question: {question}
    Predefined relations: {relations}

    Return ONLY JSON.
    """

    response = client.chat.completions.create(
        model="llama-3.3-70b-versatile",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0.0,
        response_format={"type": "json_object"}
    )

    output_text = response.choices[0].message.content.strip()
    try:
        parsed = json.loads(output_text.replace("'", '"'))
        return parsed
    except json.JSONDecodeError:
        print("⚠️ JSON parsing failed:", output_text)
        return {"primary_relations": [], "secondary_relations": []}


# Exemple d'utilisation
q1 = "What factors contribute to the development of diabetic neuropathy in patients with type 2 diabetes?"
print(classify_relations(q1, RELATIONS))


In [None]:
# Required imports
from typing import List
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq  

# Define the structured output schema
class Entities(BaseModel):
    """Identifying information about entities."""
    names: List[str] = Field(
        ...,
        description="All medical relevant entities related to Type 2 Diabetes",
    )

# Initialize your LLM 
llm = ChatGroq(
    model="llama-3.3-70b-versatile",   
    temperature=0,
)

# Few-shot prompt with examples
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a medical expert specializing in Type 2 Diabetes."
         "Your task is to extract structured knowledge from biomedical questions. "
         "Always extract all relevant entities from the text. "
        ),

        # --- FEW-SHOT EXAMPLES ---
        ("human", "Extract structured knowledge about Type 2 Diabetes from the following question: "
                  "How do epigenetic modifications and gut microbiome dysbiosis interact to influence insulin resistance and β-cell dysfunction in the progression of Type 2 Diabetes?"),
        ("ai", '{{"names": ["epigenetic modifications", "gut microbiome dysbiosis", "insulin resistance","β-cell dysfunction"]}}'),

        ("human", "Extract structured knowledge about Type 2 Diabetes from the following question: "
                  "What are common complications of Type 2 Diabetes?"),
        ("ai", '{{"names": ["Type 2 Diabetes"]}}'),

        # --- ACTUAL QUESTION ---
        ("human", "Extract structured knowledge about Type 2 Diabetes from the following question: {question}")
    ]
)

# Create chain with structured output
entity_chain = prompt | llm.with_structured_output(Entities)


# test question
result = entity_chain.invoke({"question": "Why might individuals with central obesity have a higher likelihood of developing complications related to Type 2 Diabetes?"})
print("Diagnostic tests:", result.names)

In [None]:
from typing import List, Set, Dict, Any, Optional, Tuple
from pydantic import BaseModel, Field
from collections import defaultdict
import numpy as np
from dataclasses import dataclass
import math

@dataclass
class EntityMatch:
    entity: str
    matched_nodes: List[str]
    match_scores: List[float]
    match_types: List[str]  # 'exact', 'partial', 'semantic'

@dataclass
class RetrievalResult:
    nodes: List[str]
    relationships: List[Dict[str, Any]]
    entity_matches: List[EntityMatch]
    subgraph_score: float
    path_info: Dict[str, Any]
    reasoning: str


class GraphRetriever:
    def __init__(self, neo4j_graph, embedding_model=None):
        """
        Neo4j graph retriever intégré avec extraction d'entités et classification de relations

        Args:
            neo4j_graph: Votre objet Neo4j graph
            embedding_model: Modèle d'embedding optionnel pour similarité sémantique
        """
        self.graph = graph
        self.embedding_model = embedding_model
        self.node_cache = {}
        self.relationship_cache = {}

    def _get_node_properties(self) -> Dict[str, List[str]]:
        """Get all available node properties to search across"""
        query = """
        CALL db.schema.nodeTypeProperties()
        YIELD nodeType, propertyName, propertyTypes
        RETURN nodeType, collect(propertyName) as properties
        """

        try:
            result = self.graph.query(query)
            prop_mapping = {}
            for record in result:
                node_type = record.get('nodeType')
                properties = record.get('properties', [])
                prop_mapping[node_type] = properties
            return prop_mapping
        except:
            # Fallback: assume common properties
            return {"*": ["name", "title", "description", "type", "label"]}

    def find_matching_nodes_advanced(self, entity: str) -> EntityMatch:
        """
        Advanced node matching avec amélioration de la stratégie de matching
        """
        matched_nodes = []
        match_scores = []
        match_types = []

        # PREPROCESSING: Clean and tokenize the entity
        entity_clean = entity.lower().strip()
        entity_tokens = entity_clean.replace('-', ' ').replace('_', ' ').split()

        # Strategy 1: Exact name matching
        exact_query = """
        MATCH (n)
        WHERE toLower(n.name) = toLower($entity)
        RETURN n.name as name, labels(n) as labels, 1.0 as score
        """

        try:
            result = self.graph.query(exact_query, params={"entity": entity})
            for record in result:
                if record.get('name'):
                    matched_nodes.append(record['name'])
                    match_scores.append(1.0)
                    match_types.append('exact')
        except Exception as e:
            print(f"Error in exact matching: {e}")

        # Strategy 2: TOKEN-BASED MATCHING 
        token_query = """
        MATCH (n)
        WHERE n.name IS NOT NULL
        WITH n, toLower(n.name) as node_name_lower
        WHERE any(token IN $entity_tokens WHERE node_name_lower CONTAINS token AND size(token) > 2)
          OR any(token IN $entity_tokens WHERE node_name_lower = token)
        RETURN n.name as name,
              CASE
                  WHEN any(token IN $entity_tokens WHERE toLower(n.name) = token) THEN 0.9
                  WHEN any(token IN $entity_tokens WHERE toLower(n.name) CONTAINS token AND size(token) > 3) THEN 0.8
                  WHEN any(token IN $entity_tokens WHERE toLower(n.name) CONTAINS token) THEN 0.7
                  ELSE 0.5
              END as score
        ORDER BY score DESC
        LIMIT 5
        """

        try:
            # FIX: Filtrage des tokens non pertinents
            stopwords = {'levels', 'with', 'in', 'of', 'the', 'and', 'or'}
            meaningful_tokens = [t for t in entity_tokens if t not in stopwords and len(t) > 2]

            if meaningful_tokens:
                result = self.graph.query(token_query, params={"entity_tokens": meaningful_tokens})
                for record in result:
                    name = record.get('name')
                    if name and name not in matched_nodes:
                        matched_nodes.append(name)
                        match_scores.append(record.get('score', 0.75))
                        match_types.append('token_match')
        except Exception as e:
            print(f"Error in token matching: {e}")

        # Strategy 3: CORE TERM EXTRACTION 
        core_terms = []
        medical_patterns = {
            'hba1c': ['hba1c', 'hemoglobin a1c', 'glycated hemoglobin', 'a1c'],
            'blood pressure': ['blood pressure', 'bp', 'hypertension', 'systolic', 'diastolic'],
            'blood sugar': ['blood sugar', 'glucose', 'glycemia', 'blood glucose'],
            'cholesterol': ['cholesterol', 'ldl', 'hdl', 'lipids'],
            'insulin': ['insulin', 'insulin resistance', 'insulin sensitivity'],
            'obesity': ['obesity', 'bmi', 'body mass index', 'weight'],
            'neuropathy': ['neuropathy', 'nerve damage', 'diabetic neuropathy']
        }

        for pattern, variants in medical_patterns.items():
            if any(variant in entity_clean for variant in variants):
                core_terms.append(pattern)

        if core_terms:
            core_query = """
            MATCH (n)
            WHERE any(term IN $core_terms WHERE toLower(n.name) CONTAINS term)
            RETURN n.name as name, 0.9 as score
            LIMIT 5
            """

            try:
                result = self.graph.query(core_query, params={"core_terms": core_terms})
                for record in result:
                    name = record.get('name')
                    if name and name not in matched_nodes:
                        matched_nodes.append(name)
                        match_scores.append(0.9)
                        match_types.append('core_term')
            except Exception as e:
                print(f"Error in core term matching: {e}")

        # Strategy 4: Partial text matching
        partial_query = """
        MATCH (n)
        WHERE toLower(n.name) CONTAINS toLower($entity)
           OR toLower(coalesce(n.description, '')) CONTAINS toLower($entity)
           OR any(label IN labels(n) WHERE toLower(label) CONTAINS toLower($entity))
        RETURN n.name as name,
               CASE
                   WHEN toLower(n.name) CONTAINS toLower($entity) THEN 0.8
                   WHEN toLower(coalesce(n.description, '')) CONTAINS toLower($entity) THEN 0.5
                   ELSE 0.3
               END as score,
               labels(n) as labels
        ORDER BY score DESC
        LIMIT 5
        """

        try:
            result = self.graph.query(partial_query, params={"entity": entity})
            for record in result:
                name = record.get('name')
                if name and name not in matched_nodes:
                    matched_nodes.append(name)
                    match_scores.append(record.get('score', 0.5))
                    match_types.append('partial')
        except Exception as e:
            print(f"Error in partial matching: {e}")

        return EntityMatch(
            entity=entity,
            matched_nodes=matched_nodes,
            match_scores=match_scores,
            match_types=match_types
        )

    def adaptive_hop_expansion(
        self,
        seed_nodes: List[str],
        entities: List[str],
        primary_relations: List[str],
        secondary_relations: List[str]
    ) -> Tuple[Set[str], str]:
        """
        Expansion adaptative utilisant les relations identifiées par le classificateur LLM
        """
        if not seed_nodes:
            return set(), "No seed nodes found"

        # Analyse densité autour des seed nodes
        density_query = """
        MATCH (start)-[r]-(connected)
        WHERE start.name IN $seed_nodes
        RETURN start.name AS node, count(r) AS degree
        """

        node_degrees = {}
        try:
            result = self.graph.query(density_query, {"seed_nodes": seed_nodes})
            for record in result:
                node_degrees[record.get("node")] = record.get("degree", 0)
        except Exception as e:
            print(f"Warning: Density query failed - {e}")

        avg_degree = np.mean(list(node_degrees.values())) if node_degrees else 5

        # Choix dynamique du nombre de hops
        if avg_degree > 20:
            k_hops = 1
            reasoning = f"High graph density ({avg_degree:.1f}) → 1 hop with strong relation filtering"
        elif avg_degree > 10:
            k_hops = 2
            reasoning = f"Medium density ({avg_degree:.1f}) → 2 hops with mixed filtering"
        else:
            k_hops = 3
            reasoning = f"Low density ({avg_degree:.1f}) → 3 hops with broader filtering"

        # Pondération dynamique basée sur primary/secondary
        def get_relation_score(rel_type: str) -> float:
            if rel_type in primary_relations:
                return 0.9
            elif rel_type in secondary_relations:
                return 0.8
            else:
                return 0.5

        expanded_nodes: Set[str] = set(seed_nodes)
        current_nodes = set(seed_nodes)

        for hop in range(k_hops):
            if not current_nodes:
                break

            expansion_query = """
            MATCH (current)-[r]-(next)
            WHERE current.name IN $current_nodes
            RETURN DISTINCT next.name AS name, type(r) AS rel_type
            """

            new_nodes = set()
            try:
                result = self.graph.query(expansion_query, {"current_nodes": list(current_nodes)})
                for record in result:
                    rel_type = record.get("rel_type")
                    score = get_relation_score(rel_type)
                    node_name = record.get("name")

                    if score >= 0.6 and node_name not in expanded_nodes:
                        new_nodes.add(node_name)
                        expanded_nodes.add(node_name)

            except Exception as e:
                print(f"Expansion error at hop {hop}: {e}")
                break

            current_nodes = new_nodes

        return expanded_nodes, reasoning

    def find_connecting_paths(self, nodes: List[str], primary_relations: List[str],
                              secondary_relations: List[str], max_length: int = 3) -> List[Dict]:
        """
        FIX: Correction de la requête de chemins
        """
        if len(nodes) < 2:
            return []

        paths = []

        # FIX: Construction correcte des conditions CASE
        primary_conditions = [f"type(r) = '{rel}'" for rel in primary_relations]
        secondary_conditions = [f"type(r) = '{rel}'" for rel in secondary_relations]

        primary_case = " OR ".join(primary_conditions) if primary_conditions else "false"
        secondary_case = " OR ".join(secondary_conditions) if secondary_conditions else "false"

        path_query = f"""
        MATCH path = (start)-[*1..{max_length}]-(end)
        WHERE start.name IN $nodes AND end.name IN $nodes AND start.name <> end.name
        WITH path, start.name as start_name, end.name as end_name, length(path) as path_length,
             [r in relationships(path) | type(r)] as rel_types
        RETURN start_name, end_name, path_length, rel_types,
               reduce(score = 1.0, r in relationships(path) |
                   score * CASE
                       WHEN ({primary_case}) THEN 0.9
                       WHEN ({secondary_case}) THEN 0.8
                       ELSE 0.6
                   END
               ) as path_score
        ORDER BY path_score DESC, path_length ASC
        """

        try:
            result = self.graph.query(path_query, params={"nodes": nodes})
            for record in result:
                paths.append({
                    'start': record.get('start_name'),
                    'end': record.get('end_name'),
                    'length': record.get('path_length'),
                    'relationships': record.get('rel_types', []),
                    'score': record.get('path_score', 0.5)
                })
        except Exception as e:
            print(f"Error finding paths: {e}")

        return paths

    def calculate_subgraph_relevance(self, nodes: List[str], relationships: List[Dict],
                                   entity_matches: List[EntityMatch]) -> float:
        """Calculate relevance score for the retrieved subgraph"""
        # Factor 1: Entity match quality
        match_score = 0.0
        if entity_matches:
            total_matches = sum(len(em.matched_nodes) for em in entity_matches)
            if total_matches > 0:
                weighted_score = sum(
                    max(em.match_scores) * len(em.matched_nodes)
                    for em in entity_matches if em.match_scores
                )
                match_score = weighted_score / total_matches

        # Factor 2: Graph connectivity
        connectivity_score = 0.0
        if len(nodes) > 1 and relationships:
            connectivity_score = math.log1p(len(relationships)) / math.log1p(len(nodes)**2)

        # Factor 3: Subgraph size appropriateness
        size_score = 0.0
        if 3 <= len(nodes) <= 50:
            size_score = 1.0
        elif len(nodes) < 3:
            size_score = len(nodes) / 3.0
        else:
            size_score = max(0.1, 50.0 / len(nodes))

        # Weighted combination
        total_score = (0.6 * match_score + 0.3 * connectivity_score + 0.1 * size_score)
        return min(total_score, 1.0)

    def retrieve_subgraph_advanced(self, entities: List[str], question: str,
                                 min_relevance: float = 0.3) -> RetrievalResult:
        """
        FIX: Méthode principale intégrant les trois composants
        """
        # Étape 1: Classification des relations basées sur la question
        relation_classification = classify_relations(question, RELATIONS)
        primary_relations = relation_classification.get("primary_relations", [])
        secondary_relations = relation_classification.get("secondary_relations", [])

        print(f"[DEBUG] Primary Relations: {primary_relations}")
        print(f"[DEBUG] Secondary Relations: {secondary_relations}")

        # Étape 2: Matching des entités
        entity_matches, all_seed_nodes = [], []
        for entity in entities:
            match_result = self.find_matching_nodes_advanced(entity)
            entity_matches.append(match_result)
            all_seed_nodes.extend(match_result.matched_nodes[:5])  # Top 5 matches par entité

        if not all_seed_nodes:
            return RetrievalResult([], [], entity_matches, 0.0, {},
                                 "No matching nodes found in graph")

        # Étape 3: Recherche des chemins connectants
        path_info = {
            'paths': self.find_connecting_paths(all_seed_nodes, primary_relations,
                                              secondary_relations, max_length=3),
            'seed_nodes': all_seed_nodes
        }

        # Étape 4: Expansion adaptative autour des seeds
        expanded_nodes, expansion_reasoning = self.adaptive_hop_expansion(
            all_seed_nodes, entities, primary_relations, secondary_relations
        )

        # Étape 5: Récupération des relations du sous-graphe
        relationships = []
        if len(expanded_nodes) > 1:
            subgraph_query = """
            MATCH (n)-[r]-(m)
            WHERE n.name IN $expanded_nodes AND m.name IN $expanded_nodes
            RETURN DISTINCT n.name as source, type(r) as relationship, m.name as target,
                  CASE
                      WHEN type(r) IN $primary_relations THEN 1.0
                      WHEN type(r) IN $secondary_relations THEN 0.9
                      ELSE 0.7
                  END as rel_importance
            ORDER BY rel_importance DESC
            LIMIT 100
            """

            try:
                result = self.graph.query(subgraph_query, params={
                    "expanded_nodes": list(expanded_nodes),
                    "primary_relations": primary_relations,
                    "secondary_relations": secondary_relations
                })
                for record in result:
                    relationships.append({
                        'source': record.get('source'),
                        'target': record.get('target'),
                        'relationship': record.get('relationship'),
                        'importance': record.get('rel_importance', 0.5)
                    })
            except Exception as e:
                print(f"Error getting relationships: {e}")

        # Étape 6: Scoring final du sous-graphe
        subgraph_score = self.calculate_subgraph_relevance(list(expanded_nodes),
                                                         relationships, entity_matches)

        reasoning = (f"Good relevance ({subgraph_score:.2f}). {expansion_reasoning}"
                    if subgraph_score >= min_relevance
                    else f"Low relevance ({subgraph_score:.2f}). {expansion_reasoning}")

        return RetrievalResult(
            nodes=list(expanded_nodes),
            relationships=relationships,
            entity_matches=entity_matches,
            subgraph_score=subgraph_score,
            path_info=path_info,
            reasoning=reasoning
        )

    # FIX: Méthode d'interface principale pour intégrer avec les autres composants
    def retrieve_for_question(self, question: str, min_relevance: float = 0.3) -> RetrievalResult:
        """
        Méthode principale qui intègre extraction d'entités + classification de relations + retrieval

        Args:
            question: Question de l'utilisateur
            min_relevance: Score minimal de pertinence

        Returns:
            RetrievalResult avec le sous-graphe pertinent
        """
        # Étape 1: Extraction des entités depuis la question

        # Pour l'instant, utilisation directe:
        entities = entity_chain.invoke({"question": question}).names

        # Étape 2: Retrieval avec la méthode avancée
        return self.retrieve_subgraph_advanced(entities, question, min_relevance)

In [None]:
# ==================== COMPLETE TEST PIPELINE ====================

# Step 1: Initialize both components
retriever = GraphRetriever(graph)

# Step 2: Test with real medical questions
test_questions = [
    "What are the most effective medications for controlling HbA1c levels in newly diagnosed Type 2 diabetes patients?"
]

# Step 3: Run complete pipeline
for i, question in enumerate(test_questions, 1):
    print(f"\n TEST {i}: {question}")
    print("="*80)

    # ENTITY EXTRACTION (Your existing code)
    entity_result = entity_chain.invoke({"question": question})
    extracted_entities = entity_result.names
    print(f"📋 Extracted Entities: {extracted_entities}")

    # GRAPH RETRIEVAL (Improved retriever)
    retrieval_result = retriever.retrieve_subgraph_advanced(extracted_entities, question)

    # RESULTS ANALYSIS
    print(f"RETRIEVAL RESULTS:")
    print(f"   • Relevance Score: {retrieval_result.subgraph_score:.2f}/1.0")
    print(f"   • Strategy: {retrieval_result.reasoning}")
    print(f"   • Nodes Retrieved: {len(retrieval_result.nodes)}")
    print(f"   • Relationships Found: {len(retrieval_result.relationships)}")

    print(f"\n ENTITY MATCHING QUALITY:")
    for match in retrieval_result.entity_matches:
        print(f"   '{match.entity}' → {len(match.matched_nodes)} matches")
        for node, score, match_type in zip(match.matched_nodes[:5], match.match_scores[:5], match.match_types[:5]):
            print(f"      {node} (score: {score:.2f}, type: {match_type})")

    print(f"\n KEY RELATIONSHIPS:")
    top_rels = sorted(retrieval_result.relationships, key=lambda x: x.get('importance', 0), reverse=True)[:10]
    for rel in top_rels:
        print(f"   {rel['source']} --[{rel['relationship']}]--> {rel['target']}")

    if retrieval_result.path_info['paths']:
        print(f"\n DISCOVERED PATHS:")
        for path in retrieval_result.path_info['paths'][:5]:
            print(f"   {path['start']} → {path['end']} (length: {path['length']}, score: {path['score']:.2f})")

    print(f"\n✅ READY FOR ANSWER GENERATION: {retrieval_result.subgraph_score >= 0.5}")
    print("-"*80)

In [None]:
Answer generation

In [None]:
# Required imports
from typing import List, Dict, Any
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from dataclasses import dataclass

@dataclass
class AnswerResult:
    answer: str
    confidence: float
    knowledge_used: Dict[str, Any]
    reasoning_steps: List[str]

class GraphAnswerGenerator:
    def __init__(self):
        """Initialize the answer generator with Llama 3.3 via ChatGroq"""
        self.llm = ChatGroq(
            model="llama-3.3-70b-versatile",
            temperature=0,
        )

    def format_graph_knowledge(self, retrieval_result) -> str:
        """
        Convert graph retrieval results into structured text format

        Args:
            retrieval_result: Output from your ImprovedNeo4jGraphRetriever

        Returns:
            Formatted knowledge string
        """
        knowledge_sections = []

        # Section 1: Key Medical Entities Found
        if retrieval_result.nodes:
            entities_text = "MEDICAL ENTITIES:\n"
            # Group nodes by likely type (you could enhance this with node properties)
            for i, node in enumerate(retrieval_result.nodes[:], 1):  # Limit for context
                entities_text += f"- {node}\n"
            knowledge_sections.append(entities_text)

        # Section 2: Medical Relationships
        if retrieval_result.relationships:
            relationships_text = "MEDICAL RELATIONSHIPS:\n"
            # Sort by importance if available
            sorted_rels = sorted(
                retrieval_result.relationships,
                key=lambda x: x.get('importance', 0.5),
                reverse=True
            )

            for rel in sorted_rels[:20]:  # Top 20 relationships
                source = rel.get('source', 'Unknown')
                target = rel.get('target', 'Unknown')
                rel_type = rel.get('relationship', 'RELATED_TO')

                # Convert relationship types to natural language
                rel_description = self._relationship_to_text(rel_type)
                relationships_text += f"- {source} {rel_description} {target}\n"

            knowledge_sections.append(relationships_text)

        # Section 3: Entity Matching Context
        if retrieval_result.entity_matches:
            matching_text = "ENTITY CONTEXT:\n"
            for match in retrieval_result.entity_matches:
                if match.matched_nodes:
                    matching_text += f"- Query term '{match.entity}' relates to: {', '.join(match.matched_nodes[:3])}\n"
            knowledge_sections.append(matching_text)

        # Section 4: Connection Paths 
        if retrieval_result.path_info and retrieval_result.path_info.get('paths'):
            paths_text = "KEY CONNECTIONS:\n"
            for path in retrieval_result.path_info['paths'][:5]:
                start = path.get('start', '')
                end = path.get('end', '')
                relationships = path.get('relationships', [])
                path_desc = ' → '.join(relationships) if relationships else 'connected to'
                paths_text += f"- {start} is {path_desc} {end}\n"
            knowledge_sections.append(paths_text)

        return "\n".join(knowledge_sections)

    def _relationship_to_text(self, rel_type: str) -> str:
        """
        Convertit un type de relation brut (Neo4j) en texte naturel :
        - Met tout en minuscules
        - Remplace les underscores '_' par des espaces
        """

        return rel_type.lower().replace('_', ' ')

    def create_answer_prompt(self) -> ChatPromptTemplate:
        """Create the prompt template for answer generation"""

        prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a medical expert specializing in Type 2 Diabetes. Your role is to provide accurate, helpful answers based on the provided medical knowledge graph information.

INSTRUCTIONS:
1. Answer the question using ONLY the provided medical knowledge
2. Be precise , factual and explicative - avoid speculation
3. Organize your answer logically (causes → mechanisms → treatments → outcomes)
4. If the knowledge is insufficient, clearly state what information is missing
5. Use medical terminology appropriately but explain complex concepts
6. Focus on practical, clinically relevant information

RESPONSE FORMAT:
- Direct answer first
- Supporting details with mechanisms/relationships
- Practical implications or recommendations
- Note any limitations in the available knowledge"""),

            ("human", """Based on the following medical knowledge from a Type 2 Diabetes knowledge graph, please answer this question:

QUESTION: {question}

AVAILABLE MEDICAL KNOWLEDGE:
{formatted_knowledge}

RETRIEVAL QUALITY: {confidence_score}/1.0
REASONING CONTEXT: {retrieval_reasoning}

Please provide a comprehensive answer based on this knowledge.""")
        ])

        return prompt

    def generate_answer(self, question: str, retrieval_result) -> AnswerResult:
        """
        Generate answer using structured knowledge from graph retrieval

        Args:
            question: Original user question
            retrieval_result: Output GraphRetriever

        Returns:
            AnswerResult with generated answer and metadata
        """

        # Step 1: Format the knowledge
        formatted_knowledge = self.format_graph_knowledge(retrieval_result)

        # Step 2: Create the prompt
        prompt = self.create_answer_prompt()

        # Step 3: Prepare prompt variables
        prompt_vars = {
            "question": question,
            "formatted_knowledge": formatted_knowledge,
            "confidence_score": retrieval_result.subgraph_score,
            "retrieval_reasoning": retrieval_result.reasoning
        }

        # Step 4: Generate answer
        try:
            chain = prompt | self.llm
            response = chain.invoke(prompt_vars)
            answer_text = response.content

            # Step 5: Estimate answer confidence based on retrieval quality
            answer_confidence = self._calculate_answer_confidence(
                retrieval_result, formatted_knowledge, answer_text
            )

            # Step 6: Extract reasoning steps (basic implementation)
            reasoning_steps = self._extract_reasoning_steps(answer_text)

            return AnswerResult(
                answer=answer_text,
                confidence=answer_confidence,
                knowledge_used={
                    "entities_count": len(retrieval_result.nodes),
                    "relationships_count": len(retrieval_result.relationships),
                    "retrieval_score": retrieval_result.subgraph_score,
                    "entity_matches": len(retrieval_result.entity_matches)
                },
                reasoning_steps=reasoning_steps
            )

        except Exception as e:
            return AnswerResult(
                answer=f"Error generating answer: {str(e)}",
                confidence=0.0,
                knowledge_used={},
                reasoning_steps=["Error occurred during generation"]
            )

    def _calculate_answer_confidence(self, retrieval_result, formatted_knowledge: str, answer: str) -> float:
        """Calculate confidence score for the generated answer"""

        # Factor 1: Retrieval quality
        retrieval_confidence = retrieval_result.subgraph_score

        # Factor 2: Knowledge richness
        knowledge_richness = min(len(formatted_knowledge) / 1000, 1.0)  # Normalize to 0-1

        # Factor 3: Answer completeness (basic heuristic)
        answer_completeness = min(len(answer) / 500, 1.0) if len(answer) > 50 else 0.3

        # Factor 4: Entity coverage
        entities_covered = len(retrieval_result.entity_matches) / max(len(retrieval_result.entity_matches), 1)

        # Weighted average
        total_confidence = (
            0.4 * retrieval_confidence +
            0.3 * knowledge_richness +
            0.2 * answer_completeness +
            0.1 * entities_covered
        )

        return min(total_confidence, 1.0)

    def _extract_reasoning_steps(self, answer: str) -> List[str]:
        """Extract basic reasoning steps from the answer"""
        # Simple implementation - you could make this more sophisticated
        sentences = answer.split('. ')
        reasoning_steps = []

        for sentence in sentences[:5]:  # Take first 5 sentences as reasoning steps
            if len(sentence.strip()) > 20:
                reasoning_steps.append(sentence.strip())

        return reasoning_steps

# ==================== COMPLETE PIPELINE TEST ====================

def test_complete_pipeline(neo4j_graph):
    """Test the complete end-to-end pipeline"""

    # Initialize components
    retriever = Neo4jGraphRetriever(neo4j_graph)  # Your existing retriever
    answer_generator = GraphAnswerGenerator()

    # Test question
    test_question = "What are the most effective medications for controlling HbA1c levels in newly diagnosed Type 2 diabetes patients?"

    print(f"QUESTION: {test_question}")
    print("="*80)

    # Step 1: Extract entities (your existing code)
    entity_result = entity_chain.invoke({"question": test_question})
    entities = entity_result.names
    print(f"ENTITIES: {entities}")



    # Retrieve relevant subgraph info
    retrieval_result = retriever.retrieve_subgraph_info(entities, k_hops=2)
    # Step 2: Retrieve knowledge graph
    #retrieval_result = retriever.retrieve_subgraph_advanced(extracted_entities, question)

    # Step 3: Generate answer
    answer_result = answer_generator.generate_answer(test_question, retrieval_result)

    print(f"\n GENERATED ANSWER:")
    print("-"*50)
    print(answer_result.answer)
    print("-"*50)

    return answer_result


def test_complet_pipeline(neo4j_graph):
    """Test the complete end-to-end pipeline"""

    # Initialize components
    retriever = GraphRetriever(neo4j_graph)
    answer_generator = GraphAnswerGenerator()

    # Test question
    test_question = "what are the symptoms of type 2 diabetes?"

    print(f"🔍 QUESTION: {test_question}")
    print("="*80)

    # Step 1: Extract entities (replace with your actual entity extraction chain)
    # For demo, hardcoding:
    entities = entity_chain.invoke({"question": question}).names

    print(f" ENTITIES: {entities}")

    # Step 2: Retrieve knowledge graph
    retrieval_result = retriever.retrieve_subgraph_advanced(extracted_entities, question,k_hops=2)
    # Step 3: Generate answer
    answer_result = answer_generator.generate_answer(test_question, retrieval_result)

    print(f"\n GENERATED ANSWER:")
    print("-"*50)
    print(answer_result.answer)
    print("-"*50)

    return answer_result


result = test_complete_pipeline(graph)
result

In [None]:
Vector retriever

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Neo4jVector
embedding_model = HuggingFaceEmbeddings(
    model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb",
    encode_kwargs={"normalize_embeddings": True}
)

vector_store = Neo4jVector(
    embedding=embedding_model,
    url=url,
    username=username,
    password=password,
    node_label="Document",
    text_node_property="text",
    embedding_node_property="embedding"
)

retriever = vector_store.as_retriever(
    search_type="similarity",
    search_kwargs={
        "k": 10,  
        "score_threshold": 0.7 
    }
)

In [None]:
from langchain.chains import RetrievalQA

# --- Prompt personnalisé ---
prompt_template = """You are a medical assistant specialized in type 2 diabetes.
Use the retrieved context to answer the question factually.
If the answer is not in the context, say you don't know.

Context: {context}
Question: {question}
Answer:"""

prompt = PromptTemplate(
    template=prompt_template,
    input_variables=["context", "question"]
)

llm = ChatGroq(
    groq_api_key="api_key",   
    model="llama-3.3-70b-versatile",       
    temperature=0
)

# --- QA chain ---
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    chain_type="stuff",
    chain_type_kwargs={"prompt": prompt}
)


# --- Test ---
query = "What are the first warning signs and symptoms of type 2 diabetes?"
answer = qa_chain.run(query)
print(answer)


In [None]:
# --- Test ---
query = "What are the risk factors for type 2 diabetes?"
answer = qa_chain.run(query)
print(answer)

In [None]:
# --- Test ---
query = "5.	Which organs are most affected by insulin resistance in T2D?"
answer = qa_chain.run(query)
print(answer)

In [None]:
# --- Test ---
query = "What are the most effective medications for controlling HbA1c levels in newly diagnosed Type 2 diabetes patients?"
answer = qa_chain.run(query)
print(answer)

In [None]:
# --- Test ---
query = "What are the mechanisms by which GLP-1 receptor agonists improve cardiovascular outcomes in patients with type 2 diabetes?"
answer = qa_chain.run(query)
print(answer)

In [None]:
# --- Test ---
query = "Why is lifestyle modification (diet and exercise) often recommended before or alongside medication for managing type 2 diabetes?"
answer = qa_chain.run(query)
print(answer)

In [None]:
# --- Test ---
query = "What diet is recommended for someone with type 2 diabetes?"
answer = qa_chain.run(query)
print(answer)

In [None]:
# --- Test ---
query = "How should blood glucose be monitored?"
answer = qa_chain.run(query)
print(answer)

In [None]:
Vector Retriever

In [None]:
import numpy as np
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.document_transformers import LongContextReorder
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_groq import ChatGroq
from sentence_transformers import CrossEncoder
import logging

class AdvancedRAGRetriever:
    def __init__(self, vector_store, llm, embedding_model):
        self.vector_store = vector_store
        self.llm = llm
        self.embedding_model = embedding_model
        self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

        # Setup logging
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)

    def create_vector_retriever(self, k=10):
        """
        Crée un retriever vectoriel optimisé
        """
        vector_retriever = self.vector_store.as_retriever(
            search_type="mmr",  # Maximum Marginal Relevance pour diversité
            search_kwargs={
                "k": k * 2,  # Récupère plus de docs pour le re-ranking
                "lambda_mult": 0.7,  # Balance pertinence vs diversité
                "fetch_k": k * 4,    # Docs candidats avant MMR
            }
        )

        return vector_retriever

    def create_compression_retriever(self, base_retriever, k=10):
        """
        Ajoute de la compression contextuelle pour améliorer la qualité
        """
        # Filtre par similarité d'embeddings
        embeddings_filter = EmbeddingsFilter(
            embeddings=self.embedding_model,
            similarity_threshold=0.76,  # Seuil de similarité
            k=k * 2
        )

        # Supprime les doublons
        redundant_filter = EmbeddingsRedundantFilter(
            embeddings=self.embedding_model,
            similarity_threshold=0.95
        )

        # Extracteur LLM pour les passages les plus pertinents
        llm_extractor = LLMChainExtractor.from_llm(self.llm)

        # Pipeline de compression
        pipeline_compressor = DocumentCompressorPipeline(
            transformers=[embeddings_filter, redundant_filter, llm_extractor]
        )

        # Retriever avec compression
        compression_retriever = ContextualCompressionRetriever(
            base_compressor=pipeline_compressor,
            base_retriever=base_retriever
        )

        return compression_retriever

    def create_multi_query_retriever(self, base_retriever):
        """
        Génère plusieurs variantes de la question pour une meilleure recherche
        """
        multi_query_retriever = MultiQueryRetriever.from_llm(
            retriever=base_retriever,
            llm=self.llm,
            prompt=self._get_multi_query_prompt()
        )
        return multi_query_retriever

    def rerank_with_cross_encoder(self, query, documents, top_k=5):
        """
        Re-rank les documents avec un cross-encoder pour améliorer la pertinence
        """
        if not documents:
            return documents

        # Prépare les paires query-document
        pairs = [(query, doc.page_content) for doc in documents]

        # Calcule les scores de pertinence
        scores = self.cross_encoder.predict(pairs)

        # Combine documents et scores
        doc_scores = list(zip(documents, scores))

        # Trie par score décroissant
        doc_scores.sort(key=lambda x: x[1], reverse=True)

        # Retourne les top_k documents
        return [doc for doc, score in doc_scores[:top_k]]

    def create_complete_advanced_retriever(self, k=10):
        """
        Crée le retriever le plus avancé combinant toutes les techniques
        """
        # 1. Retriever vectoriel de base
        vector_retriever = self.create_vector_retriever(k=k*2)

        # 2. Multi-query pour diversifier les recherches
        multi_query_retriever = self.create_multi_query_retriever(vector_retriever)

        # 3. Compression contextuelle
        final_retriever = self.create_compression_retriever(multi_query_retriever, k=k)

        return final_retriever

    def advanced_similarity_search(self, query, k=5, include_reranking=True):
        """
        Recherche avancée avec re-ranking optionnel
        """
        # Recherche initiale avec plus de documents
        initial_docs = self.vector_store.similarity_search_with_score(
            query, k=k*3, score_threshold=0.7
        )

        documents = [doc for doc, score in initial_docs]

        if include_reranking and documents:
            # Re-ranking avec cross-encoder
            documents = self.rerank_with_cross_encoder(query, documents, k)
        else:
            documents = documents[:k]

        # Réorganise pour un meilleur contexte
        reordered_docs = LongContextReorder().transform_documents(documents)

        return reordered_docs

    def evaluate_retrieval_quality(self, query, retrieved_docs, ground_truth=None):
        """
        Évalue la qualité du retrieval
        """
        metrics = {}

        # Vérification si on a des documents
        if not retrieved_docs:
            return {"error": "No documents retrieved"}

        # Diversité des documents
        if len(retrieved_docs) > 1:
            embeddings = [self.embedding_model.embed_query(doc.page_content)
                         for doc in retrieved_docs]
            similarities = []
            for i in range(len(embeddings)):
                for j in range(i+1, len(embeddings)):
                    sim = np.dot(embeddings[i], embeddings[j]) / (
                        np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
                    )
                    similarities.append(sim)

            metrics['diversity_score'] = 1 - np.mean(similarities) if similarities else 0

        # Pertinence moyenne avec cross-encoder
        query_doc_pairs = [(query, doc.page_content) for doc in retrieved_docs]
        relevance_scores = self.cross_encoder.predict(query_doc_pairs)
        metrics['avg_relevance'] = np.mean(relevance_scores)
        metrics['min_relevance'] = np.min(relevance_scores)
        metrics['max_relevance'] = np.max(relevance_scores)

        # Longueur des documents
        doc_lengths = [len(doc.page_content) for doc in retrieved_docs]
        metrics['avg_doc_length'] = np.mean(doc_lengths)
        metrics['total_content_length'] = sum(doc_lengths)

        return metrics

    def _get_multi_query_prompt(self):
        """Prompt optimisé pour la génération de queries multiples"""
        prompt = PromptTemplate(
            input_variables=["question"],
            template="""You are an expert medical assistant. Given a medical question about diabetes,
            generate 3 different versions of this question that could help retrieve relevant information.
            Make the questions more specific and include relevant medical terminology when appropriate.

            Original question: {question}

            Alternative questions:
            1.
            2.
            3.
            """
        )
        return prompt


# === UTILISATION AVANCÉE ===
def setup_advanced_rag_system(vector_store, llm, embedding_model):
    """
    Configure un système RAG avancé - version sans BM25/TFIDF
    """
    # Initialise le retriever avancé
    advanced_retriever = AdvancedRAGRetriever(vector_store, llm, embedding_model)

    # Crée le retriever complet (sans avoir besoin de documents)
    try:
        complete_retriever = advanced_retriever.create_complete_advanced_retriever(k=5)
    except Exception as e:
        print(f" Erreur lors de la création du retriever complet: {e}")
        print("Utilisation du retriever de base...")
        complete_retriever = vector_store.as_retriever(search_kwargs={"k": 5})

    # Template de prompt 
    improved_template = """
    You are an expert medical assistant specialized in type 2 diabetes.
    Use ONLY and STRICTLY the provided medical context to answer questions accurately and comprehensively.

    GUIDELINES:
    - Prioritize information from the most relevant sources
    - Provide detailed explanations with medical reasoning
    - If information is incomplete, clearly state what's missing
    - Include relevant context and background information
    - Use professional medical language but keep it accessible
    - If an information is not stated in the provided context stated it clearly

    RELEVANT MEDICAL CONTEXT:
    {context}

    PATIENT QUESTION: {question}

    COMPREHENSIVE MEDICAL ANSWER:
    """

    improved_prompt = PromptTemplate(
        template=improved_template,
        input_variables=["context", "question"]
    )

    # Crée la chaîne QA 
    advanced_qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=complete_retriever,
        chain_type_kwargs={"prompt": improved_prompt},
        return_source_documents=True
    )

    return advanced_qa_chain, advanced_retriever


--- Test Fuunction --

def ask_advanced_question(advanced_qa_chain, advanced_retriever, question):
    """
    Pose une question au système RAG avancé
    """
    print(f"\n Question: {question}")

    try:
        result = advanced_qa_chain({"query": question})
        print(f" Réponse améliorée: {result['result']}")

        # Évalue la qualité si possible
        if advanced_retriever and result.get('source_documents'):
            metrics = advanced_retriever.evaluate_retrieval_quality(
                question, result['source_documents']
            )
            print(f" Score de pertinence: {metrics.get('avg_relevance', 0):.3f}")

        return result

    except Exception as e:
        print(f" Erreur lors de la réponse: {e}")
        return None


# === EXEMPLE D'UTILISATION ===

if __name__ == "__main__":
    # Configuration de base 
    GROQ_API_KEY = "api_key"

    llm = ChatGroq(
        groq_api_key=GROQ_API_KEY,
        model_name="llama3-70b-8192",
        temperature=0,
        top_p=0.9,
        streaming=False
    )

    vector_store = vector_store
    embedding_model = HuggingFaceEmbeddings(
      model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb",
      encode_kwargs={"normalize_embeddings": True}
      )


    advanced_qa_chain, advanced_retriever = setup_advanced_rag_system(
        vector_store, llm, embedding_model
    )

    # Tests
    if advanced_qa_chain and advanced_retriever:

         ask_advanced_question(advanced_qa_chain, advanced_retriever,
                              "What are the most effective medications for controlling HbA1c levels in newly diagnosed Type 2 diabetes patients?")


In [None]:
# === EXEMPLE D'UTILISATION ===
if __name__ == "__main__":
    # Configuration de base
    GROQ_API_KEY = "API_Key"

    llm = ChatGroq(
        groq_api_key=GROQ_API_KEY,
        model_name="llama-3.3-70b-versatile",
        temperature=0,
        top_p=0.9,
        streaming=False
    )

    vector_store = vector_store
    embedding_model = HuggingFaceEmbeddings(
      model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb",
      encode_kwargs={"normalize_embeddings": True}
      )

    advanced_qa_chain, advanced_retriever = setup_advanced_rag_system(
        vector_store, llm, embedding_model
    )

    # Tests
    if advanced_qa_chain and advanced_retriever:

         ask_advanced_question(advanced_qa_chain, advanced_retriever,
                              "Which organs are most affected by insulin resistance in T2D?")


In [None]:
Hybrid retriever

In [None]:
import numpy as np
from typing import Dict, Any, List

class HybridRetriever:
    def __init__(self, vector_retriever: AdvancedRAGRetriever, graph_retriever: GraphRetriever,
                 alpha: float = 0.7):
        """
        Hybrid Retriever combinant vector-based et graph-based retrieval

        Args:
            vector_retriever: Instance de AdvancedRAGRetriever
            graph_retriever: Instance de GraphRetriever
            alpha: Poids du vector retriever (0.0 à 1.0, plus haut = plus d'importance au vectoriel)
        """
        self.vector_retriever = vector_retriever
        self.graph_retriever = graph_retriever
        self.alpha = alpha  

    def retrieve(self, question: str, k: int = 5, min_graph_relevance: float = 0.3) -> Dict[str, Any]:
        """
        Exécute un retrieval hybride et combine les résultats des deux approches.

        Returns:
            dict contenant:
              - hybrid_documents: liste finale pondérée
              - vector_docs: docs du retriever vectoriel
              - graph_result: sous-graphe du retriever graphe
              - scores: détails des scores de fusion
        """
        # === 1. Retrieval vectoriel  ===
        vector_docs = self.vector_retriever.advanced_similarity_search(question, k=k)
        vector_scores = self.vector_retriever.evaluate_retrieval_quality(question, vector_docs)
        avg_vector_score = vector_scores.get("avg_relevance", 0.5)

        # === 2. Retrieval graphe ===
        graph_result = self.graph_retriever.retrieve_for_question(question, min_relevance=min_graph_relevance)
        graph_score = graph_result.subgraph_score if graph_result else 0.0

        # === 3. Fusion pondérée des scores ===
        final_score = self.alpha * avg_vector_score + (1 - self.alpha) * graph_score

        # === 4. Résultat final ===
        return {
            "hybrid_score": final_score,
            "vector_score": avg_vector_score,
            "graph_score": graph_score,
            "vector_docs": vector_docs,
            "graph_result": graph_result,
            "reasoning": f"Fusion avec alpha={self.alpha:.2f}: "
                         f"vector={avg_vector_score:.2f}, graph={graph_score:.2f}, "
                         f"final={final_score:.2f}"
        }


# === EXEMPLE D'UTILISATION ===
'''if __name__ == "__main__":
    advanced_retriever = AdvancedRAGRetriever(
        vector_store=vector_store,
        llm=llm,
        embedding_model=embedding_model
    )
    graph_retriever = GraphRetriever(graph)

    #hybrid = HybridRetriever(advanced_retriever, graph_retriever, alpha=0.7)

    #question = "What are the most effective medications for lowering HbA1c in Type 2 diabetes?"
    #result = hybrid.retrieve(question, k=5)

    print("\n=== RESULTAT HYBRIDE ===")
    print(f"Hybrid Score: {result['hybrid_score']:.3f}")
    print(f"Vector Score: {result['vector_score']:.3f}")
    print(f"Graph Score: {result['graph_score']:.3f}")
    print("Reasoning:", result["reasoning"])

    # Afficher un résumé des documents vectoriels
    for i, doc in enumerate(result["vector_docs"], 1):
        print(f"\nDoc {i}:", doc.page_content[:200], "...")

    # Afficher un résumé du sous-graphe
    print("\nGraph Subgraph Score:", result["graph_result"].subgraph_score)
    print("Graph Nodes:", result["graph_result"].nodes[:10])
    print("Graph Reasoning:", result["graph_result"].reasoning)
'''

In [None]:
advanced_retriever = AdvancedRAGRetriever(
        vector_store=vector_store,
        llm=llm,
        embedding_model=embedding_model
    )
graph_retriever = GraphRetriever(graph)

hybrid = HybridRetriever(advanced_retriever, graph_retriever, alpha=0.7)

In [None]:
class HybridAnswerGenerator:
    def __init__(self, hybrid_retriever, llm):
        """
        Args:
            hybrid_retriever: Instance of HybridRetriever
            llm: LLM callable (ex: OpenAI, Groq, HuggingFace pipeline)
        """
        self.hybrid_retriever = hybrid_retriever
        self.llm = llm

    def __call__(self, question: str, k: int = 5) -> str:
        """
        Generates a final natural answer as if directly from a powerful LLM.
        """
        # 1. Retrieve best context (vector + graph)
        retrieval_result = self.hybrid_retriever.retrieve(question, k=k)

        docs_text = "\n".join([doc.page_content for doc in retrieval_result["vector_docs"]])
        graph_summary = retrieval_result["graph_result"].reasoning if retrieval_result["graph_result"] else ""

        # 2. Build prompt for LLM
        prompt = f"""
        You are an expert medical assistant specializing in Type 2 Diabetes (T2D). Your task is to answer the patient's question accurately, using both textual medical context and graph-based knowledge.

GUIDELINES:
1. Prioritize information from the most reliable and relevant sources.
2. Provide clear, detailed explanations with medical reasoning.
3. Mention if information is missing or uncertain.
4. Include important background or context for the patient to understand.
5. Use professional yet accessible language for healthcare communication.
6. When applicable, reference medications, lifestyle interventions, or clinical guidelines.

QUESTION:
{question}

TEXTUAL MEDICAL CONTEXT:
{docs_text}

GRAPH-BASED KNOWLEDGE SUMMARY:
{graph_summary}

INSTRUCTIONS:
- Combine insights from textual context and graph knowledge.
- Focus on actionable, evidence-based recommendations.
- Avoid speculation; clearly state any uncertainties.

FINAL ANSWER:
        """

        # 3. Generate & return final answer
        return self.llm(prompt)


In [None]:
from groq import Groq

client = Groq(api_key="API_Key")

def llama3_llm(prompt: str) -> str:
    """
    Appelle LLaMA 3 via ChatGroq pour générer une réponse.
    """
    response = client.chat.completions.create(
        model="llama-3.3-70b-versatile",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.3,
    )
    return response.choices[0].message.content 



In [None]:
answer_generator = HybridAnswerGenerator(hybrid, llama3_llm)

question = "What are the first warning signs and symptoms of type 2 diabetes?"
answer = answer_generator(question)

print("\n=== FINAL ANSWER ===")
print(answer)

In [None]:
answer_generator = HybridAnswerGenerator(hybrid, llama3_llm)

question = "What are three ways to prevent type 2 diabetes? "
answer = answer_generator(question)
print(answer)

In [None]:
answer_generator = HybridAnswerGenerator(hybrid, llama3_llm)

question = "What are three ways to prevent type 2 diabetes? "
answer = answer_generator(question)
print(answer)

In [None]:
answer_generator = HybridAnswerGenerator(hybrid, llama3_llm)

question = "What is the first-line medication recommended by the ADA for most people with newly diagnosed type 2 diabetes? "
answer = answer_generator(question)
print(answer)

In [None]:
answer_generator = HybridAnswerGenerator(hybrid, llama3_llm)

question = "What are the risk factors for type 2 diabetes?"
answer = answer_generator(question)
print(answer)

In [None]:
answer_generator = HybridAnswerGenerator(hybrid, llama3_llm)

question = "Which organs are most affected by insulin resistance in T2D?"
answer = answer_generator(question)
print(answer)

In [None]:
answer_generator = HybridAnswerGenerator(hybrid, llama3_llm)

question = "What are the most effective medications for controlling HbA1c levels in newly diagnosed Type 2 diabetes patients?"
answer = answer_generator(question)
print(answer)

In [None]:
answer_generator = HybridAnswerGenerator(hybrid, llama3_llm)

question = "What diet is recommended for someone with type 2 diabetes?"
answer = answer_generator(question)
print(answer)

In [None]:
answer_generator = HybridAnswerGenerator(hybrid, llama3_llm)

question = "How should blood glucose be monitored?"
answer = answer_generator(question)
print(answer)

In [None]:
answer_generator = HybridAnswerGenerator(hybrid, llama3_llm)

question = "Can type 2 diabetes be prevented?"
answer = answer_generator(question)
print(answer)

In [None]:
Gradio interface

In [None]:
import gradio as gr
import time
from typing import Optional

class HybridAnswerGenerator:
    def __init__(self, hybrid_retriever, llm):
        """
        Args:
            hybrid_retriever: Instance of HybridRetriever
            llm: LLM callable 
        """
        self.hybrid_retriever = hybrid_retriever
        self.llm = llm

    def __call__(self, question: str, k: int = 5) -> str:
        """
        Generates a final natural answer as if directly from a powerful LLM.
        """
        # 1. Retrieve best context (vector + graph)
        retrieval_result = self.hybrid_retriever.retrieve(question, k=k)
        print("DEBUG:", type(retrieval_result), retrieval_result)

        docs = retrieval_result["vector_docs"]

        def extract_content(doc):
            if hasattr(doc, 'page_content'):
                return doc.page_content
            elif hasattr(doc, 'content'):
                return doc.content
            else:
                return str(doc)

        docs_text = "\n".join([extract_content(doc) for doc in docs])
        graph_summary = retrieval_result["graph_result"].reasoning if retrieval_result["graph_result"] else ""

        # 2. Build prompt for LLM
        prompt = f"""
        You are an expert medical assistant specializing in Type 2 Diabetes (T2D). Your task is to answer the patient's question accurately, using both textual medical context and graph-based knowledge.

GUIDELINES:
1. Prioritize information from the most reliable and relevant sources.
2. Provide clear, detailed explanations with medical reasoning.
3. Mention if information is missing or uncertain.
4. Include important background or context for the patient to understand.
5. Use professional yet accessible language for healthcare communication.
6. When applicable, reference medications, lifestyle interventions, or clinical guidelines.

QUESTION: {question}

TEXTUAL MEDICAL CONTEXT:
{docs_text}

GRAPH-BASED KNOWLEDGE SUMMARY:
{graph_summary}

INSTRUCTIONS:
- Combine insights from textual context and graph knowledge.
- Focus on actionable, evidence-based recommendations.
- Avoid speculation; clearly state any uncertainties.

FINAL ANSWER:
        """

        # 3. Generate & return final answer
        return self.llm(prompt)

def create_minimal_interface(hybrid_generator: HybridAnswerGenerator):
    """
    Creates a minimalist medical interface for HybridAnswerGenerator
    """

    def process_question(question: str, k_value: int, progress=gr.Progress()):
        """
        Process medical question with progress tracking
        """
        if not question.strip():
            return "Please enter a question before submitting.", None, "⚠ Empty question"

        progress(0.3, desc="Analyzing question...")
        time.sleep(0.5)

        progress(0.7, desc="Retrieving knowledge...")
        time.sleep(1)

        progress(0.9, desc="Generating response...")

        try:
            answer = hybrid_generator(question, k=k_value)
            progress(1.0, desc="Complete")

            debug_info = f"**Query parameters:** k={k_value} | Length: {len(question)} chars"

            return answer, debug_info, "✓ Response generated"

        except Exception as e:
            return f"Error: {str(e)}", None, "✗ Generation failed"

    # Minimal CSS with only 3 colors: white (#ffffff), dark blue (#2c3e50), light gray (#f8f9fa)
    minimal_css = """
    @import url('https://fonts.googleapis.com/css2?family=Fira+Code:wght@300;400;500;600&display=swap');

    * {
        font-family: 'Fira Code', monospace !important;
    }

    .gradio-container {
        background: #ffffff !important;
        color: #2c3e50 !important;
    }

    .main-container {
        max-width: 1000px;
        margin: 0 auto;
        padding: 40px 20px;
    }

    h1, h2, h3 {
        color: #2c3e50 !important;
        font-weight: 500 !important;
        margin: 0 0 20px 0 !important;
    }

    .header {
        text-align: center;
        margin-bottom: 40px;
        padding-bottom: 20px;
        border-bottom: 1px solid #f8f9fa;
    }

    .warning-box {
        background: #f8f9fa !important;
        border: 1px solid #2c3e50 !important;
        border-radius: 4px !important;
        padding: 15px !important;
        margin: 20px 0 !important;
        color: #2c3e50 !important;
        font-size: 0.9em !important;
    }

    .input-group, .output-group {
        background: #ffffff !important;
        border: 1px solid #f8f9fa !important;
        border-radius: 4px !important;
        padding: 20px !important;
        margin: 15px 0 !important;
    }

    button {
        background: #2c3e50 !important;
        color: #ffffff !important;
        border: 1px solid #2c3e50 !important;
        border-radius: 4px !important;
        padding: 10px 20px !important;
        font-weight: 400 !important;
        transition: all 0.2s ease !important;
    }

    button:hover {
        background: #ffffff !important;
        color: #2c3e50 !important;
        border: 1px solid #2c3e50 !important;
    }

    .secondary-button {
        background: #ffffff !important;
        color: #2c3e50 !important;
        border: 1px solid #f8f9fa !important;
    }

    .secondary-button:hover {
        background: #f8f9fa !important;
        color: #2c3e50 !important;
    }

    input, textarea {
        border: 1px solid #f8f9fa !important;
        background: #ffffff !important;
        color: #2c3e50 !important;
        border-radius: 4px !important;
    }

    input:focus, textarea:focus {
        border-color: #2c3e50 !important;
        outline: none !important;
    }

    .examples-grid {
        display: grid;
        grid-template-columns: 1fr;
        gap: 8px;
        margin-top: 15px;
    }

    .example-btn {
        background: #f8f9fa !important;
        color: #2c3e50 !important;
        border: 1px solid #f8f9fa !important;
        text-align: left !important;
        padding: 12px 16px !important;
        font-size: 0.9em !important;
    }

    .example-btn:hover {
        border-color: #2c3e50 !important;
    }

    .status-text {
        font-size: 0.9em;
        font-weight: 500;
    }
    """

    with gr.Blocks(css=minimal_css, title="T2D Medical Assistant", theme=gr.themes.Default()) as interface:

        with gr.Column(elem_classes=["main-container"]):

            # Simple header
            with gr.Row(elem_classes=["header"]):
                gr.HTML("""
                    <h1>T2D Medical Assistant</h1>
                    <p>Hybrid AI system combining vector search and graph knowledge</p>
                """)

            # Medical disclaimer
            gr.HTML("""
                <div class="warning-box">
                    <strong>Medical Disclaimer:</strong> This AI assistant provides general information about Type 2 Diabetes for educational purposes only.
                    It does not replace professional medical advice. Always consult your healthcare provider for personalized medical guidance.
                </div>
            """)

            # Input section
            with gr.Group(elem_classes=["input-group"]):
                gr.Markdown("### Ask your question")

                question_input = gr.Textbox(
                    label="Medical question about Type 2 Diabetes",
                    placeholder="Example: What are the side effects of metformin? How does exercise affect blood glucose?",
                    lines=3,
                    max_lines=6
                )

                with gr.Row():
                    k_slider = gr.Slider(
                        minimum=1,
                        maximum=10,
                        value=5,
                        step=1,
                        label="Number of sources (k)",
                        info="More sources = comprehensive but slower response"
                    )

                    submit_btn = gr.Button(
                        "Generate Response",
                        variant="primary",
                        size="lg"
                    )

            # Output section
            with gr.Group(elem_classes=["output-group"]):
                gr.Markdown("### Medical Response")

                status_output = gr.Textbox(
                    label="Status",
                    interactive=False,
                    show_label=False,
                    elem_classes=["status-text"]
                )

                answer_output = gr.Textbox(
                    label="AI Response",
                    lines=12,
                    max_lines=20,
                    show_copy_button=True,
                    interactive=False,
                    placeholder="The AI response will appear here..."
                )

                with gr.Accordion("Technical Details", open=False):
                    debug_output = gr.Markdown("Processing details will appear after response generation.")

            # Example questions
            with gr.Group():
                gr.Markdown("### Example Questions")

                examples = [
                    "What are the side effects of metformin and how to manage them?",
                    "How does physical exercise affect blood glucose in T2D patients?",
                    "Which foods should be avoided with Type 2 Diabetes?",
                    "How to recognize hypoglycemia symptoms and what to do?",
                    "What is the difference between Type 1 and Type 2 diabetes?"
                ]

                with gr.Column(elem_classes=["examples-grid"]):
                    for example in examples:
                        gr.Button(
                            example,
                            elem_classes=["example-btn", "secondary-button"],
                            size="sm"
                        ).click(
                            fn=lambda x=example: x,
                            outputs=question_input
                        )

        # Event handlers
        submit_btn.click(
            fn=process_question,
            inputs=[question_input, k_slider],
            outputs=[answer_output, debug_output, status_output],
            show_progress=True
        )

        question_input.submit(
            fn=process_question,
            inputs=[question_input, k_slider],
            outputs=[answer_output, debug_output, status_output],
            show_progress=True
        )

        # Footer
        gr.HTML("""
            <div style="text-align: center; margin-top: 40px; padding: 20px 0;
                        border-top: 1px solid #f8f9fa; color: #2c3e50;">
                <p>Hybrid AI System | Vector Search + Graph Knowledge | Medical Assistant for T2D</p>
            </div>
        """)

    return interface

# Usage example
def demo_usage():
    """
    Example usage of the interface
    """
    hybrid_retriever = HybridRetriever(advanced_retriever, graph_retriever, alpha=0.7)
    llm = ChatGroq(
    model="llama-3.3-70b-versatile",   
    temperature=0,
    )

    hybrid_generator = HybridAnswerGenerator(hybrid_retriever, llm)
    interface = create_minimal_interface(hybrid_generator)
    # Launch the interface
    interface.launch(
         server_name="0.0.0.0",
         server_port=7866,
         share=False,
         inbrowser=True
     )

    print("Interface ready to launch!")

if __name__ == "__main__":
    demo_usage()