In [2]:
import os
import sys
import json
import numpy as np
import pandas as pd
from collections import defaultdict
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import NearestNeighbors

In [102]:
class module(object):
    def __init__(self):
        # read X_pca
        try:
            self.X_pca= pd.read_csv('X_pca.csv')
            print('Read X_pca successfully:\n',self.X_pca.head())
            self.X_pca = self.X_pca.to_numpy() 
        except FileNotFoundError as e:
            print('Can not find X_pca.csv.')
            sys.exit(0)

        # read labels
        try:
            self.labels= pd.read_csv('labels.csv')
            print('Read labels successfully:\n',self.labels.head())
            self.labels = self.labels.to_numpy() 
        except FileNotFoundError as e:
            print('Can not find labels.csv.')
            sys.exit(0)

        # read origin dataset
        try:
            self.scaled_songs= pd.read_csv('scaled_songs.csv')
            print('Read scaled_songs successfully:\n',self.scaled_songs.head())
            self.scaled_songs = self.scaled_songs.to_numpy()
        except FileNotFoundError as e:
            print('Can not find scaled_songs.csv.')
            sys.exit(0)

        # read filter_id_name
        try:
            self.filtered_id_name = pd.read_csv('filtered_id_name.csv')
            print('Read filtered_id_name successfully:\n',self.filtered_id_name.head())
        except FileNotFoundError as e:
            print('Can not find filtered_id_name.csv.')
            sys.exit(0)

        # load PCA
        self.pca = PCA(n_components=5) 
        self.pca.fit_transform(self.scaled_songs[:,1:])    

        # load GMM
        self.gmm = GaussianMixture(n_components=10, covariance_type='full', random_state=42)
        self.training()

    # training GMM model
    def training(self):
        print('trainning ....')
        self.gmm.fit(self.X_pca)
        print('trainning accomplished.')

    def predict(self,song_id):
        print('predicting ...\n')
        song_ids = self.scaled_songs[:, 0]
        indices = np.where(song_ids == song_id)[0]
        if len(indices) != 1:
            print('Sorry, this song is not in our database.')
            sys.exit(0)
        data = self.scaled_songs[indices[0],:]
        data = data[1:]
        data = data.reshape(1, -1)
        data_pca = self.pca.transform(data)
        predicted_label = self.gmm.predict(data_pca)[0]

        # select all data belong to this labels
        cluster_indices = np.where(self.labels == predicted_label)[0]
        cluster_data = self.X_pca[cluster_indices]

        # using nearest neighbor search
        nbrs = NearestNeighbors(n_neighbors=10, metric='euclidean')
        nbrs.fit(cluster_data)
        distances, local_indices = nbrs.kneighbors(data_pca)
        
        recommended_indices = cluster_indices[local_indices[0]]
        recommended_rows = self.filtered_id_name.iloc[recommended_indices]
        print('What I recommend are: ',recommended_rows)
        return None

In [99]:
modul = module()

Read X_pca successfully:
         PC1       PC2       PC3       PC4       PC5
0 -1.962109 -0.760090  0.113974  1.127371  2.272101
1 -2.129417  0.377469  0.198823  0.376483  3.298361
2 -2.011111 -0.936744  1.793915 -0.199390  0.868910
3 -1.696710  0.058786  0.787465 -1.253576  1.069855
4 -1.462294 -0.707375 -0.909267  1.365771  0.470650
Read labels successfully:
    label
0      6
1      7
2      6
3      4
4      6
Read scaled_songs successfully:
                        id  danceability    energy       key  loudness  \
0  7lmeHLHBe4nmXzuXc0HDjk     -0.239643  1.545439  0.504262  0.855004   
1  1wsRitfRRtWyEapl0q22o8      0.448997  1.470214  1.632828  0.794662   
2  1hR0fIFK2qRG3f3RF70pb7     -1.067078  1.516782  0.504262  0.850871   
3  2lbASgTSoDO7MTuLAXlTW0     -0.399791  1.506035  1.632828  0.783751   
4  1MQTmpYOZ6fcMQc56Hdo7T     -0.474528  1.369913 -0.906446  0.635127   

   speechiness  acousticness  instrumentalness  liveness   valence     tempo  \
0    -0.084450     -1.018675 

In [100]:
modul.predict('7lmeHLHBe4nmXzuXc0HDjk')

predicting ...

What I recommend are:                              id  \
0       7lmeHLHBe4nmXzuXc0HDjk   
134606  0uukw2CgEIApv4IWAjXrBC   
90295   4JJTel4WTX7VeyN3fe5CbS   
63048   5qFL2uwfnGU8FccwLMgPNQ   
62472   44w63XqGr3sATAzOnOySgF   
42085   3XorCFmcupSm5QS6hA9g4N   
81540   4ATmY1hv93ehw77LrIdbEh   
126859  6fybp4N6eW3bsFAvARxyVe   
96608   0HDaKOlVAfUWXdFR2RhBtN   
129531  3EWQsarNCItwHn9hE2MHTn   

                                                     name  \
0                                                 Testify   
134606                                              Dead!   
90295   Take Ü There (feat. Kiesza) - Missy Elliott Remix   
63048                                    Internet Friends   
62472                                   If Everyone Cared   
42085                                          I Want You   
81540                    Finale (feat. Nicholas Petricca)   
126859           Ain't Talkin' 'Bout Love - 2015 Remaster   
96608                                