# Imports

In [2]:
import warnings
import numpy as np
import pandas as pd

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split, StratifiedKFold

from dstoolkit.feature.monitoring import psi, ks_test_drift, jensen_shannon_divergence, chi_squared_monitoring

## Functions

In [3]:
def generate_synthetic_binary_data(
    n_samples=10000,
    n_features=30,
    n_informative=5,
    n_redundant=10,
    n_repeated=0,
    n_classes=2,
    class_sep=0.5,
    flip_y=0.15,
    weights=[0.8, 0.2],
    test_size=0.2,
    valid_size=0.2,
    random_state=42
):
    """
    Gera uma base sintética binária com splits train/valid/test.
    Retorna: (X_train, y_train, X_valid, y_valid, X_test, y_test)
    """

    # ---------------------------
    # 1. Gerar base sintética
    # ---------------------------
    X, y = make_classification(
        n_samples=n_samples,
        n_features=n_features,
        n_informative=n_informative,
        n_redundant=n_redundant,
        n_repeated=n_repeated,
        n_classes=n_classes,
        class_sep=class_sep,
        flip_y=flip_y,
        weights=weights,
        random_state=random_state
    )

    # ---------------------------
    # 2. Train/Test
    # ---------------------------
    X_train, X_test, y_train, y_test = train_test_split(
        X, y,
        test_size=test_size,
        stratify=y,
        random_state=random_state
    )

    # ---------------------------
    # 3. Train/Validation
    # ---------------------------
    valid_relative = valid_size / (1 - test_size)

    X_train, X_valid, y_train, y_valid = train_test_split(
        X_train, y_train,
        test_size=valid_relative,
        stratify=y_train,
        random_state=random_state
    )

    y_train = pd.DataFrame(y_train, columns=['target'])
    y_valid = pd.DataFrame(y_valid, columns=['target'])
    y_test = pd.DataFrame(y_test, columns=['target'])
    
    return pd.DataFrame(X_train), y_train, pd.DataFrame(X_valid), y_valid, pd.DataFrame(X_test), y_test

# Tests

In [4]:
X_train, y_train, X_valid, y_valid, X_test, y_test = generate_synthetic_binary_data(
    n_samples=5000,
    n_features=30,
    n_informative=5,
    n_redundant=10,
    n_repeated=0,
    n_classes=2,
    class_sep=0.5,   # classes mais misturadas
    flip_y=0.15,     # ruído nos rótulos
    weights=[0.8, 0.2],  # desbalanceamento
    test_size=0.2,
    valid_size=0.2,
    random_state=42
)

In [5]:
X_train.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,0.660424,0.370553,-0.157266,0.763497,0.291066,0.031432,0.269237,1.068864,0.106731,2.140277,...,-2.03332,0.228399,-0.886704,-0.176886,-0.213807,0.250023,-0.130981,2.060301,0.6157,1.099804
1,-1.332327,-1.531204,-0.768285,0.823467,-0.504803,-0.891358,-1.173892,0.61027,-0.127102,-0.738951,...,2.745836,1.374807,-1.812954,-1.929477,1.686287,1.374261,0.599295,0.961674,-0.668937,-0.00986
2,-0.411217,-1.182707,1.48846,0.452974,1.215235,1.83206,0.139225,0.171208,0.413916,0.530517,...,-1.074884,0.741791,-1.394438,-1.329335,0.439295,-1.511498,1.691682,1.035581,1.252118,-1.33989
3,-0.961314,2.79371,3.721175,-1.254595,2.116411,4.164159,3.264257,-0.149023,1.007158,0.991212,...,-0.036316,-1.537561,-1.527298,4.31549,-0.549563,-0.684245,-0.067385,-0.730745,2.883616,1.313439
4,0.374048,-0.036991,-0.108653,-0.128614,-0.03976,0.146764,-0.122993,-1.256825,0.297515,0.820744,...,-0.514466,-0.215456,-1.477026,-0.440279,-0.959748,0.786763,-0.244742,0.429517,1.167471,1.183787


In [6]:
psi(X_train[0], X_train[0])

0.0

In [11]:
ks_test_drift(X_train[0], X_train[2])

{'ks_statistic': 0.08366666666666667,
 'p_value': 1.4844168722388856e-09,
 'drift_detected': True}

In [12]:
ks_test_drift(X_train[0], X_train[0])

{'ks_statistic': 0.0, 'p_value': 1.0, 'drift_detected': False}

In [7]:
psi(X_train[0], X_test[1])

0.38435111747329526

In [8]:
reference = pd.Series(
    np.random.choice(["A", "B", "C"], size=1000, p=[0.5, 0.3, 0.2])
)

# Distribuição atual (ex: produção)
current = pd.Series(
    np.random.choice(["A", "B", "C"], size=400, p=[0.5, 0.3, 0.2])
)

In [9]:
chi_squared_monitoring(reference, current)

{'chi2': 5.20991304170532,
 'p_value': 0.07390635051280263,
 'drift_detected': False}

In [10]:
jensen_shannon_divergence(reference, current, bins=10)

{'js_divergence': 0.0022469957331387805, 'js_distance': 0.04740248657126312}