In [71]:
import numpy as np
from sklearn.ensemble import BaggingClassifier
from sklearn.model_selection import train_test_split, ParameterGrid
from sklearn.neighbors import KNeighborsClassifier
from sklearn.utils import resample
from multi_imbalance.datasets import load_datasets
from multi_imbalance.resampling.SOUP import SOUP


datasets = load_datasets()['new_ecoli']
X, y = datasets.data, datasets.target   
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)

n_classifiers = 30
n_samples = X_test.shape[0]
n_classes = np.unique(np.concatenate((y_train, y_test))).shape[0]

results = np.zeros(shape=(n_classifiers, n_samples, n_classes))
decision_matrix = np.zeros(shape=(n_samples, n_classes))

for i in range(n_classifiers):
    x_sampled, y_sampled = resample(X_train, y_train, stratify=y_train)
    x_resampled, y_resampled = SOUP().fit_transform(x_sampled, y_sampled)
    clf = KNeighborsClassifier().fit(x_resampled, y_resampled)
    results[i] = clf.predict_proba(X_test)

weights_sum = np.sum(results, axis=0)
decisions_indices = np.argmax(weights_sum,axis=1)
decision_matrix[np.arange(n_samples),decisions_indices] = 1

decision_matrix

array([[1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 1., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 1., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0.