## Load model

In [1]:
import mlflow.sklearn
import numpy as np
from pyspark.sql import SparkSession

from mlflow.tracking import MlflowClient
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType
from search_run.ranking.ranking import Ranking
import pyspark.sql.functions as F

ranking = Ranking()

spark = SparkSession.builder.getOrCreate()
client = MlflowClient()
experiment = client.get_experiment_by_name('Default')
runs = client.list_run_infos(experiment_id=experiment.experiment_id)
run_path = f"runs:/{runs[0].run_uuid}/model/"
model = mlflow.sklearn.load_model(run_path)


model


show_columns=[]

# if you wanna print the key used use this one instead
#show_columns=['key']

## Predict ranking

In [2]:


def evaluate_model(position, key_lenght):
    """ function to predict used at the udf """
    return float(model.predict(np.array([[position, key_lenght]]))[0])

entries_df = ranking.load_entries_df(spark)
model_udf = udf(evaluate_model, FloatType())
result_df = entries_df.withColumn("predicted_key_lenght", model_udf(F.col('position'), F.col('key_lenght')))
#result_df.show()

## Analyse results - easiest to find

In [3]:

result_df.select(*ranking.model_info.features, *show_columns, 'predicted_key_lenght').orderBy(F.col
                                                                                            ('predicted_key_lenght')
                                                                                       .asc()).show()


+--------+----------+--------------------+
|position|key_lenght|predicted_key_lenght|
+--------+----------+--------------------+
|     201|        32|           1.3371203|
|     148|        31|           1.8000659|
|     203|        30|           1.8142977|
|     149|        30|           1.9202597|
|     190|        30|           1.9402977|
|     271|        34|           2.1100752|
|      50|        30|            2.162003|
|       5|        30|            2.162003|
|      92|        30|            2.162003|
|      11|        30|            2.162003|
|     153|        31|            2.193936|
|     110|        31|           2.2073877|
|     100|        31|           2.2073877|
|      37|        31|           2.2073877|
|      77|        31|           2.2073877|
|       8|        32|           3.2765884|
|     132|        32|           3.2765884|
|      93|        32|           3.2765884|
|     128|        29|           3.5030427|
|     106|        29|           3.5030427|
+--------+-



## Analyse results - hardest to find

In [4]:

result_df.select(*ranking.model_info.features, *show_columns, 'predicted_key_lenght',).orderBy(F.col
                                                                                               ('predicted_key_lenght')
                                                                                       .desc
                                                            ()).show()


+--------+----------+--------------------+
|position|key_lenght|predicted_key_lenght|
+--------+----------+--------------------+
|     558|        41|           448.29694|
|     572|        41|            345.6516|
|     496|        52|             166.485|
|    2342|        39|           164.00435|
|     590|        41|           163.08333|
|    3715|        20|           149.57974|
|    3702|        20|           149.57974|
|    3701|        20|           149.57974|
|     503|        63|           147.04309|
|    1312|        60|           143.74069|
|    3636|        20|            139.0517|
|    3633|        20|            139.0517|
|    3587|        20|           135.52727|
|    1158|        42|            133.6041|
|     483|        54|           131.63066|
|    1219|        47|           129.90463|
|    1396|        47|           128.48994|
|     544|        53|           128.42384|
|    1156|        47|          127.912964|
|    1408|        46|          127.273315|
+--------+-

## Save output to be consumed


In [18]:
import shutil
from search_run.data_paths import DataPaths

output = result_df.select('key', 'predicted_key_lenght').orderBy(F.col('predicted_key_lenght').desc())
shutil.rmtree(DataPaths.prediction_batch_location)

output.repartition(1).write.csv(DataPaths.prediction_batch_location, header=True)
print("Finished")

Finished
Finished


## Visualize output

In [29]:
#! cat {DataPaths.prediction_batch_location}/**


