<a href="https://colab.research.google.com/github/buwituze/pre-consultation-agent/blob/main/dialogue_policy_model_c.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# üí¨ Model C ‚Äî Question-Flow / Dialogue Policy
## Clinical Question Sequencing Model

**Purpose:** Select the single most appropriate next question to ask a patient during a voice-based pre-consultation, based on what is already known and what is still missing.

| | |
|---|---|
| **Input** | Patient state (from Model B) + conversation history |
| **Output** | One next question ‚Äî plain text, voice-ready |
| **Model** | Google Gemini AI (rule-guided prompting) |
| **Mode** | Question selection only ‚Äî no diagnosis, no advice |

### Pipeline Position
```
Model B Output                  Model C Output
(Structured State)
       ‚îÇ
       ‚ñº
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  Patient State      ‚îÇ
‚îÇ  Missing Fields     ‚îÇ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚ñ∫ "On a scale from 1 to 10,
‚îÇ  Asked Questions    ‚îÇ         how severe is the pain?"
‚îÇ  Conversation Stage ‚îÇ
‚îÇ  Safety Rules       ‚îÇ              ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò              ‚ñº
                              ‚Üí Spoken to patient
                              ‚Üí Answer fed back to Model B
                              ‚Üí Loop continues
```

> ‚ö†Ô∏è **Hard Rules:** One question at a time. No diagnosis. No treatment advice. No medical interpretation.

---
## üì¶ Section 1 ‚Äî Install & Imports

In [None]:
!pip install -q -U google-generativeai

import json
import re
from dataclasses import dataclass, field, asdict
from typing import List, Optional
from enum import Enum

import google.generativeai as genai
from google.colab import userdata

print("‚úÖ Dependencies ready.")

---
## üîë Section 2 ‚Äî API Key

In [None]:
# Add your key via: left sidebar ‚Üí üîë Secrets ‚Üí "GEMINI_API_KEY"
try:
    genai.configure(api_key=userdata.get("GEMINI_API_KEY"))
    print("‚úÖ API key loaded from Colab Secrets.")
except Exception:
    genai.configure(api_key="YOUR_API_KEY_HERE")  # fallback only
    print("‚ö†Ô∏è  API key set manually. Use Colab Secrets for security.")

---
## ü§ñ Section 3 ‚Äî Model Initialisation

In [None]:
GEMINI_MODEL_NAME = "gemini-1.5-flash"

