In [None]:
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor, XGBClassifier
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# Load MNIST data
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist["data"], mnist["target"].astype(np.uint8)

# Filter for images of the digit 5
X_5 = X[y == 5]

# Split data for a small experiment (to reduce computational load)
# In practice, you'd use much more data
X_train_5, X_test_5 = train_test_split(X_5, test_size=0.95, random_state=42)


In [None]:
X_train_5 = X_train_5.to_numpy()


In [None]:
X_train_5_bw = np.round(X_train_5/255., 0).astype(np.uint8)

In [None]:
plt.imshow(X_train_5_bw[0].reshape(28, 28), cmap='gray')
plt.title("Generated Image of Digit 5")
plt.show()

In [None]:
# Train a model for each pixel
models = []
for i in range(X_train_5_bw.shape[1]):  # For each pixel
    print(f"Training model for pixel {i+1}/{X_train_5_bw.shape[1]}\r", end="")
    y_train = X_train_5_bw[:, i]
    X_train = np.delete(X_train_5_bw, i, axis=1)  # Use all other pixels as features
    model = XGBClassifier(objective='reg:squarederror', n_estimators=10, max_depth=26, device='cuda')  # Simplified model
    model.fit(X_train, y_train)
    models.append(model)

In [None]:

# Generate a new image
# Start with a random image or the mean image; here we use a random one
new_image =  np.random.choice([0, 1], size=(1, 28*28))
#new_image = np.zeros((1, X_train_5_bw.shape[1]))  # Start with a black image
for _ in range(1000):  # Iterate to refine the image
    for i in np.random.permutation(X_train_5_bw.shape[1]):  # Random order of pixels
        X_gen = np.delete(new_image, i, axis=1)
        p = np.clip(models[i].predict_proba(X_gen)[0][1], 0, 1)
        #print(p)
        new_image[:, i] = np.random.choice([0,1], p=[1-p, p])

    # Transform the generated image back to the original space
    print_image = new_image

    # Plot the generated image
    plt.imshow(print_image.reshape(28, 28), cmap='gray')
    plt.title("Generated Image of Digit 5")
    plt.show()
