Skip to content
18 changes: 15 additions & 3 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,17 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)

if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)

allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True

user_agent = {
"file_type": "attn_procs_weights",
Expand All @@ -151,7 +162,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (is_safetensors_available() and weight_name is None) or (
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
Expand All @@ -169,10 +180,11 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except EnvironmentError:
except IOError as e:
if not allow_pickle:
raise e
Comment on lines -172 to +185
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten before you merge can you double check this error change is ok?

# try loading non-safetensors weights
pass

if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
Expand Down
21 changes: 19 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
use_safetensors (`bool`, *optional* ):
If set to `True`, the pipeline will forcibly load the models from `safetensors` weights. If set to
`None` (the default). The pipeline will load using `safetensors` if safetensors weights are available
*and* if `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
Comment on lines +395 to +398
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!


<Tip>

Expand Down Expand Up @@ -423,6 +427,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)

if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)

allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True

if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
Expand Down Expand Up @@ -509,7 +524,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else:
if is_safetensors_available():
if use_safetensors:
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
Expand All @@ -525,7 +540,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
user_agent=user_agent,
commit_hash=commit_hash,
)
except: # noqa: E722
except IOError as e:
if not allow_pickle:
raise e
pass
if model_file is None:
model_file = _get_model_file(
Expand Down
26 changes: 25 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
use_safetensors (`bool`, *optional* ):
If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
`safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines
Expand Down Expand Up @@ -752,6 +756,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
kwargs.pop("use_safetensors", None if is_safetensors_available() else False)

# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
Expand Down Expand Up @@ -1068,6 +1073,17 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
from_flax = kwargs.pop("from_flax", False)
custom_pipeline = kwargs.pop("custom_pipeline", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)

if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)

allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True

pipeline_is_cached = False
allow_patterns = None
Expand Down Expand Up @@ -1123,9 +1139,17 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
CUSTOM_PIPELINE_FILE_NAME,
]

if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(model_filenames, variant=variant)
):
raise EnvironmentError(
f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant):
ignore_patterns = ["*.bin", "*.msgpack"]

safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
Expand Down
39 changes: 38 additions & 1 deletion tests/models/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def test_lora_save_load_safetensors(self):
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4

def test_lora_save_load_safetensors_load_torch(self):
def test_lora_save_safetensors_load_torch(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

Expand Down Expand Up @@ -475,6 +475,43 @@ def test_lora_save_load_safetensors_load_torch(self):
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")

def test_lora_save_torch_force_load_safetensors_error(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["attention_head_dim"] = (8, 16)

torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)

lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)

model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
with self.assertRaises(IOError) as e:
new_model.load_attn_procs(tmpdirname, use_safetensors=True)
self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception))

def test_lora_on_off(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand Down
11 changes: 11 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ def test_download_only_pytorch(self):
# We need to never convert this tiny model to safetensors for this test to pass
assert not any(f.endswith(".safetensors") for f in files)

def test_force_safetensors_error(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
with self.assertRaises(EnvironmentError):
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe-no-safetensors",
safety_checker=None,
cache_dir=tmpdirname,
use_safetensors=True,
)

def test_returned_cached_folder(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained(
Expand Down