Description
Add support for bfloat16 (BF16) beta input in KDA kernels. Currently beta must be float32.
Context
In the current implementation, beta and initial_state must be provided in float32. Supporting BF16 beta input would reduce memory usage and improve bandwidth utilization, especially in inference scenarios where mixed-precision is common.
Tasks
Description
Add support for
bfloat16(BF16)betainput in KDA kernels. Currentlybetamust befloat32.Context
In the current implementation,
betaandinitial_statemust be provided infloat32. Supporting BF16 beta input would reduce memory usage and improve bandwidth utilization, especially in inference scenarios where mixed-precision is common.Tasks