Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@ def build_sub_model_components(
num_in_channels=num_in_channels,
image_size=image_size,
torch_dtype=torch_dtype,
model_type=model_type,
)
return unet_components

if component_name == "vae":
scaling_factor = kwargs.get("scaling_factor", None)
vae_components = create_diffusers_vae_model_from_ldm(
pipeline_class_name, original_config, checkpoint, image_size, scaling_factor, torch_dtype
pipeline_class_name,
original_config,
checkpoint,
image_size,
scaling_factor,
torch_dtype,
model_type=model_type,
)
return vae_components

Expand Down Expand Up @@ -124,11 +131,12 @@ def build_sub_model_components(
def set_additional_components(
pipeline_class_name,
original_config,
checkpoint=None,
model_type=None,
):
components = {}
if pipeline_class_name in REFINER_PIPELINES:
model_type = infer_model_type(original_config, model_type=model_type)
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
is_refiner = model_type == "SDXL-Refiner"
components.update(
{
Expand Down
86 changes: 72 additions & 14 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EDMDPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
Expand Down Expand Up @@ -175,6 +176,7 @@

LDM_VAE_KEY = "first_stage_model."
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
LDM_UNET_KEY = "model.diffusion_model."
LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
Expand Down Expand Up @@ -305,7 +307,7 @@ def is_valid_url(url):
return original_config


def infer_model_type(original_config, model_type=None):
def infer_model_type(original_config, checkpoint=None, model_type=None):
if model_type is not None:
return model_type

Expand All @@ -323,7 +325,9 @@ def infer_model_type(original_config, model_type=None):

elif has_network_config:
context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
if context_dim == 2048:
if "edm_mean" in checkpoint and "edm_std" in checkpoint:
model_type = "Playground"
elif context_dim == 2048:
model_type = "SDXL"
else:
model_type = "SDXL-Refiner"
Expand All @@ -344,13 +348,13 @@ def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=
return image_size

global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
model_type = infer_model_type(original_config, model_type)
model_type = infer_model_type(original_config, checkpoint, model_type)

if pipeline_class_name == "StableDiffusionUpscalePipeline":
image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
return image_size

elif model_type in ["SDXL", "SDXL-Refiner"]:
elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]:
image_size = 1024
return image_size

Expand Down Expand Up @@ -506,12 +510,14 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
return controlnet_config


def create_vae_diffusers_config(original_config, image_size, scaling_factor=None):
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None, latents_mean=None, latents_std=None):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
if scaling_factor is None and "scale_factor" in original_config["model"]["params"]:
if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
scaling_factor = original_config["model"]["params"]["scale_factor"]
elif scaling_factor is None:
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
Expand All @@ -531,6 +537,8 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
"layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": scaling_factor,
}
if latents_mean is not None and latents_std is not None:
config.update({"latents_mean": latents_mean, "latents_std": latents_std})

return config

Expand Down Expand Up @@ -1172,6 +1180,7 @@ def create_diffusers_unet_model_from_ldm(
extract_ema=False,
image_size=None,
torch_dtype=None,
model_type=None,
):
from ..models import UNet2DConditionModel

Expand All @@ -1190,7 +1199,9 @@ def create_diffusers_unet_model_from_ldm(
else:
num_in_channels = 4

image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
image_size = set_image_size(
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
)
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["in_channels"] = num_in_channels
unet_config["upcast_attention"] = upcast_attention
Expand Down Expand Up @@ -1223,14 +1234,40 @@ def create_diffusers_unet_model_from_ldm(


def create_diffusers_vae_model_from_ldm(
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None, torch_dtype=None
pipeline_class_name,
original_config,
checkpoint,
image_size=None,
scaling_factor=None,
torch_dtype=None,
model_type=None,
):
# import here to avoid circular imports
from ..models import AutoencoderKL

image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
image_size = set_image_size(
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
)
model_type = infer_model_type(original_config, checkpoint, model_type)

vae_config = create_vae_diffusers_config(original_config, image_size=image_size, scaling_factor=scaling_factor)
if model_type == "Playground":
edm_mean = (
checkpoint["edm_mean"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_mean"].tolist()
)
edm_std = (
checkpoint["edm_std"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_std"].tolist()
)
else:
edm_mean = None
edm_std = None

vae_config = create_vae_diffusers_config(
original_config,
image_size=image_size,
scaling_factor=scaling_factor,
latents_mean=edm_mean,
latents_std=edm_std,
)
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
ctx = init_empty_weights if is_accelerate_available() else nullcontext

Expand Down Expand Up @@ -1265,7 +1302,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
local_files_only=False,
torch_dtype=None,
):
model_type = infer_model_type(original_config, model_type=model_type)
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)

if model_type == "FrozenOpenCLIPEmbedder":
config_name = "stabilityai/stable-diffusion-2"
Expand Down Expand Up @@ -1332,7 +1369,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
"text_encoder_2": text_encoder_2,
}

elif model_type == "SDXL":
elif model_type in ["SDXL", "Playground"]:
try:
config_name = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
Expand Down Expand Up @@ -1383,7 +1420,7 @@ def create_scheduler_from_ldm(
model_type=None,
):
scheduler_config = get_default_scheduler_config()
model_type = infer_model_type(original_config, model_type=model_type)
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)

global_step = checkpoint["global_step"] if "global_step" in checkpoint else None

Expand All @@ -1406,7 +1443,8 @@ def create_scheduler_from_ldm(

if model_type in ["SDXL", "SDXL-Refiner"]:
scheduler_type = "euler"

elif model_type == "Playground":
scheduler_type = "edm_dpm_solver_multistep"
else:
beta_start = original_config["model"]["params"].get("linear_start", 0.02)
beta_end = original_config["model"]["params"].get("linear_end", 0.085)
Expand Down Expand Up @@ -1438,6 +1476,26 @@ def create_scheduler_from_ldm(
elif scheduler_type == "ddim":
scheduler = DDIMScheduler.from_config(scheduler_config)

elif scheduler_type == "edm_dpm_solver_multistep":
scheduler_config = {
"algorithm_type": "dpmsolver++",
"dynamic_thresholding_ratio": 0.995,
"euler_at_final": False,
"final_sigmas_type": "zero",
"lower_order_final": True,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"rho": 7.0,
"sample_max_value": 1.0,
"sigma_data": 0.5,
"sigma_max": 80.0,
"sigma_min": 0.002,
"solver_order": 2,
"solver_type": "midpoint",
"thresholding": False,
}
scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)

else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")

Expand Down