In [1]:
import pandas as pd
import numpy as np
import os
import pickle
import json

from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.metrics import classification_report
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from google.colab import drive

##Configurazione

In [2]:
BASE_PATH ='./drive/MyDrive/Visione e Percezione/'

##Importazione Dataset

In [4]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [28]:
dataset = pd.DataFrame()
dataset_path = BASE_PATH + 'dataset/'
for file_name in os.listdir(dataset_path):
    file_path = dataset_path + file_name
    dataset_by_sign = pd.read_csv(file_path)
    dataset = dataset.append(dataset_by_sign)

In [29]:
dataset = dataset.sample(frac=1)
X = dataset.iloc[:, 0:32]
y = dataset.iloc[:, 32]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

##Training

###Tuning

In [None]:
clf = LogisticRegression()
solvers = ['newton-cg', 'lbfgs', 'liblinear']
penalty = ['l2']
c_values = [100, 10, 1.0, 0.1, 0.01]
max_iter = [100, 200, 500, 1000]

grid = dict(solver=solvers, penalty=penalty, C=c_values, max_iter=max_iter)
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
grid_search = GridSearchCV(estimator=clf, param_grid=grid, cv=cv, scoring='accuracy', error_score=0, return_train_score=True)
grid_result = grid_search.fit(X_train, y_train)

In [None]:
print(f"Miglior risultato: {grid_result.best_score_} con {grid_result.best_params_}")

Miglior risultato: 0.9368055555555554 con {'C': 100, 'max_iter': 100, 'penalty': 'l2', 'solver': 'newton-cg'}


BEST PARAMS: {'C': 100, 'max_iter': 100, 'penalty': 'l2', 'solver': 'newton-cg'}

###Modello

In [30]:
grid_search = LogisticRegression(C=100, max_iter=100, penalty='l2', solver='newton-cg')
grid_search.fit(X_train, y_train)

LogisticRegression(C=100, class_weight=None, dual=False, fit_intercept=True,
                   intercept_scaling=1, l1_ratio=None, max_iter=100,
                   multi_class='auto', n_jobs=None, penalty='l2',
                   random_state=None, solver='newton-cg', tol=0.0001, verbose=0,
                   warm_start=False)

##Training & CV

