In [516]:
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import pandas as pd
import numpy as np
import json
import os
import sklearn
import sklearn.neighbors
import warnings
from pathlib import Path
warnings.filterwarnings('ignore')
from sklearn.preprocessing import LabelEncoder
from pandas.api.types import CategoricalDtype
from operator import itemgetter

# constants
RAW_DATA_PATH = Path('raw_data/')
DATAFRAME_PATH = Path('dataframes/')
TOTAL_TRACKS = 50
NUM_WITHHELD = 25
SELECTED_TRACK_FEATURES = ['danceability', 'energy', 'key', 'loudness',
                     'speechiness', 'acousticness', 'instrumentalness',
                     'liveness', 'valence', 'tempo']

In [476]:
tracks_features_df = pd.read_hdf(DATAFRAME_PATH / 'tracks.h5', 'tracks')
cat_type = CategoricalDtype(categories=tracks_features_df.track_uri, ordered=True)
tracks_features_df.track_uri = tracks_features_df.track_uri.astype(cat_type)
tracks_features_df = tracks_features_df.rename(columns={'track_uri':'tid'})
tracks_features_df = tracks_features_df[['tid',*SELECTED_TRACK_FEATURES]].sort_values('tid').reset_index(drop=True)

In [517]:
def make_playlist_dfs(path):
    playlists = []
    playlist_tracks = []
    for file in tqdm(os.listdir(path)):
        if not file.endswith(".json"):
            continue
        with open(path/file) as f:
            js_slice = json.load(f)
            for playlist in js_slice['playlists']:
                if playlist['num_tracks'] > TOTAL_TRACKS:
                    sorted_tracks = sorted(playlist['tracks'], key=itemgetter('pos')) 
                    for track in sorted_tracks[:TOTAL_TRACKS]:
                        yield track['track_uri'], playlist['pid'], track['pos']

playlist_tracks_df = pd.DataFrame(make_playlist_dfs(RAW_DATA_PATH),columns=['tid','pid','pos'])
playlist_tracks_df.tid = playlist_tracks_df.tid.astype(cat_type)

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:03<00:00,  2.91it/s]


In [521]:
all_playlist_ids = playlist_tracks_df.pid.unique()
train_pids, test_pids = train_test_split(all_playlist_ids,random_state=0, test_size=0.20)

# 1. Get tracks that are only from training playlists
# 2. Get the track features for each of these training tracks
# 3. Make test set only include tracks from training set
train_playlist_tracks_df = playlist_tracks_df.query('pid in @train_pids')
train_tracks_features_df = tracks_features_df.query('tid in @train_playlist_tracks_df.tid')
test_playlist_tracks_df = playlist_tracks_df.query('pid in @test_pids and tid in @train_playlist_tracks_df.tid')

# 1. Get the first NUM_WITHHELD tracks for each playlist in test
# 2. Get the track features for these with held tracks
# 3. Compute mean features by grouping incomplete playlists
test_playlist_tracks_incomplete_df = test_playlist_tracks_df.groupby('pid').head(NUM_WITHHELD)
test_tracks_incomplete_features_df = test_playlist_tracks_incomplete_df.merge(tracks_features_df,how='inner',on='tid')
test_playlist_incomplete_features = test_tracks_incomplete_features_df[['pid',*SELECTED_TRACK_FEATURES]].groupby('pid',as_index=False).mean()

In [526]:
# KNN model that will find 25 nearest neighbors to the current playlist
knn_clf = sklearn.neighbors.NearestNeighbors(n_neighbors=NUM_WITHHELD)
knn_clf.fit(train_tracks_features_df[SELECTED_TRACK_FEATURES])
distances, indices = knn_clf.kneighbors(test_playlist_incomplete_features[SELECTED_TRACK_FEATURES])

In [539]:
# for each test playlist, get the 25 next nearest predicted tracks and add them to a table for evaluation
def get_predicted_playlist_tracks():
    for index, row in test_playlist_incomplete_features.iterrows():
            predicted_tracks = train_tracks_features_df['tid'].iloc[indices[index]]
            for pos, predicted_track in enumerate(predicted_tracks):
                yield predicted_track, int(row['pid']),pos

test_predicted_playlist_tracks_df = pd.DataFrame(get_predicted_playlist_tracks(), columns =['tid', 'pid', 'pos'])
test_predicted_playlist_tracks_df

Unnamed: 0,tid,pid,pos
0,spotify:track:1mnPNAxWN8MvWYsjHWpnnc,17,0
1,spotify:track:0FZ7yrTAHEPBtVhxgpsgYf,17,1
2,spotify:track:2UAWCPSBKXi33Q0YMoDRRS,17,2
3,spotify:track:5dyUA8LLCsPPD9xTtE4XKz,17,3
4,spotify:track:7vFoFDWqTX0mHzLfrF1Cfy,17,4
...,...,...,...
24270,spotify:track:4MbS4oNEbGOaiCi1rfZ2Pw,9993,20
24271,spotify:track:2TQrgzmsmQhniZjHYO9Uzk,9993,21
24272,spotify:track:0dqvwSAycKNbTsn19A8RMF,9993,22
24273,spotify:track:2Guz1b911CbpG8L92cnglI,9993,23
