In [None]:
import torch
import random
import os
import shutil
from tqdm import tqdm
from diffusers import StableDiffusionXLPipeline

regular_prompts_list = [
    ...
]
object_name = "teapot"
save_dir = "regular_teapot"


repeat_times = 30

DEVICE = "cuda:0"
torch.cuda.set_device(DEVICE)

In [None]:
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
).to(DEVICE)
pipeline.set_progress_bar_config(disable=True)

In [None]:
# 使用lang-sam完成分割任务  python3.10装包
# git clone https://github.com/mycfhs/lang-segment-anything && cd lang-segment-anything
# python -m pip install -e . --ignore-installed
from lang_sam import LangSAM

model = LangSAM(sam_type="vit_h")  # b, l, h

In [None]:
from torchvision.transforms import ToPILImage
import gc

to_pil_image = ToPILImage()

if os.path.exists(save_dir):
    shutil.rmtree(save_dir)

os.makedirs(save_dir)
for prompt in regular_prompts_list:
    prompt = prompt.replace(" ", "_")
    os.makedirs(f"{save_dir}/{prompt}")

for _ in tqdm(range(repeat_times)):
    random_seed = random.randint(0, 1000000)
    images = pipeline(regular_prompts_list, seed=random_seed).images

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    for image, prompt in zip(images, regular_prompts_list):
        prompt = prompt.replace(" ", "_")

        masks, boxes, phrases, logits = model.predict(image, object_name)
        mask = masks.to(torch.uint8) * 255

        try:
            mask_img = to_pil_image(mask[0])
            mask_img.save(f"{save_dir}/{prompt}/{random_seed}-mask.png")
            image.save(f"{save_dir}/{prompt}/{random_seed}-image.png")
        except:
            print(f"Error img, ignore")
            continue

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()