In [None]:
# --- Imports ---
import os
import re
import subprocess
from typing import List, Optional, Dict, Any, TypedDict
from dotenv import load_dotenv
from langgraph.graph import StateGraph, END
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

load_dotenv()

In [None]:
# --- Step 1: Define State ---
class PortScanState(TypedDict):
    target_ip: str
    scan_intent: str
    llm_command_raw: Optional[str]
    validated_command: Optional[str]
    scan_output: Optional[str]
    open_ports: List[int]
    web_ports: List[int]
    detailed_commands: List[str]
    detailed_outputs: List[str]
    error_message: Optional[str]

In [None]:
# --- Step 2: LLM and Prompt Setup ---
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)

FAST_SCAN_PROMPT = ChatPromptTemplate.from_messages([
    ("system",
     "You are a penetration tester and Nmap expert. Based on the user's scan intent and IP, generate a fast Nmap command. "
     "Use options like -T4 -Pn for speed. Do NOT include anything except the raw command. "
     "It must begin with 'nmap' and target the given IP."),
    ("human", "Target: {target_ip}\nIntent: {scan_intent}")
])

DETAILED_SCAN_PROMPT = ChatPromptTemplate.from_messages([
    ("system",
     "You are a penetration tester. Based on this list of open ports and initial scan results, "
     "generate a list of Nmap commands for deeper analysis of any web-related services (e.g. HTTP/S, proxy, admin panels). "
     "Use options like -sCV or --script=http-* where appropriate. Only output Nmap commands, no extra explanation."),
    ("human", "Scan output:\n{scan_output}")
])

fast_scan_chain = FAST_SCAN_PROMPT | llm | StrOutputParser()
detailed_scan_chain = DETAILED_SCAN_PROMPT | llm | StrOutputParser()

In [None]:
# --- Step 3: Node Definitions ---
def generate_scan_command(state: PortScanState) -> PortScanState:
    print("🔧 Generating fast scan command...")
    cmd = fast_scan_chain.invoke({
        "target_ip": state["target_ip"],
        "scan_intent": state["scan_intent"]
    })
    return {"llm_command_raw": cmd.strip(), "error_message": None}

In [None]:
def validate_command(state: PortScanState) -> PortScanState:
    print("🔒 Validating generated command...")
    cmd = state.get("llm_command_raw", "").strip()
    if not cmd.startswith("nmap"):
        return {"error_message": "Invalid command (does not start with nmap)"}
    if any(x in cmd for x in [";", "&&", "|", ">", "<", "rm", "sudo"]):
        return {"error_message": "Command contains forbidden characters"}
    if state["target_ip"] not in cmd:
        return {"error_message": "Target IP missing in command"}
    return {"validated_command": cmd, "error_message": None}

In [None]:
def execute_scan(state: PortScanState) -> PortScanState:
    print(f"🚀 Executing: {state['validated_command']}")
    try:
        output = subprocess.check_output(state["validated_command"], shell=True, text=True, timeout=60)
        return {"scan_output": output, "error_message": None}
    except Exception as e:
        return {"error_message": f"Scan failed: {e}"}

In [None]:
def parse_open_ports(state: PortScanState) -> PortScanState:
    print("🧠 Parsing open ports...")
    output = state.get("scan_output", "")
    ports = set()
    for line in output.splitlines():
        match = re.match(r"(\d+)/tcp\s+open", line)
        if match:
            ports.add(int(match.group(1)))
    return {"open_ports": sorted(list(ports)), "error_message": None}

In [None]:
def identify_web_ports_via_llm(state: PortScanState) -> PortScanState:
    print("🤖 Using LLM to identify web-related ports...")
    try:
        result = detailed_scan_chain.invoke({"scan_output": state["scan_output"]})
        commands = [line.strip() for line in result.splitlines() if line.strip().startswith("nmap")]
        return {
            "detailed_commands": commands,
            "web_ports": [],  # Optionally parse ports again
            "error_message": None
        }
    except Exception as e:
        return {"error_message": f"Web port analysis failed: {e}"}

In [None]:
def execute_detailed_scans(state: PortScanState) -> PortScanState:
    print("🔍 Running detailed scans...")
    results = []
    for cmd in state.get("detailed_commands", []):
        try:
            print(f"→ {cmd}")
            output = subprocess.check_output(cmd, shell=True, text=True, timeout=90)
            results.append(output)
        except Exception as e:
            results.append(f"Error: {e}")
    return {"detailed_outputs": results}

In [None]:
# --- Step 4: Build LangGraph Workflow ---
workflow = StateGraph(PortScanState)

workflow.add_node("generate_scan_command", generate_scan_command)
workflow.add_node("validate_command", validate_command)
workflow.add_node("execute_scan", execute_scan)
workflow.add_node("parse_open_ports", parse_open_ports)
workflow.add_node("identify_web_ports", identify_web_ports_via_llm)
workflow.add_node("execute_detailed_scans", execute_detailed_scans)

workflow.set_entry_point("generate_scan_command")

workflow.add_edge("generate_scan_command", "validate_command")
workflow.add_edge("validate_command", "execute_scan")
workflow.add_edge("execute_scan", "parse_open_ports")
workflow.add_edge("parse_open_ports", "identify_web_ports")
workflow.add_edge("identify_web_ports", "execute_detailed_scans")
workflow.add_edge("execute_detailed_scans", END)

app = workflow.compile()

In [None]:
# --- Step 5: Run Example ---
if __name__ == "__main__":
    test_input = {
        "target_ip": "127.0.0.1",
        "scan_intent": "Permfrom fast scan to find open ports, and perform detailed scan on identified web ports.",
        "open_ports": [],
        "web_ports": [],
        "detailed_commands": [],
        "detailed_outputs": [],
        "error_message": None
    }

    result = app.invoke(test_input)
    print("\n🧾 Final Results:")
    print(f"Target: {result['target_ip']}")
    print("Open Ports:", result.get("open_ports", []))
    print("Detailed Commands:", result.get("detailed_commands", []))
    print("Detailed Scan Outputs:", result.get("detailed_outputs", []))
    print("Errors:", result.get("error_message"))