# Implémentation XGBoost

In [1]:
import feather
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import pandas as pd
import category_encoders as ce
import numpy as np



pd.options.display.max_columns = None
path_train = './input/train_final.feather'
path_test = './input/test_final.feather'

# Import des dataframes

In [2]:
df_train = feather.read_dataframe(path_train)
df_test = feather.read_dataframe(path_test)

In [3]:
df_train.columns

Index(['Artist', 'Track', 'User', 'Time', 'Rating', 'GENDER', 'AGE', 'WORKING',
       'REGION', 'MUSIC',
       ...
       'Intrusive', 'Unoriginal', 'Dated', 'Iconic', 'Unapproachable',
       'Classic', 'Playful', 'Arrogant', 'Warm', 'Soulful'],
      dtype='object', length=116)

# Split des dataframes

Nous choisissons empiriquement de diviser le jeu d'entrainement en deux parties pour pouvoir : 
* Réaliser une procédure de Cross validation
* Entraîner et faire du tuning sur notre modèle en local avant de le soumettre à Kaggle

Nous sacrifions une partie des données d'entraînement dans ce but, si le résutat n'est pas concluant, nous pourrons toujours revenir au dataframe original avaec les hyperparamètres correctement optimisés.

In [4]:
y_train = df_train.pop('Rating')

In [5]:
X_train = df_train

In [6]:
y_test = df_test.pop('Rating')

In [7]:
X_test = df_test

# Encoding des variables catégorielles

Nous commençons par encoder nos variables catégorielles avec un CatboostEncoder ( encoder inspiré du framework Catboost qui gère les valeurs catégorielles naturellement)

In [8]:
def encode_with_catboost(X_train, y_train, X_test):
    cat_cols = list(X_train.select_dtypes(include=['object']).columns)
    encoder = ce.CatBoostEncoder(verbose = 3, cols = cat_cols, return_df=True)
    encoder.fit(X_train, y_train)
    X_train_cat = encoder.transform(X_train)
    X_test_cat = encoder.transform(X_test)
    return X_train_cat, X_test_cat
    

In [9]:
X_train, X_test = encode_with_catboost(X_train, y_train, X_test)

In [10]:
X_train.head(10)

