Skip to content

Commit

Permalink
Add support for using a different base model
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Apr 16, 2023
1 parent 57627f0 commit c150e1b
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 33 deletions.
7 changes: 7 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
}
params.lora_adapter = argv[i];
params.use_mmap = false;
} else if (arg == "--lora-base") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.lora_base = argv[i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "--embedding") {
Expand Down Expand Up @@ -250,6 +256,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, "\n");
Expand Down
1 change: 1 addition & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct gpt_params {
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted

std::string lora_adapter = ""; // lora adapter path
std::string lora_base = ""; // base model path for the lora adapter

bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
Expand Down
5 changes: 4 additions & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ int main(int argc, char ** argv) {
}

if (!params.lora_adapter.empty()) {
int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads);
int err = llama_apply_lora_from_file(ctx,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return 1;
Expand Down
5 changes: 4 additions & 1 deletion examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ int main(int argc, char ** argv) {
}

if (!params.lora_adapter.empty()) {
int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads);
int err = llama_apply_lora_from_file(ctx,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return 1;
Expand Down
36 changes: 36 additions & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5461,6 +5461,27 @@ static void ggml_compute_forward_dup_f16(
}
}
}
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
size_t id = 0;
uint8_t * dst_ptr = (uint8_t *) dst->data;
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
// todo: use work buffer
float * src0_f32 = (float *) alloca(ne00 * sizeof(float));

for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) {
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
// convert to f32 and quantize
for (int i00 = 0; i00 < ne00; i00++) {
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
}
quantize_row_q(src0_f32, dst_ptr + id, ne00);
id += dst_row_size;
}
}
}
} else {
GGML_ASSERT(false); // TODO: implement
}
Expand Down Expand Up @@ -5653,6 +5674,21 @@ static void ggml_compute_forward_dup_f32(
}
}
}
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
size_t id = 0;
uint8_t * dst_ptr = (uint8_t *) dst->data;
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);

for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) {
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
id += dst_row_size;
}
}
}
} else {
GGML_ASSERT(false); // TODO: implement
}
Expand Down
92 changes: 75 additions & 17 deletions llama.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Defines fileno on msys:
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#include <cstdint>
#include <cstdio>
#endif

#include "llama_util.h"
Expand Down Expand Up @@ -1759,8 +1761,7 @@ int llama_model_quantize(
}
}

int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, int n_threads) {
// TODO: refactor all of this after PR #801
int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);

auto & model = ctx->model;
Expand Down Expand Up @@ -1801,13 +1802,13 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor

// create a temporary ggml context to store the lora tensors
// todo: calculate size from biggest possible tensor
std::vector<uint8_t> buf(1024ull * 1024ull * 1024ull);
std::vector<uint8_t> lora_buf(1024ull * 1024ull * 1024ull);
struct ggml_init_params params;
params.mem_size = buf.size();
params.mem_buffer = buf.data();
params.mem_size = lora_buf.size();
params.mem_buffer = lora_buf.data();
params.no_alloc = false;

ggml_context* lora_ctx = ggml_init(params);
ggml_context * lora_ctx = ggml_init(params);
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;

// create a name -> tensor map of the model to accelerate lookups
Expand All @@ -1816,6 +1817,32 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
model_tensors.insert(kv);
}


// load base model
std::unique_ptr<llama_model_loader> model_loader;
ggml_context * base_ctx = NULL;
llama_buffer base_buf;
if (path_base_model) {
fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model);
model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*vocab_only*/ false));

size_t ctx_size, mmapped_size;
model_loader->calc_sizes(&ctx_size, &mmapped_size);
base_buf.resize(ctx_size);

ggml_init_params base_params;
base_params.mem_size = base_buf.size;
base_params.mem_buffer = base_buf.addr;
base_params.no_alloc = model_loader->use_mmap;

base_ctx = ggml_init(base_params);

model_loader->ggml_ctx = base_ctx;

// maybe this should in llama_model_loader
model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, false));
}

fprintf(stderr, "%s: ", __func__);

// read tensors and apply
Expand Down Expand Up @@ -1892,13 +1919,31 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() &&
lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {

ggml_tensor * tensor = model_tensors[base_name];
ggml_tensor * dest_t = model_tensors[base_name];
ggml_tensor * base_t;
if (model_loader) {
// load from base model
if (model_loader->tensors_map.name_to_idx.find(base_name) == model_loader->tensors_map.name_to_idx.end()) {
fprintf(stderr, "%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
return 1;
}
size_t idx = model_loader->tensors_map.name_to_idx[base_name];
llama_load_tensor & lt = model_loader->tensors_map.tensors[idx];
base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] });
lt.data = (uint8_t *) lt.ggml_tensor->data;
model_loader->load_data_for(lt);
lt.ggml_tensor->data = lt.data;
}
else {
base_t = dest_t;
}

ggml_tensor * loraA = lora_tensors[base_name + ".loraA"];
ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];

if (tensor->ne[0] != loraA->ne[1] || tensor->ne[1] != loraB->ne[1]) {
if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
" are you sure that this adapter is for this model?\n", __func__, tensor->ne[0], loraA->ne[1]);
" are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
return 1;
}

Expand All @@ -1910,14 +1955,14 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
BA = ggml_scale(lora_ctx, BA, scale_tensor);
}

//printf("%s: (B)(%d %d %d %d) x (A)(%d %d %d %d) => (BA)(%d %d %d %d) + (T)(%d %d %d %d)\n",
// base_name.c_str(),
// (int)loraB->ne[0], (int)loraB->ne[1], (int)loraB->ne[2], (int)loraB->ne[3],
// (int)loraA->ne[0], (int)loraA->ne[1], (int)loraA->ne[2], (int)loraA->ne[3],
// (int)BA->ne[0], (int)BA->ne[1], (int)BA->ne[2], (int)BA->ne[3],
// (int)tensor->ne[0], (int)tensor->ne[1], (int)tensor->ne[2], (int)tensor->ne[3]
//);
ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA);
ggml_tensor * r;
if (base_t == dest_t) {
r = ggml_add_inplace(lora_ctx, dest_t, BA);
}
else {
r = ggml_add(lora_ctx, base_t, BA);
r = ggml_cpy(lora_ctx, r, dest_t);
}

struct ggml_cgraph gf = ggml_build_forward(r);
gf.n_threads = n_threads;
Expand All @@ -1934,14 +1979,27 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
}
}

// TODO: this should be in a destructor, it will leak on failure
ggml_free(lora_ctx);
if (base_ctx) {
ggml_free(base_ctx);
}

const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
fprintf(stderr, " done (%.2f ms)\n", t_lora_us / 1000.0);

return 0;
}

int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
try {
return llama_apply_lora_from_file_internal(ctx, path_lora, path_base_model, n_threads);
} catch (const std::string & err) {
fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.c_str());
return 1;
}
}

// Returns the KV cache that will contain the context for the
// ongoing prediction with the model.
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
Expand Down
7 changes: 5 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,15 @@ extern "C" {
enum llama_ftype ftype);

// Apply a LoRA adapter to a loaded model
// The model needs to be reloaded before applying a new adapter, otherwise
// the adapter will the applied on top of the previous one
// path_base_model is the path to a higher quality model to use as a base for
// the layers modified by the adapter. Can be NULL to use the current loaded model.
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
// will be applied on top of the previous one
// Returns 0 on success
LLAMA_API int llama_apply_lora_from_file(
struct llama_context * ctx,
const char * path_lora,
const char * path_base_model,
int n_threads);

// Returns the KV cache that will contain the context for the
Expand Down
28 changes: 16 additions & 12 deletions llama_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ struct llama_mmap {
#ifdef _POSIX_MAPPED_FILES
static constexpr bool SUPPORTED = true;

llama_mmap(struct llama_file * file) {
llama_mmap(struct llama_file * file, bool prefetch = true) {
size = file->size;
int fd = fileno(file->fp);
int flags = MAP_SHARED;
Expand All @@ -181,10 +181,12 @@ struct llama_mmap {
throw format("mmap failed: %s", strerror(errno));
}

// Advise the kernel to preload the mapped memory
if (madvise(addr, file->size, MADV_WILLNEED)) {
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
strerror(errno));
if (prefetch) {
// Advise the kernel to preload the mapped memory
if (madvise(addr, file->size, MADV_WILLNEED)) {
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
strerror(errno));
}
}
}

Expand Down Expand Up @@ -216,13 +218,15 @@ struct llama_mmap {
}

#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
// Advise the kernel to preload the mapped memory
WIN32_MEMORY_RANGE_ENTRY range;
range.VirtualAddress = addr;
range.NumberOfBytes = (SIZE_T)size;
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
if (prefetch) {
// Advise the kernel to preload the mapped memory
WIN32_MEMORY_RANGE_ENTRY range;
range.VirtualAddress = addr;
range.NumberOfBytes = (SIZE_T)size;
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
#pragma message("warning: You are building for pre-Windows 8; prefetch not supported")
Expand Down

0 comments on commit c150e1b

Please sign in to comment.