In [None]:
"""
RL Pentest Agent — nmap PoC (Core-Only)
---------------------------------------
핵심만 남긴 최소 구현: 스키마→행동 생성→상상→실행→보상(외재/내재)→신념 업데이트.
- 관측 제한: 문법(스키마), 자기 히스토리, 마지막 원시 출력만 사용.
- 계획: 정책이 top-k 후보 → 월드모델 상상 점수로 1개 선택.
- 보상: 외재(FLAG=+1, 오류=-0.2), 내재(예측오차 기반).

제거: 상세 Validator/KnowledgeDB(저장 없이 단순 히스토리), 잡기능.
테스트: 간단 실행/멀티라인/정책크레딧/템플릿 렌더.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
import json, re, random, time

# -----------------------------
# Schema (문법만 전달)
# -----------------------------
@dataclass
class ToolSchema:
    name: str
    command_template: str

    @staticmethod
    def from_yaml_text(yaml_text: str) -> "ToolSchema":
        try:
            import yaml  # type: ignore
            obj = yaml.safe_load(yaml_text)
        except Exception:
            obj = json.loads(yaml_text)
        tool = obj.get("tool", {})
        return ToolSchema(name=tool["name"], command_template=tool["command_template"]) 

# -----------------------------
# Template (아주 작은 서브셋)
# -----------------------------
class Template:
    _var = re.compile(r"{{\s*([a-zA-Z0-9_\.]+)\s*}}")
    _if = re.compile(r"{%\s*if\s+([a-zA-Z0-9_\.]+)\s*%}(.*?){%\s*endif\s*%}", re.S)
    _join = re.compile(r"{{\s*([a-zA-Z0-9_\.]+)\s*\|\s*join\('([^']*)'\)\s*}}")

    @staticmethod
    def render(s: str, ctx: Dict[str, Any]) -> str:
        s = Template._if.sub(lambda m: m.group(2) if ctx.get(m.group(1)) else "", s)
        s = Template._join.sub(lambda m: m.group(2).join(map(str, ctx.get(m.group(1), []))) if isinstance(ctx.get(m.group(1)), (list,tuple)) else "", s)
        s = Template._var.sub(lambda m: str(ctx.get(m.group(1), "")), s)
        return re.sub(r"\s+", " ", s).strip()

# -----------------------------
# Belief & World Model (핵심)
# -----------------------------
class Belief:
    def __init__(self, dim: int = 64):
        self.dim = dim
        self.h = [0.0]*dim
    def _embed(self, text: str) -> List[float]:
        v = [0.0]*self.dim
        for tok in re.findall(r"[a-zA-Z0-9_:-]+", text.lower()):
            v[hash(tok)%self.dim] += 1.0
        n = (sum(x*x for x in v) ** 0.5) or 1.0
        return [x/n for x in v]
    def update(self, cmd: str, out_text: str):
        x = self._embed(cmd+" "+out_text)
        a = 0.2
        self.h = [(1-a)*h + a*x_i for h,x_i in zip(self.h,x)]
    def vec(self) -> List[float]:
        return list(self.h)

class WorldModel:
    def __init__(self, belief: Belief):
        self.belief = belief
    def _act_vec(self, ctx: Dict[str, Any]) -> List[float]:
        toks = []
        for k in ["discovery","scans","service_version","os_detect","script"]:
            v = ctx.get(k) or []
            toks += v if isinstance(v, list) else [str(v)]
        if ctx.get("ports"): toks.append("p:"+str(ctx["ports"]))
        dim = len(self.belief.vec()); v=[0.0]*dim
        for t in toks: v[hash(t)%dim]+=1.0
        n=(sum(x*x for x in v)**0.5) or 1.0
        return [x/n for x in v]
    def imagine(self, obs: Dict[str, Any], ctx: Dict[str, Any], H: int=3) -> Dict[str, Any]:
        b = self.belief.vec(); a = self._act_vec(ctx)
        val = max(min(sum(bi*ai for bi,ai in zip(b,a)),0.8),-0.8)
        pred_next = [(bi+a_i)/2 for bi,a_i in zip(b,a)]
        return {"pred_next":pred_next, "pred_int":[0.1+abs(val)*0.2]*H, "value":0.2+0.5*max(0.0,val)}
    def intrinsic(self, pred_vec: List[float], actual_text: str) -> float:
        act = self.belief._embed(actual_text)
        err = (sum((p-a)**2 for p,a in zip(pred_vec,act)) ** 0.5)
        return max(0.0, min(0.5, err))

# -----------------------------
# Policy (분포→top-k)
# -----------------------------
class Policy:
    def __init__(self):
        self.pref = {"-sS":0.0, "-sU":0.0, "-sT":0.0, "-sV":0.0, "-O":0.0, "-sC":0.0}
    def credit(self, ctx: Dict[str, Any], r: float):
        for k in ["discovery","scans","service_version","os_detect","script"]:
            v = ctx.get(k) or []
            for f in (v if isinstance(v,list) else [v]):
                if f in self.pref: self.pref[f]=0.9*self.pref[f]+0.1*r
    def propose(self, obs: Dict[str, Any], k: int=5) -> List[Dict[str, Any]]:
        base = obs.get("targets", ["127.0.0.1"]) 
        cand=[]
        for _ in range(k*2):
            scans = random.sample(["-sS","-sU","-sT","-sA","-sN","-sF","-sX"], k=random.randint(1,2))
            best = max(self.pref, key=self.pref.get)
            if best not in scans: scans[0]=best
            ctx = {
                "targets": base,
                "discovery": random.sample(["-Pn","-PE","-PS"], k=random.randint(0,1)),
                "scans": scans,
                "ports": random.choice([None, "22", "80,443", "1-1024"]),
                "service_version": random.sample(["-sV","--version-light"], k=random.randint(0,1)),
                "os_detect": random.sample(["-O"], k=random.randint(0,1)),
                "script": random.sample(["-sC","--script http-*"], k=random.randint(0,1)),
                "timing": random.choice([None,3,4]),
            }
            cand.append({k:v for k,v in ctx.items() if v not in (None,[],{})})
        return cand[:k]

# -----------------------------
# Env (관측 제한 + 외/내재 보상 트리거)
# -----------------------------
class Env:
    def __init__(self):
        self.history: List[Tuple[str,str]] = []
    def reset(self) -> Dict[str, Any]:
        self.history.clear()
        return {"targets":["127.0.0.1"], "history": []}
    def step(self, cmd: str) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]:
        time.sleep(0.002)
        if random.random()<0.1:
            raw = f"EXECUTED: {cmd}\nERROR: timeout"; ext=-0.2
        else:
            hint = random.choice(["open port 80","service: ssh","dir: /admin","no response",""])
            raw = f"EXECUTED: {cmd}\nOUTPUT: {hint}"; ext = 1.0 if "FLAG{" in hint else 0.0
        self.history.append((cmd, raw))
        obs = {"last_output": raw, "targets":["127.0.0.1"], "hist_len": len(self.history)}
        return obs, ext, False, {}

# -----------------------------
# Agent (상상→실행→사후보상→신념/정책 업데이트)
# -----------------------------
@dataclass
class Cfg:
    horizon:int=3; topk:int=5; alpha:float=0.5; beta:float=1.0

class Agent:
    def __init__(self, schema: ToolSchema, env: Env, cfg: Cfg=Cfg()):
        self.schema=schema; self.env=env; self.cfg=cfg
        self.tpl=Template(); self.belief=Belief(); self.world=WorldModel(self.belief); self.pi=Policy()
        self.state=self.env.reset(); self._last_pred=None
    def _render(self, ctx: Dict[str, Any]) -> str:
        return self.tpl.render(self.schema.command_template, ctx)
    def _score(self, obs: Dict[str, Any], ctx: Dict[str, Any]) -> Tuple[float, Dict[str, Any]]:
        imag=self.world.imagine(obs, ctx, H=self.cfg.horizon)
        score=self.cfg.alpha*sum(imag["pred_int"]) + self.cfg.beta*imag["value"]
        return score, imag
    def step(self) -> Tuple[str, Dict[str, Any]]:
        cands=self.pi.propose(self.state, k=self.cfg.topk)
        scored=[(self._score(self.state,c)[0], c, self._score(self.state,c)[1]) for c in cands]
        _, ctx, imag = max(scored, key=lambda x:x[0])
        cmd=self._render(ctx)
        next_obs, r_ext, done, _ = self.env.step(cmd)
        r_int=self.world.intrinsic(imag["pred_next"], next_obs.get("last_output",""))
        total=r_ext + self.cfg.alpha*r_int + self.cfg.beta*imag.get("value",0.0)
        self.belief.update(cmd, next_obs.get("last_output",""))
        self.pi.credit(ctx, total)
        self.state={**self.state, **next_obs}
        return cmd, next_obs

# -----------------------------
# Demo & Tests
# -----------------------------
if __name__ == "__main__":
    YAML_MIN = """
