From 358d6196ecbb18d8d282858c7c521f96ecc997cc Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Wed, 13 Dec 2023 16:23:11 -0500 Subject: [PATCH] implement configurable context length --- gpt4all-backend/bert.cpp | 7 +- gpt4all-backend/bert_impl.h | 4 +- gpt4all-backend/gptj.cpp | 6 +- gpt4all-backend/gptj_impl.h | 4 +- gpt4all-backend/llamamodel.cpp | 56 ++++++----- gpt4all-backend/llamamodel_impl.h | 4 +- gpt4all-backend/llmodel.cpp | 10 +- gpt4all-backend/llmodel.h | 6 +- gpt4all-backend/llmodel_c.cpp | 8 +- gpt4all-backend/llmodel_c.h | 6 +- .../csharp/Gpt4All/Bindings/LLModel.cs | 2 +- .../csharp/Gpt4All/Bindings/NativeMethods.cs | 3 +- .../Gpt4All/Model/Gpt4AllModelFactory.cs | 2 +- gpt4all-bindings/golang/binding.cpp | 2 +- .../java/com/hexadevlabs/gpt4all/LLModel.java | 2 +- .../hexadevlabs/gpt4all/LLModelLibrary.java | 2 +- gpt4all-bindings/python/gpt4all/__init__.py | 4 +- gpt4all-bindings/python/gpt4all/gpt4all.py | 10 +- gpt4all-bindings/python/gpt4all/pyllmodel.py | 89 +++++++++--------- .../python/gpt4all/tests/test_gpt4all.py | 2 +- gpt4all-bindings/typescript/index.cc | 4 +- gpt4all-chat/chatgpt.cpp | 6 +- gpt4all-chat/chatgpt.h | 4 +- gpt4all-chat/chatlistmodel.cpp | 2 +- gpt4all-chat/chatllm.cpp | 29 ++++-- gpt4all-chat/embllm.cpp | 4 +- gpt4all-chat/modellist.cpp | 20 ++++ gpt4all-chat/modellist.h | 6 ++ gpt4all-chat/mysettings.cpp | 23 +++++ gpt4all-chat/mysettings.h | 7 ++ gpt4all-chat/qml/ModelSettings.qml | 92 ++++++++++++++----- 31 files changed, 291 insertions(+), 135 deletions(-) diff --git a/gpt4all-backend/bert.cpp b/gpt4all-backend/bert.cpp index f9e16cb459eb..2424d72c61eb 100644 --- a/gpt4all-backend/bert.cpp +++ b/gpt4all-backend/bert.cpp @@ -714,8 +714,9 @@ Bert::~Bert() { bert_free(d_ptr->ctx); } -bool Bert::loadModel(const std::string &modelPath) +bool Bert::loadModel(const std::string &modelPath, int n_ctx) { + (void)n_ctx; d_ptr->ctx = bert_load_from_file(modelPath.c_str()); d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); d_ptr->modelLoaded = d_ptr->ctx != nullptr; @@ -728,8 +729,10 @@ bool Bert::isModelLoaded() const return d_ptr->modelLoaded; } -size_t Bert::requiredMem(const std::string &/*modelPath*/) +size_t Bert::requiredMem(const std::string &modelPath, int n_ctx) { + (void)modelPath; + (void)n_ctx; return 0; } diff --git a/gpt4all-backend/bert_impl.h b/gpt4all-backend/bert_impl.h index d1cc99f4ace4..b39e77e58d9f 100644 --- a/gpt4all-backend/bert_impl.h +++ b/gpt4all-backend/bert_impl.h @@ -18,9 +18,9 @@ class Bert : public LLModel { bool supportsEmbedding() const override { return true; } bool supportsCompletion() const override { return true; } - bool loadModel(const std::string &modelPath) override; + bool loadModel(const std::string &modelPath, int n_ctx) override; bool isModelLoaded() const override; - size_t requiredMem(const std::string &modelPath) override; + size_t requiredMem(const std::string &modelPath, int n_ctx) override; size_t stateSize() const override; size_t saveState(uint8_t *dest) const override; size_t restoreState(const uint8_t *src) override; diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index 7825e6bc8e35..074ef5dcb1c8 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -676,7 +676,8 @@ GPTJ::GPTJ() d_ptr->modelLoaded = false; } -size_t GPTJ::requiredMem(const std::string &modelPath) { +size_t GPTJ::requiredMem(const std::string &modelPath, int n_ctx) { + (void)n_ctx; gptj_model dummy_model; gpt_vocab dummy_vocab; size_t mem_req; @@ -684,7 +685,8 @@ size_t GPTJ::requiredMem(const std::string &modelPath) { return mem_req; } -bool GPTJ::loadModel(const std::string &modelPath) { +bool GPTJ::loadModel(const std::string &modelPath, int n_ctx) { + (void)n_ctx; std::mt19937 rng(time(NULL)); d_ptr->rng = rng; diff --git a/gpt4all-backend/gptj_impl.h b/gpt4all-backend/gptj_impl.h index e2b1826e2c9a..c2100b24c476 100644 --- a/gpt4all-backend/gptj_impl.h +++ b/gpt4all-backend/gptj_impl.h @@ -17,9 +17,9 @@ class GPTJ : public LLModel { bool supportsEmbedding() const override { return false; } bool supportsCompletion() const override { return true; } - bool loadModel(const std::string &modelPath) override; + bool loadModel(const std::string &modelPath, int n_ctx) override; bool isModelLoaded() const override; - size_t requiredMem(const std::string &modelPath) override; + size_t requiredMem(const std::string &modelPath, int n_ctx) override; size_t stateSize() const override; size_t saveState(uint8_t *dest) const override; size_t restoreState(const uint8_t *src) override; diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 882674e3c46d..cc566b432f79 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -120,7 +120,8 @@ struct llama_file_hparams { enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; }; -size_t LLamaModel::requiredMem(const std::string &modelPath) { +size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx) { + // TODO(cebtenzzre): update to GGUF auto fin = std::ifstream(modelPath, std::ios::binary); fin.seekg(0, std::ios_base::end); size_t filesize = fin.tellg(); @@ -137,40 +138,31 @@ size_t LLamaModel::requiredMem(const std::string &modelPath) { fin.read(reinterpret_cast(&hparams.n_layer), sizeof(hparams.n_layer)); fin.read(reinterpret_cast(&hparams.n_rot), sizeof(hparams.n_rot)); fin.read(reinterpret_cast(&hparams.ftype), sizeof(hparams.ftype)); - const size_t n_ctx = 2048; const size_t kvcache_element_size = 2; // fp16 const size_t est_kvcache_size = hparams.n_embd * hparams.n_layer * 2u * n_ctx * kvcache_element_size; return filesize + est_kvcache_size; } -bool LLamaModel::loadModel(const std::string &modelPath) +bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx) { gpt_params params; - // load the model + if (n_ctx < 8) { + std::cerr << "warning: minimum context size is 8, using minimum size.\n"; + n_ctx = 8; + } + + // -- load the model -- + d_ptr->model_params = llama_model_default_params(); - d_ptr->model_params.use_mmap = params.use_mmap; + d_ptr->model_params.use_mmap = params.use_mmap; #if defined (__APPLE__) - d_ptr->model_params.use_mlock = true; + d_ptr->model_params.use_mlock = true; #else - d_ptr->model_params.use_mlock = params.use_mlock; + d_ptr->model_params.use_mlock = params.use_mlock; #endif - d_ptr->ctx_params = llama_context_default_params(); - - d_ptr->ctx_params.n_ctx = 2048; - d_ptr->ctx_params.seed = params.seed; - d_ptr->ctx_params.f16_kv = params.memory_f16; - - // The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early - // that we want this many logits so the state serializes consistently. - d_ptr->ctx_params.logits_all = true; - - d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - d_ptr->ctx_params.n_threads = d_ptr->n_threads; - d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads; - #ifdef GGML_USE_METAL if (llama_verbose()) { std::cerr << "llama.cpp: using Metal" << std::endl; @@ -197,6 +189,28 @@ bool LLamaModel::loadModel(const std::string &modelPath) return false; } + const int n_ctx_train = llama_n_ctx_train(d_ptr->model); + if (n_ctx > n_ctx_train) { + std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens (" + << n_ctx << " specified)\n"; + } + + // -- initialize the context -- + + d_ptr->ctx_params = llama_context_default_params(); + + d_ptr->ctx_params.n_ctx = n_ctx; + d_ptr->ctx_params.seed = params.seed; + d_ptr->ctx_params.f16_kv = params.memory_f16; + + // The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early + // that we want this many logits so the state serializes consistently. + d_ptr->ctx_params.logits_all = true; + + d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + d_ptr->ctx_params.n_threads = d_ptr->n_threads; + d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads; + d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params); if (!d_ptr->ctx) { #ifdef GGML_USE_KOMPUTE diff --git a/gpt4all-backend/llamamodel_impl.h b/gpt4all-backend/llamamodel_impl.h index d708ddac1821..c32b2413c90a 100644 --- a/gpt4all-backend/llamamodel_impl.h +++ b/gpt4all-backend/llamamodel_impl.h @@ -17,9 +17,9 @@ class LLamaModel : public LLModel { bool supportsEmbedding() const override { return false; } bool supportsCompletion() const override { return true; } - bool loadModel(const std::string &modelPath) override; + bool loadModel(const std::string &modelPath, int n_ctx) override; bool isModelLoaded() const override; - size_t requiredMem(const std::string &modelPath) override; + size_t requiredMem(const std::string &modelPath, int n_ctx) override; size_t stateSize() const override; size_t saveState(uint8_t *dest) const override; size_t restoreState(const uint8_t *src) override; diff --git a/gpt4all-backend/llmodel.cpp b/gpt4all-backend/llmodel.cpp index cb7cfd865b23..2431129f6969 100644 --- a/gpt4all-backend/llmodel.cpp +++ b/gpt4all-backend/llmodel.cpp @@ -138,7 +138,7 @@ const LLModel::Implementation* LLModel::Implementation::implementation(const cha return nullptr; } -LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant) { +LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant, int n_ctx) { if (!has_at_least_minimal_hardware()) { std::cerr << "LLModel ERROR: CPU does not support AVX\n"; return nullptr; @@ -154,7 +154,11 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s if(impl) { LLModel* metalimpl = impl->m_construct(); metalimpl->m_implementation = impl; - size_t req_mem = metalimpl->requiredMem(modelPath); + /* TODO(cebtenzzre): after we fix requiredMem, we should change this to happen at + * load time, not construct time. right now n_ctx is incorrectly hardcoded 2048 in + * most (all?) places where this is called, causing underestimation of required + * memory. */ + size_t req_mem = metalimpl->requiredMem(modelPath, n_ctx); float req_to_total = (float) req_mem / (float) total_mem; // on a 16GB M2 Mac a 13B q4_0 (0.52) works for me but a 13B q4_K_M (0.55) does not if (req_to_total >= 0.53) { @@ -165,6 +169,8 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s } } } + #else + (void)n_ctx; #endif if (!impl) { diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index f11c4c74cbc2..f1551cfb1f77 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -37,7 +37,7 @@ class LLModel { static bool isImplementation(const Dlhandle&); static const std::vector& implementationList(); static const Implementation *implementation(const char *fname, const std::string& buildVariant); - static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto"); + static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto", int n_ctx = 2048); static std::vector availableGPUDevices(); static void setImplementationsSearchPath(const std::string& path); static const std::string& implementationsSearchPath(); @@ -74,9 +74,9 @@ class LLModel { virtual bool supportsEmbedding() const = 0; virtual bool supportsCompletion() const = 0; - virtual bool loadModel(const std::string &modelPath) = 0; + virtual bool loadModel(const std::string &modelPath, int n_ctx) = 0; virtual bool isModelLoaded() const = 0; - virtual size_t requiredMem(const std::string &modelPath) = 0; + virtual size_t requiredMem(const std::string &modelPath, int n_ctx) = 0; virtual size_t stateSize() const { return 0; } virtual size_t saveState(uint8_t */*dest*/) const { return 0; } virtual size_t restoreState(const uint8_t */*src*/) { return 0; } diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index 38b03ea09d89..c8af2ca3d49f 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -47,16 +47,16 @@ void llmodel_model_destroy(llmodel_model model) { delete reinterpret_cast(model); } -size_t llmodel_required_mem(llmodel_model model, const char *model_path) +size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx) { LLModelWrapper *wrapper = reinterpret_cast(model); - return wrapper->llModel->requiredMem(model_path); + return wrapper->llModel->requiredMem(model_path, n_ctx); } -bool llmodel_loadModel(llmodel_model model, const char *model_path) +bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx) { LLModelWrapper *wrapper = reinterpret_cast(model); - return wrapper->llModel->loadModel(model_path); + return wrapper->llModel->loadModel(model_path, n_ctx); } bool llmodel_isModelLoaded(llmodel_model model) diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index e9b370c22bbb..dcd53f2ec788 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -110,17 +110,19 @@ void llmodel_model_destroy(llmodel_model model); * Estimate RAM requirement for a model file * @param model A pointer to the llmodel_model instance. * @param model_path A string representing the path to the model file. + * @param n_ctx Maximum size of context window * @return size greater than 0 if the model was parsed successfully, 0 if file could not be parsed. */ -size_t llmodel_required_mem(llmodel_model model, const char *model_path); +size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx); /** * Load a model from a file. * @param model A pointer to the llmodel_model instance. * @param model_path A string representing the path to the model file. + * @param n_ctx Maximum size of context window * @return true if the model was loaded successfully, false otherwise. */ -bool llmodel_loadModel(llmodel_model model, const char *model_path); +bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx); /** * Check if a model is loaded. diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs index 55defe0921be..583380cb8b95 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs @@ -188,7 +188,7 @@ public bool IsLoaded() /// true if the model was loaded successfully, false otherwise. public bool Load(string modelPath) { - return NativeMethods.llmodel_loadModel(_handle, modelPath); + return NativeMethods.llmodel_loadModel(_handle, modelPath, 2048); } protected void Destroy() diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs index cc43f3662aab..c6ea9e111920 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs @@ -70,7 +70,8 @@ public static extern IntPtr llmodel_model_create2( [return: MarshalAs(UnmanagedType.I1)] public static extern bool llmodel_loadModel( [NativeTypeName("llmodel_model")] IntPtr model, - [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path); + [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path, + [NativeTypeName("int32_t")] int n_ctx); [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] diff --git a/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs b/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs index 02c5c5888d21..9e81e5af33b9 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs @@ -39,7 +39,7 @@ private IGpt4AllModel CreateModel(string modelPath) var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error); _logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle); _logger.LogInformation("Model loading started"); - var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath); + var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath, 2048); _logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully); if (!loadedSuccessfully) { diff --git a/gpt4all-bindings/golang/binding.cpp b/gpt4all-bindings/golang/binding.cpp index 253282e5ed8e..1ccb9e61f134 100644 --- a/gpt4all-bindings/golang/binding.cpp +++ b/gpt4all-bindings/golang/binding.cpp @@ -23,7 +23,7 @@ void* load_model(const char *fname, int n_threads) { fprintf(stderr, "%s: error '%s'\n", __func__, new_error); return nullptr; } - if (!llmodel_loadModel(model, fname)) { + if (!llmodel_loadModel(model, fname, 2048)) { llmodel_model_destroy(model); return nullptr; } diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java index b5a7fca2b0c3..6c0d053ee0e2 100644 --- a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java @@ -195,7 +195,7 @@ public LLModel(Path modelPath) { if(model == null) { throw new IllegalStateException("Could not load, gpt4all backend returned error: " + error.getValue().getString(0)); } - library.llmodel_loadModel(model, modelPathAbs); + library.llmodel_loadModel(model, modelPathAbs, 2048); if(!library.llmodel_isModelLoaded(model)){ throw new IllegalStateException("The model " + modelName + " could not be loaded"); diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModelLibrary.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModelLibrary.java index 42dde345de18..b2d48e34d297 100644 --- a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModelLibrary.java +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModelLibrary.java @@ -61,7 +61,7 @@ public LLModelPromptContext(jnr.ffi.Runtime runtime) { Pointer llmodel_model_create2(String model_path, String build_variant, PointerByReference error); void llmodel_model_destroy(Pointer model); - boolean llmodel_loadModel(Pointer model, String model_path); + boolean llmodel_loadModel(Pointer model, String model_path, int n_ctx); boolean llmodel_isModelLoaded(Pointer model); @u_int64_t long llmodel_get_state_size(Pointer model); @u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest); diff --git a/gpt4all-bindings/python/gpt4all/__init__.py b/gpt4all-bindings/python/gpt4all/__init__.py index f4dfa4bff49d..391fab0298f8 100644 --- a/gpt4all-bindings/python/gpt4all/__init__.py +++ b/gpt4all-bindings/python/gpt4all/__init__.py @@ -1,2 +1,2 @@ -from .gpt4all import Embed4All, GPT4All # noqa -from .pyllmodel import LLModel # noqa +from .gpt4all import Embed4All as Embed4All, GPT4All as GPT4All +from .pyllmodel import LLModel as LLModel diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index f2494ebb905f..50c3c88e4c5d 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -69,6 +69,7 @@ def __init__( allow_download: bool = True, n_threads: Optional[int] = None, device: Optional[str] = "cpu", + n_ctx: int = 2048, verbose: bool = False, ): """ @@ -90,15 +91,16 @@ def __init__( Default is "cpu". Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model. + n_ctx: Maximum size of context window + verbose: If True, print debug messages. """ self.model_type = model_type self.model = pyllmodel.LLModel() # Retrieve model and download if allowed self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose) - if device is not None: - if device != "cpu": - self.model.init_gpu(model_path=self.config["path"], device=device) - self.model.load_model(self.config["path"]) + if device is not None and device != "cpu": + self.model.init_gpu(model_path=self.config["path"], device=device, n_ctx=n_ctx) + self.model.load_model(self.config["path"], n_ctx) # Set n_threads if n_threads is not None: self.model.set_thread_count(n_threads) diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index 6f58d672b9c4..f3a1ee8e9153 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ctypes import importlib.resources import logging @@ -7,6 +9,7 @@ import subprocess import sys import threading +from enum import Enum from queue import Queue from typing import Callable, Iterable, List @@ -72,9 +75,9 @@ class LLModelGPUDevice(ctypes.Structure): llmodel.llmodel_model_destroy.argtypes = [ctypes.c_void_p] llmodel.llmodel_model_destroy.restype = None -llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p] +llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int] llmodel.llmodel_loadModel.restype = ctypes.c_bool -llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p] +llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int] llmodel.llmodel_required_mem.restype = ctypes.c_size_t llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p] llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool @@ -114,7 +117,7 @@ class LLModelGPUDevice(ctypes.Structure): llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p] llmodel.llmodel_threadCount.restype = ctypes.c_int32 -llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode("utf-8")) +llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode()) llmodel.llmodel_available_gpu_devices.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(ctypes.c_int32)] llmodel.llmodel_available_gpu_devices.restype = ctypes.POINTER(LLModelGPUDevice) @@ -143,10 +146,16 @@ def _create_model(model_path: bytes) -> ctypes.c_void_p: err = ctypes.c_char_p() model = llmodel.llmodel_model_create2(model_path, b"auto", ctypes.byref(err)) if model is None: - raise ValueError(f"Unable to instantiate model: {err.decode()}") + s = err.value + raise ValueError("Unable to instantiate model: {'null' if s is None else s.decode()}") return model +# Symbol to terminate from generator +class Sentinel(Enum): + TERMINATING_SYMBOL = 0 + + class LLModel: """ Base class and universal wrapper for GPT4All language models @@ -173,12 +182,16 @@ def __del__(self): if self.model is not None: self.llmodel_lib.llmodel_model_destroy(self.model) - def memory_needed(self, model_path: str) -> int: - model_path_enc = model_path.encode("utf-8") - self.model = _create_model(model_path_enc) - return llmodel.llmodel_required_mem(self.model, model_path_enc) + def memory_needed(self, model_path: str, n_ctx: int) -> int: + self.model = None + return self._memory_needed(model_path, n_ctx) + + def _memory_needed(self, model_path: str, n_ctx: int) -> int: + if self.model is None: + self.model = _create_model(model_path.encode()) + return llmodel.llmodel_required_mem(self.model, model_path.encode(), n_ctx) - def list_gpu(self, model_path: str) -> list: + def list_gpu(self, model_path: str, n_ctx: int) -> list[LLModelGPUDevice]: """ Lists available GPU devices that satisfy the model's memory requirements. @@ -186,45 +199,41 @@ def list_gpu(self, model_path: str) -> list: ---------- model_path : str Path to the model. + n_ctx : int + Maximum size of context window Returns ------- list A list of LLModelGPUDevice structures representing available GPU devices. """ - if self.model is not None: - model_path_enc = model_path.encode("utf-8") - mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc) - else: - mem_required = self.memory_needed(model_path) + mem_required = self._memory_needed(model_path, n_ctx) + return self._list_gpu(mem_required) + + def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]: num_devices = ctypes.c_int32(0) devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices)) if not devices_ptr: raise ValueError("Unable to retrieve available GPU devices") - devices = [devices_ptr[i] for i in range(num_devices.value)] - return devices + return devices_ptr[:num_devices.value] - def init_gpu(self, model_path: str, device: str): - if self.model is not None: - model_path_enc = model_path.encode("utf-8") - mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc) - else: - mem_required = self.memory_needed(model_path) - device_enc = device.encode("utf-8") - success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device_enc) + def init_gpu(self, model_path: str, device: str, n_ctx: int): + mem_required = self._memory_needed(model_path, n_ctx) + + success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode()) if not success: # Retrieve all GPUs without considering memory requirements. num_devices = ctypes.c_int32(0) all_devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, 0, ctypes.byref(num_devices)) if not all_devices_ptr: raise ValueError("Unable to retrieve list of all GPU devices") - all_gpus = [all_devices_ptr[i].name.decode('utf-8') for i in range(num_devices.value)] + all_gpus = [d.name.decode() for d in all_devices_ptr[:num_devices.value]] # Retrieve GPUs that meet the memory requirements using list_gpu - available_gpus = [device.name.decode('utf-8') for device in self.list_gpu(model_path)] + available_gpus = [device.name.decode() for device in self._list_gpu(mem_required)] # Identify GPUs that are unavailable due to insufficient memory or features - unavailable_gpus = set(all_gpus) - set(available_gpus) + unavailable_gpus = set(all_gpus).difference(available_gpus) # Formulate the error message error_msg = "Unable to initialize model on GPU: '{}'.".format(device) @@ -232,7 +241,7 @@ def init_gpu(self, model_path: str, device: str): error_msg += "\nUnavailable GPUs due to insufficient memory or features: {}.".format(unavailable_gpus) raise ValueError(error_msg) - def load_model(self, model_path: str) -> bool: + def load_model(self, model_path: str, n_ctx: int) -> bool: """ Load model from a file. @@ -240,15 +249,16 @@ def load_model(self, model_path: str) -> bool: ---------- model_path : str Model filepath + n_ctx : int + Maximum size of context window Returns ------- True if model loaded successfully, False otherwise """ - model_path_enc = model_path.encode("utf-8") - self.model = _create_model(model_path_enc) + self.model = _create_model(model_path.encode()) - llmodel.llmodel_loadModel(self.model, model_path_enc) + llmodel.llmodel_loadModel(self.model, model_path.encode(), n_ctx) filename = os.path.basename(model_path) self.model_name = os.path.splitext(filename)[0] @@ -312,7 +322,7 @@ def generate_embedding(self, text: str) -> List[float]: raise ValueError("Text must not be None or empty") embedding_size = ctypes.c_size_t() - c_text = ctypes.c_char_p(text.encode('utf-8')) + c_text = ctypes.c_char_p(text.encode()) embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size)) embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)] llmodel.llmodel_free_embedding(embedding_ptr) @@ -357,7 +367,7 @@ def prompt_model( prompt, ) - prompt_bytes = prompt.encode("utf-8") + prompt_bytes = prompt.encode() prompt_ptr = ctypes.c_char_p(prompt_bytes) self._set_context( @@ -385,10 +395,7 @@ def prompt_model( def prompt_model_streaming( self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs ) -> Iterable[str]: - # Symbol to terminate from generator - TERMINATING_SYMBOL = object() - - output_queue: Queue = Queue() + output_queue: Queue[str | Sentinel] = Queue() # Put response tokens into an output queue def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType: @@ -405,7 +412,7 @@ def _generator_callback(token_id: int, response: str): def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs): self.prompt_model(prompt, callback, **kwargs) - output_queue.put(TERMINATING_SYMBOL) + output_queue.put(Sentinel.TERMINATING_SYMBOL) # Kick off llmodel_prompt in separate thread so we can return generator # immediately @@ -419,7 +426,7 @@ def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs): # Generator while True: response = output_queue.get() - if response is TERMINATING_SYMBOL: + if isinstance(response, Sentinel): break yield response @@ -442,7 +449,7 @@ def _raw_callback(token_id: int, response: bytes) -> bool: else: # beginning of a byte sequence if len(self.buffer) > 0: - decoded.append(self.buffer.decode('utf-8', 'replace')) + decoded.append(self.buffer.decode(errors='replace')) self.buffer.clear() @@ -451,7 +458,7 @@ def _raw_callback(token_id: int, response: bytes) -> bool: if self.buff_expecting_cont_bytes <= 0: # received the whole sequence or an out of place continuation byte - decoded.append(self.buffer.decode('utf-8', 'replace')) + decoded.append(self.buffer.decode(errors='replace')) self.buffer.clear() self.buff_expecting_cont_bytes = 0 diff --git a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py index 5b3c3fba19f1..679b385becd2 100644 --- a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py @@ -117,7 +117,7 @@ def test_empty_embedding(): def test_download_model(tmp_path: Path): import gpt4all.gpt4all old_default_dir = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY - gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = tmp_path # temporary pytest directory to ensure a download happens + gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = str(tmp_path) # temporary pytest directory to ensure a download happens try: model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin') model_path = tmp_path / model.config['filename'] diff --git a/gpt4all-bindings/typescript/index.cc b/gpt4all-bindings/typescript/index.cc index 8a4792362e43..a65ea31c0223 100644 --- a/gpt4all-bindings/typescript/index.cc +++ b/gpt4all-bindings/typescript/index.cc @@ -28,7 +28,7 @@ Napi::Function NodeModelWrapper::GetClass(Napi::Env env) { Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info) { auto env = info.Env(); - return Napi::Number::New(env, static_cast( llmodel_required_mem(GetInference(), full_model_path.c_str()) )); + return Napi::Number::New(env, static_cast( llmodel_required_mem(GetInference(), full_model_path.c_str(), 2048) )); } Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info) @@ -163,7 +163,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info) } } - auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str()); + auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), 2048); if(!success) { Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException(); return; diff --git a/gpt4all-chat/chatgpt.cpp b/gpt4all-chat/chatgpt.cpp index 11b0be49627f..98d241dda89c 100644 --- a/gpt4all-chat/chatgpt.cpp +++ b/gpt4all-chat/chatgpt.cpp @@ -20,15 +20,17 @@ ChatGPT::ChatGPT() { } -size_t ChatGPT::requiredMem(const std::string &modelPath) +size_t ChatGPT::requiredMem(const std::string &modelPath, int n_ctx) { Q_UNUSED(modelPath); + Q_UNUSED(n_ctx); return 0; } -bool ChatGPT::loadModel(const std::string &modelPath) +bool ChatGPT::loadModel(const std::string &modelPath, int n_ctx) { Q_UNUSED(modelPath); + Q_UNUSED(n_ctx); return true; } diff --git a/gpt4all-chat/chatgpt.h b/gpt4all-chat/chatgpt.h index 0f835bee33ad..7bb3912f55af 100644 --- a/gpt4all-chat/chatgpt.h +++ b/gpt4all-chat/chatgpt.h @@ -48,9 +48,9 @@ class ChatGPT : public QObject, public LLModel { bool supportsEmbedding() const override { return false; } bool supportsCompletion() const override { return true; } - bool loadModel(const std::string &modelPath) override; + bool loadModel(const std::string &modelPath, int n_ctx) override; bool isModelLoaded() const override; - size_t requiredMem(const std::string &modelPath) override; + size_t requiredMem(const std::string &modelPath, int n_ctx) override; size_t stateSize() const override; size_t saveState(uint8_t *dest) const override; size_t restoreState(const uint8_t *src) override; diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp index e12edd87aa1d..0b295fab6247 100644 --- a/gpt4all-chat/chatlistmodel.cpp +++ b/gpt4all-chat/chatlistmodel.cpp @@ -5,7 +5,7 @@ #include #define CHAT_FORMAT_MAGIC 0xF5D553CC -#define CHAT_FORMAT_VERSION 6 +#define CHAT_FORMAT_VERSION 7 class MyChatListModel: public ChatListModel { }; Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 78f73cd45c50..57cdc96d4519 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -248,14 +248,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) m_llModelInfo.model = model; } else { + // TODO: make configurable in UI + auto n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo); + m_ctx.n_ctx = n_ctx; + + std::string buildVariant = "auto"; #if defined(Q_OS_MAC) && defined(__arm__) if (m_forceMetal) - m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "metal"); - else - m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "auto"); -#else - m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), "auto"); + buildVariant = "metal"; #endif + m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx); if (m_llModelInfo.model) { // Update the settings that a model is being loaded and update the device list @@ -267,7 +269,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) if (requestedDevice == "CPU") { emit reportFallbackReason(""); // fallback not applicable } else { - const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString()); + const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx); std::vector availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory); LLModel::GPUDevice *device = nullptr; @@ -296,14 +298,14 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) // Report which device we're actually using emit reportDevice(actualDevice); - bool success = m_llModelInfo.model->loadModel(filePath.toStdString()); + bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx); if (actualDevice == "CPU") { // we asked llama.cpp to use the CPU } else if (!success) { // llama_init_from_file returned nullptr emit reportDevice("CPU"); emit reportFallbackReason("
GPU loading failed (out of VRAM?)"); - success = m_llModelInfo.model->loadModel(filePath.toStdString()); + success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx); } else if (!m_llModelInfo.model->usingGPUDevice()) { // ggml_vk_init was not called in llama.cpp // We might have had to fallback to CPU after load if the model is not possible to accelerate @@ -763,6 +765,8 @@ bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc) return false; } +// this function serialized the cached model state to disk. +// we want to also serialize n_ctx, and read it at load time. bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) { if (version > 1) { @@ -790,6 +794,9 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) stream << responseLogits; } stream << m_ctx.n_past; + if (version >= 6) { + stream << m_ctx.n_ctx; + } stream << quint64(m_ctx.logits.size()); stream.writeRawData(reinterpret_cast(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float)); stream << quint64(m_ctx.tokens.size()); @@ -839,6 +846,12 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, stream >> n_past; if (!discardKV) m_ctx.n_past = n_past; + if (version >= 6) { + uint32_t n_ctx; + stream >> n_ctx; + if (!discardKV) m_ctx.n_ctx = n_ctx; + } + quint64 logitsSize; stream >> logitsSize; if (!discardKV) { diff --git a/gpt4all-chat/embllm.cpp b/gpt4all-chat/embllm.cpp index 37a48294c569..7be2d3480a5c 100644 --- a/gpt4all-chat/embllm.cpp +++ b/gpt4all-chat/embllm.cpp @@ -29,8 +29,8 @@ bool EmbeddingLLM::loadModel() return false; } - m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto"); - bool success = m_model->loadModel(filePath.toStdString()); + m_model = LLModel::Implementation::construct(filePath.toStdString()); + bool success = m_model->loadModel(filePath.toStdString(), 2048); if (!success) { qWarning() << "WARNING: Could not load sbert"; delete m_model; diff --git a/gpt4all-chat/modellist.cpp b/gpt4all-chat/modellist.cpp index 3403cb8c07bf..bb68836130e5 100644 --- a/gpt4all-chat/modellist.cpp +++ b/gpt4all-chat/modellist.cpp @@ -97,6 +97,17 @@ void ModelInfo::setPromptBatchSize(int s) m_promptBatchSize = s; } +int ModelInfo::contextLength() const +{ + return MySettings::globalInstance()->modelContextLength(*this); +} + +void ModelInfo::setContextLength(int l) +{ + if (isClone) MySettings::globalInstance()->setModelContextLength(*this, l, isClone /*force*/); + m_contextLength = l; +} + double ModelInfo::repeatPenalty() const { return MySettings::globalInstance()->modelRepeatPenalty(*this); @@ -274,6 +285,7 @@ ModelList::ModelList() connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings); + connect(MySettings::globalInstance(), &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);; connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings); @@ -525,6 +537,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const return info->maxLength(); case PromptBatchSizeRole: return info->promptBatchSize(); + case ContextLengthRole: + return info->contextLength(); case RepeatPenaltyRole: return info->repeatPenalty(); case RepeatPenaltyTokensRole: @@ -740,6 +754,7 @@ QString ModelList::clone(const ModelInfo &model) updateData(id, ModelList::TopKRole, model.topK()); updateData(id, ModelList::MaxLengthRole, model.maxLength()); updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize()); + updateData(id, ModelList::ContextLengthRole, model.contextLength()); updateData(id, ModelList::RepeatPenaltyRole, model.repeatPenalty()); updateData(id, ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens()); updateData(id, ModelList::PromptTemplateRole, model.promptTemplate()); @@ -1106,6 +1121,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) updateData(id, ModelList::MaxLengthRole, obj["maxLength"].toInt()); if (obj.contains("promptBatchSize")) updateData(id, ModelList::PromptBatchSizeRole, obj["promptBatchSize"].toInt()); + if (obj.contains("contextLength")) + updateData(id, ModelList::ContextLengthRole, obj["contextLength"].toInt()); if (obj.contains("repeatPenalty")) updateData(id, ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble()); if (obj.contains("repeatPenaltyTokens")) @@ -1198,6 +1215,8 @@ void ModelList::updateModelsFromSettings() const int maxLength = settings.value(g + "/maxLength").toInt(); Q_ASSERT(settings.contains(g + "/promptBatchSize")); const int promptBatchSize = settings.value(g + "/promptBatchSize").toInt(); + Q_ASSERT(settings.contains(g + "/contextLength")); + const int contextLength = settings.value(g + "/contextLength").toInt(); Q_ASSERT(settings.contains(g + "/repeatPenalty")); const double repeatPenalty = settings.value(g + "/repeatPenalty").toDouble(); Q_ASSERT(settings.contains(g + "/repeatPenaltyTokens")); @@ -1216,6 +1235,7 @@ void ModelList::updateModelsFromSettings() updateData(id, ModelList::TopKRole, topK); updateData(id, ModelList::MaxLengthRole, maxLength); updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize); + updateData(id, ModelList::ContextLengthRole, contextLength); updateData(id, ModelList::RepeatPenaltyRole, repeatPenalty); updateData(id, ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens); updateData(id, ModelList::PromptTemplateRole, promptTemplate); diff --git a/gpt4all-chat/modellist.h b/gpt4all-chat/modellist.h index 536f3a995789..c314540744ad 100644 --- a/gpt4all-chat/modellist.h +++ b/gpt4all-chat/modellist.h @@ -39,6 +39,7 @@ struct ModelInfo { Q_PROPERTY(int topK READ topK WRITE setTopK) Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength) Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize) + Q_PROPERTY(int contextLength READ contextLength WRITE setContextLength) Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty) Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens) Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate) @@ -94,6 +95,8 @@ struct ModelInfo { void setMaxLength(int l); int promptBatchSize() const; void setPromptBatchSize(int s); + int contextLength() const; + void setContextLength(int l); double repeatPenalty() const; void setRepeatPenalty(double p); int repeatPenaltyTokens() const; @@ -112,6 +115,7 @@ struct ModelInfo { int m_topK = 40; int m_maxLength = 4096; int m_promptBatchSize = 128; + int m_contextLength = 2048; double m_repeatPenalty = 1.18; int m_repeatPenaltyTokens = 64; QString m_promptTemplate = "### Human:\n%1\n### Assistant:\n"; @@ -227,6 +231,7 @@ class ModelList : public QAbstractListModel TopKRole, MaxLengthRole, PromptBatchSizeRole, + ContextLengthRole, RepeatPenaltyRole, RepeatPenaltyTokensRole, PromptTemplateRole, @@ -269,6 +274,7 @@ class ModelList : public QAbstractListModel roles[TopKRole] = "topK"; roles[MaxLengthRole] = "maxLength"; roles[PromptBatchSizeRole] = "promptBatchSize"; + roles[ContextLengthRole] = "contextLength"; roles[RepeatPenaltyRole] = "repeatPenalty"; roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens"; roles[PromptTemplateRole] = "promptTemplate"; diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index 3183a3217f61..5f5c7b801b62 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -90,6 +90,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &model) setModelTopK(model, model.m_topK);; setModelMaxLength(model, model.m_maxLength); setModelPromptBatchSize(model, model.m_promptBatchSize); + setModelContextLength(model, model.m_contextLength); setModelRepeatPenalty(model, model.m_repeatPenalty); setModelRepeatPenaltyTokens(model, model.m_repeatPenaltyTokens); setModelPromptTemplate(model, model.m_promptTemplate); @@ -280,6 +281,28 @@ void MySettings::setModelPromptBatchSize(const ModelInfo &m, int s, bool force) emit promptBatchSizeChanged(m); } +int MySettings::modelContextLength(const ModelInfo &m) const +{ + QSettings setting; + setting.sync(); + return setting.value(QString("model-%1").arg(m.id()) + "/contextLength", m.m_contextLength).toInt(); +} + +void MySettings::setModelContextLength(const ModelInfo &m, int l, bool force) +{ + if (modelContextLength(m) == l && !force) + return; + + QSettings setting; + if (m.m_contextLength == l && !m.isClone) + setting.remove(QString("model-%1").arg(m.id()) + "/contextLength"); + else + setting.setValue(QString("model-%1").arg(m.id()) + "/contextLength", l); + setting.sync(); + if (!force) + emit contextLengthChanged(m); +} + double MySettings::modelRepeatPenalty(const ModelInfo &m) const { QSettings setting; diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h index 54b8b6e62294..3287f413a562 100644 --- a/gpt4all-chat/mysettings.h +++ b/gpt4all-chat/mysettings.h @@ -1,6 +1,8 @@ #ifndef MYSETTINGS_H #define MYSETTINGS_H +#include + #include #include @@ -59,6 +61,8 @@ class MySettings : public QObject Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &m, const QString &t, bool force = false); QString modelSystemPrompt(const ModelInfo &m) const; Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &m, const QString &p, bool force = false); + int modelContextLength(const ModelInfo &m) const; + Q_INVOKABLE void setModelContextLength(const ModelInfo &m, int s, bool force = false); // Application settings int threadCount() const; @@ -79,6 +83,8 @@ class MySettings : public QObject void setForceMetal(bool b); QString device() const; void setDevice(const QString &u); + int32_t contextLength() const; + void setContextLength(int32_t value); // Release/Download settings QString lastVersionStarted() const; @@ -114,6 +120,7 @@ class MySettings : public QObject void topKChanged(const ModelInfo &model); void maxLengthChanged(const ModelInfo &model); void promptBatchSizeChanged(const ModelInfo &model); + void contextLengthChanged(const ModelInfo &model); void repeatPenaltyChanged(const ModelInfo &model); void repeatPenaltyTokensChanged(const ModelInfo &model); void promptTemplateChanged(const ModelInfo &model); diff --git a/gpt4all-chat/qml/ModelSettings.qml b/gpt4all-chat/qml/ModelSettings.qml index a0ae6cb7ca51..c9f46735ab02 100644 --- a/gpt4all-chat/qml/ModelSettings.qml +++ b/gpt4all-chat/qml/ModelSettings.qml @@ -349,13 +349,61 @@ MySettingsTab { rowSpacing: 10 columnSpacing: 10 + Label { + id: contextLengthLabel + visible: !root.currentModelInfo.isChatGPT + text: qsTr("Context Length:") + font.pixelSize: theme.fontSizeLarge + color: theme.textColor + Layout.row: 0 + Layout.column: 0 + } + MyTextField { + id: contextLengthField + visible: !root.currentModelInfo.isChatGPT + text: root.currentModelInfo.contextLength + color: theme.textColor + font.pixelSize: theme.fontSizeLarge + ToolTip.text: qsTr("Maximum combined prompt/response tokens before information is lost.\nUsing more context than the model was trained on will yield poor results.\nNOTE: Does not take effect until you RESTART GPT4All or SWITCH MODELS.") + ToolTip.visible: hovered + Layout.row: 0 + Layout.column: 1 + validator: IntValidator { + bottom: 1 + } + Connections { + target: MySettings + function onContextLengthChanged() { + contextLengthField.text = root.currentModelInfo.contextLength; + } + } + Connections { + target: root + function onCurrentModelInfoChanged() { + contextLengthField.text = root.currentModelInfo.contextLength; + } + } + onEditingFinished: { + var val = parseInt(text) + if (!isNaN(val)) { + MySettings.setModelContextLength(root.currentModelInfo, val) + focus = false + } else { + text = root.currentModelInfo.contextLength + } + } + Accessible.role: Accessible.EditableText + Accessible.name: contextLengthLabel.text + Accessible.description: ToolTip.text + } + Label { id: tempLabel text: qsTr("Temperature:") color: theme.textColor font.pixelSize: theme.fontSizeLarge - Layout.row: 0 - Layout.column: 0 + Layout.row: 1 + Layout.column: 2 } MyTextField { @@ -365,8 +413,8 @@ MySettingsTab { font.pixelSize: theme.fontSizeLarge ToolTip.text: qsTr("Temperature increases the chances of choosing less likely tokens.\nNOTE: Higher temperature gives more creative but less predictable outputs.") ToolTip.visible: hovered - Layout.row: 0 - Layout.column: 1 + Layout.row: 1 + Layout.column: 3 validator: DoubleValidator { locale: "C" } @@ -400,8 +448,8 @@ MySettingsTab { text: qsTr("Top P:") color: theme.textColor font.pixelSize: theme.fontSizeLarge - Layout.row: 0 - Layout.column: 2 + Layout.row: 2 + Layout.column: 0 } MyTextField { id: topPField @@ -410,8 +458,8 @@ MySettingsTab { font.pixelSize: theme.fontSizeLarge ToolTip.text: qsTr("Only the most likely tokens up to a total probability of top_p can be chosen.\nNOTE: Prevents choosing highly unlikely tokens, aka Nucleus Sampling") ToolTip.visible: hovered - Layout.row: 0 - Layout.column: 3 + Layout.row: 2 + Layout.column: 1 validator: DoubleValidator { locale: "C" } @@ -446,8 +494,8 @@ MySettingsTab { text: qsTr("Top K:") color: theme.textColor font.pixelSize: theme.fontSizeLarge - Layout.row: 1 - Layout.column: 0 + Layout.row: 2 + Layout.column: 2 } MyTextField { id: topKField @@ -457,8 +505,8 @@ MySettingsTab { font.pixelSize: theme.fontSizeLarge ToolTip.text: qsTr("Only the top K most likely tokens will be chosen from") ToolTip.visible: hovered - Layout.row: 1 - Layout.column: 1 + Layout.row: 2 + Layout.column: 3 validator: IntValidator { bottom: 1 } @@ -493,7 +541,7 @@ MySettingsTab { text: qsTr("Max Length:") color: theme.textColor font.pixelSize: theme.fontSizeLarge - Layout.row: 1 + Layout.row: 0 Layout.column: 2 } MyTextField { @@ -504,7 +552,7 @@ MySettingsTab { font.pixelSize: theme.fontSizeLarge ToolTip.text: qsTr("Maximum length of response in tokens") ToolTip.visible: hovered - Layout.row: 1 + Layout.row: 0 Layout.column: 3 validator: IntValidator { bottom: 1 @@ -541,7 +589,7 @@ MySettingsTab { text: qsTr("Prompt Batch Size:") font.pixelSize: theme.fontSizeLarge color: theme.textColor - Layout.row: 2 + Layout.row: 1 Layout.column: 0 } MyTextField { @@ -552,7 +600,7 @@ MySettingsTab { font.pixelSize: theme.fontSizeLarge ToolTip.text: qsTr("Amount of prompt tokens to process at once.\nNOTE: Higher values can speed up reading prompts but will use more RAM") ToolTip.visible: hovered - Layout.row: 2 + Layout.row: 1 Layout.column: 1 validator: IntValidator { bottom: 1 @@ -588,8 +636,8 @@ MySettingsTab { text: qsTr("Repeat Penalty:") color: theme.textColor font.pixelSize: theme.fontSizeLarge - Layout.row: 2 - Layout.column: 2 + Layout.row: 3 + Layout.column: 0 } MyTextField { id: repeatPenaltyField @@ -599,8 +647,8 @@ MySettingsTab { font.pixelSize: theme.fontSizeLarge ToolTip.text: qsTr("Amount to penalize repetitiveness of the output") ToolTip.visible: hovered - Layout.row: 2 - Layout.column: 3 + Layout.row: 3 + Layout.column: 1 validator: DoubleValidator { locale: "C" } @@ -636,7 +684,7 @@ MySettingsTab { color: theme.textColor font.pixelSize: theme.fontSizeLarge Layout.row: 3 - Layout.column: 0 + Layout.column: 2 } MyTextField { id: repeatPenaltyTokenField @@ -647,7 +695,7 @@ MySettingsTab { ToolTip.text: qsTr("How far back in output to apply repeat penalty") ToolTip.visible: hovered Layout.row: 3 - Layout.column: 1 + Layout.column: 3 validator: IntValidator { bottom: 1 }