Skip to content

bria model/pipeline review #13625

@hlky

Description

@hlky

bria model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Files/categories reviewed: target pipeline/model files, public/lazy imports, docs, fast/slow tests, dtype/device/offload hooks, scheduler/custom timestep paths, attention processor paths, and related Flux precedent.

Duplicate search: checked GitHub issues/PRs for bria, BriaPipeline, BriaTransformer2DModel, attention_kwargs, controlnet_single_block_samples, BriaPipelineOutput, guidance_embeds, timesteps sigmas, normalize, and LoRA. Found broad integration/refactor items like #11978, #12214, #13341, and the separate bria_fibo review #13618, but no exact duplicates for the issues below.

Test coverage status: fast transformer and pipeline tests exist, and a slow Bria pipeline test exists. The current fast tests do not cover the failure modes below. Targeted pytest collection under .venv failed before running tests because this Windows torch build lacks torch._C._distributed_c10d.

Issue 1: attention_kwargs are ignored during normal transformer inference

Affected code:

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

Problem:
BriaTransformer2DModel.forward() passes attention_kwargs to blocks only in the gradient-checkpointing branch. Normal inference drops them, so masks and processor kwargs supplied by BriaPipeline(..., attention_kwargs=...) have no effect.

Impact:
Attention masks, custom processor kwargs, and related conditioning paths silently do nothing in the default inference path.

Reproduction:

import torch
from diffusers import BriaTransformer2DModel

torch.manual_seed(0)
model = BriaTransformer2DModel(
    patch_size=1, in_channels=4, num_layers=1, num_single_layers=1,
    attention_head_dim=8, num_attention_heads=2, joint_attention_dim=32,
    axes_dims_rope=[0, 4, 4],
).eval()

inputs = dict(
    hidden_states=torch.randn(1, 4, 4),
    encoder_hidden_states=torch.randn(1, 4, 32),
    img_ids=torch.randn(4, 3),
    txt_ids=torch.randn(4, 3),
    timestep=torch.ones(1),
    return_dict=False,
)
mask = torch.ones(1, 8, dtype=torch.bool)
mask[:, -1] = False

with torch.no_grad():
    a = model(**inputs)[0]
    b = model(**inputs, attention_kwargs={"attention_mask": mask})[0]

print((a - b).abs().max().item())  # 0.0: mask was dropped

Relevant precedent:

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

Suggested fix:

encoder_hidden_states, hidden_states = block(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
    temb=temb,
    image_rotary_emb=image_rotary_emb,
    attention_kwargs=attention_kwargs,
)

Issue 2: controlnet_single_block_samples residuals crash

Affected code:

if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]

Problem:
After each BriaSingleTransformerBlock, hidden_states already contains image tokens only. The code slices it with encoder_hidden_states.shape[1], producing an empty or wrong-sized slice before adding ControlNet residuals.

Impact:
Any ControlNet path using controlnet_single_block_samples fails at runtime.

Reproduction:

import torch
from diffusers import BriaTransformer2DModel

model = BriaTransformer2DModel(
    patch_size=1, in_channels=4, num_layers=1, num_single_layers=1,
    attention_head_dim=8, num_attention_heads=2, joint_attention_dim=32,
    axes_dims_rope=[0, 4, 4],
).eval()

control = [torch.zeros(1, 4, 16)]
model(
    hidden_states=torch.randn(1, 4, 4),
    encoder_hidden_states=torch.randn(1, 4, 32),
    img_ids=torch.randn(4, 3),
    txt_ids=torch.randn(4, 3),
    timestep=torch.ones(1),
    controlnet_single_block_samples=control,
)

Relevant precedent:

# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]

Suggested fix:

hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]

Issue 3: Custom timesteps are unusable

Affected code:

