In [None]:
import os
import boto3
import mlflow
import json
import shutil

## Verify AWS Role Assignment
Only the prod-user should have permissions to read directly from the S3 bucket for DOMINO_BLOBS In this step verify that the role is being assumed.

In [None]:
os.environ['AWS_WEB_IDENTITY_TOKEN_FILE']='/var/run/secrets/eks.amazonaws.com/serviceaccount/token'
sts_client = boto3.client('sts')
identity = sts_client.get_caller_identity()
AWS_ROLE_ARN = os.environ['AWS_ROLE_ARN']
print(f"Verify identity correctly assumed as = {AWS_ROLE_ARN}")
print(identity)



In [None]:
CLIENT_REGISTERED_MODEL_NAME="BERT-BASED-CLIENT"
CLIENT_REGISTERED_MODEL_VERSION=10

In [None]:
def get_all_model_version(registered_model_name):
    client = mlflow.tracking.MlflowClient()
    # List all versions for the model
    model_versions = client.search_model_versions(f"name='{registered_model_name}'")

    for version in model_versions:
        print(f"Version: {version.version}, Status: {version.status}, Stage: {version.current_stage}")


In [None]:
def get_parent_run_id_for_model_version(registered_model_name,registered_model_version):
    client = mlflow.tracking.MlflowClient()
    model_version_info = client.get_model_version(name=registered_model_name, version=registered_model_version)

    version_info = client.get_model_version(name=registered_model_name, version=registered_model_version)

    # Get run ID
    run_id = version_info.run_id

    # Get run object
    run = client.get_run(run_id)
    return run


In [None]:


def download_s3_folder(bucket_name, model_name,model_version,run_id, prod_ds_folder):
    s3 = boto3.client('s3')

    # Full path where we want to download
    target_local_folder = os.path.join(prod_ds_folder, model_name, model_version)

    # Check if parent folder exists and is writable
    if not os.path.isdir(prod_ds_folder):
        raise Exception(f"Base directory {prod_ds_folder} does not exist. Cannot proceed.")

    if not os.access(prod_ds_folder, os.W_OK):
        raise Exception(f"Base directory {prod_ds_folder} is not writable. Cannot proceed.")

    # If run_id folder already exists, delete it
    if os.path.exists(target_local_folder):
        shutil.rmtree(target_local_folder)

    # Create clean run_id folder
    os.makedirs(target_local_folder, exist_ok=True)

    # List all objects under mlflow/{run_id}/artifacts/model/
    prefix = f"mlflow/{run_id}/artifacts/model/"
    paginator = s3.get_paginator('list_objects_v2')
    pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)

    for page in pages:
        for obj in page.get('Contents', []):
            key = obj['Key']
            if key.endswith('/'):
                continue  # Skip "folder" markers

            # Download each file
            relative_path = key[len(prefix):]
            local_file_path = os.path.join(target_local_folder, relative_path)

            os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
            print(f"Downloading {key} to {local_file_path}")
            s3.download_file(bucket_name, key, local_file_path)


### Fetch Model Artifacts from Experiment Manager

Select you model name and model version and download the model artifacts from the mlflow_run_id associated with this model name and model version

Note that we enable IRSA for this workspace and this prod deployer user has permissions to:

1. List/Read Domino BLOBS bucket
2. Fetch the parent_run_id from the `registered_model_name/registered_model_version`. The function is `get_parent_run_id_for_model_version`
- Fetch the mlflow_run_id associated with this registered model version
- The artifact artifacts/model_context.json associated with this mlflow_run_id
- The parent_run_id is contained in this json in this attribute run_id
3. Now use the function download_s3_folder to download artifacts from the bucket location mlflow/{run_id}/artifacts/model/ into the location `{PROD_DATASET}/llm-models/{parent_run_id}/

In [None]:

#get_all_model_version(registered_model_name)
run = get_parent_run_id_for_model_version(CLIENT_REGISTERED_MODEL_NAME,CLIENT_REGISTERED_MODEL_VERSION)
params = run.data.params
print(params)
parent_run_id = params['parent_run_id']
model_name = params['triton_model_name']
model_version = params['triton_model_version']
#parent_run = get_run(params['triton_model_name'],params['triton_model_version'])

In [None]:

BUCKET_NAME=os.environ['MLFLOW_S3_BUCKET']
#LLM_ARTIFACTS_RUN_ID=get_parent_run_id_for_model_version(CLIENT_REGISTERED_MODEL_NAME,CLIENT_REGISTERED_MODEL_VERSION)
PROD_DS_FOLDER="/mnt/imported/data/triton-prod-ds/models/pre-load/"

print(f"BUCKET_NAME={BUCKET_NAME}, LLM_ARTIFACTS_RUN_ID={parent_run_id}")
download_s3_folder(BUCKET_NAME,model_name,model_version,parent_run_id,PROD_DS_FOLDER)


### Test TritonModel class locally but via the production inference server
Download it from model registry
load_context called automatically and it sees the same mount that is shared between wks and model api
predict call will interpret the input

In [None]:
import mlflow.pyfunc
os.environ['MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR']="true"
os.environ['inference-proxy-service']="https://inference-proxy-service.domino-inference-prod.svc.cluster.local:8443"
# Set model URI (update with your MLflow model registry path)
model_uri = f"models:/{CLIENT_REGISTERED_MODEL_NAME}/{CLIENT_REGISTERED_MODEL_VERSION}"  # Example for a registry model
print(model_uri)
# model_uri = "runs:/your_run_id/model"  # If stored in a specific run
# Load the MLflow model
model = mlflow.pyfunc.load_model(model_uri)

In [None]:
payload={  
    "payload": {
       "inputs": [
            {
                "name": "input_ids",
                "shape": [1, 8],
                "datatype": "INT64",
                "data": [101, 1045, 2293, 2023, 3185, 999, 102, 0]
            },
            {
                "name": "attention_mask",
                "shape": [1, 8],
                "datatype": "INT64",
                "data": [1, 1, 1, 1, 1, 1, 1, 0]
            },
            {
                "name": "token_type_ids",
                "shape": [1, 8],
                "datatype": "INT64",
                "data": [0, 0, 0, 0, 0, 0, 0, 0]
            }
      ]
    }
    
  }
  



In [None]:
#Give the model on the Triton side a few seconds to load
model.predict(payload)