Skip to content

Commit

Permalink
working inference
Browse files Browse the repository at this point in the history
  • Loading branch information
hazelnutcloud committed May 10, 2024
1 parent bc4b614 commit 6960f3b
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 21 deletions.
5 changes: 4 additions & 1 deletion godot/examples/simple/TextEdit.gd
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ func _gui_input(event: InputEvent) -> void:
accept_event()
if keycode == KEY_ENTER | KEY_MASK_SHIFT and event.is_pressed():
insert_text_at_caret("\n")
accept_event()

func _on_button_pressed() -> void:
handle_submit()

func handle_submit() -> void:
submit.emit(text)
text = ""

7 changes: 4 additions & 3 deletions godot/examples/simple/simple.gd
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func handle_input(input: String) -> void:
var id = llama_context.request_completion(input)
print("request id: ", id)

var chunk = await llama_context.completion_generated
print('new chunk: ', chunk)



func _on_llama_context_completion_generated(chunk: Dictionary) -> void:
print("new chunk: ", chunk)
14 changes: 4 additions & 10 deletions godot/examples/simple/simple.tscn
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
[ext_resource type="LlamaModel" path="res://models/Phi-3-mini-128k-instruct.Q5_K_M.gguf" id="1_ff70a"]
[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="1_gjsev"]
[ext_resource type="Script" path="res://examples/simple/simple.gd" id="1_sruc3"]
[ext_resource type="PackedScene" uid="uid://t862t0v8ht2q" path="res://examples/simple/message.tscn" id="2_7iip7"]
[ext_resource type="Script" path="res://examples/simple/TextEdit.gd" id="2_7usqw"]
[ext_resource type="Script" path="res://examples/simple/form.gd" id="2_p1ih5"]

[node name="Node" type="Node"]
script = ExtResource("1_sruc3")
Expand Down Expand Up @@ -41,20 +41,14 @@ layout_mode = 2
size_flags_horizontal = 3
size_flags_vertical = 3