if (
isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
and self.scheduler.config["use_dynamic_shifting"]
):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
else:
# 4. Prepare timesteps
# Sample from training sigmas
if isinstance(self.scheduler, DDIMScheduler) or isinstance(
self.scheduler, EulerAncestralDiscreteScheduler
):
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, None, None
)
else:
sigmas = get_original_sigmas(
num_train_timesteps=self.scheduler.config.num_train_timesteps,
num_inference_steps=num_inference_steps,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
)

Problem:
timesteps is public, but the pipeline still builds sigmas and passes both timesteps and sigmas to retrieve_timesteps(), which rejects that combination. In the DDIM/Euler ancestral branch, custom timesteps are dropped by passing None.

Impact:
Users cannot run custom timestep schedules despite the documented argument.

Reproduction:

import numpy as np
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps

scheduler = FlowMatchEulerDiscreteScheduler()
sigmas = np.linspace(1.0, 0.5, 2)
retrieve_timesteps(scheduler, 2, "cpu", timesteps=[999, 500], sigmas=sigmas)

Relevant precedent:

if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps

Suggested fix:

sigmas = None if timesteps is not None else get_original_sigmas(
    num_train_timesteps=self.scheduler.config.num_train_timesteps,
    num_inference_steps=num_inference_steps,
)
timesteps, num_inference_steps = retrieve_timesteps(
    self.scheduler,
    num_inference_steps,
    device,
    timesteps=timesteps,
    sigmas=sigmas,
)

Issue 4: normalize=True crashes when CFG is disabled

Affected code:

# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
cfg_noise_pred_text = noise_pred_text.std()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if normalize:
noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred

Problem:
cfg_noise_pred_text is defined only inside the classifier-free guidance branch, but normalize=True reads it unconditionally.

Impact:
BriaPipeline(..., guidance_scale<=1, normalize=True) fails with UnboundLocalError.

Reproduction:

import torch
from diffusers import AutoencoderKL, BriaPipeline, BriaTransformer2DModel, FlowMatchEulerDiscreteScheduler

transformer = BriaTransformer2DModel(
    patch_size=1, in_channels=16, num_layers=1, num_single_layers=1,
    attention_head_dim=8, num_attention_heads=2, joint_attention_dim=32,
    axes_dims_rope=[0, 4, 4],
)
vae = AutoencoderKL(
    block_out_channels=(32,), in_channels=3, out_channels=3,
    down_block_types=["DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D"],
    latent_channels=4, sample_size=32, shift_factor=0, scaling_factor=0.13025,
)
pipe = BriaPipeline(transformer, FlowMatchEulerDiscreteScheduler(), vae, None, None)
pipe(prompt_embeds=torch.randn(1, 4, 32), height=16, width=16, num_inference_steps=1,
     guidance_scale=1.0, normalize=True, output_type="latent")

Relevant precedent:

if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale

Suggested fix:

if normalize:
    if not self.do_classifier_free_guidance:
        raise ValueError("`normalize=True` requires `guidance_scale > 1`.")
    noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred

Issue 5: Negative prompt embedding validation is not wired into __call__

Affected code:

if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)

Problem:
check_inputs() can validate negative_prompt and negative_prompt_embeds, but __call__() does not pass them. Shape mismatches reach torch.cat() and fail with a low-level tensor error.

Impact:
Invalid user inputs are not caught with the intended clear ValueError.

Reproduction:

import torch
from diffusers import AutoencoderKL, BriaPipeline, BriaTransformer2DModel, FlowMatchEulerDiscreteScheduler

transformer = BriaTransformer2DModel(
    patch_size=1, in_channels=16, num_layers=1, num_single_layers=1,
    attention_head_dim=8, num_attention_heads=2, joint_attention_dim=32,
    axes_dims_rope=[0, 4, 4],
)
vae = AutoencoderKL(
    block_out_channels=(32,), in_channels=3, out_channels=3,
    down_block_types=["DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D"],
    latent_channels=4, sample_size=32, shift_factor=0, scaling_factor=0.13025,
)
pipe = BriaPipeline(transformer, FlowMatchEulerDiscreteScheduler(), vae, None, None)
pipe(prompt_embeds=torch.randn(1, 4, 32), negative_prompt_embeds=torch.randn(1, 5, 32),
     height=16, width=16, num_inference_steps=1, guidance_scale=5.0, output_type="latent")

