<a href="https://colab.research.google.com/github/mertserbes/wargame3d/blob/main/Wargame3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q transformers accelerate einops


In [None]:
import json
import re
from typing import Any, Dict, Tuple, List

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# 4B'den büyük, resmi instruct model:
MODEL_NAME = "Qwen/Qwen2.5-14B-Instruct"
# Alternatif:
# MODEL_NAME = "Qwen/Qwen3-14B"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype="auto",
)

# ---------------------------
# Helpers: robust JSON parsing
# ---------------------------

def extract_json(text: str) -> str:
    """
    Model bazen başına/sonuna karakter koyabiliyor.
    En dıştaki { ... } bloğunu yakalayıp döndürür.
    """
    text = text.strip()

    # en uzun JSON objesini yakalamaya çalış
    # (ilk '{' ile son '}' arası)
    if "{" in text and "}" in text:
        start = text.find("{")
        end = text.rfind("}")
        candidate = text[start:end+1].strip()
        return candidate

    return text

def try_parse_json(text: str) -> Tuple[bool, Any, str]:
    """
    JSON parse etmeyi dener.
    """
    candidate = extract_json(text)
    try:
        return True, json.loads(candidate), candidate
    except Exception as e:
        return False, None, candidate

def llm_fix_json(bad_json_text: str) -> str:
    """
    JSON bozuksa LLM'e sadece onarma görevi ver.
    """
    fix_messages = [
        {
            "role": "system",
            "content": (
                "You are a strict JSON repair tool. "
                "Return ONLY valid JSON. No explanations."
            )
        },
        {
            "role": "user",
            "content": (
                "Fix this so it becomes valid JSON. "
                "Do not change meaning, only repair formatting.\n\n"
                f"{bad_json_text}"
            )
        }
    ]
    input_ids = tokenizer.apply_chat_template(
        fix_messages, add_generation_prompt=True, return_tensors="pt"
    ).to(model.device)

    with torch.no_grad():
        out = model.generate(
            input_ids,
            max_new_tokens=512,
            temperature=0.0,
            top_p=1.0
        )

    gen = out[0][input_ids.shape[-1]:]
    return tokenizer.decode(gen, skip_special_tokens=True).strip()

# ---------------------------
# State -> features (heuristic)
# ---------------------------

