In [1]:
import pickle
from sklearn import svm
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [2]:
def get_data(test_size):
    X, y = load_iris(return_X_y=True, as_frame=True)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)
    return X_train, X_test, y_train, y_test

In [3]:
def build_model(gamma, C):
    return svm.SVC(gamma=gamma, C=C)

In [4]:
def main(context, test_size, gamma, C):
    # Get data
    context.logger.info("Getting data")
    X_train, X_test, y_train, y_test = get_data(test_size)
    
    # Train model
    context.logger.info("Training model")
    model = build_model(gamma, C)
    model.fit(X_train, y_train)
    
    # Evalutate model
    context.logger.info("Evaluating model")
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    context.log_result("accuracy", accuracy)
    
    # Log datasets
    context.logger.info("Logging datasets")
    context.log_dataset(key="X_train", df=X_train, format="csv", artifact_path=context.artifact_path)
    context.log_dataset(key="X_test", df=X_test, format="csv", artifact_path=context.artifact_path)
    context.log_dataset(key="y_train", df=y_train.to_frame(), format="csv", artifact_path=context.artifact_path)
    context.log_dataset(key="y_test", df=y_test.to_frame(), format="csv", artifact_path=context.artifact_path)
    
    # Log model
    context.logger.info("Logging model")
    pickle.dump(model, open("model.pkl", 'wb'))
    context.log_model(
        key="notebook_model",
        artifact_path=context.artifact_subpath(context.uid),
        model_file="model.pkl",
        metrics={"accuracy" : accuracy},
        tag="latest",
        parameters=model.get_params(),
        framework="sklearn",
    )