# 05 - Model and Prediction - SMALL COLAB Version

Model selection and tuning of the model

## Data files needed to run this notebook:
- X_train.pkl.gz
- X_test.pkl.gz
- y_test_.pkl.gz
- y_train.pkl.gz

all the results from notebook 04

## Settings:
- set `COLAB = True` if you run this on Colab. Data can be placed in the root directory

In [29]:
# setup
import sys
import subprocess
import pkg_resources
from collections import Counter
import re
from numpy import log, mean, matmul


# required = {'spacy', 'scikit-learn', 'numpy', 
#             'pandas', 'torch', 'matplotlib',
#             'transformers', 'allennlp==0.9.0'}
# installed = {pkg.key for pkg in pkg_resources.working_set}
# missing = required - installed

# if missing:
#     python = sys.executable
#     subprocess.check_call([python, '-m', 'pip', 'install', *missing], stdout=subprocess.DEVNULL)

import spacy
import numpy as np
import pandas as pd

# SciKit Learn
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
from sklearn.metrics import accuracy_score
from sklearn.svm import LinearSVC
from sklearn.svm import SVC


# Spacy
from spacy.lang.en import English
en = English()

# !python -m spacy download en_core_web_md # includes GloVe Vectors
# !python -m spacy download en_core_web_sm
# !python -m spacy download en

# import en_core_web_sm
# import en_core_web_md


# PyTorch
import torch
# import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from sklearn.model_selection import train_test_split

# File managment
import os
from os import listdir
from pathlib import Path
import pickle
import gzip

In [47]:
LOAD_DATA = False # read save data or regenerate data
SAVE_DATA = False # overwrite generated data? 

COLAB = True

In [48]:
if COLAB:
  # Google Colab
  path = "./"
  device = torch.device("cuda:0") # use GPU, change 
else:
  # Laptop
  path = "./data/"
  device = torch.device("cpu")
#   !pip install ipywidgets
#   !jupyter nbextension enable --py widgetsnbextension

In [49]:
# df_total = pd.read_pickle(f'{path}df_total_cleaned.pkl.gz')

X_train = pd.read_pickle(f'{path}X_train.pkl.gz')
y_train = pd.read_pickle(f'{path}y_train.pkl.gz')



In [50]:
text_cols = ["SName", "Lyric", "Artist"]
genres = list(pd.DataFrame(y_train)["Genre"].unique())

In [51]:
y_train

19167     Hip Hop
88089     Hip Hop
41285     Hip Hop
12664        Rock
52249     Hip Hop
           ...   
127879      Metal
80491        Rock
66750        Rock
133337      Metal
99994        Rock
Name: Genre, Length: 8400, dtype: object

