In [None]:
import matplotlib as mpl
mpl.rcParams['font.family'] = 'Arial'
mpl.rcParams['text.usetex'] = False
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import sklearn.datasets as skd
import pandas as pd

from sklearn.utils import check_random_state
from ForestDiffusion import ForestDiffusionModel
from miceforest import ImputationKernel
from missforest import MissForest
from utrees import UnmaskingTrees

In [None]:
rix = 0
n = 200
nimp = 1 # number of multiple imputations needed
ngen = 200

rng = check_random_state(rix)
data, labels = skd.make_moons(n, shuffle=False, noise=0.1, random_state=rix)
data4impute = data.copy()
data4impute[:, 1] = np.nan
X=np.concatenate([data, data4impute], axis=0)
impute_samples = np.isnan(X).any(axis=1)

missfer = MissForest(random_state=rix)
impute_missf = missfer.fit_transform(X.copy(), cat_vars=None)

micer = ImputationKernel(X.copy(), random_state=rix)
micer.mice(5)
impute_mice = micer.complete_data()

utreer = UnmaskingTrees(random_state=rix)
utreer.fit(X.copy())
gen_utrees = utreer.generate(n_generate=ngen);
impute_utrees = utreer.impute(n_impute=nimp)[0, :, :]

utaber = UnmaskingTrees(tabpfn=True, random_state=rix)
utaber.fit(X.copy())
gen_utab = utaber.generate(n_generate=ngen);
impute_utab = utaber.impute(n_impute=nimp)[0, :, :]

forestvper = ForestDiffusionModel(
    X=X.copy(),
    n_t=50, duplicate_K=100, diffusion_type='vp',
    bin_indexes=[], cat_indexes=[], int_indexes=[], n_jobs=-1, seed=rix)
gen_forestvp = forestvper.generate(batch_size=ngen)
impute_forestvp_fast = forestvper.impute(k=nimp) # regular (fast)
impute_forestvp_repaint = forestvper.impute(repaint=True, r=10, j=5, k=nimp) # REPAINT (slow, but better)

forestflower = ForestDiffusionModel(
    X=X.copy(),
    n_t=50, duplicate_K=100, diffusion_type='flow',
    bin_indexes=[], cat_indexes=[], int_indexes=[], n_jobs=-1, seed=rix)
gen_forestflow = forestflower.generate(batch_size=ngen)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(7, 5), squeeze=False, dpi=200, sharex=True, sharey=True);
markersize = 5
alpha = 0.8
color = 'blue'
axes[0, 0].set_title('Original data');
axes[0, 0].scatter(data[:, 0], data[:, 1], s=markersize, alpha=alpha, color='black', marker='x', linewidth=0.5,);
axes[1, 0].set_title('Training data');
axes[1, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color='green', marker='x', linewidth=0.5,);
xlim = axes[0, 0].get_xlim();
ylim = axes[0, 0].get_ylim();
onlyfirst = X[~np.isnan(X[:, 0]), 0]
onlyfirst = data[:, 0]
onlysecond = X[~np.isnan(X[:, 1]), 1]
onlysecond = data[:, 1]
axes[1, 0].scatter(onlyfirst, ylim[0]*np.ones_like(onlyfirst), marker='|', s=100, linewidth=0.5, color='green');
axes[1, 0].scatter(xlim[0]*np.ones_like(onlysecond), onlysecond, marker='_', s=100, linewidth=0.5, color='green');
axes[1, 0].set_xlim(xlim);
axes[1, 0].set_ylim(ylim);

axes[0, 0].set_title('Training data');
axes[0, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color='black', marker='x', linewidth=0.5,);

axes[0, 1].set_title('Forest-VP (generate)');
axes[0, 1].scatter(gen_forestvp[:, 0], gen_forestvp[:, 1], s=markersize, alpha=alpha, color=color);
axes[1, 1].set_title('Forest-Flow (generate)');
axes[1, 1].scatter(gen_forestflow[:, 0], gen_forestflow[:, 1], s=markersize, alpha=alpha, color=color);
axes[0, 2].set_title('UnmaskingTrees (generate)');
axes[0, 2].scatter(gen_utrees[:, 0], gen_utrees[:, 1], s=markersize, alpha=alpha, color=color);
axes[1, 2].set_title('UnmaskingTabPFN (generate)');
axes[1, 2].scatter(gen_utab[:, 0], gen_utab[:, 1], s=markersize, alpha=alpha, color=color);
plt.tight_layout();


fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(4, 4), squeeze=False, dpi=200, sharex=True, sharey=True);
markersize = 5
alpha = 0.7
color = 'red'
datacolor = 'green'
axes[0, 0].set_title('Forest-VP');
axes[0, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);
axes[0, 0].scatter(gen_forestvp[:, 0], gen_forestvp[:, 1], s=markersize, alpha=alpha, color=color);
axes[1, 0].set_title('Forest-Flow');
axes[1, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5);
axes[1, 0].scatter(gen_forestflow[:, 0], gen_forestflow[:, 1], s=markersize, alpha=alpha, color=color);
axes[0, 1].set_title('UnmaskingTrees');
axes[0, 1].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5);
axes[0, 1].scatter(gen_utrees[:, 0], gen_utrees[:, 1], s=markersize, alpha=alpha, color=color);
axes[1, 1].set_title('UnmaskingTabPFN');
axes[1, 1].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5,label='data');
axes[1, 1].scatter(gen_utab[:, 0], gen_utab[:, 1], s=markersize, alpha=alpha, color=color,label='generated');
plt.legend(handlelength=0.4, borderpad=0.4, labelspacing=0.3,framealpha=0.9,handletextpad=0.3);
for curax in axes.flatten():
    curax.set_yticks([0, 1]);
plt.tight_layout();
plt.savefig('moons-generation.png');

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(6, 4.3), squeeze=False, dpi=200, sharex=True, sharey=True);
markersize = 5
alpha = 0.8
color = 'blue'
datacolor = 'green'

axes[0, 0].set_title('MissForest');
axes[0, 0].scatter(
    impute_missf[impute_samples, 0], impute_missf[impute_samples, 1],
    s=markersize, alpha=alpha, color=color);
axes[0, 0].scatter(
    data[:, 0], data[:, 1],
    s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);
axes[1, 0].set_title('MICE-Forest');
axes[1, 0].scatter(
    impute_mice[impute_samples, 0], impute_mice[impute_samples, 1],
    s=markersize, alpha=alpha, color=color);
axes[1, 0].scatter(
    data[:, 0], data[:, 1],
    s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);
axes[0, 1].set_title('Forest-VP');
axes[0, 1].scatter(
    impute_forestvp_fast[impute_samples, 0], impute_forestvp_fast[impute_samples, 1],
    s=markersize, alpha=alpha, color=color);
axes[0, 1].scatter(
    data[:, 0], data[:, 1],
    s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);
axes[1, 1].set_title('Forest-VP w/ RePaint');
axes[1, 1].scatter(
    impute_forestvp_repaint[impute_samples, 0], impute_forestvp_repaint[impute_samples, 1],
    s=markersize, alpha=alpha, color=color);
axes[1, 1].scatter(
    data[:, 0], data[:, 1],
    s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);
axes[0, 2].set_title('UnmaskingTrees');
axes[0, 2].scatter(
    impute_utrees[impute_samples, 0], impute_utrees[impute_samples, 1],
    s=markersize, alpha=alpha, color=color);
axes[0, 2].scatter(
    data[:, 0], data[:, 1],
    s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);
axes[1, 2].set_title('UnmaskingTabPFN');
axes[1, 2].scatter(
    data[:, 0], data[:, 1],
    s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5,label='data');
axes[1, 2].scatter(
    impute_utab[impute_samples, 0], impute_utab[impute_samples, 1],
    s=markersize, alpha=alpha, color=color,label='imputed');

for curax in axes.flatten():
    curax.set_yticks([0, 1]);
plt.legend(handlelength=0.4, borderpad=0.4, labelspacing=0.3,framealpha=0.9,handletextpad=0.3);
plt.tight_layout();
plt.savefig('moons-imputation.png');