In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
sys.path.insert(0, '../')

In [3]:
import numpy as np
import pandas as pd

from dotenv import load_dotenv

load_dotenv('../upload.env')

EMAIL = os.getenv('EMAIL')  # the e-mail you used to sign up
assert EMAIL != '' and EMAIL is not None
BUCKET_NAME = os.getenv('BUCKET_NAME') # you received it in your e-mail
PARTICIPANT_ID = os.getenv('PARTICIPANT_ID') # you received it in your e-mail
AWS_ACCESS_KEY = os.getenv('AWS_ACCESS_KEY') # you received it in your e-mail
AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY') # you received it in your e-mail

In [4]:
pd.set_option('display.max_column', None)

In [5]:
from evaluation.EvalRSRunner import EvalRSRunner
from evaluation.EvalRSRunner import ChallengeDataset
from reclist.abstractions import RecModel

In [6]:
LIMIT = 10

In [7]:
class MyModel(RecModel):
    
    def __init__(self, items: pd.DataFrame, top_k: int=100, **kwargs):
        super(MyModel, self).__init__()
        self.items = items
        self.top_k = top_k
        # kwargs may contain additional arguments in case, for example, you
        # have data augmentation strategies
        print("Received additional arguments: {}".format(kwargs))
        return

    def train(self, train_df: pd.DataFrame):
        """
        Implement here your training logic. Since our example method is a simple random model,
        we actually don't use any training data to build the model, but you should ;-)

        At the end of training, make sure the class contains a trained model you can use in the predict method.
        """
        print(train_df.head(1))
        print("Training completed!")
        return 

    def predict(self, user_ids: pd.DataFrame) -> pd.DataFrame:
        """
        
        This function takes as input all the users that we want to predict the top-k items for, and 
        returns all the predicted songs.

        While in this example is just a random generator, the same logic in your implementation 
        would allow for batch predictions of all the target data points.
        
        """
        k = self.top_k
        num_users = len(user_ids)
        pred = self.items.sample(n=k*num_users, replace=True).index.values
        pred = pred.reshape(num_users, k)
        pred = np.concatenate((user_ids[['user_id']].values, pred), axis=1)
        return pd.DataFrame(pred, columns=['user_id', *[str(i) for i in range(k)]]).set_index('user_id')

In [8]:
dataset = ChallengeDataset(force_download=False)

LFM dataset already downloaded. Skipping download.
Loading dataset.
Generating folds.
Generating dataset hashes.


In [9]:
train, test = dataset.get_sample_train_test()

In [142]:
my_model = MyModel(
    items=dataset.df_tracks,
    my_custom_argument='my_custom_argument' 
)

Received additional arguments: {'my_custom_argument': 'my_custom_argument'}


In [10]:
class CustomRunner(EvalRSRunner):
    def __init__(self,
                 dataset: ChallengeDataset,
                 email: str = None,
                 participant_id: str = None,
                 aws_access_key_id: str = None,
                 aws_secret_access_key: str = None,
                 bucket_name: str = None):
        super().__init__(
                 dataset,
                 email,
                 participant_id,
                 aws_access_key_id,
                 aws_secret_access_key,
                 bucket_name)
        
    def evaluate(
        self,
        model,
        seed: int = None,
        upload: bool = True,
        limit: int = 0,
        custom_RecList = None,
        debug=True,
        **kwargs):

        return super().evaluate(model,
                seed,
                upload,
                limit,
                custom_RecList,
                debug
                        )

In [11]:
custom_runner = CustomRunner(
    dataset=dataset,
    aws_access_key_id=AWS_ACCESS_KEY,
    aws_secret_access_key=AWS_SECRET_KEY,
    participant_id=PARTICIPANT_ID,
    bucket_name=BUCKET_NAME,
    email=EMAIL
    )

In [12]:
r = custom_runner.evaluate(model=my_model, limit=LIMIT)

