Skip to content

Sana fails with BFloat16 and tiled VAE decode #10590

@vladmandic

Description

@vladmandic

Describe the bug

tried new sana 4k model, it fails when running in bfloat16 during vae decode if tiled decode is enabled.
issue is that in SanaMultiscaleAttnProcessor2_0 it only does upcasting conditionally:

        if use_linear_attention:
            # for linear attention upcast hidden_states to float32
            hidden_states = hidden_states.to(dtype=torch.float32)

and this remains true for entire execution when not using tiled decode
but when using tiled decode it may result in quadratic attention instead of linear and it ALSO needs upcasting.

Reproduction

run sana 4k model in bf16 with tiled vae.

Logs

│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_dc.py:93 in forward                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│    92 │   def forward(self, x: torch.Tensor) -> torch.Tensor:                                                                                                                                                                                                                                                                                                                                                                    │
│ ❱  93 │   │   x = self.attn(x)                                                                                                                                                                                                                                                                                                                                                                                                   │
│    94 │   │   x = self.conv_out(x)                                                                                                                                                                                                                                                                                                                                                                                               │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736 in _wrapped_call_impl                                                                                                                                                                                                                                                                                                                   │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   1735 │   │   else:                                                                                                                                                                                                                                                                                                                                                                                                             │
│ ❱ 1736 │   │   │   return self._call_impl(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                                       │
│   1737                                                                                                                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747 in _call_impl                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   1746 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                                                                                                                                                                                                                                                                                                   │
│ ❱ 1747 │   │   │   return forward_call(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                                          │
│   1748                                                                                                                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                          
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py:906 in forward                                                                                                                                                                                                                                                                                                                  │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│    905 │   def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:                                                                                                                                                                                                                                                                                                                                                       │
│ ❱  906 │   │   return self.processor(self, hidden_states)                                                                                                                                                                                                                                                                                                                                                                        │
│    907                                                                                                                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py:5810 in __call__                                                                                                                                                                                                                                                                                                                │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   5809 │   │   else:                                                                                                                                                                                                                                                                                                                                                                                                             │
│ ❱ 5810 │   │   │   hidden_states = attn.apply_quadratic_attention(query, key, value)                                                                                                                                                                                                                                                                                                                                             │
│   5811                                                                                                                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py:902 in apply_quadratic_attention                                                                                                                                                                                                                                                                                                │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│    901 │   │   scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)                                                                                                                                                                                                                                                                                                                                             │
│ ❱  902 │   │   hidden_states = torch.matmul(value, scores)                                                                                                                                                                                                                                                                                                                                                                       │
│    903 │   │   return hidden_states
RuntimeError: expected scalar type Float but found BFloat16

System Info

diffusers==c944f06
python==3.12.3

Who can help?

cc: @yiyixuxu @sayakpaul @DN6 @lawrence-cj

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions