In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from sklearn.datasets import load_iris
from xgboost import XGBClassifier
from mlserve_sdk.client import MLServeClient
import pandas as pd
import numpy as np
import os
from dotenv import load_dotenv

load_dotenv()

True

In [3]:
def generate_churn_data(n_samples=1000, missing_frac=0.05, random_state=42):
    """
    Generate synthetic churn dataset for ML benchmarking.

    Parameters
    ----------
    n_samples : int
        Number of rows to generate.
    missing_frac : float
        Fraction of missing values to inject per column (0–1).
    random_state : int
        Seed for reproducibility.

    Returns
    -------
    X : pd.DataFrame
        Feature matrix with categorical & numerical features.
    y : pd.Series
        Binary churn target (0 = no churn, 1 = churn).
    """
    np.random.seed(random_state)

    # Generate synthetic features
    data = {
        "customer_id": np.arange(1, n_samples+1),
        "age": np.random.randint(18, 80, n_samples),
        "tenure_months": np.random.randint(1, 72, n_samples),
        "monthly_charges": np.round(np.random.uniform(20, 120, n_samples), 2),
        "total_charges": np.round(np.random.uniform(20, 8000, n_samples), 2),
        "contract_type": np.random.choice(
            ["Month-to-month", "One year", "Two year"], n_samples, p=[0.6, 0.25, 0.15]
        ),
        "payment_method": np.random.choice(
            ["Electronic check", "Mailed check", "Bank transfer", "Credit card"], n_samples
        ),
        "internet_service": np.random.choice(
            ["DSL", "Fiber optic", "No"], n_samples, p=[0.3, 0.5, 0.2]
        ),
        "gender": np.random.choice(["Male", "Female"], n_samples),
        "has_phone_service": np.random.choice(["Yes", "No"], n_samples, p=[0.9, 0.1]),
        "num_dependents": np.random.poisson(1, n_samples),  # ~0-4 mostly
    }

    X = pd.DataFrame(data)

    # Inject missing values
    if missing_frac > 0:
        for col in X.columns.drop("customer_id"):
            X.loc[X.sample(frac=missing_frac, random_state=random_state).index, col] = np.nan

    # Churn probability (synthetic rules + noise)
    prob_churn = (
        0.3 * (X["contract_type"] == "Month-to-month").astype(float) +
        0.25 * (X["internet_service"] == "Fiber optic").astype(float) +
        0.15 * (X["payment_method"] == "Electronic check").astype(float) +
        0.002 * (X["monthly_charges"].fillna(60)) +
        0.01 * (X["num_dependents"].fillna(0) == 0).astype(float) +
        np.random.normal(0, 0.1, n_samples)
    )
    prob_churn = 1 / (1 + np.exp(-prob_churn))  # sigmoid

    y = pd.Series(np.random.binomial(1, prob_churn), name="churn")

    return X, y

In [4]:
X, y = generate_churn_data(n_samples=1000, missing_frac=0.05)
X.drop(columns=["customer_id"], inplace=True)
for col in ["contract_type", "payment_method", "internet_service", "gender", "has_phone_service"]:
    X[col] = X[col].astype("category")

model = XGBClassifier(enable_categorical=True)
model.fit(X, y)

0,1,2
,objective,'binary:logistic'
,base_score,
,booster,
,callbacks,
,colsample_bylevel,
,colsample_bynode,
,colsample_bytree,
,device,
,early_stopping_rounds,
,enable_categorical,True


In [8]:
USERNAME = os.getenv("USERNAME")
TOKEN = os.getenv("TOKEN")

client = MLServeClient()
client.login(USERNAME, TOKEN)

In [6]:
try:
    lv=client.get_latest_version("fraud")
    next_version=lv["next_version"]
except:
    next_version="v1"

print(next_version)

v1


In [7]:
client.deploy(
    model=model,
    name="synth",
    version=next_version,
    features=list(X),
    background_df=X.sample(200),
    metrics={'accuracy':model.score(X, y)},
    task_type='classification'
)

{'predict_url': 'https://mlserve.com/api/v1/predict/synth/v1'}

In [14]:
%%time

TEST_DATA = {
    "features": X.columns.tolist(),
    "inputs": X.values.tolist()
}
preds = client.predict("synth", "v1", TEST_DATA, explain=True)
print("Explanations:", preds['explanations'][0])

Explanations: [{'feature': 'age', 'value': 56, 'shap_value': 0.554, 'impact': 'positive'}, {'feature': 'monthly_charges', 'value': 64.64, 'shap_value': 0.395, 'impact': 'positive'}, {'feature': 'payment_method', 'value': 'Mailed check', 'shap_value': 0.344, 'impact': 'positive'}]
CPU times: user 39.3 ms, sys: 3.27 ms, total: 42.5 ms
Wall time: 1.73 s


In [9]:
preds.keys()