In [52]:
test_size = 0.3
tmp = X_train.drop(text_cols,axis=1)
tmp

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,728,729,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767
19167,0.001492,-0.273682,0.150513,-0.246460,-0.050476,-0.072327,0.447510,0.275391,-0.255859,-0.431396,-0.081238,0.441895,-0.335938,-0.404785,-0.085999,0.527344,0.418213,0.027115,0.174438,-0.009880,-0.281738,-0.009178,0.066895,-0.239380,0.115173,0.156494,-0.669434,-0.170898,-0.112000,-0.060516,-0.484375,0.589844,0.041840,-0.434814,0.671387,0.115479,0.577148,-0.383057,-0.043182,-0.361084,...,-0.031464,-0.180786,0.007504,0.675781,0.023895,-0.534668,0.239746,-0.245850,0.010201,-0.446045,-0.302002,0.046783,-0.059479,-0.473389,-0.381592,1.175781,-0.134888,0.250488,-0.156982,-0.307861,-0.015541,-0.601074,-0.520020,0.265625,-5.578125,-0.268555,0.041260,-0.620117,-0.388916,0.096130,0.408203,-1.025391,-0.107178,-0.021561,0.352051,0.488770,-0.528320,-0.062469,0.002211,0.423340
88089,0.037231,0.150635,0.496826,0.196777,-0.322998,-0.077271,0.991211,0.152954,-0.482910,-0.359619,-0.137451,-0.024475,-0.678223,0.299316,0.220581,0.517090,0.036346,-0.056274,0.417236,0.355957,0.022278,-0.016312,0.398193,-0.081360,0.247314,-0.013885,-0.188965,-0.506348,-0.417236,0.346924,-0.166260,0.611816,-0.352539,-0.379639,-0.019470,-0.540527,0.359131,-0.183228,0.541504,-0.265137,...,-0.192505,-0.093079,-0.156372,0.909668,-0.449951,-0.079224,-0.250977,0.151123,-0.406006,-0.051727,-0.394043,0.213257,-0.023788,-0.221436,-0.068909,0.601562,-0.221191,0.357666,-0.402588,-0.054291,-0.025070,0.028488,-0.106384,0.453857,-4.980469,-0.636230,-0.126953,-0.451660,-0.179565,-0.520508,0.160156,-0.338623,0.003363,0.037933,0.612305,0.417236,-0.303955,-0.520508,0.493164,0.340820
41285,-0.041718,-0.246704,-0.162720,0.137085,-0.043579,-0.229004,0.499023,0.824219,0.036102,-0.441650,-0.287842,0.302002,-0.357422,-0.162109,0.616211,0.209351,-0.001165,0.073853,0.220581,0.062866,-0.047211,-0.130005,0.078247,-0.201782,-0.113220,-0.219360,-0.531250,-0.626953,0.180786,-0.088928,-0.255615,0.347656,0.206543,-0.015564,0.802246,0.007030,-0.061890,-0.476074,0.155151,0.197388,...,-0.265869,-0.166504,-0.100159,0.726074,-0.086182,0.092102,0.175537,-0.158691,0.028442,-0.524902,0.104309,0.338623,-0.042816,-0.288574,-0.396729,0.647461,0.274902,0.750488,-0.162964,0.336182,-0.124146,-0.708496,-0.353516,-0.080688,-5.890625,-0.016678,-0.112183,-0.858398,-0.758789,-0.257568,0.738770,-0.558594,-0.090698,-0.180176,0.313232,0.296143,-0.604004,-0.382080,0.578613,0.207520
12664,-0.325684,0.143311,-0.226196,-0.002642,-0.749023,0.127930,0.955078,0.399902,-0.119263,-0.517090,-0.204956,-0.348389,-0.453613,0.671875,0.931641,0.144043,0.054565,0.833496,0.304688,0.268066,-0.174561,-0.203003,0.282715,-0.172363,0.033081,-0.107117,-0.324707,-0.375488,-0.130981,-0.239624,-0.355713,0.708984,-0.438232,-0.604492,0.538086,-0.252441,0.388916,0.180908,0.363770,-0.345215,...,-0.143433,-0.485840,-0.148071,0.555664,-0.346436,-0.685059,0.348145,0.134766,-0.212891,-0.307129,-0.403809,0.779297,0.333008,-0.410889,-0.045044,0.955078,0.443115,0.821777,-0.252441,-0.125244,0.316895,-0.520996,0.196411,0.341064,-5.328125,-0.554688,-0.337158,-0.794922,0.208862,0.528320,0.154175,-0.275879,-0.046143,-0.173706,0.242676,0.716797,-0.252197,-0.583008,0.321533,0.259033
52249,0.061066,0.076416,0.814453,-0.269531,-0.344238,0.150757,0.812500,0.367432,-0.025284,-0.302490,-0.051453,0.136597,-0.394287,-0.042206,0.085449,0.557129,0.353027,-0.274414,0.265625,0.221313,-0.042145,-0.318848,0.266602,-0.055389,0.083191,0.169922,-0.253662,-0.134888,0.048798,0.558105,-0.302979,0.624512,-0.003006,-0.409180,0.468506,-0.600098,0.716309,-0.145508,0.448242,-0.265137,...,-0.317871,0.006969,-0.120300,0.863770,-0.075195,-0.203613,-0.382568,-0.025375,-0.026672,-0.160645,-0.152100,0.079102,0.102905,-0.731934,-0.254395,1.055664,-0.647461,0.386719,-0.542969,0.124084,0.029068,-0.443359,-0.394043,0.487061,-4.847656,-0.353271,0.166138,-0.452637,-0.623535,-0.463867,-0.408936,-0.527344,-0.401855,0.095398,0.742676,0.281250,-0.496582,-0.437256,0.364258,0.228882
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
127879,-0.169678,0.041229,0.236816,-0.042542,-0.390625,-0.249268,0.797852,0.419189,-0.269531,-0.881348,0.283203,-0.315674,-0.225220,0.750977,-0.264160,0.145996,0.493408,0.444092,0.487305,0.300293,0.118835,-0.132568,0.640625,-0.134399,0.342773,-0.107422,-0.538086,0.051270,0.029022,0.494873,-0.059296,0.302734,0.121582,-0.775391,0.310303,-0.322754,0.295654,-0.295166,0.272217,0.430176,...,-0.672363,-0.103333,-0.054535,0.238525,0.277100,-0.812988,-0.215820,0.341553,-0.468506,-0.202148,-0.012787,-0.363525,-0.014145,0.385742,0.023468,0.827148,0.192261,-0.101196,-0.141602,-0.149170,0.308350,-0.067932,-0.267578,0.442139,-4.933594,-0.438965,-0.060242,-0.156250,0.106506,-0.392090,-0.206543,-0.823730,-0.051056,-0.105835,0.450439,0.190796,-0.207031,-0.497314,0.233032,-0.636230
80491,-0.017868,0.076111,0.343506,-0.726562,-0.599609,0.000035,0.419678,0.360107,0.517578,-0.833008,0.133667,0.288086,-0.236694,0.443359,0.341309,0.136475,0.438721,0.131470,0.052002,0.017960,-0.204834,0.119690,0.112793,0.211792,0.192627,0.125244,-0.288086,-0.123169,0.099731,0.099426,-0.162842,0.489990,-0.192139,-0.637207,0.658691,-0.134644,0.090088,-0.043518,0.305908,-0.226196,...,-0.236084,-0.075806,0.239258,0.950684,-0.218140,-0.249268,0.467041,0.086365,-0.368652,-0.419922,-0.633789,0.541016,0.234619,-0.080811,0.077332,0.668945,0.080688,0.064087,-0.240601,-0.270264,0.257324,-0.272949,-0.176147,0.298584,-5.941406,-0.155762,0.031799,-0.216797,-0.382324,-0.347412,0.453857,-0.279053,-0.021866,-0.483154,0.136719,0.191284,-0.330811,-0.092896,0.070190,0.769043
66750,-0.179810,-0.208984,-0.071289,0.151611,-0.094727,-0.214478,0.788574,0.704102,-0.184082,-0.365723,-0.155273,0.162354,-0.180176,0.273193,0.255371,0.219971,0.014046,0.687012,0.570801,0.197021,-0.217285,0.025986,0.299561,-0.076721,0.312012,0.030197,-0.500488,-0.373047,0.114258,-0.196899,-0.232788,0.399414,-0.200562,-0.518555,0.427979,-0.056824,0.647461,-0.203125,0.146729,0.037109,...,-0.063416,-0.127563,-0.278320,0.545410,-0.053894,-0.519531,-0.061707,0.084167,-0.330078,-0.374512,-0.229370,-0.082642,0.113281,0.044281,-0.229004,0.651367,0.415527,0.456787,-0.083984,0.013496,0.080872,-0.166382,-0.049713,0.364746,-6.593750,-0.250000,-0.008354,-0.287354,-0.019943,-0.253174,0.017578,-0.797852,0.142456,-0.049011,0.182617,0.599609,-0.137939,-0.099487,-0.040466,0.257568
133337,-0.553223,0.139282,0.312256,0.208862,-0.581055,-0.247070,0.706543,0.520508,0.349121,-0.726074,0.071350,0.096497,-0.344482,-0.035065,0.369629,0.245605,0.229370,0.522949,0.848145,0.458740,0.150635,-0.423096,0.501465,0.227295,0.607422,-0.039581,-0.427490,-0.255859,-0.123535,-0.170044,0.012985,0.593750,-0.204102,-0.582520,0.712891,-0.234009,0.781738,-0.579590,0.753906,0.154297,...,-0.228638,-0.085999,0.122498,0.687988,0.342773,-0.433105,-0.331299,-0.442871,-0.267578,-0.053253,0.038696,0.188599,0.171021,0.041107,-0.028442,1.208008,0.260254,0.818848,0.064819,0.040680,0.153931,-0.506348,-0.083496,0.056366,-5.609375,-0.537598,-0.261475,-0.466309,0.448730,-0.231323,-0.300049,-0.753906,-0.207153,-0.024521,0.374512,0.519043,0.003078,-0.461670,0.013908,-0.375977


In [53]:
X_train_set, X_val_set, y_train_set, y_val_set = train_test_split(tmp, y_train, test_size=test_size, random_state=0, shuffle = True, stratify = y_train)

In [54]:
from sklearn.svm import LinearSVC
from sklearn.svm import SVC

def train_SVC(x,y,x_val, y_val):
  model = LinearSVC(max_iter=2000)
  model.fit(x, y)
  val_preds = model.predict(x_val)
  acc = accuracy_score(y_val, val_preds)
  print(f"Accuracy: {round(acc,2)}")
  return (acc, model, val_preds)

In [55]:
(acc, model, predictions) = train_SVC(X_train_set, y_train_set, X_val_set, y_val_set)

# running rock,hiphop, metal: -> 92%
# running pop,hiphop, metal: -> 85%
# running 4 genres -> 75%
# running pop, rock, hiphop _> 0.67

Accuracy: 0.92




In [56]:
# confusiont matrix
from sklearn.metrics import confusion_matrix

In [57]:
def print_confustion_matrix(model, y_val_set, predictions):
  cm = confusion_matrix(y_val_set, predictions)
  df = pd.DataFrame(cm, columns = model.classes_, index= model.classes_)
  print(df)
  
  
  

In [58]:
print_confustion_matrix(model, y_val_set, predictions)
  

         Hip Hop  Metal  Rock
Hip Hop      747      0    93
Metal          0    840     0
Rock         101      4   735


In [59]:
def wrong_classifications(X_train, y, predictions, genres):
  print("Truth - predicted")
  predictions_df = pd.DataFrame(predictions, columns = ["Genre_Predicted"])
  truth_df = pd.DataFrame(y)
  truth_df.columns = ["Genre_Truth"]
  combined_df = pd.concat([truth_df.reset_index(drop=True), predictions_df.reset_index(drop=True)], axis=1)
  for i in genres:
    for j in genres:
      if i!=j:
        idx = combined_df.query(f"Genre_Truth =='{i}' != Genre_Predicted=='{j}'").index
        if len(idx)>0:
          print("------------------------------")
          print(f"{i} - {j}")
          print("------------------------------")
          print(X_train.iloc[idx]["Lyric"])
 

