In [None]:
%pip install transformers mlflow torch
dbutils.library.restartPython()

In [0]:
import pandas as pd
import requests
import json
import mlflow
from mlflow.models import infer_signature
from mlflow.transformers import generate_signature_output
from mlflow.tracking import MlflowClient

In [0]:
mlflow.set_registry_uri("databricks-uc")

In [None]:
uc_catalog = "dbdemos"
uc_schema = "default"
uc_model_name = "clinicalbert"
registered_model_name = f"{uc_catalog}.{uc_schema}.{uc_model_name}"

In [0]:
input_example = ["The patient is suffering from a mild [MASK]"]

signature = mlflow.models.infer_signature(
    input_example,
    result
)

# Visualize the signature
signature


In [None]:
# Prepare example for MLflow signature
input_example = icd10_examples[:3]  # Use first 3 ICD10 codes as example

# Generate sample output for signature
sample_output = extract_embeddings(input_example)

# Infer signature for MLflow model
signature = mlflow.models.infer_signature(
    input_example,
    sample_output
)

# Visualize the signature
print("MLflow Model Signature:")
print(signature)
print(f"\nInput type: List of strings")
print(f"Output type: numpy.ndarray of shape [batch_size, 768]")

In [0]:
with mlflow.start_run():
    model_info = mlflow.transformers.log_model(
        transformers_model = fillmaskpipeline,
        artifact_path = "fill_mask_generator",
        input_example = input_example,
        signature = signature,
        registered_model_name = registered_model_name,
    )

model_uri = model_info.model_uri
registered_model_version = model_info.registered_model_version

print(f"Model URI is : {model_uri}")
print(f"Registered Model version: {registered_model_version}")


In [None]:
# ICD10 code examples with descriptions
icd10_examples = [
    "ICD10: E11.9 - Type 2 diabetes mellitus without complications",
    "ICD10: I10 - Essential (primary) hypertension",
    "ICD10: J44.1 - Chronic obstructive pulmonary disease with (acute) exacerbation",
    "ICD10: I50.9 - Heart failure, unspecified",
    "ICD10: N18.3 - Chronic kidney disease, stage 3 (moderate)",
    "ICD10: J45.909 - Unspecified asthma, uncomplicated",
    "ICD10: I21.9 - Acute myocardial infarction, unspecified"
]

# Generate embeddings for ICD10 codes
icd10_embeddings = extract_embeddings(icd10_examples)

print(f"Generated embeddings for {len(icd10_examples)} ICD10 codes")
print(f"Embeddings shape: {icd10_embeddings.shape}")
print(f"Embedding dimensions: {icd10_embeddings.shape[1]}")
print("\n" + "="*80)
print("ICD10 Code Examples and Their Embeddings:")
print("="*80)

for i, (text, embedding) in enumerate(zip(icd10_examples, icd10_embeddings)):
    print(f"\n{i+1}. {text}")
    print(f"   Embedding shape: {embedding.shape}")
    print(f"   Embedding stats - Mean: {embedding.mean():.4f}, Std: {embedding.std():.4f}, Min: {embedding.min():.4f}, Max: {embedding.max():.4f}")
    print(f"   First 10 values: {embedding[:10]}")

In [None]:
import tempfile
import os

# Save model and tokenizer to temporary directory for MLflow artifact logging
with tempfile.TemporaryDirectory() as tmp_dir:
    model_save_path = os.path.join(tmp_dir, "clinicalbert")
    
    # Save model and tokenizer
    model.save_pretrained(model_save_path)
    tokenizer.save_pretrained(model_save_path)
    
    # Log model with MLflow
    with mlflow.start_run() as run:
        model_info = mlflow.pyfunc.log_model(
            artifact_path="clinicalbert_embeddings",
            python_model=ClinicalBERTEmbeddings(),
            artifacts={"model_path": model_save_path},
            input_example=input_example,
            signature=signature,
            registered_model_name=registered_model_name,
            pip_requirements=[
                "transformers",
                "torch",
                "numpy"
            ]
        )

model_uri = model_info.model_uri
registered_model_version = model_info.registered_model_version

print(f"Model URI: {model_uri}")
print(f"Registered Model Version: {registered_model_version}")
print(f"Model registered in Unity Catalog as: {registered_model_name}")

In [None]:
# Load the model from Unity Catalog
loaded_model = mlflow.pyfunc.load_model(f"models:/{registered_model_name}/{registered_model_version}")