dict_keys(['version', 'prediction_ids', 'predictions', 'explanations'])

In [10]:
%%time

TEST_DATA = {
    "features": X.columns.tolist(),
    "inputs": X.values.tolist()
}
preds = client.predict_weighted("synth", TEST_DATA)
#preds = client.predict_weighted("synth", TEST_DATA, entity_ids=["user-133"]*len(X))
#print("Predictions:", preds['predictions'])

CPU times: user 16.5 ms, sys: 1.96 ms, total: 18.5 ms
Wall time: 914 ms


In [11]:
%%time

# Predict using a Redis DB as a feature store for quick lookups

TEST_DATA = {
    "inputs": ['1', '2', '3']
}
preds = client.predict_weighted("synth", TEST_DATA, fs_url="redis://redis:6379/0", fs_entity_name='entity')
print("Predictions:", preds['predictions'])

Predictions: [1, 1, 0]
CPU times: user 13.3 ms, sys: 1.77 ms, total: 15 ms
Wall time: 158 ms


In [12]:
metrics = client.get_metrics("synth", "v1", as_dataframe=True)
metrics

Unnamed: 0_level_0,requests,predictions,throughput_rps,prediction_rps,avg_latency_ms,p50_latency_ms,p95_latency_ms,p99_latency_ms,avg_latency_per_element_ms,p50_latency_per_element_ms,p95_latency_per_element_ms,p99_latency_per_element_ms,error_rate
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
2025-09-11 20:00:00,3,2003,0.000833,0.556389,96.906352,104.533932,121.496375,123.004148,7.054197,0.123381,18.853548,20.518451,0.0


In [13]:
d = client.get_data_quality("synth", "v1", as_dataframe=True)

In [14]:
d['missingness']

Unnamed: 0,feature,missing_fraction,invalid_fraction
0,age,0.049925,0.0
1,tenure_months,0.049925,0.0
2,monthly_charges,0.049925,0.0
3,total_charges,0.049925,0.0
4,contract_type,0.049925,0.0
5,payment_method,0.049925,0.0
6,internet_service,0.049925,0.0
7,gender,0.049925,0.0
8,has_phone_service,0.049925,0.0
9,num_dependents,0.049925,0.0


In [15]:
d['drift']

Unnamed: 0,feature,psi,ks,wasserstein,pct_mean_diff,status
0,age,0.053623,0.096349,2.650181,5.564853,alert
1,tenure_months,0.015402,0.045161,0.853266,-1.913036,alert
2,monthly_charges,0.047004,0.080678,2.964646,4.046313,alert
3,total_charges,0.048329,0.059442,158.27585,-3.275684,alert
4,num_dependents,0.006651,0.015239,0.054712,-0.28755,warning


In [16]:
d['outliers']

Unnamed: 0,feature,iqr_fraction,zscore_fraction,status
0,age,0.0,0.0,ok
1,tenure_months,0.0,0.0,ok
2,monthly_charges,0.0,0.0,ok
3,total_charges,0.0,0.0,ok
4,num_dependents,0.001997,0.005255,ok


In [17]:
client.list_models()

