Skip to content

Commit

Permalink
skip-unused: disable skipping on ROCm / when LLAMA_USE_HIPBLAS
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Aug 23, 2023
1 parent c77ed60 commit 5ee8597
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
#endif
#endif

// TODO: Fix unused logit skipping crashes on ROCm
// (see https://github.com/ggerganov/llama.cpp/pull/2700#issuecomment-1689548127)
#ifndef LLAMA_USE_HIPBLAS
#define LLAMA_SKIP_UNUSED_LOGITS
#endif

#include <array>
#include <ctime>
#include <cinttypes>
Expand Down Expand Up @@ -1594,6 +1600,7 @@ static struct ggml_cgraph * llama_build_graph(
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
}

#ifdef LLAMA_SKIP_UNUSED_LOGITS
if (il == n_layer - 1 && !lctx.logits_all)
{
// From here on, we only care about the last token and its logits.
Expand All @@ -1614,6 +1621,7 @@ static struct ggml_cgraph * llama_build_graph(
n_past += N - 1;
N = 1;
}
#endif // LLAMA_SKIP_UNUSED_LOGITS

struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
offload_func_kq(tmpq);
Expand Down Expand Up @@ -1920,9 +1928,14 @@ static bool llama_eval_internal(
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
} else {
// return result for just the last token
GGML_ASSERT(ggml_nelements(res) == n_vocab);
logits_out.resize(n_vocab);
#ifdef LLAMA_SKIP_UNUSED_LOGITS
GGML_ASSERT(ggml_nelements(res) == n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab);
#else
GGML_ASSERT(ggml_nelements(res) == n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
#endif
}
}

Expand Down

0 comments on commit 5ee8597

Please sign in to comment.