Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement configurable context length #1749

Merged
merged 1 commit into from Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions gpt4all-backend/bert.cpp
Expand Up @@ -714,8 +714,9 @@ Bert::~Bert() {
bert_free(d_ptr->ctx);
}

bool Bert::loadModel(const std::string &modelPath)
bool Bert::loadModel(const std::string &modelPath, int n_ctx)
{
(void)n_ctx;
d_ptr->ctx = bert_load_from_file(modelPath.c_str());
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->modelLoaded = d_ptr->ctx != nullptr;
Expand All @@ -728,8 +729,10 @@ bool Bert::isModelLoaded() const
return d_ptr->modelLoaded;
}

size_t Bert::requiredMem(const std::string &/*modelPath*/)
size_t Bert::requiredMem(const std::string &modelPath, int n_ctx)
{
(void)modelPath;
(void)n_ctx;
return 0;
}

Expand Down
4 changes: 2 additions & 2 deletions gpt4all-backend/bert_impl.h
Expand Up @@ -18,9 +18,9 @@ class Bert : public LLModel {

bool supportsEmbedding() const override { return true; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;
Expand Down
6 changes: 4 additions & 2 deletions gpt4all-backend/gptj.cpp
Expand Up @@ -676,15 +676,17 @@ GPTJ::GPTJ()
d_ptr->modelLoaded = false;
}

size_t GPTJ::requiredMem(const std::string &modelPath) {
size_t GPTJ::requiredMem(const std::string &modelPath, int n_ctx) {
(void)n_ctx;
gptj_model dummy_model;
gpt_vocab dummy_vocab;
size_t mem_req;
gptj_model_load(modelPath, dummy_model, dummy_vocab, &mem_req);
return mem_req;
}

bool GPTJ::loadModel(const std::string &modelPath) {
bool GPTJ::loadModel(const std::string &modelPath, int n_ctx) {
(void)n_ctx;
std::mt19937 rng(time(NULL));
d_ptr->rng = rng;

Expand Down
4 changes: 2 additions & 2 deletions gpt4all-backend/gptj_impl.h
Expand Up @@ -17,9 +17,9 @@ class GPTJ : public LLModel {

bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;
Expand Down
56 changes: 35 additions & 21 deletions gpt4all-backend/llamamodel.cpp
Expand Up @@ -120,7 +120,8 @@ struct llama_file_hparams {
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
};

size_t LLamaModel::requiredMem(const std::string &modelPath) {
size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx) {
// TODO(cebtenzzre): update to GGUF
auto fin = std::ifstream(modelPath, std::ios::binary);
fin.seekg(0, std::ios_base::end);
size_t filesize = fin.tellg();
Expand All @@ -137,40 +138,31 @@ size_t LLamaModel::requiredMem(const std::string &modelPath) {
fin.read(reinterpret_cast<char*>(&hparams.n_layer), sizeof(hparams.n_layer));
fin.read(reinterpret_cast<char*>(&hparams.n_rot), sizeof(hparams.n_rot));
fin.read(reinterpret_cast<char*>(&hparams.ftype), sizeof(hparams.ftype));
const size_t n_ctx = 2048;
const size_t kvcache_element_size = 2; // fp16
const size_t est_kvcache_size = hparams.n_embd * hparams.n_layer * 2u * n_ctx * kvcache_element_size;
return filesize + est_kvcache_size;
}

bool LLamaModel::loadModel(const std::string &modelPath)
bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx)
{
gpt_params params;

// load the model
if (n_ctx < 8) {
std::cerr << "warning: minimum context size is 8, using minimum size.\n";
n_ctx = 8;
}

// -- load the model --

d_ptr->model_params = llama_model_default_params();

d_ptr->model_params.use_mmap = params.use_mmap;
d_ptr->model_params.use_mmap = params.use_mmap;
#if defined (__APPLE__)
d_ptr->model_params.use_mlock = true;
d_ptr->model_params.use_mlock = true;
#else
d_ptr->model_params.use_mlock = params.use_mlock;
d_ptr->model_params.use_mlock = params.use_mlock;
#endif

d_ptr->ctx_params = llama_context_default_params();

d_ptr->ctx_params.n_ctx = 2048;
d_ptr->ctx_params.seed = params.seed;
d_ptr->ctx_params.f16_kv = params.memory_f16;

// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
// that we want this many logits so the state serializes consistently.
d_ptr->ctx_params.logits_all = true;

d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;

#ifdef GGML_USE_METAL
if (llama_verbose()) {
std::cerr << "llama.cpp: using Metal" << std::endl;
Expand All @@ -197,6 +189,28 @@ bool LLamaModel::loadModel(const std::string &modelPath)
return false;
}

const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
if (n_ctx > n_ctx_train) {
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
<< n_ctx << " specified)\n";
}

// -- initialize the context --

d_ptr->ctx_params = llama_context_default_params();

d_ptr->ctx_params.n_ctx = n_ctx;
d_ptr->ctx_params.seed = params.seed;
d_ptr->ctx_params.f16_kv = params.memory_f16;

// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
// that we want this many logits so the state serializes consistently.
d_ptr->ctx_params.logits_all = true;

d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;

d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
if (!d_ptr->ctx) {
#ifdef GGML_USE_KOMPUTE
Expand Down
4 changes: 2 additions & 2 deletions gpt4all-backend/llamamodel_impl.h
Expand Up @@ -17,9 +17,9 @@ class LLamaModel : public LLModel {

bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;
Expand Down
10 changes: 8 additions & 2 deletions gpt4all-backend/llmodel.cpp
Expand Up @@ -138,7 +138,7 @@ const LLModel::Implementation* LLModel::Implementation::implementation(const cha
return nullptr;
}

LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant) {
LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant, int n_ctx) {
if (!has_at_least_minimal_hardware()) {
std::cerr << "LLModel ERROR: CPU does not support AVX\n";
return nullptr;
Expand All @@ -154,7 +154,11 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s
if(impl) {
LLModel* metalimpl = impl->m_construct();
metalimpl->m_implementation = impl;
size_t req_mem = metalimpl->requiredMem(modelPath);
/* TODO(cebtenzzre): after we fix requiredMem, we should change this to happen at
* load time, not construct time. right now n_ctx is incorrectly hardcoded 2048 in
* most (all?) places where this is called, causing underestimation of required
* memory. */
Comment on lines +157 to +160
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apage43 Do you think it would be relatively easy to switch this to a load-time check instead of a construct-time one? It doesn't matter so much right now since it's not working anyway (unresolved fallout from the switch to GGUF).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason its construct-time is so that we do the fallback to cpu transparently: callers of construct passing "auto" just get the cpu implementation if the mem req is too high for Metal

if its changed to fail at load time callers will have to handle that fallback themselves - which is likely fine, but would need to be done in all the bindings

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The chat UI is already doing load-time fallback for Vulkan. And this is really the only way to do it because it's the user code that decides which GPU to use, which is of course initialized after a backend/implementation is available. We should make sure the bindings are capable of this too.

I think it would make sense to only ever dlopen one build of llamamodel-mainline on Apple silicon, as there's nothing we are currently doing that the Metal build isn't capable of.

size_t req_mem = metalimpl->requiredMem(modelPath, n_ctx);
float req_to_total = (float) req_mem / (float) total_mem;
// on a 16GB M2 Mac a 13B q4_0 (0.52) works for me but a 13B q4_K_M (0.55) does not
if (req_to_total >= 0.53) {
Expand All @@ -165,6 +169,8 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s
}
}
}
#else
(void)n_ctx;
#endif

if (!impl) {
Expand Down
6 changes: 3 additions & 3 deletions gpt4all-backend/llmodel.h
Expand Up @@ -37,7 +37,7 @@ class LLModel {
static bool isImplementation(const Dlhandle&);
static const std::vector<Implementation>& implementationList();
static const Implementation *implementation(const char *fname, const std::string& buildVariant);
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto");
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto", int n_ctx = 2048);
static std::vector<GPUDevice> availableGPUDevices();
static void setImplementationsSearchPath(const std::string& path);
static const std::string& implementationsSearchPath();
Expand Down Expand Up @@ -74,9 +74,9 @@ class LLModel {

virtual bool supportsEmbedding() const = 0;
virtual bool supportsCompletion() const = 0;
virtual bool loadModel(const std::string &modelPath) = 0;
virtual bool loadModel(const std::string &modelPath, int n_ctx) = 0;
virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath) = 0;
virtual size_t requiredMem(const std::string &modelPath, int n_ctx) = 0;
virtual size_t stateSize() const { return 0; }
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }
Expand Down
8 changes: 4 additions & 4 deletions gpt4all-backend/llmodel_c.cpp
Expand Up @@ -47,16 +47,16 @@ void llmodel_model_destroy(llmodel_model model) {
delete reinterpret_cast<LLModelWrapper*>(model);
}

size_t llmodel_required_mem(llmodel_model model, const char *model_path)
size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->requiredMem(model_path);
return wrapper->llModel->requiredMem(model_path, n_ctx);
}

bool llmodel_loadModel(llmodel_model model, const char *model_path)
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->loadModel(model_path);
return wrapper->llModel->loadModel(model_path, n_ctx);
}

bool llmodel_isModelLoaded(llmodel_model model)
Expand Down
6 changes: 4 additions & 2 deletions gpt4all-backend/llmodel_c.h
Expand Up @@ -110,17 +110,19 @@ void llmodel_model_destroy(llmodel_model model);
* Estimate RAM requirement for a model file
* @param model A pointer to the llmodel_model instance.
* @param model_path A string representing the path to the model file.
* @param n_ctx Maximum size of context window
* @return size greater than 0 if the model was parsed successfully, 0 if file could not be parsed.
*/
size_t llmodel_required_mem(llmodel_model model, const char *model_path);
size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx);

/**
* Load a model from a file.
* @param model A pointer to the llmodel_model instance.
* @param model_path A string representing the path to the model file.
* @param n_ctx Maximum size of context window
* @return true if the model was loaded successfully, false otherwise.
*/
bool llmodel_loadModel(llmodel_model model, const char *model_path);
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx);

/**
* Check if a model is loaded.
Expand Down
2 changes: 1 addition & 1 deletion gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs
Expand Up @@ -188,7 +188,7 @@ public bool IsLoaded()
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
public bool Load(string modelPath)
{
return NativeMethods.llmodel_loadModel(_handle, modelPath);
return NativeMethods.llmodel_loadModel(_handle, modelPath, 2048);
}

protected void Destroy()
Expand Down
3 changes: 2 additions & 1 deletion gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs
Expand Up @@ -70,7 +70,8 @@ internal static unsafe partial class NativeMethods
[return: MarshalAs(UnmanagedType.I1)]
public static extern bool llmodel_loadModel(
[NativeTypeName("llmodel_model")] IntPtr model,
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path,
[NativeTypeName("int32_t")] int n_ctx);

[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]

Expand Down
Expand Up @@ -39,7 +39,7 @@ private IGpt4AllModel CreateModel(string modelPath)
var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error);
_logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
_logger.LogInformation("Model loading started");
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath, 2048);
_logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
if (!loadedSuccessfully)
{
Expand Down
2 changes: 1 addition & 1 deletion gpt4all-bindings/golang/binding.cpp
Expand Up @@ -23,7 +23,7 @@ void* load_model(const char *fname, int n_threads) {
fprintf(stderr, "%s: error '%s'\n", __func__, new_error);
return nullptr;
}
if (!llmodel_loadModel(model, fname)) {
if (!llmodel_loadModel(model, fname, 2048)) {
llmodel_model_destroy(model);
return nullptr;
}
Expand Down
Expand Up @@ -195,7 +195,7 @@ public LLModel(Path modelPath) {
if(model == null) {
throw new IllegalStateException("Could not load, gpt4all backend returned error: " + error.getValue().getString(0));
}
library.llmodel_loadModel(model, modelPathAbs);
library.llmodel_loadModel(model, modelPathAbs, 2048);

if(!library.llmodel_isModelLoaded(model)){
throw new IllegalStateException("The model " + modelName + " could not be loaded");
Expand Down
Expand Up @@ -61,7 +61,7 @@ public LLModelPromptContext(jnr.ffi.Runtime runtime) {

Pointer llmodel_model_create2(String model_path, String build_variant, PointerByReference error);
void llmodel_model_destroy(Pointer model);
boolean llmodel_loadModel(Pointer model, String model_path);
boolean llmodel_loadModel(Pointer model, String model_path, int n_ctx);
boolean llmodel_isModelLoaded(Pointer model);
@u_int64_t long llmodel_get_state_size(Pointer model);
@u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest);
Expand Down
4 changes: 2 additions & 2 deletions gpt4all-bindings/python/gpt4all/__init__.py
@@ -1,2 +1,2 @@
from .gpt4all import Embed4All, GPT4All # noqa
from .pyllmodel import LLModel # noqa
from .gpt4all import Embed4All as Embed4All, GPT4All as GPT4All
from .pyllmodel import LLModel as LLModel
10 changes: 6 additions & 4 deletions gpt4all-bindings/python/gpt4all/gpt4all.py
Expand Up @@ -69,6 +69,7 @@ def __init__(
allow_download: bool = True,
n_threads: Optional[int] = None,
device: Optional[str] = "cpu",
n_ctx: int = 2048,
verbose: bool = False,
):
"""
Expand All @@ -90,15 +91,16 @@ def __init__(
Default is "cpu".

Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model.
n_ctx: Maximum size of context window
verbose: If True, print debug messages.
"""
self.model_type = model_type
self.model = pyllmodel.LLModel()
# Retrieve model and download if allowed
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
if device is not None:
if device != "cpu":
self.model.init_gpu(model_path=self.config["path"], device=device)
self.model.load_model(self.config["path"])
if device is not None and device != "cpu":
self.model.init_gpu(model_path=self.config["path"], device=device, n_ctx=n_ctx)
self.model.load_model(self.config["path"], n_ctx)
# Set n_threads
if n_threads is not None:
self.model.set_thread_count(n_threads)
Expand Down