Skip to content

Commit

Permalink
expose some sampling params
Browse files Browse the repository at this point in the history
  • Loading branch information
hazelnutcloud committed May 16, 2024
1 parent ad808a1 commit 17f069c
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 18 deletions.
21 changes: 19 additions & 2 deletions godot/addons/godot-llama-cpp/chat/chat_formatter.gd
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ static func apply(format: String, messages: Array) -> String:
return format_llama3(messages)
"phi3":
return format_phi3(messages)
"mistral":
return format_mistral(messages)
_:
printerr("Unknown chat format: ", format)
return ""

static func format_llama3(messages: Array) -> String:
var res = "<|begin_of_text|>"
var res = ""

for i in range(messages.size()):
match messages[i]:
Expand All @@ -27,7 +29,7 @@ static func format_llama3(messages: Array) -> String:
return res

static func format_phi3(messages: Array) -> String:
var res = "<s>"
var res = ""

for i in range(messages.size()):
match messages[i]:
Expand All @@ -37,3 +39,18 @@ static func format_phi3(messages: Array) -> String:
printerr("Invalid message at index ", i)
res += "<|assistant|>\n"
return res

static func format_mistral(messages: Array) -> String:
var res = ""

for i in range(messages.size()):
match messages[i]:
{"text": var text, "sender": var sender}:
if sender == "user":
res += "[INST] %s [/INST]" % text
else:
res += "%s</s>"
_:
printerr("Invalid message at index ", i)

return res
2 changes: 1 addition & 1 deletion godot/addons/godot-llama-cpp/plugin.gdextension
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ compatibility_minimum = "4.2"

[libraries]

macos.debug = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-Debug.dylib"
macos.debug = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseSafe.dylib"
macos.release = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseSafe.dylib"
windows.debug.x86_32 = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp.windows.template_debug.x86_32.dll"
windows.release.x86_32 = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp.windows.template_release.x86_32.dll"
Expand Down
7 changes: 5 additions & 2 deletions godot/examples/simple/simple.gd
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ func _on_text_edit_submit(input: String) -> void:
handle_input(input)

func handle_input(input: String) -> void:
var messages = [{ "sender": "system", "text": "You are a helpful assistant" }]
#var messages = [{ "sender": "system", "text": "You are a pirate chatbot who always responds in pirate speak!" }]

#var messages = [{ "sender": "system", "text": "You are a helpful chatbot assistant!" }]
var messages = []
messages.append_array(messages_container.get_children().filter(func(msg: Message): return msg.include_in_prompt).map(
func(msg: Message) -> Dictionary:
return { "text": msg.text, "sender": msg.sender }
))
messages.append({"text": input, "sender": "user"})
var prompt = ChatFormatter.apply("phi3", messages)
var prompt = ChatFormatter.apply("llama3", messages)
print("prompt: ", prompt)

var completion_id = llama_context.request_completion(prompt)
Expand Down
5 changes: 3 additions & 2 deletions godot/examples/simple/simple.tscn
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[ext_resource type="Script" path="res://examples/simple/simple.gd" id="1_sruc3"]
[ext_resource type="PackedScene" uid="uid://t862t0v8ht2q" path="res://examples/simple/message.tscn" id="2_7iip7"]
[ext_resource type="Script" path="res://examples/simple/TextEdit.gd" id="2_7usqw"]
[ext_resource type="LlamaModel" path="res://models/Phi-3-mini-128k-instruct.Q5_K_M.gguf" id="5_qpeda"]
[ext_resource type="LlamaModel" path="res://models/meta-llama-3-8b-instruct.Q5_K_M.gguf" id="5_qov1l"]

[node name="Node" type="Node"]
script = ExtResource("1_sruc3")
Expand Down Expand Up @@ -68,7 +68,8 @@ icon = ExtResource("1_gjsev")
expand_icon = true

[node name="LlamaContext" type="LlamaContext" parent="."]
model = ExtResource("5_qpeda")
model = ExtResource("5_qov1l")
temperature = 0.9
unique_name_in_owner = true

