Skip to content

Commit

Permalink
revert tokenize ffi (#4761)
Browse files Browse the repository at this point in the history
* Revert "use `int32_t` for call to tokenize (#4738)"

This reverts commit 763bb65.

* Revert "vocab only"

This reverts commit bf54c84.

* Revert "use ffi for tokenizing/detokenizing"

This reverts commit 26a00a0.
  • Loading branch information
mxyng committed Jun 1, 2024
1 parent f6b622c commit 829ff87
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 72 deletions.
43 changes: 43 additions & 0 deletions llm/ext_server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2625,6 +2625,21 @@ static json format_partial_response(
return res;
}

static json format_tokenizer_response(const std::vector<llama_token> &tokens)
{
return json {
{"tokens", tokens}
};
}

static json format_detokenized_response(std::string content)
{
return json {
{"content", content}
};
}


static void log_server_request(const httplib::Request &req, const httplib::Response &res)
{
// skip GH copilot requests when using default port
Expand Down Expand Up @@ -3114,6 +3129,34 @@ int main(int argc, char **argv) {
}
});

svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
std::vector<llama_token> tokens;
if (body.count("content") != 0)
{
tokens = llama.tokenize(body["content"], false);
}
const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json; charset=utf-8");
});

svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
std::string content;
if (body.count("tokens") != 0)
{
const std::vector<llama_token> tokens = body["tokens"];
content = tokens_to_str(llama.ctx, tokens.cbegin(), tokens.cend());
}

const json data = format_detokenized_response(content);
return res.set_content(data.dump(), "application/json; charset=utf-8");
});

svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
Expand Down
60 changes: 0 additions & 60 deletions llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ package llm
import "C"
import (
"fmt"
"strings"
"unsafe"
)

Expand All @@ -38,62 +37,3 @@ func Quantize(infile, outfile string, ftype fileType) error {

return nil
}

type llamaModel struct {
m *C.struct_llama_model
}

func newLlamaModel(p string) *llamaModel {
cs := C.CString(p)
defer C.free(unsafe.Pointer(cs))

params := C.llama_model_default_params()
params.vocab_only = true

return &llamaModel{
C.llama_load_model_from_file(cs, params),
}
}

func (llm *llamaModel) Close() {
C.llama_free_model(llm.m)
}

func (llm *llamaModel) Tokenize(s string) []int {
cs := C.CString(s)
defer C.free(unsafe.Pointer(cs))

ltokens := make([]C.llama_token, len(s)+2)
n := C.llama_tokenize(
llm.m,
cs,
C.int32_t(len(s)),
&ltokens[0],
C.int32_t(len(ltokens)),
false,
true,
)

if n < 0 {
return nil
}

tokens := make([]int, n)
for i := 0; i < int(n); i++ {
tokens[i] = int(ltokens[i])
}

return tokens
}

func (llm *llamaModel) Detokenize(i32s []int) string {
var sb strings.Builder
for _, i32 := range i32s {
c := make([]byte, 512)
if n := C.llama_token_to_piece(llm.m, C.llama_token(i32), (*C.char)(unsafe.Pointer(&c[0])), C.int(len(c)), false); n > 0 {
sb.WriteString(unsafe.String(&c[0], n))
}
}

return sb.String()
}
113 changes: 101 additions & 12 deletions llm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ type llmServer struct {
loadDuration time.Duration // Record how long it took the model to load
loadProgress float32

*llamaModel

sem *semaphore.Weighted
}

Expand Down Expand Up @@ -311,7 +309,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
totalLayers: ggml.KV().BlockCount() + 1,
gpuCount: gpuCount,
done: make(chan error, 1),
llamaModel: newLlamaModel(model),
}

s.cmd.Env = os.Environ()
Expand Down Expand Up @@ -849,12 +846,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}

var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(EmbeddingRequest{Content: prompt}); err != nil {
data, err := json.Marshal(TokenizeRequest{Content: prompt})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), &b)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err)
}
Expand Down Expand Up @@ -884,12 +881,108 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return embedding.Embedding, nil
}

type TokenizeRequest struct {
Content string `json:"content"`
}

type TokenizeResponse struct {
Tokens []int `json:"tokens"`
}

func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
return s.llamaModel.Tokenize(content), nil
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return nil, err
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}

data, err := json.Marshal(TokenizeRequest{Content: content})
if err != nil {
return nil, fmt.Errorf("marshaling encode data: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("encode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("do encode request: %w", err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read encode request: %w", err)
}

if resp.StatusCode >= 400 {
log.Printf("llm encode error: %s", body)
return nil, fmt.Errorf("%s", body)
}

var encoded TokenizeResponse
if err := json.Unmarshal(body, &encoded); err != nil {
return nil, fmt.Errorf("unmarshal encode response: %w", err)
}

return encoded.Tokens, nil
}

type DetokenizeRequest struct {
Tokens []int `json:"tokens"`
}

type DetokenizeResponse struct {
Content string `json:"content"`
}

func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
return s.llamaModel.Detokenize(tokens), nil
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return "", err
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
}

data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
if err != nil {
return "", fmt.Errorf("marshaling decode data: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
if err != nil {
return "", fmt.Errorf("decode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("do decode request: %w", err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read decode request: %w", err)
}

if resp.StatusCode >= 400 {
log.Printf("llm decode error: %s", body)
return "", fmt.Errorf("%s", body)
}

var decoded DetokenizeResponse
if err := json.Unmarshal(body, &decoded); err != nil {
return "", fmt.Errorf("unmarshal encode response: %w", err)
}

return decoded.Content, nil
}

func (s *llmServer) Close() error {
Expand All @@ -907,10 +1000,6 @@ func (s *llmServer) Close() error {
slog.Debug("llama server stopped")
}

if s.llamaModel != nil {
s.llamaModel.Close()
}

return nil
}

Expand Down

0 comments on commit 829ff87

Please sign in to comment.