-
Notifications
You must be signed in to change notification settings - Fork 883
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Custom primitive + RoPE fat op (#676)
* extensions start * rope custom op * fix build * docs + rope benchmark * fix test * Add a Metal kernel for RoPE * Fix position of traditional * transform tests * Move rope computation to float and fix tests * Fix the test and a typo * change to fast * fix no metal build --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
- Loading branch information
1 parent
523b2d1
commit ba06521
Showing
18 changed files
with
624 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright © 2023-2024 Apple Inc. | ||
|
||
import mlx.core as mx | ||
import mlx.nn as nn | ||
from time_utils import time_fn | ||
|
||
|
||
def time_rope(): | ||
rope = nn.RoPE(4096) | ||
|
||
# vec | ||
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16) | ||
mx.eval(x) | ||
|
||
def rope_vec(x): | ||
for _ in range(32): | ||
x = rope(x) | ||
return x | ||
|
||
time_fn(rope_vec, x) | ||
|
||
# matrix | ||
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16) | ||
mx.eval(x) | ||
|
||
def rope_mat(x): | ||
for _ in range(32): | ||
x = rope(x) | ||
return x | ||
|
||
time_fn(rope_mat, x) | ||
|
||
|
||
if __name__ == "__main__": | ||
time_rope() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// Copyright © 2023-2024 Apple Inc. | ||
|
||
#include "mlx/fast.h" | ||
#include "mlx/primitives.h" | ||
|
||
namespace mlx::core::fast { | ||
|
||
void RoPE::eval_cpu( | ||
const std::vector<array>& inputs, | ||
std::vector<array>& outputs) { | ||
throw std::runtime_error("NYI"); | ||
} | ||
|
||
} // namespace mlx::core::fast |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ set( | |
"quantized" | ||
"random" | ||
"reduce" | ||
"rope" | ||
"scan" | ||
"softmax" | ||
"sort" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// Copyright © 2023-2024 Apple Inc. | ||
|
||
#include <metal_math> | ||
|
||
#include "mlx/backend/metal/kernels/bf16.h" | ||
#include "mlx/backend/metal/kernels/utils.h" | ||
|
||
template <typename T, bool traditional> | ||
[[kernel]] void rope( | ||
const device T *in [[buffer(0)]], | ||
device T * out [[buffer(1)]], | ||
constant const size_t strides[3], | ||
constant const int& offset, | ||
constant const float& base, | ||
constant const float& scale, | ||
uint3 pos [[thread_position_in_grid]], | ||
uint3 grid [[threads_per_grid]]) { | ||
// Compute the input and output indices | ||
uint in_index_1, in_index_2; | ||
uint out_index_1, out_index_2; | ||
if (traditional) { | ||
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z)); | ||
out_index_2 = out_index_1 + 1; | ||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; | ||
in_index_2 = in_index_1 + strides[2]; | ||
} else { | ||
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z)); | ||
out_index_2 = out_index_1 + grid.x; | ||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; | ||
in_index_2 = in_index_1 + grid.x * strides[2]; | ||
} | ||
|
||
// Figure out L and d. | ||
float L = scale * static_cast<float>(pos.y + offset); | ||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x); | ||
|
||
// Compute costheta, sintheta | ||
float theta = L * metal::exp2(-d * base); | ||
float costheta = metal::fast::cos(theta); | ||
float sintheta = metal::fast::sin(theta); | ||
|
||
// Read and write the output | ||
float x1 = static_cast<float>(in[in_index_1]); | ||
float x2 = static_cast<float>(in[in_index_2]); | ||
float rx1 = x1 * costheta - x2 * sintheta; | ||
float rx2 = x1 * sintheta + x2 * costheta; | ||
out[out_index_1] = static_cast<T>(rx1); | ||
out[out_index_2] = static_cast<T>(rx2); | ||
} | ||
|
||
#define instantiate_rope(name, type, traditional) \ | ||
template [[host_name("rope_" #name)]] \ | ||
[[kernel]] void rope<type, traditional>( \ | ||
const device type* in [[buffer(0)]], \ | ||
device type* out [[buffer(1)]], \ | ||
constant const size_t strides[3], \ | ||
constant const int& offset, \ | ||
constant const float& base, \ | ||
constant const float& scale, \ | ||
uint3 pos [[thread_position_in_grid]], \ | ||
uint3 grid [[threads_per_grid]]); | ||
|
||
instantiate_rope(traditional_float16, half, true) | ||
instantiate_rope(traditional_bfloat16, bfloat16_t, true) | ||
instantiate_rope(traditional_float32, float, true) | ||
instantiate_rope(float16, half, false) | ||
instantiate_rope(bfloat16, bfloat16_t, false) | ||
instantiate_rope(float32, float, false) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// Copyright © 2023-2024 Apple Inc. | ||
|
||
#include "mlx/backend/metal/utils.h" | ||
#include "mlx/fast.h" | ||
#include "mlx/primitives.h" | ||
|
||
namespace mlx::core::fast { | ||
|
||
void RoPE::eval_gpu( | ||
const std::vector<array>& inputs, | ||
std::vector<array>& outputs) { | ||
assert(inputs.size() == 1); | ||
assert(outputs.size() == 1); | ||
auto& in = inputs[0]; | ||
auto& out = outputs[0]; | ||
|
||
if (in.ndim() != 3) { | ||
throw std::runtime_error( | ||
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)"); | ||
} | ||
if (dims_ != in.shape(-1)) { | ||
throw std::runtime_error("[RoPE] Partial RoPE application not supported"); | ||
} | ||
if (in.flags().row_contiguous && in.is_donatable()) { | ||
out.move_shared_buffer(in); | ||
} else { | ||
out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||
} | ||
|
||
auto& s = out.primitive().stream(); | ||
auto& d = metal::device(s.device); | ||
std::ostringstream kname; | ||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in); | ||
auto kernel = d.get_kernel(kname.str()); | ||
auto compute_encoder = d.get_command_encoder(s.index); | ||
|
||
bool donated = in.data_shared_ptr() == nullptr; | ||
float base = std::log2(base_); | ||
compute_encoder->setComputePipelineState(kernel); | ||
set_array_buffer(compute_encoder, donated ? out : in, 0); | ||
set_array_buffer(compute_encoder, out, 1); | ||
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2); | ||
compute_encoder->setBytes(&offset_, sizeof(int), 3); | ||
compute_encoder->setBytes(&base, sizeof(float), 4); | ||
compute_encoder->setBytes(&scale_, sizeof(float), 5); | ||
|
||
int dim0 = in.shape(2) / 2; | ||
int dim1 = in.shape(1); | ||
int dim2 = in.shape(0); | ||
auto group_dims = get_block_dims(dim0, dim1, dim2); | ||
auto grid_dims = MTL::Size(dim0, dim1, dim2); | ||
compute_encoder->dispatchThreads(grid_dims, group_dims); | ||
} | ||
|
||
} // namespace mlx::core::fast |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
// Copyright © 2023-2024 Apple Inc. | ||
|
||
#include "mlx/fast.h" | ||
#include "mlx/transforms.h" | ||
|
||
namespace mlx::core::fast { | ||
|
||
std::vector<array> Custom::vjp( | ||
const std::vector<array>& primals, | ||
const std::vector<array>& cotangents, | ||
const std::vector<int>& argnums, | ||
const std::vector<array>& outputs) { | ||
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents); | ||
std::vector<array> vjp_outs; | ||
for (int i = 0, j = 0; i < vjps.size(); ++i) { | ||
if (i < argnums.size() && i == argnums[j]) { | ||
vjp_outs.push_back(vjps[i]); | ||
j++; | ||
} | ||
} | ||
return vjp_outs; | ||
} | ||
|
||
std::vector<array> Custom::jvp( | ||
const std::vector<array>& primals, | ||
const std::vector<array>& tangents, | ||
const std::vector<int>& argnums) { | ||
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents); | ||
std::vector<array> jvp_outs; | ||
for (int i = 0, j = 0; i < jvps.size(); ++i) { | ||
if (i < argnums.size() && i == argnums[j]) { | ||
jvp_outs.push_back(jvps[i]); | ||
j++; | ||
} | ||
} | ||
return jvp_outs; | ||
} | ||
|
||
std::pair<std::vector<array>, std::vector<int>> Custom::vmap( | ||
const std::vector<array>& inputs, | ||
const std::vector<int>& axes) { | ||
auto outputs = mlx::core::vmap(fallback_, axes)(inputs); | ||
auto out_axes = std::vector<int>(outputs.size(), 0); | ||
return {outputs, out_axes}; | ||
} | ||
|
||
array rope( | ||
const array& x, | ||
int dims, | ||
bool traditional, | ||
float base, | ||
float scale, | ||
int offset, | ||
StreamOrDevice s /* = {} */) { | ||
if (x.ndim() != 3) { | ||
std::ostringstream msg; | ||
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim() | ||
<< " dimensions."; | ||
throw std::invalid_argument(msg.str()); | ||
} | ||
if (traditional && x.shape(-1) != dims) { | ||
throw std::invalid_argument( | ||
"[rope] Does not support partial traditional application."); | ||
} | ||
|
||
auto fallback = [dims, traditional, base, scale, offset, s]( | ||
const std::vector<array>& inputs) { | ||
auto& x = inputs[0]; | ||
auto t = x.dtype(); | ||
auto N = x.shape(1) + offset; | ||
// Compute sines and cosines | ||
auto half_dims = dims / 2; | ||
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s); | ||
auto freqs = negative(arange(0, half_dims, t, s), s); | ||
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s); | ||
auto theta = | ||
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s); | ||
auto coss = cos(theta, s); | ||
auto sins = sin(theta, s); | ||
|
||
if (traditional) { | ||
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s); | ||
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s); | ||
std::vector<array> outs; | ||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s)); | ||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s)); | ||
for (auto& o : outs) { | ||
o = expand_dims(o, 3, s); | ||
} | ||
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)}; | ||
} else { | ||
auto out_s = x.shape(); | ||
out_s.back() = half_dims; | ||
auto x1 = slice(x, {0, 0, 0}, out_s, s); | ||
out_s.back() = dims; | ||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s); | ||
|
||
std::vector<array> outs; | ||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s)); | ||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s)); | ||
if (dims < x.shape(-1)) { | ||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); | ||
} | ||
return std::vector<array>{concatenate(outs, 2, s)}; | ||
} | ||
}; | ||
// TODO change to condition for using custom prim | ||
auto stream = to_stream(s); | ||
if (stream.device == Device::gpu && x.shape(-1) == dims) { | ||
return array( | ||
x.shape(), | ||
x.dtype(), | ||
std::make_unique<RoPE>( | ||
stream, fallback, dims, traditional, base, scale, offset), | ||
{x}); | ||
} | ||
return fallback({x})[0]; | ||
} | ||
|
||
bool RoPE::is_equivalent(const Primitive& other) const { | ||
const RoPE& a_other = static_cast<const RoPE&>(other); | ||
return ( | ||
dims_ == a_other.dims_ && base_ == a_other.base_ && | ||
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ && | ||
offset_ == a_other.offset_); | ||
} | ||
|
||
} // namespace mlx::core::fast |
Oops, something went wrong.