# Imputation Plugins

### Setup

In [1]:
import sys
import warnings
import time
from tqdm import tqdm
from math import sqrt

import numpy as np
import pandas as pd

from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error


from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer

from hyperimpute.plugins.utils.metrics import RMSE
from hyperimpute.plugins.utils.simulate import simulate_nan



import xgboost as xgb

from IPython.display import HTML, display
import tabulate

if not sys.warnoptions:
    warnings.simplefilter("ignore")

### Loading the Imputation plugins


In [2]:
from hyperimpute.plugins.imputers import Imputers, ImputerPlugin

imputers = Imputers()

### List the existing plugins

In [3]:
imputers.list()

['median',
 'missforest',
 'nop',
 'mean',
 'most_frequent',
 'hyperimpute',
 'sinkhorn',
 'miwae',
 'gain',
 'EM',
 'softimpute',
 'mice',
 'ice',
 'sklearn_ice',
 'sklearn_missforest',
 'miracle']

### Adding a new Imputation plugin

By default, HyperImpute automatically loads the imputation plugins with the pattern `hyperimpute/plugins/imputers/plugin_*`. 

Alternatively, you can call `Imputers().add(<name>, <ImputerPlugin derived class>)` at runtime.

Next, we show two examples of custom Imputation plugins.

In [3]:
custom_ice_plugin = "custom_ice"


class NewPlugin(ImputerPlugin):
    def __init__(self):
        super().__init__()
        lr = LinearRegression()
        self._model = IterativeImputer(
            estimator=lr, max_iter=500, tol=1e-10, imputation_order="roman"
        )

    @staticmethod
    def name():
        return custom_ice_plugin

    @staticmethod
    def hyperparameter_space():
        return []

    def _fit(self, *args, **kwargs) -> "NewPlugin":
        self._model.fit(*args, **kwargs)
        return self

    def _transform(self, *args, **kwargs):
        return self._model.transform(*args, **kwargs)

    def save(self) -> bytes:
        raise NotImplemented("placeholder")

    @classmethod
    def load(cls, buff: bytes) -> "NewPlugin":
        raise NotImplemented("placeholder")


imputers.add(custom_ice_plugin, NewPlugin)

assert imputers.get(custom_ice_plugin) is not None

### List the existing plugins

Now we should see the new plugins loaded.

In [4]:
imputers.list()

['mice',
 'median',
 'EM',
 'missforest',
 'nop',
 'custom_ice',
 'sinkhorn',
 'sklearn_ice',
 'miwae',
 'most_frequent',
 'mean',
 'miracle',
 'gain',
 'sklearn_missforest',
 'softimpute',
 'ice',
 'hyperimpute']

### Testing the performance

We simulate some testing datasets using 3 amputation strategies:
- **Missing Completely At Random** (MCAR) if the probability of being missing is the same for all observations
- **Missing At Random** (MAR) if the probability of being missing only depends on observed values.
- **Missing Not At Random** (MNAR) if the unavailability of the data depends on both observed and unobserved data such as its value itself.

#### Load the dataset

In [5]:
from sklearn.preprocessing import MinMaxScaler
import pandas as pd

preproc = MinMaxScaler()


def dataset():
    X, y = load_breast_cancer(return_X_y=True)
    X = pd.DataFrame(preproc.fit_transform(X, y))
    y = pd.Series(y)

    return train_test_split(X, y, test_size=0.2)


def ampute(x, mechanism, p_miss):
    x_simulated = simulate_nan(np.asarray(x), p_miss, mechanism)

    mask = x_simulated["mask"]
    x_miss = x_simulated["X_incomp"]

    return pd.DataFrame(x), pd.DataFrame(x_miss), pd.DataFrame(mask)

In [6]:
datasets = {}
headers = ["Plugin"]

pct = 0.3

mechanisms = ["MAR", "MNAR", "MCAR"]
percentages = [pct]

plugins = ["mean"]  # imputers.list()  # default plugins

X_train, X_test, y_train, y_test = dataset()

for ampute_mechanism in mechanisms:
    for p_miss in percentages:
        if ampute_mechanism not in datasets:
            datasets[ampute_mechanism] = {}

        headers.append(ampute_mechanism + "-" + str(p_miss))
        datasets[ampute_mechanism][p_miss] = ampute(X_train, ampute_mechanism, p_miss)

In [9]:
import pprint

# 在这里添加打印datasets的代码
for key, value in datasets.items():
    display(f"--- {key} ---") # 使用分隔符和标题
    for sub_key, sub_value in value.items():
        display(f"{sub_key}:")
        display(sub_value)
    display("------") # 使用分隔符

'--- MAR ---'

'0.3:'

