- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
tried new sana 4k model, it fails when running in bfloat16 during vae decode if tiled decode is enabled.
issue is that in SanaMultiscaleAttnProcessor2_0 it only does upcasting conditionally:
        if use_linear_attention:
            # for linear attention upcast hidden_states to float32
            hidden_states = hidden_states.to(dtype=torch.float32)and this remains true for entire execution when not using tiled decode
but when using tiled decode it may result in quadratic attention instead of linear and it ALSO needs upcasting.
Reproduction
run sana 4k model in bf16 with tiled vae.
Logs
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_dc.py:93 in forward                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│    92 │   def forward(self, x: torch.Tensor) -> torch.Tensor:                                                                                                                                                                                                                                                                                                                                                                    │
│ ❱  93 │   │   x = self.attn(x)                                                                                                                                                                                                                                                                                                                                                                                                   │
│    94 │   │   x = self.conv_out(x)                                                                                                                                                                                                                                                                                                                                                                                               │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736 in _wrapped_call_impl                                                                                                                                                                                                                                                                                                                   │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   1735 │   │   else:                                                                                                                                                                                                                                                                                                                                                                                                             │
│ ❱ 1736 │   │   │   return self._call_impl(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                                       │
│   1737                                                                                                                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747 in _call_impl                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   1746 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                                                                                                                                                                                                                                                                                                   │
│ ❱ 1747 │   │   │   return forward_call(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                                          │
│   1748                                                                                                                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                          
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py:906 in forward                                                                                                                                                                                                                                                                                                                  │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│    905 │   def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:                                                                                                                                                                                                                                                                                                                                                       │
│ ❱  906 │   │   return self.processor(self, hidden_states)                                                                                                                                                                                                                                                                                                                                                                        │
│    907                                                                                                                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py:5810 in __call__                                                                                                                                                                                                                                                                                                                │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   5809 │   │   else:                                                                                                                                                                                                                                                                                                                                                                                                             │
│ ❱ 5810 │   │   │   hidden_states = attn.apply_quadratic_attention(query, key, value)                                                                                                                                                                                                                                                                                                                                             │
│   5811                                                                                                                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py:902 in apply_quadratic_attention                                                                                                                                                                                                                                                                                                │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│    901 │   │   scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)                                                                                                                                                                                                                                                                                                                                             │
│ ❱  902 │   │   hidden_states = torch.matmul(value, scores)                                                                                                                                                                                                                                                                                                                                                                       │
│    903 │   │   return hidden_states
RuntimeError: expected scalar type Float but found BFloat16System Info
diffusers==c944f06
python==3.12.3
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working