In [0]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

df_prints = spark.table("toner_regression_features")

window_pc = Window.partitionBy("deviceId", "color").orderBy("timestamp")
df_with_prints = (
    df_prints
    .withColumn(
        "cumulative_print_count",
        F.sum("delta_print_count").over(window_pc)
    )
)

df_math_prints = df_with_prints.select(
    "deviceId",
    "color",
    "cumulative_print_count",
    F.col("toner_pct_remaining").alias("typical")
)

df_stats_prints = (
    df_math_prints
    .groupBy("deviceId", "color")
    .agg(
        F.count("*").alias("n"),
        F.sum("cumulative_print_count").alias("sum_x"),
        F.sum("typical").alias("sum_y"),
        F.sum(F.col("cumulative_print_count") * F.col("typical")).alias("sum_xy"),
        F.sum(F.col("cumulative_print_count") ** 2).alias("sum_x2")
    )
)

df_regression_prints= (
    df_stats_prints
    .withColumn(
        "m",
        (
            F.col("n") * F.col("sum_xy") -
            F.col("sum_x") * F.col("sum_y")
        ) /
        (
            F.col("n") * F.col("sum_x2") -
            F.col("sum_x") ** 2
        )
    )
    .withColumn(
        "c",
        (F.col("sum_y") - F.col("m") * F.col("sum_x")) / F.col("n")
    )
)
df_prediction_prints = (
    df_regression_prints
    .withColumn(
        "predicted_print_count",
        -F.col("c") / F.col("m")
    )
    .filter(
        (F.col("m") < 0) & 
        (F.col("predicted_print_count") > 0)
    )
)
df_prediction_prints.select(
    "deviceId",
    "color",
    "predicted_print_count"
).display()


deviceId,color,predicted_print_count
mn=QlA1MEM1NQ==:sn=NDMwMDY5MzcwMA==,black,15355.558776606418
mn=QlA1MEM1NQ==:sn=NDMwMDY5MzcwMA==,cyan,18326.339534580333
mn=QlA1MEM1NQ==:sn=NDMwMDY5MzcwMA==,magenta,20830.184457505555
mn=QlA1MEM1NQ==:sn=NDMwMDY5MzcwMA==,yellow,21867.754384084627
mn=QlA1MEM1NQ==:sn=NDMwMDY5OTYwMA==,black,29210.774879009048
mn=QlA1MEM1NQ==:sn=NDMwMDY5OTYwMA==,cyan,15174.796875000002
mn=QlA1MEM1NQ==:sn=NDMwMDY5OTYwMA==,magenta,9601.023622047243
mn=QlA1MEM1NQ==:sn=NDMwMDY5OTYwMA==,yellow,9231.731343283584
mn=QlA1MEM1NQ==:sn=NDMwMDYwOTcwMA==,black,22803.515427375347
mn=QlA1MEM1NQ==:sn=NDMwMDYwOTcwMA==,cyan,35255.33241111863


In [0]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

df1 = spark.table("toner_regression_features")

window_pc = Window.partitionBy("deviceId", "color").orderBy("timestamp")

df_pc = (
    df1
    .withColumn(
        "cumulative_print_count",
        F.sum("delta_print_count").over(window_pc)
    )
)

from pyspark.ml.feature import VectorAssembler

df_ml_base = df_pc.select(
    "deviceId",
    "color",
    F.col("toner_pct_remaining").alias("label"),
    "cumulative_print_count"
)

assembler1 = VectorAssembler(
    inputCols=["cumulative_print_count"],
    outputCol="features"
)

df_ml1 = assembler1.transform(df_ml_base).select(
    "deviceId",
    "color",
    "label",
    "features"
)

from pyspark.ml.regression import LinearRegression

results = []

pairs = df_ml1.select("deviceId", "color").distinct().collect()

for r in pairs:
    device = r["deviceId"]
    color = r["color"]

    df_group = df_ml1.filter(
        (F.col("deviceId") == device) &
        (F.col("color") == color)
    )

    # Need at least 2 points
    if df_group.count() < 2:
        continue

    lr = LinearRegression(
        featuresCol="features",
        labelCol="label",
        fitIntercept=True
    )

    model = lr.fit(df_group)

    m = model.coefficients[0]
    c = model.intercept

    # Valid depletion model
    if m >= 0:
        continue

    predicted_print_count = -c / m

    results.append(
        (device, color, m, c, predicted_print_count)
    )

from pyspark.sql.types import StructType, StructField, StringType, DoubleType
schema1= StructType([
    StructField("deviceId", StringType(), True),
    StructField("color", StringType(), True),
    StructField("slope_m", DoubleType(), True),
    StructField("intercept_c", DoubleType(), True),
    StructField("predicted_print_count", DoubleType(), True)
])

df_predictions = spark.createDataFrame(results, schema1)

df_predictions.filter(
    F.col("predicted_print_count") > 0
).display()



deviceId,color,slope_m,intercept_c,predicted_print_count
mn=QlA1MEM1NQ==:sn=NDMwMDY5MzcwMA==,black,-0.0062738597632555,96.3386223508573,15355.55877660641
mn=QlA1MEM1NQ==:sn=NDMwMDY5MzcwMA==,cyan,-0.0050220344386555,92.03550827715635,18326.33953458036
mn=QlA1MEM1NQ==:sn=NDMwMDY5MzcwMA==,magenta,-0.0044163414500378,91.99320703161555,20830.184457505515
mn=QlA1MEM1NQ==:sn=NDMwMDY5MzcwMA==,yellow,-0.0041247456533441,90.19892484415048,21867.7543840846
mn=QlA1MEM1NQ==:sn=NDMwMDY5OTYwMA==,black,-0.0032572852099922,95.1478249858095,29210.77487900894
mn=QlA1MEM1NQ==:sn=NDMwMDY5OTYwMA==,cyan,-0.0050278890721973,76.29719538062692,15174.796874999953
mn=QlA1MEM1NQ==:sn=NDMwMDY5OTYwMA==,magenta,-0.0077888625398641,74.7810532341156,9601.023622047393
mn=QlA1MEM1NQ==:sn=NDMwMDY5OTYwMA==,yellow,-0.0080499819776524,74.31527093596063,9231.731343283407
mn=QlA1MEM1NQ==:sn=NDMwMDYwOTcwMA==,black,-0.0021268613537503,48.49991569213579,22803.51542737552
mn=QlA1MEM1NQ==:sn=NDMwMDYwOTcwMA==,cyan,-0.0019167346261895,67.57511639021479,35255.332411118616
