# Install missing packages

In [0]:
import importlib

if importlib.util.find_spec('mlflow') is None:
  !pip install mlflow

# if importlib.util.find_spec('papermill') is None:
#   !pip install papermill
# !pip install -r requirements.txt

# Parameters (needs metadata tag parameters in Jupyter cell)

In [0]:
model_name =  None
notebook_out = None
artefacts_temp_dir = None

# Assertions to check papermill has set the parameters

In [0]:
assert model_name is not None, 'The name of model_name should have been set by papermill'
assert notebook_out is not None, 'The name of notebook_out should have been set by papermill'
assert artefacts_temp_dir is not None, 'The name of artefacts_temp_dir should have been set by papermill'

# Start notebook

In [0]:
# import packages
import pickle

import matplotlib.pyplot as plt
import mlflow
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier

from util import plot_confusion_matrix

# Set MLfow Experiment

In [0]:
mlflow.set_experiment('Iris Classification')

## Load Data

In [0]:
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names

mlflow.log_param('dataset', 'iris')

In [0]:
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

In [0]:
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

## Train Model

In [0]:
model_dict = {
    'decision_tree': DecisionTreeClassifier,
    'logistic_regression': LogisticRegression,
    'svm': SVC
}

In [0]:
ModelClass = model_dict[model_name]
model = ModelClass()

mlflow.log_param('model_name', model_name)

In [0]:
model.fit(X_train, y_train)

In [0]:
model_file_name = f'./{artefacts_temp_dir}/{model_name}.pkl'
with open(model_file_name, 'wb') as f:
    pickle.dump(model, f)
    
mlflow.log_artifact(model_file_name)

## Evaluate Model

In [0]:
acc = model.score(X_test, y_test)
print(f"Accuracy: {(acc * 100):.2f}%")

mlflow.log_metric('accuracy', acc)

In [0]:
y_pred = model.predict(X_test)
figure = plot_confusion_matrix(y_test, y_pred, class_names)
figure.show()

confusion_matrix_file_name = f'{artefacts_temp_dir}/{model_name}_confusion_matrix.png'
figure.savefig(confusion_matrix_file_name)
mlflow.log_artifact(confusion_matrix_file_name)

# After notebook has finished save as mlflow artifact

In [0]:
mlflow.log_artifact(notebook_out)