In [None]:
grid_result.cv_results_.pop('params', None)
pd.DataFrame(grid_result.cv_results_)

Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_C,param_max_iter,param_penalty,param_solver,split0_test_score,split1_test_score,split2_test_score,split3_test_score,split4_test_score,split5_test_score,split6_test_score,split7_test_score,split8_test_score,split9_test_score,split10_test_score,split11_test_score,split12_test_score,split13_test_score,split14_test_score,split15_test_score,split16_test_score,split17_test_score,split18_test_score,split19_test_score,split20_test_score,split21_test_score,split22_test_score,split23_test_score,split24_test_score,split25_test_score,split26_test_score,split27_test_score,split28_test_score,split29_test_score,mean_test_score,std_test_score,rank_test_score,split0_train_score,split1_train_score,split2_train_score,split3_train_score,split4_train_score,split5_train_score,split6_train_score,split7_train_score,split8_train_score,split9_train_score,split10_train_score,split11_train_score,split12_train_score,split13_train_score,split14_train_score,split15_train_score,split16_train_score,split17_train_score,split18_train_score,split19_train_score,split20_train_score,split21_train_score,split22_train_score,split23_train_score,split24_train_score,split25_train_score,split26_train_score,split27_train_score,split28_train_score,split29_train_score,mean_train_score,std_train_score
0,4.942288,0.267024,0.002206,9.3e-05,100.0,100,l2,newton-cg,0.940625,0.926042,0.946875,0.930208,0.934375,0.939583,0.933333,0.947917,0.932292,0.933333,0.939583,0.939583,0.928125,0.940625,0.946875,0.946875,0.914583,0.952083,0.932292,0.927083,0.945833,0.933333,0.946875,0.939583,0.927083,0.934375,0.938542,0.935417,0.928125,0.942708,0.936806,0.008217,1,0.940625,0.941898,0.939815,0.940625,0.940509,0.940741,0.941782,0.939699,0.940972,0.941667,0.942014,0.939815,0.941551,0.940625,0.939815,0.939236,0.942708,0.939005,0.941319,0.942477,0.939236,0.940972,0.940741,0.939931,0.940625,0.94213,0.939931,0.940972,0.941551,0.941435,0.940814,0.000974
1,5.000543,0.360942,0.002568,0.00134,100.0,200,l2,newton-cg,0.940625,0.926042,0.946875,0.930208,0.934375,0.939583,0.933333,0.947917,0.932292,0.933333,0.939583,0.939583,0.928125,0.940625,0.946875,0.946875,0.914583,0.952083,0.932292,0.927083,0.945833,0.933333,0.946875,0.939583,0.927083,0.934375,0.938542,0.935417,0.928125,0.942708,0.936806,0.008217,1,0.940625,0.941898,0.939815,0.940625,0.940509,0.940741,0.941782,0.939699,0.940972,0.941667,0.942014,0.939815,0.941551,0.940625,0.939815,0.939236,0.942708,0.939005,0.941319,0.942477,0.939236,0.940972,0.940741,0.939931,0.940625,0.94213,0.939931,0.940972,0.941551,0.941435,0.940814,0.000974
2,4.893547,0.253855,0.002414,0.001195,100.0,500,l2,newton-cg,0.940625,0.926042,0.946875,0.930208,0.934375,0.939583,0.933333,0.947917,0.932292,0.933333,0.939583,0.939583,0.928125,0.940625,0.946875,0.946875,0.914583,0.952083,0.932292,0.927083,0.945833,0.933333,0.946875,0.939583,0.927083,0.934375,0.938542,0.935417,0.928125,0.942708,0.936806,0.008217,1,0.940625,0.941898,0.939815,0.940625,0.940509,0.940741,0.941782,0.939699,0.940972,0.941667,0.942014,0.939815,0.941551,0.940625,0.939815,0.939236,0.942708,0.939005,0.941319,0.942477,0.939236,0.940972,0.940741,0.939931,0.940625,0.94213,0.939931,0.940972,0.941551,0.941435,0.940814,0.000974
3,4.92645,0.273498,0.002208,0.0001,100.0,1000,l2,newton-cg,0.940625,0.926042,0.946875,0.930208,0.934375,0.939583,0.933333,0.947917,0.932292,0.933333,0.939583,0.939583,0.928125,0.940625,0.946875,0.946875,0.914583,0.952083,0.932292,0.927083,0.945833,0.933333,0.946875,0.939583,0.927083,0.934375,0.938542,0.935417,0.928125,0.942708,0.936806,0.008217,1,0.940625,0.941898,0.939815,0.940625,0.940509,0.940741,0.941782,0.939699,0.940972,0.941667,0.942014,0.939815,0.941551,0.940625,0.939815,0.939236,0.942708,0.939005,0.941319,0.942477,0.939236,0.940972,0.940741,0.939931,0.940625,0.94213,0.939931,0.940972,0.941551,0.941435,0.940814,0.000974
4,3.199071,0.135377,0.002204,8e-05,10.0,100,l2,newton-cg,0.894792,0.890625,0.90625,0.890625,0.892708,0.905208,0.895833,0.902083,0.89375,0.902083,0.904167,0.904167,0.892708,0.89375,0.901042,0.901042,0.879167,0.908333,0.889583,0.8875,0.90625,0.890625,0.908333,0.898958,0.8875,0.895833,0.89375,0.897917,0.886458,0.9,0.896701,0.007202,5,0.899074,0.898727,0.89838,0.899421,0.900463,0.899074,0.898958,0.897685,0.899306,0.899421,0.899653,0.897569,0.900231,0.899306,0.898495,0.898495,0.901273,0.896875,0.900694,0.900694,0.898032,0.899884,0.897917,0.899884,0.899769,0.899653,0.898495,0.899074,0.899769,0.898727,0.899167,0.000993
5,3.205614,0.141606,0.002181,7.7e-05,10.0,200,l2,newton-cg,0.894792,0.890625,0.90625,0.890625,0.892708,0.905208,0.895833,0.902083,0.89375,0.902083,0.904167,0.904167,0.892708,0.89375,0.901042,0.901042,0.879167,0.908333,0.889583,0.8875,0.90625,0.890625,0.908333,0.898958,0.8875,0.895833,0.89375,0.897917,0.886458,0.9,0.896701,0.007202,5,0.899074,0.898727,0.89838,0.899421,0.900463,0.899074,0.898958,0.897685,0.899306,0.899421,0.899653,0.897569,0.900231,0.899306,0.898495,0.898495,0.901273,0.896875,0.900694,0.900694,0.898032,0.899884,0.897917,0.899884,0.899769,0.899653,0.898495,0.899074,0.899769,0.898727,0.899167,0.000993
6,3.361418,0.458434,0.002397,0.000853,10.0,500,l2,newton-cg,0.894792,0.890625,0.90625,0.890625,0.892708,0.905208,0.895833,0.902083,0.89375,0.902083,0.904167,0.904167,0.892708,0.89375,0.901042,0.901042,0.879167,0.908333,0.889583,0.8875,0.90625,0.890625,0.908333,0.898958,0.8875,0.895833,0.89375,0.897917,0.886458,0.9,0.896701,0.007202,5,0.899074,0.898727,0.89838,0.899421,0.900463,0.899074,0.898958,0.897685,0.899306,0.899421,0.899653,0.897569,0.900231,0.899306,0.898495,0.898495,0.901273,0.896875,0.900694,0.900694,0.898032,0.899884,0.897917,0.899884,0.899769,0.899653,0.898495,0.899074,0.899769,0.898727,0.899167,0.000993
7,3.170903,0.127627,0.002218,0.000254,10.0,1000,l2,newton-cg,0.894792,0.890625,0.90625,0.890625,0.892708,0.905208,0.895833,0.902083,0.89375,0.902083,0.904167,0.904167,0.892708,0.89375,0.901042,0.901042,0.879167,0.908333,0.889583,0.8875,0.90625,0.890625,0.908333,0.898958,0.8875,0.895833,0.89375,0.897917,0.886458,0.9,0.896701,0.007202,5,0.899074,0.898727,0.89838,0.899421,0.900463,0.899074,0.898958,0.897685,0.899306,0.899421,0.899653,0.897569,0.900231,0.899306,0.898495,0.898495,0.901273,0.896875,0.900694,0.900694,0.898032,0.899884,0.897917,0.899884,0.899769,0.899653,0.898495,0.899074,0.899769,0.898727,0.899167,0.000993
8,2.001298,0.11592,0.002175,7.9e-05,1.0,100,l2,newton-cg,0.80625,0.813542,0.827083,0.811458,0.825,0.811458,0.810417,0.809375,0.8125,0.811458,0.835417,0.811458,0.810417,0.817708,0.821875,0.8125,0.810417,0.804167,0.802083,0.811458,0.826042,0.811458,0.830208,0.819792,0.805208,0.808333,0.808333,0.825,0.807292,0.816667,0.814479,0.008119,9,0.817014,0.816667,0.816319,0.818403,0.816667,0.815972,0.816204,0.815741,0.817593,0.818287,0.816782,0.815046,0.817824,0.819329,0.813889,0.815394,0.816088,0.817708,0.819213,0.817593,0.815856,0.816088,0.817477,0.817593,0.818519,0.817477,0.817361,0.816667,0.81713,0.816551,0.816948,0.001189
9,2.019764,0.106657,0.002221,0.000205,1.0,200,l2,newton-cg,0.80625,0.813542,0.827083,0.811458,0.825,0.811458,0.810417,0.809375,0.8125,0.811458,0.835417,0.811458,0.810417,0.817708,0.821875,0.8125,0.810417,0.804167,0.802083,0.811458,0.826042,0.811458,0.830208,0.819792,0.805208,0.808333,0.808333,0.825,0.807292,0.816667,0.814479,0.008119,9,0.817014,0.816667,0.816319,0.818403,0.816667,0.815972,0.816204,0.815741,0.817593,0.818287,0.816782,0.815046,0.817824,0.819329,0.813889,0.815394,0.816088,0.817708,0.819213,0.817593,0.815856,0.816088,0.817477,0.817593,0.818519,0.817477,0.817361,0.816667,0.81713,0.816551,0.816948,0.001189


##Testing

In [24]:
map_labels = json.load(open(BASE_PATH + '/json/alphabet_mapping.json', 'rb'))

In [31]:
y_pred = grid_search.predict(X_test)
labels = np.unique(y_test)
conf_mat = confusion_matrix(y_test, y_pred, labels=labels)
labels = list(map(lambda x: map_labels[str(x)], labels))
pd.DataFrame(conf_mat, index=labels, columns=labels)

Unnamed: 0,a,b,c,d,e,f,g,h,i,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,Unnamed: 25
a,87,0,1,1,0,0,0,0,0,0,0,0,0,3,0,1,0,1,0,0,0,0,0,1,0
b,0,88,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0
c,0,0,98,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
d,0,0,0,99,0,0,0,0,1,1,0,0,0,3,0,0,0,0,0,0,0,0,1,0,0
e,0,0,0,0,107,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
f,0,1,0,1,0,103,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
g,0,0,0,0,0,0,93,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
h,1,0,0,0,0,0,1,67,0,0,0,0,0,1,11,3,0,0,0,0,0,0,0,1,0
i,0,0,0,0,0,1,0,0,103,0,0,0,0,5,0,0,0,0,0,0,0,0,0,1,0
k,0,2,0,1,0,0,0,0,1,73,0,0,0,0,1,0,1,0,0,3,0,0,3,0,0


In [32]:
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       0.90      0.92      0.91        95
           1       0.92      0.97      0.94        91
           2       0.98      0.99      0.98        99
           3       0.93      0.94      0.93       105
           4       0.99      1.00      1.00       107
           5       0.99      0.97      0.98       106
           6       0.90      0.95      0.93        98
           7       0.73      0.79      0.76        85
           8       0.96      0.94      0.95       110
           9       0.94      0.86      0.90        85
          10       1.00      1.00      1.00        98
          11       0.91      1.00      0.95       107
          12       0.94      0.87      0.90        98
          13       0.85      0.89      0.87        91
          14       0.84      0.78      0.81       105
          15       0.84      0.61      0.70        94
          16       0.93      0.96      0.95        99
          17       0.95    

##Salvataggio

In [34]:
pickle.dump(grid_search, open(BASE_PATH + '/model/model.sav', 'wb'))