Skip to content

Commit

Permalink
llama : fix Mamba inference for pipeline parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
compilade committed Mar 12, 2024
1 parent 1ac668e commit 3e06fca
Showing 1 changed file with 80 additions and 54 deletions.
134 changes: 80 additions & 54 deletions llama.cpp
Expand Up @@ -2082,7 +2082,7 @@ struct llama_context {
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]

#ifdef GGML_USE_MPI
Expand Down Expand Up @@ -5518,6 +5518,9 @@ struct llm_build_context {
lctx.inp_K_shift = nullptr;
lctx.inp_mean = nullptr;
lctx.inp_cls = nullptr;
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
}

void free() {
Expand Down Expand Up @@ -5559,14 +5562,14 @@ struct llm_build_context {

GGML_ASSERT(kv_self.recurrent);

lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
struct ggml_tensor * state_copy = build_inp_s_copy();

for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);

conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);

// TODO: name the intermediate tensors with cb()

Expand Down Expand Up @@ -5665,6 +5668,27 @@ struct llm_build_context {
return lctx.inp_cls;
}

struct ggml_tensor * build_inp_s_copy() {
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
cb(lctx.inp_s_copy, "inp_s_copy", -1);
ggml_set_input(lctx.inp_s_copy);
return lctx.inp_s_copy;
}

struct ggml_tensor * build_inp_s_mask() {
lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
cb(lctx.inp_s_mask, "inp_s_mask", -1);
ggml_set_input(lctx.inp_s_mask);
return lctx.inp_s_mask;
}

struct ggml_tensor * build_inp_s_seq() {
lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
cb(lctx.inp_s_seq, "inp_s_seq", -1);
ggml_set_input(lctx.inp_s_seq);
return lctx.inp_s_seq;
}

struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

Expand Down Expand Up @@ -8148,12 +8172,8 @@ struct llm_build_context {
// {n_embd, n_tokens}
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);

struct ggml_tensor * state_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
struct ggml_tensor * state_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
lctx.inp_s_mask = state_mask;
lctx.inp_s_seq = state_seq;
ggml_set_input(state_mask);
ggml_set_input(state_seq);
struct ggml_tensor * state_mask = build_inp_s_mask();
struct ggml_tensor * state_seq = build_inp_s_seq();

for (int il = 0; il < n_layer; ++il) {
// (ab)using the KV cache to store the states
Expand Down Expand Up @@ -8508,7 +8528,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
}

if (batch.pos) {
if (batch.pos && lctx.inp_pos) {
const int64_t n_tokens = batch.n_tokens;

ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
Expand All @@ -8519,61 +8539,63 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
"non-causal attention with generative models is not supported"
);

// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;
if (lctx.inp_KQ_mask) {
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;

assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));

float * data = (float *) lctx.inp_KQ_mask->data;
float * data = (float *) lctx.inp_KQ_mask->data;

// For causal attention, use only the previous KV cells
// of the correct sequence for each token of the batch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
// For causal attention, use only the previous KV cells
// of the correct sequence for each token of the batch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];

for (int i = 0; i < n_kv; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
f = -INFINITY;
} else {
f = 0.0f;
for (int i = 0; i < n_kv; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
f = -INFINITY;
} else {
f = 0.0f;
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
}
}
} else {
// when using kv cache, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
} else {
// when using kv cache, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;

assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));

float * data = (float *) lctx.inp_KQ_mask->data;
float * data = (float *) lctx.inp_KQ_mask->data;

for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_seq_id seq_id = batch.seq_id[j][0];

for (int i = 0; i < n_tokens; ++i) {
float f = -INFINITY;
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id[i][s] == seq_id) {
f = 0.0f;
break;
for (int i = 0; i < n_tokens; ++i) {
float f = -INFINITY;
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id[i][s] == seq_id) {
f = 0.0f;
break;
}
}
}

data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
}
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
}

for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
}
}
}
}
Expand Down Expand Up @@ -9272,11 +9294,15 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
}

if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
llama_set_s_copy(lctx);

{
ggml_backend_sched_reset(lctx.sched);

ggml_cgraph * gf = llama_build_graph_s_copy(lctx);

ggml_backend_sched_alloc_graph(lctx.sched, gf);

llama_set_s_copy(lctx);

llama_graph_compute(lctx, gf, lctx.cparams.n_threads);

need_reserve = true;
Expand Down

0 comments on commit 3e06fca

Please sign in to comment.