In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [2]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing  import StandardScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from sklearn.preprocessing import LabelEncoder

In [4]:
#On importe le Jeu de données
pokemon = pd.read_csv("https://gist.githubusercontent.com/armgilles/194bcff35001e7eb53a2a8b441e8b2c6/raw/92200bc0a673d5ce2110aaad4544ed6c4010f687/pokemon.csv")

In [16]:
#On enlève les colonnes qui ne nous serviront pas / pas de suite
pokemon_cleaned = pokemon.drop(['Type 2','Generation', 'Legendary','#'], axis = 1)
pokemon_cleaned.isna().sum()


Name       0
Type 1     0
Total      0
HP         0
Attack     0
Defense    0
Sp. Atk    0
Sp. Def    0
Speed      0
dtype: int64

In [None]:
#On indexe les noms de pokemon
pokemon_indexed_name = pokemon_cleaned.set_index('Name')

In [26]:
#création des dummies pour le type et on joint les deux tables

dummies_type = pd.get_dummies(pokemon_indexed_name['Type 1'])
pokemon_concat = pd.concat([pokemon_indexed_name,dummies_type], axis = 1)
pokemon_concat_drop = pokemon_concat.drop(['Type 1'], axis = 1)

In [29]:
#fonction pour normaliser les données
scaler = StandardScaler()

In [32]:
#On normalise nos données et on les consignes dans un DataFrame
pokemon_scaled = pd.DataFrame(scaler.fit_transform(pokemon_concat_drop))

In [51]:
#On indexe les noms de pokemon dans le DF normalisé
pokemon_scaled.index = pokemon['Name']
pokemon_scaled

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,15,16,17,18,19,20,21,22,23,24
Name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Bulbasaur,-0.976765,-0.950626,-0.924906,-0.797154,-0.239130,-0.248189,-0.801503,-0.307232,-0.200779,-0.204124,...,-0.204124,3.229330,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473
Ivysaur,-0.251088,-0.362822,-0.524130,-0.347917,0.219560,0.291156,-0.285015,-0.307232,-0.200779,-0.204124,...,-0.204124,3.229330,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473
Venusaur,0.749845,0.420917,0.092448,0.293849,0.831146,1.010283,0.403635,-0.307232,-0.200779,-0.204124,...,-0.204124,3.229330,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473
VenusaurMega Venusaur,1.583957,0.420917,0.647369,1.577381,1.503891,1.729409,0.403635,-0.307232,-0.200779,-0.204124,...,-0.204124,3.229330,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473
Charmander,-1.051836,-1.185748,-0.832419,-0.989683,-0.392027,-0.787533,-0.112853,-0.307232,-0.200779,-0.204124,...,-0.204124,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Diancie,1.375429,-0.754692,0.647369,2.443765,0.831146,2.808099,-0.629341,-0.307232,-0.200779,-0.204124,...,-0.204124,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,4.145096,-0.186893,-0.403473
DiancieMega Diancie,2.209541,-0.754692,2.497104,1.160233,2.665905,1.369846,1.436611,-0.307232,-0.200779,-0.204124,...,-0.204124,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,4.145096,-0.186893,-0.403473
HoopaHoopa Confined,1.375429,0.420917,0.955658,-0.444182,2.360112,2.088973,0.059310,-0.307232,-0.200779,-0.204124,...,-0.204124,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,3.610414,-0.241249,-0.186893,-0.403473
HoopaHoopa Unbound,2.042718,0.420917,2.497104,-0.444182,2.971699,2.088973,0.403635,-0.307232,-0.200779,-0.204124,...,-0.204124,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,3.610414,-0.241249,-0.186893,-0.403473


In [52]:
#on indexe les noms de pokemon dans la table originale
pokemon.index = pokemon['Name']

In [53]:
#on rajoute la colonne légendaire à notre tableau normalisé
pokemon_scaled_legendary = pd.concat([pokemon_scaled,pokemon['Legendary']], axis = 1)

In [54]:
pokemon_scaled_legendary

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,16,17,18,19,20,21,22,23,24,Legendary
Name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Bulbasaur,-0.976765,-0.950626,-0.924906,-0.797154,-0.239130,-0.248189,-0.801503,-0.307232,-0.200779,-0.204124,...,3.229330,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473,False
Ivysaur,-0.251088,-0.362822,-0.524130,-0.347917,0.219560,0.291156,-0.285015,-0.307232,-0.200779,-0.204124,...,3.229330,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473,False
Venusaur,0.749845,0.420917,0.092448,0.293849,0.831146,1.010283,0.403635,-0.307232,-0.200779,-0.204124,...,3.229330,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473,False
VenusaurMega Venusaur,1.583957,0.420917,0.647369,1.577381,1.503891,1.729409,0.403635,-0.307232,-0.200779,-0.204124,...,3.229330,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473,False
Charmander,-1.051836,-1.185748,-0.832419,-0.989683,-0.392027,-0.787533,-0.112853,-0.307232,-0.200779,-0.204124,...,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Diancie,1.375429,-0.754692,0.647369,2.443765,0.831146,2.808099,-0.629341,-0.307232,-0.200779,-0.204124,...,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,4.145096,-0.186893,-0.403473,True
DiancieMega Diancie,2.209541,-0.754692,2.497104,1.160233,2.665905,1.369846,1.436611,-0.307232,-0.200779,-0.204124,...,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,4.145096,-0.186893,-0.403473,True
HoopaHoopa Confined,1.375429,0.420917,0.955658,-0.444182,2.360112,2.088973,0.059310,-0.307232,-0.200779,-0.204124,...,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,3.610414,-0.241249,-0.186893,-0.403473,True
HoopaHoopa Unbound,2.042718,0.420917,2.497104,-0.444182,2.971699,2.088973,0.403635,-0.307232,-0.200779,-0.204124,...,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,3.610414,-0.241249,-0.186893,-0.403473,True


In [72]:
#On selctionne toutes les lignes de pokémon legendaires
to_drop = pokemon_scaled_legendary[pokemon_scaled_legendary['Legendary']== True]
pokemon_scaled_OU = pokemon_scaled_legendary.drop(to_drop.index)


In [121]:
#On retire la colonne légendaire qui ne nous sert plus
pokemon_scaled_OU = pokemon_scaled_OU.drop(['Legendary'], axis = 1)

In [77]:
#On crée un DataFrame qui contient les pokemons à remplacer
a_trouver = pokemon_scaled.loc[['Mewtwo', 'Lugia', 'Rayquaza', 'GiratinaOrigin Forme', 'Dialga', 'Palkia'], :]
a_trouver

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,15,16,17,18,19,20,21,22,23,24
Name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Mewtwo,2.042718,1.439777,0.955658,0.518467,2.482429,0.65072,2.125262,-0.307232,-0.200779,-0.204124,...,-0.204124,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,3.610414,-0.241249,-0.186893,-0.403473
Lugia,2.042718,1.439777,0.33908,1.801999,0.525353,2.951925,1.436611,-0.307232,-0.200779,-0.204124,...,-0.204124,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,3.610414,-0.241249,-0.186893,-0.403473
Rayquaza,2.042718,1.40059,2.188815,0.518467,2.360112,0.65072,0.920123,-0.307232,-0.200779,4.898979,...,-0.204124,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473
GiratinaOrigin Forme,2.042718,3.164003,1.263947,0.83935,1.442732,1.010283,0.747961,-0.307232,-0.200779,-0.204124,...,4.898979,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,-0.403473
Dialga,2.042718,1.204656,1.263947,1.481116,2.360112,1.010283,0.747961,-0.307232,-0.200779,-0.204124,...,-0.204124,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,5.350666,-0.403473
Palkia,2.042718,0.812786,1.263947,0.83935,2.360112,1.729409,1.092286,-0.307232,-0.200779,-0.204124,...,-0.204124,-0.309662,-0.204124,-0.175863,-0.373632,-0.190445,-0.276977,-0.241249,-0.186893,2.478479


In [150]:
#On créée notre X et notre y

x = np.array(pokemon_scaled_OU)
y = np.array(pokemon_scaled_OU.index).reshape(-1,1)

In [151]:
#premier train test split pour avoir le test set
X_train, X_test, y_train, y_test = train_test_split(x,
                                                    y,test_size = 0.2)

ValueError: Found input variables with inconsistent numbers of samples: [18375, 735]

In [124]:
#Second train test pour avoir le validation set
X_train, X_validation, y_train, y_validation = train_test_split(X_train, y_train, test_size = 0.07)

In [159]:
#Réalisation de notre classifier avec k=sqrt(nombre d'observations) 
from math import floor
knn = KNeighborsClassifier(n_neighbors = (floor(np.sqrt((len(pokemon_scaled_OU.index))))))
knn.fit(X_train,y_train)

  after removing the cwd from sys.path.


KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                     metric_params=None, n_jobs=None, n_neighbors=27, p=2,
                     weights='uniform')

In [161]:
knn.predict(a_trouver)

array(['AlakazamMega Alakazam', 'AlakazamMega Alakazam', 'Altaria',
       'Banette', 'AegislashBlade Forme', 'Blastoise'], dtype=object)

In [118]:
a_trouver.shape

(6, 25)