Unnamed: 0,Artist,Track,User,Time,GENDER,AGE,WORKING,REGION,MUSIC,LIST_OWN,LIST_BACK,Q1,Q2,Q3,Q4,Q5,Q6,Q7,Q8,Q9,Q10,Q11,Q12,Q13,Q14,Q15,Q16,Q17,Q18,Q19,HEARD_OF,OWN_ARTIST_MUSIC,LIKE_ARTIST,Uninspired,Sophisticated,Aggressive,Edgy,Sociable,Laid back,Wholesome,Uplifting,Intriguing,Legendary,Free,Thoughtful,Outspoken,Serious,Good lyrics,Unattractive,Confident,Old,Youthful,Boring,Current,Colourful,Stylish,Cheap,Irrelevant,Heartfelt,Calm,Pioneer,Outgoing,Inspiring,Beautiful,Fun,Authentic,Credible,Way out,Cool,Catchy,Sensitive,Mainstream,Superficial,Annoying,Dark,Passionate,Not authentic,Good Lyrics,Background,Timeless,Depressing,Original,Talented,Worldly,Distinctive,Approachable,Genius,Trendsetter,Noisy,Upbeat,Relatable,Energetic,Exciting,Emotional,Nostalgic,None of these,Progressive,Sexy,Over,Rebellious,Fake,Cheesy,Popular,Superstar,Relaxed,Intrusive,Unoriginal,Dated,Iconic,Unapproachable,Classic,Playful,Arrogant,Warm,Soulful
0,33,85,34406,12,36.974113,58.0,36.431736,35.203279,38.278503,0.0,1.0,52.0,67.0,68.0,11.0,51.0,9.0,52.0,49.0,31.0,73.0,96.0,72.0,73.0,96.0,85.0,73.0,81.0,51.0,70.0,35.605901,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2,174,47314,17,35.928711,80.0,32.595883,36.409322,33.295282,1.0,2.0,53.0,28.0,25.0,48.0,69.0,50.0,33.0,34.0,35.0,58.0,34.0,35.0,48.0,49.0,8.0,28.0,29.0,50.0,32.0,31.094945,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,40,177,50440,17,35.928711,18.0,35.881433,36.409322,39.129741,16.0,0.0,46.0,60.0,85.0,60.0,79.0,61.0,45.0,72.0,72.0,57.0,72.0,46.0,60.0,65.0,52.0,71.0,78.0,51.0,68.0,35.605901,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
3,33,85,25762,11,36.974113,20.0,36.431736,35.203279,39.129741,0.0,1.0,54.0,54.0,52.0,53.0,53.0,54.0,54.0,54.0,53.0,52.0,53.0,56.0,57.0,56.0,56.0,56.0,56.0,56.0,57.0,35.605901,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,21,48,22720,22,36.623893,39.0,39.107005,37.936712,36.623893,1.0,2.0,51.0,53.0,52.0,34.0,32.0,35.0,30.0,23.0,50.0,53.0,64.0,53.0,50.0,53.0,36.0,32.0,56.0,47.0,44.0,31.094945,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,13,31,11172,19,36.974113,20.0,35.91286,35.203279,39.129741,3.0,1.0,50.0,51.0,51.0,51.0,51.0,51.0,7.0,11.0,52.0,92.0,55.0,55.0,32.0,34.0,9.0,49.0,50.0,51.0,50.0,35.605901,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,0,2,1461,6,36.974113,38.0,39.107005,36.409153,38.278503,1.0,1.0,28.0,43.0,51.0,51.0,35.0,43.0,22.0,17.0,43.0,41.0,70.0,72.0,49.0,52.0,23.0,8.0,16.0,47.0,44.0,31.094945,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0
7,14,95,31563,23,35.928711,33.0,36.431736,35.203279,33.491819,3.0,7.0,100.0,100.0,100.0,1.0,1.0,2.0,1.0,1.0,1.0,90.0,27.0,8.0,100.0,79.0,66.0,32.0,44.0,75.0,97.0,35.605901,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,20,46,18011,21,35.928711,63.0,31.605633,36.409322,33.295282,0.0,2.0,19.0,59.0,18.0,16.0,77.0,41.0,9.0,46.0,31.0,19.0,65.0,58.0,33.0,30.0,24.0,23.0,22.0,20.0,21.0,31.094945,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,40,147,36852,13,36.974113,57.0,34.429171,35.203279,33.295282,0.0,3.0,49.0,45.0,42.0,49.0,41.0,52.0,50.0,49.0,51.0,53.0,53.0,51.0,44.0,52.0,54.0,52.0,45.0,49.0,52.0,31.094945,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0


In [11]:
X_test.head(10)