Relevant precedent:

if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."

Suggested fix:

self.check_inputs(
    prompt=prompt,
    height=height,
    width=width,
    negative_prompt=negative_prompt,
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_prompt_embeds,
    callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
    max_sequence_length=max_sequence_length,
)

Issue 6: guidance_embeds=True cannot construct or run correctly

Affected code:

if guidance_embeds:
self.guidance_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim)

timestep = timestep.to(hidden_states.dtype)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype)
else:
guidance = None
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
if guidance:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)

Problem:
BriaTimestepProjEmbeddings requires time_theta, but guidance_embed is constructed without it. The forward path then uses if guidance: on a tensor, which is ambiguous for batch sizes above 1.

Impact:
Any config or future checkpoint with guidance_embeds=True fails.

Reproduction:

from diffusers import BriaTransformer2DModel

BriaTransformer2DModel(
    patch_size=1, in_channels=4, num_layers=1, num_single_layers=1,
    attention_head_dim=8, num_attention_heads=2, joint_attention_dim=32,
    axes_dims_rope=[0, 4, 4], guidance_embeds=True,
)

Relevant precedent:

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)

Suggested fix:

if guidance_embeds:
    self.guidance_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)

...
if guidance is not None:
    temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)

Issue 7: Pipeline has dead LoRA code but no LoRA loader mixin

Affected code:

from ...loaders import FluxLoraLoaderMixin

