In [None]:
# azure-data-lake-sentinel/src/ml_models/sensitive_data_classifier/notebooks/custom_nlp_training.ipynb

# This is a conceptual Jupyter Notebook for Azure ML experimentation.
# You would run this on an Azure ML Compute Instance or Compute Cluster.

# --- 1. Setup Azure ML Workspace ---
import azure.ai.ml as ml
from azure.ai.ml.entities import Data, Environment, Code, Compute, Job, AmlCompute, OnlineEndpoint, OnlineDeployment, Model
from azure.ai.ml.constants import AssetTypes, InputOutputModes
from azure.identity import DefaultAzureCredential
import os
import json

# Authenticate and get MLClient
# Ensure your environment has AZUREML_ARM_SUBSCRIPTION, AZUREML_ARM_RESOURCEGROUP,
# and AZUREML_ARM_WORKSPACE_NAME set (typically done automatically by AML compute).
credential = DefaultAzureCredential()
ml_client = ml.MLClient(
    credential=credential,
    subscription_id=os.environ.get("AZUREML_ARM_SUBSCRIPTION"),
    resource_group_name=os.environ.get("AZUREML_ARM_RESOURCEGROUP"),
    workspace_name=os.environ.get("AZUREML_ARM_WORKSPACE_NAME")
)

print(f"Connected to Azure ML Workspace: {ml_client.workspace_name}")

# --- 2. Register Data ---
# This assumes you have your 'training_data.csv' ready in your ADLS Gen2
# or uploaded to the default blob store of your AML workspace.
# For simplicity, let's assume it's in a path relative to the workspace's default datastore.
# You might need to manually upload a dummy training_data.csv for testing.
# Example: training_data.csv with columns 'text' and 'label' (e.g., 'PII', 'Public')
# text,label
# "John Doe, 123 Main St, Anytown, 12345",PII
# "This is a public document about cloud computing.",Public
# "Meeting minutes from Q4 strategy session. Highly confidential.",Confidential

data_path = "azureml://datastores/workspaceblobstore/paths/data/sensitive_data_training/"
my_data = Data(
    name="sensitive-data-training-data",
    path=data_path,
    type=AssetTypes.URI_FOLDER,
    description="Training data for sensitive data classification.",
    tags={"format": "csv", "sensitive": "false"} # Tag as non-sensitive for the training data itself
)
# Uncomment to create/update the data asset
# data_asset = ml_client.data.create_or_update(my_data)
# print(f"Data asset URI: {data_asset.path}")
# For demonstration, let's assume the data asset is named 'sensitive-data-training-data'
# after manual upload or a previous run.
data_asset_name = "sensitive-data-training-data"


# --- 3. Create Environment ---
# Define the custom environment using the conda_env.yml file
custom_env_name = "sensitive-data-classifier-env"
my_env = Environment(
    name=custom_env_name,
    description="Custom environment for sensitive data classifier.",
    conda_file=os.path.join('..', 'conda_env.yml'), # Path to your conda_env.yml in the same repo
    image="mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu20.04:latest" # Base image
)
# Uncomment to create/update the environment
# ml_client.environments.create_or_update(my_env)
# print(f"Environment '{custom_env_name}' created/updated.")


# --- 4. Create Compute Cluster (if not already provisioned by Terraform) ---
# Ensure your AML Compute Cluster is attached to the VNet (ml-subnet)
compute_name = "ml-compute-cluster"
try:
    compute = ml_client.compute.get(compute_name)
    print(f"Compute '{compute_name}' already exists.")
except Exception:
    print(f"Creating compute '{compute_name}'...")
    compute = AmlCompute(
        name=compute_name,
        type="amlcompute",
        size="Standard_DS3_v2", # Matches your Terraform VM size
        min_instances=0,
        max_instances=4,
        idle_time_before_scale_down=1800, # 30 minutes
        # network_settings={ # Uncomment if your Terraform doesn't handle VNet for compute
        #     "vnet_name": "<YOUR_VNET_NAME>",
        #     "subnet": "<YOUR_ML_SUBNET_NAME>"
        # }
    )
    # Uncomment to create/update compute
    # ml_client.compute.begin_create_or_update(compute).wait()
    print(f"Compute '{compute_name}' created/updated.")


