# 学習方法の例

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pickle

In [2]:
import os
import sqlite3
import configparser
import mlflow
import mlflow.sklearn

In [3]:
cfg = configparser.ConfigParser()
cfg.read('./config.ini', encoding='utf-8')
# 各種パスを指定
DB_PATH = cfg['Path']['db_path']


In [4]:
# !echo $DB_PATH

In [5]:
# バックエンド用DBを作成
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)  # 親ディレクトリなければ作成
conn = sqlite3.connect(DB_PATH)  # バックエンド用DBを作成


In [6]:
# !echo $DB_PATH

In [7]:
# トラッキングサーバの場所を指定
tracking_uri = f'sqlite:///{DB_PATH}'
mlflow.set_tracking_uri(tracking_uri)


In [8]:
# !echo $tracking_uri

In [9]:
# !echo $ARTIFACT_LOCATION

In [10]:
# %% 手順3 エクスペリメントの作成
# Artifactストレージの場所を指定
ARTIFACT_LOCATION = cfg['Path']['artifact_location']
# Experimentの生成
EXPERIMENT_NAME = 'experiment_tuning'
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
if experiment is None:  # 当該Experiment存在しないとき、新たに作成
    experiment_id = mlflow.create_experiment(
                            name=EXPERIMENT_NAME,
                            artifact_location=ARTIFACT_LOCATION)
else: # 当該Experiment存在するとき、IDを取得
    experiment_id = experiment.experiment_id


2022/09/14 01:33:31 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2022/09/14 01:33:31 INFO mlflow.store.db.utils: Updating database tables
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.runtime.migration] Running upgrade  -> 451aebb31d03, add metric step
INFO  [alembic.runtime.migration] Running upgrade 451aebb31d03 -> 90e64c465722, migrate user column to tags
INFO  [alembic.runtime.migration] Running upgrade 90e64c465722 -> 181f10493468, allow nulls for metric values
INFO  [alembic.runtime.migration] Running upgrade 181f10493468 -> df50e92ffc5e, Add Experiment Tags Table
INFO  [alembic.runtime.migration] Running upgrade df50e92ffc5e -> 7ac759974ad8, Update run tags with larger limit
INFO  [alembic.runtime.migration] Running upgrade 7ac759974ad8 -> 89d4b8295536, create latest metrics table
INFO  [89d4b8295536_create_latest_metrics_table_py] Migration complete!
INFO  

In [11]:
# %% 手順4 実験結果のロギング
import seaborn as sns
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.svm import SVC

# データの読込とチューニング条件の指定
iris = sns.load_dataset("iris")  # irisデータセット取得
OBJECTIVE_VARIALBLE = 'species'  # 目的変数の指定
USE_EXPLANATORY = ['petal_width', 'petal_length', 'sepal_width', 'sepal_length']  # 説明変数の指定
y = iris[OBJECTIVE_VARIALBLE].values # 目的変数
X = iris[USE_EXPLANATORY].values  # 説明変数
estimator = SVC()  # 学習器（サポートベクターマシン）
cv = KFold(n_splits=3, shuffle=True, random_state=42)  # クロスバリデーション（KFold）
scoring = 'f1_micro'  # チューニングに使用するスコア（F1 Micro）
cv_params = {'gamma': [0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1, 3, 10, 30, 100],
             'C': [0.01, 0.03, 0.1, 0.3, 1, 3, 10, 30, 100]}  # チューニング用のパラメータ


In [12]:
# MLflowによるロギング開始
mlflow.sklearn.autolog()
with mlflow.start_run(experiment_id=experiment_id) as run:
    # グリッドサーチのインスタンス作成
    gridcv = GridSearchCV(estimator, cv_params, cv=cv,
                      scoring=scoring, n_jobs=-1)
    # グリッドサーチ実行
    gridcv.fit(X, y)
    # 最適パラメータの表示
    best_params = gridcv.best_params_
    best_score = gridcv.best_score_
    print(f'最適パラメータ {best_params}\nスコア {best_score}')
# %%


2022/09/14 01:34:29 INFO mlflow.sklearn.utils: Logging the 5 best runs, 94 runs will be omitted.


最適パラメータ {'C': 100, 'gamma': 0.01}
スコア 0.9866666666666667


In [9]:
# mlflow.end_run()

In [13]:
import mlflow
# トラッキングサーバの場所
tracking_uri = mlflow.get_tracking_uri()
print('Current tracking uri: {}'.format(tracking_uri))
# レジストリサーバの場所
mr_uri = mlflow.get_registry_uri()
print('Current model registry uri: {}'.format(mr_uri))
# Artifactストレージの場所
artifact_uri = mlflow.get_artifact_uri()
print('Current artifact uri: {}'.format(artifact_uri))


Current tracking uri: sqlite:///./mlflow/mlruns.db
Current model registry uri: sqlite:///./mlflow/mlruns.db
Current artifact uri: ./mlruns/0/8db46ebcfffa412abe263292c69361d8/artifacts
