# How to create an Azure AI Content safety enabled LLaMA online endpoint
### This notebook will walk you through the steps to create an __Azure AI Content Safety__ enabled __LLaMA__ online endpoint.
### The steps are:
1. Create an __Azure AI Content Safety__ resource for moderating the request from user and response from the __LLaMA__ online endpoint.
2. Create a new __Azure AI Content Safety__ enabled __LLaMA__ online endpoint with a custom score.py which will integrate with the __Azure AI Content Safety__ resource to moderate the response from the __LLaMA__ model and the request from the user, but to make the custom score.py to sucessfully autheticated to the __Azure AI Content Safety__ resource, we have 2 options:
    1. __UAI__, recommended but more complex approach, is to create a __User Assigned Identity (UAI)__ and assign appropriate roles to the __UAI__. Then, the custom score.py can obtain the access token of the __UAI__ from the AAD server to access the Azure AI Content Safety resource.
    2. __Environment variable__, simpler but less secure approach, is to just pass the access key of the __Azure AI Content Safety__ resource to the custom score.py via environment variable, then the custom score.py can use the key directly to access the Azure AI Content Safety resource, this option is less secure than the first option, if someone in your org has access to the endpoint, he/she can get the access key from the environment variable and use it to access the Azure AI Content Safety resource.
  

### 1. Prerequisites
#### 1.1 Check List:
- [x] You have created an new Python virtual environment for this notebook.
- [x] The identity you are using to execute this notebook(yourself or your VM) need to have the __Contributor__ role on the resource group where the AML Workspace your specified is located, because this notebook will create an Azure AI Content Safety resource using that identity.
- [x] Required If you choose to use the UAI approach, the identity executing this notebook (either yourself or your virtual machine) needs to have the owner role on the resource group that contains the specified AML Workspace. This is because the notebook will create a new UAI and assign the UAI some required roles to successfully create the Azure AI Content Safety enabled LLaMA endpoint.

#### 1.2 Install Dependencies

In [None]:
# uncomment the following lines to install the required packages
%pip install azure-identity==1.13.0
%pip install azure-mgmt-cognitiveservices==13.4.0
%pip install azure-ai-ml==1.8.0
%pip install azure-mgmt-msi==7.0.0
%pip install azure-mgmt-authorization==3.0.0

#### 1.3 Assign variables for the workspace and deployment

In [None]:
# NOTE: Update following workspace information to contain
#       your subscription ID, resource group name, and workspace name
subscription_id = ""
resource_group = ""
workspace_name = ""

# The public registry name contains LLaMA models
registry_name="azureml-preview-test1"

# Name of the LLaMA model to be deployed
# available_llama_models_pre_trained = ["Llama-2-7b", "Llama-2-13b"]
# available_llama_models_fine_tuned = ["Llama-2-7b-chat", "Llama-2-13b-chat"]
model_name="Llama-2-7b"

endpoint_name="llama-large" # Replace with your endpoint name
deployment_name="llama" # Replace with your deployment name
sku_name="Standard_NC24s_v3" # Name of the sku(instance type) Check the model-list(can be found in the parent folder(inference)) to get the most optimal sku for your model (Default: Standard_DS2_v2)


# settings for the Azure AI Content Safety resource
aacs_name = f"{endpoint_name}-aacs-1" # name of azure ai content safety resource, has to be unique
available_aacs_locations = ['east us', 'west europe']
aacs_location = available_aacs_locations[0]


### 2. Connect to your AML Workspace

In [None]:
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential

try:
    credential = DefaultAzureCredential()
    # Check if given credential can get token successfully.
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work
    credential = InteractiveBrowserCredential()

try:
    ml_client = MLClient.from_config(credential=credential)
except Exception as ex:
    ml_client = MLClient(credentials=credential, subscription_id=subscription_id, resource_group=resource_group, workspace_name=workspace_name)



workspace_location = ml_client.workspaces.get(ml_client.workspace_name).location
workspace_resource_id = ml_client.workspaces.get(ml_client.workspace_name).id
subscription_id = ml_client.subscription_id
resource_group_name = ml_client.resource_group_name
workspace_name = ml_client.workspace_name

reg_client = MLClient(credential, subscription_id=subscription_id, resource_group_name=resource_group_name, registry_name=registry_name)
print(f"Connected to workspace {workspace_resource_id}")
print(f"Workspace location is {workspace_location}") 

### 4. Create Azure AI Content Safety

#### 4.1 Choose a region for your Azure AI Content Safety
Currently, Azure AI Content Safety only available in the following regions:
- East US
- West Europe

__NOTE__: before you choose the region to deploy the Azure AI Content Safety, please aware of that your data will be transferred to the region you choose and by selecting a region outside your current location, you may be allowing the transmission of your data to regions outside your jurisdiction. It is important to note that data protection and privacy laws may vary between jurisdictions. Before proceeding, we strongly advise you to familiarize yourself with the local laws and regulations governing data transfer and ensure that you are legally permitted to transmit your data to an overseas location for processing. By continuing with the selection of a different region, you acknowledge that you have understood and accepted any potential risks associated with such data transmission. Please proceed with caution.

In [None]:
from azure.mgmt.cognitiveservices import CognitiveServicesManagementClient
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.mgmt.cognitiveservices.models import Account, Sku, AccountProperties
import time

try:
    credential = DefaultAzureCredential()
    # Check if given credential can get token successfully.
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work
    credential = InteractiveBrowserCredential()

client = CognitiveServicesManagementClient(credential, subscription_id)

