# Custom MLflow Python Function 
### `ESM2 Protein Embedding Transformer`


<!-- # external ref https://huggingface.co/blog/AmelieSchreiber/protein-optimization-and-design -->

In [0]:
import sys
print(sys.version)

In [0]:
%sh python --version

In [0]:
# Create the requirements.txt file
requirements = """
torch==2.1.0
transformers==4.34.0
accelerate==0.23.0
cloudpickle==3.1.1
"""

requirements_path = f"/Volumes/{catalog_name}/{schema_name}/{volume_name}/requirements.txt"

# Save the requirements.txt file to UC volumes
with open(requirements_path, "w") as f:
    f.write(requirements)

In [0]:
%pip install -r {requirements_path}

In [0]:
dbutils.library.restartPython() 

In [0]:
catalog_name = "mmt_demos"  
schema_name = "dependencies"  
volume_name = "esm_artifacts"  
requirements_path = f"/Volumes/{catalog_name}/{schema_name}/{volume_name}/requirements.txt"

In [0]:
ESMWrapper: Custom wrapper class for the transformer model.
load_context: Loads the model and tokenizer from the provided artifacts and sets up the device (CPU or GPU).
predict: Takes input sequences, tokenizes them, runs them through the model, and returns the embeddings.

In [0]:
# Custom PyFunc wrapper for a large transformer model
import mlflow
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, logging
import shutil
import pandas as pd
from mlflow.models.signature import infer_signature

# Set the logging level to ERROR to disable verbose messages
# logging.set_verbosity_error()

