Skip to content

Support for bf16 Beta #12

@icavan

Description

@icavan

Description

Add support for bfloat16 (BF16) beta input in KDA kernels. Currently beta must be float32.

Context

In the current implementation, beta and initial_state must be provided in float32. Supporting BF16 beta input would reduce memory usage and improve bandwidth utilization, especially in inference scenarios where mixed-precision is common.

Tasks

  • Modify KDA kernel to accept BF16 beta input
  • Add internal upcast logic if needed for numerical stability
  • Update tests to cover BF16 beta
  • Update documentation and usage notes

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions