In [2]:
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
import numpy as np
import tensorflow_recommenders as tfrs


In [7]:
songs_df = pd.read_csv(r'data/original/songs.csv')
ratings_df = pd.read_csv(r'data/original/ratings.csv')

In [6]:
songs_df.columns

Index(['song id', 'song title', 'release date', 'video release date',
       'IMDb URL', 'unknown', 'Pop', 'Rock', 'HipHop', 'Rap', 'Electronic',
       'Country', 'Dance', 'Jazz', 'Blues', 'Reggae', 'Classical', 'R&B',
       'Funk', 'Metal', 'Indie', 'Soul', 'WorldMusic', 'Western'],
      dtype='object')

In [4]:
ratings_df.dropna(inplace=True)

In [10]:
ratings_df.columns

Index(['user_id', 'song_id', 'rating', 'unix_timestamp'], dtype='object')

In [6]:
cols = [col.lower().replace(' ', '_').replace('[', '').replace(']', '') for col in songs_df.columns]
songs_df.columns = cols
songs_df.head(2)

Unnamed: 0,song_id,song_title,release_date,video_release_date,imdb_url,unknown,pop,rock,hiphop,rap,...,blues,reggae,classical,r&b,funk,metal,indie,soul,worldmusic,western
0,1,Call of the Mastodon,1-Jan-95,,http://us.imdb.com/M/title-exact?Toy%20Story%2...,0,0,0,1,1,...,0,0,0,0,0,0,0,0,0,0
1,2,Fear Itself,1-Jan-95,,http://us.imdb.com/M/title-exact?GoldenEye%20(...,0,1,1,0,0,...,0,0,0,0,0,0,0,1,0,0


In [7]:
#converting ids
ratings_df['user_id'] = 'id_' + ratings_df['user_id'].astype(str)
ratings_df['song_id'] = 'id_' + ratings_df['song_id'].astype(str)
songs_df['song_id'] = 'id_' + songs_df['song_id'].astype(str)

In [8]:

#deciding how to divide high activity and low activity users

user_counts = ratings_df['user_id'].value_counts()

# Filter user IDs with counts greater than 100
selected_user_ids = user_counts[user_counts <50]
for user_id, count in selected_user_ids.items():
    print(f"User ID: {user_id}, Count: {count}")
print(len(selected_user_ids))

#decided to keep low count as 50

User ID: id_483, Count: 49
User ID: id_893, Count: 49
User ID: id_546, Count: 49
User ID: id_190, Count: 49
User ID: id_100, Count: 49
User ID: id_8, Count: 49
User ID: id_161, Count: 48
User ID: id_633, Count: 48
User ID: id_411, Count: 48
User ID: id_227, Count: 48
User ID: id_507, Count: 48
User ID: id_365, Count: 48
User ID: id_647, Count: 48
User ID: id_348, Count: 48
User ID: id_81, Count: 48
User ID: id_367, Count: 48
User ID: id_787, Count: 47
User ID: id_839, Count: 47
User ID: id_395, Count: 47
User ID: id_257, Count: 47
User ID: id_753, Count: 47
User ID: id_470, Count: 47
User ID: id_490, Count: 47
User ID: id_37, Count: 47
User ID: id_539, Count: 46
User ID: id_248, Count: 46
User ID: id_187, Count: 46
User ID: id_481, Count: 46
User ID: id_52, Count: 46
User ID: id_96, Count: 46
User ID: id_793, Count: 45
User ID: id_780, Count: 45
User ID: id_428, Count: 45
User ID: id_79, Count: 45
User ID: id_492, Count: 45
User ID: id_135, Count: 45
User ID: id_123, Count: 44
User ID:

In [9]:
heavy_user_df = ratings_df['user_id'].value_counts()[:-446].index
light_user_df = ratings_df['user_id'].value_counts()[-446:].index
heavy_user_ratings_df = ratings_df.query('user_id in @heavy_user_df')
light_user_ratings_df = ratings_df.query('user_id in @light_user_df')

In [10]:
light_user_ratings_df

Unnamed: 0,user_id,song_id,rating,unix_timestamp
314,id_3,id_181,4,889237482
315,id_3,id_258,2,889237026
316,id_3,id_260,4,889237455
317,id_3,id_264,2,889237297
318,id_3,id_268,3,889236961
...,...,...,...,...
90338,id_941,id_408,5,875048886
90339,id_941,id_455,4,875049038
90340,id_941,id_763,3,875048996
90341,id_941,id_919,5,875048887


