Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_met
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_CONV_2D);

GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F32);

char base[256];
char name[256];

snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);

return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_UPSCALE);

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
11 changes: 7 additions & 4 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
// TODO: try to update the tensor API kernels to at least match the simdgroup performance
if (getenv("GGML_METAL_TENSOR_ENABLE") == NULL &&
![[dev->mtl_device name] containsString:@"M5"] &&
![[dev->mtl_device name] containsString:@"M6"] &&
![[dev->mtl_device name] containsString:@"A19"] &&
![[dev->mtl_device name] containsString:@"A20"]) {
GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__);
![[dev->mtl_device name] containsString:@"M6"]) {
GGML_LOG_WARN("%s: tensor API disabled for pre-M5 device\n", __func__);
Comment on lines +567 to +568
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change this

dev->props.has_tensor = false;
}

Expand Down Expand Up @@ -885,6 +883,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return true;
case GGML_OP_IM2COL:
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
case GGML_OP_CONV_2D:
return ggml_is_contiguous(op->src[0]) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_POOL_1D:
return false;
case GGML_OP_UPSCALE:
Expand Down
30 changes: 30 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,36 @@ typedef struct {
uint64_t nb2;
} ggml_metal_kargs_conv_transpose_2d;

typedef struct {
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t IW;
int32_t IH;
int32_t KW;
int32_t KH;
int32_t IC;
int32_t OC;
int32_t OW;
int32_t OH;
int32_t N;
int32_t s0;
int32_t s1;
int32_t p0;
int32_t p1;
int32_t d0;
int32_t d1;
} ggml_metal_kargs_conv_2d;

typedef struct {
uint64_t ofs0;
uint64_t ofs1;
Expand Down
96 changes: 87 additions & 9 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <cassert>
#include <algorithm>
#include <limits>

static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
if (!t) {
Expand Down Expand Up @@ -364,6 +365,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_im2col(ctx, idx);
} break;
case GGML_OP_CONV_2D:
{
n_fuse = ggml_metal_op_conv_2d(ctx, idx);
} break;
case GGML_OP_CONV_TRANSPOSE_1D:
{
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
Expand Down Expand Up @@ -503,10 +508,10 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.nb0 =*/ (uint64_t) nb0,
/*.nb1 =*/ (uint64_t) nb1,
/*.nb2 =*/ (uint64_t) nb2,
/*.nb3 =*/ (uint64_t) nb3,
/*.dim =*/ dim,
};

Expand Down Expand Up @@ -1036,11 +1041,6 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {

nth = std::min(nth, nk0);

if (nth*nrptg > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
nrptg = 1;
}

ggml_metal_kargs_set_rows args = {
/*.nk0 =*/ nk0,
/*.ne01 =*/ ne01,
Expand Down Expand Up @@ -3082,6 +3082,84 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
return 1;
}

int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);

GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);

const int32_t s0 = ((const int32_t *) op->op_params)[0];
const int32_t s1 = ((const int32_t *) op->op_params)[1];
const int32_t p0 = ((const int32_t *) op->op_params)[2];
const int32_t p1 = ((const int32_t *) op->op_params)[3];
const int32_t d0 = ((const int32_t *) op->op_params)[4];
const int32_t d1 = ((const int32_t *) op->op_params)[5];

ggml_metal_kargs_conv_2d args = {
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.IW =*/ ne10,
/*.IH =*/ ne11,
/*.KW =*/ ne00,
/*.KH =*/ ne01,
/*.IC =*/ ne02,
/*.OC =*/ ne03,
/*.OW =*/ ne0,
/*.OH =*/ ne1,
/*.N =*/ ne3,
/*.s0 =*/ s0,
/*.s1 =*/ s1,
/*.p0 =*/ p0,
/*.p1 =*/ p1,
/*.d0 =*/ d0,
/*.d1 =*/ d1,
};

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);

int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
nth = std::min(nth, 256);
nth = std::max(nth, 1);

const uint64_t n_out = (uint64_t) ne0 * ne1 * ne2 * ne3;
uint64_t tg64 = (n_out + nth - 1)/nth;
tg64 = std::max<uint64_t>(tg64, 1);
tg64 = std::min<uint64_t>(tg64, (uint64_t) std::numeric_limits<int>::max());
const int tg = (int) tg64;

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);

return 1;
}

int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
Expand Down
124 changes: 124 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -4146,6 +4146,130 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;

template <typename TK>
kernel void kernel_conv_2d(
constant ggml_metal_kargs_conv_2d & args,
device const TK * weights,
device const float * src,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {

const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
const uint thread_index = tg_index * threads_per_tg + local_thread;
const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;

const ulong stride_w = args.nb10 / sizeof(float);
const ulong stride_h = args.nb11 / sizeof(float);
const ulong stride_c = args.nb12 / sizeof(float);
const ulong stride_n = args.nb13 / sizeof(float);
const ulong dst_stride_w = args.nb0 / sizeof(float);
const ulong dst_stride_h = args.nb1 / sizeof(float);
const ulong dst_stride_c = args.nb2 / sizeof(float);
const ulong dst_stride_n = args.nb3 / sizeof(float);

const ulong k_stride_w = args.nb00 / sizeof(TK);
const ulong k_stride_h = args.nb01 / sizeof(TK);
const ulong k_stride_c = args.nb02 / sizeof(TK);
const ulong k_stride_o = args.nb03 / sizeof(TK);
Comment on lines +4152 to +4179
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally we keep the input buffers as type char * dst and then perform the offset computation using the strides (nb) and only then cast the to the type (e.g. float *).

This would avoid using sizeof() calls and would be more consistent with the rest of the code here.


for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
uint64_t tmp = index;

const int32_t ow = tmp % args.OW; tmp /= args.OW;
const int32_t oh = tmp % args.OH; tmp /= args.OH;
const int32_t oc = tmp % args.OC; tmp /= args.OC;
const int32_t n = tmp;

float acc = 0.0f;

const int32_t base_x = ow*args.s0 - args.p0;
const int32_t base_y = oh*args.s1 - args.p1;

int32_t ky_start = 0;
if (base_y < 0) {
ky_start = (-base_y + args.d1 - 1)/args.d1;
}
int32_t ky_end = args.KH;
const int32_t y_max = args.IH - 1 - base_y;
if (y_max < 0) {
ky_end = ky_start;
} else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
ky_end = min(ky_end, y_max/args.d1 + 1);
}

int32_t kx_start = 0;
if (base_x < 0) {
kx_start = (-base_x + args.d0 - 1)/args.d0;
}
int32_t kx_end = args.KW;
const int32_t x_max = args.IW - 1 - base_x;
if (x_max < 0) {
kx_end = kx_start;
} else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
kx_end = min(kx_end, x_max/args.d0 + 1);
}

if (ky_start < ky_end && kx_start < kx_end) {
const device const float * src_n = src + (ulong) n * stride_n;
const device const TK * w_oc = weights + (ulong) oc * k_stride_o;

for (int32_t ic = 0; ic < args.IC; ++ic) {
const device const float * src_c = src_n + (ulong) ic * stride_c;
const device const TK * w_c = w_oc + (ulong) ic * k_stride_c;

for (int32_t ky = ky_start; ky < ky_end; ++ky) {
const int32_t iy = base_y + ky*args.d1;
const device const float * src_row = src_c + (ulong) iy * stride_h;
const device const TK * w_row = w_c + (ulong) ky * k_stride_h;

for (int32_t kx = kx_start; kx < kx_end; ++kx) {
const int32_t ix = base_x + kx*args.d0;
const device const float * src_elem = src_row + (ulong) ix * stride_w;
const device const TK * w_elem = w_row + (ulong) kx * k_stride_w;

acc += (*src_elem) * (float) (*w_elem);
}
}
}
}

device float * dst_ptr = dst +
(ulong) n * dst_stride_n +
(ulong) oc * dst_stride_c +
(ulong) oh * dst_stride_h +
(ulong) ow * dst_stride_w;
*dst_ptr = acc;
}
}

template [[host_name("kernel_conv_2d_f32_f32")]]
kernel void kernel_conv_2d<float>(
constant ggml_metal_kargs_conv_2d & args,
device const float * weights,
device const float * src,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);

template [[host_name("kernel_conv_2d_f16_f32")]]
kernel void kernel_conv_2d<half>(
constant ggml_metal_kargs_conv_2d & args,
device const half * weights,
device const float * src,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);

typedef void (conv_transpose_1d_t)(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,
Expand Down
Loading