Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 102 additions & 58 deletions tests/models/transformers/test_models_transformer_bria.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from typing import Any

import torch

from diffusers import BriaTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
IPAdapterTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
Comment on lines +29 to +32
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can remove the expensive test suites from here provided the popularity of the model?

TrainingTesterMixin,
)


enable_full_determinism()


def create_bria_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
def create_bria_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
ip_cross_attn_state_dict = {}
key_id = 0

Expand All @@ -50,11 +58,8 @@ def create_bria_ip_adapter_state_dict(model):
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)

key_id += 1

# "image_proj" (ImageProjection layer weights)

image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"],
Expand All @@ -73,53 +78,36 @@ def create_bria_ip_adapter_state_dict(model):
)

del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict

return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}

class BriaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = BriaTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.7, 0.7]

# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True

class BriaTransformerTesterConfig(BaseModelTesterConfig):
@property
def dummy_input(self):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
def model_class(self):
return BriaTransformer2DModel

hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
@property
def main_input_name(self) -> str:
return "hidden_states"

return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
}
@property
def model_split_percents(self) -> list:
return [0.8, 0.7, 0.7]

@property
def input_shape(self):
def output_shape(self) -> tuple:
return (16, 4)

@property
def output_shape(self):
def input_shape(self) -> tuple:
return (16, 4)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_init_dict(self) -> dict:
return {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
Expand All @@ -131,19 +119,42 @@ def prepare_init_args_and_inputs_for_common(self):
"axes_dims_rope": [0, 4, 4],
}

inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32

return {
"hidden_states": randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"img_ids": randn_tensor(
(height * width, num_image_channels), generator=self.generator, device=torch_device
),
"txt_ids": randn_tensor(
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}


class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin):
def test_deprecated_inputs_img_txt_ids_3d(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()

model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

with torch.no_grad():
output_1 = model(**inputs_dict).to_tuple()[0]

# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)

Expand All @@ -156,26 +167,59 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]

self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
assert output_1.shape == output_2.shape
assert torch.allclose(output_1, output_2, atol=1e-5), (
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
"are not equal as them as 2d inputs"
)


class TestBriaTransformerTraining(BriaTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"BriaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class BriaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = BriaTransformer2DModel
class TestBriaTransformerCompile(BriaTransformerTesterConfig, TorchCompileTesterMixin):
pass


class TestBriaTransformerIPAdapter(BriaTransformerTesterConfig, IPAdapterTesterMixin):
@property
def ip_adapter_processor_cls(self):
return FluxIPAdapterJointAttnProcessor2_0

def modify_inputs_for_ip_adapter(self, model, inputs_dict):
torch.manual_seed(0)
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
return inputs_dict

def prepare_init_args_and_inputs_for_common(self):
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
return create_bria_ip_adapter_state_dict(model)


class BriaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = BriaTransformer2DModel
class TestBriaTransformerLoRA(BriaTransformerTesterConfig, LoraTesterMixin):
pass

def prepare_init_args_and_inputs_for_common(self):
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()

class TestBriaTransformerLoRAHotSwap(BriaTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]

def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 32

return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
96 changes: 58 additions & 38 deletions tests/models/transformers/test_models_transformer_bria_fibo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,62 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from diffusers import BriaFiboTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.

TrainingTesterMixin,
)


enable_full_determinism()


class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = BriaFiboTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.7, 0.7]

# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
class BriaFiboTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return BriaFiboTransformer2DModel

@property
def dummy_input(self):
batch_size = 1
num_latent_channels = 48
num_image_channels = 3
height = width = 16
sequence_length = 32
embedding_dim = 64
def main_input_name(self) -> str:
return "hidden_states"

hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
@property
def model_split_percents(self) -> list:
return [0.8, 0.7, 0.7]

return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
}
@property
def output_shape(self) -> tuple:
return (256, 48)

@property
def input_shape(self):
def input_shape(self) -> tuple:
return (16, 16)

@property
def output_shape(self):
return (256, 48)
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self) -> dict:
return {
"patch_size": 1,
"in_channels": 48,
"num_layers": 1,
Expand All @@ -81,9 +69,41 @@ def prepare_init_args_and_inputs_for_common(self):
"axes_dims_rope": [0, 4, 4],
}

inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_latent_channels = 48
num_image_channels = 3
height = width = 16
sequence_length = 32
embedding_dim = 64

encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
return {
"hidden_states": randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
),
"encoder_hidden_states": encoder_hidden_states,
"img_ids": randn_tensor(
(height * width, num_image_channels), generator=self.generator, device=torch_device
),
"txt_ids": randn_tensor(
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
}


class TestBriaFiboTransformer(BriaFiboTransformerTesterConfig, ModelTesterMixin):
pass


class TestBriaFiboTransformerTraining(BriaFiboTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"BriaFiboTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class TestBriaFiboTransformerCompile(BriaFiboTransformerTesterConfig, TorchCompileTesterMixin):
pass
Loading