tool:
  name: nmap
  command_template: |
    nmap {{ targets }}{% if discovery %} {{ discovery | join(' ') }}{% endif %}{% if scans %} {{ scans | join(' ') }}{% endif %}{% if ports %} -p {{ ports }}{% endif %}{% if service_version %} {{ service_version | join(' ') }}{% endif %}{% if os_detect %} {{ os_detect | join(' ') }}{% endif %}{% if script %} {{ script | join(' ') }}{% endif %}{% if timing %} -T{{ timing }}{% endif %}
"""
    schema=ToolSchema.from_yaml_text(YAML_MIN)
    env=Env(); agent=Agent(schema, env)

    # Demo
    for _ in range(2):
        cmd, obs = agent.step(); print(cmd); print(obs.get("last_output"))

    # Tests (핵심만 검증)
    def test_env_multiline_ok():
        obs, r, d, _ = env.step("nmap 127.0.0.1 -sS")
        assert "EXECUTED: " in obs["last_output"] and "\n" in obs["last_output"]
    def test_step_runs_once():
        c, o = agent.step(); assert c.startswith("nmap ") and "last_output" in o
    def test_policy_credit_updates():
        before=dict(agent.pi.pref); agent.pi.credit({"scans":["-sS"]}, 1.0); assert agent.pi.pref["-sS"]>=before["-sS"]
    def test_template_render_simple():
        t=Template(); out=t.render("X {{ a }}{% if b %} {{ b | join(' ') }}{% endif %}", {"a":"1","b":["2","3"]}); assert out=="X 1 2 3"

    test_env_multiline_ok(); test_step_runs_once(); test_policy_credit_updates(); test_template_render_simple()
    print("Core tests passed.")
