Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/ares/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Package entry point for the Ares command-line interface."""

from .main import app


def run() -> None:
"""CLI entry point for ares."""
"""Run the package CLI entry point."""
app()


Expand Down
4 changes: 2 additions & 2 deletions src/ares/cli_blue_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Suppress DEBUG/INFO logs from noisy modules in CLI output.
def _cli_log_filter(record):
"""Filter out DEBUG/INFO from noisy modules, keep all from cli_blue_ops."""
"""Return whether a log record should be shown in CLI output."""
module = record["name"]
level = record["level"].no
if module in {"ares.cli_blue_ops", "__main__"}:
Expand Down Expand Up @@ -1378,7 +1378,7 @@ async def triage_status(


def main():
"""Entry point."""
"""Run the ares-blue-ops CLI application."""
app()


Expand Down
9 changes: 3 additions & 6 deletions src/ares/cli_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Suppress DEBUG/INFO logs from noisy modules in CLI output
def _cli_log_filter(record):
"""Filter out DEBUG/INFO from noisy modules."""
"""Return whether a log record should be shown in CLI output."""
module = record["name"]
level = record["level"].no
if module in {"ares.cli_history", "__main__"}:
Expand All @@ -38,10 +38,7 @@ def _cli_log_filter(record):


def _check_enabled():
"""Check if persistent store is enabled.

Exits with code 1 if ARES_DATABASE_URL is not set.
"""
"""Exit if the persistent store is not configured."""
from ares.core.persistent_store import get_persistent_store_config

config = get_persistent_store_config()
Expand Down Expand Up @@ -578,7 +575,7 @@ async def _apply_retention():


def main():
"""Entry point for the CLI."""
"""Run the ares-history CLI application."""
app()


Expand Down
4 changes: 2 additions & 2 deletions src/ares/cli_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Suppress DEBUG/INFO logs from noisy modules (Redis client, config) in CLI output.
# Keep all logs from cli_ops itself and show WARNING+ from other modules.
def _cli_log_filter(record):
"""Filter out DEBUG/INFO from noisy modules, keep all from cli_ops."""
"""Return whether a log record should be shown in CLI output."""
module = record["name"]
level = record["level"].no
# Allow all logs from this module
Expand Down Expand Up @@ -2616,7 +2616,7 @@ async def watch(


def main() -> None:
"""Entry point for ares-ops CLI."""
"""Run the ares-ops CLI application."""
try:
app()
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion src/ares/core/alert_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class AlertCorrelator:
CLUSTER_THRESHOLD = 0.3 # Minimum similarity to join a cluster

def __init__(self):
"""Initialize the correlator."""
"""Initialize empty alert cluster state."""
self.clusters: list[AlertCluster] = []
self._cluster_counter = 0
self._alert_to_cluster: dict[str, str] = {} # fingerprint -> cluster_id
Expand Down
12 changes: 7 additions & 5 deletions src/ares/core/blue_dispatcher/publishing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,16 @@ async def publish_lateral_connection(
user: User account used.
mitre_technique: Associated MITRE technique.
"""
source_norm = source.strip().lower()
destination_norm = destination.strip().lower()
connection = {
"source": source,
"destination": destination,
"source": source_norm,
"destination": destination_norm,
"connection_type": connection_type,
"user": user,
"mitre_technique": mitre_technique,
}
await self._backend.add_lateral_connection(connection)
await self._backend.track_host(source.strip().lower())
await self._backend.track_host(destination.strip().lower())
logger.debug(f"Published lateral: {source} -> {destination} ({connection_type})")
await self._backend.track_host(source_norm)
await self._backend.track_host(destination_norm)
logger.debug(f"Published lateral: {source_norm} -> {destination_norm} ({connection_type})")
4 changes: 3 additions & 1 deletion src/ares/core/blue_orchestrator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@


def _configure_dreadnode():
"""Configure runtime integrations before returning the Dreadnode SDK."""
configure_litellm_env()

# Configure OTEL tracing to export to OTLP endpoint (e.g., Alloy/Tempo)
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(

@staticmethod
def _decode_redis_value(value: str | bytes) -> str:
"""Return a Redis value as a decoded string."""
return value.decode() if isinstance(value, bytes) else str(value)

async def start(self) -> None:
Expand Down Expand Up @@ -454,7 +456,7 @@ async def _pop_investigation_request(self, max_retries: int = 2) -> dict[str, An
return None

def _should_use_multi_agent(self, request: InvestigationRequest) -> bool:
"""Determine if multi-agent should be used for this investigation."""
"""Return whether an investigation should use the multi-agent path."""
if request.multi_agent:
return True
if not request.auto_route:
Expand Down
4 changes: 2 additions & 2 deletions src/ares/core/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class RedTeamActivity:

@property
def key(self) -> str:
"""Generate a unique key for this activity."""
"""Return a unique correlation key for this activity."""
return f"{self.timestamp.isoformat()}:{self.technique_id}:{self.target_ip}"


Expand All @@ -52,7 +52,7 @@ class BlueTeamDetection:

@property
def key(self) -> str:
"""Generate a unique key for this detection."""
"""Return a unique correlation key for this detection."""
return f"{self.timestamp.isoformat()}:{self.technique_id}:{self.alert_name}"


Expand Down
38 changes: 24 additions & 14 deletions src/ares/core/k8s_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,12 @@ def __init__(
kubeconfig: str | None = None,
in_cluster: bool = False,
):
"""
Initialize the Kubernetes pod executor.
"""Initialize the executor for a Kubernetes namespace.

Args:
namespace: Kubernetes namespace to operate in.
kubeconfig: Path to kubeconfig file. If None, uses default.
in_cluster: If True, use in-cluster configuration.
kubeconfig: Path to the kubeconfig file, if any.
in_cluster: Whether to use in-cluster Kubernetes configuration.
"""
self.namespace = namespace
self._kubeconfig = kubeconfig
Expand Down Expand Up @@ -219,13 +218,18 @@ async def execute(
logger.warning(f"Pod execution failed, retrying with fresh pod discovery: {e}")
pod_name = await self.get_pod_for_role(role)
if pod_name:
return await self._execute_in_pod(
pod_name=pod_name,
command=command,
container=container,
timeout=timeout_seconds,
stdin_data=stdin_data,
)
try:
return await self._execute_in_pod(
pod_name=pod_name,
command=command,
container=container,
timeout=timeout_seconds,
stdin_data=stdin_data,
)
except Exception as retry_exc:
raise PodExecutionError(
f"Command execution failed: {retry_exc}"
) from retry_exc
raise PodExecutionError(f"Command execution failed: {e}") from e

async def _execute_in_pod(
Expand Down Expand Up @@ -309,9 +313,10 @@ async def wait_for_pod(self, role: str, timeout: int = 60) -> bool:
"""
await self._ensure_initialized()

start = asyncio.get_event_loop().time()
loop = asyncio.get_running_loop()
start = loop.time()

while asyncio.get_event_loop().time() - start < timeout:
while loop.time() - start < timeout:
pod_name = await self.get_pod_for_role(role)
if pod_name:
logger.info(f"Pod ready for role {role}: {pod_name}")
Expand Down Expand Up @@ -481,8 +486,13 @@ async def copy_from_pod(

# Decode and write locally
import base64
import binascii

data = base64.b64decode(stdout.strip())
try:
data = base64.b64decode(stdout.strip())
except (binascii.Error, ValueError) as e:
logger.error(f"Failed to decode pod file payload: {e}")
return False

with open(local_path, "wb") as f: # noqa: ASYNC230
f.write(data)
Expand Down
15 changes: 12 additions & 3 deletions src/ares/core/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ class MessageType(Enum):


def generate_message_id() -> str:
"""Generate a unique message ID."""
"""Return a unique message identifier."""
return f"msg-{uuid.uuid4().hex[:12]}"


def generate_task_id() -> str:
"""Generate a unique task ID."""
"""Return a unique task identifier."""
return f"task-{uuid.uuid4().hex[:12]}"


Expand Down Expand Up @@ -332,7 +332,16 @@ class OperationAbort(AgentMessage):


def create_message(message_type: MessageType, source_agent: str, **kwargs) -> AgentMessage:
"""Factory function to create appropriate message type."""
"""Create an agent message instance for the requested message type.

Args:
message_type: The message type to instantiate.
source_agent: The agent that is sending the message.
**kwargs: Additional fields passed to the concrete message model.

Returns:
An initialized agent message model.
"""
message_classes = {
MessageType.CREDENTIAL_DISCOVERED: CredentialDiscovered,
MessageType.HASH_DISCOVERED: HashDiscovered,
Expand Down
6 changes: 3 additions & 3 deletions src/ares/core/orchestrator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,11 @@ async def get_operation_status(
"""Get the current status of an operation.

Args:
operation_id: Operation ID
redis_url: Redis connection URL (default: from config)
operation_id: Operation ID.
redis_url: Redis connection URL. Defaults to the configured value.

Returns:
Operation status dict or None if not found
Operation status data, or None if no status has been stored yet.
"""
redis_url = redis_url or get_redis_url()
task_queue = RedisTaskQueue(redis_url)
Expand Down
6 changes: 4 additions & 2 deletions src/ares/core/orchestrator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@


def _configure_dreadnode():
"""Configure runtime integrations before returning the Dreadnode SDK."""
configure_litellm_env()

# Configure OTEL tracing to export to OTLP endpoint (e.g., Alloy/Tempo)
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(

@staticmethod
def _decode_redis_value(value: str | bytes) -> str:
"""Return a Redis value as a decoded string."""
return value.decode() if isinstance(value, bytes) else str(value)

@staticmethod
Expand Down Expand Up @@ -507,7 +509,7 @@ async def _pop_operation_request(self) -> dict[str, Any] | None:
return None

def _log_env_vars(self, raw_env_vars: Any) -> None:
"""Log environment variables from request."""
"""Log which request environment variables were supplied."""
if raw_env_vars is None:
logger.warning("Request missing env_vars")
elif isinstance(raw_env_vars, dict):
Expand All @@ -520,7 +522,7 @@ def _log_env_vars(self, raw_env_vars: Any) -> None:
logger.warning(f"Request env_vars not a dict: {type(raw_env_vars)}")

def _resolve_openai_api_key(self, request_env_vars: dict[str, str] | None) -> str | None:
"""Resolve OpenAI API key from request or environment."""
"""Resolve the OpenAI API key from request data or process environment."""
openai_api_key = request_env_vars.get("OPENAI_API_KEY") if request_env_vars else None
if request_env_vars:
present_keys = sorted(k for k, v in request_env_vars.items() if v)
Expand Down
Loading
Loading