In [1]:
import polars as pl
from sklearn.cluster import KMeans

In [2]:
lf = pl.scan_parquet("data/processed/amazon-2023.parquet")

In [3]:
lf: pl.DataFrame = lf.group_by("parent_asin").agg([
    pl.len().alias("total_reviews"),
    pl.col("rating").mean().alias("mean_rating"),
    pl.col("brand").first().fill_null("Unknown").str.to_lowercase().alias("brand_id"),
    pl.col("main_category").first().fill_null("Unknown").str.to_lowercase().alias("category_id"),
])

In [4]:
lf: pl.LazyFrame = lf.with_columns([
    pl.col("brand_id").cast(pl.Categorical).to_physical().alias("brand_id"),
    pl.col("category_id").cast(pl.Categorical).to_physical().alias("category_id")
])

columns: list[str] = ["mean_rating", "total_reviews", "brand_id", "category_id", "parent_asin"]
lf = lf.select(columns)
lf=lf.drop("parent_asin")

In [5]:
df: pl.DataFrame = lf.collect(engine="streaming")
X: pl.DataFrame = df.to_numpy()

In [6]:
kmeans = KMeans(n_clusters=5, random_state=42, n_init=10)
labels = kmeans.fit_predict(X)

df: pl.DataFrame = df.with_columns(pl.Series(name="cluster", values=labels))

In [7]:
summary: pl.DataFrame = (
df.group_by("cluster")
    .agg([
        pl.len().alias("cluster_size"),
        pl.col("mean_rating").mean().alias("avg_mean_rating"),
        pl.col("total_reviews").mean().alias("avg_total_reviews"),
        pl.col("brand_id").mean().alias("avg_brand_id"),
        pl.col("category_id").mean().alias("avg_category_id"),
        pl.col("brand_id").mode().alias("top_brand_id"),
        pl.col("category_id").mode().alias("top_category_id"),
    ])
    .sort("cluster")
)

summary

cluster,cluster_size,avg_mean_rating,avg_total_reviews,avg_brand_id,avg_category_id,top_brand_id,top_category_id
i32,u32,f64,f64,f64,f64,list[u32],list[u32]
0,23317743,4.08529,16.288415,104126.549329,588703.963657,[31],[588849]
1,2466456,4.258269,9.041701,2303200.0,555780.659087,[1682507],[588839]
2,1895918,4.319212,6.382607,3901200.0,566696.411559,[3159577],[588839]
3,4397631,4.149755,12.030138,991741.20988,547324.914706,[638029],[588839]
4,3287580,4.00405,10.910735,120418.439924,52730.792819,[31],[31]
