diff --git a/benchmarks/python/rope_bench.py b/benchmarks/python/rope_bench.py new file mode 100644 index 000000000..62f01648e --- /dev/null +++ b/benchmarks/python/rope_bench.py @@ -0,0 +1,35 @@ +# Copyright © 2023-2024 Apple Inc. + +import mlx.core as mx +import mlx.nn as nn +from time_utils import time_fn + + +def time_rope(): + rope = nn.RoPE(4096) + + # vec + x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16) + mx.eval(x) + + def rope_vec(x): + for _ in range(32): + x = rope(x) + return x + + time_fn(rope_vec, x) + + # matrix + x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16) + mx.eval(x) + + def rope_mat(x): + for _ in range(32): + x = rope(x) + return x + + time_fn(rope_mat, x) + + +if __name__ == "__main__": + time_rope() diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index cb65e518b..d2f021af5 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -3,9 +3,10 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 0263dff9b..b25001f2c 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -11,6 +11,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp diff --git a/mlx/backend/common/rope.cpp b/mlx/backend/common/rope.cpp new file mode 100644 index 000000000..c0c2bba8e --- /dev/null +++ b/mlx/backend/common/rope.cpp @@ -0,0 +1,14 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/fast.h" +#include "mlx/primitives.h" + +namespace mlx::core::fast { + +void RoPE::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("NYI"); +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 93a25434f..063c283fe 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -32,6 +32,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 12e09deaa..afd2fbc8a 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -23,6 +23,7 @@ set( "quantized" "random" "reduce" + "rope" "scan" "softmax" "sort" diff --git a/mlx/backend/metal/kernels/rope.metal b/mlx/backend/metal/kernels/rope.metal new file mode 100644 index 000000000..484697b6d --- /dev/null +++ b/mlx/backend/metal/kernels/rope.metal @@ -0,0 +1,68 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +template +[[kernel]] void rope( + const device T *in [[buffer(0)]], + device T * out [[buffer(1)]], + constant const size_t strides[3], + constant const int& offset, + constant const float& base, + constant const float& scale, + uint3 pos [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Compute the input and output indices + uint in_index_1, in_index_2; + uint out_index_1, out_index_2; + if (traditional) { + out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z)); + out_index_2 = out_index_1 + 1; + in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; + in_index_2 = in_index_1 + strides[2]; + } else { + out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z)); + out_index_2 = out_index_1 + grid.x; + in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; + in_index_2 = in_index_1 + grid.x * strides[2]; + } + + // Figure out L and d. + float L = scale * static_cast(pos.y + offset); + float d = static_cast(pos.x) / static_cast(grid.x); + + // Compute costheta, sintheta + float theta = L * metal::exp2(-d * base); + float costheta = metal::fast::cos(theta); + float sintheta = metal::fast::sin(theta); + + // Read and write the output + float x1 = static_cast(in[in_index_1]); + float x2 = static_cast(in[in_index_2]); + float rx1 = x1 * costheta - x2 * sintheta; + float rx2 = x1 * sintheta + x2 * costheta; + out[out_index_1] = static_cast(rx1); + out[out_index_2] = static_cast(rx2); +} + +#define instantiate_rope(name, type, traditional) \ + template [[host_name("rope_" #name)]] \ + [[kernel]] void rope( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const size_t strides[3], \ + constant const int& offset, \ + constant const float& base, \ + constant const float& scale, \ + uint3 pos [[thread_position_in_grid]], \ + uint3 grid [[threads_per_grid]]); + +instantiate_rope(traditional_float16, half, true) +instantiate_rope(traditional_bfloat16, bfloat16_t, true) +instantiate_rope(traditional_float32, float, true) +instantiate_rope(float16, half, false) +instantiate_rope(bfloat16, bfloat16_t, false) +instantiate_rope(float32, float, false) diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp new file mode 100644 index 000000000..29295f3ac --- /dev/null +++ b/mlx/backend/metal/rope.cpp @@ -0,0 +1,55 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/utils.h" +#include "mlx/fast.h" +#include "mlx/primitives.h" + +namespace mlx::core::fast { + +void RoPE::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + auto& in = inputs[0]; + auto& out = outputs[0]; + + if (in.ndim() != 3) { + throw std::runtime_error( + "[RoPE] Only 3 dimensions are supported (batch x sequence x dims)"); + } + if (dims_ != in.shape(-1)) { + throw std::runtime_error("[RoPE] Partial RoPE application not supported"); + } + if (in.flags().row_contiguous && in.is_donatable()) { + out.move_shared_buffer(in); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + std::ostringstream kname; + kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in); + auto kernel = d.get_kernel(kname.str()); + auto compute_encoder = d.get_command_encoder(s.index); + + bool donated = in.data_shared_ptr() == nullptr; + float base = std::log2(base_); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, donated ? out : in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2); + compute_encoder->setBytes(&offset_, sizeof(int), 3); + compute_encoder->setBytes(&base, sizeof(float), 4); + compute_encoder->setBytes(&scale_, sizeof(float), 5); + + int dim0 = in.shape(2) / 2; + int dim1 = in.shape(1); + int dim2 = in.shape(0); + auto group_dims = get_block_dims(dim0, dim1, dim2); + auto grid_dims = MTL::Size(dim0, dim1, dim2); + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index dd4edc2ed..bd4026e2c 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/primitives.h" +#include "mlx/fast.h" #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ @@ -95,4 +96,8 @@ NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) +namespace fast { +NO_GPU_MULTI(RoPE) +} // namespace fast + } // namespace mlx::core diff --git a/mlx/fast.cpp b/mlx/fast.cpp new file mode 100644 index 000000000..96d4f03ce --- /dev/null +++ b/mlx/fast.cpp @@ -0,0 +1,128 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/fast.h" +#include "mlx/transforms.h" + +namespace mlx::core::fast { + +std::vector Custom::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents); + std::vector vjp_outs; + for (int i = 0, j = 0; i < vjps.size(); ++i) { + if (i < argnums.size() && i == argnums[j]) { + vjp_outs.push_back(vjps[i]); + j++; + } + } + return vjp_outs; +} + +std::vector Custom::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents); + std::vector jvp_outs; + for (int i = 0, j = 0; i < jvps.size(); ++i) { + if (i < argnums.size() && i == argnums[j]) { + jvp_outs.push_back(jvps[i]); + j++; + } + } + return jvp_outs; +} + +std::pair, std::vector> Custom::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto outputs = mlx::core::vmap(fallback_, axes)(inputs); + auto out_axes = std::vector(outputs.size(), 0); + return {outputs, out_axes}; +} + +array rope( + const array& x, + int dims, + bool traditional, + float base, + float scale, + int offset, + StreamOrDevice s /* = {} */) { + if (x.ndim() != 3) { + std::ostringstream msg; + msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (traditional && x.shape(-1) != dims) { + throw std::invalid_argument( + "[rope] Does not support partial traditional application."); + } + + auto fallback = [dims, traditional, base, scale, offset, s]( + const std::vector& inputs) { + auto& x = inputs[0]; + auto t = x.dtype(); + auto N = x.shape(1) + offset; + // Compute sines and cosines + auto half_dims = dims / 2; + auto positions = multiply(arange(offset, N, t, s), array(scale, t), s); + auto freqs = negative(arange(0, half_dims, t, s), s); + freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s); + auto theta = + multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s); + auto coss = cos(theta, s); + auto sins = sin(theta, s); + + if (traditional) { + auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s); + auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s); + std::vector outs; + outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s)); + outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s)); + for (auto& o : outs) { + o = expand_dims(o, 3, s); + } + return std::vector{reshape(concatenate(outs, 3, s), x.shape(), s)}; + } else { + auto out_s = x.shape(); + out_s.back() = half_dims; + auto x1 = slice(x, {0, 0, 0}, out_s, s); + out_s.back() = dims; + auto x2 = slice(x, {0, 0, half_dims}, out_s, s); + + std::vector outs; + outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s)); + outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s)); + if (dims < x.shape(-1)) { + outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); + } + return std::vector{concatenate(outs, 2, s)}; + } + }; + // TODO change to condition for using custom prim + auto stream = to_stream(s); + if (stream.device == Device::gpu && x.shape(-1) == dims) { + return array( + x.shape(), + x.dtype(), + std::make_unique( + stream, fallback, dims, traditional, base, scale, offset), + {x}); + } + return fallback({x})[0]; +} + +bool RoPE::is_equivalent(const Primitive& other) const { + const RoPE& a_other = static_cast(other); + return ( + dims_ == a_other.dims_ && base_ == a_other.base_ && + scale_ == a_other.scale_ && traditional_ == a_other.traditional_ && + offset_ == a_other.offset_); +} + +} // namespace mlx::core::fast diff --git a/mlx/fast.h b/mlx/fast.h new file mode 100644 index 000000000..5deac0cdb --- /dev/null +++ b/mlx/fast.h @@ -0,0 +1,82 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/ops.h" +#include "mlx/primitives.h" + +namespace mlx::core::fast { + +// Custom primitive accepts a fallback function which it uses for +// transformations. Transformations are virtual so that derived classes may to +// override the default behavior +class Custom : public Primitive { + public: + explicit Custom( + Stream stream, + std::function(std::vector)> fallback) + : Primitive(stream), fallback_(fallback){}; + + virtual std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + + virtual std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + + virtual std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + private: + std::function(std::vector)> fallback_; +}; + +array rope( + const array& x, + int dims, + bool traditional, + float base, + float scale, + int offset, + StreamOrDevice s /* = {} */); + +class RoPE : public Custom { + public: + RoPE( + Stream stream, + std::function(std::vector)> fallback, + int dims, + bool traditional, + float base, + float scale, + int offset) + : Custom(stream, fallback), + dims_(dims), + traditional_(traditional), + base_(base), + scale_(scale), + offset_(offset){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(RoPE) + bool is_equivalent(const Primitive& other) const override; + + private: + std::function(std::vector)> fallback_; + int dims_; + bool traditional_; + float base_; + float scale_; + int offset_; +}; + +} // namespace mlx::core::fast diff --git a/mlx/mlx.h b/mlx/mlx.h index 7b33faba7..1963a4c50 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -6,6 +6,7 @@ #include "mlx/backend/metal/metal.h" #include "mlx/compile.h" #include "mlx/device.h" +#include "mlx/fast.h" #include "mlx/fft.h" #include "mlx/io.h" #include "mlx/linalg.h" diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index a8024f0a4..f0bb92863 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import math from typing import Optional @@ -20,20 +20,13 @@ class RoPE(Module): Args: dims (int): The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged. - traditional (bool, optional): If set to True choose the traditional + traditional (bool, optional): If set to ``True`` choose the traditional implementation which is slightly less efficient. Default: ``False``. base (float, optional): The base used to compute angular frequency for each dimension in the positional encodings. Default: ``10000``. scale (float, optional): The scale used to scale the positions. Default: ``1.0``. - - Attributes: - _cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values. - _cos_sin_theta_value (tuple): Cached cosine and sine values. """ - _cos_sin_theta_key = None - _cos_sin_theta_value = None - def __init__( self, dims: int, @@ -50,69 +43,18 @@ def __init__( def _extra_repr(self): return f"{self.dims}, traditional={self.traditional}" - def _compute_rope(self, costheta, sintheta, x): - x1 = x[..., : self.dims // 2] - x2 = x[..., self.dims // 2 : self.dims] - rx1 = x1 * costheta - x2 * sintheta - rx2 = x1 * sintheta + x2 * costheta - - if self.dims < x.shape[-1]: - rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1) - else: - rx = mx.concatenate([rx1, rx2], axis=-1) - - return rx - - def _compute_traditional_rope(self, costheta, sintheta, x): - x1 = x[..., ::2] - x2 = x[..., 1::2] - rx1 = x1 * costheta - x2 * sintheta - rx2 = x1 * sintheta + x2 * costheta - - if self.dims < x.shape[-1]: - raise NotImplementedError( - "RoPE doesn't implement partial traditional application" - ) - - rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) - - return rx - def __call__(self, x, offset: int = 0): shape = x.shape x = mx.reshape(x, (-1, shape[-2], shape[-1])) - N = x.shape[1] + offset - costheta, sintheta = RoPE.create_cos_sin_theta( - N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype - ) - - rope = ( - self._compute_traditional_rope if self.traditional else self._compute_rope + x = mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=self.base, + scale=self.scale, + offset=offset, ) - rx = rope(costheta, sintheta, x) - - return mx.reshape(rx, shape) - - @classmethod - def create_cos_sin_theta( - cls, - N: int, - D: int, - offset: int = 0, - base: float = 10000, - scale: float = 1.0, - dtype=mx.float32, - ): - if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key: - half_D = D // 2 - positions = mx.arange(offset, N, dtype=dtype) * scale - freqs = mx.exp( - -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) - ) - theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) - cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype) - cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta)) - return cls._cos_sin_theta_value + return mx.reshape(x, shape) class SinusoidalPositionalEncoding(Module): diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 1ba037fdc..7dd862033 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -3,6 +3,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp diff --git a/python/src/fast.cpp b/python/src/fast.cpp new file mode 100644 index 000000000..115ea37ec --- /dev/null +++ b/python/src/fast.cpp @@ -0,0 +1,59 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +#include "mlx/fast.h" +#include "mlx/ops.h" +#include "python/src/utils.h" + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +void init_extensions(py::module_& parent_module) { + py::options options; + options.disable_function_signatures(); + + auto m = + parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); + + m.def( + "rope", + [](const array& a, + int dims, + bool traditional, + float base, + float scale, + int offset, + const StreamOrDevice& s /* = {} */) { + return fast::rope(a, dims, traditional, base, scale, offset, s); + }, + "a"_a, + "dims"_a, + py::kw_only(), + "traditional"_a, + "base"_a, + "scale"_a, + "offset"_a, + "stream"_a = none, + R"pbdoc( + rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array + + Apply rotary positional encoding to the input. + + Args: + a (array): Input array. + dims (int): The feature dimensions to be rotated. If the input feature + is larger than dims then the rest is left unchanged. + traditional (bool): If set to ``True`` choose the traditional + implementation which rotates consecutive dimensions. + base (float): The base used to compute angular frequency for + each dimension in the positional encodings. + scale (float): The scale used to scale the positions. + offset (int): The position offset to start at. + + Returns: + array: The output array. + )pbdoc"); +} diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index 81626e565..ee0f469f9 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -17,6 +17,7 @@ void init_random(py::module_&); void init_fft(py::module_&); void init_linalg(py::module_&); void init_constants(py::module_&); +void init_extensions(py::module_&); PYBIND11_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -33,5 +34,6 @@ PYBIND11_MODULE(core, m) { init_fft(m); init_linalg(m); init_constants(m); + init_extensions(m); m.attr("__version__") = TOSTRING(_VERSION_); } diff --git a/python/src/random.cpp b/python/src/random.cpp index bbcb7a2c8..442d81fee 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -133,7 +133,7 @@ void init_random(py::module_& parent_module) { low (scalar or array, optional): Lower bound of the distribution. Default is ``0``. high (scalar or array, optional): Upper bound of the distribution. Default is ``1``. shape (list(int), optional): Shape of the output. Default is ``()``. - key (array, optional): A PRNG key. Default: None. + key (array, optional): A PRNG key. Default: ``None``. dtype (Dtype, optional): Type of the output. Default is ``float32``. Returns: diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py new file mode 100644 index 000000000..1cb4ddcca --- /dev/null +++ b/python/tests/test_fast.py @@ -0,0 +1,158 @@ +# Copyright © 2023-2024 Apple Inc. + +import math +import unittest + +import mlx.core as mx +import mlx_tests + + +def rope_orig(x, dims, traditional, base, scale, offset): + N = x.shape[1] + offset + dtype = x.dtype + half_D = dims // 2 + positions = mx.arange(offset, N, dtype=dtype) * scale + freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)) + theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) + costheta, sintheta = mx.cos(theta), mx.sin(theta) + if traditional: + x1 = x[..., ::2] + x2 = x[..., 1::2] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) + return mx.reshape(rx, x.shape) + else: + x1 = x[..., : dims // 2] + x2 = x[..., dims // 2 : dims] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + if dims < x.shape[-1]: + rx = mx.concatenate([rx1, rx2, x[..., dims:]], axis=-1) + else: + rx = mx.concatenate([rx1, rx2], axis=-1) + return rx + + +class TestFast(mlx_tests.MLXTestCase): + def test_rope(self): + T = 4 + + # Defaults: dims, dtype, base, scale, offset, traditional + defaults = (8, mx.float32, 10000.0, 1.0, 0, False) + + # Per dtype absolute tolerance + tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2} + + # Test cases: + dtypes = [mx.float32, mx.float16, mx.bfloat16] + bases = [10000.0, 1000000.0] + scales = [1.0, 2.0] + offsets = [0, 3] + traditional = [True, False] + + for traditional in [True, False]: + dims, dtype, _, scale, offset, _ = defaults + for base in bases: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + dims, _, base, scale, offset, _ = defaults + for dtype in dtypes: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + ry = rope_orig( + x.astype(mx.float32), dims, traditional, base, scale, offset + ) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + if dtype != mx.float32: + self.assertLessEqual( + mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max() + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + dims, dtype, base, scale, _, _ = defaults + for offset in offsets: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + dims, dtype, base, _, offset, _ = defaults + for scale in scales: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + def test_fast_transforms(self): + x = mx.random.uniform(shape=(2, 2, 8)) + + defaults = (8, False, 10000.0, 1.0, 0) + dims, traditional, base, scale, offset = defaults + + # VJP + _, vjp_out = mx.vjp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),)) + _, vjp_fast_out = mx.vjp( + lambda x: mx.fast.rope( + x, dims, traditional=traditional, base=base, scale=scale, offset=offset + ), + (x,), + (mx.ones_like(x),), + ) + self.assertTrue(mx.allclose(vjp_out[0], vjp_fast_out[0])) + + # JVP + _, jvp_out = mx.jvp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),)) + _, jvp_fast_out = mx.jvp( + lambda x: mx.fast.rope( + x, dims, traditional=traditional, base=base, scale=scale, offset=offset + ), + (x,), + (mx.ones_like(x),), + ) + self.assertTrue(mx.allclose(jvp_out[0], jvp_fast_out[0])) + + # VMAP + x = mx.random.uniform(shape=(2, 2, 2, 8)) + vmap_out = mx.vmap(lambda x: rope_orig(x, *defaults))(x) + vmap_fast_out = mx.vmap( + lambda x: mx.fast.rope( + x, dims, traditional=traditional, base=base, scale=scale, offset=offset + ) + )(x) + self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) + + +if __name__ == "__main__": + unittest.main()