In [0]:
%pip install transformers mlflow torch mlflow[databricks]
dbutils.library.restartPython()

In [0]:
import pandas as pd
import torch
import numpy as np
import mlflow
from mlflow.models import infer_signature
from transformers import AutoModel, AutoTokenizer

In [0]:
mlflow.set_registry_uri("databricks-uc")

In [0]:
uc_catalog = "marcin_demo"
uc_schema = "default"
uc_model_name = "clinicalbert_embeddings"
registered_model_name = f"{uc_catalog}.{uc_schema}.{uc_model_name}"

In [0]:
# Load ClinicalBERT model and tokenizer
model_name = "medicalai/ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()

print(f"Loaded {model_name}")
print(f"Model hidden size (embedding dimensions): {model.config.hidden_size}")

In [0]:
# Function to extract embeddings using CLS token approach
# Pretty good research paper: https://www.researchgate.net/publication/332342631_ClinicalBERT_Modeling_Clinical_Notes_and_Predicting_Hospital_Readmission
def extract_embeddings(texts):
    """
    Extract embeddings from ClinicalBERT model using CLS token.
    
    Args:
        texts: List of strings or single string
    
    Returns:
        numpy array of shape [batch_size, 768] containing embeddings
    """
    if isinstance(texts, str):
        texts = [texts]
    
    # Tokenize input texts
    inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
    
    # Get model outputs
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Extract CLS token embeddings
    cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
    
    return cls_embeddings

# Test with a simple clinical text
test_text = "The patient was diagnosed with type 2 diabetes and prescribed medication."
test_embedding = extract_embeddings(test_text)

print(f"Test embedding shape: {test_embedding.shape}")
print(f"Embedding dimensions: {test_embedding.shape[1]}")

In [0]:
# ICD10 code examples with descriptions
icd10_examples = [
    "ICD10: E11.9 - Type 2 diabetes mellitus without complications",
    "ICD10: I10 - Essential (primary) hypertension",
    "ICD10: J44.1 - Chronic obstructive pulmonary disease with (acute) exacerbation",
]

# Generate embeddings for ICD10 codes
icd10_embeddings = extract_embeddings(icd10_examples)

print(f"Generated embeddings for {len(icd10_examples)} ICD10 codes")
print(f"Embeddings shape: {icd10_embeddings.shape}")
print(f"Embedding dimensions: {icd10_embeddings.shape[1]}")
print("\n" + "="*80)
print("ICD10 Code Examples and Their Embeddings:")
print("="*80)

for i, (text, embedding) in enumerate(zip(icd10_examples, icd10_embeddings)):
    print(f"\n{i+1}. {text}")
    print(f"   Embedding shape: {embedding.shape}")
    print(f"   Stats - Mean: {embedding.mean():.4f}, Std: {embedding.std():.4f}, Min: {embedding.min():.4f}, Max: {embedding.max():.4f}")
    print(f"   First 10 values: {embedding[:10]}")

In [0]:
# Define custom MLflow PyFunc model for ClinicalBERT embeddings
class ClinicalBERTEmbeddings(mlflow.pyfunc.PythonModel):
    """
    Custom PyFunc wrapper for ClinicalBERT embedding extraction.
    Compatible with Databricks Model Serving.
    """
    
    def load_context(self, context):
        """Load model and tokenizer from artifacts when serving endpoint initializes"""
        import torch
        from transformers import AutoModel, AutoTokenizer
        
        model_path = context.artifacts["model_path"]
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModel.from_pretrained(model_path)
        self.model.eval()
    
    def predict(self, context, model_input):
        """
        Extract embeddings from input texts.
        
        Args:
            context: MLflow context
            model_input: pandas DataFrame or list of strings
        
        Returns:
            numpy array of shape [batch_size, 768] containing embeddings
        """
        import torch
        import numpy as np
        import pandas as pd
        
        # Handle different input formats
        if isinstance(model_input, pd.DataFrame):
            texts = model_input.iloc[:, 0].tolist()
        elif isinstance(model_input, list):
            texts = model_input
        else:
            texts = [str(model_input)]
        
        # Tokenize inputs
        inputs = self.tokenizer(
            texts, 
            padding=True, 
            truncation=True, 
            max_length=512, 
            return_tensors="pt"
        )
        
        # Generate embeddings
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        # Extract CLS token embeddings
        cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
        
        return cls_embeddings

print("ClinicalBERTEmbeddings PyFunc model defined successfully")

In [0]:
# Prepare input example and signature for MLflow
input_example = icd10_examples[:3]
sample_output = extract_embeddings(input_example)

signature = mlflow.models.infer_signature(input_example, sample_output)

print("MLflow Model Signature:")
print(signature)

In [0]:
import tempfile
import os

# Save model and tokenizer to temporary directory for MLflow artifact logging
with tempfile.TemporaryDirectory() as tmp_dir:
    model_save_path = os.path.join(tmp_dir, "clinicalbert")
    
    # Save model and tokenizer
    model.save_pretrained(model_save_path)
    tokenizer.save_pretrained(model_save_path)
    
    # Log model with MLflow
    with mlflow.start_run() as run:
        model_info = mlflow.pyfunc.log_model(
            artifact_path="clinicalbert_embeddings",
            python_model=ClinicalBERTEmbeddings(),
            artifacts={"model_path": model_save_path},
            input_example=input_example,
            signature=signature,
            registered_model_name=registered_model_name,
            pip_requirements=[
                "transformers",
                "torch",
                "numpy"
            ]
        )

model_uri = model_info.model_uri
registered_model_version = model_info.registered_model_version

print(f"Model URI: {model_uri}")
print(f"Registered Model Version: {registered_model_version}")
print(f"Model registered in Unity Catalog as: {registered_model_name}")

In [0]:
from mlflow.models import validate_serving_input

# Validate the serving payload
serving_payload = """{
  "inputs": [
    "ICD10: E11.9 - Type 2 diabetes mellitus without complications",
    "ICD10: I10 - Essential (primary) hypertension"
  ]
}"""

validation_result = validate_serving_input(model_uri, serving_payload)
print("Serving payload validation successful")
print(validation_result)

In [0]:
# Load and test the model from Unity Catalog
loaded_model = mlflow.pyfunc.load_model(f"models:/{registered_model_name}/{registered_model_version}")

# Test with ICD10 code examples
test_icd10_codes = [
    "ICD10: E11.9 - Type 2 diabetes mellitus without complications",
    "ICD10: I10 - Essential (primary) hypertension",
    "ICD10: J44.1 - Chronic obstructive pulmonary disease with (acute) exacerbation",
    "ICD10: I50.9 - Heart failure, unspecified",
    "ICD10: N18.3 - Chronic kidney disease, stage 3 (moderate)"
]

embeddings = loaded_model.predict(test_icd10_codes)

print(f"Generated embeddings for {len(test_icd10_codes)} ICD10 codes")
print(f"Embeddings shape: {embeddings.shape}")
print(f"Embedding dimensions: {embeddings.shape[1]}")
print("\n" + "="*80)
print("ICD10 Code Embeddings from Unity Catalog Model:")
print("="*80)

for i, (code, embedding) in enumerate(zip(test_icd10_codes, embeddings)):
    print(f"\n{i+1}. {code}")
    print(f"   Shape: {embedding.shape}")
    print(f"   Stats - Mean: {embedding.mean():.4f}, Std: {embedding.std():.4f}")
    print(f"   Range: [{embedding.min():.4f}, {embedding.max():.4f}]")
    print(f"   First 5 values: {embedding[:5]}")

# # Display embeddings as DataFrame
# embeddings_df = pd.DataFrame(
#     embeddings,
#     index=[f"ICD10_{i+1}" for i in range(len(test_icd10_codes))]
# )
# print(f"\nEmbeddings DataFrame shape: {embeddings_df.shape}")
# embeddings_df.head()

In [0]:
# from mlflow.deployments import get_deploy_client

# # Create a deployment client
# client = get_deploy_client("databricks")
# endpoint_name = "clinicalbert_embeddings_endpoint"

# endpoint = client.create_endpoint(
#     name=endpoint_name,
#     config={
#         "served_entities": [
#             {
#                 "entity_name": f"{registered_model_name}",
#                 "entity_version": f"{registered_model_version}",
#                 "workload_size": "Small",
#                 "scale_to_zero_enabled": True
#             }
#         ],
#         "traffic_config": {
#             "routes": [
#                 {
#                     "served_model_name": f"{uc_model_name}-{registered_model_version}",
#                     "traffic_percentage": 100
#                 }
#             ]
#         }
#     }
# )

# print(f"Endpoint {endpoint_name} created")

In [0]:
# import time

# # Wait for endpoint to be ready
# def wait_for_endpoint_ready(client, endpoint_name, timeout=1200, interval=60):
#     start_time = time.time()
#     while time.time() - start_time < timeout:
#         endpoint_info = client.get_endpoint(endpoint_name)
#         endpoint_state = endpoint_info.get('state', 'UNKNOWN')['ready']
#         if endpoint_state == 'READY':
#             print(f"Endpoint {endpoint_name} is ready.")
#             return
#         elif endpoint_state == 'FAILED':
#             raise Exception(f"Endpoint {endpoint_name} creation failed.")
#         else:
#             print(f"Endpoint {endpoint_name} is in state {endpoint_state}. Waiting...")
#             time.sleep(interval)
#     raise TimeoutError(f"Timeout while waiting for endpoint {endpoint_name} to be ready.")

# wait_for_endpoint_ready(client, endpoint_name)

In [0]:
# # Test the serving endpoint
# response = client.predict(
#     endpoint=endpoint_name,
#     inputs={"inputs": [
#         "ICD10: E11.9 - Type 2 diabetes mellitus without complications",
#         "ICD10: I10 - Essential (primary) hypertension",
#         "ICD10: J44.1 - Chronic obstructive pulmonary disease with (acute) exacerbation",
#         "ICD10: I50.9 - Heart failure, unspecified",
#         "ICD10: N18.3 - Chronic kidney disease, stage 3 (moderate)",
#         "ICD10: J45.909 - Unspecified asthma, uncomplicated",
#         "ICD10: I21.9 - Acute myocardial infarction, unspecified"
#     ]}
# )

# # Process endpoint response
# embeddings_from_endpoint = np.array(response['predictions'])

# print(f"Received embeddings from serving endpoint")
# print(f"Embeddings shape: {embeddings_from_endpoint.shape}")
# print(f"Embedding dimensions: {embeddings_from_endpoint.shape[1]}")
# print("\n" + "="*80)
# print("Serving Endpoint Response - ICD10 Embeddings:")
# print("="*80)

# for i, embedding in enumerate(embeddings_from_endpoint):
#     print(f"\nICD10 Code {i+1}:")
#     print(f"   Embedding shape: {embedding.shape}")
#     print(f"   Stats - Mean: {embedding.mean():.4f}, Std: {embedding.std():.4f}")
#     print(f"   Range: [{embedding.min():.4f}, {embedding.max():.4f}]")