Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ int main(int argc, char ** argv) {
// save state (rng, logits, embedding and kv_cache) to file
{
std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
const size_t written = llama_copy_state_data(ctx, state_mem.data());

{
FILE *fp_write = fopen("dump_state.bin", "wb");
llama_copy_state_data(ctx, state_mem.data()); // could also copy directly to memory mapped file
fwrite(state_mem.data(), 1, state_mem.size(), fp_write);
fclose(fp_write);
}
FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
fclose(fp_write);

fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
}

// save state (last tokens)
Expand Down Expand Up @@ -100,18 +100,17 @@ int main(int argc, char ** argv) {
std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));

FILE * fp_read = fopen("dump_state.bin", "rb");
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

const size_t ret = fread(state_mem.data(), 1, state_mem.size(), fp_read);
if (ret != state_mem.size()) {
if (read != llama_set_state_data(ctx2, state_mem.data())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
return 1;
}

llama_set_state_data(ctx2, state_mem.data());

fclose(fp_read);
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
}

// restore state (last tokens)
Expand Down
51 changes: 17 additions & 34 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9795,12 +9795,8 @@ struct llama_context * llama_new_context_with_model(
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}

// resized during inference
if (params.logits_all) {
ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab);
} else {
ctx->logits.reserve(hparams.n_vocab);
}
// resized during inference, reserve maximum
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);

if (params.embedding){
ctx->embedding.resize(hparams.n_embd);
Expand Down Expand Up @@ -10149,8 +10145,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
// for reference, std::mt19937(1337) serializes to 6701 bytes.
const size_t s_rng_size = sizeof(size_t);
const size_t s_rng = LLAMA_MAX_RNG_STATE;
const size_t s_logits_capacity = sizeof(size_t);
const size_t s_logits_size = sizeof(size_t);
// assume worst case for logits although only currently set ones are serialized
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
const size_t s_embedding_size = sizeof(size_t);
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
Expand All @@ -10161,7 +10157,6 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
const size_t s_total = (
+ s_rng_size
+ s_rng
+ s_logits_capacity
+ s_logits_size
+ s_logits
+ s_embedding_size
Expand Down Expand Up @@ -10230,37 +10225,27 @@ struct llama_data_file_context : llama_data_context {
static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
// copy rng
{
std::stringstream rng_ss;
std::ostringstream rng_ss;
rng_ss << ctx->rng;

const size_t rng_size = rng_ss.str().size();
char rng_buf[LLAMA_MAX_RNG_STATE];
const std::string & rng_str = rng_ss.str();
const size_t rng_size = rng_str.size();

memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);

data_ctx->write(&rng_size, sizeof(rng_size));
data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE);
data_ctx->write(&rng_size, sizeof(rng_size));
data_ctx->write(rng_str.data(), rng_size);
}

// copy logits
{
const size_t logits_cap = ctx->logits.capacity();
const size_t logits_size = ctx->logits.size();

data_ctx->write(&logits_cap, sizeof(logits_cap));
data_ctx->write(&logits_size, sizeof(logits_size));

if (logits_size) {
data_ctx->write(ctx->logits.data(), logits_size * sizeof(float));
}

// If there is a gap between the size and the capacity, write padding
size_t padding_size = (logits_cap - logits_size) * sizeof(float);
if (padding_size > 0) {
std::vector<uint8_t> padding(padding_size, 0); // Create a buffer filled with zeros
data_ctx->write(padding.data(), padding_size);
}
}

// copy embeddings
Expand Down Expand Up @@ -10370,34 +10355,32 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
// set rng
{
size_t rng_size;
char rng_buf[LLAMA_MAX_RNG_STATE];
memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);

memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE;
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);

std::stringstream rng_ss;
rng_ss.str(std::string(&rng_buf[0], rng_size));
std::string rng_str((char *)inp, rng_size); inp += rng_size;

std::istringstream rng_ss(rng_str);
rng_ss >> ctx->rng;

GGML_ASSERT(!rng_ss.fail());
}

// set logits
{
size_t logits_cap;
size_t logits_size;

memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap);
memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);

GGML_ASSERT(ctx->logits.capacity() == logits_cap);
GGML_ASSERT(ctx->logits.capacity() >= logits_size);

if (logits_size) {
ctx->logits.resize(logits_size);

memcpy(ctx->logits.data(), inp, logits_size * sizeof(float));
inp += logits_size * sizeof(float);
}

inp += logits_cap * sizeof(float);
}

// set embeddings
Expand Down
2 changes: 1 addition & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'

#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 3
#define LLAMA_SESSION_VERSION 4

#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
Expand Down