# Define a Custom wrapper class for the transformer model
class ESMWrapper(mlflow.pyfunc.PythonModel):
    # load_context: Loads the model and tokenizer from the provided artifacts and sets up the device (CPU or GPU).
    def load_context(self, context):
        # Load ESM model from saved files in the artifact
        # self.artifacts = context.artifacts

        # Determine the device to use (GPU if available, otherwise CPU)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Load the tokenizer from the provided artifacts
        self.tokenizer = AutoTokenizer.from_pretrained(context.artifacts["tokenizer"])
        # Load the model from the provided artifacts
        self.model = AutoModelForMaskedLM.from_pretrained(context.artifacts["model"])
        # Move the model to the appropriate device (GPU or CPU)
        self.model.to(self.device)
        # Set the model to evaluation mode
        self.model.eval()

        # Set special tokens if they are not set; ensure that the beginning-of-sequence (bos_token) and separator (sep_token) tokens are set. If they are not already set, they are assigned the value of the cls_token (classification token).
        if self.tokenizer.bos_token is None:
            self.tokenizer.bos_token = self.tokenizer.cls_token
        if self.tokenizer.sep_token is None:
            self.tokenizer.sep_token = self.tokenizer.cls_token
    
    # Define the predict function which takes input sequences, tokenizes them, runs them through the model, and returns the embeddings
    def predict(self, context, model_input):
        protein_sequences = model_input["sequences"]
        results = []
        
        # Process each sequence
        for seq in protein_sequences:
            inputs = self.tokenizer(seq, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model(**inputs, output_hidden_states=True)
                
            # Process outputs as needed for your application
            embeddings = outputs.hidden_states[-1].mean(dim=1).cpu().numpy()
            results.append(embeddings)
            
        return results

# Log the model with explicit dependencies
with mlflow.start_run():
    # Download and Save Model Components:
    model_name = "facebook/esm2_t33_650M_UR50D" #https://huggingface.co/facebook/esm2_t33_650M_UR50D
    
    # Specifies the model name and paths to save the model and tokenizer.
    model_path = f"/Volumes/{catalog_name}/{schema_name}/{volume_name}/tmp_model"
    tokenizer_path = f"/Volumes/{catalog_name}/{schema_name}/{volume_name}/tmp_tokenizer"
    
    # Download the pre-trained model and tokenizer from Hugging Face.
    model = AutoModelForMaskedLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Saves the model and tokenizer to the specified paths.
    model.save_pretrained(model_path, safe_serialization=False)
    tokenizer.save_pretrained(tokenizer_path)
    
    # Define conda env with necessary dependencies
    conda_env = {
        "channels": ["defaults", "conda-forge", "pytorch"],
        "dependencies": [
            "python=3.11", # compute 15.4LTSMLR 
            "pip>=22.0.4",
            {"pip": [
                "torch==2.1.0", 
                "transformers==4.34.0",
                "accelerate==0.23.0",
                "cloudpickle==3.1.1", #compute 15.4LTSMLR    
            ]}
        ],
        "name": "esm_env"
    }
    
    # Create a sample input DataFrame to infer the input and output signature of the model.
    sample_input = pd.DataFrame({"sequences": ["MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFPDWQNYG"]})
    
    ## Initialize the wrapper and load the context manually for signature inference

    # Initialize an instance of the ESMWrapper class.
    esm_wrapper = ESMWrapper()
    # Manually set the tokenizer and model for the wrapper.
    esm_wrapper.tokenizer = tokenizer
    esm_wrapper.model = model
    # Determine the device (GPU or CPU) and moves the model to the appropriate device.
    esm_wrapper.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    esm_wrapper.model.to(esm_wrapper.device)
    # Set the model to evaluation mode.
    esm_wrapper.model.eval()
    
    # Use the wrapper to predict the output for the sample input
    sample_output = esm_wrapper.predict(None, sample_input)
    # Infer the input and output signature of the model using the sample input and output
    signature = infer_signature(sample_input, sample_output)
    
    # Log the model with MLflow, including the artifacts (model and tokenizer paths), Conda environment, signature, and input example.
    mlflow.pyfunc.log_model(
        artifact_path="esm_model",
        python_model=ESMWrapper(),
        artifacts={
            "model": model_path,
            "tokenizer": tokenizer_path,
            "requirements": requirements_path  
        },
        conda_env=conda_env,
        signature=signature,
        input_example=sample_input,
        # Register the model with a specified name.
        registered_model_name=f"{catalog_name}.{schema_name}.esm_protein_model"
    )

In [0]:
sample_output

In [0]:
# [9xz5z] [2025-04-07 16:15:05 +0000] 2025/04/07 16:15:05 WARNING mlflow.pyfunc: The version of CloudPickle that was used to save the model, `CloudPickle 2.2.1`, differs from the version of CloudPickle that is currently running, `CloudPickle 3.1.1`, and may be incompatible

In [0]:
import mlflow

model_uri = f"models:/{catalog_name}.{schema_name}.esm_protein_model/8"
dependencies = mlflow.pyfunc.get_model_dependencies(model_uri)
print(dependencies)

In [0]:
!cat /local_disk0/repl_tmp_data/ReplId-19610-f2e2d-7/tmpg3ihn9ch/requirements.txt

In [0]:
from mlflow.tracking import MlflowClient

def get_latest_model_version(model_name):
    mlflow_client = MlflowClient(registry_uri="databricks-uc")
    latest_version = 1
    for mv in mlflow_client.search_model_versions(f"name='{model_name}'"):
        version_int = int(mv.version)
        if version_int > latest_version:
            latest_version = version_int
    return latest_version

def get_model_uri(model_name):
  return f"models:/{model_name}/{get_latest_model_version(model_name)}"

In [0]:
host = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

workload_type = "GPU_SMALL"
workload_size = "Small"

registered_model_name = f"{catalog_name}.{schema_name}.esm_protein_model"
latest_model_version = get_latest_model_version(registered_model_name)
general_model_name = f"esm_protein_model-{latest_model_version}" 

endpoint_base_name = "esm_protein_model_endpoint"
endpoint_name = f"{endpoint_base_name}-mmt-mlflowsdk" # endpoint name

print(f"catalog: {catalog_name}",  "schema:", {schema_name}, "host:", host, "token:", token, workload_type, workload_size, "general_model_name:", general_model_name, "registered_model_name:", registered_model_name, "latest_model_version:", latest_model_version, "endpoint_name:", endpoint_name)

In [0]:
# client.delete_endpoint(endpoint_name)

In [0]:
from mlflow.deployments import get_deploy_client

client = get_deploy_client("databricks")

# Define the full API request payload
endpoint_config = {
    "name": general_model_name,
    "served_models": [
        {                
            "model_name": registered_model_name,
            "model_version": latest_model_version,
            "workload_size": workload_size,  # defines concurrency: Small/Medium/Large
            "workload_type": workload_type,  # defines compute: GPU_SMALL/GPU_MEDIUM/GPU_LARGE
            "scale_to_zero_enabled": True
        }
    ],
    "traffic_config": {
        "routes": [
            {
                "served_model_name": general_model_name,
                "traffic_percentage": 100
            }
        ]
    },
    "auto_capture_config": {
        "catalog_name": catalog_name,
        "schema_name": schema_name,
        # "table_name_prefix": endpoint_base_name,
        "enabled": True
    },
    "tags": {
        "project": "esm_protein_model",
        "team": "ml-team",
        "removeAfter": "2025-12-31",
    }
}

# Create or update the endpoint
try:
    # Check if endpoint exists
    existing_endpoint = client.get_endpoint(endpoint_name)
    print(f"Endpoint {endpoint_name} exists, updating configuration...")
    client.update_endpoint_config(endpoint_name, endpoint_config)
except Exception as e:
    if "RESOURCE_DOES_NOT_EXIST" in str(e):
        print(f"Creating new endpoint {endpoint_name}...")
        client.create_endpoint(endpoint_name, endpoint_config)
    else:
        raise

In [0]:
import requests
import json
import time

response = requests.post(
    f"{workspace_url}/api/2.0/serving-endpoints",
    headers={"Authorization": f"Bearer {token}"},
    data=json.dumps(endpoint_config)
)

if response.status_code == 200:
    print("Serving endpoint created successfully.")
else:
    print(f"Failed to create serving endpoint: {response.text}")

# Check the status of the endpoint
endpoint_status_url = f"{workspace_url}/api/2.0/serving-endpoints/{endpoint_name}"

max_retries = 6  # Number of times to check (6 times for 1 hour)
retry_interval = 600  # 10 minutes in seconds

# while True:
for attempt in range(max_retries):

    status_response = requests.get(
        endpoint_status_url,
        headers={"Authorization": f"Bearer {token}"}
    )
    if status_response.status_code == 200:
        state = status_response.json().get("state", {})
        ready = state.get("ready", False)
        config_update = state.get("config_update", False)
        message = state.get("message", "No status message available")
        pending_config = status_response.json().get('pending_config', {})
        served_entities = pending_config.get('served_entities', [{}])
        served_entity_state = served_entities[0].get('state', 'Unknown')
        
        if ready and config_update != 'IN_PROGRESS':
            print("Serving endpoint is ready.")
            break
        else:
            print(state)
            print(f"Serving endpoint state: {message}")
            print(f"Served entity state: {served_entity_state}")
            print("Waiting for the serving endpoint config_updates to be ready...")
            time.sleep(retry_interval)  # Wait for 10 minutes before checking again
    else:
        print(f"Failed to get endpoint status: {status_response.text}")
        break
else:
    print("Endpoint config_update is still not ready after the maximum number of check retries. Please review the Serving UI to see if there are errors or issues.")

In [0]:
status_response = requests.get(
        endpoint_status_url,
        headers={"Authorization": f"Bearer {token}"}
    )

status_response.json()

In [0]:
state = status_response.json().get("state", {})
state
# ready = state.get("ready", False)

In [0]:
(status_response.json()['state'], 
 status_response.json()['state']['ready'],
 status_response.json()['pending_config']['served_entities'][0]['state']
 )

In [0]:
import requests
import json

# Extract Databricks workspace URL and token using dbutils
workspace_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

# Endpoint URL
endpoint_url = f"{workspace_url}/serving-endpoints/{endpoint_name}/invocations"

# Sample input
input_data = {
    "dataframe_records": [
        {"sequences": "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFPDWQNYG"}
    ]
}

response = requests.post(
    endpoint_url,
    headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"},
    data=json.dumps(input_data)
)

if response.status_code == 200:
    print("Prediction:", response.json())
else:
    print(f"Failed to get prediction: {response.text}")

In [0]:
input_data = {
    "dataframe_records": [
        {"sequences": "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFPDWQNYG"}
    ]
}

input_data

In [0]:
json.dumps(input_data)

In [0]:
# https://huggingface.co/blog/AmelieSchreiber/protein-optimization-and-design 

# MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE

In [0]:
input_data2 = {
    "dataframe_records": [
        {"sequences": "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFPDWQNYG"},
        {"sequences": "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"}
    ]
}

input_data2, json.dumps(input_data2)

In [0]:
import requests
import json

# Extract Databricks workspace URL and token using dbutils
workspace_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

# Endpoint URL
endpoint_url = f"{workspace_url}/serving-endpoints/{endpoint_name}/invocations"

# Sample input with multiple sequences
input_data = {
    "dataframe_records": [
        {"sequences": "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFPDWQNYG"},
        {"sequences": "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"}
    ]
}

response = requests.post(
    endpoint_url,
    headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"},
    data=json.dumps(input_data)
)

if response.status_code == 200:
    print("Prediction:", response.json())
else:
    print(f"Failed to get prediction: {response.text}")

In [0]:
## actually slower... 

import mlflow.pyfunc
from pyspark.sql.functions import col, flatten

# Load the registered model as a Spark UDF
model_name = "mmt_demos.dependencies.esm_protein_model"
model_version = 8
model_uri = f"models:/{model_name}/{model_version}"
model_udf = mlflow.pyfunc.spark_udf(spark, model_uri=model_uri)

# Example Spark DataFrame with sequences
df = spark.createDataFrame([
    ("MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFPDWQNYG",),
    ("MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE",)
], ["sequences"])

# Apply the model UDF to the DataFrame
df_result = df.withColumn("predictions", 
                          flatten(model_udf(col("sequences"))) 
                         )

# Display the result
display(df_result)

In [0]:
endpoint_name
# esm_protein_model_endpoint-mmt-mlflowsdk

In [0]:
import os
import requests
import numpy as np
import pandas as pd
import json

def get_databricks_token():
    return dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

def create_tf_serving_json(data):
    return {'inputs': {name: data[name].tolist() for name in data.keys()} if isinstance(data, dict) else data.tolist()}

def score_model(dataset):
    url = f'https://e2-demo-field-eng.cloud.databricks.com/serving-endpoints/{endpoint_name}/invocations'
    token = get_databricks_token()
    headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
    ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
    data_json = json.dumps(ds_dict, allow_nan=True)
    response = requests.request(method='POST', headers=headers, url=url, data=data_json)
    if response.status_code != 200:
        raise Exception(f'Request failed with status {response.status_code}, {response.text}')
    return response.json()

In [0]:
import pandas as pd
from pyspark.sql.functions import col

# Example Spark DataFrame with sequences
df = spark.createDataFrame([
    ("MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFPDWQNYG",),
    ("MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE",)
], ["sequences"])

# Convert Spark DataFrame to Pandas DataFrame for batch inferencing
df_pd = df.toPandas()

# Perform batch inferencing using the score_model function
model_score = score_model(df_pd)

# Convert predictions back to Spark DataFrame
predictions_df = pd.DataFrame(model_score['predictions'])
scored_df = spark.createDataFrame(predictions_df)

# Display the scored DataFrame
display(scored_df)

In [0]:
df_pd

In [0]:
ds_dict = {'dataframe_split': df_pd.to_dict(orient='split')} if isinstance(df_pd, pd.DataFrame) else create_tf_serving_json(df_pd)

In [0]:
ds_dict

In [0]:
data_json = json.dumps(ds_dict, allow_nan=True)

In [0]:
data_json

In [0]:
import os
import requests
import numpy as np
import pandas as pd
import json
from pyspark.sql.functions import pandas_udf, col, explode
from pyspark.sql.types import ArrayType, FloatType, StringType, StructType, StructField

# Set your Personal Access Token (PAT) here
PAT_TOKEN = dbutils.secrets.get(scope="mmt", key="databricks_token")

def create_tf_serving_json(data):
    return {'inputs': {name: data[name].tolist() for name in data.keys()} if isinstance(data, dict) else data.tolist()}

def score_model(dataset):
    url = f'https://e2-demo-field-eng.cloud.databricks.com/serving-endpoints/{endpoint_name}/invocations'
    token = PAT_TOKEN
    headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
    ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
    data_json = json.dumps(ds_dict, allow_nan=True)
    response = requests.request(method='POST', headers=headers, url=url, data=data_json)
    if response.status_code != 200:
        raise Exception(f'Request failed with status {response.status_code}, {response.text}')
    return response.json()['predictions']

# Define the schema for the output
schema = StructType([
    StructField("sequences", StringType(), True),
    StructField("predictions", ArrayType(FloatType()), True)
])

# Define the Pandas UDF
@pandas_udf(schema)
def score_model_udf(sequences: pd.Series) -> pd.DataFrame:
    results = []
    for sequence in sequences:
        predictions = score_model(pd.DataFrame({"sequences": [sequence]}))[0]
        # Flatten the predictions if they are nested lists
        if isinstance(predictions[0], list):
            scores = [item for sublist in predictions for item in sublist]
        scores = np.array(scores).astype(np.float32).tolist()  # Convert to float32 and then to list
        results.append({"sequences": sequence, "predictions": scores})
    return pd.DataFrame(results)

# Example Spark DataFrame with sequences
df = spark.createDataFrame([
    ("MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFPDWQNYG",),
    ("MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE",)
], ["sequences"])

# Apply the Pandas UDF to the DataFrame
scored_df = df.withColumn("scores", score_model_udf(col("sequences")))


In [0]:
## quite fast

display(scored_df)

In [0]:
display(scored_df.select('scores.*'))