In [None]:
#Install necessary libraries
!pip install --upgrade torch transformers accelerate optimum[onxx] 

In [None]:
!pip install onnx onnxruntime

In [1]:
import os
import json
import mlflow
import requests
from mlflow.exceptions import MlflowException
from mlflow.tracking import MlflowClient
from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository
from pathlib import Path
import shutil
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from optimum.exporters.tasks import TasksManager
from optimum.exporters.onnx import export


### Download models from hugging face
In this step we download two models from Hugging Face
1. Bert-based
2. Distilbert-based

In [2]:
DEVICE = "cpu"
# ONNX opset version (14+ needed for scaled_dot_product_attention)
OPSET = 14

def download_and_export_onxx(model_name: str, output_dir: Path, task: str = "text-classification"):
    if not os.path.exists(output_dir):       
        os.makedirs(output_dir)
    """
    Export a Hugging Face model to ONNX using Optimum.
    """
    print(f"Exporting {model_name} to ONNX for task={task} -> {output_dir}")

    # 1. Load model & tokenizer
    print("Loading model...")
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,  # Adjust to your actual classification labels
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # 2. Determine model_type (e.g., "bert", "distilbert")
    model_type = model.config.model_type
    

    # 3. Build the exporter config
    exporter_config_constructor = TasksManager.get_exporter_config_constructor(
        model_type=model_type,
        exporter="onnx",
        library_name="transformers"  # Explicitly specify "transformers"
    )
    exporter_config = exporter_config_constructor(
        model.config,
        task=task
    )

    print("Running export...")
    file_path = os.path.join(output_dir,"model.onnx")
    # 4. Run export
    onnx_paths = export(
        model=model,
        config=exporter_config,
        output=Path(file_path),
        device=DEVICE,
        opset=OPSET
    )
    print("Saving tokenizer...")
    tokenizer.save_pretrained(output_dir)

In [3]:
BERT = "bert-base-uncased"
DISTILBERT = "distilbert-base-uncased"
TASK = "text-classification"
BERT_LOCAL_FOLDER = f"/home/ubuntu/mymodels/{BERT}"
DISTILBERT_LOCAL_FOLDER = f"/home/ubuntu/mymodels/{DISTILBERT}"



In [4]:
download_and_export_onxx(BERT, BERT_LOCAL_FOLDER,TASK)   
download_and_export_onxx(BERT, DISTILBERT_LOCAL_FOLDER,TASK)    

Exporting bert-base-uncased to ONNX for task=text-classification -> /home/ubuntu/mymodels/bert-base-uncased
Loading model...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Running export...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Saving tokenizer...
Exporting bert-base-uncased to ONNX for task=text-classification -> /home/ubuntu/mymodels/distilbert-base-uncased
Loading model...
Running export...
Saving tokenizer...


## Prepare to register to MLflow

`EmptyTritonModel` only exists so we can register the binaries of large model. This model will not be deployed. The dataset hosting the model are stored in the folder location 
`<DATASET>/models/<registered_model_name>/<registered_model_version>`

The values for the `registered_model_name` and `registered_model_version` are the values for the `EmptyTritonModel`

