In [1]:
from diffusers import StableDiffusionXLPipeline
import torch
import os

# Load the model
model_id = "stabilityai/stable-diffusion-xl-base-1.0"  # SDXL 1.0 model on Hugging Face
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")  # Ensure you're using a CUDA-compatible GPU

  from .autonotebook import tqdm as notebook_tqdm
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
Loading pipeline components...:  14%|█▍        | 1/7 [00:00<00:02,  2.40it/s]

In [None]:
import os
from PIL import Image
import io
import json
from tqdm import tqdm
import ipywidgets as widgets
from IPython.display import display, clear_output

# Load prompts from JSON file
with open("prompts.json", "r") as file:
    data = json.load(file)
    existing_prompts_1 = data.get("Counting Accuracy", [])
    existing_prompts_2 = data.get("Size Proportionality", [])
    existing_prompts_3 = data.get("Fractional Representation", [])
    existing_prompts_4 = data.get("Geometric Shape Understanding", [])
    existing_prompts_5 = data.get("Numerical Sequencing", [])

# Set up output directories
output_dir = "./sdxl_generated_images/"
os.makedirs(output_dir, exist_ok=True)
output_dir_1 = os.path.join(output_dir, "Counting Accuracy/")
output_dir_2 = os.path.join(output_dir, "Size Proportionality/")
output_dir_3 = os.path.join(output_dir, "Fractional Representation/")
output_dir_4 = os.path.join(output_dir, "Geometric Shape Understanding/")
output_dir_5 = os.path.join(output_dir, "Numerical Sequencing/")
os.makedirs(output_dir_1, exist_ok=True)
os.makedirs(output_dir_2, exist_ok=True)
os.makedirs(output_dir_3, exist_ok=True)
os.makedirs(output_dir_4, exist_ok=True)
os.makedirs(output_dir_5, exist_ok=True)

# Settings
guidance_scale = 20
num_inference_steps = 50
num_images_per_prompt = 5  # Generate 5 images per prompt

# Function to generate images for a prompt
def generate_images(prompt, output_dir, idx):
    images = []
    for i in range(num_images_per_prompt):
        image = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images[0]
        
        # Convert to PNG format explicitly and save
        filename = os.path.join(output_dir, f"image_prompt_{idx}_{i+1}.png")
        image.save(filename, format="PNG")
        images.append(filename)
    return images

# Function to display and select the best image
def select_best_image(prompt, output_dir, idx):
    images = generate_images(prompt, output_dir, idx)
    
    # Load images as binary data for widgets
    image_widgets = []
    for img_path in images:
        with open(img_path, "rb") as file:
            img_data = file.read()
            img_widget = widgets.Image(value=img_data, format="png", width=200, height=200)
            image_widgets.append(img_widget)

    # Dropdown to select the best image
    dropdown = widgets.Dropdown(
        options=[(f"Image {i + 1}", i) for i in range(len(images))],
        description="Select Best:",
    )

    # Save button
    save_button = widgets.Button(description="Save Selected Image")

    # Display images and widgets
    display(widgets.HBox(image_widgets))
    display(dropdown, save_button)

    # Define save and cleanup action
    def save_and_cleanup(b):
        selected_index = dropdown.value
        selected_image = images[selected_index]
        selected_filename = os.path.join(output_dir, f"counting_{idx}.png")
        Image.open(selected_image).save(selected_filename)
        print(f"Saved best image for prompt {idx} in {selected_filename}")

        # Delete all generated images after selection
        for img_path in images:
            os.remove(img_path)
        
        # Clear output to move to the next prompt
        clear_output()

    # Attach the save action to the button
    save_button.on_click(save_and_cleanup)

# Generate and select images for each prompt set, one at a time
for prompt_set, output_dir in zip(
    [existing_prompts_1],
    [output_dir_1]
):
    for idx, prompt in enumerate(prompt_set):
        print(f"\nPrompt {idx + 1}: {prompt}")
        select_best_image(prompt, output_dir, idx)
