-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Open
Description
Describe the bug
We have a neural network that we want to release for inference. The latest version of mlx is slower than the previous ones by around 15% on our whole model. #1950 was a first fix, but we have still a slowdown. I attempted again to reduce the network to track down the root cause. I arrived to a minimal reproducible example which shows the slowdown starting in 0.23.1.
To Reproduce
Include code snippet
# /// script
# requires-python = "==3.12.9"
# dependencies = []
#
# ///
import time
import mlx.core as mx
import mlx.nn as nn
class LmGen(nn.Module):
def __init__(self):
self.gen_sequence = mx.full(
shape=(1, 1),
vals=-2,
dtype=mx.int32
)
self.text_emb = nn.Embedding(32001, 4096)
self.layers = [
nn.Sequential(
nn.Linear(4096, 512, bias=False),
nn.ReLU(),
nn.Linear(512, 4096, bias=False),
)
for _ in range(32)
]
def step(self):
xs = self.text_emb(self.gen_sequence)
for layer in self.layers:
xs = layer(xs)
return xs
def main():
WARMUP = 5
TOTAL_STEPS = 100
gen = LmGen()
gen.set_dtype(mx.bfloat16)
nn.quantize(gen, bits=4, group_size=32)
sum_times = 0
for i in range(100):
data = mx.arange(8, dtype=mx.uint32)
uploaded_image_embeddings = mx.arange(1152000, dtype=mx.bfloat16)
mx.eval((data, uploaded_image_embeddings))
t1 = time.time()
mx.eval(gen.step())
t2 = time.time()
if i >= 5:
sum_times += t2 - t1
print(f"average time per step: {(sum_times / (TOTAL_STEPS - WARMUP)) * 1000:1f} ms")
main()$ uv run --with mlx==0.22.1 something.py
average time per step: 1.492164 ms
$ uv run --with mlx==0.23.1 something.py
average time per step: 2.658959 ms
$ CMAKE_BUILD_PARALLEL_LEVEL=16 uv run --with git+https://github.com/ml-explore/mlx#2770a1024082eb10cce6bc0ac589ad089e7be611 something.py
average time per step: 2.870063 ms
Expected behavior
The speed should be the same or similar as mlx versions are increasing.
Desktop (please complete the following information):
ProductName: macOS
ProductVersion: 15.3.1
BuildVersion: 24D70
Model Name: MacBook Air
Model Identifier: Mac15,12
Model Number: MXCV3FN/A
Chip: Apple M3
Total Number of Cores: 8 (4 performance and 4 efficiency)
Memory: 16 GB
System Firmware Version: 11881.81.4
OS Loader Version: 11881.81.4
Serial Number (system): MW9GK71RY5
Hardware UUID: 810DA0DC-BEF2-5453-848C-AE07236C3260
Provisioning UDID: 00008122-001A089C2129001C
Activation Lock Status: Disabled
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels