In [None]:
%pip install mlflow pymongo tqdm

Python interpreter will be restarted.
Collecting mlflow
  Using cached mlflow-2.4.1-py3-none-any.whl (18.1 MB)
Collecting pymongo
  Using cached pymongo-4.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (492 kB)
Collecting tqdm
  Using cached tqdm-4.65.0-py3-none-any.whl (77 kB)
Collecting gunicorn<21
  Using cached gunicorn-20.1.0-py3-none-any.whl (79 kB)
Collecting docker<7,>=4.0.0
  Using cached docker-6.1.3-py3-none-any.whl (148 kB)
Collecting pyyaml<7,>=5.1
  Using cached PyYAML-6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (661 kB)
Collecting gitpython<4,>=2.1.0
  Using cached GitPython-3.1.31-py3-none-any.whl (184 kB)
Collecting cloudpickle<3
  Using cached cloudpickle-2.2.1-py3-none-any.whl (25 kB)
Collecting alembic!=1.10.0,<2
  Using cached alembic-1.11.1-py3-none-any.whl (224 kB)
Collecting Flask<3
  Using cached Flask-2.3.2-py3-none-any.whl (96 kB)
Collecting markdown<4,>=3.3
  Using cached Markdown-3.4.3-py

In [None]:
from pymongo import MongoClient
MONGO_CONN = 'mongodb+srv://<username>:<password>@retail-demo.2wqno.mongodb.net/?retryWrites=true&w=majority'
client = MongoClient(MONGO_CONN)

In [None]:
import numpy as np
import pandas as pd
from datetime import datetime
import pyspark
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql import Window, WindowSpec
from pyspark.sql.functions import struct

import mlflow.pyfunc
from tqdm import tqdm
tqdm.pandas()

import warnings
warnings.filterwarnings("ignore")

sales = spark.read.format("mongodb").\
            option('spark.mongodb.connection.uri', MONGO_CONN).\
            option('spark.mongodb.database', "search").\
            option('spark.mongodb.collection', "processed_clogs_myn").\
            load()



In [None]:
sales.select("id").distinct().count()

Out[3]: 44446

In [None]:
sales = sales.withColumn("old_sales", F.when(F.col("old_sales").isNull(), F.lit(0.0)).otherwise(F.col("old_sales").cast("float")))
sales = sales.withColumn("total_sales", F.when(F.col("total_sales").isNull(), F.lit(0.0)).otherwise(F.col("total_sales").cast("float")))
sales = sales.groupby("id").agg(F.sum("total_sales").alias("total_sales"), F.avg("avg_price").alias("avg_price"), F.avg("max_price").alias("max_price"), F.avg("min_price").alias("min_price"),\
                                F.avg("old_avg_price").alias("old_avg_price"), F.sum("old_sales").alias("old_sales") )

sales.count()

In [None]:
model_name = "retail_competitive_pricing_model_1"
apply_model_udf = mlflow.pyfunc.spark_udf(spark, f"models:/{model_name}/staging")
    
# Apply the model to the new data
columns = ['old_sales','total_sales','min_price','max_price','avg_price','old_avg_price']
udf_inputs = struct(*columns)
udf_inputs

2023/06/13 07:02:40 INFO mlflow.models.flavor_backend_registry: Selected backend for flavor 'python_function'
Out[12]: Column<'struct(old_sales, total_sales, min_price, max_price, avg_price, old_avg_price)'>

In [None]:
sales = sales.fillna(0.0)
sales = sales.withColumn("pred_price",apply_model_udf(udf_inputs))

sales = sales.withColumn("price_elasticity", F.expr("((old_sales - total_sales)/(old_sales + total_sales))/(((old_avg_price - avg_price)+1)/(old_avg_price + avg_price))"))

sales = sales.withColumn("discount", F.ceil((F.lit(1) - F.col("pred_price"))*F.lit(100)))
display(sales)

id,total_sales,avg_price,max_price,min_price,old_avg_price,old_sales,pred_price,price_elasticity,discount
59467,5.0,1.0,1.0,1.0,1.0,4.0,0.9454153832823006,-0.2222222222222222,6
38986,7.0,1.0,1.0,1.0,1.0,6.0,0.9454153832823006,-0.1538461538461538,6
30966,209.0,1.0,1.0,1.0,1.000161144764629,207.0,0.9425635852911128,-0.0096146100057635,6
20868,189.0,1.0,1.013157894736842,0.986842105263158,1.0077120372904638,188.0,0.9427028481310558,-0.0052847400081208,6
9586,48.0,1.0,1.0,1.0,1.0,46.0,0.9399949980978312,-0.0425531914893617,7
39590,135.0,1.0,1.0,1.0,1.007535083235467,131.0,0.9443862177000628,-0.0299627258589539,6
40740,1.0,1.0,1.0,1.0,0.0,0.0,0.9387022415531036,,7
30923,7.0,1.0,1.0,1.0,1.0,4.0,0.9454153832823006,-0.5454545454545454,6
46538,5.0,1.0,1.0,1.0,1.0,4.0,0.9454153832823006,-0.2222222222222222,6
15269,2.0,1.0,1.0,1.0,0.0,0.0,0.9388575473933594,,7


In [None]:
sales.count()

Out[14]: 44446

In [None]:
sales.select("id", "pred_price", "price_elasticity").write.format("mongodb").\
            option('spark.mongodb.connection.uri', MONGO_CONN).\
            option('spark.mongodb.database', "search").\
            option('spark.mongodb.collection', "price_myn").\
            option('spark.mongodb.idFieldList', 'id').\
            mode('overwrite').\
            save()

In [None]:
sales.orderBy(F.col('discount').desc()).show(10,False)

+-----------+-----------+-------------------+-------------------+-------------------+------------------+---------+--------------------+------------------+------------------+
|product_uid|total_sales|avg_price          |max_price          |min_price          |old_avg_price     |old_sales|pred_price          |price_elasticity  |discount          |
+-----------+-----------+-------------------+-------------------+-------------------+------------------+---------+--------------------+------------------+------------------+
|133161     |0          |50925.9            |50925.9            |50925.9            |0.0               |0        |11386.767400002409  |0.0               |77.64051808607721 |
|129649     |4          |64771.06           |64771.06           |64771.06           |48578.295         |3        |21785.076399999707  |1.0000617597896215|66.36603384289263 |
|124778     |0          |0.16000000000000011|0.16000000000000011|0.16000000000000011|0.1588321167883213|0        |0.06200000003814