NameError: name 'my_model' is not defined

## Create complete dataset

In [13]:
dataset.df_tracks

Unnamed: 0_level_0,track,artist_id,artist,albums_id,albums
track_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,A Matter of Time,3,Foo Fighters,"[1, 67655, 605875, 682938, 889936, 2797722, 34...","['Wasting Light', 'Wasting Light (Deluxe Versi..."
2,Hangar 18,1,Megadeth,"[2, 693, 821, 2113, 12071, 13309, 17004, 83531...","['Rust In Peace', 'Countdown To Extinction', '..."
3,Up the Downstair,4,Porcupine Tree,"[84, 116, 13570, 283314, 302086, 303154, 35956...","['Up the Downstair', 'Coma Divine', 'Up The Do..."
5,Mixtaped,5,No-Man,"[5, 7654681]","['Schoolyard Ghosts', 'Schoolyard Ghosts Disc 1']"
6,Sounds that I Hear,6,Airbag,"[6, 340, 1139848]","['Sounds That I Hear', 'Identity', nan]"
...,...,...,...,...,...
32244852,Cotton Eye Joe,10641,Rednex,"[93769, 941156, 10112520]","[""100 Hits Of The '90s"", 'Sex & Violins', 'Q M..."
32250955,Hot Stuff,2949,Donna Summer,"[120851, 127382, 217400, 256358, 712567]","['On the Radio (Greatest Hits)', 'Endless Summ..."
32253071,Candyman,326,Christina Aguilera,"[36723, 37494]","['Back to Basics', 'Back to Basics']"
32275309,One (Radio Edit),2322,Swedish House Mafia,[32033],['One (Your Name) [feat. Pharrell]']


## Create Complete Dataset

In [15]:
train.head()

Unnamed: 0,user_id,artist_id,album_id,track_id,timestamp,user_track_count
43142,7438644,1312,2944,8507,1354982620,1
43143,7438644,1312,2944,8508,1354982427,1
43144,7438644,1312,2944,8509,1354982287,1
43145,7438644,1312,2944,8510,1354981997,1
43149,7438644,1313,2946,8514,1354981396,1


In [16]:
dataset.df_users

Unnamed: 0_level_0,country,age,gender,playcount,registered_unixtime,country_id,gender_id,novelty_artist_avg_month,novelty_artist_avg_6months,novelty_artist_avg_year,mainstreaminess_avg_month,mainstreaminess_avg_6months,mainstreaminess_avg_year,mainstreaminess_global,cnt_listeningevents,cnt_distinct_tracks,cnt_distinct_artists,cnt_listeningevents_per_week,relative_le_per_weekday1,relative_le_per_weekday2,relative_le_per_weekday3,relative_le_per_weekday4,relative_le_per_weekday5,relative_le_per_weekday6,relative_le_per_weekday7,relative_le_per_hour0,relative_le_per_hour1,relative_le_per_hour2,relative_le_per_hour3,relative_le_per_hour4,relative_le_per_hour5,relative_le_per_hour6,relative_le_per_hour7,relative_le_per_hour8,relative_le_per_hour9,relative_le_per_hour10,relative_le_per_hour11,relative_le_per_hour12,relative_le_per_hour13,relative_le_per_hour14,relative_le_per_hour15,relative_le_per_hour16,relative_le_per_hour17,relative_le_per_hour18,relative_le_per_hour19,relative_le_per_hour20,relative_le_per_hour21,relative_le_per_hour22,relative_le_per_hour23
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1
384,UK,35,m,42139,1035849600,0,0,0.276629,0.044439,0.309429,0.024655,0.367343,0.000000,0.124903,17198.0,3601.0,330.0,196.2290,0.1828,0.1513,0.1361,0.1252,0.1318,0.1286,0.1443,0.0850,0.1169,0.1101,0.0984,0.0781,0.0500,0.0145,0.0051,0.0006,0.0012,0.0079,0.0051,0.0091,0.0209,0.0279,0.0410,0.0811,0.1506,0.0377,0.0298,0.0132,0.0007,0.0001,0.0151
1206,,-1,n,33103,1035849600,-1,1,0.437224,0.109671,0.513787,0.181991,0.391304,0.000000,0.343855,17967.0,10990.0,1693.0,265.1490,0.1679,0.1850,0.1815,0.1640,0.1542,0.0700,0.0773,0.0539,0.0371,0.0321,0.0245,0.0322,0.0368,0.0366,0.0460,0.0257,0.0041,0.0014,,,0.0003,0.0016,0.0115,0.0276,0.0662,0.0790,0.0838,0.0995,0.1195,0.0968,0.0839
2622,,-1,,2030,1037404800,-1,-1,0.604828,0.043923,0.698983,0.052310,0.780064,0.079669,0.245980,3939.0,3084.0,1176.0,22.8009,0.1813,0.1318,0.1181,0.1442,0.1775,0.1084,0.1389,0.0074,0.0053,0.0036,0.0005,0.0003,0.0008,0.0003,0.0003,0.0008,0.0025,0.0279,0.0678,0.0675,0.1285,0.1320,0.0493,0.0317,0.0470,0.0658,0.1127,0.1112,0.0680,0.0437,0.0254
2732,,-1,n,147,1037577600,-1,1,0.756973,0.020071,0.882801,0.005092,0.886364,0.032614,0.077512,234.0,202.0,112.0,1.1455,0.0214,0.0726,0.0940,0.2778,0.0171,0.2564,0.2607,,,,,,,,,0.0128,0.0342,0.0342,0.1410,0.1154,0.1368,0.0598,0.0726,0.0171,0.0342,0.0769,0.1453,0.0470,0.0513,0.0085,0.0128
3653,UK,31,m,18504,1041033600,0,0,0.380005,0.045207,0.424411,0.042821,0.491756,0.077731,0.207567,18238.0,9839.0,2151.0,41.8051,0.1164,0.1545,0.1316,0.1380,0.1580,0.1745,0.1270,0.0612,0.0392,0.0249,0.0250,0.0169,0.0105,0.0088,0.0063,0.0086,0.0206,0.0117,0.0198,0.0434,0.0495,0.0541,0.0518,0.0564,0.0554,0.0849,0.0954,0.0833,0.0657,0.0471,0.0595
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
50871714,BY,19,f,569,1342728447,55,2,0.417087,0.035417,0.549888,0.084304,0.662491,0.075544,0.160736,26932.0,6837.0,1154.0,244.3810,0.1765,0.1437,0.1091,0.1272,0.1750,0.1444,0.1241,0.0568,0.0440,0.0292,0.0197,0.0138,0.0111,0.0118,0.0140,0.0176,0.0265,0.0314,0.0420,0.0554,0.0567,0.0542,0.0529,0.0575,0.0558,0.0578,0.0573,0.0515,0.0551,0.0639,0.0641
50900118,RU,19,m,96,1342887305,12,0,0.566328,0.032624,0.680317,0.049010,0.827854,0.103613,0.145418,7174.0,3970.0,547.0,76.4347,0.1398,0.1280,0.1310,0.1469,0.1450,0.1415,0.1678,0.0375,0.0162,0.0057,0.0163,0.0114,0.0121,0.0107,0.0110,0.0360,0.0407,0.0514,0.0489,0.0473,0.0415,0.0397,0.0431,0.0587,0.0733,0.0843,0.0764,0.0728,0.0643,0.0555,0.0452
50931921,,-1,m,221,1343064308,-1,0,0.483171,0.032946,0.351648,0.038075,0.382353,0.012505,0.136102,998.0,676.0,155.0,12.9569,0.2695,0.0341,0.0932,0.1453,0.0611,0.0992,0.2976,,0.0010,,,,,,,,0.0130,0.0190,0.0641,0.0832,0.0681,0.0561,0.1172,0.1723,0.1062,0.0892,0.1092,0.0621,0.0120,0.0210,0.0060
50933471,,-1,n,49,1343071062,-1,1,0.559632,0.026570,0.599199,0.040033,0.656787,0.039335,0.175908,4152.0,2845.0,475.0,38.5342,0.1419,0.1202,0.1525,0.1185,0.1373,0.1168,0.2129,0.0539,0.0349,0.0248,0.0084,0.0007,0.0010,0.0022,,,0.0026,0.0079,0.0282,0.0537,0.0354,0.0417,0.0347,0.0294,0.0236,0.0508,0.0763,0.1086,0.1419,0.1387,0.1004


In [85]:
dataset.df_users

Unnamed: 0_level_0,country,age,gender,playcount,registered_unixtime,country_id,gender_id,novelty_artist_avg_month,novelty_artist_avg_6months,novelty_artist_avg_year,mainstreaminess_avg_month,mainstreaminess_avg_6months,mainstreaminess_avg_year,mainstreaminess_global,cnt_listeningevents,cnt_distinct_tracks,cnt_distinct_artists,cnt_listeningevents_per_week,relative_le_per_weekday1,relative_le_per_weekday2,relative_le_per_weekday3,relative_le_per_weekday4,relative_le_per_weekday5,relative_le_per_weekday6,relative_le_per_weekday7,relative_le_per_hour0,relative_le_per_hour1,relative_le_per_hour2,relative_le_per_hour3,relative_le_per_hour4,relative_le_per_hour5,relative_le_per_hour6,relative_le_per_hour7,relative_le_per_hour8,relative_le_per_hour9,relative_le_per_hour10,relative_le_per_hour11,relative_le_per_hour12,relative_le_per_hour13,relative_le_per_hour14,relative_le_per_hour15,relative_le_per_hour16,relative_le_per_hour17,relative_le_per_hour18,relative_le_per_hour19,relative_le_per_hour20,relative_le_per_hour21,relative_le_per_hour22,relative_le_per_hour23
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1
384,UK,35,m,42139,1035849600,0,0,0.276629,0.044439,0.309429,0.024655,0.367343,0.000000,0.124903,17198.0,3601.0,330.0,196.2290,0.1828,0.1513,0.1361,0.1252,0.1318,0.1286,0.1443,0.0850,0.1169,0.1101,0.0984,0.0781,0.0500,0.0145,0.0051,0.0006,0.0012,0.0079,0.0051,0.0091,0.0209,0.0279,0.0410,0.0811,0.1506,0.0377,0.0298,0.0132,0.0007,0.0001,0.0151
1206,,-1,n,33103,1035849600,-1,1,0.437224,0.109671,0.513787,0.181991,0.391304,0.000000,0.343855,17967.0,10990.0,1693.0,265.1490,0.1679,0.1850,0.1815,0.1640,0.1542,0.0700,0.0773,0.0539,0.0371,0.0321,0.0245,0.0322,0.0368,0.0366,0.0460,0.0257,0.0041,0.0014,,,0.0003,0.0016,0.0115,0.0276,0.0662,0.0790,0.0838,0.0995,0.1195,0.0968,0.0839
2622,,-1,,2030,1037404800,-1,-1,0.604828,0.043923,0.698983,0.052310,0.780064,0.079669,0.245980,3939.0,3084.0,1176.0,22.8009,0.1813,0.1318,0.1181,0.1442,0.1775,0.1084,0.1389,0.0074,0.0053,0.0036,0.0005,0.0003,0.0008,0.0003,0.0003,0.0008,0.0025,0.0279,0.0678,0.0675,0.1285,0.1320,0.0493,0.0317,0.0470,0.0658,0.1127,0.1112,0.0680,0.0437,0.0254
2732,,-1,n,147,1037577600,-1,1,0.756973,0.020071,0.882801,0.005092,0.886364,0.032614,0.077512,234.0,202.0,112.0,1.1455,0.0214,0.0726,0.0940,0.2778,0.0171,0.2564,0.2607,,,,,,,,,0.0128,0.0342,0.0342,0.1410,0.1154,0.1368,0.0598,0.0726,0.0171,0.0342,0.0769,0.1453,0.0470,0.0513,0.0085,0.0128
3653,UK,31,m,18504,1041033600,0,0,0.380005,0.045207,0.424411,0.042821,0.491756,0.077731,0.207567,18238.0,9839.0,2151.0,41.8051,0.1164,0.1545,0.1316,0.1380,0.1580,0.1745,0.1270,0.0612,0.0392,0.0249,0.0250,0.0169,0.0105,0.0088,0.0063,0.0086,0.0206,0.0117,0.0198,0.0434,0.0495,0.0541,0.0518,0.0564,0.0554,0.0849,0.0954,0.0833,0.0657,0.0471,0.0595
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
50871714,BY,19,f,569,1342728447,55,2,0.417087,0.035417,0.549888,0.084304,0.662491,0.075544,0.160736,26932.0,6837.0,1154.0,244.3810,0.1765,0.1437,0.1091,0.1272,0.1750,0.1444,0.1241,0.0568,0.0440,0.0292,0.0197,0.0138,0.0111,0.0118,0.0140,0.0176,0.0265,0.0314,0.0420,0.0554,0.0567,0.0542,0.0529,0.0575,0.0558,0.0578,0.0573,0.0515,0.0551,0.0639,0.0641
50900118,RU,19,m,96,1342887305,12,0,0.566328,0.032624,0.680317,0.049010,0.827854,0.103613,0.145418,7174.0,3970.0,547.0,76.4347,0.1398,0.1280,0.1310,0.1469,0.1450,0.1415,0.1678,0.0375,0.0162,0.0057,0.0163,0.0114,0.0121,0.0107,0.0110,0.0360,0.0407,0.0514,0.0489,0.0473,0.0415,0.0397,0.0431,0.0587,0.0733,0.0843,0.0764,0.0728,0.0643,0.0555,0.0452
50931921,,-1,m,221,1343064308,-1,0,0.483171,0.032946,0.351648,0.038075,0.382353,0.012505,0.136102,998.0,676.0,155.0,12.9569,0.2695,0.0341,0.0932,0.1453,0.0611,0.0992,0.2976,,0.0010,,,,,,,,0.0130,0.0190,0.0641,0.0832,0.0681,0.0561,0.1172,0.1723,0.1062,0.0892,0.1092,0.0621,0.0120,0.0210,0.0060
50933471,,-1,n,49,1343071062,-1,1,0.559632,0.026570,0.599199,0.040033,0.656787,0.039335,0.175908,4152.0,2845.0,475.0,38.5342,0.1419,0.1202,0.1525,0.1185,0.1373,0.1168,0.2129,0.0539,0.0349,0.0248,0.0084,0.0007,0.0010,0.0022,,,0.0026,0.0079,0.0282,0.0537,0.0354,0.0417,0.0347,0.0294,0.0236,0.0508,0.0763,0.1086,0.1419,0.1387,0.1004


In [91]:
complete_train_df = pd.merge(train, dataset.df_users.reset_index(), how='left', on='user_id')

X_cols = [c for c in complete_train_df.columns if c.startswith('relative')]

In [120]:
X_cols = [c for c in complete_train_df.columns if c.startswith('relative_per_week') or c.startswith('mainstreaminess')
            or c.startswith('novelty')] + ['age', 'gender_id']



In [133]:
X_cols

['novelty_artist_avg_month',
 'novelty_artist_avg_6months',
 'novelty_artist_avg_year',
 'mainstreaminess_avg_month',
 'mainstreaminess_avg_6months',
 'mainstreaminess_avg_year',
 'mainstreaminess_global',
 'age',
 'gender_id']

In [121]:
dataset.df_users.index.nunique()

119555

In [122]:
complete_train_df['user_id'].nunique()

29730

In [123]:
complete_train_df[X_cols].drop_duplicates().shape

(29730, 9)

## K-Nearest Neighbors

In [17]:
from sklearn.cluster import DBSCAN

In [44]:
class DBSCANModel(RecModel):
    
    def __init__(self, items: pd.DataFrame,
                 users: pd.DataFrame, top_k: int=100, **kwargs):
        super(DBSCANModel, self).__init__()
        self.items = items
        self.users = users
        self.top_k = top_k
        self.X_cols = ['novelty_artist_avg_month',
                 'novelty_artist_avg_6months',
                 'novelty_artist_avg_year',
                 'mainstreaminess_avg_month',
                 'mainstreaminess_avg_6months',
                 'mainstreaminess_avg_year',
                 'mainstreaminess_global',
                 'age',
                 'gender_id'
                 ]
        # kwargs may contain additional arguments in case, for example, you
        # have data augmentation strategies
        print("Received additional arguments: {}".format(kwargs))
        return

    def train(self, train_df: pd.DataFrame):
        """
        Implement here your training logic. Since our example method is a simple random model,
        we actually don't use any training data to build the model, but you should ;-)

        At the end of training, make sure the class contains a trained model you can use in the predict method.
        """
        
        X_cluster_df = self.users[self.X_cols].fillna(self.users[self.X_cols].mean())
        
        
        
        self.clustering = DBSCAN()
        self.users['pred_cluster'] = self.clustering.fit_predict(X_cluster_df)
        
        self.complete_df = pd.merge(train_df, self.users.reset_index()[['user_id', 'pred_cluster']], on='user_id', how='left')
        
        
        print("Training completed!")
        return 

    def predict(self, user_ids: pd.DataFrame) -> pd.DataFrame:
        """
        
        This function takes as input all the users that we want to predict the top-k items for, and 
        returns all the predicted songs.

        While in this example is just a random generator, the same logic in your implementation 
        would allow for batch predictions of all the target data points.
        
        """
        k = self.top_k
        num_users = len(user_ids)
        
        
        
        cluster_tracks_df = pd.DataFrame(self.complete_df.groupby(['pred_cluster', 'track_id'])['user_track_count'].sum()).reset_index()
        
        users_lsts = []
        for user in user_ids['user_id'].tolist():
            user_cluster = self.users.loc[user]['pred_cluster']
            user_cluster_tracks = cluster_tracks_df[cluster_tracks_df['pred_cluster'] == user_cluster][['track_id', 'user_track_count']]
            
            track_counts = user_cluster_tracks['user_track_count']
            
            track_counts_n = track_counts / np.linalg.norm(track_counts, ord=1)
            
            suggested_tracks = np.random.choice(user_cluster_tracks['track_id'], k, p=track_counts_n)
            
            user_dict = {'user_id': user}
            for i in range(k):
                user_dict[str(i)] = suggested_tracks[i]
            
            users_lsts.append(user_dict)
        
        
        output_df = pd.DataFrame(users_lsts)

        users_df = output_df.set_index('user_id')
        return users_df

In [45]:
dbscan_runner = CustomRunner(
    dataset=dataset,
    aws_access_key_id=AWS_ACCESS_KEY,
    aws_secret_access_key=AWS_SECRET_KEY,
    participant_id=PARTICIPANT_ID,
    bucket_name=BUCKET_NAME,
    email=EMAIL
    )

In [46]:
my_model = DBSCANModel(
    items=dataset.df_tracks,
    users=dataset.df_users
)

Received additional arguments: {}


In [49]:
# r = dbscan_runner.evaluate(model=my_model, limit=0)