## Load model

In [39]:
import mlflow.sklearn
import numpy as np
from pyspark.sql import SparkSession

from mlflow.tracking import MlflowClient
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType
from search_run.ranking.ranking import Ranking
import pyspark.sql.functions as F

ranking = Ranking()

spark = SparkSession.builder.getOrCreate()
client = MlflowClient()
experiment = client.get_experiment_by_name('Default')
runs = client.list_run_infos(experiment_id=experiment.experiment_id)
run_path = f"runs:/{runs[0].run_uuid}/model/"
model = mlflow.sklearn.load_model(run_path)


model
#show_columns = ['key']
show_columns=[]

## Predict ranking

In [23]:


def evaluate_model(position, key_lenght):
    """ function to predict used at the udf """
    return float(model.predict(np.array([[position, key_lenght]]))[0])

entries_df = ranking.load_entries_df(spark)
model_udf = udf(evaluate_model, FloatType())
result_df = entries_df.withColumn("predicted_key_lenght", model_udf(F.col('position'), F.col('key_lenght')))
#result_df.show()

## Analyse results - easiest to find

In [29]:

result_df.select(*ranking.model_info.features, *show_columns, 'predicted_key_lenght').orderBy(F.col
                                                                                            ('predicted_key_lenght')
                                                                                       .asc()).show()


+--------+----------+--------------------+
|position|key_lenght|predicted_key_lenght|
+--------+----------+--------------------+
|    2973|        16|           22.474922|
|    2966|        18|           22.474922|
|    2953|        17|           22.474922|
|    2939|        18|           22.474922|
|    2943|        22|           22.474922|
|    2949|         9|           22.474922|
|    2972|        18|           22.474922|
|    2957|        17|           22.474922|
|    2935|        13|           22.474922|
|    2941|        21|           22.474922|
|    2963|        21|           22.474922|
|    2947|        20|           22.474922|
|    2969|        24|           22.474922|
|    2971|        24|           22.474922|
|    2951|        11|           22.474922|
|    2956|        25|           22.474922|
|    2981|        12|           22.474922|
|    2955|        16|           22.474922|
|    2980|        20|           22.474922|
|    2958|        22|           22.474922|
+--------+-



## Analyse results - hardest to find

In [28]:

result_df.select(*ranking.model_info.features, *show_columns, 'predicted_key_lenght',).orderBy(F.col
                                                                                               ('predicted_key_lenght')
                                                                                       .desc
                                                            ()).show()


+--------+----------+--------------------+
|position|key_lenght|predicted_key_lenght|
+--------+----------+--------------------+
|     256|        41|          102.572296|
|     169|        41|          102.572296|
|     270|        41|          102.572296|
|     288|        41|           101.62816|
|     259|        42|           101.48737|
|     226|        40|           100.85271|
|     292|        40|            99.90857|
|     245|        44|            99.06883|
|     272|        43|            99.06883|
|     181|        43|            99.06883|
|     152|        44|            99.06883|
|     154|        43|            99.06883|
|     165|        39|            98.83337|
|     207|        39|            98.83337|
|     187|        46|            98.59564|
|     242|        53|            98.59564|
|     156|        45|            98.59564|
|     198|        63|            98.59564|
|     178|        54|            98.59564|
|     279|        49|            98.59564|
+--------+-

## Save output to be consumed


In [27]:
import shutil

output = result_df.select('key', 'predicted_key_lenght').orderBy(F.col('predicted_key_lenght').desc())
shutil.rmtree('/data/search_run/predict_input_lenght/latest')
output.repartition(1).write.csv('/data/search_run/predict_input_lenght/latest', header=True)
print("Finished")

Finished


## Visualize output

In [38]:
#!cat /data/search_run/predict_input_lenght/latest/**