In [11]:
heavy_user_ratings_df =heavy_user_ratings_df.sort_values(['user_id'])
heavy_user_ratings_df

Unnamed: 0,user_id,song_id,rating,unix_timestamp
0,id_1,id_1,5,874965758
166,id_1,id_174,5,875073198
167,id_1,id_175,5,875072547
168,id_1,id_176,5,876892468
169,id_1,id_177,5,876892701
...,...,...,...,...
9877,id_99,id_201,3,885680348
9876,id_99,id_196,4,885680578
9875,id_99,id_182,4,886518810
9888,id_99,id_265,3,885679833


In [12]:
train_df = pd.DataFrame()
test_df = pd.DataFrame()
val_df = pd.DataFrame()

In [13]:
trains =[]
vals=[]
tests=[]
grouped = heavy_user_ratings_df.groupby('user_id')
for _,group in grouped:
    train, remaining = train_test_split(group, test_size=0.2, random_state=42)
    val, test = train_test_split(remaining, test_size=0.5, random_state = 42)
    trains.append(train)
    tests.append(test)
    vals.append(val)
train_df = pd.concat(trains,axis=0)
test_df = pd.concat(tests,axis=0)
val_df = pd.concat(vals,axis=0)

In [14]:
val_df.shape

(7931, 4)

In [15]:
#preprocessing genre cols
songs_df.columns

Index(['song_id', 'song_title', 'release_date', 'video_release_date',
       'imdb_url', 'unknown', 'pop', 'rock', 'hiphop', 'rap', 'electronic',
       'country', 'dance', 'jazz', 'blues', 'reggae', 'classical', 'r&b',
       'funk', 'metal', 'indie', 'soul', 'worldmusic', 'western'],
      dtype='object')

In [16]:
genre_cols = ['unknown', 'pop', 'rock', 'hiphop', 'rap', 'electronic',
       'country', 'dance', 'jazz', 'blues', 'reggae', 'classical', 'r&b',
       'funk', 'metal', 'indie', 'soul', 'worldmusic', 'western']
songs_df['song_genre'] = songs_df[genre_cols].apply(
    lambda x: ','.join([col for col,val in zip(genre_cols,x) if val ==1]),
    axis=1
)
songs_df

Unnamed: 0,song_id,song_title,release_date,video_release_date,imdb_url,unknown,pop,rock,hiphop,rap,...,reggae,classical,r&b,funk,metal,indie,soul,worldmusic,western,song_genre
0,id_1,Call of the Mastodon,1-Jan-95,,http://us.imdb.com/M/title-exact?Toy%20Story%2...,0,0,0,1,1,...,0,0,0,0,0,0,0,0,0,"hiphop,rap,electronic"
1,id_2,Fear Itself,1-Jan-95,,http://us.imdb.com/M/title-exact?GoldenEye%20(...,0,1,1,0,0,...,0,0,0,0,0,0,1,0,0,"pop,rock,soul"
2,id_3,Dimensions,1-Jan-95,,http://us.imdb.com/M/title-exact?Four%20Rooms%...,0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,soul
3,id_4,Las Numero 1 De La Sonora Santanera,1-Jan-95,,http://us.imdb.com/M/title-exact?Get%20Shorty%...,0,1,0,0,0,...,0,0,0,0,0,0,0,0,0,"pop,electronic,jazz"
4,id_5,Friend Or Foe,1-Jan-95,,http://us.imdb.com/M/title-exact?Copycat%20(1995),0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,"country,jazz,soul"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1677,id_1678,Maurice Larcange Au Zenith,6-Feb-98,,http://us.imdb.com/M/title-exact?Mat%27+i+syn+...,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,jazz
1678,id_1679,Dangerously In Love,6-Feb-98,,http://us.imdb.com/M/title-exact?B%2E+Monkey+(...,0,0,0,0,0,...,0,0,0,0,1,0,1,0,0,"metal,soul"
1679,id_1680,Soulcrusher,1-Jan-98,,http://us.imdb.com/Title?Sliding+Doors+(1998),0,0,0,0,0,...,0,0,0,0,1,0,0,0,0,"jazz,metal"
1680,id_1681,War Of Aggression,1-Jan-94,,http://us.imdb.com/M/title-exact?You%20So%20Cr...,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,electronic


In [17]:
genre_df = songs_df['song_genre'].str.split(',',expand=True).rename(columns={0:'genre1',1:'genre2',2:'genre3',3:'genre4',4:'genre5',5:'genre6'})

In [18]:
songs_df = pd.concat(
    [songs_df, genre_df],
    axis=1
)
songs_df

Unnamed: 0,song_id,song_title,release_date,video_release_date,imdb_url,unknown,pop,rock,hiphop,rap,...,soul,worldmusic,western,song_genre,genre1,genre2,genre3,genre4,genre5,genre6
0,id_1,Call of the Mastodon,1-Jan-95,,http://us.imdb.com/M/title-exact?Toy%20Story%2...,0,0,0,1,1,...,0,0,0,"hiphop,rap,electronic",hiphop,rap,electronic,,,
1,id_2,Fear Itself,1-Jan-95,,http://us.imdb.com/M/title-exact?GoldenEye%20(...,0,1,1,0,0,...,1,0,0,"pop,rock,soul",pop,rock,soul,,,
2,id_3,Dimensions,1-Jan-95,,http://us.imdb.com/M/title-exact?Four%20Rooms%...,0,0,0,0,0,...,1,0,0,soul,soul,,,,,
3,id_4,Las Numero 1 De La Sonora Santanera,1-Jan-95,,http://us.imdb.com/M/title-exact?Get%20Shorty%...,0,1,0,0,0,...,0,0,0,"pop,electronic,jazz",pop,electronic,jazz,,,
4,id_5,Friend Or Foe,1-Jan-95,,http://us.imdb.com/M/title-exact?Copycat%20(1995),0,0,0,0,0,...,1,0,0,"country,jazz,soul",country,jazz,soul,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1677,id_1678,Maurice Larcange Au Zenith,6-Feb-98,,http://us.imdb.com/M/title-exact?Mat%27+i+syn+...,0,0,0,0,0,...,0,0,0,jazz,jazz,,,,,
1678,id_1679,Dangerously In Love,6-Feb-98,,http://us.imdb.com/M/title-exact?B%2E+Monkey+(...,0,0,0,0,0,...,1,0,0,"metal,soul",metal,soul,,,,
1679,id_1680,Soulcrusher,1-Jan-98,,http://us.imdb.com/Title?Sliding+Doors+(1998),0,0,0,0,0,...,0,0,0,"jazz,metal",jazz,metal,,,,
1680,id_1681,War Of Aggression,1-Jan-94,,http://us.imdb.com/M/title-exact?You%20So%20Cr...,0,0,0,0,0,...,0,0,0,electronic,electronic,,,,,


In [19]:
train_df.groupby('song_id').aggregate({'user_id':'count','rating':'mean'})

Unnamed: 0_level_0,user_id,rating
song_id,Unnamed: 1_level_1,Unnamed: 2_level_1
id_1,251,3.908367
id_10,55,3.763636
id_100,271,4.184502
id_1000,9,3.000000
id_1001,12,1.416667
...,...,...
id_995,15,3.200000
id_996,10,2.200000
id_997,12,2.083333
id_998,12,2.666667


In [20]:
add_features = train_df.groupby('song_id').aggregate({'user_id':'count','rating':'mean'}).rename(columns={'user_id':'total_views','rating':'avg_rating'}).reset_index()

In [21]:
add_features['avg_rating'] = add_features['avg_rating'].round(2)
add_features

Unnamed: 0,song_id,total_views,avg_rating
0,id_1,251,3.91
1,id_10,55,3.76
2,id_100,271,4.18
3,id_1000,9,3.00
4,id_1001,12,1.42
...,...,...,...
1623,id_995,15,3.20
1624,id_996,10,2.20
1625,id_997,12,2.08
1626,id_998,12,2.67


In [22]:
#merging these data into songs_df
songs_df = songs_df.merge(add_features,left_on='song_id',right_on='song_id',how='left')
songs_df

