In [28]:
import wandb
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import pandas as pd
from preprocess import preprocess

# wandb.login()

# wandb.init(project="ium-projekt", entity="ium-team",
#            group="production")


numerical_columns = ['popularity', 'duration_ms',
                     'danceability', 'energy', 'key', 'loudness',
                     'speechiness', 'acousticness', 'instrumentalness',
                     'liveness', 'valence', 'tempo',
                     'popularity_from_sessions', 'release_date_numeric']

data = preprocess("../data_v2/data/", numerical_columns)

In [95]:
data_numerical = data[numerical_columns]

k = 11

kmeans = KMeans(n_clusters=k, n_init=10)
kmeans.fit(data_numerical)
labels = kmeans.labels_
inertia = kmeans.inertia_
# wandb.log({"num_clusters": k})
# wandb.log({"inertia": inertia})
# wandb.log({"silhouette_score": silhouette_score(data_numerical, labels)})

In [96]:
data["labels"] = labels

In [97]:
grouped = data.groupby("labels")

In [98]:
data["duration_hours"] = data["duration_unscaled"]/3600000

In [99]:
data.columns

Index(['track_id', 'track_name', 'popularity', 'duration_ms', 'explicit',
       'release_date', 'danceability', 'energy', 'key', 'loudness',
       'speechiness', 'acousticness', 'instrumentalness', 'liveness',
       'valence', 'tempo', 'popularity_from_sessions', 'artist_id',
       'artist_name', 'genres', 'release_date_numeric', 'duration_unscaled',
       'labels', 'duration_hours'],
      dtype='object')

In [100]:
grouped_dataframes = {label: group_df for label, group_df in grouped}

In [101]:
playlists = []

for label, group_df in grouped_dataframes.items():
    # If there arent enough songs in the group, add random songs,
    # that arent already present in the group
    if group_df["duration_hours"].sum() < 1:
        while group_df["duration_hours"].sum() < 1:
            common_rows = pd.merge(data, group_df, how='inner', on='track_id')
            df_excluded = data[~data['track_id'].isin(common_rows['track_id'])]
            random_sample = df_excluded.sample(n=1, replace=False)
            group_df = pd.concat([data, random_sample], ignore_index=True)
    
    # Get the correct playlist duration
            
    group_df.sort_values(by="popularity_from_sessions", ascending=False, inplace=True)
    selected_rows = []
    current_playlist_duration = 0
    for index, row in group_df.iterrows():
        selected_rows.append(row)
        current_playlist_duration += row['duration_hours']

        if current_playlist_duration >= 1:
            break

    playlist = pd.DataFrame(selected_rows)
    playlists.append((label, playlist, playlist["popularity_from_sessions"].mean()))

# Now out of the playlists select the 10 most popular on average
selected_playlists = {}
playlists.sort(key=lambda p: p[2], reverse=True)
playlists = playlists[:10]


In [102]:
for label, playlist, popularity in playlists:
    print(label, playlist.shape[0], playlist["duration_hours"].sum(), popularity)

7 13 1.0421927777777777 0.7964352720450283
9 16 1.0264475 0.78125
6 14 1.104151111111111 0.7735191637630663
3 16 1.042471388888889 0.760670731707317
10 16 1.011691111111111 0.759908536585366
1 17 1.0521749999999999 0.7496413199426112
4 15 1.0089238888888892 0.732520325203252
5 15 1.0001513888888889 0.7308943089430895
2 16 1.0161333333333333 0.7179878048780488
0 13 1.013041111111111 0.6763602251407129


In [106]:
playlists_list = []
for _, playlist, _ in playlists:
    playlists_list.append(playlist["track_id"].to_list())

In [107]:
playlists_list

[['4V84hb0KLUwGgLoXfX0YMa',
  '1q8E1FfFuhd12c5JcJwPxQ',
  '0yzr0zrYvv5oyNoCMBFMAa',
  '14uvyd51Ha7FihKHlOtUig',
  '0eXz8pS25MoeUguNPR9VvD',
  '6vRkYTrWDzzVrZTqBJFR0u',
  '0X6coWomPnfGLXQ6gdi3cI',
  '0Qrb3L8JgreLBW8g4qyan9',
  '6FVYwnVrnAEIRnY3bHJb46',
  '2OXo0vKbu3ilgz4S5EOn60',
  '4Tp4gRuDo1sIMP6gH9LwuH',
  '2DeK0E3KxFzuDbwKQo1A49',
  '2EoOZnxNgtmZaD8uUmz2nD'],
 ['3NrfU7FvNF0w6mGcFVnlUj',
  '5LME7YULt0enp6UAB8VoDn',
  '15k1TDabqSEmyXOwMq9RM7',
  '2IT0T0EqPaUxasjl2o8J2G',
  '4odiyU3myG29Ld0wurMfE8',
  '6FI3RJ58Ztl0X1VtA6pVs9',
  '4tqBQD2QG0IYLU08rpkU6X',
  '2aadUB2b4hugbdIPU8Aypk',
  '5cTsXX5qwa6zmG80OCz4hR',
  '0i5el041vd6nxrGEU8QRxy',
  '7kyK2NSDfRE1612vdYuqIx',
  '3G0yz3DZn3lfraledmBCT0',
  '0CaBBQsaAiRHhiLmzi7ZRp',
  '1YP719l4JjsOmyU4PGv3c0',
  '5OOxMbmz5txzE78oZbGQhY',
  '7cdy4PbCdDZZNjyxoZyE0c'],
 ['4nehxVflg443IcjhsqpfEG',
  '2TDqa2yNF9qhSmP8gqcleE',
  '3La01jjk5XfvpdyOCMlV1A',
  '1Pctt56GoC6Xn8AxwFKGWE',
  '6Kkt27YmFyIFrcX3QXFi2o',
  '5xRP5iyVdGglqlY4Vcjhkx',
  '7dn6WQzScfRGpp7