In [60]:
 
wrong_classifications(X_train, y_val_set,predictions  , genres)

Truth - predicted
------------------------------
Hip Hop - Rock
------------------------------
6121      previously on ashanti "always there when you c...
19946     Before I get started. polo this beat is retart...
126819    Patrz   w siebie i widz   coraz mniej Powoli p...
33512     Fatjoe: TS. Thalía: Hey baby. Fatjoe: Yeah. Th...
126051    I felt the ground start to shake  Oh God  oh G...
                                ...                        
30991     Its 2002, everything was totally new. We were ...
21676     It's time to make a difference. I know he's ba...
112568    Cross me once more fool me cross me twice and ...
6192       why dont you take me tonight. take me away wh...
66616     she said that she'd take it off right here. ta...
Name: Lyric, Length: 93, dtype: object
------------------------------
Rock - Hip Hop
------------------------------
26589     I didn't hear you leave,. I wonder how am I st...
59383     Your breath is sweet. Your eyes are like two j...
49296    

In [61]:
from sklearn.metrics import classification_report
print(classification_report(y_val_set, predictions, target_names=genres))

              precision    recall  f1-score   support

     Hip Hop       0.88      0.89      0.89       840
        Rock       1.00      1.00      1.00       840
       Metal       0.89      0.88      0.88       840

    accuracy                           0.92      2520
   macro avg       0.92      0.92      0.92      2520
weighted avg       0.92      0.92      0.92      2520



In [62]:
# Try different kernels and values of hyper parameters to see if we can improve the score
# this method uses cross validation so we could use the whole data set (we do ot so we can use the val to get the test score of the fitted model, with the parameters. )

In [63]:
%%time
from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
parameters = {'kernel':('linear', 'rbf', 'poly', 'sigmoid'), 'C':[1, 10]}
svc = svm.SVC()
clf = GridSearchCV(svc, parameters)
clf.fit(X_train_set, y_train_set)


CPU times: user 5min 8s, sys: 103 ms, total: 5min 8s
Wall time: 5min 9s


In [64]:
print(clf.best_params_)

{'C': 10, 'kernel': 'poly'}


In [65]:
list(zip(clf.cv_results_['params'], clf.cv_results_['mean_test_score']))

[({'C': 1, 'kernel': 'linear'}, 0.9285714285714285),
 ({'C': 1, 'kernel': 'rbf'}, 0.9396258503401361),
 ({'C': 1, 'kernel': 'poly'}, 0.9430272108843537),
 ({'C': 1, 'kernel': 'sigmoid'}, 0.9333333333333333),
 ({'C': 10, 'kernel': 'linear'}, 0.9171768707482993),
 ({'C': 10, 'kernel': 'rbf'}, 0.9430272108843537),
 ({'C': 10, 'kernel': 'poly'}, 0.94421768707483),
 ({'C': 10, 'kernel': 'sigmoid'}, 0.9047619047619048)]

In [66]:
pd.DataFrame(clf.cv_results_)

Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_C,param_kernel,params,split0_test_score,split1_test_score,split2_test_score,split3_test_score,split4_test_score,mean_test_score,std_test_score,rank_test_score
0,5.504284,0.181683,1.11517,0.023845,1,linear,"{'C': 1, 'kernel': 'linear'}",0.933673,0.931973,0.928571,0.918367,0.930272,0.928571,0.005378,6
1,6.460202,0.041378,1.812262,0.074424,1,rbf,"{'C': 1, 'kernel': 'rbf'}",0.942177,0.941327,0.938776,0.926871,0.94898,0.939626,0.007215,4
2,6.054304,0.154642,1.622746,0.027551,1,poly,"{'C': 1, 'kernel': 'poly'}",0.943878,0.946429,0.942177,0.932823,0.94983,0.943027,0.005717,2
3,7.570076,0.223949,2.100247,0.084721,1,sigmoid,"{'C': 1, 'kernel': 'sigmoid'}",0.937925,0.931122,0.929422,0.92432,0.943878,0.933333,0.006837,5
4,6.106257,0.298831,1.013373,0.015316,10,linear,"{'C': 10, 'kernel': 'linear'}",0.914116,0.92432,0.92517,0.911565,0.910714,0.917177,0.006286,7
5,5.525364,0.156075,1.528579,0.045745,10,rbf,"{'C': 10, 'kernel': 'rbf'}",0.946429,0.94898,0.941327,0.930272,0.948129,0.943027,0.006908,2
6,5.5297,0.331987,1.491638,0.071126,10,poly,"{'C': 10, 'kernel': 'poly'}",0.94898,0.95068,0.939626,0.933673,0.948129,0.944218,0.006512,1
7,5.759659,0.294981,1.047126,0.01688,10,sigmoid,"{'C': 10, 'kernel': 'sigmoid'}",0.907313,0.906463,0.907313,0.905612,0.897109,0.904762,0.003878,8


Sigmoid is doing the worst
- Linear betwee 0.91 and 0.92
- rbf: 0.93, 0.94
- Poly: 0.94 (may be overfitting) but ar the best scores
- Sigmoid, 0.93, 0.90

# SVC: 
Choose the best kernel and optimize that one.
We have vectors in a many dimensional space, so we don't really know what is the valid choice. We just try to optimize the problem, not trying to explain what happens under the hood. 

## Radial Kernel - RBF

In [103]:
%%time
from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
parameters = {'kernel':(['rbf']), 'C':[0.5, 1, 5, 10,20]}
svc = svm.SVC()
clf = GridSearchCV(svc, parameters)
clf.fit(X_train_set, y_train_set)


CPU times: user 3min 9s, sys: 55.1 ms, total: 3min 9s
Wall time: 3min 10s


In [104]:
print(clf.best_params_)

{'C': 5, 'kernel': 'rbf'}


In [105]:
list(zip(clf.cv_results_['params'], clf.cv_results_['mean_test_score']))

[({'C': 0.5, 'kernel': 'rbf'}, 0.936734693877551),
 ({'C': 1, 'kernel': 'rbf'}, 0.9396258503401361),
 ({'C': 5, 'kernel': 'rbf'}, 0.9438775510204082),
 ({'C': 10, 'kernel': 'rbf'}, 0.9430272108843537),
 ({'C': 20, 'kernel': 'rbf'}, 0.9431972789115648)]

- For 3 genres: Best value for RBF: {'C': 5, 'kernel': 'rbf'} 0.94
- For 4 genres: ({'C': 5, 'kernel': 'rbf'}, 0.7687074829931972),

## Polynomial Kernel
Optimize _C_ value and degrees of the polynomial approximation.

In [106]:
%%time
parameters = {'kernel':(['poly']), 'C':[0.5, 1, 5, 10,20], 'degree':[2,3,4]}
svc = svm.SVC()
clf = GridSearchCV(svc, parameters)
clf.fit(X_train_set, y_train_set)

CPU times: user 8min 40s, sys: 129 ms, total: 8min 40s
Wall time: 8min 40s


In [107]:
print(clf.best_params_)

{'C': 5, 'degree': 4, 'kernel': 'poly'}


