Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3605,8 +3605,6 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);

ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);

ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);

std::swap(bid_dst, bid_tmp);
Expand Down
126 changes: 87 additions & 39 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -4628,12 +4628,13 @@ kernel void kernel_argsort_merge_f32_i32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int im = tgpig[0] / args.ne01;
int i01 = tgpig[0] % args.ne01;
int i02 = tgpig[1];
int i03 = tgpig[2];

const int start = im * (2*args.len);
const int im = tgpig[0] / args.ne01;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];

const int start = im * (2 * args.len);

const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
Expand All @@ -4657,54 +4658,101 @@ kernel void kernel_argsort_merge_f32_i32(
+ args.nb02*i02
+ args.nb03*i03);

for (int k = tpitg.x; k < (int) total; k += ntg.x) {
// find partition (i,j) such that i+j = k
int low = k > len1 ? k - len1 : 0;
int high = MIN(k, len0);
if (total == 0) {
return;
}

while (low < high) {
const int mid = (low + high) >> 1;
const int chunk = (total + ntg.x - 1) / ntg.x;

const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k - mid - 1];
const int k0 = tpitg.x * chunk;
const int k1 = min(k0 + chunk, total);

const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
if (k0 >= total) {
return;
}

if (order == GGML_SORT_ORDER_ASC) {
if (val0 <= val1) {
low = mid + 1;
} else {
high = mid;
}
} else {
if (val0 >= val1) {
low = mid + 1;
} else {
high = mid;
}
}
int low = k0 > len1 ? k0 - len1 : 0;
int high = MIN(k0, len0);

// binary-search partition (i, j) such that i + j = k
while (low < high) {
const int mid = (low + high) >> 1;

const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k0 - mid - 1];

const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];

bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}

const int i = low;
const int j = k - i;
if (take_left) {
low = mid + 1;
} else {
high = mid;
}
}

int i = low;
int j = k0 - i;

// keep the merge fronts into registers
int32_t idx0 = 0;
float val0 = 0.0f;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}

int32_t idx1 = 0;
float val1 = 0.0f;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}

for (int k = k0; k < k1; ++k) {
int32_t out_idx;

if (i >= len0) {
out_idx = tmp1[j];
while (k < k1) {
dst[k++] = tmp1[j++];
}
break;
} else if (j >= len1) {
out_idx = tmp0[i];
while (k < k1) {
dst[k++] = tmp0[i++];
}
break;
} else {
const int32_t idx0 = tmp0[i];
const int32_t idx1 = tmp1[j];
bool take_left;

const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}

out_idx = (order == GGML_SORT_ORDER_ASC)
? (val0 <= val1 ? idx0 : idx1)
: (val0 >= val1 ? idx0 : idx1);
if (take_left) {
out_idx = idx0;
++i;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
} else {
out_idx = idx1;
++j;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
}
}

dst[k] = out_idx;
Expand Down
Loading