# create a new Cognitive Services Account
kind = "ContentSafety"
aacs_sku_name = "S0"
parameters = Account(sku=Sku(name=aacs_sku_name), kind=kind, location=aacs_location, properties= AccountProperties(custom_sub_domain_name=aacs_name, public_network_access="Enabled"))
# How many seconds to wait between checking the status of an async operation.
wait_time = 10


try:
    client.accounts.get(resource_group_name, aacs_name)
    print(f"Found existing Azure AI content safety Account {aacs_name}.")
except:
    print(f"Creating Azure AI content safety Account {aacs_name}.")
    poller = client.accounts.begin_create(resource_group_name, aacs_name, parameters)
    while (not poller.done()) :
        print("Waiting {wait_time} seconds for operation to finish.".format(wait_time=wait_time))
        time.sleep (wait_time)
        # This will raise an exception if the server responded with an error.
        result = poller.result()
    print("Resource created.")



aacs=client.accounts.get(resource_group_name, aacs_name)
aacs_endpoint = aacs.properties.endpoint
aacs_resource_id = aacs.id
print(f"AACS endpoint is {aacs_endpoint}")
print(f"AACS ResourceId is {aacs_resource_id}")

aacs_access_key = client.accounts.list_keys(resource_group_name=resource_group_name, account_name=aacs_name).key1
print(f"AACS access key is {aacs_access_key}")

### 5. Create LLaMA online endpoint

#### 5.1 Decide on SKU and instance count for the LLaMA online endpoint.

In [None]:
print(f"Will create LLaMA endpoint {endpoint_name} using {sku_name} compute instance(s)")

#### 5.2 Check if LLaMA model is available in the AML registry.

In [None]:

version_list = list(reg_client.models.list(model_name)) # list available versions of the model
llama_model = None
if len(version_list) == 0:
    print("Model not found in registry")
else:
    model_version = version_list[0].version
    llama_model = reg_client.models.get(model_name, model_version)
    print(
        f"Using model name: {llama_model.name}, version: {llama_model.version}, id: {llama_model.id} for inferencing"
    )

#### 5.3 Create LLaMA online endpoint
This step may take a few minutes.

In [None]:
from azure.ai.ml.entities import ManagedOnlineEndpoint
# Check if the endpoint already exists in the workspace
try:
    endpoint = ml_client.online_endpoints.get(endpoint_name)
    print("---Endpoint already exists---")
except:
    # Create an online endpoint if it doesn't exist

    # Define the endpoint
    endpoint = ManagedOnlineEndpoint(name=endpoint_name, description="Test endpoint for model")

    # Trigger the endpoint creation
    try:
        ml_client.begin_create_or_update(endpoint).wait()
        print("\n---Endpoint created successfully---\n")
    except Exception as err:
        raise RuntimeError(f"Endpoint creation failed. Detailed Response:\n{err}") from err

### 6. Create the Azure AI Content Safety enabled LLaMA online endpoint

##### 6.1 Create environment for LLaMA endpoint


In [None]:
from azure.ai.ml.entities import (
    Environment,
    BuildContext
)
try:
    env = ml_client.environments.get("LLaMA-ENVIRONMENT", label="latest")
    print("---Environment already exists---")
except:
    print("---Creating environment---")
    env = Environment(name = "LLaMA-ENVIRONMENT", build= BuildContext(path='./docker_env') )
    ml_client.environments.create_or_update(env)


# TODO: Add a check to see if the job finished
    

##### 7.3.2 Create the Safety-Enabled LLaMA Online Endpoint
This step may take a few minutes.

In [None]:
from azure.ai.ml.entities import (
    CodeConfiguration,
    OnlineRequestSettings,
    ManagedOnlineDeployment
)

# Define the deployment
# Update the model version as necessary
deployment = ManagedOnlineDeployment(
    name=deployment_name,
    endpoint_name=endpoint_name,
    model=llama_model.id,
    instance_type=sku_name,
    instance_count=1,
    code_configuration=CodeConfiguration(
        code="./score", scoring_script="score.py"
    ),
    environment = env,
    environment_variables= {
        "MLFLOW_MODEL_FOLDER":"mlflow_model_folder",
        "CONTENT_SAFETY_ENDPOINT": aacs_endpoint,
        "CONTENT_SAFETY_KEY": aacs_access_key
    },
    request_settings= OnlineRequestSettings(request_timeout_ms=90000)
)

# Trigger the deployment creation
try:
    ml_client.begin_create_or_update(deployment).wait()
    print("\n---Deployment created successfully---\n")
except Exception as err:
    raise RuntimeError(f"Deployment creation failed. Detailed Response:\n{err}") from err


### 8. Test the Safety Enabled LLaMA online endpoint.

#### 8.1 Test endpoint with normal sample request

In [None]:
import os

test_src_dir = "./safety-llama-test"
os.makedirs(test_src_dir, exist_ok=True)
print(f"test script directory: {test_src_dir}")

In [None]:
%%writefile {test_src_dir}/sample-request.json
{
    "input_data": {
          "columns": ["input_string"],
          "index": [0],
          "data": [ "My name is John and I " ]
        }
}

In [None]:
sample_data = f"{test_src_dir}/sample-request.json"
ml_client.online_endpoints.invoke(
        endpoint_name=endpoint_name, 
        deployment_name=deployment_name,
        request_file=sample_data)

#### 8.2 Test endpoint with harmful sample request

In [None]:
%%writefile {test_src_dir}/sample-harmful-request.json
{
    "input_data": {
          "columns": ["input_string"],
          "index": [0],
          "data": [ "I wanna kill you" ]
        }
}

In [None]:
sample_data = f"{test_src_dir}/sample-harmful-request.json"
ml_client.online_endpoints.invoke(
        endpoint_name=endpoint_name, 
        deployment_name=deployment_name,
        request_file=sample_data)