In [108]:
pd.DataFrame(clf.cv_results_)

Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_C,param_degree,param_kernel,params,split0_test_score,split1_test_score,split2_test_score,split3_test_score,split4_test_score,mean_test_score,std_test_score,rank_test_score
0,6.113123,0.076298,1.723581,0.015218,0.5,2,poly,"{'C': 0.5, 'degree': 2, 'kernel': 'poly'}",0.941327,0.937925,0.937075,0.92602,0.948129,0.938095,0.007183,15
1,5.991432,0.055215,1.695894,0.015683,0.5,3,poly,"{'C': 0.5, 'degree': 3, 'kernel': 'poly'}",0.938776,0.940476,0.937925,0.92602,0.948129,0.938265,0.007106,14
2,6.079806,0.065062,1.704202,0.022001,0.5,4,poly,"{'C': 0.5, 'degree': 4, 'kernel': 'poly'}",0.936224,0.940476,0.937925,0.929422,0.95068,0.938946,0.006917,13
3,5.386266,0.081984,1.528348,0.017709,1.0,2,poly,"{'C': 1, 'degree': 2, 'kernel': 'poly'}",0.938776,0.944728,0.940476,0.927721,0.94898,0.940136,0.007147,11
4,5.385731,0.08007,1.528152,0.026072,1.0,3,poly,"{'C': 1, 'degree': 3, 'kernel': 'poly'}",0.943878,0.946429,0.942177,0.932823,0.94983,0.943027,0.005717,6
5,5.579892,0.07534,1.583597,0.030671,1.0,4,poly,"{'C': 1, 'degree': 4, 'kernel': 'poly'}",0.943878,0.948129,0.943878,0.934524,0.953231,0.944728,0.006155,2
6,4.665254,0.038387,1.347175,0.013235,5.0,2,poly,"{'C': 5, 'degree': 2, 'kernel': 'poly'}",0.943878,0.947279,0.940476,0.930272,0.947279,0.941837,0.006309,8
7,4.964838,0.127233,1.414579,0.020016,5.0,3,poly,"{'C': 5, 'degree': 3, 'kernel': 'poly'}",0.948129,0.945578,0.942177,0.932823,0.94983,0.943707,0.006022,4
8,5.241765,0.081807,1.499978,0.029474,5.0,4,poly,"{'C': 5, 'degree': 4, 'kernel': 'poly'}",0.94983,0.948129,0.943027,0.936224,0.94898,0.945238,0.005091,1
9,4.658473,0.070518,1.329366,0.021036,10.0,2,poly,"{'C': 10, 'degree': 2, 'kernel': 'poly'}",0.94983,0.94898,0.940476,0.931973,0.942177,0.942687,0.006489,7


A lot of calculations, but the differences are very small. So the poly nomial with 2 degrees and C=5 will do just fine. Better to have less degrees, than more to prevent overfitting.

For 4 categories:
- {'C': 1, 'degree': 3, 'kernel': 'poly'}


# Stochastic gradient descent

In [109]:
# from sklearn.preprocessing import LabelEncoder
# le = LabelEncoder()
# y_train_num = le.fit_transform(y_train_set)
# y_val_num = le.transform(y_val_set)


In [110]:
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline

# Always scale the input. The most convenient way is to use a pipeline.
clf = make_pipeline(StandardScaler(),
                    SGDClassifier(max_iter=1000, tol=1e-3))

clf.fit(X_train_set, y_train_set)

Pipeline(memory=None,
         steps=[('standardscaler',
                 StandardScaler(copy=True, with_mean=True, with_std=True)),
                ('sgdclassifier',
                 SGDClassifier(alpha=0.0001, average=False, class_weight=None,
                               early_stopping=False, epsilon=0.1, eta0=0.0,
                               fit_intercept=True, l1_ratio=0.15,
                               learning_rate='optimal', loss='hinge',
                               max_iter=1000, n_iter_no_change=5, n_jobs=None,
                               penalty='l2', power_t=0.5, random_state=None,
                               shuffle=True, tol=0.001, validation_fraction=0.1,
                               verbose=0, warm_start=False))],
         verbose=False)

In [111]:
predictions = clf.predict(X_val_set)

In [112]:
accuracy_score(y_val_set, predictions)

0.928968253968254

The performance here is worse.

# KNeigbors

In [113]:
from sklearn.neighbors import KNeighborsClassifier
neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X_train_set, y_train_set)

predictions = neigh.predict(X_val_set)
accuracy_score(y_val_set, predictions)
# print(neigh.predict_proba([[0.9]]))

0.888095238095238

So this one in less accurate. But let's see if we can improve the results. 

In [114]:
# Tuning of the results

In [115]:
%%time 
parameters = {'n_neighbors':[3,4,5,6,7]}
clf = GridSearchCV(neigh, parameters)
clf.fit(X_train_set, y_train_set)

CPU times: user 4min 30s, sys: 58.4 ms, total: 4min 30s
Wall time: 4min 30s


In [116]:
print(clf.best_params_)

{'n_neighbors': 5}


In [117]:
pd.DataFrame(clf.cv_results_)

Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_n_neighbors,params,split0_test_score,split1_test_score,split2_test_score,split3_test_score,split4_test_score,mean_test_score,std_test_score,rank_test_score
0,0.453238,0.01289,10.259557,0.00651,3,{'n_neighbors': 3},0.911565,0.897109,0.914116,0.893707,0.909014,0.905102,0.008149,3
1,0.445331,0.002572,10.301003,0.023142,4,{'n_neighbors': 4},0.903912,0.880952,0.90051,0.887755,0.891156,0.892857,0.008384,5
2,0.450969,0.010486,10.383743,0.053917,5,{'n_neighbors': 5},0.92517,0.904762,0.912415,0.897959,0.920918,0.912245,0.010024,1
3,0.446359,0.003466,10.3965,0.023241,6,{'n_neighbors': 6},0.911565,0.893707,0.903912,0.894558,0.917517,0.904252,0.009324,4
4,0.445768,0.003286,10.459926,0.029637,7,{'n_neighbors': 7},0.917517,0.903061,0.907313,0.905612,0.92517,0.911735,0.008321,2


In [118]:
neigh = KNeighborsClassifier(n_neighbors=clf.best_params_["n_neighbors"])
neigh.fit(X_train_set, y_train_set)

predictions = neigh.predict(X_val_set)
accuracy_score(y_val_set, predictions)
# print(neigh.predict_proba([[0.9]]))

0.8952380952380953


A little improvement over the 3 neighbors, but not so much

# Naive Bayes

In [119]:
from sklearn.naive_bayes import GaussianNB
clf = GaussianNB()
clf.fit(X_train_set, y_train_set)

predictions = clf.predict(X_val_set)
accuracy_score(y_val_set, predictions)

0.9087301587301587

So this option also performs less. The assumption for Guassian distribution is probably als not valid

# Neural Network

In [120]:
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_classification

clf = MLPClassifier(random_state=1, hidden_layer_sizes=(500,500), max_iter=500).fit(X_train_set, y_train_set)

predictions = clf.predict(X_val_set)
accuracy_score(y_val_set, predictions)

0.9444444444444444

In [121]:
clf.predict_proba(X_val_set[0:10])