[connection signal="submit" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" to="." method="_on_text_edit_submit"]
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
Submodule llama.cpp updated 58 files
+77 −29 .github/workflows/build.yml
+19 −1 CMakeLists.txt
+45 −0 CMakePresets.json
+4 −1 README.md
+16 −0 cmake/arm64-windows-llvm.cmake
+6 −0 cmake/arm64-windows-msvc.cmake
+10 −0 common/common.cpp
+1 −0 common/common.h
+1 −1 common/grammar-parser.cpp
+6 −6 common/json-schema-to-grammar.cpp
+5 −5 common/log.h
+3 −3 convert-hf-to-gguf-update.py
+31 −47 convert-hf-to-gguf.py
+155 −25 convert.py
+3 −0 examples/CMakeLists.txt
+1 −0 examples/embedding/embedding.cpp
+21 −6 examples/llava/llava-cli.cpp
+0 −15 examples/llava/llava.cpp
+59 −1 examples/perplexity/README.md
+3 −1 examples/quantize/README.md
+2 −0 examples/rpc/CMakeLists.txt
+74 −0 examples/rpc/README.md
+130 −0 examples/rpc/rpc-server.cpp
+1 −1 examples/server/README.md
+8 −7 examples/server/bench/bench.py
+7 −0 examples/server/server.cpp
+5 −2 examples/server/tests/features/steps/steps.py
+1 −1 examples/server/utils.hpp
+0 −1 ggml-backend.c
+11 −2 ggml-cuda.cu
+4 −0 ggml-cuda/common.cuh
+47 −0 ggml-cuda/fattn-common.cuh
+430 −0 ggml-cuda/fattn-vec-f16.cu
+5 −0 ggml-cuda/fattn-vec-f16.cuh
+384 −0 ggml-cuda/fattn-vec-f32.cu
+3 −0 ggml-cuda/fattn-vec-f32.cuh
+15 −453 ggml-cuda/fattn.cu
+33 −30 ggml-cuda/upscale.cu
+7 −0 ggml-impl.h
+48 −35 ggml-metal.m
+33 −41 ggml-metal.metal
+2,195 −27 ggml-quants.c
+1,023 −0 ggml-rpc.cpp
+24 −0 ggml-rpc.h
+5 −24 ggml-sycl.cpp
+306 −161 ggml.c
+16 −2 ggml.h
+1 −0 gguf-py/gguf/__init__.py
+11 −5 gguf-py/gguf/gguf_writer.py
+20 −9 gguf-py/gguf/lazy.py
+109 −0 gguf-py/gguf/quants.py
+230 −108 llama.cpp
+3 −0 llama.h
+4 −0 scripts/sync-ggml-am.sh
+1 −1 scripts/sync-ggml.last
+2 −0 scripts/sync-ggml.sh
+45 −17 tests/test-backend-ops.cpp
+46 −0 tests/test-grammar-integration.cpp
71 changes: 61 additions & 10 deletions src/llama_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ void LlamaContext::_bind_methods() {
ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed);
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed");

ClassDB::bind_method(D_METHOD("get_temperature"), &LlamaContext::get_temperature);
ClassDB::bind_method(D_METHOD("set_temperature", "temperature"), &LlamaContext::set_temperature);
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "temperature"), "set_temperature", "get_temperature");

ClassDB::bind_method(D_METHOD("get_top_p"), &LlamaContext::get_top_p);
ClassDB::bind_method(D_METHOD("set_top_p", "top_p"), &LlamaContext::set_top_p);
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "top_p"), "set_top_p", "get_top_p");

ClassDB::bind_method(D_METHOD("get_frequency_penalty"), &LlamaContext::get_frequency_penalty);
ClassDB::bind_method(D_METHOD("set_frequency_penalty", "frequency_penalty"), &LlamaContext::set_frequency_penalty);
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "frequency_penalty"), "set_frequency_penalty", "get_frequency_penalty");

ClassDB::bind_method(D_METHOD("get_presence_penalty"), &LlamaContext::get_presence_penalty);
ClassDB::bind_method(D_METHOD("set_presence_penalty", "presence_penalty"), &LlamaContext::set_presence_penalty);
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "presence_penalty"), "set_presence_penalty", "get_presence_penalty");

ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx);
ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");
Expand Down Expand Up @@ -106,13 +122,13 @@ void LlamaContext::__thread_loop() {
shared_prefix_idx = std::min(context_tokens.size(), request_tokens.size());
}

