In [21]:
# In this project we will create an app where a user enters their login information and then the machine learning algorithm predicts what type of music the user likes 

In [36]:
import pandas as pd
# Choose our Machine Learning Algorithm. We will be using a decision tree
from sklearn.tree import DecisionTreeClassifier
# Import Module that automatically splits the data into a training and test set
from sklearn.model_selection import train_test_split
# Import Accuracy Score
from sklearn.metrics import accuracy_score


# Import Data
music_data = pd.read_csv('music.csv')

# Next we clean the data by removing duplicates and incomplete data
# Split the data into input set and output set.
# The input set is the age and gender. The output is the genre

X = music_data.drop(columns=['genre'])
y = music_data['genre']
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2)

model = DecisionTreeClassifier()

# Train model with the Decision Tree Classifier Algorithm
model.fit(X_train.values,y_train.values)

# Predict an outcome:
predictions = model.predict(X_test.values)
predictions

score = accuracy_score(y_test.values,predictions)
score

1.0

In [40]:
# Model Persistance: How to keep trained model and use it over and over again without having to retrain on data
import joblib

joblib.dump(model, 'Music-Recommender.joblib')
trained_model = joblib.load('Music-Recommender.joblib')
predictions = trained_model.predict([[21,1]])
predictions

array(['HipHop'], dtype=object)

In [41]:
# Visualizing Decision Trees
import joblib
trained_model = joblib.load('Music-Recommender.joblib')
X = music_data.drop(columns=['genre'])
y = music_data['genre']

import pandas as pd
music_data = pd.read_csv('music.csv')

from sklearn import tree
tree.export_graphviz(trained_model,
                     out_file='music-recommender.dot', 
                     feature_names=['age','gender'],
                     class_names = sorted(y.unique()),
                     label = 'all',
                     rounded = True,
                     filled = True)