(           0         1         2         3         4         5         6   \
 338  0.145251  0.264457  0.142492  0.070965  0.433962  0.165266  0.058833   
 427  0.180747  0.414948  0.172759  0.091792  0.319401  0.116711  0.084677   
 406  0.433480  0.174163  0.418147  0.278473  0.382053  0.201307  0.128866   
 96   0.246060  0.274941  0.234953  0.130477  0.468268  0.157015  0.058341   
 490  0.249373  0.430504  0.237648  0.137010  0.264422  0.100055  0.040159   
 ..        ...       ...       ...       ...       ...       ...       ...   
 277  0.559847  0.347311  0.532859  0.406575  0.330414  0.121036  0.187910   
 9    0.259312  0.484613  0.277659  0.140997  0.595558  0.675480  0.532568   
 359  0.116191  0.291173  0.110773  0.057306  0.435768  0.123244  0.063496   
 192  0.129632  0.287792  0.117062  0.061336  0.152298  0.012453  0.000000   
 559  0.214350  0.480893  0.212356  0.110286  0.360928  0.253727  0.260544   
 
            7         8         9   ...        20        21   

'------'

'--- MNAR ---'

'0.3:'

(           0         1         2         3         4         5         6   \
 338  0.145251  0.264457  0.142492  0.070965  0.433962  0.165266  0.058833   
 427  0.180747  0.414948  0.172759  0.091792  0.319401  0.116711  0.084677   
 406  0.433480  0.174163  0.418147  0.278473  0.382053  0.201307  0.128866   
 96   0.246060  0.274941  0.234953  0.130477  0.468268  0.157015  0.058341   
 490  0.249373  0.430504  0.237648  0.137010  0.264422  0.100055  0.040159   
 ..        ...       ...       ...       ...       ...       ...       ...   
 277  0.559847  0.347311  0.532859  0.406575  0.330414  0.121036  0.187910   
 9    0.259312  0.484613  0.277659  0.140997  0.595558  0.675480  0.532568   
 359  0.116191  0.291173  0.110773  0.057306  0.435768  0.123244  0.063496   
 192  0.129632  0.287792  0.117062  0.061336  0.152298  0.012453  0.000000   
 559  0.214350  0.480893  0.212356  0.110286  0.360928  0.253727  0.260544   
 
            7         8         9   ...        20        21   

'------'

'--- MCAR ---'

'0.3:'

(           0         1         2         3         4         5         6   \
 338  0.145251  0.264457  0.142492  0.070965  0.433962  0.165266  0.058833   
 427  0.180747  0.414948  0.172759  0.091792  0.319401  0.116711  0.084677   
 406  0.433480  0.174163  0.418147  0.278473  0.382053  0.201307  0.128866   
 96   0.246060  0.274941  0.234953  0.130477  0.468268  0.157015  0.058341   
 490  0.249373  0.430504  0.237648  0.137010  0.264422  0.100055  0.040159   
 ..        ...       ...       ...       ...       ...       ...       ...   
 277  0.559847  0.347311  0.532859  0.406575  0.330414  0.121036  0.187910   
 9    0.259312  0.484613  0.277659  0.140997  0.595558  0.675480  0.532568   
 359  0.116191  0.291173  0.110773  0.057306  0.435768  0.123244  0.063496   
 192  0.129632  0.287792  0.117062  0.061336  0.152298  0.012453  0.000000   
 559  0.214350  0.480893  0.212356  0.110286  0.360928  0.253727  0.260544   
 
            7         8         9   ...        20        21   

'------'

#### Evaluation

We compare the methods in terms of root mean squared error (RMSE) to the initial dataset.

In [10]:
results = []
duration = []

for plugin in tqdm(plugins):
    plugin_results = [plugin]
    plugin_duration = [plugin]

    for ampute_mechanism in mechanisms:
        for p_miss in percentages:
            ctx = imputers.get(plugin)
            x, x_miss, mask = datasets[ampute_mechanism][p_miss]

            start = time.time() * 1000
            x_imp = ctx.fit_transform(x_miss)

            plugin_duration.append(round(time.time() * 1000 - start, 4))
            plugin_results.append(RMSE(x_imp.values, x.values, mask.values))

    results.append(plugin_results)
    duration.append(plugin_duration)

100%|██████████| 1/1 [00:00<00:00, 20.83it/s]


### Reconstruction error(RMSE)

__Interpretation__ : The following table shows the reconstruction error -  the __Root Mean Square Error(RMSE)__ for each method applied on the original full dataset and the imputed dataset.

In [11]:
display(HTML(tabulate.tabulate(results, headers=headers, tablefmt="html")))

Plugin,MAR-0.3,MNAR-0.3,MCAR-0.3
mean,0.195269,0.170658,0.146133


### XGBoost test score after imputation

__Interpretation__ The following table shows different metrics on the test set for an XGBoost classifier, after imputing the dataset with each method.
Metrics:
 - accuracy

In [12]:
from sklearn import metrics


def get_metrics(X_train, y_train, X_test, y_test):
    xgb_clf = xgb.XGBClassifier(verbosity=0)
    xgb_clf = xgb_clf.fit(X_train, y_train)

    y_pred = xgb_clf.predict(X_test)

    score = xgb_clf.score(X_test, y_test)

    fpr, tpr, thresholds = metrics.roc_curve(y_test, y_pred)
    auroc = metrics.auc(fpr, tpr)

    prec, recall, thresholds = metrics.precision_recall_curve(y_test, y_pred)
    aurpc = metrics.auc(recall, prec)

    return score, auroc, aurpc


metrics_headers = ["Plugin", "Accuracy", "AUROC", "AURPC"]
xgboost_test_score = []


x, x_miss, mask = datasets["MAR"][pct]

xgboost_test_score.append(
    ["original dataset", *get_metrics(X_train, y_train, X_test, y_test)]
)

for plugin in plugins:
    X_train_imp = imputers.get(plugin).fit_transform(x_miss.copy())

    score, auroc, aurpc = get_metrics(X_train_imp, y_train, X_test, y_test)

    xgboost_test_score.append([plugin, score, auroc, aurpc])

In [13]:
display(
    HTML(
        tabulate.tabulate(xgboost_test_score, headers=metrics_headers, tablefmt="html")
    )
)

Plugin,Accuracy,AUROC,AURPC
original dataset,0.95614,0.956335,0.975618
mean,0.964912,0.963798,0.978921


### Duration(ms) results

__Info__ : Here we measure the duration of imputing the dataset with each method.

In [14]:
display(HTML(tabulate.tabulate(duration, headers=headers, tablefmt="html")))

Plugin,MAR-0.3,MNAR-0.3,MCAR-0.3
mean,3.999,3.999,3.0015


## Debugging

HyperImpute supports **debug** logging. __WARNING__: Don't use it for release builds. 

In [15]:
from hyperimpute import logger

imputers = Imputers()

logger.add(sink=sys.stderr, level="DEBUG")

x, x_miss, mask = datasets["MAR"][pct]

x_imp = imputers.get("EM").fit_transform(x)

imputers.get("softimpute").fit_transform(x_miss)

[2024-03-18T11:54:21.359450+0800][11772][DEBUG] Loaded plugin imputer - EM


[2024-03-18T11:54:21.371450+0800][11772][DEBUG] EMPlugin._fit took 0.0 seconds
[2024-03-18T11:54:21.391451+0800][11772][DEBUG] EM converged after 1 iterations.
[2024-03-18T11:54:21.392450+0800][11772][DEBUG] EMPlugin._transform took 0.02 seconds
[2024-03-18T11:54:21.444626+0800][11772][DEBUG] Loaded plugin imputer - softimpute
[2024-03-18T11:54:31.387479+0800][11772][DEBUG] SoftImputePlugin._fit took 9.929 seconds
[2024-03-18T11:54:34.533415+0800][11772][DEBUG] SoftImputePlugin._transform took 3.1449 seconds


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,0.145251,0.264457,0.142492,0.070965,0.433962,0.165266,0.058833,0.088221,0.419192,0.281171,...,0.114906,0.394989,0.107426,0.048860,0.455854,0.170970,0.084265,0.223333,0.261975,0.141677
1,0.180747,0.414948,0.172759,0.091792,0.319401,0.116711,0.084677,0.069781,0.482828,0.206613,...,0.163807,0.533582,0.165745,0.074789,0.390477,0.138070,0.153914,0.267668,0.275971,0.141545
2,0.433480,0.174163,0.418147,0.278473,0.382053,0.201307,0.219663,0.225050,0.340909,0.185131,...,0.347919,0.301980,0.326162,0.187451,0.321793,0.140592,0.184505,0.387973,0.239109,0.098911
3,0.246060,0.274941,0.234953,0.130477,0.468268,0.157015,0.137689,0.146173,0.424242,0.345198,...,0.177191,0.237207,0.158026,0.084460,0.282837,0.172993,0.039776,0.269679,0.130495,0.122786
4,0.249373,0.430504,0.237648,0.137010,0.264422,0.100055,0.040159,0.062674,0.244444,0.206403,...,0.221985,0.532249,0.210817,0.107575,0.359440,0.165969,0.098243,0.217698,0.302582,0.177030
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
450,0.559847,0.347311,0.532859,0.406575,0.330414,0.121036,0.273228,0.290408,0.247475,0.106015,...,0.427962,0.323242,0.391404,0.258258,0.350855,0.224745,0.176518,0.444674,0.197516,0.015283
451,0.259312,0.484613,0.277659,0.140997,0.595558,0.432627,0.278825,0.424602,0.489899,0.618937,...,0.348724,0.656335,0.235271,0.129326,0.817058,0.364655,0.882588,0.759450,0.552139,1.000000
452,0.116191,0.291173,0.110773,0.057306,0.435768,0.123244,0.063496,0.069881,0.225253,0.413437,...,0.145500,0.346482,0.126401,0.062525,0.410289,0.075298,0.091374,0.173608,0.175241,0.172635
453,0.129632,0.287792,0.117062,0.061336,0.152298,0.012453,0.000000,0.000000,0.299495,0.305602,...,0.018884,0.234808,0.058967,0.029149,0.000000,0.000000,0.000000,0.000000,0.067810,0.069198