Unnamed: 0,song_id,song_title,release_date,video_release_date,imdb_url,unknown,pop,rock,hiphop,rap,...,western,song_genre,genre1,genre2,genre3,genre4,genre5,genre6,total_views,avg_rating
0,id_1,Call of the Mastodon,1-Jan-95,,http://us.imdb.com/M/title-exact?Toy%20Story%2...,0,0,0,1,1,...,0,"hiphop,rap,electronic",hiphop,rap,electronic,,,,251.0,3.91
1,id_2,Fear Itself,1-Jan-95,,http://us.imdb.com/M/title-exact?GoldenEye%20(...,0,1,1,0,0,...,0,"pop,rock,soul",pop,rock,soul,,,,97.0,3.20
2,id_3,Dimensions,1-Jan-95,,http://us.imdb.com/M/title-exact?Four%20Rooms%...,0,0,0,0,0,...,0,soul,soul,,,,,,64.0,2.98
3,id_4,Las Numero 1 De La Sonora Santanera,1-Jan-95,,http://us.imdb.com/M/title-exact?Get%20Shorty%...,0,1,0,0,0,...,0,"pop,electronic,jazz",pop,electronic,jazz,,,,150.0,3.53
4,id_5,Friend Or Foe,1-Jan-95,,http://us.imdb.com/M/title-exact?Copycat%20(1995),0,0,0,0,0,...,0,"country,jazz,soul",country,jazz,soul,,,,51.0,3.16
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1677,id_1678,Maurice Larcange Au Zenith,6-Feb-98,,http://us.imdb.com/M/title-exact?Mat%27+i+syn+...,0,0,0,0,0,...,0,jazz,jazz,,,,,,1.0,1.00
1678,id_1679,Dangerously In Love,6-Feb-98,,http://us.imdb.com/M/title-exact?B%2E+Monkey+(...,0,0,0,0,0,...,0,"metal,soul",metal,soul,,,,,,
1679,id_1680,Soulcrusher,1-Jan-98,,http://us.imdb.com/Title?Sliding+Doors+(1998),0,0,0,0,0,...,0,"jazz,metal",jazz,metal,,,,,,
1680,id_1681,War Of Aggression,1-Jan-94,,http://us.imdb.com/M/title-exact?You%20So%20Cr...,0,0,0,0,0,...,0,electronic,electronic,,,,,,1.0,3.00


In [23]:
song_df.to_csv('data/original/somgs_cleaned_merged.csv', index=False)

NameError: name 'song_df' is not defined

In [24]:
columns = [
    'song_id',
    'song_title',
    'song_genre',
    'total_views',
    'avg_rating', 
    'genre1',
]

In [25]:
train_df = train_df.merge(songs_df[columns],left_on='song_id',right_on='song_id',how='left')

val_df = val_df.merge(songs_df[columns],left_on='song_id',right_on='song_id',how='left')

test_df = test_df.merge(songs_df[columns],left_on='song_id',right_on='song_id',how='left')

light_user_ratings_df = light_user_ratings_df.merge(songs_df[columns],left_on='song_id',right_on='song_id',how='left')



In [26]:
light_user_ratings_df

Unnamed: 0,user_id,song_id,rating,unix_timestamp,song_title,song_genre,total_views,avg_rating,genre1
0,id_3,id_181,4,889237482,Bitter Suites to Succubi,"pop,rock,metal,indie,worldmusic",289.0,4.02,pop
1,id_3,id_258,2,889237026,Miscelanea,"jazz,indie",214.0,3.80,jazz
2,id_3,id_260,4,889237455,Virtuous Woman,"pop,funk,indie,soul",70.0,2.63,pop
3,id_3,id_264,2,889237297,Music For Christmas Lovers,"indie,soul",50.0,2.78,indie
4,id_3,id_268,3,889236961,Window In The Skies,"jazz,metal",116.0,3.76,jazz
...,...,...,...,...,...,...,...,...,...
10939,id_941,id_408,5,875048886,Compact Disc,"hiphop,electronic,soul",77.0,4.44,hiphop
10940,id_941,id_455,4,875049038,"Don't Sleep""""",pop,87.0,3.24,pop
10941,id_941,id_763,3,875048996,In Your Arms Again,electronic,93.0,3.38,electronic
10942,id_941,id_919,5,875048887,"Aren't Women Wonderful""""","rock,indie",64.0,3.73,rock


