In [1]:
# from venn_abers import VennAbersCV
import sys
sys.path.append('../../')
from src.modeling.venn_abers import VennAbersCV

from sklearn.datasets import load_breast_cancer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.feature_selection import SelectKBest, VarianceThreshold
from sklearn.model_selection import train_test_split

In [2]:
estimator = Pipeline(
    steps=[
        ("scaler", StandardScaler()),
        ("model", LogisticRegression(max_iter=1000)),
    ]
)

clf = Pipeline(
    steps=[
        ("selector", VarianceThreshold(threshold=0.1)),
        ("estimator", VennAbersCV(estimator=estimator, n_splits=10, shuffle=True, random_state=42)),
    ]
)

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

clf.fit(X_train, y_train)
probs, p0_p1 = clf.predict_proba(X_test, p0_p1_output=True)

In [3]:
p0_p1

array([[0.71398834, 0.98470314],
       [0.        , 0.08040551],
       [0.        , 0.19021377],
       [0.93077152, 1.        ],
       [0.95395785, 1.        ],
       [0.        , 0.07058549],
       [0.        , 0.06856994],
       [0.        , 0.18798651],
       [0.86095234, 0.9926166 ],
       [0.8320527 , 0.9926166 ],
       [0.61141766, 0.98470314],
       [0.        , 0.16808118],
       [0.80068956, 0.98470314],
       [0.24370276, 0.67932865],
       [0.9298409 , 1.        ],
       [0.        , 0.15458163],
       [0.9083347 , 0.9926166 ],
       [0.95734237, 1.        ],
       [0.95900784, 1.        ],
       [0.        , 0.09678201],
       [0.55363611, 0.98470314],
       [0.91019286, 0.9926166 ],
       [0.        , 0.08096217],
       [0.93645241, 1.        ],
       [0.90596568, 0.9926166 ],
       [0.94644241, 1.        ],
       [0.92453249, 0.9926166 ],
       [0.93733675, 1.        ],
       [0.90596568, 0.9926166 ],
       [0.        , 0.08552756],
       [0.

In [4]:
y_pred = probs[:, 1]

In [5]:
(y_pred <= p0_p1[:, 1]).all()

True

In [6]:
(y_pred >= p0_p1[:, 0]).all()

True