Unnamed: 0,Artist,Track,User,Time,GENDER,AGE,WORKING,REGION,MUSIC,LIST_OWN,LIST_BACK,Q1,Q2,Q3,Q4,Q5,Q6,Q7,Q8,Q9,Q10,Q11,Q12,Q13,Q14,Q15,Q16,Q17,Q18,Q19,HEARD_OF,OWN_ARTIST_MUSIC,LIKE_ARTIST,Uninspired,Sophisticated,Aggressive,Edgy,Sociable,Laid back,Wholesome,Uplifting,Intriguing,Legendary,Free,Thoughtful,Outspoken,Serious,Good lyrics,Unattractive,Confident,Old,Youthful,Boring,Current,Colourful,Stylish,Cheap,Irrelevant,Heartfelt,Calm,Pioneer,Outgoing,Inspiring,Beautiful,Fun,Authentic,Credible,Way out,Cool,Catchy,Sensitive,Mainstream,Superficial,Annoying,Dark,Passionate,Not authentic,Good Lyrics,Background,Timeless,Depressing,Original,Talented,Worldly,Distinctive,Approachable,Genius,Trendsetter,Noisy,Upbeat,Relatable,Energetic,Exciting,Emotional,Nostalgic,None of these,Progressive,Sexy,Over,Rebellious,Fake,Cheesy,Popular,Superstar,Relaxed,Intrusive,Unoriginal,Dated,Iconic,Unapproachable,Classic,Playful,Arrogant,Warm,Soulful
0,25,59,18161,21,35.928711,44.0,35.881433,35.203279,33.295282,0.0,0.0,52.0,53.0,48.0,6.0,7.0,48.0,47.0,47.0,49.0,49.0,50.0,69.0,49.0,49.0,9.0,10.0,71.0,52.0,52.0,31.094945,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,35,88,28759,23,35.928711,36.0,36.431736,36.409322,33.491819,2.0,0.0,52.0,51.0,67.0,52.0,34.0,53.0,32.0,32.0,53.0,72.0,68.0,68.0,69.0,70.0,70.0,32.0,71.0,52.0,47.0,31.094945,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,21,48,20142,21,36.974113,61.0,32.595883,35.203279,33.295282,0.0,1.0,5.0,53.0,24.0,9.0,5.0,85.0,3.0,4.0,69.0,56.0,79.0,51.0,4.0,4.0,4.0,6.0,7.0,3.0,6.0,31.094945,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,41,155,41201,16,36.974113,14.0,35.91286,36.409322,38.278503,1.0,2.0,54.0,55.0,52.0,57.0,57.0,53.0,55.0,50.0,53.0,51.0,63.0,65.0,52.0,57.0,57.0,58.0,60.0,54.0,54.0,35.605901,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,10,141,34253,12,36.974113,63.0,32.595883,35.203279,33.295282,0.0,2.0,35.0,35.0,51.0,77.0,73.0,31.0,32.0,33.0,98.0,49.0,75.0,75.0,23.0,31.0,12.0,33.0,75.0,34.0,34.0,31.094945,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,43,159,42337,16,36.974113,50.0,34.593639,35.203279,33.295282,0.0,0.0,12.0,30.0,16.0,33.0,32.0,52.0,10.0,8.0,88.0,29.0,33.0,29.0,12.0,14.0,14.0,13.0,73.0,10.0,11.0,31.094945,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,39,103,29329,23,35.928711,43.0,36.431736,36.409153,33.295282,1.0,0.0,61.0,66.0,44.0,51.0,62.0,36.0,5.0,73.0,61.0,27.0,30.0,50.0,62.0,60.0,39.0,32.0,72.0,34.0,69.0,42.215189,36.681174,48.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,21,48,22412,22,36.623893,39.0,39.107005,37.936712,36.623893,1.0,2.0,51.0,53.0,51.0,34.0,32.0,35.0,30.0,22.0,50.0,53.0,64.0,53.0,50.0,53.0,36.0,32.0,56.0,47.0,44.0,31.094945,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,1,5,2134,18,35.928711,51.0,36.035632,36.409153,39.129741,16.0,8.0,54.0,100.0,100.0,53.0,27.0,51.0,14.0,10.0,49.0,82.0,85.0,88.0,100.0,100.0,53.0,51.0,95.0,71.0,73.0,35.605901,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
9,1,4,4003,18,36.974113,21.0,35.881433,36.409322,38.278503,4.0,2.0,64.0,66.0,69.0,51.0,59.0,27.0,44.0,43.0,30.0,66.0,69.0,78.0,63.0,54.0,65.0,60.0,51.0,60.0,66.0,35.605901,32.43835,49.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0


# Root Mean Squared Error

In [28]:
def compute_RMSE(y_true, y_pred):
    #la prédiction doit avoir un format entier
    tmp = []
    for el in y_pred:
        tmp.append(round(el))
    RMSE = np.sqrt(np.mean((y_true-tmp)**2))
    print("Root mean squared error: {}".format(RMSE))