In [27]:
#dropping unnecessary cols
train_df = train_df.drop(['rating','unix_timestamp'],axis=1)
val_df = val_df.drop(['rating','unix_timestamp'],axis=1)
test_df = test_df.drop(['rating','unix_timestamp'],axis=1)
light_user_ratings_df = light_user_ratings_df.drop(['rating','unix_timestamp'],axis=1)

In [28]:
unique_user_ids = train_df['user_id'].unique().tolist()
unique_song_ids = songs_df['song_id'].unique().tolist()
unique_genre_ids = songs_df['genre1'].unique().tolist()
unique_title_song_ids =  songs_df['song_title'].unique().tolist()
unique_total_views = songs_df['total_views'].unique().tolist()
total_views_buckets = np.linspace(min(unique_total_views),max(unique_total_views),20,dtype=int)

In [29]:
max(unique_total_views)

300.0

In [30]:
train_df

Unnamed: 0,user_id,song_id,song_title,song_genre,total_views,avg_rating,genre1
0,id_1,id_14,Placer & Castigo,"jazz,metal",103.0,3.84,jazz
1,id_1,id_224,The Way West,jazz,28.0,3.61,jazz
2,id_1,id_242,I Need You Now,electronic,38.0,3.71,electronic
3,id_1,id_72,One World One People,"electronic,country,blues",99.0,3.19,electronic
4,id_1,id_166,Behind The Blue,jazz,45.0,4.20,jazz
...,...,...,...,...,...,...,...
63494,id_99,id_282,Coda,jazz,149.0,3.70,jazz
63495,id_99,id_409,This Is Noise,"electronic,jazz",45.0,3.04,electronic
63496,id_99,id_92,What You Thought You Heard,"pop,country,metal",71.0,3.61,pop
63497,id_99,id_751,15 Grandes \xc3\x89xitos,"pop,metal,soul",78.0,3.51,pop


In [31]:
songs_tf = tf.data.Dataset.from_tensor_slices(dict(songs_df[columns].to_dict('list')))#.prefetch(2)
songs = songs_tf.map(lambda x: {
    "song_id": x["song_id"],
    "song_title": x["song_title"],
    "total_views": x["total_views"],
    "genre1": x["genre1"],
    })
train_tf = tf.data.Dataset.from_tensor_slices(train_df.to_dict('list'))
train = train_tf.map(lambda x: {
    "user_id": x["user_id"],
    "song_id": x["song_id"],
    "song_title": x["song_title"],
    "total_views": x["total_views"],
    "genre1": x["genre1"],
    })

val_tf = tf.data.Dataset.from_tensor_slices(val_df.to_dict('list'))
val = val_tf.map(lambda x: {
    "user_id": x["user_id"],
    "song_id": x["song_id"],
    "song_title": x["song_title"],
    "total_views": x["total_views"],
    "genre1": x["genre1"],
    })

test_tf = tf.data.Dataset.from_tensor_slices(test_df.to_dict('list'))
test = test_tf.map(lambda x: {
    "user_id": x["user_id"],
    "song_id": x["song_id"],
    "song_title": x["song_title"],
    "total_views": x["total_views"],
    "genre1": x["genre1"],
    })



In [32]:
songs_tf

<_TensorSliceDataset element_spec={'song_id': TensorSpec(shape=(), dtype=tf.string, name=None), 'song_title': TensorSpec(shape=(), dtype=tf.string, name=None), 'song_genre': TensorSpec(shape=(), dtype=tf.string, name=None), 'total_views': TensorSpec(shape=(), dtype=tf.float32, name=None), 'avg_rating': TensorSpec(shape=(), dtype=tf.float32, name=None), 'genre1': TensorSpec(shape=(), dtype=tf.string, name=None)}>

In [33]:
class UserModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        self.user_query_model = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(1,), name= 'user_input', dtype=tf.string),
            tf.keras.layers.StringLookup(vocabulary=unique_user_ids,mask_token=None),
            tf.keras.layers.Embedding(
            input_dim = len(unique_user_ids)+1,
            output_dim=32),
            tf.keras.layers.Flatten(name='FlattenUser')
        ])
        
        self.song_query_model = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(1,), name= 'song_input', dtype=tf.string),
            tf.keras.layers.StringLookup(vocabulary=unique_song_ids,mask_token=None),
            tf.keras.layers.Embedding(
            input_dim = len(unique_song_ids)+1,
            output_dim=32),
            tf.keras.layers.Flatten(name='FlattenUserSong')
        ])
        
    def call(self,inputs):
            return tf.concat([
                self.user_query_model(inputs['user_id']),
                self.song_query_model(inputs['song_id'])
            ],axis=1)

In [34]:
UserModel()

<__main__.UserModel at 0x1eb30d8bd60>

In [35]:
class SongModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        self.song_candidate_model = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(1,), name = 'song_id', dtype = tf.string),
            tf.keras.layers.StringLookup(vocabulary=unique_song_ids, mask_token=None),
            tf.keras.layers.Embedding(len(unique_song_ids) + 1, 32),
            tf.keras.layers.Flatten(name='FlattenSongId')
        ])
        
        self.genre_candidate_model = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(1,), name = 'genre_id', dtype = tf.string),
            tf.keras.layers.StringLookup(vocabulary=unique_genre_ids, mask_token=None),
            tf.keras.layers.Embedding(len(unique_genre_ids) + 1, 32),
            tf.keras.layers.Flatten(name='FlattenSongGenre')
        ])

        self.title_text_embedding = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(1,), name='title_text', dtype = tf.string),
            tf.keras.layers.StringLookup(vocabulary=unique_title_song_ids, mask_token=None),
            tf.keras.layers.Embedding(len(unique_title_song_ids)+1, 32, mask_zero=True),
            tf.keras.layers.Flatten(name='FlattenCandidatesTitle'),
    ])
        self.total_views = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(1,), name = 'total_views_buckets_input', dtype = tf.int64),
            tf.keras.layers.Discretization(total_views_buckets.tolist()),
            tf.keras.layers.Embedding(len(total_views_buckets) + 1, 32),
            tf.keras.layers.Flatten(name='FlattenMerchantMonthlySales')
            ])
        
    def call(self, titles):
            return tf.concat([
                self.song_candidate_model(titles['song_id']),
                self.title_text_embedding(titles['song_title']),
                self.total_views(titles['total_views']),
                self.genre_candidate_model(titles['genre1']),
                ], axis = 1)

In [36]:
SongModel()

<__main__.SongModel at 0x1eb30e06820>

In [37]:
class CandidatesModel(tfrs.models.Model):

    def __init__(self):
        super().__init__()

        self.query_model = tf.keras.Sequential([
            UserModel(),
            tf.keras.layers.Dense(32)
        ])
        
        self.candidate_model = tf.keras.Sequential([
            SongModel(),
            tf.keras.layers.Dense(32)
        ])
        
        self.task = tfrs.tasks.Retrieval(
            metrics=tfrs.metrics.FactorizedTopK(
                candidates=songs.batch(128).map(self.candidate_model),
            ),
        )

    def compute_loss(self, features, training=False):
        query_embeddings = self.query_model({
            "user_id": features["user_id"],
            'song_id': features["song_id"]
        })

        candidates_embeddings = self.candidate_model({
            'song_id': features["song_id"],
            'song_title': features["song_title"],
            'total_views': features["total_views"],
            'genre1': features["genre1"]
        })

        return self.task(query_embeddings, candidates_embeddings)


In [38]:
candidates_model = CandidatesModel()

In [39]:
tf.random.set_seed(42)

cached_train = train.batch(1000).cache()
cached_val = val.batch(1000).cache()
cached_test = test.batch(1000).cache()
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='loss', patience=7),
    tf.keras.callbacks.ModelCheckpoint(filepath='./logs/models', save_weights_only=True, save_best_only=True, save_freq="epoch",),
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
]

In [358]:
candidates_model.compile(
    run_eagerly=True,
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
)
history = candidates_model.fit(
    x=cached_train,
    epochs=30,
    verbose=True,
    batch_size=16,
    validation_data=cached_val,
    use_multiprocessing=True,
    workers=12
)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30


Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


In [359]:
 candidates_model.save_weights(filepath='./weights/trained_n_personalized_model', overwrite=True)

In [360]:
history.history['loss']

