In [1]:
from missforest import MissForest

import numpy as np
import pandas as pd

In [2]:
# Set a random seed for reproducibility
np.random.seed(0)

# Define the size of the dataset
n_samples = 100
n_features = 5

# Generate random data
X_num = np.random.randn(n_samples, n_features)
X_cat = np.random.choice(['A', 'B', 'C'], size=(n_samples, n_features))

# Create pandas DataFrame
df_num = pd.DataFrame(X_num, columns=[f'num_feature_{i}' for i in range(n_features)])
df_cat = pd.DataFrame(X_cat, columns=[f'cat_feature_{i}' for i in range(n_features)])

df_true = pd.concat([df_num, df_cat], axis=1)

# Insert some missing values
missing_rate = 0.1
n_missing_samples = int(np.floor(n_samples * n_features * missing_rate))

missing_samples_row = np.random.randint(0, n_samples, n_missing_samples)
missing_samples_col = np.random.randint(0, n_features, n_missing_samples)

for i in range(n_missing_samples):
    # print(missing_samples_row[i], missing_samples_col[i])
    df_num.iloc[missing_samples_row[i], missing_samples_col[i]] = np.nan
    df_cat.iloc[missing_samples_row[i], missing_samples_col[i]] = np.nan

df_cat = df_cat.astype('category')
df_missed = pd.concat([df_num, df_cat], axis=1)

df_missed.head()


Unnamed: 0,num_feature_0,num_feature_1,num_feature_2,num_feature_3,num_feature_4,cat_feature_0,cat_feature_1,cat_feature_2,cat_feature_3,cat_feature_4
0,1.764052,0.400157,0.978738,2.240893,1.867558,C,B,B,A,B
1,-0.977278,0.950088,-0.151357,-0.103219,0.410599,C,C,C,B,A
2,0.144044,1.454274,0.761038,0.121675,0.443863,C,C,B,B,B
3,0.333674,1.494079,-0.205158,0.313068,-0.854096,B,A,B,A,A
4,-2.55299,0.653619,0.864436,-0.742165,,C,C,B,C,


In [7]:
miss_forest = MissForest(n_imputations=5, max_iter=10, keep_categorical=True)
df_imputed = miss_forest.fit_transform(df_missed)
df_imputed[0].head()

Unnamed: 0,num_feature_0,num_feature_1,num_feature_2,num_feature_3,num_feature_4,cat_feature_0,cat_feature_1,cat_feature_2,cat_feature_3,cat_feature_4
0,1.764052,0.400157,0.978738,2.240893,1.867558,C,B,B,A,B
1,-0.977278,0.950088,-0.151357,-0.103219,0.410599,C,C,C,B,A
2,0.144044,1.454274,0.761038,0.121675,0.443863,C,C,B,B,B
3,0.333674,1.494079,-0.205158,0.313068,-0.854096,B,A,B,A,A
4,-2.55299,0.653619,0.864436,-0.742165,-0.363995,C,C,B,C,B


In [8]:
miss_forest = MissForest(n_imputations=5, max_iter=10, keep_categorical=False)
df_imputed = miss_forest.fit_transform(df_missed)
df_imputed[0].head()

Unnamed: 0,num_feature_0,num_feature_1,num_feature_2,num_feature_3,num_feature_4,cat_feature_0,cat_feature_1,cat_feature_2,cat_feature_3,cat_feature_4
0,1.764052,0.400157,0.978738,2.240893,1.867558,2.0,1.0,1.0,0.0,1.0
1,-0.977278,0.950088,-0.151357,-0.103219,0.410599,2.0,2.0,2.0,1.0,0.0
2,0.144044,1.454274,0.761038,0.121675,0.443863,2.0,2.0,1.0,1.0,1.0
3,0.333674,1.494079,-0.205158,0.313068,-0.854096,1.0,0.0,1.0,0.0,0.0
4,-2.55299,0.653619,0.864436,-0.742165,0.107621,2.0,2.0,1.0,2.0,1.0
