Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,12 @@ def set_use_xla_flash_attention(
Specify the partition specification if using SPMD. Otherwise None.
"""
if use_xla_flash_attention:
if not is_torch_xla_available:
raise "torch_xla is not available"
if not is_torch_xla_available():
raise ImportError("torch_xla is not available")
elif is_torch_xla_version("<", "2.3"):
raise "flash attention pallas kernel is supported from torch_xla version 2.3"
raise ImportError("flash attention pallas kernel is supported from torch_xla version 2.3")
elif is_spmd() and is_torch_xla_version("<", "2.4"):
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
raise ImportError("flash attention pallas kernel using SPMD is supported from torch_xla version 2.4")
else:
if is_flux:
processor = XLAFluxFlashAttnProcessor2_0(partition_spec)
Expand Down
18 changes: 13 additions & 5 deletions src/diffusers/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,17 @@ def from_pretrained(cls, pretrained_model_or_path: str | os.PathLike | None = No
library = None
orig_class_name = None

def load_config_with_name(config_name, *args, **kwargs):
original_config_name = cls.config_name
try:
cls.config_name = config_name
return cls.load_config(*args, **kwargs)
finally:
cls.config_name = original_config_name

# Always attempt to fetch model_index.json first
try:
cls.config_name = "model_index.json"
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
config = load_config_with_name("model_index.json", pretrained_model_or_path, **load_config_kwargs)

if subfolder is not None and subfolder in config:
library, orig_class_name = config[subfolder]
Expand All @@ -289,8 +296,9 @@ def from_pretrained(cls, pretrained_model_or_path: str | os.PathLike | None = No

# Unable to load from model_index.json so fallback to loading from config
if library is None and orig_class_name is None:
cls.config_name = "config.json"
config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs)
config = load_config_with_name(
"config.json", pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs
)

if "_class_name" in config:
# If we find a class name in the config, we can try to load the model as a diffusers model
Expand Down Expand Up @@ -342,7 +350,7 @@ def from_pretrained(cls, pretrained_model_or_path: str | os.PathLike | None = No

load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
load_id = "|".join("null" if p is None else p for p in parts)
load_id = "|".join("null" if p is None else str(p) for p in parts)
model._diffusers_load_id = load_id

return model
7 changes: 4 additions & 3 deletions src/diffusers/models/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def cache_context(self, name: str):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry._set_context(name)

yield

registry._set_context(None)
try:
yield
finally:
registry._set_context(None)
41 changes: 18 additions & 23 deletions src/diffusers/models/downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from ..utils import deprecate
from .normalization import RMSNorm
from .upsampling import upfirdn2d_native
from .upsampling import _prepare_fir_kernel, upfirdn2d_native


class Downsample1D(nn.Module):
Expand Down Expand Up @@ -210,32 +210,29 @@ def _downsample_2d(
"""

assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor

# setup kernel
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)

kernel = kernel * gain
kernel = _prepare_fir_kernel(
kernel,
factor=factor,
gain=gain,
device=hidden_states.device,
dtype=hidden_states.dtype,
)

