Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Security patch: Scan all pickle files, including VAEs; default to safetensor loading #3011

Merged
merged 3 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions ldm/invoke/ckpt_to_diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,18 +327,18 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema:
print(
' | Extracting EMA weights (usually better for inference)'
' | Extracting EMA weights (usually better for inference)'
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
print(
' | Extracting only the non-EMA weights (usually better for fine-tuning)'
' | Extracting only the non-EMA weights (usually better for fine-tuning)'
)

for key in keys:
Expand Down Expand Up @@ -809,6 +809,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
vae:AutoencoderKL=None,
precision:torch.dtype=torch.float32,
return_generator_pipeline:bool=False,
scan_needed:bool=True,
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
'''
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
Expand Down Expand Up @@ -843,15 +844,20 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()

checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
if Path(checkpoint_path).suffix == '.ckpt':
if scan_needed:
ModelManager.scan_model(checkpoint_path,checkpoint_path)
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = load_file(checkpoint_path)
cache_dir = global_cache_dir('hub')
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline

# Sometimes models don't have the global_step item
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print(" | global_step key not found in model")
print(" | global_step key not found in model")
global_step = None

# sometimes there is a state_dict key and sometimes not
Expand Down Expand Up @@ -953,14 +959,14 @@ def load_pipeline_from_original_stable_diffusion_ckpt(

# Convert the VAE model, or use the one passed
if not vae:
print(' | Using checkpoint model\'s original VAE')
print(' | Using checkpoint model\'s original VAE')
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)

vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
else:
print(' | Using external VAE specified in config')
print(' | Using VAE specified in config')

# Convert the text model.
model_type = pipeline_type
Expand Down
100 changes: 49 additions & 51 deletions ldm/invoke/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,13 @@ def del_model(self, model_name: str, delete_files: bool = False) -> None:
self.stack.remove(model_name)
if delete_files:
if weights:
print(f"** deleting file {weights}")
print(f"** Deleting file {weights}")
Path(weights).unlink(missing_ok=True)
elif path:
print(f"** deleting directory {path}")
print(f"** Deleting directory {path}")
rmtree(path, ignore_errors=True)
elif repo_id:
print(f"** deleting the cached model directory for {repo_id}")
print(f"** Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id)

def add_model(
Expand Down Expand Up @@ -420,11 +420,6 @@ def _load_ckpt_model(self, model_name, mconfig):
"NOHASH",
)

# scan model
self.scan_model(model_name, weights)

print(f">> Loading {model_name} from {weights}")

# for usage statistics
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
Expand All @@ -438,10 +433,13 @@ def _load_ckpt_model(self, model_name, mconfig):
weight_bytes = f.read()
model_hash = self._cached_sha256(weights, weight_bytes)
sd = None
if weights.endswith(".safetensors"):
sd = safetensors.torch.load(weight_bytes)
else:

if weights.endswith(".ckpt"):
self.scan_model(model_name, weights)
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
else:
sd = safetensors.torch.load(weight_bytes)

del weight_bytes
# merged models from auto11 merge board are flat for some reason
if "state_dict" in sd:
Expand All @@ -464,18 +462,12 @@ def _load_ckpt_model(self, model_name, mconfig):
vae = os.path.normpath(os.path.join(Globals.root, vae))
if os.path.exists(vae):
print(f" | Loading VAE weights from: {vae}")
vae_ckpt = None
vae_dict = None
if vae.endswith(".safetensors"):
vae_ckpt = safetensors.torch.load_file(vae)
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
else:
if vae.endswith((".ckpt",".pt")):
self.scan_model(vae,vae)
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {
k: v
for k, v in vae_ckpt["state_dict"].items()
if k[0:4] != "loss"
}
else:
vae_ckpt = safetensors.torch.load_file(vae)
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
else:
print(f" | VAE file {vae} not found. Skipping.")
Expand All @@ -497,9 +489,9 @@ def _load_diffusers_model(self, mconfig):

print(f">> Loading diffusers model from {name_or_path}")
if using_fp16:
print(" | Using faster float16 precision")
print(" | Using faster float16 precision")
else:
print(" | Using more accurate float32 precision")
print(" | Using more accurate float32 precision")

# TODO: scan weights maybe?
pipeline_args: dict[str, Any] = dict(
Expand Down Expand Up @@ -551,7 +543,7 @@ def _load_diffusers_model(self, mconfig):
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width

print(f" | Default image dimensions = {width} x {height}")
print(f" | Default image dimensions = {width} x {height}")

return pipeline, width, height, model_hash

Expand Down Expand Up @@ -591,13 +583,14 @@ def offload_model(self, model_name: str) -> None:
if self._has_cuda():
torch.cuda.empty_cache()

@classmethod
def scan_model(self, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
print(f">> Scanning Model: {model_name}")
print(f" | Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files == 1:
Expand All @@ -620,7 +613,7 @@ def scan_model(self, model_name, checkpoint):
print("### Exiting InvokeAI")
sys.exit()
else:
print(">> Model scanned ok")
print(" | Model scanned ok")

def import_diffuser_model(
self,
Expand Down Expand Up @@ -800,19 +793,20 @@ def heuristic_import(
print(f">> Probing {thing} for import")

if thing.startswith(("http:", "https:", "ftp:")):
print(f" | {thing} appears to be a URL")
print(f" | {thing} appears to be a URL")
model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1"
) # _resolve_path does a download if needed
is_temporary = True

elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
print(
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
)
return
else:
print(f" | {thing} appears to be a checkpoint file on disk")
print(f" | {thing} appears to be a checkpoint file on disk")
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")

elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
Expand Down Expand Up @@ -869,11 +863,12 @@ def heuristic_import(
return model_path.stem

# another round of heuristics to guess the correct config file.
checkpoint = (
safetensors.torch.load_file(model_path)
if model_path.suffix == ".safetensors"
else torch.load(model_path)
)
checkpoint = None
if model_path.suffix.endswith((".ckpt",".pt")):
self.scan_model(model_path,model_path)
checkpoint = torch.load(model_path)
else:
checkpoint = safetensors.torch.load_file(model_path)
# additional probing needed if no config file provided
if model_config_file is None:
model_type = self.probe_model_type(checkpoint)
Expand Down Expand Up @@ -918,7 +913,7 @@ def heuristic_import(
if model_config_file.name.startswith('v2'):
convert = True
print(
" | This SD-v2 model will be converted to diffusers format for use"
" | This SD-v2 model will be converted to diffusers format for use"
)

if convert:
Expand All @@ -933,6 +928,7 @@ def heuristic_import(
model_description=description,
original_config_file=model_config_file,
commit_to_conf=commit_to_conf,
scan_needed=False,
)
# in the event that this file was downloaded automatically prior to conversion
# we do not keep the original .ckpt/.safetensors around
Expand All @@ -957,14 +953,15 @@ def heuristic_import(
return model_name

def convert_and_import(
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool=True,
) -> str:
"""
Convert a legacy ckpt weights file to diffuser model and import
Expand Down Expand Up @@ -999,11 +996,12 @@ def convert_and_import(
extract_ema=True,
original_config_file=original_config_file,
vae=vae_model,
scan_needed=scan_needed,
)
print(
f" | Success. Optimized model is now located at {str(diffusers_path)}"
f" | Success. Optimized model is now located at {str(diffusers_path)}"
)
print(f" | Writing new config file entry for {model_name}")
print(f" | Writing new config file entry for {model_name}")
new_config = dict(
path=str(diffusers_path),
description=model_description,
Expand Down Expand Up @@ -1293,7 +1291,7 @@ def _diffuser_sha256(
with open(hashpath) as f:
hash = f.read()
return hash
print(" | Calculating sha256 hash of model files")
print(" | Calculating sha256 hash of model files")
tic = time.time()
sha = hashlib.sha256()
count = 0
Expand All @@ -1305,7 +1303,7 @@ def _diffuser_sha256(
sha.update(chunk)
hash = sha.hexdigest()
toc = time.time()
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
with open(hashpath, "w") as f:
f.write(hash)
return hash
Expand Down Expand Up @@ -1350,12 +1348,12 @@ def _load_vae(self, vae_config) -> AutoencoderKL:
local_files_only=not Globals.internet_available,
)

print(f" | Loading diffusers VAE from {name_or_path}")
print(f" | Loading diffusers VAE from {name_or_path}")
if using_fp16:
vae_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}]
else:
print(" | Using more accurate float32 precision")
print(" | Using more accurate float32 precision")
fp_args_list = [{}]

vae = None
Expand Down Expand Up @@ -1396,7 +1394,7 @@ def _delete_model_from_cache(repo_id):
hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete)
print(
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}"
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
)
strategy.execute()

Expand Down