# Test with ICD10 code examples
test_icd10_codes = [
    "ICD10: E11.9 - Type 2 diabetes mellitus without complications",
    "ICD10: I10 - Essential (primary) hypertension",
    "ICD10: J44.1 - Chronic obstructive pulmonary disease with (acute) exacerbation",
    "ICD10: I50.9 - Heart failure, unspecified",
    "ICD10: N18.3 - Chronic kidney disease, stage 3 (moderate)",
]

# Generate embeddings using loaded model
embeddings = loaded_model.predict(test_icd10_codes)

print(f"Generated embeddings for {len(test_icd10_codes)} ICD10 codes")
print(f"Embeddings shape: {embeddings.shape}")
print(f"Embedding dimensions per code: {embeddings.shape[1]}")
print("\n" + "="*80)
print("ICD10 Code Embeddings from Unity Catalog Model:")
print("="*80)

for i, (code, embedding) in enumerate(zip(test_icd10_codes, embeddings)):
    print(f"\n{i+1}. {code}")
    print(f"   Shape: {embedding.shape}")
    print(f"   Stats - Mean: {embedding.mean():.4f}, Std: {embedding.std():.4f}")
    print(f"   Range: [{embedding.min():.4f}, {embedding.max():.4f}]")
    print(f"   First 5 values: {embedding[:5]}")

# Display embeddings as DataFrame for better visualization
import pandas as pd
embeddings_df = pd.DataFrame(
    embeddings,
    index=[f"ICD10_{i+1}" for i in range(len(test_icd10_codes))]
)
print(f"\nEmbeddings DataFrame shape: {embeddings_df.shape}")
embeddings_df.head()

In [0]:
from mlflow.deployments import get_deploy_client

# Create a deployment client
client = get_deploy_client("databricks")

endpoint_name = "clinicalbert_endpoint"

endpoint = client.create_endpoint(
    name=endpoint_name,
    config={
        "served_entities": [
            {
                "entity_name": f"{registered_model_name}",
                "entity_version": f"{registered_model_version}",
                "workload_size": "Small",
                "scale_to_zero_enabled": True
            }
        ],
        "traffic_config": {
            "routes": [
                {
                    "served_model_name": f"{uc_model_name}-{registered_model_version}",
                    "traffic_percentage": 100
                }
            ]
        }
    }
)




In [0]:
import time

# Function to wait until the endpoint is ready
def wait_for_endpoint_ready(client, endpoint_name, timeout=1200, interval=60):
    start_time = time.time()
    while time.time() - start_time < timeout:
        endpoint_info = client.get_endpoint(endpoint_name)
        endpoint_state = endpoint_info.get('state', 'UNKNOWN')['ready']
        if endpoint_state == 'READY':
            print(f"Endpoint {endpoint_name} is ready.")
            return
        elif endpoint_state == 'FAILED':
            raise Exception(f"Endpoint {endpoint_name} creation failed.")
        else:
            print(f"Endpoint {endpoint_name} is in state {endpoint_state}. Waiting...")
            time.sleep(interval)
    raise TimeoutError(f"Timeout while waiting for endpoint {endpoint_name} to be ready.")

# Wait for the endpoint to be ready
wait_for_endpoint_ready(client, endpoint_name)

In [None]:
# Invoke the endpoint for embedding generation
response = client.predict(
    endpoint=endpoint_name,
    inputs={"inputs": [
        "ICD10: E11.9 - Type 2 diabetes mellitus without complications",
        "ICD10: I10 - Essential (primary) hypertension",
        "ICD10: J44.1 - Chronic obstructive pulmonary disease with (acute) exacerbation",
        "ICD10: I50.9 - Heart failure, unspecified",
        "ICD10: N18.3 - Chronic kidney disease, stage 3 (moderate)",
        "ICD10: J45.909 - Unspecified asthma, uncomplicated",
        "ICD10: I21.9 - Acute myocardial infarction, unspecified"
    ]}
)

# Process and display embeddings from endpoint
import numpy as np
embeddings_from_endpoint = np.array(response['predictions'])

print(f"Received embeddings from serving endpoint")
print(f"Embeddings shape: {embeddings_from_endpoint.shape}")
print(f"Embedding dimensions: {embeddings_from_endpoint.shape[1]}")
print("\n" + "="*80)
print("Serving Endpoint Response - ICD10 Embeddings:")
print("="*80)

for i, embedding in enumerate(embeddings_from_endpoint):
    print(f"\nICD10 Code {i+1}:")
    print(f"   Embedding shape: {embedding.shape}")
    print(f"   Stats - Mean: {embedding.mean():.4f}, Std: {embedding.std():.4f}")
    print(f"   Range: [{embedding.min():.4f}, {embedding.max():.4f}]")