In [None]:
import polars as pl
from sklearn.cluster import KMeans

In [64]:
lf: pl.DataFrame = pl.scan_parquet("data/processed/amazon-2023.parquet")

In [65]:
lf: pl.DataFrame = lf.group_by("parent_asin").agg([
    pl.len().alias("total_reviews"),
    pl.col("rating").mean().alias("mean_rating"),
    pl.col("brand").first().fill_null("Unknown").str.to_lowercase().alias("brand_id"),
    pl.col("main_category").first().fill_null("Unknown").str.to_lowercase().alias("category_id"),
])

In [66]:
lf: pl.LazyFrame = lf.with_columns([
    pl.col("brand_id").cast(pl.Categorical).to_physical().alias("brand_id"),
    pl.col("category_id").cast(pl.Categorical).to_physical().alias("category_id")
])

columns: list[str] = ["mean_rating", "total_reviews", "brand_id", "category_id", "parent_asin"]
lf = lf.select(columns)

In [67]:
df: pl.DataFrame = lf.collect()
X: pl.DataFrame = df.drop("parent_asin").to_numpy()

In [None]:
kmeans = KMeans(n_clusters=5, random_state=42, n_init=10)
labels = kmeans.fit_predict(X)

df: pl.DataFrame = df.with_columns(pl.Series(name="cluster", values=labels))

In [70]:
summary: pl.DataFrame = (
df.group_by("cluster")
    .agg([
        pl.len().alias("cluster_size"),
        pl.col("mean_rating").mean().alias("avg_mean_rating"),
        pl.col("total_reviews").mean().alias("avg_total_reviews"),
        pl.col("brand_id").mean().alias("avg_brand_id"),
        pl.col("category_id").mean().alias("avg_category_id"),
        pl.col("brand_id").mode().alias("top_brand_id"),
        pl.col("category_id").mode().alias("top_category_id"),
    ])
    .sort("cluster")
)

summary

cluster,cluster_size,avg_mean_rating,avg_total_reviews,avg_brand_id,avg_category_id,top_brand_id,top_category_id
i32,u32,f64,f64,f64,f64,list[u32],list[u32]
0,173759,4.237223,2.963662,125641.190033,6.88716,[104360],[5]
1,701879,4.207828,4.109774,10434.166891,6.749344,[21],[7]
2,127859,4.255761,2.449628,205077.01346,7.223582,[172250],[5]
3,274584,4.2127,3.42583,60019.671882,6.58373,[50580],[7]
4,103789,4.301812,2.157396,294546.480263,7.538265,[262774],[5]
