In [1]:
import matplotlib.pyplot as plt
import numpy as np
import sklearn.datasets as skd

In [38]:
import numpy as np
import pandas as pd
import xgboost as xgb

from sklearn.utils import check_random_state
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.utils.deprecation import _deprecate_Xt_in_inverse_transform
from sklearn.utils.validation import (
    _check_feature_names_in,
    _check_sample_weight,
    check_array,
    check_is_fitted,
)
class KBinsDiscretizerSampler(KBinsDiscretizer):
    def inverse_transform_sample(self, X=None, random_state=None):
        rng = check_random_state(random_state)
        check_is_fitted(self)

        if "onehot" in self.encode:
            X = self._encoder.inverse_transform(X)

        Xinv = check_array(X, copy=True, dtype=(np.float64, np.float32))
        n_features = self.n_bins_.shape[0]
        if Xinv.shape[1] != n_features:
            raise ValueError(
                "Incorrect number of features. Expecting {}, received {}.".format(
                    n_features, Xinv.shape[1]
                )
            )
        n = X.shape[0]
        for jj in range(n_features):
            jitter = rng.uniform(0., 1., size=n)
            bin_edges = self.bin_edges_[jj]
            bin_centers = (bin_edges[1:] + bin_edges[:-1]) * 0.5
            bin_lefts = bin_edges[1:][(Xinv[:, jj]).astype(np.int64)]
            bin_rights = bin_edges[:-1][(Xinv[:, jj]).astype(np.int64)]
            Xinv[:, jj] = bin_lefts * jitter + bin_rights * (1 - jitter)

        return Xinv

def top_p_sampling(n_bins, probs, rng, top_p):
    """ This implements a modified version of nucleus sampling.
    It discards probability mass beyond the boundary of the class straddles the top_p boundary,
    but it does not discard the probability mass of this class below the boundary.
    """
    probs = probs.ravel()  # currently assumes only one sample
    sort_indices = np.argsort(probs)[::-1]
    sort_probs = probs[sort_indices]
    cumsum_probs = np.cumsum(sort_probs)
    unnorm_probs = np.diff(np.minimum(cumsum_probs, top_p), prepend=0.)
    unnorm_probs = unnorm_probs[np.argsort(sort_indices)]  # undo the sort
    norm_probs = unnorm_probs / np.sum(unnorm_probs)
    
    chosen = np.array(rng.choice(n_bins, p=norm_probs))
    return chosen

