In [None]:
import os
import torch

from pipeline_apldm_sdxl import RepLDMSDXLPipeline


gpu_id = 0
image_sizes = [(3072, 3072)]
prompt = 'A Renaissance noblewoman, portrayed in an elegant gown with intricate embroidery. Her expression is thoughtful, and her eyes are deep and insightful. The background is a lush Italian garden, reflecting the artistic style of the High Renaissance.'
init_rates = [0.8, 0.8]
attn_guidance_scale = 0.005
attn_guidance_density = [0]*31 + [1]*15 + [0]*4  # [1]*47 + [0]*3
attn_guidance_decay = None  # ('cosine', 0, 3)
multi_encoder = True
multi_decoder = True
models_to_cpu = True
random_seed = 523
num_images_per_prompt = 1
num_inference_steps = 50
num_resample_timesteps = 50
show_image = True
save_image = True


if __name__ == '__main__':
    print(f'CUDA available: {torch.cuda.is_available()}')
    device = f'cuda:{gpu_id}' if gpu_id is not None else 'cpu'
    
    negative_prompt = "blurry, ugly, duplicate, poorly drawn, deformed, mosaic"
    
    model_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
    pipe = RepLDMSDXLPipeline.from_pretrained(
        model_ckpt,
        torch_dtype=torch.float16,
        variant="fp16",
        cache_dir='../huggingface_models',
        local_files_only=True
    )
    pipe = pipe.to(device)
    attn_guidance_density = tuple(attn_guidance_density)

    with torch.no_grad():
        for idx1, image_size in enumerate(image_sizes):
            save_dir = f'./cases/init_rates_{init_rates}/guidance_scale_{attn_guidance_scale}'
            torch.manual_seed(random_seed)
            generator = torch.Generator(device)
            save_dir = os.path.join(save_dir, f'{str(image_size)}')
            os.makedirs(save_dir, exist_ok=True)
            for idx2 in range(num_images_per_prompt):
                torch.cuda.empty_cache()
                seed = torch.randint(0, 10000, (1,)).item()
                generator = generator.manual_seed(seed)
                images = pipe(
                    prompt, negative_prompt=negative_prompt, generator=generator,
                    height=image_size[0], width=image_size[1],
                    num_inference_steps=num_inference_steps, guidance_scale = 7.5,
                    show_image=show_image,
                    models_to_cpu=models_to_cpu, multi_decoder=multi_decoder, multi_encoder=multi_encoder,
                    num_resample_timesteps = num_resample_timesteps,
                    init_rates = init_rates,
                    attn_type = 'vanilla',
                    attn_guidance_scale = attn_guidance_scale,
                    attn_guidance_density = attn_guidance_density,
                    attn_guidance_decay = attn_guidance_decay,
                    power_calibrate = 0,
                    attn_guidance_filter = None,
                )
                if save_image:  
                    file_name = f'{idx2}.png'
                    images[-1].save(os.path.join(save_dir, file_name))
                    del images
    print('END')