Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion gpt_oss/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ set(METAL_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal
)
Expand All @@ -38,13 +40,15 @@ add_custom_command(
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/random.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air"
COMMAND xcrun -sdk macosx metallib "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air" "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air" "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/random.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air" "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air" -o "${METAL_LIB}"
COMMAND xcrun -sdk macosx metallib "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air" "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air" "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/random.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air" "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air" "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air" -o "${METAL_LIB}"
DEPENDS ${METAL_SOURCES}
COMMENT "Compiling Metal compute library"
)
Expand Down
5 changes: 5 additions & 0 deletions gpt_oss/metal/benchmark/end-to-end.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ static void end2end_prefill(benchmark::State& state,
assert(context_length <= num_tokens);
context->num_tokens = context_length;
}
status = gptoss_context_get_num_tokens(context.get(), &num_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to get number of tokens");
return;
}
// Prefill
for (auto _ : state) {
status = gptoss_context_process(context.get());
Expand Down
74 changes: 74 additions & 0 deletions gpt_oss/metal/source/gather_and_accumulate.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <internal/kernel-args.h>
#include <metal_integer>
#include <metal_math>
#include <metal_stdlib>

// TODO(ibrahim): This is not optimal as each thread only gathers and accumulates a single float4. To amortize the
// cost of reading the expert, offset and scales for a token, we should let each thread gather and accumulate several
// float4s.
kernel void gptoss_f32_gather_and_accumulate_e4(
constant gptoss_gather_args& args [[ buffer(0) ]],
const device float* in [[ buffer(1) ]],
const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(2) ]],
const device uint* expert_offsets [[ buffer(3) ]],
const device uint* intra_expert_offsets [[ buffer(4) ]],
device float* out [[ buffer(5) ]],
uint3 gid [[thread_position_in_grid]])
{
const uint T = args.tokens;
const uint k = args.active_experts_per_token;
const uint D = args.token_stride;

assert((D & 3u) == 0);
assert(k == 4);

const uint row = gid.y;
if (row >= T) {
return;
}

const uint col_vec4 = gid.x;
const uint col = col_vec4 * 4u;
if (col >= D) {
return;
}

device float4* dst4 = reinterpret_cast<device float4*>(out + row * D + col);

const uint base = row * k;
const gptoss_expert_prediction expert0 = expert_predictions[base];
const gptoss_expert_prediction expert1 = expert_predictions[base + 1];
const gptoss_expert_prediction expert2 = expert_predictions[base + 2];
const gptoss_expert_prediction expert3 = expert_predictions[base + 3];
const uint expert0_id = expert0.expert_id;
const uint expert1_id = expert1.expert_id;
const uint expert2_id = expert2.expert_id;
const uint expert3_id = expert3.expert_id;
const float scale0 = expert0.score;
const float scale1 = expert1.score;
const float scale2 = expert2.score;
const float scale3 = expert3.score;
const uint4 current_intra_expert_offsets =
*reinterpret_cast<const device uint4*>(&intra_expert_offsets[base]);
// Get the row indices for the current expert ids
const uint r0 = expert_offsets[expert0_id] + current_intra_expert_offsets.x;
const uint r1 = expert_offsets[expert1_id] + current_intra_expert_offsets.y;
const uint r2 = expert_offsets[expert2_id] + current_intra_expert_offsets.z;
const uint r3 = expert_offsets[expert3_id] + current_intra_expert_offsets.w;

const device float4* src0 =
reinterpret_cast<const device float4*>(in + r0 * D + col);
const device float4* src1 =
reinterpret_cast<const device float4*>(in + r1 * D + col);
const device float4* src2 =
reinterpret_cast<const device float4*>(in + r2 * D + col);
const device float4* src3 =
reinterpret_cast<const device float4*>(in + r3 * D + col);

float4 acc = *dst4;
acc = metal::fma(*src0, scale0, acc);
acc = metal::fma(*src1, scale1, acc);
acc = metal::fma(*src2, scale2, acc);
acc = metal::fma(*src3, scale3, acc);
*dst4 = acc;
}
47 changes: 47 additions & 0 deletions gpt_oss/metal/source/include/internal/kernel-args.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@
#define MLP_GATE_Sg_Bm 16
#define MLP_GATE_Sg_Bn 16

#define MOE_DENSE_MATMUL_SWIGLU_Bm 32
#define MOE_DENSE_MATMUL_SWIGLU_Bn 64
#define MOE_DENSE_MATMUL_SWIGLU_Bk 16
#define MOE_DENSE_MATMUL_SWIGLU_Sg_Bm 32
#define MOE_DENSE_MATMUL_SWIGLU_Sg_Bn 16

#define MOE_DENSE_MATMUL_Bm 32
#define MOE_DENSE_MATMUL_Bn 64
#define MOE_DENSE_MATMUL_Bk 16
#define MOE_DENSE_MATMUL_Sg_Bm 32
#define MOE_DENSE_MATMUL_Sg_Bn 16

struct gptoss_expert_prediction {
uint32_t expert_id;
float score;
Expand Down Expand Up @@ -92,6 +104,41 @@ struct gptoss_dense_matmul_args {
uint32_t k;
};

struct gptoss_scatter_args {
uint32_t tokens;
uint32_t active_experts_per_token;
uint32_t token_stride;
};

struct gptoss_moe_dense_matmul_swiglu_args {
uint32_t expert_token_count;
uint32_t k;
uint32_t n;
uint32_t expert_id;
uint32_t expert_token_offset;
uint32_t weight_blocks_expert_stride_bytes;
uint32_t weight_scales_expert_stride_bytes;
uint32_t bias_expert_stride_bytes;
float swiglu_min;
float swiglu_max;
};
struct gptoss_moe_dense_matmul_args {
uint32_t expert_token_count;
uint32_t k;
uint32_t n;
uint32_t expert_id;
uint32_t expert_token_offset;
uint32_t weight_blocks_expert_stride_bytes;
uint32_t weight_scales_expert_stride_bytes;
uint32_t bias_expert_stride_bytes;
};

struct gptoss_gather_args {
uint32_t tokens;
uint32_t active_experts_per_token;
uint32_t token_stride;
};

struct gptoss_unembedding_args {
uint32_t num_column_vecs;
uint32_t num_rows_per_threadgroup;
Expand Down
Loading