In [2]:
import joblib
from sklearn import datasets
from sklearn import neighbors
from sklearn import svm
from sklearn import tree
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split


def download_data():
    iris_dataset = datasets.load_iris()
    X = iris_dataset.data
    y = iris_dataset.target
    return X, y


def split_data(X, y, test_size=0.25):
    return train_test_split(X, y, test_size=test_size)


def build_model(model=""):
    X, y = download_data()
    X_train, X_test, y_train, y_test = split_data(X, y)

    if model == "knn":
        classifier = neighbors.KNeighborsClassifier()
    elif model == "svm":
        classifier = svm.SVC()
    else:
        model = "tree"
        classifier = tree.DecisionTreeClassifier()

    classifier.fit(X_train, y_train)

    predictions = classifier.predict(X_test)
    accuracy = accuracy_score(y_test, predictions)
    print("Accuracy for {}: {}".format(model, accuracy))

    joblib.dump(classifier, "models/{}.pkl".format(model))
    return classifier


In [3]:
import kfp
from kfp import dsl


@dsl.pipeline(
    name='iris-classification',
    description='A basic pipeline example for iris classification'
)
def iris_classification_pipeline():
    tree = dsl.ContainerOp(
        name="Train using Decision Tree",
        image="annajung/iris:latest",
        command=["sh", "-c"],
        arguments=["python iris_classification.py build_model tree"],
        file_outputs={'output': '/tmp/accuracy_tree.txt'}
    )

    knn = dsl.ContainerOp(
        name="Train using K Nearest Neighbors",
        image="annajung/iris:latest",
        command=["sh", "-c"],
        arguments=["python iris_classification.py build_model knn"],
        file_outputs={'output': '/tmp/accuracy_knn.txt'}
    )

    svm = dsl.ContainerOp(
        name="Train using Support Vector Machine",
        image="annajung/iris:latest",
        command=["sh", "-c"],
        arguments=["python iris_classification.py build_model svm"],
        file_outputs={'output': '/tmp/accuracy_svm.txt'}
    )