class MaskingTreesModel:
    def __init__(
        self,
        n_bins=5,
        duplicate_K=50,
        top_p=0.9,
        random_state = None,
    ):
        self.n_bins = n_bins
        self.duplicate_K = duplicate_K
        self.top_p = top_p
        self.random_state = random_state

        assert 2 <= n_bins
        assert 1 <= duplicate_K
        assert 0 < top_p <= 1
    
        self.xgbers_ = None
        self.quantize_cols_ = None
        self.quantizers_ = None
        self.X_ = None

    def fit(
        self,
        X,
        quantize_cols='floating',
    ):
        # TODO - handle categorical columns
        # TODO - xgboost iterator - generate batches on the fly
        # TODO - xgboost kwargs
        # TODO - sample_weight from OADM formula
        # TODO - KDITransformer
        rng = check_random_state(self.random_state)
        n_samples, n_dims = X.shape
        if isinstance(quantize_cols, list):
            assert len(quantize_cols) == n_dims
            self.quantize_cols_ = quantize_cols
        elif quantize_cols == 'floating':
            assert isinstance(X, pd.DataFrame)
            q_cols = list(X.select_dtypes(include='floating').columns)
            self.quantize_cols_ = [col in q_cols for col in list(X.columns)]
        elif quantize_cols == 'number':
            assert isinstance(X, pd.DataFrame)
            q_cols = list(X.select_dtypes(include='number').columns)
            self.quantize_cols_ = [col in q_cols for col in list(X.columns)]    
        elif quantize_cols == 'none':
            self.quantize_cols_ = [False] * n_dims
        elif quantize_cols == 'all':
            self.quantize_cols_ = [True] * n_dims
        else:
            raise ValueError(f'unexpected quantize_cols: {quantize_cols}')

        if isinstance(X, pd.DataFrame):
            X = X.values
            self.X_ = X.copy()
        elif isinstance(X, np.ndarray):
            self.X_ = X.copy()
        else:
            raise ValueError(f'X must be np.ndarray or pd.DataFrame: {type(X)}')
        
        self.quantizers_ = []       
        for d in range(n_dims):
            if self.quantize_cols_[d]:
                curq = KBinsDiscretizerSampler(
                    n_bins=self.n_bins, encode='ordinal', strategy='quantile')
                print(X[:, d])
                curq.fit(X[~np.isnan(X[:, d]), d:d+1])
            else:
                curq = None
            self.quantizers_.append(curq)

        X_train = []
        Y_train = []
        for dupix in range(self.duplicate_K):
            mask_ixs = np.repeat(np.arange(n_dims)[np.newaxis, :], n_samples, axis=0)
            mask_ixs = np.apply_along_axis(rng.permutation, axis=1, arr=mask_ixs) # n_samples, n_dims
            for n in range(n_samples):
                fuller_X = X[n, :]
                for d in range(n_dims):
                    victim_ix = mask_ixs[n, d]
                    if fuller_X[victim_ix] != np.nan:
                        emptier_X = fuller_X.copy()
                        emptier_X[mask_ixs[n, d]] = np.nan
                        X_train.append(emptier_X.reshape(1, -1))
                        Y_train.append(fuller_X.reshape(1, -1))
                        fuller_X = emptier_X
        X_train = np.concatenate(X_train, axis=0)
        Y_train = np.concatenate(Y_train, axis=0)
        self.trees_ = []
        for d in range(n_dims):
            xgber = xgb.XGBClassifier(tree_method="hist") # TODO: early_stopping_rounds=2)
            train_ixs = ~np.isnan(Y_train[:, d])
            if self.quantize_cols_[d]:
                curY_train = self.quantizers_[d].transform(Y_train[train_ixs, d:d+1])
            else:
                curY_train = Y_train[train_ixs, d:d+1]
            curX_train = X_train[train_ixs, :] 
            xgber.fit(curX_train, curY_train)  # TODO: sample_weight
            self.trees_.append(xgber)
        return self

    def generate(
        self,
        n_samples=1,
    ):
        n_samples, n_dims = self.X_.shape
        rng = check_random_state(self.random_state)

        X = np.full(fill_value=np.nan, shape=(n_samples, n_dims))
        unmask_ixs = np.repeat(np.arange(n_dims)[np.newaxis, :], n_samples, axis=0)  # (n_samples, n_dims)
        unmask_ixs = np.apply_along_axis(rng.permutation, axis=1, arr=unmask_ixs) # (n_samples, n_dims)
        for n in range(n_samples):
            for dix in range(n_dims):
                unmask_ix = unmask_ixs[n, dix]
                pred_probas = self.trees_[unmask_ix].predict_proba(X[[n], :])
                pred_quant = top_p_sampling(self.n_bins, pred_probas, rng, self.top_p)

                if self.quantize_cols_[unmask_ix]:
                    pred_val = self.quantizers_[unmask_ix].inverse_transform_sample(pred_quant.reshape(1, 1))
                else:
                    pred_val = pred_quant
                X[n, unmask_ix] = pred_val.item()
        return X

    def impute(
        self,
        k=1,
    ):
        (n_samples, n_dims) = self.X_.shape
        rng = check_random_state(self.random_state)

        imputedX = np.repeat(self.X_[np.newaxis, :, :], repeats=k, axis=0) # (k, n_samples, n_dims)           
        for n in range(n_samples):
            to_unmask = np.where(np.isnan(self.X_[n, :]))[0] # (n_to_unmask,)
            unmask_ixs = np.repeat(to_unmask[np.newaxis, :], k, axis=0)  # (k, n_to_unmask)
            unmask_ixs = np.apply_along_axis(rng.permutation, axis=1, arr=unmask_ixs) # (k, n_to_unmask)
            n_to_unmask = unmask_ixs.shape[1]
            for kix in range(k):
                for dix in range(n_to_unmask):
                    unmask_ix = unmask_ixs[kix, dix]
                    pred_probas = self.trees_[unmask_ix].predict_proba(imputedX[kix,[n], :])
                    pred_quant = top_p_sampling(self.n_bins, pred_probas, rng, self.top_p)
                    pred_val = self.quantizers_[unmask_ix].inverse_transform_sample(pred_quant.reshape(1, 1))
                    imputedX[kix, n, unmask_ix] = pred_val.item()
        return imputedX

rix = 0
rng = check_random_state(rix)
n_upper = 100
n_lower = 100
n = n_upper + n_lower
data, labels = skd.make_moons(
    (n_upper, n_lower), shuffle=False, noise=0.1, random_state=rix)
data4impute = data.copy()
data4impute[:, 1] = np.nan
X=np.concatenate([data, data4impute], axis=0)

model = MaskingTreesModel(n_bins=20)
model.fit(X)
data_fake = model.generate(n_samples=200);



nimp = 1 # number of imputations needed
data_impute = model.impute(k=nimp)[0, :, :]

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(7, 5));
axes[0, 0].scatter(data[:, 0], data[:, 1]);
axes[0, 0].set_title('original');
axes[0, 1].scatter(data_fake[:, 0], data_fake[:, 1]);
axes[0, 1].set_title('generated');
axes[1, 0].scatter(data_impute[200:, 0], data_impute[200:, 1]);
axes[1, 0].set_title('imputed');
"""
axes[1, 1].scatter(data_impute[200:, 0], data_impute[200:, 1]);
axes[1, 1].set_title('imputed - repainted');
"""
plt.tight_layout();



AssertionError: 

In [35]:
# X["cat_feature"].astype("category")
# clf = xgb.XGBClassifier(tree_method="hist", enable_categorical=True, device="cuda")

In [42]:
from sklearn.datasets import fetch_openml
X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)
X = X[list(X.select_dtypes(include=['number']).columns)]# + ['sex']]

In [43]:
model = MaskingTreesModel(n_bins=20)
model.fit(X)

[29.      0.9167  2.     ... 26.5    27.     29.    ]
[211.3375 151.55   151.55   ...   7.225    7.225    7.875 ]
[ nan  nan  nan ... 304.  nan  nan]


ValueError: Invalid classes inferred from unique values of `y`.  Expected: [0 1 2], got [1. 2. 3.]

In [44]:
X

Unnamed: 0,pclass,age,sibsp,parch,fare,body
0,1,29.0000,0,0,211.3375,
1,1,0.9167,1,2,151.5500,
2,1,2.0000,1,2,151.5500,
3,1,30.0000,1,2,151.5500,135.0
4,1,25.0000,1,2,151.5500,
...,...,...,...,...,...,...
1304,3,14.5000,1,0,14.4542,328.0
1305,3,,1,0,14.4542,
1306,3,26.5000,0,0,7.2250,304.0
1307,3,27.0000,0,0,7.2250,


In [45]:
X.values

array([[  1.    ,  29.    ,   0.    ,   0.    , 211.3375,      nan],
       [  1.    ,   0.9167,   1.    ,   2.    , 151.55  ,      nan],
       [  1.    ,   2.    ,   1.    ,   2.    , 151.55  ,      nan],
       ...,
       [  3.    ,  26.5   ,   0.    ,   0.    ,   7.225 , 304.    ],
       [  3.    ,  27.    ,   0.    ,   0.    ,   7.225 ,      nan],
       [  3.    ,  29.    ,   0.    ,   0.    ,   7.875 ,      nan]])