## Run DVC featurize stage

In [1]:
!dvc repro ../dvc.yaml:featurize

Stage '../dvc.yaml:prepare' didn't change, skipping                   core[39m>
Stage '../dvc.yaml:featurize' didn't change, skipping
Data and pipelines are up to date.
[0m

## Load data

In [2]:
import os
import numpy as np
import pandas as pd

filepath = "/workspace/data/processed/mnist_train.npz"

train = np.load(filepath)
train_data = train["data"]
train_labels = train["labels"]
train = np.insert(train_data, 0, train_labels, axis=1)

df_train = pd.DataFrame(data=train, index=None, columns=None).rename(
    columns={0: "label"}
)

## PyCaret setup

In [3]:
from pycaret.classification import setup

experiment = setup(
    data=df_train, target="label", session_id=42, train_size=0.8, use_gpu=False
)

Unnamed: 0,Description,Value
0,session_id,42
1,Target,label
2,Target Type,Multiclass
3,Label Encoded,
4,Original Data,"(60000, 785)"
5,Missing Values,False
6,Numeric Features,784
7,Categorical Features,0
8,Ordinal Features,False
9,High Cardinality Features,False


## Show available models

In [4]:
from pycaret.classification import models

models()

Unnamed: 0_level_0,Name,Reference,Turbo
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
lr,Logistic Regression,sklearn.linear_model._logistic.LogisticRegression,True
knn,K Neighbors Classifier,sklearn.neighbors._classification.KNeighborsCl...,True
nb,Naive Bayes,sklearn.naive_bayes.GaussianNB,True
dt,Decision Tree Classifier,sklearn.tree._classes.DecisionTreeClassifier,True
svm,SVM - Linear Kernel,sklearn.linear_model._stochastic_gradient.SGDC...,True
rbfsvm,SVM - Radial Kernel,sklearn.svm._classes.SVC,False
gpc,Gaussian Process Classifier,sklearn.gaussian_process._gpc.GaussianProcessC...,False
mlp,MLP Classifier,sklearn.neural_network._multilayer_perceptron....,False
ridge,Ridge Classifier,sklearn.linear_model._ridge.RidgeClassifier,True
rf,Random Forest Classifier,sklearn.ensemble._forest.RandomForestClassifier,True


## Run training

In [5]:
from pycaret.classification import compare_models

top_model = compare_models(exclude=["gbc"], sort="F1")

Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
lightgbm,Light Gradient Boosting Machine,0.8997,0.9931,0.8996,0.8996,0.8994,0.8886,0.8886,31.988
rf,Random Forest Classifier,0.8809,0.9905,0.8807,0.8799,0.8792,0.8676,0.8679,1.396
et,Extra Trees Classifier,0.8808,0.9904,0.8807,0.8798,0.8792,0.8676,0.8678,1.959
knn,K Neighbors Classifier,0.8546,0.9695,0.8544,0.8574,0.8542,0.8384,0.8388,21.659
lr,Logistic Regression,0.8398,0.982,0.8397,0.8385,0.8389,0.822,0.8221,14.646
svm,SVM - Linear Kernel,0.833,0.0,0.8329,0.8334,0.8327,0.8144,0.8146,12.647
lda,Linear Discriminant Analysis,0.8231,0.9801,0.823,0.826,0.8241,0.8035,0.8036,1.021
ridge,Ridge Classifier,0.8221,0.0,0.8219,0.8191,0.8185,0.8023,0.8028,0.272
dt,Decision Tree Classifier,0.7922,0.8846,0.7921,0.7932,0.7926,0.7692,0.7692,3.338
nb,Naive Bayes,0.5692,0.8881,0.5687,0.6265,0.5373,0.5212,0.5349,0.45


## Inspect performance

In [6]:
from pycaret.classification import evaluate_model

evaluate_model(top_model)

interactive(children=(ToggleButtons(description='Plot Type:', icons=('',), options=(('Hyperparameters', 'param…

## Evaluate on holdout

In [7]:
from pycaret.classification import predict_model

predict_model(top_model)

Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC
0,Light Gradient Boosting Machine,0.8991,0.9931,0.8995,0.8987,0.8985,0.8879,0.888


Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,778,779,780,781,782,783,784,label,Label,Score
0,-0.008644,-0.023223,-0.039178,-0.041322,-0.057646,-0.071167,-0.098878,-0.156653,-0.239080,-0.377827,...,-0.406094,-0.441359,-0.396626,-0.288156,-0.156811,-0.089673,-0.034147,8.0,8,0.9998
1,-0.008644,-0.023223,-0.039178,-0.041322,-0.057646,-0.071167,-0.098878,-0.156653,-0.239080,-0.377827,...,-0.406094,-0.441359,-0.396626,-0.288156,-0.156811,-0.089673,-0.034147,8.0,8,0.9999
2,-0.008644,-0.023223,-0.039178,-0.041322,-0.057646,-0.071167,-0.098878,-0.156653,-0.239080,-0.377827,...,-0.406094,-0.441359,-0.396626,-0.288156,-0.156811,-0.089673,-0.034147,7.0,7,0.9998
3,-0.008644,-0.023223,-0.039178,-0.041322,-0.057646,-0.071167,-0.098878,-0.086238,-0.239080,-0.351652,...,2.691085,1.580673,-0.396626,-0.288156,-0.156811,-0.089673,-0.034147,4.0,4,0.9564
4,-0.008644,-0.023223,-0.039178,-0.041322,-0.057646,-0.071167,-0.098878,-0.156653,-0.239080,-0.377827,...,-0.406094,-0.441359,-0.396626,-0.288156,-0.156811,-0.089673,-0.034147,8.0,8,0.9998
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11995,-0.008644,-0.023223,-0.039178,-0.041322,-0.057646,-0.071167,-0.098878,-0.156653,-0.239080,-0.377827,...,-0.406094,-0.441359,-0.396626,-0.288156,-0.156811,-0.089673,-0.034147,7.0,7,0.9693
11996,-0.008644,-0.023223,-0.039178,-0.041322,-0.057646,-0.071167,-0.098878,-0.156653,-0.239080,-0.377827,...,-0.406094,0.887405,1.108689,2.054928,-0.156811,-0.089673,-0.034147,6.0,6,0.9783
11997,-0.008644,-0.023223,-0.039178,-0.041322,-0.057646,-0.071167,-0.098878,-0.156653,-0.239080,-0.377827,...,-0.406094,-0.441359,-0.396626,-0.288156,-0.156811,-0.089673,-0.034147,8.0,8,0.9997
11998,-0.008644,-0.023223,-0.039178,-0.041322,-0.057646,-0.071167,-0.098878,-0.156653,-0.239080,-0.325477,...,-0.360547,-0.383586,0.023977,-0.220240,-0.156811,-0.089673,-0.034147,8.0,8,0.9993
