In [1]:
import pickle
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import  classification_report, confusion_matrix
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn import cluster, metrics

#試取資料
file = open("../log/ml.pickle", "rb")
data = pickle.load(file)
file.close()

print(data)

In [2]:
# 提取特徵

scene_info = data['scene_info']
scene_command = data['command']

SnakeHead_x = []
SnakeHead_y = []
command = []

for i, s in enumerate(scene_info[1:-2]):
    SnakeHead_x.append(s['snake_head'][0])
    SnakeHead_y.append(s['snake_head'][1])
            
for c in scene_command[1:-2]:
    if c == "UP":
        command.append(1)
    elif c == "LEFT":
        command.append(2)
    elif c == "DOWN":
        command.append(3)
    elif c == "RIGHT":
        command.append(4)
    else :
        command.append(0)

In [3]:
import numpy as np

numpy_data = np.array([SnakeHead_x, SnakeHead_y])
feature = np.transpose(numpy_data) 
answer = np.transpose(command)

In [4]:
feature

array([[ 40,  50],
       [ 40,  60],
       [ 40,  70],
       ...,
       [  0, 250],
       [  0, 260],
       [  0, 270]])

In [5]:
answer

array([0, 0, 0, ..., 0, 0, 0])

In [6]:
print(feature.shape)
print(answer.shape)

(202639, 2)
(202639,)


In [7]:
#資料劃分
x_train, x_test, y_train, y_test = train_test_split(feature, answer, test_size=0.3, random_state=9)
#參數區間
param_grid = {'n_neighbors':[1, 2, 3]}
#交叉驗證 
cv = StratifiedShuffleSplit(n_splits=3, test_size=0.3, random_state=9)
grid = GridSearchCV(KNeighborsClassifier(), param_grid, cv=cv, verbose=10, n_jobs=6) #n_jobs為平行運算的數量
grid.fit(x_train, y_train)
grid_predictions = grid.predict(x_test)

#儲存
file = open('model.pickle', 'wb')
pickle.dump(grid, file)
file.close()

Fitting 3 folds for each of 3 candidates, totalling 9 fits


[Parallel(n_jobs=6)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=6)]: Done   2 out of   9 | elapsed:    7.3s remaining:   25.9s
[Parallel(n_jobs=6)]: Done   3 out of   9 | elapsed:    7.4s remaining:   14.9s
[Parallel(n_jobs=6)]: Done   4 out of   9 | elapsed:    7.4s remaining:    9.3s
[Parallel(n_jobs=6)]: Done   5 out of   9 | elapsed:    7.7s remaining:    6.1s
[Parallel(n_jobs=6)]: Done   6 out of   9 | elapsed:    7.7s remaining:    3.8s
[Parallel(n_jobs=6)]: Done   7 out of   9 | elapsed:   11.9s remaining:    3.3s
[Parallel(n_jobs=6)]: Done   9 out of   9 | elapsed:   12.5s remaining:    0.0s
[Parallel(n_jobs=6)]: Done   9 out of   9 | elapsed:   12.5s finished


In [8]:
#最佳參數
print(grid.best_params_)
#預測結果
#print(grid_predictions)
#混淆矩陣
print(confusion_matrix(y_test, grid_predictions))
#分類結果
print(classification_report(y_test, grid_predictions))

{'n_neighbors': 1}
[[54778     0     0     0     0]
 [    0  1067     0     0     0]
 [    0     0  1920     0     0]
 [    0     0     0  1084     0]
 [    0     0     0     0  1943]]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00     54778
           1       1.00      1.00      1.00      1067
           2       1.00      1.00      1.00      1920
           3       1.00      1.00      1.00      1084
           4       1.00      1.00      1.00      1943

    accuracy                           1.00     60792
   macro avg       1.00      1.00      1.00     60792
weighted avg       1.00      1.00      1.00     60792

