# Real data

In this notebook we will test active learning on real world imbalanced data.

In [1]:
## Imports ##

# numpy
import numpy as np

# matplotlib
import matplotlib as mlp
import matplotlib.pyplot as plt

# sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# skactiveml
from skactiveml.classifier import SklearnClassifier
from skactiveml.pool import UncertaintySampling
from skactiveml.pool import RandomSampling
from skactiveml.utils import MISSING_LABEL

# plot function
from plot_accuracy import plot_accuracy

# balanced accuracy
from balanced_accuracy import balanced_accuracy

# set warnings
import warnings
mlp.rcParams["figure.facecolor"] = "white"
warnings.filterwarnings("ignore")

## [Breast cancer wisconsin](https://archive.ics.uci.edu/dataset/17/breast+cancer+wisconsin+diagnostic)

Wolberg,William, Mangasarian,Olvi, Street,Nick, and Street,W.. (1995). Breast Cancer Wisconsin (Diagnostic). UCI Machine Learning Repository. https://doi.org/10.24432/C5DW2B.

The breast cancer data set is publicly available, we can just import it.

In [2]:
# import the data set
from ucimlrepo import fetch_ucirepo 
breast_cancer_wisconsin_diagnostic = fetch_ucirepo(id=17) 

In [3]:
X = breast_cancer_wisconsin_diagnostic.data.features
y = list(breast_cancer_wisconsin_diagnostic.data.targets['Diagnosis'])
y = [0 if x == 'B' else 1 for x in y]

# label 'B' means good
# label 'M' means bad
Xf, Xt, yf, yt = train_test_split(X, y, random_state=1)

In [13]:
# create the qs and clf
qs = UncertaintySampling(random_state=1)
clf = SklearnClassifier(LogisticRegression(), classes=np.unique(yf))
y = np.full(shape=len(yf), fill_value=MISSING_LABEL)

# perform active learning
out = []
clf.fit(Xf,y)
for _ in range(50):
    i = qs.query(Xf, y, clf)[0]
    y[i] = yf[i]
    clf.fit(Xf,y)
    counts = balanced_accuracy(yt,clf.predict(Xt), as_counters=True)
    print(counts)
    print(counts[0])
    print((counts[1][0]/counts[0][0],counts[1][1]/counts[0][1]))
out

(Counter({0: 88, 1: 55}), Counter({0: 88}))
Counter({0: 88, 1: 55})
(1.0, 0.0)
(Counter({0: 88, 1: 55}), Counter({0: 71, 1: 48}))
Counter({0: 88, 1: 55})
(0.8068181818181818, 0.8727272727272727)
(Counter({0: 88, 1: 55}), Counter({0: 63, 1: 51}))
Counter({0: 88, 1: 55})
(0.7159090909090909, 0.9272727272727272)
(Counter({0: 88, 1: 55}), Counter({0: 68, 1: 50}))
Counter({0: 88, 1: 55})
(0.7727272727272727, 0.9090909090909091)
(Counter({0: 88, 1: 55}), Counter({0: 70, 1: 49}))
Counter({0: 88, 1: 55})
(0.7954545454545454, 0.8909090909090909)
(Counter({0: 88, 1: 55}), Counter({0: 69, 1: 53}))
Counter({0: 88, 1: 55})
(0.7840909090909091, 0.9636363636363636)
(Counter({0: 88, 1: 55}), Counter({0: 71, 1: 53}))
Counter({0: 88, 1: 55})
(0.8068181818181818, 0.9636363636363636)
(Counter({0: 88, 1: 55}), Counter({0: 70, 1: 50}))
Counter({0: 88, 1: 55})
(0.7954545454545454, 0.9090909090909091)
(Counter({0: 88, 1: 55}), Counter({0: 75, 1: 52}))
Counter({0: 88, 1: 55})
(0.8522727272727273, 0.94545454545

[]