Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] support dtype in mlx.core's module initialization #1232

Closed
pshishodia-kgp opened this issue Jun 25, 2024 · 1 comment
Closed
Labels
wontfix This will not be worked on

Comments

@pshishodia-kgp
Copy link

Describe the bug
I want to create a linear layer with dtype=bfloat16.

In mlx, the initialization of any of mlx.core's modules (Linear/Embedding/LayerNorm, etc) doesn't accept dtype. so I end up doing the following

import mlx.nn as nn
import mlx.core as mx

bf16_layer = nn.Linear(10, 10)
bf16_layer.weight = bf16_layer.weight.astype(mx.bfloat16)
bf16_layer.bias = bf16_layer.bias.astype(mx.bfloat16)

Expected behavior
Expecting to have dtype support in initialization similar to pytorch

import torch.nn as nn
bf16_layer = nn.Linear(10, 10, dtype=torch.bfloat16)

Desktop (please complete the following information):

  • OS Version:MacOS14.5
@angeloskath
Copy link
Member

The standard way for MLX would be to write model.set_dtype(mx.bfloat16) which is syntactic sugar over model.apply(lambda x: x.astype(mx.bfloat16) if mx.issubtype(x.dtype, mx.floating) else x).

See more at https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.Module.set_dtype.html .

@awni awni closed this as completed Jun 25, 2024
@awni awni added the wontfix This will not be worked on label Jun 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants