In [0]:
import mlflow
from mlflow.pyfunc.model import PythonModelContext
from typing import Optional, Dict, List, Any, Union

import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import ArrayType, FloatType, StringType
import numpy as np
import mlflow.deployments


In [0]:
mlflow.set_registry_uri("databricks-uc")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 5000)
deploy_client = mlflow.deployments.get_deploy_client("databricks")

CATALOG = "jun_demo"
SCHEMA = "test"
REGISTRED_MODEL_NAME = "test_model"

class BGERerankerUdfModel(mlflow.pyfunc.PythonModel):

    def load_context(self, context: PythonModelContext):
      from pyspark.sql import SparkSession
      self.spark = SparkSession.builder.getOrCreate()     
      self.model_name = REGISTRED_MODEL_NAME 
    
    def predict(self, context: PythonModelContext, model_input: pd.DataFrame, params: Optional[Dict[str, Any]] = None):

      @pandas_udf(FloatType())
      def query_embeddings_udf(sentence1: pd.Series, sentence2: pd.Series) -> pd.Series:
        inp = {
          "inputs": {
              "sentence1": sentence1.to_list(),
              "sentence2": sentence2.to_list(),
              }
        }
        response = deploy_client.predict(endpoint="bge-reranker", inputs=inp)
        return pd.Series(response["predictions"], index=sentence1.index)
    
      df = pd.DataFrame({
        "sentence1": model_input["sentence1"],
        "sentence2": model_input["sentence2"]
      })

      no_of_concurrency = params.get("no_of_concurrency", 4)
      df_sp = self.spark.createDataFrame(df)

      rerank_results_df_sp = df_sp.repartition(no_of_concurrency).withColumn("scores", query_embeddings_udf(df_sp["sentence1"], df_sp["sentence2"]))
      rerank_results_df = rerank_results_df_sp.toPandas()
      return rerank_results_df
    
  

In [0]:

input_example = {
  "sentence1": ["What is Python?", "What is Python?"], 
  "sentence2": ["My first paragraph. That contains information", "Python is a programming language"]
}

model = BGERerankerUdfModel()

context = PythonModelContext({
  'model_name': REGISTRED_MODEL_NAME
}, None)
model.load_context(context)

result = model.predict(None, pd.DataFrame(input_example), {"no_of_concurrency": 4})
display(result)

sentence1,sentence2,scores
What is Python?,My first paragraph. That contains information,-9.28367
What is Python?,Python is a programming language,6.1343727


In [0]:
# from mlflow.models.signature import ModelSignature
# from mlflow.types.schema import ColSpec, Schema, ParamSpec, ParamSchema
# import numpy as np

# input_schema = Schema([ColSpec("string", "sentence1"), ColSpec("string", "sentence2")])

# output_schema = Schema([ColSpec("string", "predictions")])

# param_schema = ParamSchema([ParamSpec("max_lenghth", "long", 512)])

# signature = ModelSignature(inputs=input_schema, outputs=output_schema, params=param_schema)

from mlflow.models.signature import infer_signature

signature = infer_signature(model_input=pd.DataFrame(input_example), model_output=result, params={"max_lenghth": 512})

with mlflow.start_run():
  model_info = mlflow.pyfunc.log_model(
    REGISTRED_MODEL_NAME,
    python_model=BGERerankerUdfModel(),
    input_example=input_example,
    signature=signature,
    registered_model_name=f"{CATALOG}.{SCHEMA}.{REGISTRED_MODEL_NAME}"
  )



Uploading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Registered model 'jun_demo.test.test_model' already exists. Creating a new version of this model...


Uploading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Created version '3' of model 'jun_demo.test.test_model'.


In [0]:
reranker = mlflow.pyfunc.load_model(model_info.model_uri)
result = reranker.predict(pd.DataFrame(input_example), params={"no_of_concurrency": 4})
display(result)


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]



sentence1,sentence2,scores
What is Python?,Python is a programming language,6.134373
What is Python?,My first paragraph. That contains information,-9.28367
