-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
I've come across a load_textual_inversion
issue while testing EasyNegative where EasyNegative.pt
loads fine but EasyNegative.safetensors
causes a crash. I'm not sure if that's a problem with the embedding itself or a potential bug.
Reproduction
from diffusers import StableDiffusionPipeline
import torch
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.load_textual_inversion("/content/EasyNegative.safetensors", token="EasyNegative")
prompt = "multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
negative_prompt = "EasyNegative"
generator = torch.Generator("cuda").manual_seed(1)
image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=50, generator=generator).images[0]
display(image)
Logs
in <cell line: 10>:10 │
│ │
│ /usr/local/lib/python3.9/dist-packages/diffusers/loaders.py:546 in load_textual_inversion │
│ │
│ 543 │ │ │ │ subfolder=subfolder, │
│ 544 │ │ │ │ user_agent=user_agent, │
│ 545 │ │ │ ) │
│ ❱ 546 │ │ │ state_dict = torch.load(model_file, map_location="cpu") │
│ 547 │ │ │
│ 548 │ │ # 2. Load token and embedding correcly from file │
│ 549 │ │ if isinstance(state_dict, torch.Tensor): │
│ │
│ /usr/local/lib/python3.9/dist-packages/torch/serialization.py:815 in load │
│ │
│ 812 │ │ │ │ return _legacy_load(opened_file, map_location, _weights_only_unpickler, │
│ 813 │ │ │ except RuntimeError as e: │
│ 814 │ │ │ │ raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None │
│ ❱ 815 │ │ return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args │
│ 816 │
│ 817 │
│ 818 # Register pickling support for layout instances such as │
│ │
│ /usr/local/lib/python3.9/dist-packages/torch/serialization.py:1033 in _legacy_load │
│ │
│ 1030 │ │ │ f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or ne │
│ 1031 │ │ │ "functionality.") │
│ 1032 │ │
│ ❱ 1033 │ magic_number = pickle_module.load(f, **pickle_load_args) │
│ 1034 │ if magic_number != MAGIC_NUMBER: │
│ 1035 │ │ raise RuntimeError("Invalid magic number; corrupt file?") │
│ 1036 │ protocol_version = pickle_module.load(f, **pickle_load_args) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnpicklingError: invalid load key, '"'.
System Info
diffusers
version: 0.15.1- Platform: Linux-5.10.147+-x86_64-with-glibc2.31
- Python version: 3.9.16
- PyTorch version (GPU?): 1.13.1+cu116 (True)
- Transformers version: 4.28.0
- Accelerate version: 0.18.0
- Running on Google Colab with NVIDIA Tesla T4 GPU
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working