In [None]:
from pathlib import Path
from typing import Dict, Any, List
import json, re
from llama_cpp import Llama
LLM_PATH = str(Path(__file__).resolve().parents[1] / "models" / "llm" / "your-7b-instruct-q5_k_m.gguf")
llm = Llama(
    model_path=LLM_PATH,
    n_ctx=4096,
    n_gpu_layers=50,
    n_threads=8,
    verbose=False,
)

SYSTEM_PROMPT = (
    "You are KISAAN-SATHI, an offline farm advisor for smallholders.\n"
    "Rules:\n"
    "- Use simple, step-by-step language.\n"
    "- Prefer non-chemical remedies first. If chemicals appear, include PPE and PHI.\n"
    "- Ask at most one concise follow-up if key info (crop/symptom) is missing.\n"
    "- Do not hallucinate; keep advice practical and safe.\n"
)

def _chat(messages: List[Dict[str, str]], max_tokens: int = 512) -> str:
    out = llm.create_chat_completion(
        messages=messages,
        temperature=0.3,
        max_tokens=max_tokens,
    )
    return out["choices"][0]["message"]["content"]
KB_DIR = Path(__file__).resolve().parents[1] / "kb" / "crop_cards"

def _load_docs() -> List[Dict[str, Any]]:
    docs = []
    for p in sorted(KB_DIR.glob("*.md")):
        txt = p.read_text(encoding="utf-8")
        docs.append({"path": str(p.name), "text": txt})
    return docs

_DOCS_CACHE = None

def _get_docs():
    global _DOCS_CACHE
    if _DOCS_CACHE is None:
        _DOCS_CACHE = _load_docs()
    return _DOCS_CACHE

CROP_KEYWORDS = {
    "tomato": ["tomato", "tamatar", "solanum"],
    "chili": ["chili", "mirchi", "capsicum"],
    "brinjal": ["brinjal", "baingan", "eggplant"],
    "rice": ["rice", "dhaan", "paddy"],
    "wheat": ["wheat", "gehun"],
}

DISEASE_HINTS = [
    "early blight", "septoria", "powdery mildew", "bacterial spot", "leaf curl",
]


def simple_retrieve(query: str, k: int = 4) -> List[Dict[str, Any]]:
    q = query.lower()
    docs = _get_docs()
    scores = []
    for d in docs:
        t = d["text"].lower()
        score = 0
        for crop, keys in CROP_KEYWORDS.items():
            score += sum(1 for kw in keys if kw in q and kw in t)
        score += sum(1 for w in DISEASE_HINTS if w in q and w in t)
        score += sum(1 for w in q.split() if w in t)
        scores.append((score, d))
    scores.sort(key=lambda x: x[0], reverse=True)
    return [d for _, d in scores[:k]]
SYMPTOM_TABLE = [
    {
        "crop": "tomato",
        "pattern": ["small brown spots", "yellow halo", "older leaves", "after rain"],
        "candidates": ["early blight", "septoria"],
        "conditions": "warm + humid, rain splash"
    },
    {
        "crop": "chili",
        "pattern": ["white powder", "leaf surface", "dry weather"],
        "candidates": ["powdery mildew"],
        "conditions": "dry days + cool nights"
    },
]


def symptom_rules(crop: str, symptoms: str) -> List[Dict[str, Any]]:
    s = symptoms.lower()
    out = []
    for row in SYMPTOM_TABLE:
        if row["crop"] == crop.lower():
            hits = sum(1 for token in row["pattern"] if token in s)
            out.append({"row": row, "score": hits})
    out.sort(key=lambda x: x["score"], reverse=True)
    if not out:
        return []
    top = out[0]
    return [{"disease": d, "confidence": "medium" if top["score"] >= 1 else "low"} for d in top["row"]["candidates"]]
CHEM_HINTS = ["copper", "chlorothalonil", "mancozeb", "carbendazim"]


def enforce_guardrails(text: str) -> str:
    t = text
    # ensure PPE/PHI if chemicals present
    if any(w in t.lower() for w in CHEM_HINTS):
        if "PPE" not in t and "gloves" not in t.lower():
            t += "\n- Safety: Wear gloves, mask; avoid spray drift; keep kids/animals away."
        if "PHI" not in t:
            t += "\n- PHI: Follow product label; typically 3–7 days before harvest for many copper products."
    # ban silly mixes (bleach + acid)
    if ("bleach" in t.lower() or "sodium hypochlorite" in t.lower()) and ("acid" in t.lower() or "vinegar" in t.lower()):
        t += "\n- Warning: Never mix bleach with acids/vinegar—dangerous chlorine gas."
    return t
NEEDED_FIELDS = ["crop", "symptoms"]


def need_clarification(user_text: str) -> str:
    q = user_text.lower()
    has_crop = any(any(k in q for k in v) for v in CROP_KEYWORDS.values())
    has_symptom = any(w in q for w in ["spot", "daag", "halo", "powder", "curl", "wilt", "sookh"])
    if not has_crop and not has_symptom:
        return "Kaunsi fasal (crop) hai aur patton/phal par kya nishaan (symptoms) dikh rahe hain?"
    if not has_crop:
        return "Kaunsi fasal (crop) hai? Tomato/Chili/Brinjal/Rice/Wheat?"
    if not has_symptom:
        return "Symptoms bataiye—daag, safed powder, murjhna, patta curl, etc.?"
    return ""

def answer(user_text: str) -> str:
    follow = need_clarification(user_text)
    if follow:
        return follow
    crop = None
    q = user_text.lower()
    for c, keys in CROP_KEYWORDS.items():
        if any(k in q for k in keys):
            crop = c
            break
    if not crop:
        crop = "general"
    hits = simple_retrieve(user_text, k=3)
    evidence = "\n\n".join([f"[KB:{h['path']}]\n" + h["text"][:600] for h in hits])
    rules = symptom_rules(crop, user_text)

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": f"Farmer message: {user_text}\nCrop: {crop}\nRuleCandidates: {json.dumps(rules, ensure_ascii=False)}\nEVIDENCE:\n{evidence}\n\nWrite a short, stepwise plan. Start with organic actions. If chemicals are mentioned, include PPE and PHI."}
    ]

    draft = _chat(messages, max_tokens=600)
    final = enforce_guardrails(draft)
    return final