diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c index 5cdaee7f..b507619f 100644 --- a/gpt_oss/metal/source/context.c +++ b/gpt_oss/metal/source/context.c @@ -72,6 +72,18 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create( if (status != gptoss_status_success) { goto cleanup; } + status = gptoss_metal_buffer_create(&model->device, model->num_experts * sizeof(uint32_t), NULL, &context->expert_offset_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * sizeof(uint32_t), NULL, &context->token_to_expert_routing_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->swiglu_input_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &context->swiglu_activation_buffer); if (status != gptoss_status_success) { goto cleanup; @@ -115,7 +127,9 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create( context->allocation_size = context->residual_activation_buffer.size + context->rmsnorm_activation_buffer.size + context->qkv_activation_buffer.size + context->sdpa_activation_buffer.size + - context->gate_activation_buffer.size + context->expert_activation_buffer.size + context->swiglu_activation_buffer.size + context->moe_activation_buffer.size + + context->gate_activation_buffer.size + context->expert_activation_buffer.size + + context->expert_offset_buffer.size + context->token_to_expert_routing_buffer.size + context->swiglu_input_buffer.size + + context->swiglu_activation_buffer.size + context->moe_activation_buffer.size + context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size; context->model = model; @@ -176,6 +190,7 @@ static enum gptoss_status process_tokens( assert(num_output_tokens <= context->max_batch_tokens); assert(num_input_tokens >= num_output_tokens); const size_t dense_matmul_kernel_token_multiple_constraint = 64; + const size_t min_tokens_for_dense_moe_kernels = 64; enum gptoss_status status = gptoss_status_success; const struct gptoss_model* model = context->model; @@ -496,82 +511,233 @@ static enum gptoss_status process_tokens( return status; } - status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu( - command_buffer, - &model->f32_mf4w_moe_matmul_swiglu_fn, - model->mlp_swiglu_threadgroup_size, - &context->rmsnorm_activation_buffer, - /*input_offset=*/0, - &context->expert_activation_buffer, - /*expert_offset=*/0, - &model->block_weight_buffers[n], - /*weight_block_offset=*/0, - &model->block_weight_buffers[n], - /*weight_scale_offset=*/model->mlp_swiglu_scale_offset, - &model->block_weight_buffers[n], - /*bias_offset=*/model->mlp_swiglu_bias_offset, - &context->swiglu_activation_buffer, - /*output_offset=*/0, - &context->control_buffer, - /*control_offset=*/0, - model->swiglu_limit, - model->per_expert_block_weight_size, - num_block_output_tokens, - model->num_active_experts, - model->embedding_dim, - model->mlp_dim); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch"); - return status; - } + // If we have enough tokens in prefill, we will pick the prefill-optimized kernels. + if (num_block_output_tokens >= min_tokens_for_dense_moe_kernels) { + // Commit and wait for the command buffer to complete. + // As we need topk output to compute routing metadata. + status = gptoss_metal_command_buffer_commit(command_buffer); + if (status != gptoss_status_success) { + return status; + } + + status = gptoss_metal_command_buffer_wait_completion(command_buffer, NULL); + if (status != gptoss_status_success) { + return status; + } + const size_t E = model->num_experts; + const size_t T = num_block_output_tokens * model->num_active_experts; - status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul( - command_buffer, - &model->f32_mf4w_moe_matmul_fn, - model->mlp_out_threadgroup_size, - &context->swiglu_activation_buffer, - /*input_offset=*/0, - &context->expert_activation_buffer, - /*expert_offset=*/0, - &model->block_weight_buffers[n], - /*weight_block_offset=*/model->mlp_out_block_offset, - &model->block_weight_buffers[n], - /*weight_scale_offset=*/model->mlp_out_scale_offset, - &model->block_weight_buffers[n], - /*bias_offset=*/model->mlp_out_bias_offset, - &context->moe_activation_buffer, - /*output_offset=*/0, - &context->control_buffer, - /*control_offset=*/0, - model->per_expert_block_weight_size, - num_block_output_tokens, - model->num_active_experts, - model->mlp_dim, - model->embedding_dim); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch"); - return status; - } + const struct gptoss_expert_prediction* preds = + (const struct gptoss_expert_prediction*) context->expert_activation_buffer.ptr; - status = gptoss_metal_command_buffer_encode_launch_f32_accumulate( - command_buffer, - &model->f32_accumulate_e4_fn, - model->mlp_acc_threadgroup_size, - model->max_threadgroups, - &context->moe_activation_buffer, - /*input_offset=*/0, - &context->expert_activation_buffer, - /*expert_offset=*/0, - &context->residual_activation_buffer, - /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), - &context->control_buffer, - /*control_offset=*/0, - model->embedding_dim, - num_block_output_tokens, - model->num_active_experts); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch"); - return status; + uint32_t* token_to_expert_routing = (uint32_t*) context->token_to_expert_routing_buffer.ptr; + uint32_t* expert_offset = (uint32_t*) context->expert_offset_buffer.ptr; + // Zero out the expert offset buffer. + memset(expert_offset, 0, E * sizeof(uint32_t)); + + for (size_t i = 0; i < T; i++) { + const uint32_t expert_id = preds[i].expert_id; + token_to_expert_routing[i] = expert_offset[expert_id]; + expert_offset[expert_id]++; + } + + uint32_t total = 0; + // Prefix sum. + for (size_t i = 0; i < model->num_experts; i++) { + const uint32_t bin_size = expert_offset[i]; + expert_offset[i] = total; + total += bin_size; + } + + // Create a new command buffer. + status = gptoss_metal_command_buffer_create(&context->model->command_queue, command_buffer); + if (status != gptoss_status_success) { + return status; + } + status = gptoss_metal_command_buffer_encode_launch_f32_scatter( + command_buffer, + &model->f32_scatter_e4_fn, + &context->rmsnorm_activation_buffer, + /*input_offset=*/0, + &context->expert_activation_buffer, + /*expert_predictions_offset=*/0, + &context->expert_offset_buffer, + /*expert_offsets_offset=*/0, + &context->token_to_expert_routing_buffer, + /*intra_expert_offsets_offset=*/0, + &context->swiglu_input_buffer, + /*output_offset=*/0, + model->embedding_dim, + num_block_output_tokens, + model->num_active_experts); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_scatter kernel launch"); + return status; + } + // Dense MoE SwiGLU matmul -- iterate over all experts. + const size_t total_tokens = num_block_output_tokens * model->num_active_experts; + for (size_t e = 0; e < model->num_experts; e++) { + bool last_expert = e == model->num_experts - 1; + uint32_t expert_tokens = last_expert ? total_tokens - expert_offset[e] : expert_offset[e + 1] - expert_offset[e]; + if (expert_tokens == 0) { + continue; + } + status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu( + command_buffer, + &model->f32_mf4w_moe_dense_matmul_swiglu_fn, + &context->swiglu_input_buffer, + /*input_offset=*/0, + &model->block_weight_buffers[n], + /*weight_block_offset=*/0, + &model->block_weight_buffers[n], + /*weight_scale_offset=*/model->mlp_swiglu_scale_offset, + &model->block_weight_buffers[n], + /*bias_offset=*/model->mlp_swiglu_bias_offset, + &context->swiglu_activation_buffer, + /*output_offset=*/0, + model->swiglu_limit, + /*expert_stride_bytes=*/model->per_expert_block_weight_size, + expert_tokens, + expert_offset[e], + e, + model->embedding_dim, + 2 * model->mlp_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch"); + return status; + } + } + // Dense MoE proj matmul -- again iterate over all experts. + for (size_t e = 0; e < model->num_experts; e++) { + bool last_expert = e == model->num_experts - 1; + uint32_t expert_tokens = last_expert ? total_tokens - expert_offset[e] : expert_offset[e + 1] - expert_offset[e]; + if (expert_tokens == 0) { + continue; + } + status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul( + command_buffer, + &model->f32_mf4w_moe_dense_matmul_fn, + &context->swiglu_activation_buffer, + /*input_offset=*/0, + &model->block_weight_buffers[n], + /*weight_block_offset=*/model->mlp_out_block_offset, + &model->block_weight_buffers[n], + /*weight_scale_offset=*/model->mlp_out_scale_offset, + &model->block_weight_buffers[n], + /*bias_offset=*/model->mlp_out_bias_offset, + &context->moe_activation_buffer, + /*output_offset=*/0, + /*expert_stride_bytes=*/model->per_expert_block_weight_size, + expert_tokens, + expert_offset[e], + e, + model->mlp_dim, + model->embedding_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch"); + return status; + } + } + + status = gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4( + command_buffer, + &model->f32_gather_and_accumulate_e4_fn, + &context->moe_activation_buffer, + /*input_offset=*/0, + &context->expert_activation_buffer, + /*expert_predictions_offset=*/0, + &context->expert_offset_buffer, + /*expert_offsets_offset=*/0, + &context->token_to_expert_routing_buffer, + /*intra_expert_offsets_offset=*/0, + &context->residual_activation_buffer, + /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), + model->embedding_dim, + num_block_output_tokens, + model->num_active_experts); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_gather_and_accumulate_e4 kernel launch"); + return status; + } + + } else { + status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu( + command_buffer, + &model->f32_mf4w_moe_matmul_swiglu_fn, + model->mlp_swiglu_threadgroup_size, + &context->rmsnorm_activation_buffer, + /*input_offset=*/0, + &context->expert_activation_buffer, + /*expert_offset=*/0, + &model->block_weight_buffers[n], + /*weight_block_offset=*/0, + &model->block_weight_buffers[n], + /*weight_scale_offset=*/model->mlp_swiglu_scale_offset, + &model->block_weight_buffers[n], + /*bias_offset=*/model->mlp_swiglu_bias_offset, + &context->swiglu_activation_buffer, + /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, + model->swiglu_limit, + model->per_expert_block_weight_size, + num_block_output_tokens, + model->num_active_experts, + model->embedding_dim, + model->mlp_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch"); + return status; + } + + status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul( + command_buffer, + &model->f32_mf4w_moe_matmul_fn, + model->mlp_out_threadgroup_size, + &context->swiglu_activation_buffer, + /*input_offset=*/0, + &context->expert_activation_buffer, + /*expert_offset=*/0, + &model->block_weight_buffers[n], + /*weight_block_offset=*/model->mlp_out_block_offset, + &model->block_weight_buffers[n], + /*weight_scale_offset=*/model->mlp_out_scale_offset, + &model->block_weight_buffers[n], + /*bias_offset=*/model->mlp_out_bias_offset, + &context->moe_activation_buffer, + /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, + model->per_expert_block_weight_size, + num_block_output_tokens, + model->num_active_experts, + model->mlp_dim, + model->embedding_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch"); + return status; + } + + status = gptoss_metal_command_buffer_encode_launch_f32_accumulate( + command_buffer, + &model->f32_accumulate_e4_fn, + model->mlp_acc_threadgroup_size, + model->max_threadgroups, + &context->moe_activation_buffer, + /*input_offset=*/0, + &context->expert_activation_buffer, + /*expert_offset=*/0, + &context->residual_activation_buffer, + /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), + &context->control_buffer, + /*control_offset=*/0, + model->embedding_dim, + num_block_output_tokens, + model->num_active_experts); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch"); + return status; + } } } } @@ -946,6 +1112,9 @@ enum gptoss_status GPTOSS_ABI gptoss_context_release( gptoss_metal_buffer_release(&context->expert_activation_buffer); gptoss_metal_buffer_release(&context->swiglu_activation_buffer); gptoss_metal_buffer_release(&context->moe_activation_buffer); + gptoss_metal_buffer_release(&context->expert_offset_buffer); + gptoss_metal_buffer_release(&context->token_to_expert_routing_buffer); + gptoss_metal_buffer_release(&context->swiglu_input_buffer); // Input/output buffers gptoss_metal_buffer_release(&context->control_buffer); diff --git a/gpt_oss/metal/source/include/internal/metal-kernels.h b/gpt_oss/metal/source/include/internal/metal-kernels.h index c12a834d..3addb3cc 100644 --- a/gpt_oss/metal/source/include/internal/metal-kernels.h +++ b/gpt_oss/metal/source/include/internal/metal-kernels.h @@ -317,6 +317,82 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate( uint32_t num_tokens, uint32_t num_experts); +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_scatter( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_scatter_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* expert_predictions_buffer, + size_t expert_predictions_offset, + const struct gptoss_metal_buffer* expert_offsets_buffer, + size_t expert_offsets_offset, + const struct gptoss_metal_buffer* intra_expert_offsets_buffer, + size_t intra_expert_offsets_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + uint32_t num_channels, + uint32_t num_tokens, + uint32_t num_active_experts); + +enum gptoss_status +gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_swiglu_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_block_buffer, + size_t weight_block_offset, + const struct gptoss_metal_buffer* weight_scale_buffer, + size_t weight_scale_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + float swiglu_limit, + uint32_t expert_stride_bytes, + uint32_t num_tokens, + uint32_t expert_token_offset, + uint32_t expert_id, + uint32_t num_cols, + uint32_t num_rows); + +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_block_buffer, + size_t weight_block_offset, + const struct gptoss_metal_buffer* weight_scale_buffer, + size_t weight_scale_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + uint32_t expert_stride_bytes, + uint32_t num_tokens, + uint32_t expert_token_offset, + uint32_t expert_id, + uint32_t num_cols, + uint32_t num_rows); + +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_gather_and_accumulate_e4_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* expert_predictions_buffer, + size_t expert_predictions_offset, + const struct gptoss_metal_buffer* expert_offsets_buffer, + size_t expert_offsets_offset, + const struct gptoss_metal_buffer* intra_expert_offsets_buffer, + size_t intra_expert_offsets_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + uint32_t num_channels, + uint32_t num_tokens, + uint32_t num_active_experts); + enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk( const struct gptoss_metal_command_buffer* command_buffer, const struct gptoss_metal_function* f32_topk_fn, diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h index c63578a7..574b07ff 100644 --- a/gpt_oss/metal/source/include/internal/model.h +++ b/gpt_oss/metal/source/include/internal/model.h @@ -87,6 +87,10 @@ struct gptoss_model { struct gptoss_metal_function f32_mf4w_moe_matmul_swiglu_fn; struct gptoss_metal_function f32_mf4w_moe_matmul_fn; struct gptoss_metal_function f32_accumulate_e4_fn; + struct gptoss_metal_function f32_scatter_e4_fn; + struct gptoss_metal_function f32_mf4w_moe_dense_matmul_swiglu_fn; + struct gptoss_metal_function f32_mf4w_moe_dense_matmul_fn; + struct gptoss_metal_function f32_gather_and_accumulate_e4_fn; struct gptoss_metal_function f32_topk_softmax_e32_k4_fn; struct gptoss_metal_function f32_topk_softmax_e128_k4_fn; struct gptoss_metal_function f32_sdpa_q8_d64_fn; @@ -156,6 +160,9 @@ struct gptoss_context { struct gptoss_metal_buffer sdpa_activation_buffer; // SDPA output struct gptoss_metal_buffer gate_activation_buffer; // MoE gating output struct gptoss_metal_buffer expert_activation_buffer; // MoE expert predictions + struct gptoss_metal_buffer expert_offset_buffer; // MoE expert histograms cumsum + struct gptoss_metal_buffer token_to_expert_routing_buffer; // MoE token to expert routing + struct gptoss_metal_buffer swiglu_input_buffer; // MLP+SwiGLU input for prefill. struct gptoss_metal_buffer swiglu_activation_buffer; // MLP+SwiGLU output struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert) diff --git a/gpt_oss/metal/source/metal-kernels.c b/gpt_oss/metal/source/metal-kernels.c index 3aaeb32f..c6d7fbc9 100644 --- a/gpt_oss/metal/source/metal-kernels.c +++ b/gpt_oss/metal/source/metal-kernels.c @@ -936,6 +936,305 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope( /*threadgroup_buffer_size=*/0); } +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_scatter( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_scatter_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* expert_predictions_buffer, + size_t expert_predictions_offset, + const struct gptoss_metal_buffer* expert_offsets_buffer, + size_t expert_offsets_offset, + const struct gptoss_metal_buffer* intra_expert_offsets_buffer, + size_t intra_expert_offsets_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + uint32_t num_channels, + uint32_t num_tokens, + uint32_t num_active_experts) +{ + if (command_buffer->object == NULL || f32_scatter_fn->pipeline_state_object == NULL) { + return gptoss_status_invalid_state; + } + + if (num_channels % 4 != 0) { + return gptoss_status_invalid_argument; + } + + const size_t num_vecs = num_channels / 4; + const size_t tgx = math_min(num_vecs, 64); + const size_t tgy = 1; + const size_t tgz = 1; + const size_t grid_x = math_ceil_div(num_vecs, tgx); + const size_t grid_y = num_tokens; + const size_t grid_z = 1; + const size_t total_threadgroup_size = tgx * tgy * tgz; + if (total_threadgroup_size > f32_scatter_fn->max_threadgroup_threads) { + return gptoss_status_invalid_argument; + } + const struct gptoss_scatter_args args = { + .tokens = num_tokens, + .active_experts_per_token = num_active_experts, + .token_stride = num_channels, + }; + + return gptoss_metal_command_buffer_encode_launch_kernel( + command_buffer, f32_scatter_fn, + tgx, tgy, tgz, + grid_x, grid_y, grid_z, + sizeof(args), &args, + 5, + (const struct gptoss_metal_buffer *[]) {input_buffer, expert_predictions_buffer, expert_offsets_buffer, intra_expert_offsets_buffer, output_buffer}, + (const size_t[]) {input_offset, expert_predictions_offset, expert_offsets_offset, intra_expert_offsets_offset, output_offset}, + /*threadgroup_buffer_size=*/0); +} + +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_gather_and_accumulate_e4_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* expert_predictions_buffer, + size_t expert_predictions_offset, + const struct gptoss_metal_buffer* expert_offsets_buffer, + size_t expert_offsets_offset, + const struct gptoss_metal_buffer* intra_expert_offsets_buffer, + size_t intra_expert_offsets_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + uint32_t num_channels, + uint32_t num_tokens, + uint32_t num_active_experts) +{ + if (command_buffer->object == NULL || f32_gather_and_accumulate_e4_fn->pipeline_state_object == NULL) { + return gptoss_status_invalid_state; + } + + if (num_channels % 4 != 0) { + return gptoss_status_invalid_argument; + } + + const size_t num_vecs = num_channels / 4; + const size_t tgx = math_min(num_vecs, 64); + const size_t tgy = 1; + const size_t tgz = 1; + const size_t grid_x = math_ceil_div(num_vecs, tgx); + const size_t grid_y = num_tokens; + const size_t grid_z = 1; + const size_t total_threadgroup_size = tgx * tgy * tgz; + if (total_threadgroup_size > f32_gather_and_accumulate_e4_fn->max_threadgroup_threads) { + return gptoss_status_invalid_argument; + } + const struct gptoss_gather_args args = { + .tokens = num_tokens, + .active_experts_per_token = num_active_experts, + .token_stride = num_channels, + }; + + return gptoss_metal_command_buffer_encode_launch_kernel( + command_buffer, f32_gather_and_accumulate_e4_fn, + tgx, tgy, tgz, + grid_x, grid_y, grid_z, + sizeof(args), &args, + 5, + (const struct gptoss_metal_buffer *[]) {input_buffer, expert_predictions_buffer, expert_offsets_buffer, intra_expert_offsets_buffer, output_buffer}, + (const size_t[]) {input_offset, expert_predictions_offset, expert_offsets_offset, intra_expert_offsets_offset, output_offset}, + /*threadgroup_buffer_size=*/0); +} + +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_swiglu_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_block_buffer, + size_t weight_block_offset, + const struct gptoss_metal_buffer* weight_scale_buffer, + size_t weight_scale_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + float swiglu_limit, + uint32_t expert_stride_bytes, + uint32_t num_tokens, + uint32_t expert_token_offset, + uint32_t expert_id, + uint32_t num_cols, + uint32_t num_rows) +{ + if (command_buffer->object == NULL || f32_mf4w_moe_dense_matmul_swiglu_fn->pipeline_state_object == NULL) { + return gptoss_status_invalid_state; + } + + if (num_cols % 32 != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: number of columns (%" PRIu32 ") is not divisible by 32", + num_cols); + return gptoss_status_invalid_argument; + } + + const struct gptoss_moe_dense_matmul_swiglu_args args = { + .expert_token_count = num_tokens, + .n = num_rows, + .k = num_cols, + .expert_id = expert_id, + .expert_token_offset = expert_token_offset, + .weight_blocks_expert_stride_bytes = expert_stride_bytes, + .weight_scales_expert_stride_bytes = expert_stride_bytes, + .bias_expert_stride_bytes = expert_stride_bytes, + .swiglu_min = -swiglu_limit, + .swiglu_max = swiglu_limit, + }; + const size_t threads_per_simdgroup = f32_mf4w_moe_dense_matmul_swiglu_fn->simdgroup_threads; + const uint32_t m = args.expert_token_count; + const uint32_t n = args.n; + const uint32_t k = args.k; + const uint32_t Bm = MOE_DENSE_MATMUL_SWIGLU_Bm; + const uint32_t Bn = MOE_DENSE_MATMUL_SWIGLU_Bn; + const uint32_t Bk = MOE_DENSE_MATMUL_SWIGLU_Bk; + const uint32_t Sg_Bm = MOE_DENSE_MATMUL_SWIGLU_Sg_Bm; + const uint32_t Sg_Bn = MOE_DENSE_MATMUL_SWIGLU_Sg_Bn; + if (Bm % Sg_Bm != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: Bm (%" PRIu32 ") is not divisible by Sg_Bm (%" PRIu32 ")", + Bm, Sg_Bm); + return gptoss_status_invalid_argument; + } + if (Bn % Sg_Bn != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: Bn (%" PRIu32 ") is not divisible by Sg_Bn (%" PRIu32 ")", + Bn, Sg_Bn); + return gptoss_status_invalid_argument; + } + + const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup; + const size_t threadgroup_size_y = 1; + const size_t threadgroup_size_z = 1; + const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z; + if (total_threadgroup_size > f32_mf4w_moe_dense_matmul_swiglu_fn->max_threadgroup_threads) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)", + total_threadgroup_size, f32_mf4w_moe_dense_matmul_swiglu_fn->max_threadgroup_threads); + return gptoss_status_invalid_argument; + } + if (n % Bn != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: n (%" PRIu32 ") is not divisible by Bn (%" PRIu32 ")", + n, Bn); + return gptoss_status_invalid_argument; + } + if (k % Bk != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: k (%" PRIu32 ") is not divisible by Bk (%" PRIu32 ")", + k, Bk); + return gptoss_status_invalid_argument; + } + const size_t grid_x = n / Bn; + const size_t grid_y = math_ceil_div(m, Bm); + const size_t grid_z = 1; + + return gptoss_metal_command_buffer_encode_launch_kernel( + command_buffer, f32_mf4w_moe_dense_matmul_swiglu_fn, + threadgroup_size_x, threadgroup_size_y, threadgroup_size_z, + grid_x, grid_y, grid_z, + sizeof(args), &args, + 5, + (const struct gptoss_metal_buffer *[]) {input_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer}, + (const size_t[]) {input_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset}, + /*threadgroup_buffer_size=*/0); + + } + +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_block_buffer, + size_t weight_block_offset, + const struct gptoss_metal_buffer* weight_scale_buffer, + size_t weight_scale_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + uint32_t expert_stride_bytes, + uint32_t num_tokens, + uint32_t expert_token_offset, + uint32_t expert_id, + uint32_t num_cols, + uint32_t num_rows) +{ + if (command_buffer->object == NULL || f32_mf4w_moe_dense_matmul_fn->pipeline_state_object == NULL) { + return gptoss_status_invalid_state; + } + + if (num_cols % 32 != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 32", + num_cols); + return gptoss_status_invalid_argument; + } + const struct gptoss_moe_dense_matmul_args args = { + .expert_token_count = num_tokens, + .k = num_cols, + .n = num_rows, + .expert_id = expert_id, + .expert_token_offset = expert_token_offset, + .weight_blocks_expert_stride_bytes = expert_stride_bytes, + .weight_scales_expert_stride_bytes = expert_stride_bytes, + .bias_expert_stride_bytes = expert_stride_bytes, + }; + + const size_t threads_per_simdgroup = f32_mf4w_moe_dense_matmul_fn->simdgroup_threads; + const uint32_t m = args.expert_token_count; + const uint32_t n = args.n; + const uint32_t k = args.k; + const uint32_t Bm = MOE_DENSE_MATMUL_Bm; + const uint32_t Bn = MOE_DENSE_MATMUL_Bn; + const uint32_t Bk = MOE_DENSE_MATMUL_Bk; + const uint32_t Sg_Bm = MOE_DENSE_MATMUL_Sg_Bm; + const uint32_t Sg_Bn = MOE_DENSE_MATMUL_Sg_Bn; + if (Bm % Sg_Bm != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: Bm (%" PRIu32 ") is not divisible by Sg_Bm (%" PRIu32 ")", + Bm, Sg_Bm); + return gptoss_status_invalid_argument; + } + if (Bn % Sg_Bn != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: Bn (%" PRIu32 ") is not divisible by Sg_Bn (%" PRIu32 ")", + Bn, Sg_Bn); + return gptoss_status_invalid_argument; + } + + const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup; + const size_t threadgroup_size_y = 1; + const size_t threadgroup_size_z = 1; + const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z; + if (total_threadgroup_size > f32_mf4w_moe_dense_matmul_fn->max_threadgroup_threads) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)", + total_threadgroup_size, f32_mf4w_moe_dense_matmul_fn->max_threadgroup_threads); + return gptoss_status_invalid_argument; + } + if (n % Bn != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: n (%" PRIu32 ") is not divisible by Bn (%" PRIu32 ")", + n, Bn); + return gptoss_status_invalid_argument; + } + if (k % Bk != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: k (%" PRIu32 ") is not divisible by Bk (%" PRIu32 ")", + k, Bk); + return gptoss_status_invalid_argument; + } + + const size_t grid_y = math_ceil_div(m, Bm); + const size_t grid_x = n / Bn; + const size_t grid_z = 1; + + return gptoss_metal_command_buffer_encode_launch_kernel( + command_buffer, f32_mf4w_moe_dense_matmul_fn, + threadgroup_size_x, threadgroup_size_y, threadgroup_size_z, + grid_x, grid_y, grid_z, + sizeof(args), &args, + 5, + (const struct gptoss_metal_buffer *[]) {input_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer}, + (const size_t[]) {input_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset}, + /*threadgroup_buffer_size=*/0); +} + enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate( const struct gptoss_metal_command_buffer* command_buffer, const struct gptoss_metal_function* f32_accumulate_fn, diff --git a/gpt_oss/metal/source/model.c b/gpt_oss/metal/source/model.c index 469ef232..f4a12b3b 100644 --- a/gpt_oss/metal/source/model.c +++ b/gpt_oss/metal/source/model.c @@ -349,6 +349,22 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file( if (status != gptoss_status_success) { goto cleanup; } + status = gptoss_metal_function_create(&model->library, "gptoss_f32_scatter_e4", &model->f32_scatter_e4_fn); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_dense_matmul_swiglu", &model->f32_mf4w_moe_dense_matmul_swiglu_fn); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_dense_matmul", &model->f32_mf4w_moe_dense_matmul_fn); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_function_create(&model->library, "gptoss_f32_gather_and_accumulate_e4", &model->f32_gather_and_accumulate_e4_fn); + if (status != gptoss_status_success) { + goto cleanup; + } status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_matmul_swiglu", &model->f32_mf4w_moe_matmul_swiglu_fn); if (status != gptoss_status_success) { goto cleanup; @@ -524,6 +540,10 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release( gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_mlp_gate_fn); gptoss_metal_function_release(&model->f32_bf16w_unembedding_fn); gptoss_metal_function_release(&model->f32_rope_fn); + gptoss_metal_function_release(&model->f32_scatter_e4_fn); + gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_swiglu_fn); + gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_fn); + gptoss_metal_function_release(&model->f32_gather_and_accumulate_e4_fn); gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_swiglu_fn); gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_fn); gptoss_metal_function_release(&model->f32_accumulate_e4_fn);