if self.use_conv:
_, _, convH, convW = weight.shape
pad_value = (kernel.shape[0] - factor) + (convW - 1)
stride_value = [factor, factor]
upfirdn_input = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
kernel,
pad=((pad_value + 1) // 2, pad_value // 2),
)
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
else:
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
kernel,
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
Expand Down Expand Up @@ -380,19 +377,17 @@ def downsample_2d(
"""

assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor

kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)

kernel = kernel * gain
kernel = _prepare_fir_kernel(
kernel,
factor=factor,
gain=gain,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
kernel,
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
`torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
`torch.float32`.

Returns:
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
Expand All @@ -346,7 +346,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin

# Auto-detect appropriate dtype if not specified
if dtype is None:
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
dtype = torch.float32

omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
omega /= embed_dim / 2.0
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def from_pretrained(
subfolder=subfolder,
**kwargs,
)
else:
unused_kwargs = kwargs

model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,7 @@ def enable_parallelism(
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
)

if not torch.distributed.is_available() and not torch.distributed.is_initialized():
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
raise RuntimeError(
"torch.distributed must be available and initialized before calling `enable_parallelism`."
)
Expand Down
63 changes: 40 additions & 23 deletions src/diffusers/models/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@
from .normalization import RMSNorm


def _prepare_fir_kernel(
kernel: torch.Tensor | None,
*,
factor: int,
gain: float,
device: torch.device,
dtype: torch.dtype,
upsample: bool = False,
) -> torch.Tensor:
if kernel is None:
kernel = [1] * factor

kernel = torch.as_tensor(kernel, device=device, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel = kernel / torch.sum(kernel)

scale = gain * (factor**2) if upsample else gain
return (kernel * scale).to(device=device, dtype=dtype)


class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.

Expand Down Expand Up @@ -253,17 +274,14 @@ def _upsample_2d(

assert isinstance(factor, int) and factor >= 1

# Setup filter kernel.
if kernel is None:
kernel = [1] * factor

# setup kernel
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)

kernel = kernel * (gain * (factor**2))
kernel = _prepare_fir_kernel(
kernel,
factor=factor,
gain=gain,
device=hidden_states.device,
dtype=hidden_states.dtype,
upsample=True,
)

if self.use_conv:
convH = weight.shape[2]
Expand Down Expand Up @@ -300,14 +318,14 @@ def _upsample_2d(

output = upfirdn2d_native(
inverse_conv,
torch.tensor(kernel, device=inverse_conv.device),
kernel.to(device=inverse_conv.device, dtype=inverse_conv.dtype),
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
)
else:
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
kernel,
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
Expand Down Expand Up @@ -496,19 +514,18 @@ def upsample_2d(
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor

kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)

kernel = kernel * (gain * (factor**2))
kernel = _prepare_fir_kernel(
kernel,
factor=factor,
gain=gain,
device=hidden_states.device,
dtype=hidden_states.dtype,
upsample=True,
)
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
kernel,
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def is_torch_xla_version(operation: str, version: str):
version (`str`):
A string version of torch_xla
"""
if not is_torch_xla_available:
if not is_torch_xla_available():
return False
return compare_versions(parse(_torch_xla_version), operation, version)

Expand Down
37 changes: 37 additions & 0 deletions tests/hooks/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch

from diffusers.hooks import HookRegistry, ModelHook
from diffusers.hooks.hooks import BaseState, StateManager
from diffusers.models.cache_utils import CacheMixin
from diffusers.training_utils import free_memory
from diffusers.utils.logging import get_logger

Expand Down Expand Up @@ -114,6 +116,27 @@ def reset_state(self, module):
self.increment = 0


class CacheTestState(BaseState):
def reset(self):
pass


class CacheTestHook(ModelHook):
_is_stateful = True

def __init__(self):
super().__init__()
self.state_manager = StateManager(CacheTestState)

def reset_state(self, module):
self.state_manager.reset()


class CacheContextModel(torch.nn.Module, CacheMixin):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class SkipLayerHook(ModelHook):
def __init__(self, skip_layer: bool):
super().__init__()
Expand Down Expand Up @@ -338,6 +361,20 @@ def test_invocation_order_stateful_middle(self):
)
self.assertEqual(output, expected_invocation_order_log)

def test_cache_context_clears_stateful_hook_context_after_exception(self):
model = CacheContextModel()
hook = CacheTestHook()
HookRegistry.check_if_exists_or_initialize(model).register_hook(hook, "cache_test")

with self.assertRaisesRegex(RuntimeError, "interrupted"):
with model.cache_context("failed-call"):
hook.state_manager.get_state()
raise RuntimeError("interrupted")

self.assertIsNone(hook.state_manager._current_context)
with self.assertRaisesRegex(ValueError, "No context is set"):
hook.state_manager.get_state()

def test_invocation_order_stateful_last(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(1), "add_hook")
Expand Down
19 changes: 19 additions & 0 deletions tests/models/test_attention_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib.metadata
import tempfile
import unittest
from unittest.mock import patch

import numpy as np
import pytest
Expand All @@ -9,6 +10,7 @@

from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
from diffusers.utils import import_utils

from ..testing_utils import torch_device

Expand Down Expand Up @@ -83,6 +85,23 @@ def test_only_cross_attention(self):
self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())


class AttentionXLAFlashAttentionTests(unittest.TestCase):
def test_set_use_xla_flash_attention_raises_import_error_without_torch_xla(self):
attn = Attention(query_dim=4, heads=1, dim_head=4)

with patch("diffusers.models.attention_processor.is_torch_xla_available", return_value=False):
with self.assertRaisesRegex(ImportError, "torch_xla is not available"):
attn.set_use_xla_flash_attention(True)

def test_is_torch_xla_version_returns_false_without_torch_xla(self):
import_utils.is_torch_xla_version.cache_clear()
try:
with patch("diffusers.utils.import_utils.is_torch_xla_available", return_value=False):
self.assertFalse(import_utils.is_torch_xla_version("<", "2.3"))
finally:
import_utils.is_torch_xla_version.cache_clear()


class DeprecatedAttentionBlockTests(unittest.TestCase):
@pytest.fixture(scope="session")
def is_dist_enabled(pytestconfig):
Expand Down
Loading
Loading