In [1]:
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import os
from tqdm.auto import tqdm
import random

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [3]:
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [4]:
classes = ["dog", "cat", "bird"]

In [5]:
prompts = {
    "dog": [
        "A high-quality photograph of a {breed} dog in {setting}",
        "Close-up portrait of a {breed} dog with {feature}",
        "{breed} dog playing in {environment}",
        "A {breed} dog {action} in {lighting} conditions"
    ],
    "cat": [
        "A detailed image of a {breed} cat in {setting}",
        "Close-up of a {breed} cat's face showing {feature}",
        "{breed} cat lounging in {environment}",
        "A {breed} cat {action} under {lighting} light"
    ],
    "bird": [
        "A clear photograph of a {species} bird perched on {perch}",
        "Close-up of a {species} bird's feathers showing {feature}",
        "{species} bird in flight against {background}",
        "A {species} bird {action} in its natural {environment}"
    ]
}


In [6]:
base_dir = "./data/synthetic/cifar10"
os.makedirs(base_dir, exist_ok=True)

In [7]:
num_images_per_class = 1000
batch_size = 10

In [8]:
def get_random_prompt(class_name):
    template = random.choice(prompts[class_name])
    if class_name == "dog":
        return template.format(
            breed=random.choice(["Golden Retriever", "German Shepherd", "Poodle", "Bulldog", "Labrador", "Husky"]),
            setting=random.choice(["a park", "a beach", "a snowy landscape", "a forest", "a city street"]),
            feature=random.choice(["expressive eyes", "fluffy fur", "floppy ears", "a wagging tail"]),
            environment=random.choice(["a backyard", "a dog park", "a lake", "mountains"]),
            action=random.choice(["running", "jumping", "sleeping", "playing fetch"]),
            lighting=random.choice(["natural", "golden hour", "overcast", "studio"])
        )
    elif class_name == "cat":
        return template.format(
            breed=random.choice(["Siamese", "Persian", "Maine Coon", "Tabby", "Sphynx", "Bengal"]),
            setting=random.choice(["a sunny windowsill", "a cozy living room", "a garden", "a bookshelf"]),
            feature=random.choice(["striking eyes", "soft fur", "long whiskers", "a fluffy tail"]),
            environment=random.choice(["a cat tree", "a cardboard box", "a cozy bed", "a garden"]),
            action=random.choice(["grooming", "stretching", "napping", "playing with a toy"]),
            lighting=random.choice(["soft", "warm", "dramatic", "natural"])
        )
    elif class_name == "bird":
        return template.format(
            species=random.choice(["Blue Jay", "Cardinal", "Hummingbird", "Eagle", "Parrot", "Owl"]),
            perch=random.choice(["a tree branch", "a bird feeder", "a telephone wire", "a rocky cliff"]),
            feature=random.choice(["vibrant colors", "intricate patterns", "sleek feathers", "a sharp beak"]),
            background=random.choice(["a clear sky", "a forest canopy", "a sunset", "snow-capped mountains"]),
            action=random.choice(["singing", "feeding", "building a nest", "taking off"]),
            environment=random.choice(["a tropical rainforest", "a desert landscape", "a wetland", "an urban park"])
        )
    else:
        raise ValueError(f"Unknown class: {class_name}")


In [9]:
from collections import deque

class_queue = deque(classes)
all_complete = False

while not all_complete:
    all_complete = True
    
    for _ in range(len(classes)):
        class_name = class_queue[0]
        class_dir = os.path.join(base_dir, class_name)
        os.makedirs(class_dir, exist_ok=True)

        existing_images = len([f for f in os.listdir(class_dir) if f.endswith('.png')])
        images_to_generate = max(0, num_images_per_class - existing_images)

        if images_to_generate > 0:
            all_complete = False
            print(f"\nGenerating images for class: {class_name}")
            print(f"Existing images: {existing_images}")
            print(f"Images to generate: {images_to_generate}")

            batch_size = min(batch_size, images_to_generate)
            prompt = [get_random_prompt(class_name) for _ in range(batch_size)]
            
            batch_images = pipe(prompt).images

            for j, image in enumerate(batch_images):
                image_index = existing_images + j
                image.save(f"{class_dir}/{class_name}_{image_index:04d}.png")

            print(f"Completed batch for {class_name}")

        class_queue.rotate(-1)  # Move to the next class

print("Finished generating all images for all classes")



Generating images for class: dog
Existing images: 510
Images to generate: 490


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 480
Images to generate: 520


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Completed batch for cat

Generating images for class: bird
Existing images: 450
Images to generate: 550


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 520
Images to generate: 480


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 490
Images to generate: 510


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for cat

Generating images for class: bird
Existing images: 460
Images to generate: 540


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 530
Images to generate: 470


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 500
Images to generate: 500


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Completed batch for cat

Generating images for class: bird
Existing images: 470
Images to generate: 530


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 540
Images to generate: 460


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 510
Images to generate: 490


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Completed batch for cat

Generating images for class: bird
Existing images: 480
Images to generate: 520


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 550
Images to generate: 450


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 520
Images to generate: 480


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Completed batch for cat

Generating images for class: bird
Existing images: 490
Images to generate: 510


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 560
Images to generate: 440


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 530
Images to generate: 470


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Completed batch for cat

Generating images for class: bird
Existing images: 500
Images to generate: 500


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 570
Images to generate: 430


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 540
Images to generate: 460


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Completed batch for cat

Generating images for class: bird
Existing images: 510
Images to generate: 490


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 580
Images to generate: 420


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 550
Images to generate: 450


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for cat

Generating images for class: bird
Existing images: 520
Images to generate: 480


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 590
Images to generate: 410


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 560
Images to generate: 440


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Completed batch for cat

Generating images for class: bird
Existing images: 530
Images to generate: 470


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 600
Images to generate: 400


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 570
Images to generate: 430


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for cat

Generating images for class: bird
Existing images: 540
Images to generate: 460


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 610
Images to generate: 390


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 580
Images to generate: 420


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for cat

Generating images for class: bird
Existing images: 550
Images to generate: 450


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 620
Images to generate: 380


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 590
Images to generate: 410


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Completed batch for cat

Generating images for class: bird
Existing images: 560
Images to generate: 440


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 630
Images to generate: 370


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 600
Images to generate: 400


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Completed batch for cat

Generating images for class: bird
Existing images: 570
Images to generate: 430


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: dog
Existing images: 640
Images to generate: 360


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for dog

Generating images for class: cat
Existing images: 610
Images to generate: 390


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Completed batch for cat

Generating images for class: bird
Existing images: 950
Images to generate: 50


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: cat
Existing images: 990
Images to generate: 10


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Completed batch for cat

Generating images for class: bird
Existing images: 960
Images to generate: 40


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: bird
Existing images: 970
Images to generate: 30


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: bird
Existing images: 980
Images to generate: 20


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird

Generating images for class: bird
Existing images: 990
Images to generate: 10


  0%|          | 0/50 [00:00<?, ?it/s]

Completed batch for bird
Finished generating all images for all classes


In [None]:
plt.figure(figsize=(15, 5))
for i, class_name in enumerate(classes):
    img_path = os.path.join(base_dir, class_name, f"{class_name}_0000.png")
    img = Image.open(img_path)
    plt.subplot(1, 3, i+1)
    plt.imshow(img)
    plt.title(class_name)
    plt.axis('off')
plt.tight_layout()
plt.show()