# --- 5. Create Training Job ---
# Define the command job that runs your train.py script
code_folder = os.path.join('..') # Points to the folder containing train.py and conda_env.yml
job = ml.command(
    name="sensitive-data-classifier-training",
    display_name="Train Sensitive Data Classifier",
    description="Trains an NLP model to classify sensitive data in text.",
    inputs={
        "data_path": ml.Input(
            type=AssetTypes.URI_FOLDER,
            path=f"azureml:{data_asset_name}:latest", # Use the registered data asset
            mode=InputOutputModes.RO_MOUNT # Read-only mount
        )
    },
    code=code_folder,
    command="python train.py --data_path ${{inputs.data_path}} --model_output_path ${{outputs.model_output}}",
    environment=f"{custom_env_name}@latest", # Use the custom environment created above
    compute=compute_name,
    outputs={
        "model_output": ml.Output(type=AssetTypes.MLFLOW_MODEL) # Save model in MLflow format (recommended)
    },
    experiment_name="sensitive-data-classification"
)

# --- 6. Submit Training Job ---
print("Submitting training job...")
# Uncomment to submit the job
# returned_job = ml_client.jobs.create_or_update(job)
# print(f"Job ID: {returned_job.id}")
# ml_client.jobs.stream(returned_job.id) # Stream logs in real-time
# returned_job.wait_for_completion()
# print(f"Training job completed with status: {returned_job.status}")

# For demonstration, let's assume a job completes and we get its ID
# Replace with actual job ID from a successful run
# If running interactively, uncomment the job submission above
job_id = "your-completed-training-job-id" # e.g., "sensitive-data-classifier-training_12345678"


# --- 7. Register Model (after job completion) ---
print("Registering model...")
model_path_from_job = f"azureml://jobs/{job_id}/outputs/model_output"
registered_model = ml_client.models.create_or_update(
    Model(
        name="sensitive-data-classifier-model",
        path=model_path_from_job,
        type=AssetTypes.MLFLOW_MODEL, # Or AssetTypes.CUSTOM_MODEL if not MLflow compatible
        description="NLP model for classifying PII, PCI, PHI, Confidential data.",
        tags={"task": "text-classification", "sensitivity": "high"}
    )
)
print(f"Model '{registered_model.name}' registered with ID: {registered_model.id}, Version: {registered_model.version}")


# --- 8. Deploy Endpoint (Real-time Online Endpoint) ---
# Create an Azure ML Online Endpoint for real-time inference.
endpoint_name = "sensitive-data-classifier-endpoint"
endpoint = OnlineEndpoint(
    name=endpoint_name,
    description="Real-time endpoint for sensitive data classification",
    auth_mode="key", # or "aml_token" - depends on how you authenticate
)
print(f"Creating/updating endpoint '{endpoint.name}'...")
# ml_client.online_endpoints.begin_create_or_update(endpoint).wait()
# print(f"Endpoint '{endpoint.name}' is ready.")


# --- 9. Deploy Deployment to Endpoint ---
# Create a deployment for the registered model on the endpoint.
deployment_name = "v1-deployment" # A name for this specific deployment version
deployment = OnlineDeployment(
    name=deployment_name,
    endpoint_name=endpoint_name,
    model=registered_model, # The model registered previously
    environment=f"{custom_env_name}@latest", # Use the custom environment
    instance_type="Standard_DS3_v2", # Matches your Terraform instance type for inference
    instance_count=1,
)
print(f"Creating/updating deployment '{deployment.name}' for endpoint '{endpoint.name}'...")
# ml_client.online_deployments.begin_create_or_update(deployment).wait()
# print(f"Deployment '{deployment.name}' is ready.")


# --- 10. Test Inference (Optional) ---
# Assuming the endpoint is deployed and healthy
# You would use the endpoint URL and key (from AML workspace) in your Azure Function
# sample_text = {"input_data": {"text": "My name is Alice and my SSN is 999-88-7777 and I live at 123 Pine St."}}
# test_result = ml_client.online_endpoints.invoke(
#     endpoint_name=endpoint_name,
#     request_file=None, # Pass dictionary directly
#     input=json.dumps(sample_text)
# )
# print(f"Test inference result: {test_result}")

# --- 11. Cleanup (Optional, for temporary resources) ---
# ml_client.online_endpoints.begin_delete(name=endpoint_name).wait()
# ml_client.compute.begin_delete(name=compute_name).wait()
