Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
GGML_METAL_DECL_KERNEL(get_rows_f32);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
Expand Down Expand Up @@ -145,7 +146,7 @@ @implementation GGMLMetalClass
ctx->n_buffers = 0;
ctx->concur_list_len = 0;

ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);

#ifdef GGML_SWIFT
// load the default.metallib file
Expand Down Expand Up @@ -175,7 +176,7 @@ @implementation GGMLMetalClass

//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);

NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
Expand Down Expand Up @@ -224,6 +225,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
GGML_METAL_ADD_KERNEL(get_rows_f32);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
Expand Down Expand Up @@ -293,7 +295,9 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(gelu);
GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(soft_max_4);
GGML_METAL_DEL_KERNEL(diag_mask_inf);
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
GGML_METAL_DEL_KERNEL(get_rows_f32);
GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
Expand Down Expand Up @@ -386,6 +390,7 @@ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;

//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
*offs = (size_t) ioffs;

Expand Down Expand Up @@ -723,13 +728,15 @@ void ggml_metal_graph_compute(
case GGML_OP_ADD:
{
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));

// utilize float4
GGML_ASSERT(ne00 % 4 == 0);
const int64_t nb = ne00/4;

if (ggml_nelements(src1) == ne10) {
// src1 is a row
GGML_ASSERT(ne11 == 1);
[encoder setComputePipelineState:ctx->pipeline_add_row];
} else {
[encoder setComputePipelineState:ctx->pipeline_add];
Expand All @@ -746,13 +753,15 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL:
{
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));

// utilize float4
GGML_ASSERT(ne00 % 4 == 0);
const int64_t nb = ne00/4;

if (ggml_nelements(src1) == ne10) {
// src1 is a row
GGML_ASSERT(ne11 == 1);
[encoder setComputePipelineState:ctx->pipeline_mul_row];
} else {
[encoder setComputePipelineState:ctx->pipeline_mul];
Expand All @@ -768,6 +777,8 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_SCALE:
{
GGML_ASSERT(ggml_is_contiguous(src0));

const float scale = *(const float *) src1->data;

[encoder setComputePipelineState:ctx->pipeline_scale];
Expand Down Expand Up @@ -867,8 +878,8 @@ void ggml_metal_graph_compute(

// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
if (!ggml_is_transposed(src0) &&
!ggml_is_transposed(src1) &&
src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00%32 == 0 &&
Expand All @@ -893,9 +904,12 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
Expand Down Expand Up @@ -1045,6 +1059,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GET_ROWS:
{
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
Expand All @@ -1060,9 +1075,9 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];

const int64_t n = ggml_nelements(src1);

Expand Down
96 changes: 61 additions & 35 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ kernel void kernel_add_row(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
constant int64_t & nb,
constant int64_t & nb,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
Expand Down Expand Up @@ -1321,7 +1321,6 @@ kernel void kernel_mul_mat_q3_K_f32(
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
}
}

}
#else
kernel void kernel_mul_mat_q3_K_f32(
Expand Down Expand Up @@ -1865,6 +1864,15 @@ kernel void kernel_mul_mat_q6_K_f32(

//============================= templates and their specializations =============================

// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
float4x4 temp = *(((device float4x4 *)src));
for (int i = 0; i < 16; i++){
reg[i/4][i%4] = temp[i/4][i%4];
}
}

template <typename type4x4>
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
half4x4 temp = *(((device half4x4 *)src));
Expand All @@ -1875,7 +1883,6 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)

template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {

device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
Expand All @@ -1887,12 +1894,10 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
}

}

template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {

device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
Expand Down Expand Up @@ -1964,7 +1969,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
}

#else
float kcoef = il&1 ? 1.f/16.f : 1.f;
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
Expand Down Expand Up @@ -2110,22 +2114,25 @@ kernel void kernel_get_rows(
// each block_q contains 16*nl weights
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm(device const uchar * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & gqa,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {

threadgroup half * sa = ((threadgroup half *)shared_memory);
device const uchar * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & gqa,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {

threadgroup half * sa = (threadgroup half *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);

const uint r0 = tgpig.y;
Expand All @@ -2138,18 +2145,23 @@ kernel void kernel_mul_mm(device const uchar * src0,
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;

simdgroup_half8x8 ma[4];
simdgroup_half8x8 ma[4];
simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8];
for (int i = 0; i < 8; i++){
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}

short il = (tiitg % THREAD_PER_ROW);
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;

uint offset0 = im/gqa*nb02;
ushort offset1 = il/nl;

device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
+ nb12 * im
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));

for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
//load data and store to threadgroup memory
Expand Down Expand Up @@ -2229,6 +2241,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
constant uint64_t &, constant uint64_t &, uint, uint, uint);

template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
Expand All @@ -2239,14 +2252,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;

typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);

template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
typedef void (mat_mm_t)(
device const uchar * src0,
device const uchar * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & gqa,
threadgroup uchar *, uint3, uint, uint);

template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
Expand Down
20 changes: 16 additions & 4 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4303,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
}

size_t ggml_nbytes(const struct ggml_tensor * tensor) {
size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
size_t nbytes;
size_t blck_size = ggml_blck_size(tensor->type);
if (blck_size == 1) {
nbytes = ggml_type_size(tensor->type);
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
}
else {
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
}

return nbytes;
}

Expand Down Expand Up @@ -18340,7 +18351,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
i,
node->ne[0], node->ne[1],
ggml_op_name(node->op));
ggml_op_name(node->op),
ggml_get_name(node));
}

for (int i = 0; i < GGML_OP_COUNT; i++) {
Expand Down
4 changes: 0 additions & 4 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2969,10 +2969,6 @@ static bool llama_eval_internal(
if (lctx.ctx_metal) {
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
ggml_metal_graph_compute(lctx.ctx_metal, gf);
ggml_metal_get_tensor (lctx.ctx_metal, res);
if (!lctx.embedding.empty()) {
ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
}
} else {
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
}
Expand Down