Skip to content

Commit

Permalink
feat(Context.php): add getModelParameters method to Context class
Browse files Browse the repository at this point in the history
feat(InvalidArgumentException.php): add InvalidArgumentException class
feat(LLamaCPP.php): add validation for embedding generation in createEmbedding method

test(LLamaCPPTest.php): add test case for createEmbedding method with embedding set to false and expect InvalidArgumentException
  • Loading branch information
kambo-1st committed Apr 24, 2023
1 parent 741b9f0 commit 5762c8f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/Context.php
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,12 @@ public function getCtx(): CData
{
return $this->ctx;
}

/**
* @return ModelParameters
*/
public function getModelParameters(): ModelParameters
{
return $this->modelParameters;
}
}
7 changes: 7 additions & 0 deletions src/Exception/InvalidArgumentException.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
<?php

namespace Kambo\LLamaCPP\Exception;

final class InvalidArgumentException extends LLamaCppException
{
}
5 changes: 5 additions & 0 deletions src/LLamaCPP.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use Symfony\Component\EventDispatcher\EventDispatcherInterface;
use Kambo\LLamaCPP\Native\LLamaCPPFFI;
use Kambo\LLamaCPP\Events\TokenGeneratedEvent;
use Kambo\LLamaCPP\Exception\InvalidArgumentException;
use Generator;

use function strlen;
Expand Down Expand Up @@ -95,6 +96,10 @@ public function generateAll(string $prompt, ?GenerationParameters $generation =

public function createEmbedding(string $text, int $noOfThreads = 10): array
{
if (!$this->context->getModelParameters()->isEmbedding()) {
throw new InvalidArgumentException('Generation must of embedding must be turned on.');
}

$input = $this->ffi->newArray('llama_token', strlen($text));
$nOfTok = $this->ffi->llama_tokenize($this->context->getCtx(), $text, $input, strlen($text), true);

Expand Down
24 changes: 24 additions & 0 deletions tests/LLamaCPPTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
use Kambo\LLamaCPP\Context;
use Kambo\LLamaCPP\LLamaCPP;
use Kambo\LLamaCPP\Native\LLamaCPPFFI;
use Kambo\LLamaCPP\Parameters\ModelParameters;
use Kambo\LLamaCPP\Exception\InvalidArgumentException;
use Symfony\Component\EventDispatcher\EventDispatcherInterface;
use Generator;
use FFI;
Expand Down Expand Up @@ -106,6 +108,13 @@ public function testGenerateAll()

public function testCreateEmbedding()
{
$modelParameters = new ModelParameters(
modelPath: 'test',
embedding: true,
);
$this->contextMock->method('getModelParameters')
->willReturn($modelParameters);

$this->contextMock->method('getCtx')
->willReturn(FFI::new('int'));

Expand Down Expand Up @@ -139,4 +148,19 @@ public function testCreateEmbedding()
$result
);
}

public function testCreateEmbeddingFail()
{
$modelParameters = new ModelParameters(
modelPath: 'test',
embedding: false,
);
$this->contextMock->method('getModelParameters')
->willReturn($modelParameters);

$llamaCPP = new LLamaCPP($this->contextMock, $this->eventDispatcherMock, $this->ffiMock);

$this->expectException(InvalidArgumentException::class);
$llamaCPP->createEmbedding('test');
}
}

0 comments on commit 5762c8f

Please sign in to comment.