# Using ALS + KMeans to get meme clusters

In [1]:
from datetime import datetime

import polars as pl
from implicit.als import AlternatingLeastSquares
from scipy import sparse
from sklearn.cluster import KMeans
from sklearn.preprocessing import LabelEncoder

In [2]:
dtm_fmt = "%B %-d, %Y, %I:%M %p"

In [65]:
fresh_users = (
    pl.read_parquet('user_features_daily.pq')
    .filter(pl.col('date_dtm') == datetime(2024, 3, 31))
    .filter(pl.col('n_memes_sent') < 1000)
)

In [66]:
user_meme_df = (
    pl.read_csv('user_meme_reaction_240301_240413.csv')
    .with_columns(pl.col('sent_at').str.to_datetime(dtm_fmt))
    .with_columns(pl.col('reaction_id').fill_null(2))
    .with_columns(pl.col('reaction_id').map_elements(lambda x: 1 if x == 1 else -1, pl.Int64))
    .filter(pl.col('sent_at') < datetime(2024, 4, 1))
    .join(fresh_users.select('user_id'), on='user_id', how='inner')
)

In [67]:
user_meme_df.shape

(165060, 6)

In [68]:
le_user = LabelEncoder().fit(user_meme_df.select('user_id').unique().get_column('user_id').to_list())
le_meme = LabelEncoder().fit(user_meme_df.select('meme_id').unique().get_column('meme_id').to_list())

In [69]:
n_users = le_user.classes_.shape[0]
n_memes = le_meme.classes_.shape[0]

In [70]:
users = le_user.transform(user_meme_df.get_column('user_id'))
memes = le_meme.transform(user_meme_df.get_column('meme_id'))
reactions = user_meme_df.get_column('reaction_id').to_numpy()

In [71]:
user_meme = sparse.coo_array((reactions, (users, memes)), shape=(n_users, n_memes)).tocsr()

In [87]:
model = AlternatingLeastSquares(factors=32, regularization=1000, alpha=1, iterations=15, calculate_training_loss=True)

In [88]:
model.fit(user_meme)



  0%|          | 0/15 [00:00<?, ?it/s]

In [89]:
item_embeds = model.item_factors

In [91]:
cluster = KMeans(n_clusters=10)

In [92]:
clusters = cluster.fit_predict(item_embeds)

In [93]:
pl.Config().set_tbl_rows(30)

polars.config.Config

In [94]:
pl.Series(clusters).value_counts().sort('count')

Unnamed: 0_level_0,count
i32,u32
8,508
7,517
9,531
3,539
4,576
1,725
2,756
6,1275
5,2458
0,10592


In [95]:
res = []
for meme_id, cluster_id in zip(le_meme.classes_, clusters):
    res.append({'meme_id': meme_id, 'cluster_id': cluster_id})
pl.DataFrame(res).write_parquet('meme_clusters.pq')