In [2]:
import torch
import random
import os
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch
from pipeline_rf import RectifiedFlowPipeline
from diffusers import AutoPipelineForText2Image
from dataclasses import dataclass
from generate import generate_single_image, load_model

# config

In [3]:
@dataclass
class TestConfig:
    device = "cuda"
    save_path = "/home/liutao/workspace/data/diversity/"
    sd_id = "/data/model/stable-diffusion-2-1"
    instaflow_id = "/data/model/instaflow_0_9B_from_sd_1_5"
    lcm_id = "/data/model/LCM_Dreamshaper_v7"
    sdxl_turbo_id = "/data/model/sdxl-turbo"
    ours_path = "/data/"
    ours_model_id = "/data/20231212/SwiftBrush_reproduce_se_parallel/checkpoints_20240228/vsd_global_step9000_4nis.pth"
    # test_prompt = ["A group of mushrooms that are sitting on the ground", "A man with a beard wearing glasses and a hat in blue shirt"]
    test_prompt = ["A small waterfall in the middle of rocks, an airbrush painting", "A oil painting of red roses in a blue vase", "A brown and white dog running through water",
                    "Reflection of the glass cube building and water surface", "A butterfly on flowers", "Milk and a sandwich with knife on a table", "A cute cat", "A cute dog"]
    items = ["sd", "lcm", "sdxl_turbo","instaflow"]
    sd_steps = [50]
    instaflow_steps = [1]
    lcm_steps = [1, 4]
    sdxl_turbo_steps = [1, 4]
    ours_steps = [4]
    sd_size = 768
    distill_size = 512
    sample_nums = 1000
    start_sample_num = 1000
    end_sapmle_num = 2000
    guidance_scale = 0.0
    file_extension = ".jpg"
config = TestConfig()

# baseline

In [5]:
# Iterate over the list with different steps
for item in config.items:
    if item == "sd":
        steps = config.sd_steps
        scheduler = EulerDiscreteScheduler.from_pretrained(config.sd_id, subfolder="scheduler")
        pipe = StableDiffusionPipeline.from_pretrained(config.sd_id, scheduler=scheduler, torch_dtype=torch.float16)
        height = width = config.sd_size
    elif item == "lcm":
        steps = config.lcm_steps
        pipe = DiffusionPipeline.from_pretrained(config.lcm_id, safety_checker=None, requires_safety_checker=False, torch_dtype=torch.float16)
        height = width = config.distill_size
    elif item == "sdxl_turbo":
        steps = config.sdxl_turbo_steps
        pipe = AutoPipelineForText2Image.from_pretrained(config.sdxl_turbo_id, torch_dtype=torch.float16, variant="fp16")
        height = width = config.distill_size
    elif item == "instaflow":
        steps = config.instaflow_steps
        pipe = RectifiedFlowPipeline.from_pretrained(config.instaflow_id, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False) 
        height = width = config.distill_size
    pipe = pipe.to(config.device)
    pipe.set_progress_bar_config(disable=True)
    for step in steps:
        print(f"{item}:{step}")
        dir_name = item+str(step)
        directory = os.path.join(config.save_path,dir_name)
        if not os.path.exists(directory):
            os.makedirs(directory)

        for prompt in config.test_prompt:
            for i in range(config.start_sample_num, config.end_sapmle_num):
                image = pipe(prompt=prompt, num_inference_steps=step, guidance_scale=config.guidance_scale, height=height, width=width).images[0]
                image_name = prompt+"_"+str(i)+config.file_extension
                image.save(os.path.join(directory,image_name))
            print(prompt)

    torch.cuda.empty_cache()    
        

Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00,  8.99it/s]


sd:50


Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 11.89it/s]


lcm:1
lcm:4


Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 15.24it/s]


sdxl_turbo:1
sdxl_turbo:4


Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 11.15it/s]


instaflow:1


# ours

In [None]:
item = "ours"
steps = config.ours_steps
vae, tokenizer, text_encoder, unet, scheduler, alphas = load_model(config.ours_path, config.ours_model_id, config.device)
for step in steps:
    print(f"{item}:{step}")
    dir_name = item+str(step)
    directory = os.path.join(config.save_path,dir_name)
    if not os.path.exists(directory):
        os.makedirs(directory)

    for prompt in config.test_prompt:
        for i in range(config.start_sample_num, config.end_sapmle_num):
            image = generate_single_image(network=(vae, tokenizer, text_encoder, unet, scheduler),prompt=prompt, seed=random.randint(1, 1e11),num_inference_steps=step)
            image_name = prompt+"_"+str(i)+config.file_extension
            image.save(os.path.join(directory,image_name))
        print(prompt)
    
torch.cuda.empty_cache()