[node name="RichTextLabel" type="RichTextLabel" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer"]
[node name="RichTextLabel2" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer" instance=ExtResource("2_7iip7")]
layout_mode = 2
focus_mode = 2
text = "How can I help you?"
fit_content = true
scroll_active = false
selection_enabled = true

[node name="HBoxContainer" type="HBoxContainer" parent="Panel/MarginContainer/VBoxContainer"]
layout_mode = 2
script = ExtResource("2_p1ih5")

[node name="TextEdit" type="TextEdit" parent="Panel/MarginContainer/VBoxContainer/HBoxContainer"]
unique_name_in_owner = true
custom_minimum_size = Vector2(2.08165e-12, 100)
layout_mode = 2
size_flags_horizontal = 3
Expand All @@ -73,5 +67,5 @@ model = ExtResource("1_ff70a")
unique_name_in_owner = true

[connection signal="submit" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" to="." method="_on_text_edit_submit"]
[connection signal="pressed" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/Button" to="Panel/MarginContainer/VBoxContainer/HBoxContainer" method="_on_button_pressed"]
[connection signal="pressed" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/Button" to="Panel/MarginContainer/VBoxContainer/HBoxContainer/Button" method="_on_pressed"]
[connection signal="pressed" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/Button" to="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" method="_on_button_pressed"]
[connection signal="completion_generated" from="LlamaContext" to="." method="_on_llama_context_completion_generated"]
1 change: 1 addition & 0 deletions godot/project.godot
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ config_version=5
[application]

config/name="godot-llama-cpp"
run/main_scene="res://examples/simple/simple.tscn"
config/features=PackedStringArray("4.2", "Forward Plus")
config/icon="res://icon.svg"

Expand Down
148 changes: 142 additions & 6 deletions src/llama_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
#include "common.h"
#include "llama.h"
#include "llama_model.h"
#include <algorithm>
#include <godot_cpp/classes/engine.hpp>
#include <godot_cpp/classes/os.hpp>
#include <godot_cpp/classes/worker_thread_pool.hpp>
#include <godot_cpp/core/class_db.hpp>
#include <godot_cpp/variant/utility_functions.hpp>
#include <godot_cpp/variant/dictionary.hpp>
#include <godot_cpp/variant/utility_functions.hpp>

using namespace godot;

Expand All @@ -24,6 +25,10 @@ void LlamaContext::_bind_methods() {
ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");

ClassDB::bind_method(D_METHOD("get_n_len"), &LlamaContext::get_n_len);
ClassDB::bind_method(D_METHOD("set_n_len", "n_len"), &LlamaContext::set_n_len);
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_len"), "set_n_len", "get_n_len");

ClassDB::bind_method(D_METHOD("request_completion", "prompt"), &LlamaContext::request_completion);
ClassDB::bind_method(D_METHOD("__thread_loop"), &LlamaContext::__thread_loop);

Expand Down Expand Up @@ -63,6 +68,9 @@ void LlamaContext::_ready() {
UtilityFunctions::printerr(vformat("%s: Failed to initialize llama context, null ctx", __func__));
return;
}

sampling_ctx = llama_sampling_init(sampling_params);

UtilityFunctions::print(vformat("%s: Context initialized", __func__));

thread->start(callable_mp(this, &LlamaContext::__thread_loop));
Expand All @@ -73,6 +81,10 @@ void LlamaContext::__thread_loop() {
semaphore->wait();

mutex->lock();
if (exit_thread) {
mutex->unlock();
break;
}
if (completion_requests.size() == 0) {
mutex->unlock();
continue;
Expand All @@ -83,10 +95,115 @@ void LlamaContext::__thread_loop() {

UtilityFunctions::print(vformat("%s: Running completion for prompt id: %d", __func__, req.id));

Dictionary chunk;
chunk["id"] = req.id;
chunk["text"] = "Hello, world!";
call_deferred("emit_signal", "completion_generated", chunk);
std::vector<llama_token> request_tokens;
request_tokens = ::llama_tokenize(ctx, req.prompt.utf8().get_data(), true);

size_t shared_prefix_idx = 0;
auto diff = std::mismatch(context_tokens.begin(), context_tokens.end(), request_tokens.begin(), request_tokens.end());
if (diff.first != context_tokens.end()) {
shared_prefix_idx = std::distance(context_tokens.begin(), diff.first);
} else {
shared_prefix_idx = std::min(context_tokens.size(), request_tokens.size());
}

bool rm_success = llama_kv_cache_seq_rm(ctx, 0, shared_prefix_idx, -1);
if (!rm_success) {
UtilityFunctions::printerr(vformat("%s: Failed to remove tokens from kv cache", __func__));
Dictionary response;
response["id"] = req.id;
response["error"] = "Failed to remove tokens from kv cache";
call_deferred("emit_signal", "completion_generated", response);
continue;
}
context_tokens.erase(context_tokens.begin() + shared_prefix_idx, context_tokens.end());
request_tokens.erase(request_tokens.begin(), request_tokens.begin() + shared_prefix_idx);

uint batch_size = std::min(ctx_params.n_batch, (uint)request_tokens.size());

llama_batch batch = llama_batch_init(batch_size, 0, 1);

// chunk request_tokens into sequences of size batch_size
std::vector<std::vector<llama_token>> sequences;
for (size_t i = 0; i < request_tokens.size(); i += batch_size) {
sequences.push_back(std::vector<llama_token>(request_tokens.begin() + i, request_tokens.begin() + std::min(i + batch_size, request_tokens.size())));
}

int curr_token_pos = context_tokens.size();
bool decode_failed = false;

for (size_t i = 0; i < sequences.size(); i++) {
llama_batch_clear(batch);

std::vector<llama_token> sequence = sequences[i];

for (size_t j = 0; j < sequence.size(); j++) {
llama_batch_add(batch, sequence[j], j + curr_token_pos, { 0 }, false);
curr_token_pos++;
}

if (i == sequences.size() - 1) {
batch.logits[batch.n_tokens - 1] = true;
}

if (llama_decode(ctx, batch) != 0) {
decode_failed = true;
break;
}
}

if (decode_failed) {
Dictionary response;
response["id"] = req.id;
response["error"] = "llama_decode() failed";
call_deferred("emit_signal", "completion_generated", response);
continue;
}

context_tokens.insert(context_tokens.end(), request_tokens.begin(), request_tokens.end());

while (true) {
if (exit_thread) {
return;
}
llama_token new_token_id = llama_sampling_sample(sampling_ctx, ctx, NULL, batch.n_tokens - 1);
llama_sampling_accept(sampling_ctx, ctx, new_token_id, true);

Dictionary response;
response["id"] = req.id;

if (llama_token_is_eog(model->model, new_token_id) || curr_token_pos == n_len) {
response["done"] = true;
call_deferred("emit_signal", "completion_generated", response);
break;
}

context_tokens.push_back(new_token_id);

response["text"] = llama_token_to_piece(ctx, new_token_id).c_str();
response["done"] = false;
call_deferred("emit_signal", "completion_generated", response);

llama_batch_clear(batch);

llama_batch_add(batch, new_token_id, curr_token_pos, { 0 }, true);

curr_token_pos++;

if (llama_decode(ctx, batch) != 0) {
decode_failed = true;
break;
}
}

if (decode_failed) {
Dictionary response;
response["id"] = req.id;
response["error"] = "llama_decode() failed";
call_deferred("emit_signal", "completion_generated", response);
continue;
}

llama_sampling_reset(sampling_ctx);
}
}

Expand Down Expand Up @@ -134,7 +251,26 @@ void LlamaContext::set_n_ctx(int n_ctx) {
ctx_params.n_ctx = n_ctx;
}

LlamaContext::~LlamaContext() {
int LlamaContext::get_n_len() {
return n_len;
}
void LlamaContext::set_n_len(int n_len) {
this->n_len = n_len;
}

void LlamaContext::_exit_tree() {
if (Engine::get_singleton()->is_editor_hint()) {
return;
}

mutex->lock();
exit_thread = true;
mutex->unlock();

semaphore->post();

thread->wait_to_finish();

if (ctx) {
llama_free(ctx);
}
Expand Down
10 changes: 9 additions & 1 deletion src/llama_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define LLAMA_CONTEXT_H

#include "llama.h"
#include "common.h"
#include "llama_model.h"
#include <godot_cpp/classes/mutex.hpp>
#include <godot_cpp/classes/node.hpp>
Expand All @@ -21,13 +22,18 @@ class LlamaContext : public Node {
private:
Ref<LlamaModel> model;
llama_context *ctx = nullptr;
llama_sampling_context *sampling_ctx = nullptr;
llama_context_params ctx_params;
llama_sampling_params sampling_params;
int n_len = 1024;
int request_id = 0;
Vector<completion_request> completion_requests;

Ref<Thread> thread;
Ref<Semaphore> semaphore;
Ref<Mutex> mutex;
std::vector<llama_token> context_tokens;
bool exit_thread = false;

protected:
static void _bind_methods();
Expand All @@ -43,11 +49,13 @@ class LlamaContext : public Node {
void set_seed(int seed);
int get_n_ctx();
void set_n_ctx(int n_ctx);
int get_n_len();
void set_n_len(int n_len);

virtual PackedStringArray _get_configuration_warnings() const override;
virtual void _ready() override;
virtual void _exit_tree() override;
LlamaContext();
~LlamaContext();
};
} //namespace godot

Expand Down

0 comments on commit 6960f3b

Please sign in to comment.