bool rm_success = llama_kv_cache_seq_rm(ctx, 0, shared_prefix_idx, -1);
bool rm_success = llama_kv_cache_seq_rm(ctx, -1, shared_prefix_idx, -1);
if (!rm_success) {
UtilityFunctions::printerr(vformat("%s: Failed to remove tokens from kv cache", __func__));
Dictionary response;
response["id"] = req.id;
response["error"] = "Failed to remove tokens from kv cache";
call_deferred("emit_signal", "completion_generated", response);
call_thread_safe("emit_signal", "completion_generated", response);
continue;
}
context_tokens.erase(context_tokens.begin() + shared_prefix_idx, context_tokens.end());
Expand All @@ -128,6 +144,14 @@ void LlamaContext::__thread_loop() {
sequences.push_back(std::vector<llama_token>(request_tokens.begin() + i, request_tokens.begin() + std::min(i + batch_size, request_tokens.size())));
}

printf("Request tokens: \n");
for (auto sequence : sequences) {
for (auto token : sequence) {
printf("%s", llama_token_to_piece(ctx, token).c_str());
}
}
printf("\n");

int curr_token_pos = context_tokens.size();
bool decode_failed = false;

Expand Down Expand Up @@ -155,7 +179,7 @@ void LlamaContext::__thread_loop() {
Dictionary response;
response["id"] = req.id;
response["error"] = "llama_decode() failed";
call_deferred("emit_signal", "completion_generated", response);
call_thread_safe("emit_signal", "completion_generated", response);
continue;
}

Expand All @@ -171,17 +195,17 @@ void LlamaContext::__thread_loop() {
Dictionary response;
response["id"] = req.id;

context_tokens.push_back(new_token_id);

if (llama_token_is_eog(model->model, new_token_id) || curr_token_pos == n_len) {
response["done"] = true;
call_deferred("emit_signal", "completion_generated", response);
call_thread_safe("emit_signal", "completion_generated", response);
break;
}

context_tokens.push_back(new_token_id);

response["text"] = llama_token_to_piece(ctx, new_token_id).c_str();
response["done"] = false;
call_deferred("emit_signal", "completion_generated", response);
call_thread_safe("emit_signal", "completion_generated", response);

llama_batch_clear(batch);

Expand All @@ -199,11 +223,9 @@ void LlamaContext::__thread_loop() {
Dictionary response;
response["id"] = req.id;
response["error"] = "llama_decode() failed";
call_deferred("emit_signal", "completion_generated", response);
call_thread_safe("emit_signal", "completion_generated", response);
continue;
}

llama_sampling_reset(sampling_ctx);
}
}

Expand Down Expand Up @@ -258,6 +280,34 @@ void LlamaContext::set_n_len(int n_len) {
this->n_len = n_len;
}

float LlamaContext::get_temperature() {
return sampling_params.temp;
}
void LlamaContext::set_temperature(float temperature) {
sampling_params.temp = temperature;
}

float LlamaContext::get_top_p() {
return sampling_params.top_p;
}
void LlamaContext::set_top_p(float top_p) {
sampling_params.top_p = top_p;
}

float LlamaContext::get_frequency_penalty() {
return sampling_params.penalty_freq;
}
void LlamaContext::set_frequency_penalty(float frequency_penalty) {
sampling_params.penalty_freq = frequency_penalty;
}

float LlamaContext::get_presence_penalty() {
return sampling_params.penalty_present;
}
void LlamaContext::set_presence_penalty(float presence_penalty) {
sampling_params.penalty_present = presence_penalty;
}

void LlamaContext::_exit_tree() {
if (Engine::get_singleton()->is_editor_hint()) {
return;
Expand All @@ -275,5 +325,6 @@ void LlamaContext::_exit_tree() {
llama_free(ctx);
}

llama_sampling_free(sampling_ctx);
llama_backend_free();
}
8 changes: 8 additions & 0 deletions src/llama_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ class LlamaContext : public Node {
void set_n_ctx(int n_ctx);
int get_n_len();
void set_n_len(int n_len);
float get_temperature();
void set_temperature(float temperature);
float get_top_p();
void set_top_p(float top_p);
float get_frequency_penalty();
void set_frequency_penalty(float frequency_penalty);
float get_presence_penalty();
void set_presence_penalty(float presence_penalty);

virtual PackedStringArray _get_configuration_warnings() const override;
virtual void _ready() override;
Expand Down

0 comments on commit 17f069c

Please sign in to comment.