Skip to content

[BUG] Neural network got ~65% slower since MLX v0.23.1 #1963

@gabrieldemarmiesse

Description

@gabrieldemarmiesse

Describe the bug
We have a neural network that we want to release for inference. The latest version of mlx is slower than the previous ones by around 15% on our whole model. #1950 was a first fix, but we have still a slowdown. I attempted again to reduce the network to track down the root cause. I arrived to a minimal reproducible example which shows the slowdown starting in 0.23.1.

To Reproduce

Include code snippet

# /// script
# requires-python = "==3.12.9"
# dependencies = []
#
# ///
import time

import mlx.core as mx
import mlx.nn as nn


class LmGen(nn.Module):
    def __init__(self):
        self.gen_sequence = mx.full(
            shape=(1, 1),
            vals=-2,
            dtype=mx.int32
        )

        self.text_emb = nn.Embedding(32001, 4096)
        self.layers = [
            nn.Sequential(
                nn.Linear(4096, 512, bias=False),
                nn.ReLU(),
                nn.Linear(512, 4096, bias=False),
            )
            for _ in range(32)
        ]

    def step(self):
        xs = self.text_emb(self.gen_sequence)
        for layer in self.layers:
            xs = layer(xs)
        return xs


def main():
    WARMUP = 5
    TOTAL_STEPS = 100
    gen = LmGen()
    gen.set_dtype(mx.bfloat16)
    nn.quantize(gen, bits=4, group_size=32)

    sum_times = 0
    for i in range(100):
        data = mx.arange(8, dtype=mx.uint32)
        uploaded_image_embeddings = mx.arange(1152000, dtype=mx.bfloat16)
        mx.eval((data, uploaded_image_embeddings))

        t1 = time.time()
        mx.eval(gen.step())
        t2 = time.time()
        
        if i >= 5:
            sum_times += t2 - t1

    print(f"average time per step: {(sum_times / (TOTAL_STEPS - WARMUP)) * 1000:1f} ms")

main()
$  uv run --with mlx==0.22.1 something.py
average time per step: 1.492164 ms
$  uv run --with mlx==0.23.1 something.py
average time per step: 2.658959 ms
$ CMAKE_BUILD_PARALLEL_LEVEL=16 uv run --with git+https://github.com/ml-explore/mlx#2770a1024082eb10cce6bc0ac589ad089e7be611 something.py
average time per step: 2.870063 ms

Expected behavior
The speed should be the same or similar as mlx versions are increasing.

Desktop (please complete the following information):

ProductName:            macOS
ProductVersion:         15.3.1
BuildVersion:           24D70
      Model Name: MacBook Air
      Model Identifier: Mac15,12
      Model Number: MXCV3FN/A
      Chip: Apple M3
      Total Number of Cores: 8 (4 performance and 4 efficiency)
      Memory: 16 GB
      System Firmware Version: 11881.81.4
      OS Loader Version: 11881.81.4
      Serial Number (system): MW9GK71RY5
      Hardware UUID: 810DA0DC-BEF2-5453-848C-AE07236C3260
      Provisioning UDID: 00008122-001A089C2129001C
      Activation Lock Status: Disabled

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions