Bug Report
Streaming TTS inference crashes immediately on any device using SDPA attention — this includes:
- Apple Silicon (MPS) — all Mac users
- CPU — any machine without a GPU
- CUDA without
flash_attn — e.g. fresh installs, Docker images without flash-attn
The model loads successfully but crashes on the first text-window forward pass during generation.
Error
RuntimeError: The expanded size of the tensor (228) must match the existing size (223)
at non-singleton dimension 3.
Target sizes: [1, 14, 5, 228]. Tensor sizes: [1, 1, 5, 223]
Steps to reproduce
# On any Mac, or any machine without flash_attn installed
python demo/vibevoice_realtime_demo.py \
--model_path microsoft/VibeVoice-Realtime-0.5B \
--device mps # or cpu
# Open browser → type any text → crash on first generation
Root cause
MockCacheLayer.get_mask_sizes() does not include query_length in the returned kv_length, violating the DynamicLayer contract in transformers 4.57.
The canonical implementation (transformers.cache_utils.DynamicLayer, L123-128):
def get_mask_sizes(self, cache_position):
query_length = cache_position.shape[0]
kv_length = self.get_seq_length() + query_length # ← includes query_length
return kv_length, 0
Current MockCacheLayer:
def get_mask_sizes(self, cache_position):
kv_length = self.key_cache.shape[2] # ← missing query_length
return kv_length, 0
This makes the causal mask query_length tokens too short. SDPA strictly requires the mask to match the KV tensor dimensions, so it crashes. Flash Attention 2 computes causality internally and never calls get_mask_sizes, which is why the bug is latent on CUDA+flash_attn.
Environment
| Component |
Version |
| macOS |
26.4 |
| Hardware |
Apple M4 Max |
| Python |
3.11.15 |
| PyTorch |
2.11.0 |
| Transformers |
4.57.6 |
Fix
A one-line fix is available in #303.
Bug Report
Streaming TTS inference crashes immediately on any device using SDPA attention — this includes:
flash_attn— e.g. fresh installs, Docker images without flash-attnThe model loads successfully but crashes on the first text-window forward pass during generation.
Error
Steps to reproduce
Root cause
MockCacheLayer.get_mask_sizes()does not includequery_lengthin the returnedkv_length, violating theDynamicLayercontract in transformers 4.57.The canonical implementation (
transformers.cache_utils.DynamicLayer, L123-128):Current
MockCacheLayer:This makes the causal mask
query_lengthtokens too short. SDPA strictly requires the mask to match the KV tensor dimensions, so it crashes. Flash Attention 2 computes causality internally and never callsget_mask_sizes, which is why the bug is latent on CUDA+flash_attn.Environment
Fix
A one-line fix is available in #303.