In [None]:
# %% [markdown]
# Ontology-Aware LLM Agent (LangChain + Ollama, FOSS-only)
# Run this top-to-bottom in Jupyter. Tested with Python 3.10+.

# %%capture
%pip install -q langchain langchain-community regex pydantic

# %%
import re, math, json
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple

from langchain_community.chat_models import ChatOllama
from langchain.prompts import ChatPromptTemplate

# -------------------------
# 1) Ontology-Lite Context
# -------------------------
CURRENT_YEAR = 2025
LEGAL_MARRIAGE_AGE = 18
MAX_AGE = 110

LANDMARKS = {
    "Eiffel Tower": {"built_year": 1889},
    "Burj Khalifa": {"built_year": 2010},
}

CITY_POPULATION_HINT = {
    # If a city is widely known or described as large, "quiet/empty" is suspicious.
    "Los Angeles": {"population_hint": "high"},
    "Paris": {"population_hint": "high"},
    "Dubai": {"population_hint": "high"},
}

QUIET_TERMS = {"quiet", "empty", "calm", "peaceful"}
BUSY_TERMS = {"busy", "noisy", "crowded"}

TRAITS = {"resilient", "pragmatic", "sociable", "reserved", "curious", "patient"}  # whitelist
OCCUPATIONS = {"nurse", "engineer", "teacher", "artist", "chef", "driver"}

ALLERGEN_MAP = {
    "nuts": {"walnut", "almond", "hazelnut", "nut"},
    "gluten": {"bread", "pasta", "cake", "noodles"},
    "seafood": {"shrimp", "fish", "oyster", "crab"},
    "dairy": {"milk", "cheese", "yogurt", "butter", "cream"},
}

COOKING_REQUIRING_POWER = {"bake", "oven", "microwave"}
NO_POWER_PATTERNS = {r"no electricity", r"without electricity", r"power outage"}

# -------------------------
# 2) Data Structures
# -------------------------
@dataclass
class Assertion:
    subj: str
    pred: str
    obj: Any
    span: Tuple[int, int]
    confidence: float = 1.0

@dataclass
class Violation:
    rule: str
    description: str
    fix_hint: str

# -------------------------
# 3) IE-Lite Extraction
# -------------------------
AGE_RE = re.compile(r"\b([A-Z][a-z]+)\s+(?:is|was|aged)\s+(\d{1,3})(?:\s*years?\s*old)?\b")
MARRIED_RE = re.compile(r"\b([A-Z][a-z]+)\s+(?:is|got|was)\s*married\b")
OCCUP_RE = re.compile(r"\b([A-Z][a-z]+).*?\b(?:is|works as (?:an?|the)|became|a)\s+([A-Za-z]+)\b")
CITY_DESC_RE = re.compile(r"\b(?:in|at|from)\s+([A-Z][a-z]+)\b.*?\b(quiet|empty|calm|busy|noisy|crowded)\b", re.IGNORECASE)
ALLERGY_RE = re.compile(r"\b([A-Z][a-z]+)\s+(?:is|was)\s+allergic\s+to\s+([a-z]+)\b", re.IGNORECASE)
ATE_RE = re.compile(r"\b([A-Z][a-z]+)\s+(?:ate|eats|eating)\s+([a-z]+(?:\s+[a-z]+)?)\b", re.IGNORECASE)
BUILT_IN_RE = re.compile(r"\b([A-Z][a-z]+(?:\s[A-Z][a-z]+)*)\b.*?\b(?:built|constructed)\s+in\s+(\d{3,4})\b")
YEARS_SINCE_RE = re.compile(r"\b([A-Z][a-z]+(?:\s[A-Z][a-z]+)*)\b.*?\b(?:been\s+(?:there|standing)\s+for|age\s+of)\s+(\d{1,3})\s+years?\b", re.IGNORECASE)
ARITH_50_3_RE = re.compile(r"\b(?:50\s*/\s*3|split\s+50\s+(?:between|among)\s+3)\b.*?\b(\d{1,3})\b")

def extract(text: str) -> List[Assertion]:
    assertions: List[Assertion] = []

    for m in AGE_RE.finditer(text):
        assertions.append(Assertion(m.group(1), "hasAge", int(m.group(2)), m.span()))
    for m in MARRIED_RE.finditer(text):
        assertions.append(Assertion(m.group(1), "isMarried", True, m.span()))
    for m in OCCUP_RE.finditer(text):
        person, job = m.group(1), m.group(2).lower()
        if job in OCCUPATIONS:
            assertions.append(Assertion(person, "hasOccupation", job, m.span(), 0.9))
    for m in CITY_DESC_RE.finditer(text):
        city, desc = m.group(1), m.group(2).lower()
        assertions.append(Assertion(city, "hasDescriptor", desc, m.span(), 0.9))
    for m in ALLERGY_RE.finditer(text):
        assertions.append(Assertion(m.group(1), "hasAllergy", m.group(2).lower(), m.span()))
    for m in ATE_RE.finditer(text):
        # crude semantic error catch: "ate the table"
        food = m.group(2).lower()
        assertions.append(Assertion(m.group(1), "consumed", food, m.span(), 0.7))
    for m in BUILT_IN_RE.finditer(text):
        landmark, year = m.group(1), int(m.group(2))
        assertions.append(Assertion(landmark, "builtYear", year, m.span(), 0.9))
    for m in YEARS_SINCE_RE.finditer(text):
        landmark, years = m.group(1), int(m.group(2))
        assertions.append(Assertion(landmark, "yearsSince", years, m.span(), 0.8))
    for m in ARITH_50_3_RE.finditer(text):
        claimed = int(m.group(1))
        assertions.append(Assertion("split_50_by_3", "claim", claimed, m.span()))

    # power context & cooking verbs (lightweight)
    if any(re.search(pat, text, re.IGNORECASE) for pat in NO_POWER_PATTERNS):
        assertions.append(Assertion("world", "noPower", True, (0, 0)))
    if any(kw in text.lower() for kw in COOKING_REQUIRING_POWER):
        assertions.append(Assertion("world", "cookingRequiresPower", True, (0, 0)))

    return assertions

# -------------------------
# 4) Deterministic Checker
# -------------------------
def check(text: str, assertions: List[Assertion]) -> List[Violation]:
    viol: List[Violation] = []
    ages: Dict[str, int] = {}
    married: Dict[str, bool] = {}
    occs: Dict[str, List[str]] = {}

    allergies: Dict[str, str] = {}
    consumed: Dict[str, List[str]] = {}

    city_desc: Dict[str, List[str]] = {}
    landmark_years: Dict[str, int] = {}
    landmark_since: Dict[str, int] = {}

    no_power = any(a.pred == "noPower" and a.obj is True for a in assertions)
    cooking = any(a.pred == "cookingRequiresPower" and a.obj is True for a in assertions)

    for a in assertions:
        if a.pred == "hasAge": ages[a.subj] = a.obj
        elif a.pred == "isMarried": married[a.subj] = True
        elif a.pred == "hasOccupation": occs.setdefault(a.subj, []).append(a.obj)
        elif a.pred == "hasAllergy": allergies[a.subj] = a.obj
        elif a.pred == "consumed": consumed.setdefault(a.subj, []).append(a.obj)
        elif a.pred == "hasDescriptor": city_desc.setdefault(a.subj, []).append(a.obj)
        elif a.pred == "builtYear": landmark_years[a.subj] = a.obj
        elif a.pred == "yearsSince": landmark_since[a.subj] = a.obj
        elif a.subj == "split_50_by_3" and a.pred == "claim":
            if a.obj == 20:
                viol.append(Violation(
                    "ArithmeticSplit",
                    "Claimed 50/3 = 20.",
                    "Correct to 16 or 16–17 with remainder; or rephrase as an approximate split."
                ))

    # Age & marriage; max age
    for p, age in ages.items():
        if age > MAX_AGE:
            viol.append(Violation("MaxAge", f"{p} age {age} exceeds {MAX_AGE}.", f"Set age ≤ {MAX_AGE}."))
        if married.get(p, False) and age < LEGAL_MARRIAGE_AGE:
            viol.append(Violation("Age&Marriage", f"{p} is married at {age}.", f"Increase age to ≥ {LEGAL_MARRIAGE_AGE} or remove marriage."))

    # Single occupation
    for p, os in occs.items():
        if len(set(os)) > 1:
            viol.append(Violation("SingleOccupation", f"{p} has multiple occupations: {os}.",
                                  "Keep one occupation consistent with the narrative."))

    # City descriptor vs. population hint
    for city, descs in city_desc.items():
        if CITY_POPULATION_HINT.get(city, {}).get("population_hint") == "high":
            if any(d in QUIET_TERMS for d in descs) and not any(d in BUSY_TERMS for d in descs):
                viol.append(Violation("CityDescriptor",
                                      f"{city} described as {descs} despite high population.",
                                      "Harmonize descriptors (e.g., ‘busy’/‘crowded’) or contextualize a quiet setting (time/place)."))

    # Landmark arithmetic
    for lm, y in landmark_years.items():
        if lm in LANDMARKS and LANDMARKS[lm]["built_year"] != y:
            viol.append(Violation("LandmarkYear",
                                  f"{lm} built year stated {y} but known {LANDMARKS[lm]['built_year']}.",
                                  f"Use built in {LANDMARKS[lm]['built_year']}."))
    for lm, years in landmark_since.items():
        if lm in LANDMARKS:
            expected = CURRENT_YEAR - LANDMARKS[lm]["built_year"]
            if abs(expected - years) > 1:
                viol.append(Violation("LandmarkAge",
                                      f"{lm} age stated {years} years; expected ≈ {expected}.",
                                      f"Update to {expected} years (in {CURRENT_YEAR})."))

    # Causal cooking (power)
    if no_power and cooking:
        viol.append(Violation("CausalCooking",
                              "Baking/cooking requiring power while no electricity.",
                              "Add a power source, or change to a no-bake method."))

    # Allergy vs. consumption + semantic “ate the table”
    for p, allergen in allergies.items():
        foods = consumed.get(p, [])
        bad_set = ALLERGEN_MAP.get(allergen, set())
        if any(any(term in f for term in bad_set) for f in foods):
            viol.append(Violation("AllergyFood",
                                  f"{p} allergic to {allergen} but consumed {foods}.",
                                  "Replace dish/ingredient with a safe alternative."))
    if re.search(r"\bate\s+the\s+table\b", text.lower()):
        viol.append(Violation("SemanticSelection",
                              "Literal ‘ate the table’.",
                              "Change to ‘ate at the table’ or similar."))

    return viol

# -------------------------
# 5) Minimal-Edit Rewriter (Ollama)
# -------------------------
SYSTEM_PROMPT = """You are a careful editor. Apply ONLY the minimal textual edits needed
to resolve the listed violations. Preserve style, plot, and names. Do not add new facts.
Output only the revised story text, nothing else.
"""

def rewrite_with_llm(text: str, violations: List[Violation], model: str = "llama3.1:8b") -> str:
    chat = ChatOllama(model=model, temperature=0)
    vlist = [{"rule": v.rule, "description": v.description, "fix_hint": v.fix_hint} for v in violations]
    prompt = ChatPromptTemplate.from_messages([
        ("system", SYSTEM_PROMPT),
        ("human", "Original text:\n```{text}```\n\nViolations:\n```{violations}```\n\nReturn ONLY the minimally revised text.")
    ])
    msg = prompt.format_messages(text=text, violations=json.dumps(vlist, ensure_ascii=False, indent=2))
    resp = chat.invoke(msg)
    return resp.content.strip()

# -------------------------
# 6) Orchestrator
# -------------------------
def run_once(text: str) -> Tuple[str, List[Violation]]:
    assertions = extract(text)
    violations = check(text, assertions)
    return text, violations

def repair(text: str, max_loops: int = 2, model: str = "llama3.1:8b") -> Dict[str, Any]:
    history = []
    current = text
    for i in range(max_loops):
        _, v = run_once(current)
        history.append({"pass": i+1, "violations": v})
        if not v:
            break
        current = rewrite_with_llm(current, v, model=model)
    final_text, final_v = run_once(current)
    return {"final_text": final_text, "final_violations": final_v, "history": history}

# -------------------------
# 7) Evaluation helper
# -------------------------
RULES_COUNT = 8  # adjust if you enable/disable checks (age/marriage, max age, single occ, city, landmark year, landmark age, causal cook, allergy, arithmetic, semantic)
def consistency_score(violations: List[Violation], rules_count: int = RULES_COUNT) -> float:
    return max(0.0, 1.0 - (len(violations) / max(1, rules_count)))

def summarize_report(result: Dict[str, Any]) -> str:
    hist = result["history"]
    final_v = result["final_violations"]
    score = consistency_score(final_v)
    lines = []
    for p in hist:
        lines.append(f"Pass {p['pass']}: {len(p['violations'])} violation(s).")
        for v in p["violations"]:
            lines.append(f"  - [{v.rule}] {v.description} | Fix: {v.fix_hint}")
    lines.append(f"\nFinal check: {len(final_v)} violation(s). Consistency score = {score:.2f}")
    if final_v:
        for v in final_v:
            lines.append(f"  - [{v.rule}] {v.description}")
    return "\n".join(lines)

# -------------------------
# 8) Demo
# -------------------------
demo_story = """
Amelia is 17 years old and is married. In Paris, the streets felt empty and quiet.
She works as a nurse and an engineer. With a power outage, she baked a cake.
Amelia is allergic to nuts but ate walnut cake. The Eiffel Tower has been standing for 100 years.
They split 50 between 3 friends and each got 20. He ate the table.
"""

result = repair(demo_story, max_loops=2, model="llama3.1:8b")
print(summarize_report(result))
print("\n--- Revised Story ---\n")
print(result["final_text"])