[3096.035400390625,
 3088.57470703125,
 3071.652587890625,
 3032.447998046875,
 2950.173828125,
 2817.48876953125,
 2668.82275390625,
 2525.59228515625,
 2390.900390625,
 2264.34716796875,
 2133.671875,
 1991.560791015625,
 1839.298095703125,
 1682.5262451171875,
 1526.9586181640625,
 1376.6016845703125,
 1233.9571533203125,
 1100.30419921875,
 976.2415771484375,
 862.3427124023438,
 759.310791015625,
 667.7662963867188,
 587.9752197265625,
 519.6881103515625,
 462.1507263183594,
 414.24725341796875,
 374.7011413574219,
 342.2392272949219,
 315.6819152832031,
 293.97894287109375]

In [361]:
train_accuracy = candidates_model.evaluate(
        cached_train, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]

test_accuracy = candidates_model.evaluate(
        cached_test, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]

print(f"Top-100 accuracy (train): {train_accuracy:.2f}.")
print(f"Top-100 accuracy (test): {test_accuracy:.2f}.")

Top-100 accuracy (train): 1.00.
Top-100 accuracy (test): 0.99.


In [40]:
candidates_model.load_weights(filepath='./weights/trained_n_personalized_model')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x1eb31008fd0>

In [47]:
index = tfrs.layers.factorized_top_k.BruteForce(
    query_model=candidates_model.query_model,
    k=100
)

# trained_index
trained_index = index.index_from_dataset(
    tf.data.Dataset.zip((
        songs.batch(100).map(lambda x: x["song_id"]),
        songs.batch(100).map(lambda x: {
            'song_id': x["song_id"],
            'total_views': x["total_views"],
            'song_title': x["song_title"],
            'genre1': x["genre1"],
        }).map(candidates_model.candidate_model)))
)

index_songs = dict(songs_df[['song_id','song_title']].to_numpy())
index_songs

{'id_1': 'Call of the Mastodon',
 'id_2': 'Fear Itself',
 'id_3': 'Dimensions',
 'id_4': 'Las Numero 1 De La Sonora Santanera',
 'id_5': 'Friend Or Foe',
 'id_6': 'Muertos Vivos',
 'id_7': 'Ordinary Day',
 'id_8': 'Da Ghetto Psychic',
 'id_9': 'Gin & Phonic',
 'id_10': 'Pink World',
 'id_11': 'Superinstrumental',
 'id_12': 'I Need You',
 'id_13': 'The Way Home',
 'id_14': 'Placer & Castigo',
 'id_15': 'Arrivederci',
 'id_16': 'Dancing On The Ceiling',
 'id_17': 'Outskirts',
 'id_18': 'Cross Currents',
 'id_19': 'Panjtan Ka Ghulam',
 'id_20': 'Gold',
 'id_21': 'Whatever Happened To Boredom?',
 'id_22': '1942-1952 Jimmy Wakely',
 'id_23': 'All The Good Times',
 'id_24': 'Un cafe_ setanta matins',
 'id_25': 'Miss Machine',
 'id_26': 'Sue Thompson - Her Very Best',
 'id_27': 'Live',
 'id_28': 'Strictly Confidential',
 'id_29': 'The Real Twang Thang',
 'id_30': 'Willie Bobo\'s Finest Hour""',
 'id_31': 'Shake A Hand',
 'id_32': 'Occasional Rain',
 'id_33': 'The Emperor Falls',
 'id_34': 'Be

In [53]:
scores , titles = trained_index(queries = {
    'user_id':tf.constant(['id_103']),
    'song_id':tf.constant(['id_181']),
    }, k = 10
)
print("Recommendations for user id_103 who have listened to the song Bitter Suites to Succubi" )

titles = [index_songs[idx.decode('utf-8')] for idx in titles[0, :].numpy()]
dict(zip(titles, scores.numpy()[0]))

Recommendations for user id_103 who have listened to the song Bitter Suites to Succubi


{'Bitter Suites to Succubi': 28.014503,
 'That Which Remains': 25.185734,
 'Louie Bluie Film Soundtrack': 24.183182,
 'Pink Chokolate!': 24.181765,
 'Live Radio Sessions': 23.875671,
 'Chicago Blues Festival 1974 With Jimmy Dawkins': 23.564648,
 'We Sing.  We Dance.  We Steal Things.': 23.36516,
 'Pat Travers': 23.123913,
 'Italo Dance Collection_ Vol. 1': 22.93893,
 'Live From Dublin': 22.867775}