-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Description
Describe the bug
Slightly related to huggingface/accelerate#2494 in the sense it applies the suggested workaround to avoid the big modeling feature (not supported by FakeTensorMode).
Once the model is initialized, it hits an issue within EulerDiscreteScheduler when numpy.array is called to create sigmas tensor. FakeTensorMode does not support numpy.array so ideally we should rewrite the code without using this call whener is possible.
Reproduction
The repro also has another workaround (PR #318 that was rejected, but a better solution never found) for now
from huggingface_hub import hf_hub_download
import torch
from diffusers import DiffusionPipeline
import safetensors
import transformers
# Monkey-patch for https://github.com/huggingface/safetensors/pull/318
class ONNXTorchPatcher:
def __init__(self):
def safetensors_load_file_wrapper(filename, device="cpu"):
result = {}
with safetensors.torch.safe_open( # type: ignore[attr-defined]
filename, framework="pt", device=device
) as f:
for k in f.keys():
fake_mode = torch._guards.detect_fake_mode()
if not fake_mode:
result[k] = f.get_tensor(k)
else:
empty_tensor = f.get_slice(k)
result[k] = torch.empty(
tuple(empty_tensor.get_shape()),
dtype=safetensors.torch._getdtype(
empty_tensor.get_dtype()
),
)
return result
self.safetensors_torch_load_file = safetensors.torch.load_file
self.safetensors_torch_load_file_wrapper = safetensors_load_file_wrapper
self.transformers_modeling_utils_safe_load_file = (
transformers.modeling_utils.safe_load_file
)
def __enter__(self):
safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper
transformers.modeling_utils.safe_load_file = (
self.safetensors_torch_load_file_wrapper
)
def __exit__(self, exc_type, exc_value, traceback):
safetensors.torch.load_file = self.safetensors_torch_load_file
transformers.modeling_utils.safe_load_file = (
self.transformers_modeling_utils_safe_load_file
)
with torch._subclasses.FakeTensorMode():
fake_model = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", low_cpu_mem_usage=False)The expected behavior is that DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", low_cpu_mem_usage=False) should succeed under torch._subclasses.fake_tensor.FakeTensorMode
Logs
No response
System Info
diffuser main branch
Who can help?
No response