Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : add col2im kernels #808

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ extern "C" {
GGML_OP_CLAMP,
GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_IM2COL,
GGML_OP_COL2IM,
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
Expand Down Expand Up @@ -1560,6 +1561,19 @@ extern "C" {
bool is_2D,
enum ggml_type dst_type);

GGML_API struct ggml_tensor * ggml_col2im(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1,
bool is_2D,
enum ggml_type dst_type);

GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down
56 changes: 56 additions & 0 deletions src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_COL2IM_F16,
GGML_METAL_KERNEL_TYPE_COL2IM_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
Expand Down Expand Up @@ -597,6 +599,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COL2IM_F16, col2im_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COL2IM_F32, col2im_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
Expand Down Expand Up @@ -724,6 +728,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_ALIBI:
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
case GGML_OP_COL2IM:
return true;
case GGML_OP_POOL_1D:
case GGML_OP_POOL_2D:
Expand Down Expand Up @@ -2293,6 +2298,57 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];

[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
} break;
case GGML_OP_COL2IM:
{
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);

const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;

const int32_t N = src1->ne[is_2D ? 3 : 2];
const int32_t IC = src0->ne[is_2D ? 3 : 2];
const int32_t IH = is_2D ? src1->ne[2] : 1;
const int32_t IW = src1->ne[1];

const int32_t KH = is_2D ? src0->ne[1] : 1;
const int32_t KW = src0->ne[0];

const int32_t OH = is_2D ? dst->ne[1] : 1;
const int32_t OW = dst->ne[0];

const int32_t CHW = IC * KH * KW;

id<MTLComputePipelineState> pipeline = nil;

switch (dst->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
default: GGML_ASSERT(false);
};

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
// TODO:
[encoder setBytes:&OW length:sizeof( int32_t) atIndex:4];
[encoder setBytes:&OH length:sizeof( int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];

[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
} break;
case GGML_OP_UPSCALE:
Expand Down
55 changes: 55 additions & 0 deletions src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1837,6 +1837,61 @@ kernel void kernel_im2col(
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;

typedef void (col2im_t)(
device const float * x,
device char * dst,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);

template <typename T>
kernel void kernel_col2im(
device const float * x,
device char * dst,
constant int32_t & OW,
constant int32_t & OH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t iow = tgpig[2] * s0 + tpitg[2] * d0 - p0;
const int32_t ioh = tgpig[1] * s1 + tpitg[1] * d1 - p1;

const int32_t offset_dst =
(tgpig[0] * tgpg[0] * tgpg[1] + tgpg[1] * tgpig[0]) * tgpg[2] +
ioh * tgpg[2] + iow;

const int32_t offset_src =
(tgpig[0] * ) * CHW;


device T * pdst = (device T *) (dst);

if (ioh >= 0 && ioh < OH && iow >= 0 && iow < OW) {
pdst[offset_dst] += x[offset_src];
}
}

template [[host_name("kernel_col2im_f32")]] kernel col2im_t kernel_col2im<float>;
template [[host_name("kernel_col2im_f16")]] kernel col2im_t kernel_col2im<half>;

kernel void kernel_upscale_f32(
device const char * src0,
device char * dst,
Expand Down
Loading