Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion infra/parameters.example.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"value": "YOUR_DATABASE_NAME"
},
"pyritInitializer": {
"value": "targets airt"
"value": "target airt"
},
"envSecretName": {
"value": "env-global"
Expand Down
54 changes: 34 additions & 20 deletions pyrit/setup/initializers/airt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

import json
import logging
import os
from collections.abc import Callable

Expand Down Expand Up @@ -38,6 +39,8 @@
from pyrit.score.float_scale.self_ask_scale_scorer import SelfAskScaleScorer
from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer

logger = logging.getLogger(__name__)


class AIRTInitializer(PyRITInitializer):
"""
Expand Down Expand Up @@ -275,32 +278,43 @@ def _setup_adversarial_targets(self, *, endpoint: str, api_key: str, model_name:

def _validate_operation_fields(self) -> None:
"""
Check that mandatory global memory labels (operation, operator)
are populated.
Ensure operator and operation are populated in GLOBAL_MEMORY_LABELS.

Reads operator/operation from .pyrit_conf if it exists, then merges
Comment thread
romanlutz marked this conversation as resolved.
them into GLOBAL_MEMORY_LABELS. In container/GUI deployments where
.pyrit_conf is not present, the labels are set per-user by the GUI
at runtime, so this method is a no-op.

Raises:
ValueError: If mandatory global memory labels are missing.
ValueError: If .pyrit_conf exists but is missing operator or operation.
"""
with open(DEFAULT_CONFIG_PATH) as f:
data = yaml.load(f, Loader=yaml.SafeLoader)
raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS")
labels = dict(json.loads(raw_labels)) if raw_labels else {}

if "operator" not in data:
raise ValueError(
"Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)
if DEFAULT_CONFIG_PATH.exists():
with open(DEFAULT_CONFIG_PATH) as f:
data = yaml.load(f, Loader=yaml.SafeLoader) or {}

if "operation" not in data:
raise ValueError(
"Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)
if "operator" not in data:
raise ValueError(
"Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)

raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS")
labels = dict(json.loads(raw_labels)) if raw_labels else {}
if "operation" not in data:
raise ValueError(
"Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)

if "operator" not in labels:
labels["operator"] = data["operator"]
if "operator" not in labels:
labels["operator"] = data["operator"]

if "operation" not in labels:
labels["operation"] = data["operation"]
if "operation" not in labels:
labels["operation"] = data["operation"]

os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels)
os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels)
else:
logger.info(
"No .pyrit_conf found at %s — skipping operator/operation validation. "
"In GUI mode, these labels are set per-user at runtime.",
DEFAULT_CONFIG_PATH,
)
60 changes: 60 additions & 0 deletions tests/unit/setup/test_airt_initializer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
import os
import sys
from unittest.mock import patch
Expand Down Expand Up @@ -214,6 +215,65 @@ def test_validate_missing_operation_raises_error(self, tmp_path):
):
init._validate_operation_fields()

def test_validate_operation_fields_skips_when_pyrit_conf_missing(self, tmp_path):
"""Test that _validate_operation_fields does not crash when .pyrit_conf is missing.

In container/GUI deployments, .pyrit_conf does not exist. The method should
skip validation gracefully instead of raising FileNotFoundError.
"""
nonexistent_path = tmp_path / "nonexistent" / ".pyrit_conf"
init = AIRTInitializer()
with patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", nonexistent_path):
# Should not raise
init._validate_operation_fields()

def test_validate_operation_fields_preserves_existing_labels_when_pyrit_conf_missing(self, tmp_path):
"""Test that existing GLOBAL_MEMORY_LABELS are preserved when .pyrit_conf is missing."""
nonexistent_path = tmp_path / "nonexistent" / ".pyrit_conf"
init = AIRTInitializer()
with (
patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", nonexistent_path),
patch.dict("os.environ", {"GLOBAL_MEMORY_LABELS": '{"operator": "gui_user", "operation": "gui_op"}'}),
):
init._validate_operation_fields()
# Existing labels should remain untouched
labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"])
assert labels["operator"] == "gui_user"
assert labels["operation"] == "gui_op"

def test_validate_operation_fields_merges_conf_into_labels(self, tmp_path):
"""Test that .pyrit_conf values are merged into GLOBAL_MEMORY_LABELS when labels are missing."""
conf_file = tmp_path / ".pyrit_conf"
conf_file.write_text(yaml.dump({"operator": "conf_user", "operation": "conf_op"}))
init = AIRTInitializer()
with (
patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file),
patch.dict("os.environ", {}, clear=False),
):
# Remove GLOBAL_MEMORY_LABELS if present
os.environ.pop("GLOBAL_MEMORY_LABELS", None)
init._validate_operation_fields()
labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"])
assert labels["operator"] == "conf_user"
assert labels["operation"] == "conf_op"

def test_validate_operation_fields_does_not_overwrite_existing_labels(self, tmp_path):
"""Test that .pyrit_conf values do not overwrite existing GLOBAL_MEMORY_LABELS entries."""
conf_file = tmp_path / ".pyrit_conf"
conf_file.write_text(yaml.dump({"operator": "conf_user", "operation": "conf_op"}))
init = AIRTInitializer()
with (
patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file),
patch.dict(
"os.environ",
{"GLOBAL_MEMORY_LABELS": '{"operator": "existing_user", "operation": "existing_op"}'},
),
):
init._validate_operation_fields()
labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"])
assert labels["operator"] == "existing_user"
assert labels["operation"] == "existing_op"

def test_validate_db_connection_raises_error(self):
"""Test that validate raises error when AZURE_SQL_DB_CONNECTION_STRING is missing."""
del os.environ["AZURE_SQL_DB_CONNECTION_STRING"]
Expand Down
Loading