-
Notifications
You must be signed in to change notification settings - Fork 13.7k
metal: accelerated conv2d #17175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
bghira
wants to merge
1
commit into
ggml-org:master
Choose a base branch
from
bghira:feature/metal-conv2d
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+274
−13
Open
metal: accelerated conv2d #17175
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4146,6 +4146,130 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>; | |
| //template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; | ||
| //template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; | ||
|
|
||
| template <typename TK> | ||
| kernel void kernel_conv_2d( | ||
| constant ggml_metal_kargs_conv_2d & args, | ||
| device const TK * weights, | ||
| device const float * src, | ||
| device float * dst, | ||
| uint3 tgpig[[threadgroup_position_in_grid]], | ||
| uint3 tgpg[[threadgroups_per_grid]], | ||
| uint3 tpitg[[thread_position_in_threadgroup]], | ||
| uint3 ntg[[threads_per_threadgroup]]) { | ||
|
|
||
| const uint threads_per_tg = ntg.x * ntg.y * ntg.z; | ||
| const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x; | ||
| const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x; | ||
| const uint thread_index = tg_index * threads_per_tg + local_thread; | ||
| const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z; | ||
| const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW; | ||
|
|
||
| const ulong stride_w = args.nb10 / sizeof(float); | ||
| const ulong stride_h = args.nb11 / sizeof(float); | ||
| const ulong stride_c = args.nb12 / sizeof(float); | ||
| const ulong stride_n = args.nb13 / sizeof(float); | ||
| const ulong dst_stride_w = args.nb0 / sizeof(float); | ||
| const ulong dst_stride_h = args.nb1 / sizeof(float); | ||
| const ulong dst_stride_c = args.nb2 / sizeof(float); | ||
| const ulong dst_stride_n = args.nb3 / sizeof(float); | ||
|
|
||
| const ulong k_stride_w = args.nb00 / sizeof(TK); | ||
| const ulong k_stride_h = args.nb01 / sizeof(TK); | ||
| const ulong k_stride_c = args.nb02 / sizeof(TK); | ||
| const ulong k_stride_o = args.nb03 / sizeof(TK); | ||
|
Comment on lines
+4152
to
+4179
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generally we keep the input buffers as type This would avoid using |
||
|
|
||
| for (uint64_t index = thread_index; index < total_outputs; index += total_threads) { | ||
| uint64_t tmp = index; | ||
|
|
||
| const int32_t ow = tmp % args.OW; tmp /= args.OW; | ||
| const int32_t oh = tmp % args.OH; tmp /= args.OH; | ||
| const int32_t oc = tmp % args.OC; tmp /= args.OC; | ||
| const int32_t n = tmp; | ||
|
|
||
| float acc = 0.0f; | ||
|
|
||
| const int32_t base_x = ow*args.s0 - args.p0; | ||
| const int32_t base_y = oh*args.s1 - args.p1; | ||
|
|
||
| int32_t ky_start = 0; | ||
| if (base_y < 0) { | ||
| ky_start = (-base_y + args.d1 - 1)/args.d1; | ||
| } | ||
| int32_t ky_end = args.KH; | ||
| const int32_t y_max = args.IH - 1 - base_y; | ||
| if (y_max < 0) { | ||
| ky_end = ky_start; | ||
| } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) { | ||
| ky_end = min(ky_end, y_max/args.d1 + 1); | ||
| } | ||
|
|
||
| int32_t kx_start = 0; | ||
| if (base_x < 0) { | ||
| kx_start = (-base_x + args.d0 - 1)/args.d0; | ||
| } | ||
| int32_t kx_end = args.KW; | ||
| const int32_t x_max = args.IW - 1 - base_x; | ||
| if (x_max < 0) { | ||
| kx_end = kx_start; | ||
| } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) { | ||
| kx_end = min(kx_end, x_max/args.d0 + 1); | ||
| } | ||
|
|
||
| if (ky_start < ky_end && kx_start < kx_end) { | ||
| const device const float * src_n = src + (ulong) n * stride_n; | ||
| const device const TK * w_oc = weights + (ulong) oc * k_stride_o; | ||
|
|
||
| for (int32_t ic = 0; ic < args.IC; ++ic) { | ||
| const device const float * src_c = src_n + (ulong) ic * stride_c; | ||
| const device const TK * w_c = w_oc + (ulong) ic * k_stride_c; | ||
|
|
||
| for (int32_t ky = ky_start; ky < ky_end; ++ky) { | ||
| const int32_t iy = base_y + ky*args.d1; | ||
| const device const float * src_row = src_c + (ulong) iy * stride_h; | ||
| const device const TK * w_row = w_c + (ulong) ky * k_stride_h; | ||
|
|
||
| for (int32_t kx = kx_start; kx < kx_end; ++kx) { | ||
| const int32_t ix = base_x + kx*args.d0; | ||
| const device const float * src_elem = src_row + (ulong) ix * stride_w; | ||
| const device const TK * w_elem = w_row + (ulong) kx * k_stride_w; | ||
|
|
||
| acc += (*src_elem) * (float) (*w_elem); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| device float * dst_ptr = dst + | ||
| (ulong) n * dst_stride_n + | ||
| (ulong) oc * dst_stride_c + | ||
| (ulong) oh * dst_stride_h + | ||
| (ulong) ow * dst_stride_w; | ||
| *dst_ptr = acc; | ||
| } | ||
| } | ||
|
|
||
| template [[host_name("kernel_conv_2d_f32_f32")]] | ||
| kernel void kernel_conv_2d<float>( | ||
| constant ggml_metal_kargs_conv_2d & args, | ||
| device const float * weights, | ||
| device const float * src, | ||
| device float * dst, | ||
| uint3 tgpig[[threadgroup_position_in_grid]], | ||
| uint3 tgpg[[threadgroups_per_grid]], | ||
| uint3 tpitg[[thread_position_in_threadgroup]], | ||
| uint3 ntg[[threads_per_threadgroup]]); | ||
|
|
||
| template [[host_name("kernel_conv_2d_f16_f32")]] | ||
| kernel void kernel_conv_2d<half>( | ||
| constant ggml_metal_kargs_conv_2d & args, | ||
| device const half * weights, | ||
| device const float * src, | ||
| device float * dst, | ||
| uint3 tgpig[[threadgroup_position_in_grid]], | ||
| uint3 tgpg[[threadgroups_per_grid]], | ||
| uint3 tpitg[[thread_position_in_threadgroup]], | ||
| uint3 ntg[[threads_per_threadgroup]]); | ||
|
|
||
| typedef void (conv_transpose_1d_t)( | ||
| constant ggml_metal_kargs_conv_transpose_1d & args, | ||
| device const float * src0, | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to change this