In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA

mnist = fetch_openml("mnist_784")
X = mnist.data.to_numpy()
y = mnist.target

pca = PCA(n_components=50) 
X_reduced = pca.fit_transform(X)

num_clusters = 10 
gmm = GaussianMixture(n_components=num_clusters, covariance_type="full", random_state=42)
gmm.fit(X_reduced)

clusters = gmm.predict(X_reduced)

def plot_cluster_images(cluster_number, num_samples=10):
    indices = np.where(clusters == cluster_number)[0][:num_samples]
    fig, axes = plt.subplots(1, num_samples))
    for i, idx in enumerate(indices):
        axes[i].imshow(X[idx].reshape(28, 28))  
        axes[i].axis("off")
    plt.suptitle(f"Cluster {cluster_number}")
    plt.show()
    
for i in range(5):
    plot_cluster_images(i)