-
Notifications
You must be signed in to change notification settings - Fork 4
/
c_log_tuning_local.py
54 lines (50 loc) · 2.4 KB
/
c_log_tuning_local.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# %% 手順2 トラッキングサーバの構築
import mlflow
import configparser
import os
cfg = configparser.ConfigParser()
cfg.read('./config.ini', encoding='utf-8')
# 各種パスを指定
TRACKING_URI = cfg['Path']['tracking_uri']
# トラッキングサーバの場所を指定
mlflow.set_tracking_uri(TRACKING_URI)
# %% 手順3 エクスペリメントの作成
# Artifactストレージの場所を指定(Experimentの生成が必要)
ARTIFACT_LOCATION = cfg['Path']['artifact_location']
# Experimentの生成
EXPERIMENT_NAME = 'experiment_tuning'
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
if experiment is None: # 当該Experiment存在しないとき、新たに作成
experiment_id = mlflow.create_experiment(
name=EXPERIMENT_NAME,
artifact_location=ARTIFACT_LOCATION)
else: # 当該Experiment存在するとき、IDを取得
experiment_id = experiment.experiment_id
# %% 手順4 実験結果のロギング
import seaborn as sns
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.svm import SVC
# データの読込とチューニング条件の指定
iris = sns.load_dataset("iris") # irisデータセット取得
OBJECTIVE_VARIALBLE = 'species' # 目的変数の指定
USE_EXPLANATORY = ['petal_width', 'petal_length', 'sepal_width', 'sepal_length'] # 説明変数の指定
y = iris[OBJECTIVE_VARIALBLE].values # 目的変数
X = iris[USE_EXPLANATORY].values # 説明変数
estimator = SVC() # 学習器(サポートベクターマシン)
cv = KFold(n_splits=3, shuffle=True, random_state=42) # クロスバリデーション(KFold)
scoring = 'f1_micro' # チューニングに使用するスコア(F1 Micro)
cv_params = {'gamma': [0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1, 3, 10, 30, 100],
'C': [0.01, 0.03, 0.1, 0.3, 1, 3, 10, 30, 100]} # チューニング用のパラメータ
# MLflowによるロギング開始
mlflow.sklearn.autolog()
with mlflow.start_run(experiment_id=experiment_id) as run:
# グリッドサーチのインスタンス作成
gridcv = GridSearchCV(estimator, cv_params, cv=cv,
scoring=scoring, n_jobs=-1)
# グリッドサーチ実行
gridcv.fit(X, y)
# 最適パラメータの表示
best_params = gridcv.best_params_
best_score = gridcv.best_score_
print(f'最適パラメータ {best_params}\nスコア {best_score}')
# %%