def parse_state(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Senin state formatına göre temel alanlar:
    - "103" -> team_id
    - "6" -> lat
    - "7" -> lon
    - "0" -> altitude_ft (uçaklar için)
    - "3" -> heading_deg (uçaklar için)
    - "98" -> aktif/alive flag gibi (senaryo içinde kullanmak istersen)
    """
    entities = []
    for k, v in state.items():
        if k == "scene_step":
            continue
        if not isinstance(v, dict):
            continue
        if "103" not in v:
            continue

        ent = {
            "id": str(k),
            "team_id": v.get("103"),
            "lat": v.get("6"),
            "lon": v.get("7"),
            "alt_ft": v.get("0"),
            "heading_deg": v.get("3"),
            "active_flag_98": v.get("98"),
            # bazı state kayıtları NaN string olabiliyor
            "raw": v
        }
        entities.append(ent)

    team1 = [e for e in entities if e["team_id"] == 1]
    team2 = [e for e in entities if e["team_id"] == 2]

    # Basit mavi “defense density” ölçüsü: team1 noktalarının bounding box'ı
    def bbox(ents):
        lats = [e["lat"] for e in ents if isinstance(e["lat"], (int, float))]
        lons = [e["lon"] for e in ents if isinstance(e["lon"], (int, float))]
        if not lats or not lons:
            return None
        return {
            "lat_min": min(lats), "lat_max": max(lats),
            "lon_min": min(lons), "lon_max": max(lons),
        }

    return {
        "scene_step": state.get("scene_step"),
        "team1_entities": team1,
        "team2_entities": team2,
        "team1_bbox": bbox(team1),
        "team2_bbox": bbox(team2),
    }

def compact_team_summary(features: Dict[str, Any]) -> Dict[str, Any]:
    """
    LLM prompt'u aşırı şişmesin diye özet çıkar.
    """
    t1 = features["team1_entities"]
    t2 = features["team2_entities"]

    def compact_list(ents, limit=12):
        out = []
        for e in ents[:limit]:
            out.append({
                "id": e["id"],
                "lat": e["lat"],
                "lon": e["lon"],
                "alt_ft": e["alt_ft"],
                "heading_deg": e["heading_deg"],
                "active": e["active_flag_98"],
            })
        return out

    return {
        "scene_step": features["scene_step"],
        "team1_count": len(t1),
        "team2_count": len(t2),
        "team1_bbox": features["team1_bbox"],
        "team2_bbox": features["team2_bbox"],
        "team1_entities_sample": compact_list(t1),
        "team2_entities_sample": compact_list(t2),
    }

# ---------------------------
# LLM planner
# ---------------------------

PLANNER_SYSTEM = """
You are planning for a fictional simulation called Wargame3D (NOT real-world).
You will receive a compact snapshot of Team 1 (blue) and Team 2 (red).
Goal: Create a RED TEAM (team_id=2) strategy plan against BLUE (team_id=1).

CRITICAL OUTPUT RULES:
- Return ONLY valid JSON (no markdown, no extra text).
- Include BOTH: (a) strategy/actions and (b) reasons/justifications INSIDE JSON.
- Assign roles to EACH Team 2 entity (role must be one of: leader, strike, flank, decoy, support, scout).
- Provide an action plan per role (high-level simulation actions; do not mention real-world instructions).
- All Team 2 aircraft must have: "mode": "mert_demo".
- Be consistent, coherent, and make the plan depend on Team 1 distribution (bbox / density / positions).

JSON SCHEMA (must follow):
{
  "name": string,
  "desc": string,
  "jsbsim_dt_hz": 50,
  "entity_sim_freq": 5,
  "entities": { ... }  // can be minimal, but must include Team 2 entries with mode=mert_demo
  "red_plan": {
    "intent": string,
    "roles": {
      "<team2_entity_id>": {
        "role": "leader|strike|flank|decoy|support|scout",
        "objectives": [string, ...],
        "actions": [
          {"t": int, "action": string, "params": object, "why": string},
          ...
        ],
        "risk_controls": [string, ...]
      },
      ...
    },
    "team_level_rationale": [string, ...]
  }
}

Constraints:
- Actions should be time-indexed with small integers like t=0,1,2... (simulation steps, not seconds).
- Use concise "why" explanations.
"""

def generate_with_llm(prompt: str, max_new_tokens=900, temperature=0.2, top_p=0.9) -> str:
    messages = [
        {"role": "system", "content": PLANNER_SYSTEM.strip()},
        {"role": "user", "content": prompt},
    ]
    input_ids = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt"
    ).to(model.device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
        )

    generated_ids = output_ids[0][input_ids.shape[-1]:]
    text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return text.strip()

def build_planner_prompt(state: Dict[str, Any]) -> str:
    features = parse_state(state)
    summary = compact_team_summary(features)

    # Team2 ids listesi (role ataması zorunlu)
    team2_ids = [e["id"] for e in features["team2_entities"]]

    prompt = {
        "input_summary": summary,
        "required_team2_entity_ids": team2_ids,
        "notes": [
            "Team 2 is the player side (red).",
            "Make a sensible plan based on Team 1 positions and density.",
            "Keep it simulation-oriented and high-level."
        ]
    }
    return json.dumps(prompt, ensure_ascii=False)

def plan_red_strategy_from_state(state: Dict[str, Any], retries: int = 3) -> Dict[str, Any]:
    prompt = build_planner_prompt(state)

    raw = generate_with_llm(prompt)
    ok, obj, candidate = try_parse_json(raw)

    if ok:
        return obj

    # 1) JSON repair dene
    repaired = llm_fix_json(candidate)
    ok2, obj2, _ = try_parse_json(repaired)
    if ok2:
        return obj2

    # 2) bir iki retry: "ONLY JSON" diye daha sert uyar
    for _ in range(retries - 1):
        raw = generate_with_llm(
            prompt + "\n\nREMINDER: Output MUST be ONLY valid JSON. No extra text.",
            temperature=0.1
        )
        ok, obj, candidate = try_parse_json(raw)
        if ok:
            return obj

        repaired = llm_fix_json(candidate)
        ok2, obj2, _ = try_parse_json(repaired)
        if ok2:
            return obj2

    raise RuntimeError("LLM output could not be parsed as valid JSON after retries.")

# ---------------------------
# Example usage
# ---------------------------

# 1) Buraya state JSON'unu koy (senin mesajındaki ilk JSON)
STATE = {
  # ... buraya state dict'i yapıştır ...
}

final_scenario = plan_red_strategy_from_state(STATE)

# Dosyaya kaydet
with open("final_scenario.json", "w", encoding="utf-8") as f:
    json.dump(final_scenario, f, ensure_ascii=False, indent=2)

print("Saved -> final_scenario.json")
print(json.dumps(final_scenario, ensure_ascii=False, indent=2)[:2000])


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

model-00001-of-00008.safetensors:   0%|          | 0.00/3.89G [00:00<?, ?B/s]

model-00007-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00008-of-00008.safetensors:   0%|          | 0.00/1.70G [00:00<?, ?B/s]

model-00004-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00006-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00005-of-00008.safetensors:   0%|          | 0.00/3.98G [00:00<?, ?B/s]

model-00003-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00002-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]