# Model inference

This notebook shows how to load the trained model for inference. This includes processing new data to prepare it for inference.

> This process will likely be captured in a Docker Repo asset on Highwind

In [1]:
# Config
ARTIFACT_SAVE_DIR = "../saved_model/"

# For this example, the test data will be used for inference
TEST_DATA_PATH = "../data/test.csv"
TARGET_COLUMN = "target"

In [2]:
import os
import pandas as pd
import joblib
from huggingface_hub import hf_hub_download

  from .autonotebook import tqdm as notebook_tqdm


## Load data

In [3]:
test_df = pd.read_csv(TEST_DATA_PATH)
test_df.head(3)

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
0,6.1,2.8,4.7,1.2,1
1,5.7,3.8,1.7,0.3,0
2,7.7,2.6,6.9,2.3,2


In [4]:
# Separate features and labels
X_test = test_df.copy()
y_test = X_test.pop(TARGET_COLUMN)

In [5]:
# Check shapes
print(f"X_test: {X_test.shape}")
print(f"y_test: {y_test.shape}")

X_test: (30, 4)
y_test: (30,)


## Download trined model and feature scaler

In [9]:
# Download from Hugging Face
download_files = ["model.joblib", "scaler.joblib"]

for file_name in download_files:
    hf_hub_download(
        repo_id="MelioAI/iris-classifier",
        filename=file_name,
        local_dir=ARTIFACT_SAVE_DIR
    )

## Process data

Repeat all processing steps to prepare data for inference

In [10]:
# Feature scaling
scaler = joblib.load(os.path.join(ARTIFACT_SAVE_DIR, "scaler.joblib"))
X_test_scaled = scaler.transform(X_test)

## Make predictions

In [11]:
# Load model
model = joblib.load(os.path.join(ARTIFACT_SAVE_DIR, "model.joblib"))
model

In [12]:
# Make predictions on processed data
y_pred = model.predict(X_test_scaled)
y_pred

array([1, 0, 2, 1, 1, 0, 1, 2, 1, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,
       0, 2, 2, 2, 2, 2, 0, 0])

In [13]:
# Make predictions (class probabilities) on processed data
y_pred = model.predict_proba(X_test_scaled)
y_pred

array([[1.14475073e-02, 8.76014266e-01, 1.12538227e-01],
       [9.64369004e-01, 3.56305853e-02, 4.10786879e-07],
       [3.76768338e-08, 2.88064673e-03, 9.97119316e-01],
       [1.32028749e-02, 7.59597068e-01, 2.27200057e-01],
       [1.88736913e-03, 7.52191840e-01, 2.45920791e-01],
       [9.32163787e-01, 6.78346862e-02, 1.52679556e-06],
       [8.92013671e-02, 8.78517730e-01, 3.22809028e-02],
       [8.42008232e-05, 6.42658457e-02, 9.35649953e-01],
       [7.39930569e-04, 5.77409899e-01, 4.21850170e-01],
       [3.02188127e-02, 9.25802172e-01, 4.39790155e-02],
       [1.18150194e-03, 2.10500206e-01, 7.88318292e-01],
       [9.49649175e-01, 5.03503361e-02, 4.88813101e-07],
       [9.60117238e-01, 3.98825066e-02, 2.55213774e-07],
       [9.51904539e-01, 4.80949904e-02, 4.70254636e-07],
       [9.89916261e-01, 1.00836487e-02, 9.07712609e-08],
       [1.91396506e-02, 7.24696375e-01, 2.56163974e-01],
       [4.03061386e-05, 3.18044910e-02, 9.68155203e-01],
       [2.66553386e-02, 9.34977