diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index e67cca2..804e145 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -304,7 +304,7 @@ def forward( attention_mask=attention_mask, ) - attention_interface: Callable = flash_dmattn_func_auto(backend="flex") + attention_interface: Callable = flash_dmattn_func_auto(backend="cuda") query_states = query_states.transpose(1, 2).contiguous() # [B, H, Q_LEN, D] key_states = key_states.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] value_states = value_states.transpose(1, 2).contiguous() # [B, H, KV_LEN, D]