You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The standard way for MLX would be to write model.set_dtype(mx.bfloat16) which is syntactic sugar over model.apply(lambda x: x.astype(mx.bfloat16) if mx.issubtype(x.dtype, mx.floating) else x).
Describe the bug
I want to create a linear layer with dtype=bfloat16.
In mlx, the initialization of any of mlx.core's modules (Linear/Embedding/LayerNorm, etc) doesn't accept dtype. so I end up doing the following
Expected behavior
Expecting to have dtype support in initialization similar to pytorch
Desktop (please complete the following information):
The text was updated successfully, but these errors were encountered: