In [121]:
import mlflow
from sklearn.ensemble import RandomForestRegressor
from search_run.ranking.ranking import Ranking
from pyspark.sql.session import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType


spark = SparkSession.builder.getOrCreate()

features = ['position', 'key_lenght' ]
label = 'input_lenght'


## Entries

In [84]:
entries_df = Ranking().load_entries_df(spark)
#entries_df = entries_df.drop("content")
entries_df = entries_df.withColumn('key_lenght', F.length('key'))
entries_df.select('position', 'key_lenght').show()

+--------+----------+
|position|key_lenght|
+--------+----------+
|       1|        41|
|       2|        16|
|       3|        36|
|       4|        33|
|       5|        53|
|       6|        21|
|       7|         7|
|       8|        28|
|       9|        20|
|      10|        26|
|      11|        30|
|      12|        16|
|      13|        23|
|      14|        38|
|      15|        15|
|      16|        32|
|      17|        24|
|      18|        22|
|      19|        14|
|      20|        34|
+--------+----------+
only showing top 20 rows



## Commands performed

In [87]:
dataset = Ranking().load_commands_performed_df()
schema = '`key` STRING,  `generated_date` STRING, `uuid` STRING, `given_input` STRING'
original_df = spark.createDataFrame(dataset, schema=schema)
performed_df = original_df.withColumn("input_lenght", F.length("given_input"))
performed_df = performed_df.filter('given_input != "NaN"')
performed_df = performed_df.drop('uuid')

performed_df.select('input_lenght', 'generated_date').show()


+------------+-------------------+
|input_lenght|     generated_date|
+------------+-------------------+
|          15|2021-08-28 05:04:19|
|          16|2021-08-28 05:01:59|
|          14|2021-08-27 21:32:54|
|           4|2021-08-27 21:22:42|
|           7|2021-08-27 21:18:39|
|           9|2021-08-27 21:00:34|
|          14|2021-08-27 21:00:17|
|           9|2021-08-27 20:55:48|
|           8|2021-08-27 20:53:04|
|           5|2021-08-27 20:46:58|
|           5|2021-08-27 18:07:43|
|           7|2021-08-27 18:01:32|
|           9|2021-08-27 17:56:18|
|          18|2021-08-27 17:56:00|
|          18|2021-08-27 17:55:55|
|           9|2021-08-27 17:55:44|
|           3|2021-08-27 17:48:18|
|          10|2021-08-27 17:47:43|
|          10|2021-08-27 17:47:39|
|          19|2021-08-27 17:08:03|
+------------+-------------------+
only showing top 20 rows



## Final Dataset

In [107]:

df = performed_df.join(entries_df, on='key', how='left')
df = df.filter('position is not null')

df.select(*features, label).show()
#df.show()

+--------+----------+------------+
|position|key_lenght|input_lenght|
+--------+----------+------------+
|    3873|        19|         134|
|    3873|        19|         134|
|     630|        27|          10|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
|     630|        27|          27|
+--------+----------+------------+
only showing top 20 rows



In [109]:
import numpy as np

X = np.array(df.select(*features).collect())
Y = np.array(df.select(label).collect())

X

array([[3873,   19],
       [3873,   19],
       [ 630,   27],
       ...,
       [ 988,   15],
       [ 988,   15],
       [1943,   24]])

## Train model

In [122]:


with mlflow.start_run():
    regr = RandomForestRegressor(max_depth=2, random_state=0)
    model = regr.fit(X, Y)
    mlflow.sklearn.log_model(model)

  model = regr.fit(X, Y)


TypeError: log_model() missing 1 required positional argument: 'artifact_path'

## Predict ranking

In [112]:
def evaluate_model(position, key_lenght):
    """ function to predict used at the udf """
    return float(model.predict(np.array([[position, key_lenght]]))[0])

model_udf = udf(evaluate_model, FloatType())

result_df = entries_df.withColumn("predicted_key_lenght", model_udf(F.col('position'), F.col('key_lenght')))
#result_df.show()

## Analyse results - easiest to find

In [118]:

result_df.select(*features, 'predicted_key_lenght').orderBy(F.col('predicted_key_lenght').asc()).show()


+--------+----------+--------------------+
|position|key_lenght|predicted_key_lenght|
+--------+----------+--------------------+
|    2460|        15|           23.089348|
|    2455|        24|           23.089348|
|    2446|        22|           23.089348|
|    2451|        19|           23.089348|
|    2434|        16|           23.089348|
|    2441|        22|           23.089348|
|    2459|        26|           23.089348|
|    2448|        13|           23.089348|
|    2429|        10|           23.089348|
|    2468|        12|           23.089348|
|    2454|        18|           23.089348|
|    2438|        27|           23.089348|
|    2457|        19|           23.089348|
|    2458|        20|           23.089348|
|    2444|        18|           23.089348|
|    2447|        24|           23.089348|
|    2469|        17|           23.089348|
|    2427|        13|           23.089348|
|    2449|        23|           23.089348|
|    2453|        13|           23.089348|
+--------+-



## Analyse results - hardest to find

In [120]:

result_df.select(*features,'predicted_key_lenght',).orderBy(F.col('predicted_key_lenght').desc
                                                                                  ()).show()


+--------+----------+--------------------+
|position|key_lenght|predicted_key_lenght|
+--------+----------+--------------------+
|     119|        43|           150.99907|
|     105|        47|           150.99907|
|     117|        44|           150.99907|
|      26|        52|           150.99907|
|     134|        41|           150.99907|
|      93|        55|           150.99907|
|     102|        43|           150.99907|
|      24|        41|           150.99907|
|      53|        48|           150.99907|
|       5|        53|           150.99907|
|       1|        41|           150.99907|
|      50|        50|           150.99907|
|      59|        54|           150.99907|
|      99|        56|           150.99907|
|      94|        45|           150.99907|
|      31|        45|           150.99907|
|      71|        41|           150.99907|
|      47|        48|           150.99907|
|      44|        43|           150.99907|
|      77|        44|           150.99907|
+--------+-

# Save output


In [137]:

output = result_df.select('key', 'predicted_key_lenght').orderBy(F.col('predicted_key_lenght').desc())
output.show()
import shutil
shutil.rmtree('/data/search_run/predict_input_lenght/latest')
output.repartition(1).write.csv('/data/search_run/predict_input_lenght/latest', header=True)
print("Finished")

+--------------------+--------------------+
|                 key|predicted_key_lenght|
+--------------------+--------------------+
|engineering categ...|           150.99907|
|growth budget per...|           150.99907|
|troubleshooting s...|           150.99907|
|data platform dat...|           150.99907|
|how i want to beh...|           150.99907|
|accompliments per...|           150.99907|
|remove us from fa...|           150.99907|
|buy train ticket ...|           150.99907|
|databricks repres...|           150.99907|
|become the ml pla...|           150.99907|
|filter available ...|           150.99907|
|champions league ...|           150.99907|
|360 review cycle ...|           150.99907|
|context aware ran...|           150.99907|
|vertex ai feature...|           150.99907|
|inspect python ob...|           150.99907|
|data integration ...|           150.99907|
|growth framework ...|           150.99907|
|deliverables q4 t...|           150.99907|
|dont know how and...|          