In [2]:
from typing import Optional, Dict, Any
import os
from azure.ai.evaluation.red_team import RedTeam, RiskCategory, AttackStrategy
from openai import AzureOpenAI
from azure.identity import AzureCliCredential, get_bearer_token_provider, DefaultAzureCredential

In [3]:
# Initialize Azure credentials

credential = AzureCliCredential()
# print(credential)

In [None]:
# ----- Configuration -----
# Azure AI Project information
azure_ai_project = os.environ.get("RED_TEAM_PROJECT_ENDPOINT")

# Azure OpenAI deployment information
# azure_openai_deployment = os.environ.get("AZURE_OPENAI_ADVERSARIAL_DEPLOYMENT_NAME")
# azure_openai_endpoint = os.environ.get("AZURE_OPENAI_ADVERSARIAL_ENDPOINT")
# azure_openai_api_key = os.environ.get("AZURE_OPENAI_ADVERSARIAL_API_KEY")
# azure_openai_api_version = os.environ.get("AZURE_OPENAI_ADVERSARIAL_API_VERSION")

azure_openai_deployment = os.environ.get("AZURE_OPENAI_UNSAFE_DEPLOYMENT_NAME")
azure_openai_endpoint = os.environ.get("AZURE_OPENAI_UNSAFE_ENDPOINT")
azure_openai_api_key = os.environ.get("AZURE_OPENAI_UNSAFE_API_KEY")
azure_openai_api_version = os.environ.get("AZURE_OPENAI_UNSAFE_API_VERSION")

In [4]:
# Define a callback that uses Azure OpenAI API to generate responses
async def azure_openai_callback(
    messages: list,
    stream: Optional[bool] = False,
    session_state: Optional[str] = None,
    context: Optional[Dict[str, Any]] = None,
) -> dict[str, list[dict[str, str]]]:
    
    # Get token provider for Azure AD authentication
    token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default")

    # Initialize Azure OpenAI client
    client = AzureOpenAI(
        azure_endpoint=azure_openai_endpoint,
        api_version=azure_openai_api_version,
        azure_ad_token_provider=token_provider,
    )

    ## Extract the latest message from the conversation history
    messages_list = [{"role": message.role, "content": message.content} for message in messages]
    latest_message = messages_list[-1]["content"]

    try:
        # Call the model
        response = client.chat.completions.create(
            model=azure_openai_deployment,
            messages=[
                {"role": "user", "content": latest_message},
            ],
            # max_tokens=500, # If using an o1 base model, comment this line out
            max_completion_tokens=500,  # If using an o1 base model, uncomment this line
            # temperature=0.7, # If using an o1 base model, comment this line out (temperature param not supported for o1 base models)
        )

        # Format the response to follow the expected chat protocol format
        formatted_response = {"content": response.choices[0].message.content, "role": "assistant"}
    except Exception as e:
        print(f"Error calling Azure OpenAI: {e!s}")
        formatted_response = "I encountered an error and couldn't process your request."
    return {"messages": [formatted_response]}


In [5]:
# Create the RedTeam instance with all of the risk categories with 5 attack objectives generated for each category
red_team = RedTeam(
    azure_ai_project=azure_ai_project,
    credential=credential,
    risk_categories=[RiskCategory.Violence, RiskCategory.HateUnfairness, RiskCategory.Sexual, RiskCategory.SelfHarm],
    num_objectives=5,
)

Class RedTeam: This is an experimental class, and may change at any time. Please see https://aka.ms/azuremlexperimental for more information.


In [None]:
result = await red_team.scan(
    target=azure_openai_callback,
    scan_name="Basic-Callback-Scan",
    attack_strategies=[AttackStrategy.EASY],
    output_path="red_team_output.json",
)

In [None]:
# # Run the red team scan with multiple attack strategies
# advanced_result = await model_red_team.scan(
#     target=azure_openai_callback,
#     scan_name="Advanced-Callback-Scan",
#     attack_strategies=[
#         AttackStrategy.EASY,  # Group of easy complexity attacks
#         AttackStrategy.MODERATE,  # Group of moderate complexity attacks
#         AttackStrategy.CharacterSpace,  # Add character spaces
#         AttackStrategy.ROT13,  # Use ROT13 encoding
#         AttackStrategy.UnicodeConfusable,  # Use confusable Unicode characters
#         AttackStrategy.CharSwap,  # Swap characters in prompts
#         AttackStrategy.Morse,  # Encode prompts in Morse code
#         AttackStrategy.Leetspeak,  # Use Leetspeak
#         AttackStrategy.Url,  # Use URLs in prompts
#         AttackStrategy.Binary,  # Encode prompts in binary
#         AttackStrategy.Compose([AttackStrategy.Base64, AttackStrategy.ROT13]),  # Use two strategies in one attack
#     ],
#     output_path="Advanced-Callback-Scan.json",
# )