In [5]:
class EmptyTritonModel(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        return
    def predict(self, context, model_input, params=None):
        return

The `TritonModel` is the actual model which gets deployed. No model binaries are registered with it. Instead it is merely passed the following parameters in the configuration
1. `inference_server_name` - This is the Triton Inference Server where the models are deployed.
2. `model_name` - The `registered model name` for the associated `EmptyTritonModel`
3. `model_name` - The `registered model version` for the associated `EmptyTritonModel`

In [6]:
class TritonModel(mlflow.pyfunc.PythonModel):
    import json
    import requests
    def load_context(self, context):
        with open(context.artifacts["model_context"], "r") as f:
            cfg = json.load(f)
        #self.proxy_service = cfg["proxy_service"]
        self.proxy_service = os.getenv("inference-proxy-service","")
        self.inference_server_name = cfg["inference_server_name"]
        self.model_name = cfg["model_name"]
        self.model_version = cfg["model_version"]

    def predict(self, context, model_input, params=None):
        print("Called predict")
        triton_model_input = {
            "inference_server_name": self.inference_server_name,
            "model_name": self.model_name,
            "model_version": self.model_version,
            "payload": model_input["payload"]
        }
        url = f"{self.proxy_service}/predict"
        if not self.proxy_service:
            return {
                "status_code": 400,
                "body": "Proxy Service Variable not set"
            }            
        else:
            resp = requests.post(url, json=triton_model_input)
            return {
                "status_code": resp.status_code,
                "body": resp.json()
        }

MLflow code to create experiement and register model

In [7]:
def get_or_create_experiment(name: str) -> str:
    # Try to get the experiment
    experiment = mlflow.get_experiment_by_name(name)
    if experiment:
        return experiment.experiment_id

    # Otherwise, create it
    experiment_id = mlflow.create_experiment(name)
    return experiment_id

In [8]:
def get_or_create_registered_model(name: str):
    client = MlflowClient()
    try:
        # Try to get the model details
        model = client.get_registered_model(name)
        print(f"Model '{name}' already registered.")
    except MlflowException as e:
        if "RESOURCE_DOES_NOT_EXIST" in str(e):
            # Model doesn't exist — create it
            model = client.create_registered_model(name)
            print(f"Model '{name}' created.")
        else:
            raise e
    return model

Specify three variables:
1. Experiment Name where the MLflow runs are registered
2. Registered Model Name - for the `EmptyTritonModel` which stores the model binaries
3. Client Registered Model Name - for the `TritonModel` which serves as a client model

In [9]:
experiment_name="TRITON-INFERENCE-SERVER-MODELS"
registered_model_name="BERT-BASED"
client_registered_model_name="BERT-BASED-CLIENT"


In [10]:
## Create Experiment and Models
experiment_id = get_or_create_experiment(experiment_name)
registered_model = get_or_create_registered_model(registered_model_name)
client_registered_model = get_or_create_registered_model(client_registered_model_name)

Model 'BERT-BASED' already registered.
Model 'BERT-BASED-CLIENT' already registered.


Convert the downloaded models to onxx format and store them to the target folder location - <DEV_DATASET>

In [11]:
DATA_TYPES = {
    "bert-base-uncased": "TYPE_INT64",
    "distilbert-base-uncased": "TYPE_INT64"  # DistilBERT (smaller)
}
def make_models_triton_ready(registered_model_name,registered_model_version,model_type,
                                    source_folder,target_folder):   
    if os.path.exists(target_folder):
        shutil.rmtree(target_folder)    
    model_dir = target_folder
    data_type = DATA_TYPES[model_type]

    """
    Creates a minimal Triton config.pbtxt for an ONNX model with BERT-like inputs.
    Adjust if your model uses different inputs or outputs.
    """
    config_distilbert_base_uncased = f"""
name: "{registered_model_name}"
platform: "onnxruntime_onnx"
max_batch_size: 0

input [
  {{
    name: "input_ids"
    data_type: {data_type}
    dims: [ -1, -1 ]
  }},
  {{
    name: "attention_mask"
    data_type: {data_type}
    dims: [ -1, -1 ]
  }}
]

output [
  {{
    name: "logits"
    data_type: TYPE_FP32
    dims: [ -1, 2 ]
  }}
version_policy {{
      specific {{
         versions: [{registered_model_version}]
      }}
    }}
]
"""
    config_bert_base_uncased = f"""
    name: "{registered_model_name}"
    platform: "onnxruntime_onnx"
    max_batch_size: 0

    input [
      {{
        name: "input_ids"
        data_type: {data_type}
        dims: [ -1, -1 ]
      }},
      {{
        name: "attention_mask"
        data_type: {data_type}
        dims: [ -1, -1 ]
      }},
      {{
       name: "token_type_ids"
       data_type: {data_type}
       dims: [ -1, -1 ]
      }}

    ]

    output [
      {{
        name: "logits"
        data_type: TYPE_FP32
        dims: [ -1, 2 ]
      }}
    ]
    version_policy {{
      specific {{
         versions: [{registered_model_version}]
      }}
    }}
    """
    CONFIG_MAP = {
        "bert-base-uncased": config_bert_base_uncased,
        "distilbert-base-uncased": config_distilbert_base_uncased
    }


        
    config_text = CONFIG_MAP[model_type]
    config_path = os.path.join(source_folder, "config.pbtxt")
    with open(config_path, "w") as f:
            f.write(config_text.strip() + "\n")
    if not os.path.exists(model_dir):       
        for item in os.listdir(source_folder):
            src_path = os.path.join(source_folder, item)
            dst_path = os.path.join(target_folder, item)
    
            if os.path.isdir(source_folder):
                shutil.copytree(source_folder, model_dir, dirs_exist_ok=True)
            else:
                shutil.copy2(source_folder, model_dir)        
    else:
        print(f"{model_dir} already exists")


In [12]:
'''
registered_model_name="test-model"
registered_model_version=20
model_type="bert-base-uncased"
source_folder=BERT_LOCAL_FOLDER
target_folder=f"/mnt/imported/data/triton-dev-ds/models/pre-load/{registered_model_name}/{registered_model_version}"

if os.path.exists(target_folder):
    shutil.rmtree(target_folder)
    
make_models_triton_ready(model_name,model_version,model_type,BERT_LOCAL_FOLDER,target_folder)
print(f"Target folder is {target_folder}")
'''

'\nregistered_model_name="test-model"\nregistered_model_version=20\nmodel_type="bert-base-uncased"\nsource_folder=BERT_LOCAL_FOLDER\ntarget_folder=f"/mnt/imported/data/triton-dev-ds/models/pre-load/{registered_model_name}/{registered_model_version}"\n\nif os.path.exists(target_folder):\n    shutil.rmtree(target_folder)\n    \nmake_models_triton_ready(model_name,model_version,model_type,BERT_LOCAL_FOLDER,target_folder)\nprint(f"Target folder is {target_folder}")\n'

In [14]:
# Start an MLflow run context and log the llama2-7B model wrapper along with the param-included signature to
# allow for overriding parameters at inference time
source_folder=BERT_LOCAL_FOLDER
DEV_TRITON_BASE_FOLDER = "/mnt/imported/data/triton-dev-ds/models/pre-load/"
os.environ['MLFLOW_ENABLE_PROXY_MULTIPART_UPLOAD'] = "true"
LOG_LLM_ARTIFACTS_TO_MLFLOW=True
TRITON_INFERENCE_SERVER="triton-domino-pre-load-inference"
model_type="bert-base-uncased"

client = MlflowClient()

with mlflow.start_run(experiment_id=experiment_id) as parent_run:
    parent_run_id = parent_run.info.run_id
    print(f"Parent Run Id {parent_run_id}")
    # Save parent model
    model_info = mlflow.pyfunc.log_model(
        artifact_path="",
        python_model=EmptyTritonModel(),
        artifacts={}
    )
    
    runs_uri = model_info.model_uri
    print("runs_uri:", runs_uri)
    model_src = RunsArtifactRepository.get_underlying_uri(runs_uri)
    mv = client.create_model_version(registered_model_name, model_src, parent_run_id,tags={"is_parent":"true"})

    
    target_folder=os.path.join(DEV_TRITON_BASE_FOLDER,mv.name,mv.version)
    
    make_models_triton_ready(mv.name,mv.version,model_type,source_folder,target_folder)
    
    if LOG_LLM_ARTIFACTS_TO_MLFLOW:
        mlflow.log_artifacts(source_folder,artifact_path="model")
        
    # Start child run
    with mlflow.start_run(experiment_id=experiment_id,parent_run_id=parent_run_id, nested=True) as child_run:
        child_run_id = child_run.info.run_id
        mlflow.log_param("parent_run_id",parent_run_id)
        mlflow.log_param("triton_model_name",mv.name)
        mlflow.log_param("triton_model_version",mv.version)
        print(f"Child Run Id {parent_run_id}")
        model_context = {
            "parent_run_id":parent_run_id,
            "inference_server_name":TRITON_INFERENCE_SERVER,
            "model_name":mv.name,
            "model_version":mv.version
        }
        config_path = "/tmp/model_context.json"
        with open(config_path, "w") as f:
            json.dump(model_context, f)
            
        model_info = mlflow.pyfunc.log_model(
            artifact_path="triton_model_artifacts",
            python_model=TritonModel(),
            artifacts={"model_context": config_path}
        )
    
        runs_uri = model_info.model_uri
        print("runs_uri:", runs_uri)

        model_src = RunsArtifactRepository.get_underlying_uri(runs_uri)
        mv = client.create_model_version(client_registered_model_name, model_src, child_run_id,tags={"is_triton_client":"true"})

        print("Name:", mv.name)
        print("Version:", mv.version)
        print("Status:", mv.status)

Parent Run Id 0037760ffbe74a258c0b7ee8982f2ece


2025/05/08 20:24:02 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: BERT-BASED, version 10


runs_uri: runs:/0037760ffbe74a258c0b7ee8982f2ece/
Child Run Id 0037760ffbe74a258c0b7ee8982f2ece


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

2025/05/08 20:24:42 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: BERT-BASED-CLIENT, version 7


runs_uri: runs:/8b67cfc67e4f4417810412ab16f9a867/triton_model_artifacts
Name: BERT-BASED-CLIENT
Version: 7
Status: READY
🏃 View run suave-rat-564 at: http://127.0.0.1:8768/#/experiments/1/runs/8b67cfc67e4f4417810412ab16f9a867
🧪 View experiment at: http://127.0.0.1:8768/#/experiments/1
🏃 View run luxuriant-wolf-763 at: http://127.0.0.1:8768/#/experiments/1/runs/0037760ffbe74a258c0b7ee8982f2ece
🧪 View experiment at: http://127.0.0.1:8768/#/experiments/1


### Test TritonModel class locally
Download it from model registry
load_context called automatically and it sees the same mount that is shared between wks and model api
predict call will interpret the input

In [15]:
import mlflow.pyfunc
os.environ['MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR']="true"
os.environ['inference-proxy-service']="http://inference-proxy-service.domino-inference-dev.svc.cluster.local:8000"
# Set model URI (update with your MLflow model registry path)
model_uri = f"models:/{client_registered_model_name}/latest"  # Example for a registry model
print(model_uri)
# model_uri = "runs:/your_run_id/model"  # If stored in a specific run
# Load the MLflow model
model = mlflow.pyfunc.load_model(model_uri)

models:/BERT-BASED-CLIENT/latest


  latest = client.get_latest_versions(name, None if stage is None else [stage])


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [16]:

payload={  
    "payload": {
       "inputs": [
            {
                "name": "input_ids",
                "shape": [1, 8],
                "datatype": "INT64",
                "data": [101, 1045, 2293, 2023, 3185, 999, 102, 0]
            },
            {
                "name": "attention_mask",
                "shape": [1, 8],
                "datatype": "INT64",
                "data": [1, 1, 1, 1, 1, 1, 1, 0]
            },
            {
                "name": "token_type_ids",
                "shape": [1, 8],
                "datatype": "INT64",
                "data": [0, 0, 0, 0, 0, 0, 0, 0]
            }
      ]
    }
    
  }
  



In [17]:
model.predict(payload)

Called predict


{'status_code': 200,
 'body': {'status_code': 200,
  'result': {'model_name': 'BERT-BASED',
   'model_version': '10',
   'outputs': [{'name': 'logits',
     'datatype': 'FP32',
     'shape': [1, 2],
     'data': [0.07767994701862335, 0.16845941543579102]}]}}}