[{'model': 'synth',
  'created_at': '2025-09-11T20:45:09.094906',
  'versions': [{'version': 'v1',
    'weight': 0.0,
    'internal_url': 'http://synth-v1:8080/predict',
    'url': 'http://localhost:54441/predict',
    'active': True,
    'deployed_at': '2025-09-11T20:45:17.481528',
    'features': ['age',
     'tenure_months',
     'monthly_charges',
     'total_charges',
     'contract_type',
     'payment_method',
     'internet_service',
     'gender',
     'has_phone_service',
     'num_dependents'],
    'model_metadata': {'model_class': 'XGBClassifier',
     'module': 'xgboost.sklearn',
     'top_level': {'class': 'XGBClassifier',
      'module': 'xgboost.sklearn',
      'structure': {'name': None,
       'class': 'XGBClassifier',
       'module': 'xgboost.sklearn',
       'type': 'xgboost',
       'params': {'objective': 'binary:logistic',
        'base_score': None,
        'booster': None,
        'callbacks': None,
        'colsample_bylevel': None,
        'colsample_bynode'

In [18]:
# let's go back to predictions we made above
# each prediction carries a prediction id
# you can use this id to send feedback about the prediction
test_id=preds["prediction_ids"][:1][0]

# we provide the true value of the prediction which we receive after the prediction is made
# we also provide the associated reward of the prediction, i.e. the business value
feedback=[
    {"prediction_id":test_id, "true_value":1, "reward":10}
]
client.send_feedback(feedback)

{'status': 'ok', 'updated': 1, 'not_found': []}

In [19]:
client.get_online_metrics("synth", "v1", window_hours=24, as_dataframe=True)

Unnamed: 0,model,version,window_hours,n,n_supervised,mean_reward,n_rewards
0,synth,v1,24,2003,1,10.0,1


In [20]:
# now let's give feedback for 10 more predictions
test_ids=preds["prediction_ids"][1:10]

feedback=[]
for tid in test_ids:
    val=np.random.randint(0, 2)
    r=np.random.normal(10, 7)
    feedback.append({"prediction_id":tid, "true_value":val, "reward":r})

client.send_feedback(feedback)

{'status': 'ok', 'updated': 2, 'not_found': []}

In [21]:
client.get_online_metrics("synth", "v1", window_hours=24, as_dataframe=True)

Unnamed: 0,model,version,window_hours,n,n_supervised,mean_reward,n_rewards
0,synth,v1,24,2003,3,11.490675,3


## Comparing with a newer version

In [22]:
lv=client.get_latest_version("synth")
next_version=lv["next_version"]

feats=list(X)

client.deploy(
    model=model,
    name="synth",
    version=next_version,
    features=feats,
    background_df=X.sample(200),
    metrics={'accuracy':0.2},
    task_type='classification'
)

{'predict_url': 'https://mlserve.com/api/v1/predict/synth/v2'}

In [23]:
%%time

TEST_DATA = {
    "features": X.columns.tolist(),
    "inputs": X.values.tolist()
}
preds = client.predict("synth", "v2", TEST_DATA)
#print("Predictions:", preds['predictions'])

CPU times: user 18.2 ms, sys: 3.78 ms, total: 22 ms
Wall time: 729 ms


In [24]:
# now let's give feedback for a few predictions but as true values we will put the predictions (accuracy 100%)
test_ids=preds["prediction_ids"][1:30]
true_values=preds['predictions'][1:30]

feedback=[]
for tid, val in zip(test_ids,true_values):
    r=np.random.normal(10, 7)
    feedback.append({"prediction_id":tid, "true_value":val, "reward":r})

client.send_feedback(feedback)

{'status': 'ok', 'updated': 29, 'not_found': []}

In [25]:
client.get_online_metrics("synth", "v2", window_hours=24, as_dataframe=True)

Unnamed: 0,model,version,window_hours,n,n_supervised,accuracy,f1,brier,mean_reward,n_rewards
0,synth,v2,24,1000,29,1.0,1.0,0.0,10.016758,29


In [26]:
client.get_model_evolution("synth", as_dataframe=True)

Unnamed: 0,version,deployed_at,accuracy,f1,brier
0,v1,2025-09-11T20:45:17.481528,,,
1,v2,2025-09-11T20:45:26.963827,1.0,1.0,0.0


## AB testing model versions

In [27]:
client.get_abtests("synth")

[]

In [28]:
# Let's create an ab test between versions 1 and 2
# observe how the first version gets a zero weight
# now when you use predict_weighted you will randomly get a prediction from each model based on these probabilities
client.configure_abtest("synth", weights={"v1":0.5, "v2":0.5})

{'status': 'ok', 'model': 'synth', 'weights': {'v1': 0.5, 'v2': 0.5}}

In [29]:
# only 1 ab test so far
client.get_abtests("synth")

[{'id': 1,
  'created_at': '2025-09-11T20:45:28.881556',
  'weights': {'v1': 0.5, 'v2': 0.5}}]

In [30]:
%%time

# this endpoint will return predictions based on probabilities assigned in the ab test configuration
preds=[]
for i in range(100):
    TEST_DATA = {
        "features": X.columns.tolist(),
        "inputs": [X.values[i,:].tolist()]
    }
    pred = client.predict_weighted("synth", TEST_DATA)
    preds.append(pred)

CPU times: user 2.75 s, sys: 179 ms, total: 2.93 s
Wall time: 19.6 s


In [31]:
feedback=[]
for pred in preds:
    if pred['versions']=='v1':
        val=1
        r=1
    else:
        val=pred['predictions'][0]
        r=10
    feedback.append({"prediction_id":pred['prediction_ids'][0], "true_value":val, "reward":r})

client.send_feedback(feedback)

{'status': 'ok', 'updated': 100, 'not_found': []}

In [32]:
client.get_model_evolution("synth", as_dataframe=True)

Unnamed: 0,version,deployed_at,accuracy,f1,brier,accuracy_delta_pct,f1_delta_pct,brier_delta_pct
0,v1,2025-09-11T20:45:17.481528,0.965517,0.970588,0.034483,,,
1,v2,2025-09-11T20:45:26.963827,1.0,1.0,0.0,3.571429,3.030303,-100.0


In [33]:
# now I am sure I want v2 in prod. Let's adjust the ab test
client.configure_abtest("synth", weights={"v2":1})

{'status': 'ok', 'model': 'synth', 'weights': {'v1': 0.0, 'v2': 1.0}}

In [34]:
# I don't need v1 anymore. Let's stop it
client.stop_model("synth", "v1", remove=True)

{'status': 'ok',
 'message': 'Successfully stopped synth:v1 and removed its container and image'}