## Load model

In [1]:
import mlflow.sklearn
import numpy as np
from pyspark.sql import SparkSession

from mlflow.tracking import MlflowClient
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType
from search_run.ranking.ranking import Ranking
import pyspark.sql.functions as F

ranking = Ranking()

spark = SparkSession.builder.getOrCreate()
client = MlflowClient()
experiment = client.get_experiment_by_name('Default')
runs = client.list_run_infos(experiment_id=experiment.experiment_id)
run_path = f"runs:/{runs[0].run_uuid}/model/"
model = mlflow.sklearn.load_model(run_path)


model


show_columns=[]

# if you wanna print the key used use this one instead
#show_columns=['key']

## Predict ranking

In [2]:


def evaluate_model(position, key_lenght):
    """ function to predict used at the udf """
    return float(model.predict(np.array([[position, key_lenght]]))[0])

entries_df = ranking.load_entries_df(spark)
model_udf = udf(evaluate_model, FloatType())
result_df = entries_df.withColumn("predicted_key_lenght", model_udf(F.col('position'), F.col('key_lenght')))
#result_df.show()

## Analyse results - easiest to find

In [3]:

result_df.select(*ranking.model_info.features, *show_columns, 'predicted_key_lenght').orderBy(F.col
                                                                                            ('predicted_key_lenght')
                                                                                       .asc()).show()


+--------+----------+--------------------+
|position|key_lenght|predicted_key_lenght|
+--------+----------+--------------------+
|      58|        32|           1.2697285|
|     127|        32|           1.3133777|
|      19|        32|           1.3197286|
|      18|        30|           1.6375389|
|       2|        31|           1.7776817|
|      26|        31|           1.7776817|
|       3|        31|           1.7776817|
|      36|        31|           1.7776817|
|      75|        30|           1.8103861|
|      74|        31|           1.9343946|
|      79|        31|           2.0561326|
|      32|        29|            2.139599|
|      34|        29|            2.139599|
|      54|        29|            2.139599|
|      27|        28|           2.3564785|
|      61|        28|           2.3564785|
|      39|        28|           2.3564785|
|      71|        30|           2.4556108|
|      47|        27|           2.6417165|
|      23|        27|           2.6417165|
+--------+-



## Analyse results - hardest to find

In [4]:

result_df.select(*ranking.model_info.features, *show_columns, 'predicted_key_lenght',).orderBy(F.col
                                                                                               ('predicted_key_lenght')
                                                                                       .desc
                                                            ()).show()


+--------+----------+--------------------+
|position|key_lenght|predicted_key_lenght|
+--------+----------+--------------------+
|     484|        41|           413.92636|
|    2270|        39|           245.19867|
|     486|        35|            221.0795|
|    2294|        39|           220.93867|
|    2325|        39|           220.93867|
|    2301|        39|           220.93867|
|     498|        41|           196.96376|
|    2283|        38|           191.49321|
|    2583|        39|           168.09265|
|     422|        52|           167.30533|
|    2367|        38|            167.0532|
|    2100|        39|           163.08955|
|    2201|        39|           162.45686|
|    1933|        39|           162.38455|
|    1926|        39|           161.44955|
|    2478|        40|           157.63939|
|     470|        53|           153.86404|
|     489|        38|           150.95296|
|    2268|        41|            148.6637|
|     487|        42|           146.16835|
+--------+-

## Save output to be consumed


In [5]:
import shutil

output = result_df.select('key', 'predicted_key_lenght').orderBy(F.col('predicted_key_lenght').desc())
shutil.rmtree('/data/search_run/predict_input_lenght/latest')
output.repartition(1).write.csv('/data/search_run/predict_input_lenght/latest', header=True)
print("Finished")

Finished


## Visualize output

In [38]:
#!cat /data/search_run/predict_input_lenght/latest/**

