Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantize embedding #994

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/src/_templates/nn-module-template.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{{ fullname | escape | underline}}

.. currentmodule:: {{ module }}

.. autoclass:: {{ objname }}

{% block methods %}

{% if methods %}
.. rubric:: {{ _('Methods') }}

.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != "__init__" %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}

1 change: 1 addition & 0 deletions docs/src/python/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ In detail:
:toctree: _autosummary

value_and_grad
quantize

.. toctree::

Expand Down
3 changes: 2 additions & 1 deletion docs/src/python/nn/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Layers
Mish
MultiHeadAttention
PReLU
QuantizedEmbedding
QuantizedLinear
RMSNorm
ReLU
Expand All @@ -43,4 +44,4 @@ Layers
Softshrink
Step
Transformer
Upsample
Upsample
1 change: 1 addition & 0 deletions docs/src/python/tree_utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_flatten
tree_unflatten
tree_map
tree_map_with_path
2 changes: 1 addition & 1 deletion python/mlx/nn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
)
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize
from mlx.nn.layers.recurrent import GRU, LSTM, RNN
from mlx.nn.layers.transformer import (
MultiHeadAttention,
Expand Down
13 changes: 11 additions & 2 deletions python/mlx/nn/layers/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.

import math

Expand All @@ -14,7 +14,7 @@ class Embedding(Module):

Args:
num_embeddings (int): How many possible discrete tokens can we embed.
Usually called the vocabulary size.
Usually called the vocabulary size.
dims (int): The dimensionality of the embeddings.
"""

Expand All @@ -28,3 +28,12 @@ def _extra_repr(self):

def __call__(self, x):
return self.weight[x]

def as_linear(self, x):
"""
Call the embedding layer as a linear layer.

Use this for example when input embedding and output projection
weights are tied.
"""
return x @ self.weight.T
176 changes: 142 additions & 34 deletions python/mlx/nn/layers/quantized.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,143 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.

import math
from typing import Callable, Optional

import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.utils import tree_flatten, tree_map
from mlx.utils import tree_map_with_path


def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
class_predicate: Optional[callable] = None,
):
"""Quantize the sub-modules of a module according to a predicate.

By default all :obj:`Linear` and :obj:`Embedding` layers will be
quantized. Note also, the module is updated in-place.

Args:
model (mlx.nn.Module): The model whose leaf modules may be quantized.
group_size (int): The quantization group size (see
:func:`mlx.core.quantize`). Default: ``64``.
bits (int): The number of bits per parameter (see
:func:`mlx.core.quantize`). Default: ``4``.
class_predicate (Optional[Callable]): A callable which receives the
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
it should be quantized and ``False`` otherwise. If ``None``, then
all linear and embedding layers are quantized. Default: ``None``.
"""
class_predicate = class_predicate or (
lambda _, m: isinstance(m, (Linear, Embedding))
)

def _maybe_quantize(path, m):
if class_predicate(path, m):
if isinstance(m, Linear):
return QuantizedLinear.from_linear(m, group_size, bits)
elif isinstance(m, Embedding):
return QuantizedEmbedding.from_embedding(m, group_size, bits)
else:
raise ValueError(f"Unable to quantize model of type {type(m)}")
else:
return m

leaves = model.leaf_modules()
leaves = tree_map_with_path(_maybe_quantize, leaves, is_leaf=Module.is_module)
model.update_modules(leaves)


class QuantizedEmbedding(Module):
"""The same as :obj:`Embedding` but with a quantized weight matrix.

:obj:`QuantizedEmbedding` also provides a :meth:`from_embedding`
classmethod to convert embedding layers to :obj:`QuantizedEmbedding`
layers.

Args:
num_embeddings (int): How many possible discrete tokens can we embed.
Usually called the vocabulary size.
dims (int): The dimensionality of the embeddings.
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
"""

def __init__(
self,
num_embeddings: int,
dims: int,
group_size: int = 64,
bits: int = 4,
):
super().__init__()

# Quantization config
self.group_size = group_size
self.bits = bits

# Initialize the quantized weight
scale = math.sqrt(1 / dims)
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
self.num_embeddings = num_embeddings
self.dims = dims

# Freeze this model's parameters
self.freeze()

def __call__(self, x):
s = x.shape
x = x.flatten()
out = mx.dequantize(
self["weight"][x],
scales=self["scales"][x],
biases=self["biases"][x],
group_size=self.group_size,
bits=self.bits,
)
return out.reshape(*s, -1)

def as_linear(self, x):
"""
Call the quantized embedding layer as a quantized linear layer.

Use this for example when input embedding and output projection
weights are tied.
"""
return mx.quantized_matmul(
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
transpose=True,
group_size=self.group_size,
bits=self.bits,
)

def _extra_repr(self):
return (
f"{self.num_embeddings}, {self.dims}, "
f"group_size={self.group_size}, bits={self.bits}"
)

@classmethod
def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
embedding_layer.weight, group_size, bits
)
return ql


class QuantizedLinear(Module):
Expand All @@ -15,23 +147,18 @@ class QuantizedLinear(Module):
parameters are frozen and will not be included in any gradient computation
but this will probably change in the future.

QuantizedLinear also provides two useful classmethods to convert linear
layers to QuantizedLinear layers.

- :meth:`from_linear` returns a QuantizedLinear layer that applies the same
linear transformation up to the quantization error.
- :meth:`quantize_module` swaps all the linear layers of the passed module
with QuantizedLinear ones.
:obj:`QuantizedLinear` also provides a classmethod :meth:`from_linear` to
convert linear layers to :obj:`QuantizedLinear` layers.

Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
input_dims (int): The dimensionality of the input features.
output_dims (int): The dimensionality of the output features.
bias (bool, optional): If set to ``False`` then the layer will not use
a bias. (default: True).
a bias. Default: ``True``.
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. (default: 64)
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. (default: 4)
See :func:`~mlx.core.quantize`. Default: ``4``.
"""

def __init__(
Expand Down Expand Up @@ -94,8 +221,7 @@ def __call__(self, x):

@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
"""Create a QuantizedLinear layer from the parameters of a provided
linear layer."""
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
Expand All @@ -105,21 +231,3 @@ def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
ql.bias = linear_layer.bias

return ql

@classmethod
def quantize_module(
cls,
model: Module,
group_size: int = 64,
bits: int = 4,
linear_class_predicate=lambda m: isinstance(m, Linear),
):
def _quantize_if_linear(m):
if linear_class_predicate(m):
return cls.from_linear(m, group_size, bits)
else:
return m

leaves = model.leaf_modules()
leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
model.update_modules(leaves)