In [0]:
import mlflow
from mlflow.pyfunc.model import PythonModelContext
from typing import Optional, Dict, List, Any, Union

import pandas as pd
import numpy as np
import mlflow.deployments
import multiprocessing as mp


In [0]:
mlflow.set_registry_uri("databricks-uc")
deploy_client = mlflow.deployments.get_deploy_client("databricks")

CATALOG = "jun_demo"
SCHEMA = "test"
REGISTRED_MODEL_NAME = "test_model"


class BGERerankerUdfModel(mlflow.pyfunc.PythonModel):

    @staticmethod
    def query_reranker(sentence1: pd.Series, sentence2: pd.Series) -> pd.Series:
      inp = {
        "inputs": {
            "sentence1": sentence1.to_list(),
            "sentence2": sentence2.to_list(),
            }
      }
      response = deploy_client.predict(endpoint="bge-reranker", inputs=inp)
      return pd.Series(response["predictions"], index=sentence1.index)


    @staticmethod
    def func(df):
      if df.empty:
        return pd.Series([], dtype=np.float32)
      return BGERerankerUdfModel.query_reranker(df["sentence1"], df["sentence2"])


    def load_context(self, context: PythonModelContext):
      from pyspark.sql import SparkSession
      self.spark = SparkSession.builder.getOrCreate()     
      self.model_name = REGISTRED_MODEL_NAME 
    

    def predict(self, context: PythonModelContext, model_input: pd.DataFrame, params: Optional[Dict[str, Any]] = None):
      no_of_concurrency = params.get("no_of_concurrency", 4)

      dfs = np.array_split(model_input, no_of_concurrency) # divide the dataframe as desired

      with mp.Pool(no_of_concurrency) as pool:
          scores = pd.concat(pool.map(BGERerankerUdfModel.func, dfs))

      return model_input.assign(scores=scores)
    
  

In [0]:

input_example = {
  "sentence1": ["What is Python?", "What is Python?", "What is Python?", "What is Python?", "What is Python?", "What is Python?", "What is Python?", "What is Python?"], 
  "sentence2": ["My first paragraph", "That contains information", "information", "Python is a programming language", "Python", "is", "program", "language"]
}

model = BGERerankerUdfModel()

context = PythonModelContext({
  'model_name': REGISTRED_MODEL_NAME
}, None)
model.load_context(context)

result = model.predict(None, pd.DataFrame(input_example), {"no_of_concurrency": 4})
display(result)

sentence1,sentence2,scores
What is Python?,My first paragraph,-9.444510459899902
What is Python?,That contains information,-1.8268609046936035
What is Python?,information,-4.175684928894043
What is Python?,Python is a programming language,6.134361267089844
What is Python?,Python,2.210157632827759
What is Python?,is,-3.809581995010376
What is Python?,program,0.7984319925308228
What is Python?,language,-0.0479622147977352


In [0]:
from mlflow.models.signature import infer_signature

signature = infer_signature(model_input=pd.DataFrame(input_example), model_output=result, params={"no_of_concurrency": 4})

with mlflow.start_run():
  model_info = mlflow.pyfunc.log_model(
    REGISTRED_MODEL_NAME,
    python_model=BGERerankerUdfModel(),
    input_example=input_example,
    signature=signature,
    registered_model_name=f"{CATALOG}.{SCHEMA}.{REGISTRED_MODEL_NAME}"
  )



Uploading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Registered model 'jun_demo.test.test_model' already exists. Creating a new version of this model...


Uploading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Created version '6' of model 'jun_demo.test.test_model'.


In [0]:
reranker = mlflow.pyfunc.load_model(model_info.model_uri)
result = reranker.predict(pd.DataFrame(input_example), params={"no_of_concurrency": 4})
display(result)


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

sentence1,sentence2,scores
What is Python?,My first paragraph,-9.444510459899902
What is Python?,That contains information,-1.8268609046936035
What is Python?,information,-4.175684928894043
What is Python?,Python is a programming language,6.134361267089844
What is Python?,Python,2.210157632827759
What is Python?,is,-3.809581995010376
What is Python?,program,0.7984319925308228
What is Python?,language,-0.0479622147977352