array([[1.06823831e-18, 1.00000000e+00, 4.39361605e-23],
       [1.00000000e+00, 5.75843456e-16, 1.89994343e-15],
       [8.51008301e-06, 3.02647639e-08, 9.99991460e-01],
       [8.09631690e-20, 1.00000000e+00, 2.84926628e-15],
       [2.19418652e-11, 7.23780273e-13, 1.00000000e+00],
       [1.00000000e+00, 1.21130592e-12, 1.50231330e-10],
       [5.89613746e-08, 2.35962731e-09, 9.99999939e-01],
       [1.00000000e+00, 2.32063741e-16, 1.82761602e-15],
       [9.74750658e-01, 2.44990193e-08, 2.52493172e-02],
       [2.53563641e-09, 1.10371443e-11, 9.99999997e-01]])

In [122]:
predictions = clf.predict(X_val_set)
accuracy_score(y_val_set, predictions)

0.9444444444444444

In [123]:
clf.score(X_val_set, y_val_set)

0.9444444444444444

In [124]:
# All the results are in the range of 92%-94%
print_confustion_matrix(clf, y_val_set, predictions)

         Hip Hop  Metal  Rock
Hip Hop      762      0    78
Metal          1    839     0
Rock          58      3   779


- different classifier have marginal effect
- There must be something in the dataset in the Metal, it's classified correctly almost too much compared to the other genres. But is not clear what its. 
    - Some different languages. 

Best model: RBC for SVC
({'C': 5, 'kernel': 'rbf'}, 0.7687074829931972) for 4 category classifier.

Neural net: 500x500 hidden layers.
 - Pop and rock are the genres most closely related.

# Result of the chosen model

- For the front end we'll set up the system with SVC, Radial kernel, with C=5
- We fit the model on the complete training set
- We evaluatoin on the test set

In [125]:
X_train.head()

Unnamed: 0,SName,Lyric,Artist,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,...,728,729,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767
19167,Back To Sleep,"I know it's late, I know it's late. And baby I...",Chris Brown,0.001492,-0.273682,0.150513,-0.24646,-0.050476,-0.072327,0.44751,0.275391,-0.255859,-0.431396,-0.081238,0.441895,-0.335938,-0.404785,-0.085999,0.527344,0.418213,0.027115,0.174438,-0.00988,-0.281738,-0.009178,0.066895,-0.23938,0.115173,0.156494,-0.669434,-0.170898,-0.112,-0.060516,-0.484375,0.589844,0.04184,-0.434814,0.671387,0.115479,0.577148,...,-0.031464,-0.180786,0.007504,0.675781,0.023895,-0.534668,0.239746,-0.24585,0.010201,-0.446045,-0.302002,0.046783,-0.059479,-0.473389,-0.381592,1.175781,-0.134888,0.250488,-0.156982,-0.307861,-0.015541,-0.601074,-0.52002,0.265625,-5.578125,-0.268555,0.04126,-0.620117,-0.388916,0.09613,0.408203,-1.025391,-0.107178,-0.021561,0.352051,0.48877,-0.52832,-0.062469,0.002211,0.42334
88089,The Best,The Best. Soulja Boy. Soulja! Soulja! Soulja! ...,Soulja Boy,0.037231,0.150635,0.496826,0.196777,-0.322998,-0.077271,0.991211,0.152954,-0.48291,-0.359619,-0.137451,-0.024475,-0.678223,0.299316,0.220581,0.51709,0.036346,-0.056274,0.417236,0.355957,0.022278,-0.016312,0.398193,-0.08136,0.247314,-0.013885,-0.188965,-0.506348,-0.417236,0.346924,-0.16626,0.611816,-0.352539,-0.379639,-0.01947,-0.540527,0.359131,...,-0.192505,-0.093079,-0.156372,0.909668,-0.449951,-0.079224,-0.250977,0.151123,-0.406006,-0.051727,-0.394043,0.213257,-0.023788,-0.221436,-0.068909,0.601562,-0.221191,0.357666,-0.402588,-0.054291,-0.02507,0.028488,-0.106384,0.453857,-4.980469,-0.63623,-0.126953,-0.45166,-0.179565,-0.520508,0.160156,-0.338623,0.003363,0.037933,0.612305,0.417236,-0.303955,-0.520508,0.493164,0.34082
41285,Just Askin',"Wassup, in your world?. And are you still cool...",Iggy Azalea,-0.041718,-0.246704,-0.16272,0.137085,-0.043579,-0.229004,0.499023,0.824219,0.036102,-0.44165,-0.287842,0.302002,-0.357422,-0.162109,0.616211,0.209351,-0.001165,0.073853,0.220581,0.062866,-0.047211,-0.130005,0.078247,-0.201782,-0.11322,-0.21936,-0.53125,-0.626953,0.180786,-0.088928,-0.255615,0.347656,0.206543,-0.015564,0.802246,0.00703,-0.06189,...,-0.265869,-0.166504,-0.100159,0.726074,-0.086182,0.092102,0.175537,-0.158691,0.028442,-0.524902,0.104309,0.338623,-0.042816,-0.288574,-0.396729,0.647461,0.274902,0.750488,-0.162964,0.336182,-0.124146,-0.708496,-0.353516,-0.080688,-5.890625,-0.016678,-0.112183,-0.858398,-0.758789,-0.257568,0.73877,-0.558594,-0.090698,-0.180176,0.313232,0.296143,-0.604004,-0.38208,0.578613,0.20752
12664,You Wear A Crown But You're No King,You'll never stop 'til you get what you want. ...,Blessthefall,-0.325684,0.143311,-0.226196,-0.002642,-0.749023,0.12793,0.955078,0.399902,-0.119263,-0.51709,-0.204956,-0.348389,-0.453613,0.671875,0.931641,0.144043,0.054565,0.833496,0.304688,0.268066,-0.174561,-0.203003,0.282715,-0.172363,0.033081,-0.107117,-0.324707,-0.375488,-0.130981,-0.239624,-0.355713,0.708984,-0.438232,-0.604492,0.538086,-0.252441,0.388916,...,-0.143433,-0.48584,-0.148071,0.555664,-0.346436,-0.685059,0.348145,0.134766,-0.212891,-0.307129,-0.403809,0.779297,0.333008,-0.410889,-0.045044,0.955078,0.443115,0.821777,-0.252441,-0.125244,0.316895,-0.520996,0.196411,0.341064,-5.328125,-0.554688,-0.337158,-0.794922,0.208862,0.52832,0.154175,-0.275879,-0.046143,-0.173706,0.242676,0.716797,-0.252197,-0.583008,0.321533,0.259033
52249,Kevin Gates,Workout. Tell. Workout. Tell. Gates. Gates. Ga...,Kevin Gates,0.061066,0.076416,0.814453,-0.269531,-0.344238,0.150757,0.8125,0.367432,-0.025284,-0.30249,-0.051453,0.136597,-0.394287,-0.042206,0.085449,0.557129,0.353027,-0.274414,0.265625,0.221313,-0.042145,-0.318848,0.266602,-0.055389,0.083191,0.169922,-0.253662,-0.134888,0.048798,0.558105,-0.302979,0.624512,-0.003006,-0.40918,0.468506,-0.600098,0.716309,...,-0.317871,0.006969,-0.1203,0.86377,-0.075195,-0.203613,-0.382568,-0.025375,-0.026672,-0.160645,-0.1521,0.079102,0.102905,-0.731934,-0.254395,1.055664,-0.647461,0.386719,-0.542969,0.124084,0.029068,-0.443359,-0.394043,0.487061,-4.847656,-0.353271,0.166138,-0.452637,-0.623535,-0.463867,-0.408936,-0.527344,-0.401855,0.095398,0.742676,0.28125,-0.496582,-0.437256,0.364258,0.228882


In [126]:
y_train

19167     Hip Hop
88089     Hip Hop
41285     Hip Hop
12664        Rock
52249     Hip Hop
           ...   
127879      Metal
80491        Rock
66750        Rock
133337      Metal
99994        Rock
Name: Genre, Length: 8400, dtype: object

In [127]:
X_train_final = X_train.drop(text_cols, axis=1)
X_train_final

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,728,729,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767
19167,0.001492,-0.273682,0.150513,-0.246460,-0.050476,-0.072327,0.447510,0.275391,-0.255859,-0.431396,-0.081238,0.441895,-0.335938,-0.404785,-0.085999,0.527344,0.418213,0.027115,0.174438,-0.009880,-0.281738,-0.009178,0.066895,-0.239380,0.115173,0.156494,-0.669434,-0.170898,-0.112000,-0.060516,-0.484375,0.589844,0.041840,-0.434814,0.671387,0.115479,0.577148,-0.383057,-0.043182,-0.361084,...,-0.031464,-0.180786,0.007504,0.675781,0.023895,-0.534668,0.239746,-0.245850,0.010201,-0.446045,-0.302002,0.046783,-0.059479,-0.473389,-0.381592,1.175781,-0.134888,0.250488,-0.156982,-0.307861,-0.015541,-0.601074,-0.520020,0.265625,-5.578125,-0.268555,0.041260,-0.620117,-0.388916,0.096130,0.408203,-1.025391,-0.107178,-0.021561,0.352051,0.488770,-0.528320,-0.062469,0.002211,0.423340
88089,0.037231,0.150635,0.496826,0.196777,-0.322998,-0.077271,0.991211,0.152954,-0.482910,-0.359619,-0.137451,-0.024475,-0.678223,0.299316,0.220581,0.517090,0.036346,-0.056274,0.417236,0.355957,0.022278,-0.016312,0.398193,-0.081360,0.247314,-0.013885,-0.188965,-0.506348,-0.417236,0.346924,-0.166260,0.611816,-0.352539,-0.379639,-0.019470,-0.540527,0.359131,-0.183228,0.541504,-0.265137,...,-0.192505,-0.093079,-0.156372,0.909668,-0.449951,-0.079224,-0.250977,0.151123,-0.406006,-0.051727,-0.394043,0.213257,-0.023788,-0.221436,-0.068909,0.601562,-0.221191,0.357666,-0.402588,-0.054291,-0.025070,0.028488,-0.106384,0.453857,-4.980469,-0.636230,-0.126953,-0.451660,-0.179565,-0.520508,0.160156,-0.338623,0.003363,0.037933,0.612305,0.417236,-0.303955,-0.520508,0.493164,0.340820
41285,-0.041718,-0.246704,-0.162720,0.137085,-0.043579,-0.229004,0.499023,0.824219,0.036102,-0.441650,-0.287842,0.302002,-0.357422,-0.162109,0.616211,0.209351,-0.001165,0.073853,0.220581,0.062866,-0.047211,-0.130005,0.078247,-0.201782,-0.113220,-0.219360,-0.531250,-0.626953,0.180786,-0.088928,-0.255615,0.347656,0.206543,-0.015564,0.802246,0.007030,-0.061890,-0.476074,0.155151,0.197388,...,-0.265869,-0.166504,-0.100159,0.726074,-0.086182,0.092102,0.175537,-0.158691,0.028442,-0.524902,0.104309,0.338623,-0.042816,-0.288574,-0.396729,0.647461,0.274902,0.750488,-0.162964,0.336182,-0.124146,-0.708496,-0.353516,-0.080688,-5.890625,-0.016678,-0.112183,-0.858398,-0.758789,-0.257568,0.738770,-0.558594,-0.090698,-0.180176,0.313232,0.296143,-0.604004,-0.382080,0.578613,0.207520
12664,-0.325684,0.143311,-0.226196,-0.002642,-0.749023,0.127930,0.955078,0.399902,-0.119263,-0.517090,-0.204956,-0.348389,-0.453613,0.671875,0.931641,0.144043,0.054565,0.833496,0.304688,0.268066,-0.174561,-0.203003,0.282715,-0.172363,0.033081,-0.107117,-0.324707,-0.375488,-0.130981,-0.239624,-0.355713,0.708984,-0.438232,-0.604492,0.538086,-0.252441,0.388916,0.180908,0.363770,-0.345215,...,-0.143433,-0.485840,-0.148071,0.555664,-0.346436,-0.685059,0.348145,0.134766,-0.212891,-0.307129,-0.403809,0.779297,0.333008,-0.410889,-0.045044,0.955078,0.443115,0.821777,-0.252441,-0.125244,0.316895,-0.520996,0.196411,0.341064,-5.328125,-0.554688,-0.337158,-0.794922,0.208862,0.528320,0.154175,-0.275879,-0.046143,-0.173706,0.242676,0.716797,-0.252197,-0.583008,0.321533,0.259033
52249,0.061066,0.076416,0.814453,-0.269531,-0.344238,0.150757,0.812500,0.367432,-0.025284,-0.302490,-0.051453,0.136597,-0.394287,-0.042206,0.085449,0.557129,0.353027,-0.274414,0.265625,0.221313,-0.042145,-0.318848,0.266602,-0.055389,0.083191,0.169922,-0.253662,-0.134888,0.048798,0.558105,-0.302979,0.624512,-0.003006,-0.409180,0.468506,-0.600098,0.716309,-0.145508,0.448242,-0.265137,...,-0.317871,0.006969,-0.120300,0.863770,-0.075195,-0.203613,-0.382568,-0.025375,-0.026672,-0.160645,-0.152100,0.079102,0.102905,-0.731934,-0.254395,1.055664,-0.647461,0.386719,-0.542969,0.124084,0.029068,-0.443359,-0.394043,0.487061,-4.847656,-0.353271,0.166138,-0.452637,-0.623535,-0.463867,-0.408936,-0.527344,-0.401855,0.095398,0.742676,0.281250,-0.496582,-0.437256,0.364258,0.228882
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
127879,-0.169678,0.041229,0.236816,-0.042542,-0.390625,-0.249268,0.797852,0.419189,-0.269531,-0.881348,0.283203,-0.315674,-0.225220,0.750977,-0.264160,0.145996,0.493408,0.444092,0.487305,0.300293,0.118835,-0.132568,0.640625,-0.134399,0.342773,-0.107422,-0.538086,0.051270,0.029022,0.494873,-0.059296,0.302734,0.121582,-0.775391,0.310303,-0.322754,0.295654,-0.295166,0.272217,0.430176,...,-0.672363,-0.103333,-0.054535,0.238525,0.277100,-0.812988,-0.215820,0.341553,-0.468506,-0.202148,-0.012787,-0.363525,-0.014145,0.385742,0.023468,0.827148,0.192261,-0.101196,-0.141602,-0.149170,0.308350,-0.067932,-0.267578,0.442139,-4.933594,-0.438965,-0.060242,-0.156250,0.106506,-0.392090,-0.206543,-0.823730,-0.051056,-0.105835,0.450439,0.190796,-0.207031,-0.497314,0.233032,-0.636230
80491,-0.017868,0.076111,0.343506,-0.726562,-0.599609,0.000035,0.419678,0.360107,0.517578,-0.833008,0.133667,0.288086,-0.236694,0.443359,0.341309,0.136475,0.438721,0.131470,0.052002,0.017960,-0.204834,0.119690,0.112793,0.211792,0.192627,0.125244,-0.288086,-0.123169,0.099731,0.099426,-0.162842,0.489990,-0.192139,-0.637207,0.658691,-0.134644,0.090088,-0.043518,0.305908,-0.226196,...,-0.236084,-0.075806,0.239258,0.950684,-0.218140,-0.249268,0.467041,0.086365,-0.368652,-0.419922,-0.633789,0.541016,0.234619,-0.080811,0.077332,0.668945,0.080688,0.064087,-0.240601,-0.270264,0.257324,-0.272949,-0.176147,0.298584,-5.941406,-0.155762,0.031799,-0.216797,-0.382324,-0.347412,0.453857,-0.279053,-0.021866,-0.483154,0.136719,0.191284,-0.330811,-0.092896,0.070190,0.769043
66750,-0.179810,-0.208984,-0.071289,0.151611,-0.094727,-0.214478,0.788574,0.704102,-0.184082,-0.365723,-0.155273,0.162354,-0.180176,0.273193,0.255371,0.219971,0.014046,0.687012,0.570801,0.197021,-0.217285,0.025986,0.299561,-0.076721,0.312012,0.030197,-0.500488,-0.373047,0.114258,-0.196899,-0.232788,0.399414,-0.200562,-0.518555,0.427979,-0.056824,0.647461,-0.203125,0.146729,0.037109,...,-0.063416,-0.127563,-0.278320,0.545410,-0.053894,-0.519531,-0.061707,0.084167,-0.330078,-0.374512,-0.229370,-0.082642,0.113281,0.044281,-0.229004,0.651367,0.415527,0.456787,-0.083984,0.013496,0.080872,-0.166382,-0.049713,0.364746,-6.593750,-0.250000,-0.008354,-0.287354,-0.019943,-0.253174,0.017578,-0.797852,0.142456,-0.049011,0.182617,0.599609,-0.137939,-0.099487,-0.040466,0.257568
133337,-0.553223,0.139282,0.312256,0.208862,-0.581055,-0.247070,0.706543,0.520508,0.349121,-0.726074,0.071350,0.096497,-0.344482,-0.035065,0.369629,0.245605,0.229370,0.522949,0.848145,0.458740,0.150635,-0.423096,0.501465,0.227295,0.607422,-0.039581,-0.427490,-0.255859,-0.123535,-0.170044,0.012985,0.593750,-0.204102,-0.582520,0.712891,-0.234009,0.781738,-0.579590,0.753906,0.154297,...,-0.228638,-0.085999,0.122498,0.687988,0.342773,-0.433105,-0.331299,-0.442871,-0.267578,-0.053253,0.038696,0.188599,0.171021,0.041107,-0.028442,1.208008,0.260254,0.818848,0.064819,0.040680,0.153931,-0.506348,-0.083496,0.056366,-5.609375,-0.537598,-0.261475,-0.466309,0.448730,-0.231323,-0.300049,-0.753906,-0.207153,-0.024521,0.374512,0.519043,0.003078,-0.461670,0.013908,-0.375977


In [128]:
model = SVC(max_iter=4000, kernel='rbf', C=5)

In [129]:
model.fit(X_train_final, y_train)

SVC(C=5, break_ties=False, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='scale', kernel='rbf',
    max_iter=4000, probability=False, random_state=None, shrinking=True,
    tol=0.001, verbose=False)

In [130]:
pickle.dump(model, gzip.open(f'{path}{"final_lyrics_model.pkl.gz"}', 'wb'))

In [131]:
# load the test set
X_test = pd.read_pickle(f'{path}X_test.pkl.gz')
y_test = pd.read_pickle(f'{path}y_test.pkl.gz')

print(X_test.shape, y_test.shape)

(3600, 771) (3600,)


In [132]:
X_test_final = X_test.drop(text_cols, axis=1)
X_test_final

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,728,729,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767
89956,-0.244995,-0.127075,0.097412,-0.187256,-0.259521,-0.194092,0.568359,0.562012,-0.143677,-0.538086,-0.088745,0.024734,-0.390625,-0.017960,0.118286,0.623535,0.367432,0.048218,0.535645,0.066223,-0.355225,0.094849,0.306641,-0.430664,-0.182251,0.109436,-0.612305,-0.441895,-0.364258,-0.038361,-0.676270,0.712891,-0.276611,-0.604492,0.422852,0.084900,0.187744,-0.110535,0.458008,-0.352051,...,-0.166016,-0.221558,-0.136230,0.709473,-0.186157,-0.458496,0.064331,-0.041840,-0.337891,-0.149536,-0.403076,0.602539,0.272461,-0.464355,0.050476,1.202148,0.396729,0.179443,-0.330078,-0.309570,-0.083679,-0.282959,-0.058685,0.698242,-4.859375,-0.615723,0.002083,-0.978516,-0.359619,0.036774,0.078186,-0.501465,-0.123047,-0.450195,0.400635,0.800293,-0.675293,0.069397,0.398193,0.586426
33634,-0.034454,-0.149658,0.664062,-0.213867,-0.482178,-0.253418,0.167114,0.550293,-0.196533,-0.573242,0.078186,-0.334961,-0.467529,-0.104736,0.178833,0.468750,0.051727,0.114929,0.150391,0.172241,-0.564941,-0.071411,0.366699,-0.521484,0.185425,0.025467,-0.388428,-0.493408,-0.468994,0.368408,-0.503906,0.593750,-0.203247,-0.558594,0.223633,-0.498291,0.469727,0.020569,0.582031,-0.208008,...,-0.717285,-0.250977,-0.345947,0.979492,0.085693,-0.156616,0.327393,0.002342,-0.353760,-0.480469,-0.537598,0.524414,-0.000020,-0.166382,-0.060822,0.953613,-0.458984,0.691895,-0.375977,-0.637695,0.044434,-0.244629,-0.293457,0.191528,-2.894531,-0.087585,-0.099121,-0.363770,-0.378418,0.028946,-0.312744,-0.751465,0.074219,-0.265137,-0.074219,0.669434,-0.548828,-0.318848,0.070435,0.294678
16861,0.141235,-0.208374,0.282715,-0.068359,-0.137085,-0.065552,0.525391,0.498535,-0.515137,-0.383545,0.304688,0.435791,-0.042938,0.311279,0.381348,0.785156,0.101868,-0.090088,0.291992,-0.042633,-0.469971,-0.387451,0.134399,-0.456787,-0.262207,-0.133301,-0.509766,-0.301758,-0.158203,0.171143,-0.478516,0.063721,-0.177734,-0.388916,0.388672,-0.337402,0.508301,0.145996,0.368652,-0.479980,...,-0.190186,-0.176025,0.038330,0.946777,-0.001671,0.091003,0.121887,-0.061127,-0.186035,-0.399170,-0.108215,0.005074,0.155151,-0.246460,-0.373291,0.813477,-0.269775,0.663574,-0.240601,-0.087708,0.295654,-0.014618,-0.708008,0.238647,-5.714844,-0.000671,0.015602,-0.653809,-0.694824,0.268555,0.154785,-0.604492,-0.254639,-0.252930,0.430420,0.540527,-0.599121,-0.480957,0.131470,0.092529
92314,0.345215,0.182495,0.065552,-0.339355,-0.525391,0.010788,0.584473,0.826660,-0.712402,-0.538574,0.109009,-0.178223,-0.596680,0.368896,0.419922,0.850586,-0.028397,0.223633,0.150757,0.229614,-0.493652,0.366455,0.123840,-0.393311,-0.048401,-0.068359,-0.155518,-0.427246,-0.048859,-0.028427,-0.234131,0.716309,-0.188354,-0.593750,0.766113,-0.404541,0.144043,-0.005875,0.049347,-0.344238,...,-0.024292,-0.551758,-0.195801,0.839844,-0.116760,-0.339600,0.472656,0.344971,-0.448730,-0.785156,-0.571289,0.630371,-0.128784,-0.294189,0.050598,0.420654,0.197998,0.874023,-0.305908,-0.744629,0.103455,-0.208618,0.011574,0.280518,-5.167969,0.191895,0.194702,-0.813477,-0.541016,0.362549,-0.051819,-0.761719,-0.501953,-0.466064,0.818848,0.695801,-0.710449,-0.107666,0.366455,0.580566
95289,-0.081360,0.009384,0.453613,-0.090454,-0.411133,0.203369,0.451904,0.871094,-0.225098,-1.244141,-0.017792,-0.135620,-0.295898,0.045471,0.503418,0.187866,0.167114,0.519043,0.439453,0.208984,-0.469971,0.458740,0.253174,-0.220825,-0.047302,-0.024475,-0.172729,-0.107727,-0.244385,0.267822,-0.441895,0.378906,-0.059235,-0.762207,0.257568,-0.119385,-0.118591,0.007145,0.069214,-0.089111,...,-0.675781,-0.275391,-0.181641,1.320312,-0.018234,-0.672852,0.470459,0.563965,-0.422363,-0.788086,-0.146118,0.209106,0.091858,-0.143799,0.082275,0.804199,-0.213135,0.225952,-0.654297,-0.669434,0.318359,0.105225,-0.309814,0.249146,-3.816406,0.048828,-0.147217,-0.587402,-0.786133,0.136353,-0.079285,-0.294922,-0.215332,-0.411865,0.302002,0.394043,-0.223267,-0.057770,0.276611,0.663574
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
118197,-0.474121,0.379395,-0.052002,0.240479,-0.168701,0.211792,0.625488,0.508789,-0.714844,-0.219116,-0.187134,-0.291992,-0.764160,0.201904,-0.070068,0.383057,-0.073120,-0.262939,0.510254,0.208496,0.022400,-0.293457,0.044281,0.089294,0.389648,-0.167725,-0.342041,-0.090637,-0.073059,0.273926,0.147949,0.629395,0.197144,-0.213257,0.491211,-0.441650,0.452148,-0.509766,0.857910,0.165405,...,0.086792,-0.148926,-0.284180,0.072815,-0.329834,-0.763184,-0.062195,0.254150,0.365967,0.147461,0.222656,0.542480,0.253418,-0.112244,-0.307373,0.500977,0.200317,0.323730,-0.289551,-0.300781,0.059021,0.241455,0.907715,0.074585,-5.378906,-0.533203,0.071472,-0.314453,0.080017,-0.415283,0.605469,-0.621094,-0.462158,-0.062866,0.413330,0.089294,-0.097717,-0.487793,0.321045,0.369385
152448,-0.527832,0.535156,0.091675,-0.161011,-0.460449,-0.021591,0.848633,0.563477,-0.372070,-0.105286,-0.095764,0.054474,-0.324219,0.648926,0.033295,0.210815,-0.049896,-0.031433,0.483154,-0.362793,0.007835,-0.541016,0.053436,0.031082,0.171387,-0.033051,-0.073730,-0.099915,0.159912,0.178223,-0.069092,0.056763,0.104797,0.236450,0.194702,0.152344,0.492188,-0.337891,0.555664,0.259766,...,0.077332,0.041779,0.049347,0.058441,-0.298340,-0.380615,0.026901,0.023392,0.129761,0.099609,0.035065,0.656250,0.517090,0.205078,-0.088257,0.421875,0.392822,-0.019638,-0.115906,-0.399658,0.188965,-0.165771,0.704590,0.141235,-6.832031,-0.245483,0.236328,-0.343506,0.346924,-0.240601,0.512207,-0.675781,-0.113892,0.237061,-0.147217,0.264648,-0.349609,-0.316895,0.336182,0.521973
66130,-0.265137,-0.281738,0.449219,-0.094727,-0.453857,-0.122620,0.360352,0.526367,0.146240,-0.671875,0.157715,-0.103882,-0.191528,-0.009659,0.166870,0.308105,0.552734,0.059296,-0.107422,0.063416,0.045807,-0.037750,0.403564,-0.136108,0.140381,0.232910,-0.180420,-0.308838,-0.250244,-0.050720,-0.196167,0.431641,-0.350586,-0.322266,0.322998,-0.338623,0.396729,0.231445,0.126099,-0.139771,...,-0.845703,-0.590820,-0.225342,0.878418,0.101624,-0.260498,-0.054779,-0.053162,-0.373779,-0.307617,-0.412109,0.481445,0.344727,-0.326904,-0.116760,0.778320,-0.437012,0.377441,-0.506836,-0.149658,-0.125610,-0.174683,-0.139282,0.496094,-1.553711,-0.479248,0.038605,-0.571777,-0.587891,-0.391113,-0.164673,-0.479248,0.071411,-0.120239,-0.066345,0.391846,-0.296387,-0.319336,0.227051,0.458984
104573,-0.182373,-0.352783,0.146484,-0.426514,-0.583984,0.259277,0.880859,0.958008,-0.179932,-0.843750,0.117981,-0.204468,-0.210449,0.616699,0.566895,0.610352,0.577637,0.606934,0.336426,0.035370,-0.222778,0.224609,0.296143,-0.053894,-0.168579,-0.194458,-0.437256,-0.220581,-0.007778,0.259766,-0.567383,0.785156,-0.199097,-0.494629,0.778320,-0.018723,0.214111,-0.013786,0.110840,-0.068542,...,-0.276611,-0.428955,0.088135,0.812500,-0.210449,-0.693359,0.272461,0.509766,-0.556641,-0.514648,-0.452393,1.005859,0.266846,-0.367188,-0.049927,1.100586,-0.688477,0.698242,-0.288818,-0.405273,0.583496,-0.598145,0.056702,0.674316,-2.144531,-0.521973,-0.213623,-1.016602,-0.419922,-0.104858,0.082947,-0.488770,-0.828125,-0.507324,0.489746,0.223267,-0.271973,-0.354004,0.679199,0.349854


In [133]:
test_predictions = model.predict(X_test_final)

In [134]:
test_predictions.shape

(3600,)

In [135]:
y_test.shape

(3600,)

In [136]:
acc = accuracy_score(y_test, test_predictions)
print(f"Accuracy: {round(acc,2)}")

Accuracy: 0.95


In [137]:
print_confustion_matrix(model, y_test, test_predictions)

         Hip Hop  Metal  Rock
Hip Hop     1063      0   137
Metal          0   1200     0
Rock          57      1  1142


In [138]:
print(classification_report(y_test, test_predictions, target_names=genres))

              precision    recall  f1-score   support

     Hip Hop       0.95      0.89      0.92      1200
        Rock       1.00      1.00      1.00      1200
       Metal       0.89      0.95      0.92      1200

    accuracy                           0.95      3600
   macro avg       0.95      0.95      0.95      3600
weighted avg       0.95      0.95      0.95      3600