gemini_model = genai.GenerativeModel(
    model_name        = GEMINI_MODEL_NAME,
    generation_config = genai.types.GenerationConfig(
        temperature       = 0.2,   # Slight variability for natural phrasing
        max_output_tokens = 80,    # One short question only
    ),
    safety_settings = [
        {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
        {"category": "HARM_CATEGORY_HARASSMENT",        "threshold": "BLOCK_ONLY_HIGH"},
        {"category": "HARM_CATEGORY_HATE_SPEECH",       "threshold": "BLOCK_ONLY_HIGH"},
    ]
)

print(f"‚úÖ Gemini model ready: {GEMINI_MODEL_NAME}")

---
## üóÇÔ∏è Section 4 ‚Äî Data Structures

Three simple dataclasses carry all the state needed to select the next question.

In [None]:
class ConversationStage(str, Enum):
    EARLY      = "early"       # First 1‚Äì3 questions
    MID        = "mid"         # Filling in known gaps
    ESCALATION = "escalation"  # Red flag follow-up


@dataclass
class PatientState:
    """
    Structured patient information ‚Äî produced by Model B.
    Populate only what is known; leave unknown fields empty.
    """
    age                     : Optional[int]   = None
    chief_complaint         : str             = ""
    duration                : str             = ""
    severity                : str             = ""   # patient's own words
    body_part               : str             = ""
    associated_symptoms     : List[str]       = field(default_factory=list)
    red_flags_present       : Optional[bool]  = None
    additional_observations : str             = ""


@dataclass
class ConversationContext:
    """
    Running record of the dialogue so far.
    """
    questions_asked  : List[str]             = field(default_factory=list)
    patient_answers  : List[str]             = field(default_factory=list)
    stage            : ConversationStage     = ConversationStage.EARLY

    def add_turn(self, question: str, answer: str):
        """Record one Q&A exchange and advance the conversation stage."""
        self.questions_asked.append(question)
        self.patient_answers.append(answer)
        n = len(self.questions_asked)
        if n <= 2:
            self.stage = ConversationStage.EARLY
        elif n <= 6:
            self.stage = ConversationStage.MID
        else:
            self.stage = ConversationStage.ESCALATION


@dataclass
class QuestionResult:
    """Output of one Model C call."""
    question         : str    # The question to speak to the patient
    stage            : str    # Conversation stage at time of selection
    red_flag_active  : bool   # Whether red flag mode influenced selection


print("‚úÖ Data structures defined.")

---
## üìã Section 5 ‚Äî Coverage Checklist & Clinical Rules

The checklist defines what **must eventually be covered**. The rules guide **prioritisation order** ‚Äî they are soft constraints, not rigid scripts.

In [None]:
# ‚îÄ‚îÄ Required coverage categories ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# These topics must be addressed before the consultation is considered complete.
# Order is flexible; coverage is not.
COVERAGE_CHECKLIST = [
    "severity or intensity of the main symptom",
    "when the symptom started or how long it has been present",
    "whether the symptom is getting better, worse, or staying the same",
    "any other symptoms alongside the main one",
    "whether the symptom affects the patient's daily activities",
    "any relevant medical history or known conditions",
]

# ‚îÄ‚îÄ Red flag follow-up questions (activated when red_flags_present = True) ‚îÄ‚îÄ‚îÄ‚îÄ
RED_FLAG_FOLLOWUPS = [
    "Is the patient currently able to breathe comfortably?",
    "Has the patient lost consciousness or felt faint?",
    "Is there any unusual bleeding?",
    "Is the patient able to move all limbs normally?",
    "Is the patient in severe pain right now?",
]

# ‚îÄ‚îÄ Soft prioritisation rules (passed to the model as guidance) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
PRIORITISATION_RULES = """\
- If the main symptom involves pain and severity is unknown ‚Üí ask about severity first.
- If severity is high or a red flag is present ‚Üí immediately cover red-flag screening questions.
- If the symptom is acute (sudden onset) ‚Üí prioritise onset time and progression.
- If the symptom is respiratory or cardiac ‚Üí prioritise breathing and consciousness.
- If the patient is elderly (65+) or very young (under 5) ‚Üí treat all gaps as higher priority.
- Otherwise ‚Üí follow the coverage checklist in a natural, conversational order.
"""

print("‚úÖ Coverage checklist and clinical rules defined.")
print(f"   Coverage categories : {len(COVERAGE_CHECKLIST)}")
print(f"   Red flag follow-ups : {len(RED_FLAG_FOLLOWUPS)}")

---
## üìù Section 6 ‚Äî Prompt

In [None]:
SYSTEM_PROMPT = """\
You are a clinical question-selection assistant in a voice-based hospital pre-consultation system.
Your only job is to choose the single best next question to ask the patient.

=== HARD RULES (NEVER VIOLATE) ===
- Output ONE question only ‚Äî nothing else.
- Do NOT diagnose, label, or suggest any medical condition.
- Do NOT offer reassurance, advice, or interpretation.
- Do NOT repeat a question already asked.
- Do NOT ask multiple questions at once.
- The question must be short, clear, and natural to say out loud.
- Write in plain English (or Kinyarwanda if that is the patient's language).
- Output the question text only ‚Äî no explanation, no prefix, no punctuation other than the question mark.

=== YOUR GOAL ===
Select the question that best reduces clinical uncertainty given what is already known,
following the prioritisation rules and coverage checklist provided.
"""


def build_prompt(state: PatientState, context: ConversationContext) -> str:
    """Build the user-turn prompt from current patient state and conversation context."""

    # Format known patient information
    known = []
    if state.age:                   known.append(f"Age: {state.age}")
    if state.chief_complaint:       known.append(f"Main symptom: {state.chief_complaint}")
    if state.duration:              known.append(f"Duration: {state.duration}")
    if state.severity:              known.append(f"Severity: {state.severity}")
    if state.body_part:             known.append(f"Body part: {state.body_part}")
    if state.associated_symptoms:   known.append(f"Other symptoms: {', '.join(state.associated_symptoms)}")
    if state.additional_observations: known.append(f"Notes: {state.additional_observations}")

    # Identify what is still missing from coverage checklist
    filled_fields = set(f.lower() for f in [
        state.severity, state.duration, state.chief_complaint,
        str(state.associated_symptoms), state.additional_observations
    ] if f)
    missing = [
        item for item in COVERAGE_CHECKLIST
        if not any(keyword in filled_fields for keyword in item.split()[:2])
    ]

    # Format conversation history
    history = ""
    if context.questions_asked:
        pairs = [
            f"  Q: {q}\n  A: {a}"
            for q, a in zip(context.questions_asked, context.patient_answers)
        ]
        history = "\n".join(pairs)
    else:
        history = "  (none yet)"

    # Red flag instructions
    red_flag_block = ""
    if state.red_flags_present:
        red_flag_block = f"""
=== ‚ö†Ô∏è RED FLAG ACTIVE ===
A red flag has been detected. Prioritise these follow-up questions next (pick the most relevant one not yet asked):
{chr(10).join(f'- {q}' for q in RED_FLAG_FOLLOWUPS)}
"""

    return f"""\
=== PATIENT STATE (known so far) ===
{chr(10).join(f'- {k}' for k in known) or '- (nothing known yet)'}

=== MISSING INFORMATION (still to cover) ===
{chr(10).join(f'- {m}' for m in missing) or '- (all core topics covered)'}

=== CONVERSATION HISTORY ===
{history}

=== CONVERSATION STAGE ===
{context.stage.value}
{red_flag_block}
=== PRIORITISATION RULES ===
{PRIORITISATION_RULES}

=== TASK ===
Output the single best next question to ask the patient. Question only. Nothing else.
"""


print("‚úÖ Prompt builder defined.")

---
## üöÄ Section 7 ‚Äî Core Question Selection Function

In [None]:
def select_next_question(
    state   : PatientState,
    context : ConversationContext,
    verbose : bool = True,
) -> QuestionResult:
    """
    Select the single best next question to ask the patient.

    Args:
        state   : Current structured patient state (from Model B).
        context : Running conversation context.
        verbose : Print the selected question.

    Returns:
        QuestionResult with the question and metadata.
    """
    prompt = build_prompt(state, context)

    # One-shot call ‚Äî no chat history needed; each call is self-contained
    chat = gemini_model.start_chat(history=[
        {"role": "user",  "parts": [SYSTEM_PROMPT]},
        {"role": "model", "parts": ["Understood. I will output one question only, with no diagnosis, advice, or extra text."]},
    ])
    response = chat.send_message(prompt)
    raw      = response.text.strip()

    # Clean up: strip leading labels like "Question:" if Gemini adds them
    question = re.sub(r"^(question|next question|q)[:\-]?\s*", "", raw, flags=re.IGNORECASE).strip()
    # Ensure it ends with a question mark
    if question and not question.endswith("?"):
        question += "?"

    result = QuestionResult(
        question        = question,
        stage           = context.stage.value,
        red_flag_active = bool(state.red_flags_present),
    )

    if verbose:
        flag_marker = "üö®" if result.red_flag_active else "üí¨"
        print(f"\n{flag_marker} [{result.stage.upper()}] Next question:")
        print(f"   ‚ñ∂ {result.question}")

    return result


print("‚úÖ select_next_question() defined.")

---
## üîó Section 8 ‚Äî Full Pipeline: Model B ‚Üí Model C

This function accepts Model B's output dict directly and runs one question-selection step.

In [None]:
def run_from_model_b(
    model_b_output : dict,
    context        : ConversationContext,
    verbose        : bool = True,
) -> QuestionResult:
    """
    Accepts Model B's extraction dict and runs Model C question selection.

    Args:
        model_b_output : Dict returned by Model B's extract_clinical_information().
        context        : Running ConversationContext (maintained by the caller).
        verbose        : Print progress.

    Returns:
        QuestionResult.
    """
    ext = model_b_output.get("extraction_dict", {})

    state = PatientState(
        chief_complaint         = ext.get("chief_complaint", ""),
        duration                = ext.get("duration", ""),
        severity                = ext.get("severity", ""),
        body_part               = ext.get("body_part", ""),
        associated_symptoms     = ext.get("associated_symptoms", []),
        red_flags_present       = ext.get("red_flags_present"),
        additional_observations = ext.get("additional_observations", ""),
    )

    if verbose:
        print("üîó Model B ‚Üí Model C")
        print(f"   Complaint : {state.chief_complaint or '(unknown)'}")
        print(f"   Red flag  : {state.red_flags_present}")
        print(f"   Stage     : {context.stage.value}")

    return select_next_question(state, context, verbose=verbose)


print("‚úÖ run_from_model_b() defined.")

---
## üîÑ Section 9 ‚Äî Simulated Conversation Loop

This simulates a full pre-consultation session. In production, each `patient_answer` comes from Model A (voice ‚Üí text). Here we provide answers manually to show the loop working.

In [None]:
def run_conversation(
    state          : PatientState,
    simulated_answers : List[str],   # In production: replaced by Model A output
    max_questions  : int = 6,
) -> List[dict]:
    """
    Simulate a full Model C conversation loop.

    Args:
        state             : Initial PatientState (from Model B).
        simulated_answers : Pre-written patient answers (for testing).
        max_questions     : Stop after this many questions.

    Returns:
        List of turn dicts {turn, question, answer, stage}.
    """
    context = ConversationContext()
    log     = []

    print("\n" + "‚ïê" * 60)
    print(" üè•  SIMULATED PRE-CONSULTATION SESSION")
    print("‚ïê" * 60)
    print(f"  Chief complaint : {state.chief_complaint or '(not yet known)'}")
    print(f"  Red flag active : {state.red_flags_present}")
    print()

    for turn in range(1, max_questions + 1):
        print(f"  ‚îÄ‚îÄ Turn {turn} {'‚îÄ' * 46}")

        result = select_next_question(state, context, verbose=True)

        # In production: answer = Model A transcription of patient voice response
        answer = simulated_answers[turn - 1] if turn <= len(simulated_answers) else "I don't know."
        print(f"   Patient: {answer}")

        context.add_turn(result.question, answer)
        log.append({
            "turn"     : turn,
            "question" : result.question,
            "answer"   : answer,
            "stage"    : result.stage,
        })

    print("\n" + "‚ïê" * 60)
    print(f" ‚úÖ  Session complete ‚Äî {len(log)} questions asked.")
    print("‚ïê" * 60)
    return log


print("‚úÖ Conversation loop defined.")

---
## üß™ Section 10 ‚Äî Test Cases

Three scenarios: standard English, Kinyarwanda, and a red flag case.

In [None]:
# ‚îÄ‚îÄ Test Case 1: English ‚Äî headache ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
state_1 = PatientState(
    age             = 34,
    chief_complaint = "headache",
    body_part       = "head",
    # severity, duration, associated_symptoms not yet known
)

answers_1 = [
    "It started this morning.",
    "About a seven out of ten.",
    "It's getting worse.",
    "I feel a bit nauseous.",
    "I can't really focus on work.",
    "No, I don't have any known conditions.",
]

log_1 = run_conversation(state_1, answers_1, max_questions=5)

In [None]:
# ‚îÄ‚îÄ Test Case 2: Kinyarwanda ‚Äî abdominal pain ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
state_2 = PatientState(
    chief_complaint = "ububabare mu nda",   # stomach pain in Kinyarwanda
    body_part       = "nda",               # abdomen
    # other fields unknown
)

answers_2 = [
    "Kuva ejo.",                     # Since yesterday
    "Ni uburemere cyane.",           # It is very heavy/severe
    "Nshaka kuruka ariko sinabikora.",# Feel like vomiting but haven't
    "Sinashye neza.",                # Didn't sleep well
    "Oya, nta ndwara nsanzwe.",      # No known conditions
]

log_2 = run_conversation(state_2, answers_2, max_questions=4)

In [None]:
# ‚îÄ‚îÄ Test Case 3: English ‚Äî red flag (chest pain) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
state_3 = PatientState(
    age               = 58,
    chief_complaint   = "chest pain",
    duration          = "2 hours",
    body_part         = "chest",
    red_flags_present = True,           # Already flagged by Model B
)

answers_3 = [
    "Yes, it's hard to breathe deeply.",
    "No, I haven't fainted but I feel dizzy.",
    "The pain goes to my left arm.",
    "Eight out of ten.",
]

log_3 = run_conversation(state_3, answers_3, max_questions=4)

---
## üì§ Section 11 ‚Äî Export Session Log

In [None]:
import datetime
from google.colab import files

def export_session(log: list, label: str = "session"):
    """Save a conversation log to JSON and download it."""
    export = {
        "timestamp"      : datetime.datetime.now().isoformat(),
        "model"          : GEMINI_MODEL_NAME,
        "total_questions": len(log),
        "turns"          : log,
    }
    path = f"model_c_{label}.json"
    with open(path, "w", encoding="utf-8") as f:
        json.dump(export, f, ensure_ascii=False, indent=2)
    print(f"‚úÖ Saved: {path}")
    files.download(path)


# Export whichever session log you want:
# export_session(log_1, "english_headache")
# export_session(log_2, "kinyarwanda_abdomen")
# export_session(log_3, "redflag_chest_pain")

print("‚úÖ Export function ready. Uncomment the line for the session you want to download.")

---
## üìù Notes

| Topic | Detail |
|---|---|
| **Temperature = 0.2** | Allows natural phrasing variation while keeping question focus consistent. Lower to 0.0 for fully deterministic output. |
| **max_output_tokens = 80** | Hard ceiling that forces Gemini to output a single short question and nothing more. |
| **Red flag mode** | When `red_flags_present = True`, the red flag follow-up list is injected into the prompt and takes priority over the standard checklist. |
| **Kinyarwanda** | The model is instructed to match the patient's language. For consistent Kinyarwanda output, verify with native speakers and refine phrasings in the prompt or simulated answers. |
| **Coverage vs. order** | The checklist guarantees eventual topic coverage; the model chooses the most appropriate order per context. |
| **Production wiring** | Replace `simulated_answers` in `run_conversation()` with live Model A transcriptions to go real-time. |
| **Data path** | Update file paths for training transcripts and annotated flows once finalised. |