Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ares-cli/src/orchestrator/automation/gpo.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! auto_gpo_abuse -- exploit GPO write access for code execution.
//!
//! When a controlled user has write access to a Group Policy Object
//! (e.g., samwell.tarly has write on a GPO linked to contoso.local),
//! (e.g., a user has write on a GPO linked to contoso.local),
//! this automation dispatches `pyGPOAbuse` to inject a scheduled task that
//! runs as SYSTEM on all hosts where the GPO applies.
//!
Expand Down
2 changes: 1 addition & 1 deletion ares-cli/src/orchestrator/automation/rbcd.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! auto_rbcd_exploitation -- exploit GenericAll/GenericWrite on computer objects via RBCD.
//!
//! When a controlled user has GenericAll or GenericWrite on a computer object
//! (e.g., stanniskingslanding$), this automation dispatches the full RBCD
//! (e.g., userDC$), this automation dispatches the full RBCD
//! chain: addcomputer → rbcd_write → S4U → secretsdump.
//!
//! This is separate from s4u.rs which handles pre-existing delegation vulns.
Expand Down
4 changes: 2 additions & 2 deletions ares-cli/src/orchestrator/automation/unconstrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ pub async fn auto_unconstrained_exploitation(
}

// Machine accounts: resolve hostname → IP for coerce+dump chain.
// User accounts (sansa.stark): dispatch LLM exploit task since we
// can't determine which host to coerce from just the account name.
// User accounts: dispatch LLM exploit task since we can't determine
// which host to coerce from just the account name.
let is_machine = account_name.ends_with('$');

// Find a DC in the same domain — this is what we coerce FROM.
Expand Down
32 changes: 31 additions & 1 deletion ares-cli/src/orchestrator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,39 @@ async fn run_inner() -> Result<()> {
info!(
requeued = recovered.requeued_task_ids.len(),
failed = recovered.failed_task_ids.len(),
"Recovery: re-enqueued interrupted tasks"
"Recovery: re-dispatching interrupted tasks via LLM submission"
);
}
for task in recovered.tasks_to_redispatch {
match dispatcher
.do_submit(&task.task_type, &task.target_role, task.payload, 1)
.await
{
Ok(Some(tid)) => {
info!(
task_id = %tid,
task_type = %task.task_type,
role = %task.target_role,
retry = task.retry_count,
"Recovery: re-dispatched task via LLM runner"
);
}
Ok(None) => {
warn!(
task_type = %task.task_type,
role = %task.target_role,
"Recovery: task deferred or dropped during re-dispatch"
);
}
Err(e) => {
warn!(
task_type = %task.task_type,
err = %e,
"Recovery: failed to re-dispatch task"
);
}
}
}
}
Err(e) => {
// Recovery failure is non-fatal — we already loaded state above
Expand Down
41 changes: 22 additions & 19 deletions ares-cli/src/orchestrator/recovery/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::orchestrator::task_queue::TaskQueue;

use super::dedup::dedupe_hashes;
use super::normalize::{normalize_credential_domains, normalize_hash_domains};
use super::requeue::requeue_task;
use super::types::{
is_connection_error, RecoveredState, INTERRUPTED_STATUSES, MAX_CONNECTION_RETRIES, MAX_RETRIES,
};
Expand Down Expand Up @@ -174,6 +173,7 @@ impl OperationRecoveryManager {

let mut requeued_task_ids = Vec::new();
let mut failed_task_ids = Vec::new();
let mut tasks_to_redispatch = Vec::new();

for (task_id, task) in &mut pending_tasks {
if !INTERRUPTED_STATUSES.contains(&task.status) {
Expand All @@ -198,24 +198,26 @@ impl OperationRecoveryManager {
task.error = Some("Requeued after pod restart (task was pending)".to_string());
}

match requeue_task(queue, task_id, task).await {
Ok(()) => {
requeued_task_ids.push(task_id.clone());
info!(
task_id = %task_id,
retry_count = task.retry_count,
max_retries = max_retries,
"Task requeued for recovery"
);
}
Err(e) => {
warn!(
task_id = %task_id,
err = %e,
"Failed to requeue task"
);
}
}
let payload = task
.params
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<serde_json::Map<String, serde_json::Value>>();

tasks_to_redispatch.push(super::types::RecoveryTask {
task_type: task.task_type.clone(),
target_role: task.assigned_agent.clone(),
payload: serde_json::Value::Object(payload),
retry_count: task.retry_count,
});

requeued_task_ids.push(task_id.clone());
info!(
task_id = %task_id,
retry_count = task.retry_count,
max_retries = max_retries,
"Task collected for re-dispatch via LLM submission"
);
} else {
// Exceeded max retries
task.status = TaskStatus::Failed;
Expand Down Expand Up @@ -249,6 +251,7 @@ impl OperationRecoveryManager {

Ok(RecoveredState {
state: loaded_state,
tasks_to_redispatch,
requeued_task_ids,
failed_task_ids,
})
Expand Down
1 change: 0 additions & 1 deletion ares-cli/src/orchestrator/recovery/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
mod dedup;
mod manager;
mod normalize;
mod requeue;
mod types;

pub use manager::OperationRecoveryManager;
Expand Down
59 changes: 0 additions & 59 deletions ares-cli/src/orchestrator/recovery/requeue.rs

This file was deleted.

33 changes: 32 additions & 1 deletion ares-cli/src/orchestrator/recovery/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,25 @@ pub fn is_connection_error(err: &anyhow::Error) -> bool {
CONNECTION_ERROR_KEYWORDS.iter().any(|kw| msg.contains(kw))
}

/// A task that needs to be re-dispatched through the normal LLM submission
/// flow after recovery.
#[derive(Debug, Clone)]
pub struct RecoveryTask {
pub task_type: String,
pub target_role: String,
pub payload: serde_json::Value,
pub retry_count: i32,
}

/// Result of a recovery operation.
#[derive(Debug)]
pub struct RecoveredState {
/// The full shared state loaded from Redis.
#[allow(dead_code)]
pub state: SharedRedTeamState,
/// Task IDs that were re-enqueued for retry.
/// Tasks that need re-dispatch through the normal submission flow.
pub tasks_to_redispatch: Vec<RecoveryTask>,
/// Task IDs that were prepared for re-dispatch.
pub requeued_task_ids: Vec<String>,
/// Task IDs that exceeded max retries and were marked failed.
pub failed_task_ids: Vec<String>,
Expand Down Expand Up @@ -102,4 +114,23 @@ mod tests {
assert_eq!(MAX_CONNECTION_RETRIES, 3);
assert_eq!(INTERRUPTED_STATUSES.len(), 3);
}

#[test]
fn recovery_task_carries_payload_for_redispatch() {
let task = RecoveryTask {
task_type: "credential_access".to_string(),
target_role: "credential_access".to_string(),
payload: serde_json::json!({"target": "192.168.58.1"}),
retry_count: 2,
};
assert_eq!(task.task_type, "credential_access");
assert_eq!(task.target_role, "credential_access");
assert_eq!(task.payload["target"], "192.168.58.1");
assert_eq!(task.retry_count, 2);

let cloned = task.clone();
assert_eq!(cloned.task_type, task.task_type);
let dbg = format!("{task:?}");
assert!(dbg.contains("credential_access"));
}
}
80 changes: 74 additions & 6 deletions ares-cli/src/orchestrator/state/publishing/hosts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,15 @@ impl SharedState {
queue: &TaskQueueCore<impl ConnectionLike + Clone + Send + Sync + 'static>,
host: &Host,
) -> Result<()> {
// Extract domain from hostname — prefer a real FQDN
// Require at least 3 dot-separated parts (e.g. dc03.contoso.local)
// so 2-part hostnames like "HOSTNAME.local" don't yield "local" as the domain.
let raw_domain = if !host.hostname.is_empty() {
host.hostname
.split('.')
.skip(1)
.collect::<Vec<_>>()
.join(".")
let parts: Vec<&str> = host.hostname.split('.').collect();
if parts.len() >= 3 {
parts[1..].join(".")
} else {
String::new()
}
} else {
String::new()
};
Expand Down Expand Up @@ -557,6 +559,72 @@ mod tests {
assert!(s.domain_controllers.is_empty());
}

#[tokio::test]
async fn register_dc_two_part_hostname_uses_fallback() {
// Hostname with only 2 parts (e.g. "DC01.local") must NOT register
// "local" as the domain. With a fallback domain already in state,
// register_dc should use the fallback instead.
let state = SharedState::new("op-1".to_string());
let q = mock_queue();
{
let mut s = state.inner.write().await;
s.domains.push("contoso.local".to_string());
}

let host = make_host("192.168.58.1", "DC01.local", true);
state.register_dc(&q, &host).await.unwrap();

let s = state.inner.read().await;
// Must NOT have registered just "local" as a domain
assert!(
!s.domain_controllers.contains_key("local"),
"two-part hostname leaked 'local' as a domain"
);
assert_eq!(
s.domain_controllers.get("contoso.local"),
Some(&"192.168.58.1".to_string()),
"expected fallback to existing domain"
);
}

#[tokio::test]
async fn register_dc_two_part_hostname_no_fallback_skips() {
// 2-part hostname AND no fallback domain → skip entirely instead of
// registering a TLD as the AD domain.
let state = SharedState::new("op-1".to_string());
let q = mock_queue();

let host = make_host("192.168.58.1", "DC01.local", true);
state.register_dc(&q, &host).await.unwrap();

let s = state.inner.read().await;
assert!(
s.domain_controllers.is_empty(),
"2-part hostname with no fallback must skip DC registration"
);
assert!(
!s.domains.iter().any(|d| d == "local"),
"2-part hostname leaked 'local' into domains"
);
}

#[tokio::test]
async fn register_dc_three_part_hostname_extracts_full_domain() {
// Sanity check the >=3 parts branch with a deeper FQDN to make sure
// the parts[1..].join(".") slice is right (not just the last label).
let state = SharedState::new("op-1".to_string());
let q = mock_queue();

let host = make_host("192.168.58.1", "dc.eu.contoso.local", true);
state.register_dc(&q, &host).await.unwrap();

let s = state.inner.read().await;
assert_eq!(
s.domain_controllers.get("eu.contoso.local"),
Some(&"192.168.58.1".to_string())
);
}

#[tokio::test]
async fn publish_host_strips_trailing_dot() {
let state = SharedState::new("op-1".to_string());
Expand Down
Loading
Loading