# Simple XGBRegressor

In [13]:
#Instanciation
xg_reg = xgb.XGBRegressor(objective ='reg:squarederror',
                          colsample_bytree = 0.3,
                          learning_rate = 0.05,
                          max_depth = 10,
                          n_estimators = 200,
                          random_state=42,
                          n_jobs=8,
                          booster='gbtree')

In [14]:
#Fit sur l'entraînement
xg_reg.fit(X_train,y_train,eval_metric='rmse')

  if getattr(data, 'base', None) is not None and \
  data.base is not None and isinstance(data, np.ndarray) \


XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,
             colsample_bynode=1, colsample_bytree=0.3, gamma=0,
             importance_type='gain', learning_rate=0.05, max_delta_step=0,
             max_depth=10, min_child_weight=1, missing=None, n_estimators=200,
             n_jobs=8, nthread=None, objective='reg:squarederror',
             random_state=42, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
             seed=None, silent=None, subsample=1, verbosity=1)

In [15]:
# Prediction sur le jeu de test
y_pred = xg_reg.predict(X_test)

In [29]:
compute_RMSE(y_test, y_pred)

Root mean squared error: 14.404265183522476


La première version de l'algorithme donne une RMSE sur notre propre jeu de test de 14,40, après divers tests successifs.

# XgBoost optimisé avec GridSearch et Cross Validation - Tuning du modèle

Afin de tester toutes les combinaisons d'hyperparamètres sans avoir à relancer le modèle à chaque fois, nous utilisons une GridSearch. Après divers itérations, cela nous permet de trouver les hyperparamètres optimaux.
Nous rajoutons une procédure de validation croisée (ici sur 2 folds pour diminuer les temps de calculs) afin de lutter contre l'overfitting du modèle.

In [55]:
from sklearn.model_selection  import GridSearchCV
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score

In [56]:
xgb_model = xgb.XGBRegressor()

In [57]:
parameters = {'n_jobs':[8], 
              'objective':['reg:squarederror'],
              'booster' : ['gbtree'],
              'learning_rate': [0.06], 
              'max_depth': [12],
              'subsample': [1],
              'eval_metric ' : ['rmse'],
              'colsample_bytree': [0.5],
              'n_estimators': [300]}

In [58]:
clf = GridSearchCV(xgb_model, parameters, n_jobs=8, 
                   cv=KFold(n_splits=2, shuffle=True), 
                   scoring='neg_mean_squared_error',
                   verbose=3, refit=True)

In [76]:
clf.fit(X_train,y_train, eval_metric='rmse')

Fitting 2 folds for each of 1 candidates, totalling 2 fits


[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   2 out of   2 | elapsed:  2.1min remaining:    0.0s
[Parallel(n_jobs=8)]: Done   2 out of   2 | elapsed:  2.1min finished
  if getattr(data, 'base', None) is not None and \
  data.base is not None and isinstance(data, np.ndarray) \


GridSearchCV(cv=KFold(n_splits=2, random_state=None, shuffle=True),
             error_score='raise-deprecating',
             estimator=XGBRegressor(base_score=0.5, booster='gbtree',
                                    colsample_bylevel=1, colsample_bynode=1,
                                    colsample_bytree=1, gamma=0,
                                    importance_type='gain', learning_rate=0.1,
                                    max_delta_step=0, max_depth=3,
                                    min_child_weight=1, missing=None,
                                    n_estimators=100, n_jobs=1, nthr...
                                    subsample=1, verbosity=1),
             iid='warn', n_jobs=8,
             param_grid={'booster': ['gbtree'], 'colsample_bytree': [0.5],
                         'eval_metric ': ['rmse'], 'learning_rate': [0.06],
                         'max_depth': [12], 'n_estimators': [300],
                         'n_jobs': [8], 'objective': ['reg:squarederro

In [79]:
clf.best_estimator_ 

XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,
             colsample_bynode=1, colsample_bytree=0.5, eval_metric ='rmse',
             gamma=0, importance_type='gain', learning_rate=0.06,
             max_delta_step=0, max_depth=12, min_child_weight=1, missing=None,
             n_estimators=300, n_jobs=8, nthread=None,
             objective='reg:squarederror', random_state=0, reg_alpha=0,
             reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,
             subsample=1, verbosity=1)

In [80]:
clf.best_score_ 

-211.60805746318724

In [81]:
clf.best_params_ 

{'booster': 'gbtree',
 'colsample_bytree': 0.5,
 'eval_metric ': 'rmse',
 'learning_rate': 0.06,
 'max_depth': 12,
 'n_estimators': 300,
 'n_jobs': 8,
 'objective': 'reg:squarederror',
 'subsample': 1}

In [82]:
# Prediction sur le jeu de test
y_pred = clf.predict(X_test)

In [84]:
compute_RMSE(y_test, y_pred)

Root mean squared error: 13.8395990199659


Le modèle le plus performant sur la base de la Mean Square Error : 

{'booster': 'gbtree', 'colsample_bytree': 0.5, 'learning_rate': 0.06, 'max_depth': 12, 'n_estimators': 300, 'n_jobs': 8, 'objective': 'reg:squarederror', 'subsample': 1}

MSE sur le training set: 
-211.60805746318724

Root mean squared error sur le test set: 13.8395990199659

Cette valeur correspond à la 15 ème place du challenge Kaggle. Kaggle ne fournissant pas la possibilité d'effectuer les challenges une fois qu'ils sont cloturés,ce que nous n'avions pas prévu, cette mesure reste indicative. De plus, le jeu de données n'étant plus disponible, il nous est impossible de tester sur un dataset de plus grande envergure.

In [75]:
#Export du model et chargement
import pickle
pickle.dump(xg_reg, open("./output/xgboost.dat", "wb"))
#loaded_model = pickle.load(open("./output/xgboost.dat", "rb"))

# Kaggle Challenge

Cette partie n'est plus d'acualité car nous ne pouvons pas tester sur Kaggle. Nous avions fait l'export du résultat sousle bon format.

Nous nous servons du meilleur modèle précédent pour s'entraîner sur le dataset d'entraînement d'origine. Nous testons ensuite le modèle sur le jeu de test et exportons le résultat en CSV.

In [None]:
df_train = feather.read_dataframe(path_train)

In [None]:
df_test = feather.read_dataframe(path_test)

In [None]:
y_train = df_train.pop('Rating')

In [None]:
X_train, X_test = encode_with_catboost(df_train, y_train, df_test)

In [None]:
#Instanciation
xg_reg = xgb.XGBRegressor(objective ='reg:squarederror',
                          colsample_bytree = 0.5,
                          learning_rate = 0.06,
                          max_depth = 12,
                          n_estimators = 300,
                          random_state=42,
                          n_jobs=8,
                          booster='gbtree')

#Fit sur l'entraînement
xg_reg.fit(X_train,y_train, eval_metric='rmse')

# Prediction sur le jeu de test
y_pred = xg_reg.predict(X_test)

In [None]:
y_pred

In [None]:
#Export du model et chargement
import pickle
pickle.dump(xg_reg, open("./output/xgboost.dat", "wb"))
#loaded_model = pickle.load(open("./output/xgboost.dat", "rb"))

# Export du dataframe sous le bon format

In [None]:
test = pd.read_csv('./input/test.csv')


In [None]:
ratings = pd.DataFrame(y_pred, columns=['Rating']) 

In [None]:
pd.concat([test[['Artist', 'Track', 'User']].reset_index(drop=True), ratings.reset_index(drop=True), test[['Time']].reset_index(drop=True)], axis=1)

In [None]:
export_df

In [None]:
pd.to_csv('./output/submit.csv')