Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom primitive + RoPE fat op #676

Merged
merged 12 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions benchmarks/python/rope_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright © 2023-2024 Apple Inc.

import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn


def time_rope():
rope = nn.RoPE(4096)

# vec
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
mx.eval(x)

def rope_vec(x):
for _ in range(32):
x = rope(x)
return x

time_fn(rope_vec, x)

# matrix
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
mx.eval(x)

def rope_mat(x):
for _ in range(32):
x = rope(x)
return x

time_fn(rope_mat, x)


if __name__ == "__main__":
time_rope()
3 changes: 2 additions & 1 deletion mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ target_sources(
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
Expand Down
14 changes: 14 additions & 0 deletions mlx/backend/common/rope.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright © 2023-2024 Apple Inc.

#include "mlx/fast.h"
#include "mlx/primitives.h"

namespace mlx::core::fast {

void RoPE::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("NYI");
}

} // namespace mlx::core::fast
1 change: 1 addition & 0 deletions mlx/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ set(
"quantized"
"random"
"reduce"
"rope"
"scan"
"softmax"
"sort"
Expand Down
68 changes: 68 additions & 0 deletions mlx/backend/metal/kernels/rope.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright © 2023-2024 Apple Inc.

#include <metal_math>

#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"

template <typename T, bool traditional>
[[kernel]] void rope(
const device T *in [[buffer(0)]],
device T * out [[buffer(1)]],
constant const size_t strides[3],
constant const int& offset,
constant const float& base,
constant const float& scale,
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute the input and output indices
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
out_index_2 = out_index_1 + grid.x;
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}

// Figure out L and d.
float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);

// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);

// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1 = x1 * costheta - x2 * sintheta;
float rx2 = x1 * sintheta + x2 * costheta;
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
}

#define instantiate_rope(name, type, traditional) \
template [[host_name("rope_" #name)]] \
[[kernel]] void rope<type, traditional>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);

instantiate_rope(traditional_float16, half, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
instantiate_rope(traditional_float32, float, true)
instantiate_rope(float16, half, false)
instantiate_rope(bfloat16, bfloat16_t, false)
instantiate_rope(float32, float, false)
55 changes: 55 additions & 0 deletions mlx/backend/metal/rope.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright © 2023-2024 Apple Inc.

#include "mlx/backend/metal/utils.h"
#include "mlx/fast.h"
#include "mlx/primitives.h"

namespace mlx::core::fast {

void RoPE::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
auto& out = outputs[0];

if (in.ndim() != 3) {
throw std::runtime_error(
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
}
if (dims_ != in.shape(-1)) {
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
}
if (in.flags().row_contiguous && in.is_donatable()) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a contig and copy check before this right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the copy check is. Also row_contiguous is stricter than contiguous is it not? ie all row_contiguous arrays are contiguous but not the other way around.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant, if it's not contiguous, we should make a contiguous copy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not appear to me that your kernel handles non-contiguous inputs, but maybe I missed something..

Copy link
Member Author

@awni awni Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think I missed it, I was looking for elem_to_loc, but you hardcoded the strides.. so it should be ok

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check here though that the input has the same size as the output? If it's broadcasted e.g. along the last axis it would be incorrect to donate right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I hardcoded the strides cause the grid is launched with half the last dimension and it can't be delegated to a simple elem_to_loc. I would have to do something like multiply pos.x by 2 and then pass to elem_to_loc etc. I think this is equally readable but I am open to suggestions :-)

Regarding broadcasting, a broadcasted array wouldn't be row_contiguous so this check should be fine donation-wise, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh of course! Let me quietly exit this thread before I say anything else incorrect

out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}

auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
std::ostringstream kname;
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);

bool donated = in.data_shared_ptr() == nullptr;
float base = std::log2(base_);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, donated ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&offset_, sizeof(int), 3);
compute_encoder->setBytes(&base, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 5);

int dim0 = in.shape(2) / 2;
int dim1 = in.shape(1);
int dim2 = in.shape(0);
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}

} // namespace mlx::core::fast
5 changes: 5 additions & 0 deletions mlx/backend/no_metal/primitives.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc.

#include "mlx/primitives.h"
#include "mlx/fast.h"

#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
Expand Down Expand Up @@ -95,4 +96,8 @@ NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Transpose)

namespace fast {
NO_GPU_MULTI(RoPE)
} // namespace fast

} // namespace mlx::core
128 changes: 128 additions & 0 deletions mlx/fast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright © 2023-2024 Apple Inc.

#include "mlx/fast.h"
#include "mlx/transforms.h"

namespace mlx::core::fast {

std::vector<array> Custom::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
std::vector<array> vjp_outs;
for (int i = 0, j = 0; i < vjps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
vjp_outs.push_back(vjps[i]);
j++;
}
}
return vjp_outs;
}

std::vector<array> Custom::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
std::vector<array> jvp_outs;
for (int i = 0, j = 0; i < jvps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
jvp_outs.push_back(jvps[i]);
j++;
}
}
return jvp_outs;
}

std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto outputs = mlx::core::vmap(fallback_, axes)(inputs);
auto out_axes = std::vector<int>(outputs.size(), 0);
return {outputs, out_axes};
}

array rope(
const array& x,
int dims,
bool traditional,
float base,
float scale,
int offset,
StreamOrDevice s /* = {} */) {
if (x.ndim() != 3) {
std::ostringstream msg;
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (traditional && x.shape(-1) != dims) {
throw std::invalid_argument(
"[rope] Does not support partial traditional application.");
}

auto fallback = [dims, traditional, base, scale, offset, s](
const std::vector<array>& inputs) {
auto& x = inputs[0];
auto t = x.dtype();
auto N = x.shape(1) + offset;
// Compute sines and cosines
auto half_dims = dims / 2;
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
auto freqs = negative(arange(0, half_dims, t, s), s);
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
auto coss = cos(theta, s);
auto sins = sin(theta, s);

if (traditional) {
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
std::vector<array> outs;
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
for (auto& o : outs) {
o = expand_dims(o, 3, s);
}
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
} else {
auto out_s = x.shape();
out_s.back() = half_dims;
auto x1 = slice(x, {0, 0, 0}, out_s, s);
out_s.back() = dims;
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);

std::vector<array> outs;
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
}
return std::vector<array>{concatenate(outs, 2, s)};
}
};
// TODO change to condition for using custom prim
auto stream = to_stream(s);
if (stream.device == Device::gpu && x.shape(-1) == dims) {
return array(
x.shape(),
x.dtype(),
std::make_unique<RoPE>(
stream, fallback, dims, traditional, base, scale, offset),
{x});
}
return fallback({x})[0];
}

bool RoPE::is_equivalent(const Primitive& other) const {
const RoPE& a_other = static_cast<const RoPE&>(other);
return (
dims_ == a_other.dims_ && base_ == a_other.base_ &&
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
offset_ == a_other.offset_);
}

} // namespace mlx::core::fast
Loading