Skip to content

Commit

Permalink
feat(embedding): add createEmbedding method to LLamaCPP class
Browse files Browse the repository at this point in the history
feat(embedding): add llama_n_embd and llama_get_embeddings methods to LLamaCPPFFI class

test(LLamaCPPTest.php): add testCreateEmbedding method to test createEmbedding function in LLamaCPP class
  • Loading branch information
kambo-1st committed Apr 24, 2023
1 parent 5977e98 commit 741b9f0
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 0 deletions.
21 changes: 21 additions & 0 deletions examples/embedding.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
<?php

require_once __DIR__ . '/../vendor/autoload.php';

use Kambo\LLamaCPP\LLamaCPP;
use Kambo\LLamaCPP\Context;
use Kambo\LLamaCPP\Parameters\ModelParameters;
use Kambo\LLamaCPP\Parameters\GenerationParameters;

$template = "You are a programmer, write PHP class that will add two numbers and print the result. Stop at class end.";
$context = Context::createWithParameter(
new ModelParameters(
modelPath:__DIR__ .'/models/ggjt-model.bin',
embedding: true,
)
);
$llama = new LLamaCPP($context);

$embeddings = $llama->createEmbedding($template);

var_dump($embeddings);
20 changes: 20 additions & 0 deletions src/LLamaCPP.php
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,24 @@ public function generateAll(string $prompt, ?GenerationParameters $generation =

return implode('', $tokens);
}

public function createEmbedding(string $text, int $noOfThreads = 10): array
{
$input = $this->ffi->newArray('llama_token', strlen($text));
$nOfTok = $this->ffi->llama_tokenize($this->context->getCtx(), $text, $input, strlen($text), true);

for ($i = 0; $i < $nOfTok; $i++) {
$this->ffi->llama_eval($this->context->getCtx(), $input + $i, 1, $i, $noOfThreads);
}

$nCount = $this->ffi->llama_n_embd($this->context->getCtx());
$embedding = $this->ffi->llama_get_embeddings($this->context->getCtx());

$embeddings = [];
for ($i = 0; $i < $nCount; $i++) {
$embeddings[] = $embedding[$i];
}

return $embeddings;
}
}
17 changes: 17 additions & 0 deletions src/Native/LLamaCPPFFI.php
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,21 @@ public function llama_token_eos()
{
return $this->fii->llama_token_eos();
}

public function llama_n_embd(CData $ctx): int
{
return $this->fii->llama_n_embd($ctx);
}

/**
* Get the embeddings for the input
*
* @param CData $ctx
*
* @return ?CData
*/
public function llama_get_embeddings(CData $ctx): ?CData
{
return $this->fii->llama_get_embeddings($ctx);
}
}
36 changes: 36 additions & 0 deletions tests/LLamaCPPTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,40 @@ public function testGenerateAll()

$this->assertEquals('test', $result);
}

public function testCreateEmbedding()
{
$this->contextMock->method('getCtx')
->willReturn(FFI::new('int'));

$this->ffiMock->method('newArray')
->willReturn(FFI::new('int[100]'));

$this->ffiMock->method('llama_tokenize')
->willReturn(1);

$this->ffiMock->method('llama_eval');

$this->ffiMock->method('llama_token_eos')
->willReturn(0);

$this->ffiMock->method('llama_n_embd')
->willReturn(5);

$testArray = FFI::new('int[5]');
foreach ([5,2,3,4,1] as $key => $value) {
$testArray[$key] = $value;
}
$this->ffiMock->method('llama_get_embeddings')
->willReturn($testArray);

$llamaCPP = new LLamaCPP($this->contextMock, $this->eventDispatcherMock, $this->ffiMock);

$result = $llamaCPP->createEmbedding('test');

$this->assertEquals(
[5,2,3,4,1],
$result
);
}
}

0 comments on commit 741b9f0

Please sign in to comment.