In [None]:
def plot_batch_results(output, save_path, step):
    images = output["images"]  # Original images
    recons_images = output["generated"]  # Reconstructed images
    masks = output["masks"]  # Masks for each slot
    batch_size = images.shape[0]
    num_slots = masks.shape[1]

    for i in range(batch_size):
        orig_image = images[i].permute(1, 2, 0).cpu().numpy()
        recons_image = recons_images[i].permute(1, 2, 0).cpu().numpy()
        batch_masks = masks[i]  # Shape: (num_slots, H, W)

        # Initialize the figure
        fig, axes = plt.subplots(2, num_slots + 1, figsize=(15, 6))

        # Plot original and reconstructed images
        axes[0, 0].imshow(orig_image)
        axes[0, 0].set_title("Original Image")
        axes[0, 0].axis("off")

        axes[0, 1].imshow(recons_image)
        axes[0, 1].set_title("Reconstructed Image")
        axes[0, 1].axis("off")

        # Plot all slots combined with random colors
        combined_slots = np.zeros_like(orig_image)
        for j in range(num_slots):
            mask = batch_masks[j]
            color = np.random.rand(3)  # Random color for each slot
            combined_slots += mask[..., None] * color

        axes[1, 0].imshow(combined_slots)
        axes[1, 0].set_title("All Slots Combined")
        axes[1, 0].axis("off")

        # Plot individual slot masks
        for j in range(num_slots):
            axes[1, j + 1].imshow(batch_masks[j], cmap="viridis")
            axes[1, j + 1].set_title(f"Slot {j + 1}")
            axes[1, j + 1].axis("off")

        plt.tight_layout()
        plt.savefig(f"{save_path}/results_step_{step}_batch_{i}.png")
        plt.close(fig)
