-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom primitive + RoPE fat op #676
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
c1df31d
extensions start
awni d446d58
rope custom op
awni 12034e9
fix build
awni ce8dba7
docs + rope benchmark
awni 1f3fe12
fix test
awni b23090c
Add a Metal kernel for RoPE
angeloskath 638ef68
Fix position of traditional
angeloskath 1d47cf3
transform tests
awni 721f36c
Move rope computation to float and fix tests
angeloskath f29acba
Fix the test and a typo
angeloskath 252576b
change to fast
awni cc06e11
fix no metal build
awni File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright © 2023-2024 Apple Inc. | ||
|
||
import mlx.core as mx | ||
import mlx.nn as nn | ||
from time_utils import time_fn | ||
|
||
|
||
def time_rope(): | ||
rope = nn.RoPE(4096) | ||
|
||
# vec | ||
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16) | ||
mx.eval(x) | ||
|
||
def rope_vec(x): | ||
for _ in range(32): | ||
x = rope(x) | ||
return x | ||
|
||
time_fn(rope_vec, x) | ||
|
||
# matrix | ||
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16) | ||
mx.eval(x) | ||
|
||
def rope_mat(x): | ||
for _ in range(32): | ||
x = rope(x) | ||
return x | ||
|
||
time_fn(rope_mat, x) | ||
|
||
|
||
if __name__ == "__main__": | ||
time_rope() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// Copyright © 2023-2024 Apple Inc. | ||
|
||
#include "mlx/fast.h" | ||
#include "mlx/primitives.h" | ||
|
||
namespace mlx::core::fast { | ||
|
||
void RoPE::eval_cpu( | ||
const std::vector<array>& inputs, | ||
std::vector<array>& outputs) { | ||
throw std::runtime_error("NYI"); | ||
} | ||
|
||
} // namespace mlx::core::fast |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ set( | |
"quantized" | ||
"random" | ||
"reduce" | ||
"rope" | ||
"scan" | ||
"softmax" | ||
"sort" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// Copyright © 2023-2024 Apple Inc. | ||
|
||
#include <metal_math> | ||
|
||
#include "mlx/backend/metal/kernels/bf16.h" | ||
#include "mlx/backend/metal/kernels/utils.h" | ||
|
||
template <typename T, bool traditional> | ||
[[kernel]] void rope( | ||
const device T *in [[buffer(0)]], | ||
device T * out [[buffer(1)]], | ||
constant const size_t strides[3], | ||
constant const int& offset, | ||
constant const float& base, | ||
constant const float& scale, | ||
uint3 pos [[thread_position_in_grid]], | ||
uint3 grid [[threads_per_grid]]) { | ||
// Compute the input and output indices | ||
uint in_index_1, in_index_2; | ||
uint out_index_1, out_index_2; | ||
if (traditional) { | ||
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z)); | ||
out_index_2 = out_index_1 + 1; | ||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; | ||
in_index_2 = in_index_1 + strides[2]; | ||
} else { | ||
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z)); | ||
out_index_2 = out_index_1 + grid.x; | ||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; | ||
in_index_2 = in_index_1 + grid.x * strides[2]; | ||
} | ||
|
||
// Figure out L and d. | ||
float L = scale * static_cast<float>(pos.y + offset); | ||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x); | ||
|
||
// Compute costheta, sintheta | ||
float theta = L * metal::exp2(-d * base); | ||
float costheta = metal::fast::cos(theta); | ||
float sintheta = metal::fast::sin(theta); | ||
|
||
// Read and write the output | ||
float x1 = static_cast<float>(in[in_index_1]); | ||
float x2 = static_cast<float>(in[in_index_2]); | ||
float rx1 = x1 * costheta - x2 * sintheta; | ||
float rx2 = x1 * sintheta + x2 * costheta; | ||
out[out_index_1] = static_cast<T>(rx1); | ||
out[out_index_2] = static_cast<T>(rx2); | ||
} | ||
|
||
#define instantiate_rope(name, type, traditional) \ | ||
template [[host_name("rope_" #name)]] \ | ||
[[kernel]] void rope<type, traditional>( \ | ||
const device type* in [[buffer(0)]], \ | ||
device type* out [[buffer(1)]], \ | ||
constant const size_t strides[3], \ | ||
constant const int& offset, \ | ||
constant const float& base, \ | ||
constant const float& scale, \ | ||
uint3 pos [[thread_position_in_grid]], \ | ||
uint3 grid [[threads_per_grid]]); | ||
|
||
instantiate_rope(traditional_float16, half, true) | ||
instantiate_rope(traditional_bfloat16, bfloat16_t, true) | ||
instantiate_rope(traditional_float32, float, true) | ||
instantiate_rope(float16, half, false) | ||
instantiate_rope(bfloat16, bfloat16_t, false) | ||
instantiate_rope(float32, float, false) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// Copyright © 2023-2024 Apple Inc. | ||
|
||
#include "mlx/backend/metal/utils.h" | ||
#include "mlx/fast.h" | ||
#include "mlx/primitives.h" | ||
|
||
namespace mlx::core::fast { | ||
|
||
void RoPE::eval_gpu( | ||
const std::vector<array>& inputs, | ||
std::vector<array>& outputs) { | ||
assert(inputs.size() == 1); | ||
assert(outputs.size() == 1); | ||
auto& in = inputs[0]; | ||
auto& out = outputs[0]; | ||
|
||
if (in.ndim() != 3) { | ||
throw std::runtime_error( | ||
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)"); | ||
} | ||
if (dims_ != in.shape(-1)) { | ||
throw std::runtime_error("[RoPE] Partial RoPE application not supported"); | ||
} | ||
if (in.flags().row_contiguous && in.is_donatable()) { | ||
out.move_shared_buffer(in); | ||
} else { | ||
out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||
} | ||
|
||
auto& s = out.primitive().stream(); | ||
auto& d = metal::device(s.device); | ||
std::ostringstream kname; | ||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in); | ||
auto kernel = d.get_kernel(kname.str()); | ||
auto compute_encoder = d.get_command_encoder(s.index); | ||
|
||
bool donated = in.data_shared_ptr() == nullptr; | ||
float base = std::log2(base_); | ||
compute_encoder->setComputePipelineState(kernel); | ||
set_array_buffer(compute_encoder, donated ? out : in, 0); | ||
set_array_buffer(compute_encoder, out, 1); | ||
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2); | ||
compute_encoder->setBytes(&offset_, sizeof(int), 3); | ||
compute_encoder->setBytes(&base, sizeof(float), 4); | ||
compute_encoder->setBytes(&scale_, sizeof(float), 5); | ||
|
||
int dim0 = in.shape(2) / 2; | ||
int dim1 = in.shape(1); | ||
int dim2 = in.shape(0); | ||
auto group_dims = get_block_dims(dim0, dim1, dim2); | ||
auto grid_dims = MTL::Size(dim0, dim1, dim2); | ||
compute_encoder->dispatchThreads(grid_dims, group_dims); | ||
} | ||
|
||
} // namespace mlx::core::fast |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
// Copyright © 2023-2024 Apple Inc. | ||
|
||
#include "mlx/fast.h" | ||
#include "mlx/transforms.h" | ||
|
||
namespace mlx::core::fast { | ||
|
||
std::vector<array> Custom::vjp( | ||
const std::vector<array>& primals, | ||
const std::vector<array>& cotangents, | ||
const std::vector<int>& argnums, | ||
const std::vector<array>& outputs) { | ||
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents); | ||
std::vector<array> vjp_outs; | ||
for (int i = 0, j = 0; i < vjps.size(); ++i) { | ||
if (i < argnums.size() && i == argnums[j]) { | ||
vjp_outs.push_back(vjps[i]); | ||
j++; | ||
} | ||
} | ||
return vjp_outs; | ||
} | ||
|
||
std::vector<array> Custom::jvp( | ||
const std::vector<array>& primals, | ||
const std::vector<array>& tangents, | ||
const std::vector<int>& argnums) { | ||
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents); | ||
std::vector<array> jvp_outs; | ||
for (int i = 0, j = 0; i < jvps.size(); ++i) { | ||
if (i < argnums.size() && i == argnums[j]) { | ||
jvp_outs.push_back(jvps[i]); | ||
j++; | ||
} | ||
} | ||
return jvp_outs; | ||
} | ||
|
||
std::pair<std::vector<array>, std::vector<int>> Custom::vmap( | ||
const std::vector<array>& inputs, | ||
const std::vector<int>& axes) { | ||
auto outputs = mlx::core::vmap(fallback_, axes)(inputs); | ||
auto out_axes = std::vector<int>(outputs.size(), 0); | ||
return {outputs, out_axes}; | ||
} | ||
|
||
array rope( | ||
const array& x, | ||
int dims, | ||
bool traditional, | ||
float base, | ||
float scale, | ||
int offset, | ||
StreamOrDevice s /* = {} */) { | ||
if (x.ndim() != 3) { | ||
std::ostringstream msg; | ||
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim() | ||
<< " dimensions."; | ||
throw std::invalid_argument(msg.str()); | ||
} | ||
if (traditional && x.shape(-1) != dims) { | ||
throw std::invalid_argument( | ||
"[rope] Does not support partial traditional application."); | ||
} | ||
|
||
auto fallback = [dims, traditional, base, scale, offset, s]( | ||
const std::vector<array>& inputs) { | ||
auto& x = inputs[0]; | ||
auto t = x.dtype(); | ||
auto N = x.shape(1) + offset; | ||
// Compute sines and cosines | ||
auto half_dims = dims / 2; | ||
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s); | ||
auto freqs = negative(arange(0, half_dims, t, s), s); | ||
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s); | ||
auto theta = | ||
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s); | ||
auto coss = cos(theta, s); | ||
auto sins = sin(theta, s); | ||
|
||
if (traditional) { | ||
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s); | ||
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s); | ||
std::vector<array> outs; | ||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s)); | ||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s)); | ||
for (auto& o : outs) { | ||
o = expand_dims(o, 3, s); | ||
} | ||
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)}; | ||
} else { | ||
auto out_s = x.shape(); | ||
out_s.back() = half_dims; | ||
auto x1 = slice(x, {0, 0, 0}, out_s, s); | ||
out_s.back() = dims; | ||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s); | ||
|
||
std::vector<array> outs; | ||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s)); | ||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s)); | ||
if (dims < x.shape(-1)) { | ||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); | ||
} | ||
return std::vector<array>{concatenate(outs, 2, s)}; | ||
} | ||
}; | ||
// TODO change to condition for using custom prim | ||
auto stream = to_stream(s); | ||
if (stream.device == Device::gpu && x.shape(-1) == dims) { | ||
return array( | ||
x.shape(), | ||
x.dtype(), | ||
std::make_unique<RoPE>( | ||
stream, fallback, dims, traditional, base, scale, offset), | ||
{x}); | ||
} | ||
return fallback({x})[0]; | ||
} | ||
|
||
bool RoPE::is_equivalent(const Primitive& other) const { | ||
const RoPE& a_other = static_cast<const RoPE&>(other); | ||
return ( | ||
dims_ == a_other.dims_ && base_ == a_other.base_ && | ||
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ && | ||
offset_ == a_other.offset_); | ||
} | ||
|
||
} // namespace mlx::core::fast |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a contig and copy check before this right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what the copy check is. Also
row_contiguous
is stricter thancontiguous
is it not? ie allrow_contiguous
arrays arecontiguous
but not the other way around.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant, if it's not contiguous, we should make a contiguous copy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not appear to me that your kernel handles non-contiguous inputs, but maybe I missed something..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I think I missed it, I was looking for elem_to_loc, but you hardcoded the strides.. so it should be ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to check here though that the input has the same size as the output? If it's broadcasted e.g. along the last axis it would be incorrect to donate right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I hardcoded the strides cause the grid is launched with half the last dimension and it can't be delegated to a simple
elem_to_loc
. I would have to do something like multiplypos.x
by 2 and then pass toelem_to_loc
etc. I think this is equally readable but I am open to suggestions :-)Regarding broadcasting, a broadcasted array wouldn't be
row_contiguous
so this check should be fine donation-wise, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh of course! Let me quietly exit this thread before I say anything else incorrect