class BriaPipeline(DiffusionPipeline):

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
if not is_ng_none(negative_prompt):
negative_prompt = (
batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

Problem:
BriaPipeline imports and checks FluxLoraLoaderMixin, but the class does not inherit it. Therefore isinstance(self, FluxLoraLoaderMixin) is always false and public LoRA loader methods are absent.

Impact:
Users cannot call pipe.load_lora_weights(), and attention_kwargs={"scale": ...} cannot scale text-encoder LoRA layers through this pipeline path.

Reproduction:

from diffusers import BriaPipeline, FluxPipeline

print(hasattr(BriaPipeline, "load_lora_weights"))  # False
print(hasattr(FluxPipeline, "load_lora_weights"))  # True

Relevant precedent:

class FluxPipeline(
DiffusionPipeline,
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
FluxIPAdapterMixin,
):

Suggested fix:

class BriaPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
    ...

If Bria LoRA loading is intentionally unsupported, remove the unused import, isinstance branches, and lora_scale plumbing instead.

Issue 8: BriaPipelineOutput is not exported by the lazy package

Affected code:

_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_bria"] = ["BriaPipeline"]

@dataclass
class BriaPipelineOutput(BaseOutput):
"""
Output class for Bria pipelines.
Args:
images (`list[PIL.Image.Image]` or `np.ndarray`)
list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: list[PIL.Image.Image] | np.ndarray

Problem:
pipeline_output.py defines BriaPipelineOutput, but src/diffusers/pipelines/bria/__init__.py does not include it in _import_structure.

Impact:
The output type is public in the pipeline return docs but cannot be imported from diffusers.pipelines.bria.

Reproduction:

from diffusers.pipelines.bria.pipeline_output import BriaPipelineOutput
print(BriaPipelineOutput.__name__)

from diffusers.pipelines.bria import BriaPipelineOutput

Relevant precedent:

_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}

Suggested fix:

_import_structure = {"pipeline_output": ["BriaPipelineOutput"]}

...
else:
    from .pipeline_bria import BriaPipeline
    from .pipeline_output import BriaPipelineOutput

Issue 9: Transformer is missing attention/offload integration metadata

Affected code:

class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
"""
The Transformer model introduced in Flux. Based on FluxPipeline with several changes:
- no pooled embeddings
- We use zero padding for prompts
- No guidance embedding since this is not a distilled version
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
"""
_supports_gradient_checkpointing = True

Problem:
BriaTransformer2DModel does not inherit AttentionMixin and does not define _no_split_modules. Comparable transformer models expose model-level attention processor APIs and protect transformer blocks from unsafe device_map splitting.

Impact:
Users cannot call standard APIs like set_attn_processor() / fuse_qkv_projections(), and automatic placement/offload can split residual transformer blocks.

Reproduction:

from diffusers import BriaTransformer2DModel, FluxTransformer2DModel

print(hasattr(BriaTransformer2DModel, "set_attn_processor"))  # False
print(getattr(BriaTransformer2DModel, "_no_split_modules", None))  # None
print(FluxTransformer2DModel._no_split_modules)

Relevant precedent:

class FluxTransformer2DModel(
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
AttentionMixin,
):
"""
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Args:
patch_size (`int`, defaults to `1`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `64`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `None`):
The number of channels in the output. If not specified, it defaults to `in_channels`.
num_layers (`int`, defaults to `19`):
The number of layers of dual stream DiT blocks to use.
num_single_layers (`int`, defaults to `38`):
The number of layers of single stream DiT blocks to use.
attention_head_dim (`int`, defaults to `128`):
The number of dimensions to use for each attention head.
num_attention_heads (`int`, defaults to `24`):
The number of attention heads to use.
joint_attention_dim (`int`, defaults to `4096`):
The number of dimensions to use for the joint attention (embedding/channel dimension of
`encoder_hidden_states`).
pooled_projection_dim (`int`, defaults to `768`):
The number of dimensions to use for the pooled projection.
guidance_embeds (`bool`, defaults to `False`):
Whether to use guidance embeddings for guidance-distilled variant of the model.
axes_dims_rope (`tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions to use for the rotary positional embeddings.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]

Suggested fix:

from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward

class BriaTransformer2DModel(..., CacheMixin, AttentionMixin):
    _no_split_modules = ["BriaTransformerBlock", "BriaSingleTransformerBlock"]
    _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
    _repeated_blocks = ["BriaTransformerBlock", "BriaSingleTransformerBlock"]

Issue 10: RoPE still selects float64 on NPU

Affected code:

pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64

pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64

Problem:
The RoPE helpers gate float64 only for MPS. The review rules call out NPU as another backend where float64 is unsupported.

Impact:
NPU users can hit unsupported dtype failures in positional embedding creation.

Reproduction:

import torch
from diffusers.models.transformers.transformer_bria import BriaEmbedND

if not hasattr(torch, "npu") or not torch.npu.is_available():
    print("Requires an NPU runtime")
else:
    ids = torch.zeros(4, 3, device="npu")
    BriaEmbedND(theta=10000, axes_dim=[0, 4, 4])(ids)

Relevant precedent:

pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64

Suggested fix:

is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64

Issue 11: Public __call__ docstring contains unresolved conflict markers

Affected code:

guidance_scale (`float`, *optional*, defaults to 5.0):
<<<<<<< HEAD
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
linked to the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `list[str]`, *optional*):
=======
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of
equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is
enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
that are closely linked to the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `list[str]`, *optional*):
>>>>>>> main

Problem:
The public BriaPipeline.__call__ docstring still contains <<<<<<< HEAD, =======, and >>>>>>> main.

Impact:
Generated API docs expose merge-conflict text, and the file violates the review rule against ephemeral PR-context artifacts.

Reproduction:

from diffusers import BriaPipeline

doc = BriaPipeline.__call__.__doc__ or ""
print("<<<<<<<" in doc, "=======" in doc, ">>>>>>>" in doc)

Relevant precedent:
N/A.

Suggested fix:

# Delete the conflict markers and keep one resolved guidance_scale paragraph.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions