# Train and Deploy AML Model

This notebook trains a LogisticRegression model to predict high-value donors (total_donations > $1000) using `gold/donor_activity` and deploys it to an Azure Machine Learning endpoint.

## Inputs
- `gold/donor_activity` (from `silver_to_gold`)

## Outputs
- Model saved to `models/high_value_donor_model`
- AML endpoint for real-time predictions

## Dependencies
- `pyspark.ml`, `azure-ai-ml`, `azure-identity`

## Environment
- Uses `.env` for Blob Storage and AML credentials
- Runs monthly in Databricks Workflow

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
from azure.ai.ml import MLClient
from azure.ai.ml.entities import ManagedOnlineEndpoint, ManagedOnlineDeployment, Model
from azure.identity import DefaultAzureCredential
from dotenv import load_dotenv
import os

# Load environment variables
load_dotenv()
AZURE_CONN_STR = os.getenv("AZURE_STORAGE_CONNECTION_STRING")
CONTAINER_NAME = os.getenv("CONTAINER_NAME")
AZURE_STORAGE_ACCOUNT_NAME = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
AZURE_SUBSCRIPTION_ID = os.getenv("AZURE_SUBSCRIPTION_ID")
AZURE_RESOURCE_GROUP = os.getenv("AZURE_RESOURCE_GROUP")
AZURE_ML_WORKSPACE = os.getenv("AZURE_ML_WORKSPACE")

if not all([AZURE_CONN_STR, CONTAINER_NAME, AZURE_STORAGE_ACCOUNT_NAME, AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ML_WORKSPACE]):
    raise ValueError("Missing environment variables. Check .env file.")

# Initialize Spark session
spark = SparkSession.builder \
    .appName("TrainAMLModel") \
    .config("spark.hadoop.fs.azure.account.auth.type", "SAS") \
    .config("spark.hadoop.fs.azure.sas.token.provider.type", "org.apache.hadoop.fs.azure.SimpleSasTokenProvider") \
    .config("spark.hadoop.fs.azure.sas.fixed.token", AZURE_CONN_STR) \
    .getOrCreate()

# Define blob storage paths
blob_base_path = f"wasbs://{CONTAINER_NAME}@{AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net"
gold_path = f"{blob_base_path}/gold"
model_path = f"{blob_base_path}/models/high_value_donor_model"

# Read donor_activity table
donor_features = spark.read.parquet(f"{gold_path}/donor_activity")

# Prepare features and label
features = donor_features.withColumn("label", when(col("total_donations") > 1000, 1).otherwise(0))
assembler = VectorAssembler(inputCols=["total_donations", "donation_count"], outputCol="features")
data = assembler.transform(features)

# Train LogisticRegression model
lr = LogisticRegression(maxIter=10, regParam=0.01)
model = lr.fit(data)

# Save model to Blob Storage
model.write(model_path)
print(f"Model saved to {model_path}")

# Initialize AML client
credential = DefaultAzureCredential()
ml_client = MLClient(credential, AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ML_WORKSPACE)

# Create or update AML endpoint
endpoint_name = "high-value-donor-endpoint"
endpoint = ManagedOnlineEndpoint(name=endpoint_name)
try:
    ml_client.online_endpoints.begin_create_or_update(endpoint).result()
except Exception as e:
    print(f"Endpoint creation error: {str(e)}")

# Deploy model to AML endpoint
deployment_name = "high-value-donor-deployment"
model_entity = Model(path=model_path)
deployment = ManagedOnlineDeployment(
    name=deployment_name,
    endpoint_name=endpoint_name,
    model=model_entity,
    instance_type="Standard_DS2_v2",
    instance_count=1
)
try:
    ml_client.online_deployments.begin_create_or_update(deployment).result()
    # Set deployment to handle 100% of traffic
    ml_client.online_endpoints.update_traffic(endpoint_name, {deployment_name: 100})
    # Get endpoint details
    endpoint_details = ml_client.online_endpoints.get(endpoint_name)
    scoring_uri = endpoint_details.scoring_uri
    endpoint_key = ml_client.online_endpoints.get_keys(endpoint_name).primary_key
    print(f"Model deployed to AML endpoint: {scoring_uri}")
    print(f"Update .env with:\nAZURE_ML_ENDPOINT_URL={scoring_uri}\nAZURE_ML_ENDPOINT_KEY={endpoint_key}")
except Exception as e:
    print(f"Deployment error: {str(e)}")

spark.stop()
print("AML model training and deployment completed.")