In this notebook, we will optimize the prompts directly to increase FLIPD. Use the following block to specify your prompt and parameters:

In [1]:
prompt = "Air Conditioners & Parts"
seed = 0
num_inference_steps = 50
guidance_scale = 7.5
num_images_per_prompt = 4
image_size = 512

In [2]:
import torch
import mediapy as media
from diffusers import DDIMScheduler
import os

try:
    from local_sd_pipeline import LocalStableDiffusionPipeline
    from optim_utils import *
except ModuleNotFoundError:
    import os; os.chdir("..")
    from local_sd_pipeline import LocalStableDiffusionPipeline
    from optim_utils import *

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# load model
device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")

model_id = "CompVis/stable-diffusion-v1-4"

pipe = LocalStableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    safety_checker=None,
    requires_safety_checker=False,
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
print(device)
pipe = pipe.to(device)

Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 10.66it/s]


cuda:1


## Before Optimization

Here is 4 different generations of the diffusion model without any change to the prompt.

In [4]:
set_random_seed(seed)

outputs, track_stats = pipe(
    prompt,
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale,
    num_images_per_prompt=num_images_per_prompt,
    track_noise_norm=True,
)
outputs = outputs.images

print(f"prompt: {prompt}")
media.show_images(outputs, width=300)

100%|██████████| 50/50 [00:03<00:00, 13.24it/s]


prompt: Air Conditioners & Parts


## After Optimizing the Prompt to Increase FLIPD

Here we optimize the prompt to increase FLIPD (mirroring the results from Figure 9 of [the paper](https://arxiv.org/pdf/2411.00113)). As you can see, when the prompt is optimized, the generated images seem more and more chaotic, until they converge to images with many textures and artificially high intrinsic dimensionality. 

In [8]:
method = "flipd"
for num_steps in [1, 5, 10, 20, 25]:
    auged_prompt_embeds = pipe.aug_prompt(
        prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=1 if method == "flipd" else num_images_per_prompt,
        target_steps=[40],
        lr=0.05,
        optim_iters=num_steps,
        print_optim=True,
        method=method,
    )
    set_random_seed(seed)

    outputs, track_stats = pipe(
        prompt_embeds=auged_prompt_embeds,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
        track_noise_norm=True,
        height=image_size,
        width=image_size,
    )
    outputs = outputs.images
    media.show_images(outputs, width=300)


 80%|████████  | 40/50 [00:01<00:00, 20.43it/s, loss=-7.79e+3]
100%|██████████| 50/50 [00:03<00:00, 13.34it/s]


 80%|████████  | 40/50 [00:04<00:01,  8.40it/s, loss=-5.45e+4]
100%|██████████| 50/50 [00:03<00:00, 13.10it/s]


 80%|████████  | 40/50 [00:08<00:02,  4.68it/s, loss=-1.19e+5]
100%|██████████| 50/50 [00:03<00:00, 12.73it/s]


 80%|████████  | 40/50 [00:15<00:03,  2.53it/s, loss=-4.43e+5]
100%|██████████| 50/50 [00:04<00:00, 12.35it/s]


 80%|████████  | 40/50 [00:19<00:04,  2.07it/s, loss=-7.19e+5]
100%|██████████| 50/50 [00:04<00:00, 12.15it/s]
