In [None]:
fig, axes = plt.subplots(len(N_values), len(K_values), figsize=(10, 12), sharey=True)

for ax_ in axes.flatten():
    ax_.grid(True, which='both', color='lightgrey', linestyle='--', linewidth=0.5, alpha=0.8)

for i, D in enumerate(D_values):
    for j, K in enumerate(K_values):
        ax = axes[i, j]
        # Add a light gray grid to all subplots
        all_metrics_gmm = []
        all_metrics_kmeans = []
        all_metrics_gmm_init = []
        for N in N_values:
            metrics_key = f"N_{N}_K_{K}_D_{D}"
            metrics_gmm = np.array(acc_data_points_gmm.get(metrics_key, []))  # Replace with your data retrieval logic
            all_metrics_gmm.append(metrics_gmm)
        
        for N in N_values:
            metrics_key = f"N_{N}_K_{K}_D_{D}"
            metrics_kmeans = np.array(acc_data_points_kmeans.get(metrics_key, []))  # Replace with your data retrieval logic
            all_metrics_kmeans.append(metrics_kmeans)

        for N in N_values:
            metrics_key = f"N_{N}_K_{K}_D_{D}"
            metrics_gmm_init = np.array(acc_data_points_gmm_init.get(metrics_key, []))  # Replace with your data retrieval logic
            all_metrics_gmm_init.append(metrics_gmm_init)

        if any(metrics.size > 0 for metrics in all_metrics_gmm + all_metrics_kmeans):
            # Adjusted positions for six boxplots
            scale_factor = 1.8  # Increase this to add more spacing between groups
            base_positions = np.arange(len(N_values)) * scale_factor  # Space out base positions

            n_positions_gmm = base_positions - 0.3  # GMM boxplots slightly left of center
            n_positions_kmeans = base_positions  # KMeans boxplots centered
            n_positions_gmm_init = base_positions + 0.3  # GMM init boxplots slightly right of center

            # Create boxplots for GMM
            box_gmm = ax.boxplot(
                all_metrics_gmm, 
                positions=n_positions_gmm, 
                # notch=False, 
                patch_artist=True, 
                widths=0.3
            )
            for patch in box_gmm['boxes']:
                patch.set_facecolor('red')
                patch.set_alpha(0.5)
                patch.set_edgecolor('black')
                patch.set_linewidth(1)
            for median in box_gmm['medians']:
                median.set_color('black')
                median.set_linewidth(1)

            # Create boxplots for KMeans
            box_kmeans = ax.boxplot(
                all_metrics_kmeans, 
                positions=n_positions_kmeans, 
                notch=False, 
                patch_artist=True, 
                widths=0.3
            )
            for patch in box_kmeans['boxes']:
                patch.set_facecolor('blue')
                patch.set_alpha(0.5)
                patch.set_edgecolor('black')
                patch.set_linewidth(1)
            for median in box_kmeans['medians']:
                median.set_color('black')
                median.set_linewidth(1)
            
            # Create boxplots for GMM init
            box_gmm_init = ax.boxplot(
                all_metrics_gmm_init, 
                positions=n_positions_gmm_init, 
                notch=False, 
                patch_artist=True, 
                widths=0.3
            )
            for patch in box_gmm_init['boxes']:
                patch.set_facecolor('green')
                patch.set_alpha(0.5)
                patch.set_edgecolor('black')
                patch.set_linewidth(1)
            for median in box_gmm_init['medians']:
                median.set_color('black')
                median.set_linewidth(1)

            # # Add scatter points for individual data points (GMM)
            # for k, (metrics, pos) in enumerate(zip(all_metrics_gmm, n_positions_gmm)):
            #     jitter = np.random.normal(0, 0.05, size=len(metrics))
            #     ax.scatter(pos + jitter, metrics, alpha=0.7, color='black', s=20)

            # # Add scatter points for individual data points (KMeans)
            # for k, (metrics, pos) in enumerate(zip(all_metrics_kmeans, n_positions_kmeans)):
            #     jitter = np.random.normal(0, 0.05, size=len(metrics))
            #     ax.scatter(pos + jitter, metrics, alpha=0.7, color='black', s=20)

            # Set x-ticks for N values
            if i == len(D_values) - 1:  # Only for the bottom row
                x_tick_positions = (n_positions_gmm + n_positions_kmeans + n_positions_gmm_init) / 3
                ax.set_xticks(x_tick_positions)
                ax.set_xticklabels([str(N) for N in N_values], fontsize=12)
            else:
                ax.set_xticks([])
            ax.set_ylim([0.0, 1.01])  # Adjust y-axis limits
            ax.grid(True, which='both', color='lightgrey', linestyle='--', linewidth=0.5, alpha=0.8)
        else:
            ax.set_xticks([])

        ax.tick_params(axis='y', labelsize=12)

for j, K in enumerate(K_values):
    fig.text(0.25 + j * 0.26, 0.90, f'{K}', va='center', ha='center', fontsize=12)

fig.text(0.5, 0.95, 'Distribution type accuracy per single data point', ha='center', va='center', fontsize=16)
fig.text(0.5, 0.92, 'Number of clusters', va='center', ha='center', fontsize=15)

fig.text(0.95, 0.5, 'Number of samples', va='center', ha='center', rotation=-90, fontsize=15)
for i, D in enumerate(D_values):
    fig.text(0.91, 0.75 - i * 0.25, f'{D}', va='center', ha='center', rotation=-90, fontsize=12)

fig.text(0.08, 0.5, 'Accuracy per data point', va='center', rotation='vertical', fontsize=12)
fig.text(0.5, 0.09, 'Number of data points', ha='center', fontsize=12)

legend_elements = [
    Patch(facecolor='red', edgecolor='black', alpha=0.5, label='Inference (GMM init)'),
    Patch(facecolor='blue', edgecolor='black', alpha=0.5, label='Inference (KMeans init)'),
    Patch(facecolor='green', edgecolor='black', alpha=0.5, label='GMM')
]

fig.legend(
    handles=legend_elements,
    loc='lower center',  # Position the legend below the plot
    bbox_to_anchor=(0.5, 0.05),  # Center the legend horizontally below the plot
    ncol=3,  # Arrange entries in two columns for horizontal layout
    fontsize=10,  # Font size for the legend labels
)
# Adjust layout to prevent overlap
plt.tight_layout(rect=[0.1, 0.1, 0.9, 0.9])

plt.ylim([0.3, 1.05])
plt.show()
