Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 244 additions & 75 deletions gpt_oss/metal/source/context.c
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create(
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, model->num_experts * sizeof(uint32_t), NULL, &context->expert_offset_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * sizeof(uint32_t), NULL, &context->token_to_expert_routing_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->swiglu_input_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &context->swiglu_activation_buffer);
if (status != gptoss_status_success) {
goto cleanup;
Expand Down Expand Up @@ -115,7 +127,9 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create(
context->allocation_size =
context->residual_activation_buffer.size + context->rmsnorm_activation_buffer.size +
context->qkv_activation_buffer.size + context->sdpa_activation_buffer.size +
context->gate_activation_buffer.size + context->expert_activation_buffer.size + context->swiglu_activation_buffer.size + context->moe_activation_buffer.size +
context->gate_activation_buffer.size + context->expert_activation_buffer.size +
context->expert_offset_buffer.size + context->token_to_expert_routing_buffer.size + context->swiglu_input_buffer.size +
context->swiglu_activation_buffer.size + context->moe_activation_buffer.size +
context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;

context->model = model;
Expand Down Expand Up @@ -176,6 +190,7 @@ static enum gptoss_status process_tokens(
assert(num_output_tokens <= context->max_batch_tokens);
assert(num_input_tokens >= num_output_tokens);
const size_t dense_matmul_kernel_token_multiple_constraint = 64;
const size_t min_tokens_for_dense_moe_kernels = 64;

enum gptoss_status status = gptoss_status_success;
const struct gptoss_model* model = context->model;
Expand Down Expand Up @@ -496,82 +511,233 @@ static enum gptoss_status process_tokens(
return status;
}

status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
command_buffer,
&model->f32_mf4w_moe_matmul_swiglu_fn,
model->mlp_swiglu_threadgroup_size,
&context->rmsnorm_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_offset=*/0,
&model->block_weight_buffers[n],
/*weight_block_offset=*/0,
&model->block_weight_buffers[n],
/*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
&model->block_weight_buffers[n],
/*bias_offset=*/model->mlp_swiglu_bias_offset,
&context->swiglu_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
model->swiglu_limit,
model->per_expert_block_weight_size,
num_block_output_tokens,
model->num_active_experts,
model->embedding_dim,
model->mlp_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch");
return status;
}
// If we have enough tokens in prefill, we will pick the prefill-optimized kernels.
if (num_block_output_tokens >= min_tokens_for_dense_moe_kernels) {
// Commit and wait for the command buffer to complete.
// As we need topk output to compute routing metadata.
status = gptoss_metal_command_buffer_commit(command_buffer);
if (status != gptoss_status_success) {
return status;
}

status = gptoss_metal_command_buffer_wait_completion(command_buffer, NULL);
if (status != gptoss_status_success) {
return status;
}
const size_t E = model->num_experts;
const size_t T = num_block_output_tokens * model->num_active_experts;

status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
command_buffer,
&model->f32_mf4w_moe_matmul_fn,
model->mlp_out_threadgroup_size,
&context->swiglu_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_offset=*/0,
&model->block_weight_buffers[n],
/*weight_block_offset=*/model->mlp_out_block_offset,
&model->block_weight_buffers[n],
/*weight_scale_offset=*/model->mlp_out_scale_offset,
&model->block_weight_buffers[n],
/*bias_offset=*/model->mlp_out_bias_offset,
&context->moe_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
model->per_expert_block_weight_size,
num_block_output_tokens,
model->num_active_experts,
model->mlp_dim,
model->embedding_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch");
return status;
}
const struct gptoss_expert_prediction* preds =
(const struct gptoss_expert_prediction*) context->expert_activation_buffer.ptr;

status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(
command_buffer,
&model->f32_accumulate_e4_fn,
model->mlp_acc_threadgroup_size,
model->max_threadgroups,
&context->moe_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_offset=*/0,
&context->residual_activation_buffer,
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
&context->control_buffer,
/*control_offset=*/0,
model->embedding_dim,
num_block_output_tokens,
model->num_active_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch");
return status;
uint32_t* token_to_expert_routing = (uint32_t*) context->token_to_expert_routing_buffer.ptr;
uint32_t* expert_offset = (uint32_t*) context->expert_offset_buffer.ptr;
// Zero out the expert offset buffer.
memset(expert_offset, 0, E * sizeof(uint32_t));

for (size_t i = 0; i < T; i++) {
const uint32_t expert_id = preds[i].expert_id;
token_to_expert_routing[i] = expert_offset[expert_id];
expert_offset[expert_id]++;
}

uint32_t total = 0;
// Prefix sum.
for (size_t i = 0; i < model->num_experts; i++) {
const uint32_t bin_size = expert_offset[i];
expert_offset[i] = total;
total += bin_size;
}

// Create a new command buffer.
status = gptoss_metal_command_buffer_create(&context->model->command_queue, command_buffer);
if (status != gptoss_status_success) {
return status;
}
status = gptoss_metal_command_buffer_encode_launch_f32_scatter(
command_buffer,
&model->f32_scatter_e4_fn,
&context->rmsnorm_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_predictions_offset=*/0,
&context->expert_offset_buffer,
/*expert_offsets_offset=*/0,
&context->token_to_expert_routing_buffer,
/*intra_expert_offsets_offset=*/0,
&context->swiglu_input_buffer,
/*output_offset=*/0,
model->embedding_dim,
num_block_output_tokens,
model->num_active_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_scatter kernel launch");
return status;
}
// Dense MoE SwiGLU matmul -- iterate over all experts.
const size_t total_tokens = num_block_output_tokens * model->num_active_experts;
for (size_t e = 0; e < model->num_experts; e++) {
bool last_expert = e == model->num_experts - 1;
uint32_t expert_tokens = last_expert ? total_tokens - expert_offset[e] : expert_offset[e + 1] - expert_offset[e];
if (expert_tokens == 0) {
continue;
}
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(
command_buffer,
&model->f32_mf4w_moe_dense_matmul_swiglu_fn,
&context->swiglu_input_buffer,
/*input_offset=*/0,
&model->block_weight_buffers[n],
/*weight_block_offset=*/0,
&model->block_weight_buffers[n],
/*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
&model->block_weight_buffers[n],
/*bias_offset=*/model->mlp_swiglu_bias_offset,
&context->swiglu_activation_buffer,
/*output_offset=*/0,
model->swiglu_limit,
/*expert_stride_bytes=*/model->per_expert_block_weight_size,
expert_tokens,
expert_offset[e],
e,
model->embedding_dim,
2 * model->mlp_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
return status;
}
}
// Dense MoE proj matmul -- again iterate over all experts.
for (size_t e = 0; e < model->num_experts; e++) {
bool last_expert = e == model->num_experts - 1;
uint32_t expert_tokens = last_expert ? total_tokens - expert_offset[e] : expert_offset[e + 1] - expert_offset[e];
if (expert_tokens == 0) {
continue;
}
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(
command_buffer,
&model->f32_mf4w_moe_dense_matmul_fn,
&context->swiglu_activation_buffer,
/*input_offset=*/0,
&model->block_weight_buffers[n],
/*weight_block_offset=*/model->mlp_out_block_offset,
&model->block_weight_buffers[n],
/*weight_scale_offset=*/model->mlp_out_scale_offset,
&model->block_weight_buffers[n],
/*bias_offset=*/model->mlp_out_bias_offset,
&context->moe_activation_buffer,
/*output_offset=*/0,
/*expert_stride_bytes=*/model->per_expert_block_weight_size,
expert_tokens,
expert_offset[e],
e,
model->mlp_dim,
model->embedding_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
return status;
}
}

status = gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(
command_buffer,
&model->f32_gather_and_accumulate_e4_fn,
&context->moe_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_predictions_offset=*/0,
&context->expert_offset_buffer,
/*expert_offsets_offset=*/0,
&context->token_to_expert_routing_buffer,
/*intra_expert_offsets_offset=*/0,
&context->residual_activation_buffer,
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
model->embedding_dim,
num_block_output_tokens,
model->num_active_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_gather_and_accumulate_e4 kernel launch");
return status;
}

} else {
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
command_buffer,
&model->f32_mf4w_moe_matmul_swiglu_fn,
model->mlp_swiglu_threadgroup_size,
&context->rmsnorm_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_offset=*/0,
&model->block_weight_buffers[n],
/*weight_block_offset=*/0,
&model->block_weight_buffers[n],
/*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
&model->block_weight_buffers[n],
/*bias_offset=*/model->mlp_swiglu_bias_offset,
&context->swiglu_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
model->swiglu_limit,
model->per_expert_block_weight_size,
num_block_output_tokens,
model->num_active_experts,
model->embedding_dim,
model->mlp_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch");
return status;
}

status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
command_buffer,
&model->f32_mf4w_moe_matmul_fn,
model->mlp_out_threadgroup_size,
&context->swiglu_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_offset=*/0,
&model->block_weight_buffers[n],
/*weight_block_offset=*/model->mlp_out_block_offset,
&model->block_weight_buffers[n],
/*weight_scale_offset=*/model->mlp_out_scale_offset,
&model->block_weight_buffers[n],
/*bias_offset=*/model->mlp_out_bias_offset,
&context->moe_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
model->per_expert_block_weight_size,
num_block_output_tokens,
model->num_active_experts,
model->mlp_dim,
model->embedding_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch");
return status;
}

status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(
command_buffer,
&model->f32_accumulate_e4_fn,
model->mlp_acc_threadgroup_size,
model->max_threadgroups,
&context->moe_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_offset=*/0,
&context->residual_activation_buffer,
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
&context->control_buffer,
/*control_offset=*/0,
model->embedding_dim,
num_block_output_tokens,
model->num_active_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch");
return status;
}
}
}
}
Expand Down Expand Up @@ -946,6 +1112,9 @@ enum gptoss_status GPTOSS_ABI gptoss_context_release(
gptoss_metal_buffer_release(&context->expert_activation_buffer);
gptoss_metal_buffer_release(&context->swiglu_activation_buffer);
gptoss_metal_buffer_release(&context->moe_activation_buffer);
gptoss_metal_buffer_release(&context->expert_offset_buffer);
gptoss_metal_buffer_release(&context->token_to_expert_routing_buffer);
gptoss_metal_buffer_release(&context->swiglu_input_buffer);

// Input/output buffers
gptoss_metal_buffer_release(&context->control_buffer);
Expand Down
Loading