In [1]:
# Importing all the packages

import os
import logging
import numpy as np
from sklearn.base import ClassifierMixin
from zenml.integrations.sklearn.helpers.digits import get_digits, get_digits_model
from zenml.pipelines import pipeline
from zenml.steps import step
from zenml.steps.step_output import Output
from zenml.core.repo import Repository
from zenml.integrations.graphviz.visualizers.pipeline_run_dag_visualizer import (
    PipelineRunDagVisualizer,
)

In [2]:
@step
def importer() -> Output(
    X_train=np.ndarray, X_test=np.ndarray, y_train=np.ndarray, y_test=np.ndarray
):
    """Loads the digits array as normal numpy arrays."""
    X_train, X_test, y_train, y_test = get_digits()
    return X_train, X_test, y_train, y_test

In [3]:
@step
def normalizer(
    X_train: np.ndarray, X_test: np.ndarray
) -> Output(X_train_normed=np.ndarray, X_test_normed=np.ndarray):
    """Normalize digits dataset with mean and standard deviation."""
    X_train_normed = (X_train - np.mean(X_train)) / np.std(X_train)
    X_test_normed = (X_test - np.mean(X_test)) / np.std(X_test)
    return X_train_normed, X_test_normed

In [4]:
@step(enable_cache=False)
def trainer(
    X_train: np.ndarray,
    y_train: np.ndarray,
) -> ClassifierMixin:
    """Train a simple sklearn classifier for the digits dataset."""
    model = get_digits_model()
    model.fit(X_train, y_train)
    return model

In [5]:
@step
def evaluator(
    X_test: np.ndarray,
    y_test: np.ndarray,
    model: ClassifierMixin,
) -> float:
    """Calculate the accuracy on the test set"""
    test_acc = model.score(X_test, y_test)
    logging.info(f"Test accuracy: {test_acc}")
    return test_acc

In [6]:
@pipeline
def mnist_pipeline(
    importer,
    normalizer,
    trainer,
    evaluator,
):
    # Link all the steps together
    X_train, X_test, y_train, y_test = importer()
    X_trained_normed, X_test_normed = normalizer(X_train=X_train, X_test=X_test)
    model = trainer(X_train=X_trained_normed, y_train=y_train)
    evaluator(X_test=X_test_normed, y_test=y_test, model=model)

In [7]:
def visualizer_graph():
    repo = Repository()
    pipe = repo.get_pipelines()[-1]
    latest_run = pipe.runs[-1]
    PipelineRunDagVisualizer().visualize(latest_run)


if __name__ == "__main__":
    # Run the pipeline
    first_pipeline = mnist_pipeline(
        importer=importer(),
        normalizer=normalizer(),
        trainer=trainer(),
        evaluator=evaluator(),
    )
    first_pipeline.run()
    visualizer_graph()

[1;35mCreating pipeline: mnist_pipeline[0m
[1;35mCache enabled for pipeline `[0m[33;21mmnist_pipeline`[1;35m[0m
[1;35mUsing orchestrator `[0m[33;21mlocal_orchestrator`[1;35m for pipeline `[0m[33;21mmnist_pipeline`[1;35m. Running pipeline..[0m
[1;35mStep `[0m[33;21mimporter`[1;35m has started.[0m
[1;35mStep `[0m[33;21mimporter`[1;35m has finished in 0.029s.[0m
[1;35mStep `[0m[33;21mnormalizer`[1;35m has started.[0m
[1;35mStep `[0m[33;21mnormalizer`[1;35m has finished in 0.031s.[0m
[1;35mStep `[0m[33;21mtrainer`[1;35m has started.[0m
[1;35mStep `[0m[33;21mtrainer`[1;35m has finished in 0.100s.[0m
[1;35mStep `[0m[33;21mevaluator`[1;35m has started.[0m


INFO:root:Test accuracy: 0.9154616240266963


[1;35mStep `[0m[33;21mevaluator`[1;35m has finished in 0.121s.[0m
[1;35mThis integration is not completed yet. Results might be unexpected.[0m


Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)


?PNG

   
IHDR  ?     SlW?   bKGD ? ? ?????    IDATx???{\??????n?t/I%??`HB2??RHJ)M?h?q??3?33f\s??©0#?Kh??ܫʵT?ݽ????1??W?QٻϾ????~x????֫ͪ?????x?1!?B?D?? ?!D?Q?!?B??
N!?"U?\ ?eRYY???2???A @(???L?|]]JKK?FWW|>_????TUU???mmm???@SS?ݾi-*8	!?TWW#;;??Ƌ/PPP???\ ???????H$?J>?]]]???A__???022???LMMѥK???ރ??:v?(????R'???WVV???ܹs<@vv6rrr??????\?y????ҥLLL`ll###?????? zzz??Յ????????
?x<???7??????????M~]\\???r?????????(..F^^????EoAAjjj?m?????????????kkk?????ڒ~	!J?
N??c????իWq??mܹs?????G?  ;v????KKK???????{?BQ?????₹?h?????????j?x<t??}?????5????!C??W?^??x\?9E'!Di"11W?^???W???
?@ uuuX[[?o߾??~??ְ??lp/?"???Cvv6????{??5????????????!C0d?6???\?&??	*8	!J??????8?<Ο????4?D"???b???6l???0x?`????
???????ב?????ddd@$?G?puu???+ƌ??? ?G'!D?=x? 111???AJJ
???`cc?ѣGc???1b?????)WJJJp??%\?x/^ĝ;w?????ooox{{?G?\?$??*8	!
??ݻ8t??9?[?n?S?N?0a<<<???###?#*???|???????8y?$???`kkooo????o߾\G$?p?
NB?B???FDDadd??c?bʔ)pww??