From b1791806dceae48f730b22a32fb07f306a54e88d Mon Sep 17 00:00:00 2001 From: Slava Razum Date: Tue, 20 Dec 2022 21:46:42 +0200 Subject: [PATCH 1/7] Stream tests --- tests/Pest.php | 3 +- tests/Resources/Completions.php | 37 +++++++++++++++++++++++ tests/Streams/EventStream.php | 41 ++++++++++++++++++++++++++ tests/Streams/Stream.php | 21 +++++++++++++ tests/Transporters/HttpTransporter.php | 30 +++++++++++++++---- 5 files changed, 125 insertions(+), 7 deletions(-) create mode 100644 tests/Streams/EventStream.php create mode 100644 tests/Streams/Stream.php diff --git a/tests/Pest.php b/tests/Pest.php index 72156a21..12e4ebb6 100644 --- a/tests/Pest.php +++ b/tests/Pest.php @@ -1,13 +1,14 @@ completionTokens->toBe(16) ->totalTokens->toBe(17); }); + +test('create stream', function () { + $client = mockClient('POST', 'completions', [ + 'model' => 'da-vince', + 'prompt' => 'hi', + 'stream' => true, + ], new EventStream(Utils::streamFor('data: ' . json_encode(completion())))); + + $result = $client->completions()->create([ + 'model' => 'da-vince', + 'prompt' => 'hi', + 'stream' => true, + ]); + + expect($result)->toBeInstanceOf(Stream::class); + + $data = iterator_to_array($result->read())[0]; + + expect($data) + ->id->toBe('cmpl-5uS6a68SwurhqAqLBpZtibIITICna') + ->object->toBe('text_completion') + ->created->toBe(1664136088) + ->model->toBe('davinci') + ->choices->toBeArray()->toHaveCount(1) + ->choices->each->toBeInstanceOf(CreateResponseChoice::class) + ->usage->toBeInstanceOf(CreateResponseUsage::class); + + expect($data->choices[0]) + ->text->toBe("el, she elaborates more on the Corruptor's role, suggesting K") + ->index->toBe(0) + ->logprobs->toBe(null) + ->finishReason->toBe('length'); +}); + diff --git a/tests/Streams/EventStream.php b/tests/Streams/EventStream.php new file mode 100644 index 00000000..d116c316 --- /dev/null +++ b/tests/Streams/EventStream.php @@ -0,0 +1,41 @@ +read())[0])->toBe(['text' => 'Hey!']); +}); + +test('skips empty stream lines', function () { + $stream = new EventStream(Utils::streamFor("data: {\"text\": \"Hey\"}\n\ndata: {\"text\": \"there!\"}")); + + expect(iterator_to_array($stream->read()))->toBe([['text' => 'Hey'], ['text' => 'there!']]); +}); + +test('aborts after done message', function () { + $stream = new EventStream(Utils::streamFor("data: {\"text\": \"Hey\"}\ndata: [DONE]\ndata: {\"text\": \"there!\"}")); + + expect(iterator_to_array($stream->read()))->toBe([['text' => 'Hey']]); +}); + +test('stream message serialization error', function () { + $stream = new EventStream(Utils::streamFor("data: invalid")); + + iterator_to_array($stream->read()); +})->throws(UnserializableResponse::class); + +test('stream message error', function () { + $stream = new EventStream(Utils::streamFor('data: ' . json_encode(['error' => [ + 'message' => 'Something went wrong.', + 'type' => 'invalid_request_error', + 'param' => null, + 'code' => 'invalid_request_error', + ]]))); + + iterator_to_array($stream->read()); +})->throws(ErrorException::class, 'Something went wrong'); diff --git a/tests/Streams/Stream.php b/tests/Streams/Stream.php new file mode 100644 index 00000000..3ee26e7e --- /dev/null +++ b/tests/Streams/Stream.php @@ -0,0 +1,21 @@ +read()))->toBe([ + ['text' => 'Hey ', 'test' => true], + ['text' => 'there!', 'test' => true], + ]); +}); diff --git a/tests/Transporters/HttpTransporter.php b/tests/Transporters/HttpTransporter.php index b8806214..83243435 100644 --- a/tests/Transporters/HttpTransporter.php +++ b/tests/Transporters/HttpTransporter.php @@ -1,8 +1,10 @@ client = Mockery::mock(ClientInterface::class); @@ -34,7 +35,7 @@ ])); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->withArgs(function (Psr7Request $request) { expect($request->getMethod())->toBe('GET') @@ -62,7 +63,7 @@ ])); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andReturn($response); @@ -78,6 +79,23 @@ ]); }); +test('request object stream response', function () { + $payload = Payload::create('completions', ['stream' => true]); + + $response = new Response(200, ['Content-Type' => 'text/event-stream'], json_encode([ + 'qdwq' + ])); + + $this->client + ->shouldReceive('send') + ->once() + ->andReturn($response); + + $response = $this->http->requestObject($payload, true); + + expect($response)->toBeInstanceOf(Stream::class); +}); + test('request object server errors', function () { $payload = Payload::list('models'); @@ -91,7 +109,7 @@ ])); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andReturn($response); @@ -111,7 +129,7 @@ $headers = Headers::withAuthorization(ApiToken::from('foo')); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andThrow(new ConnectException('Could not resolve host.', $payload->toRequest($baseUri, $headers))); @@ -128,7 +146,7 @@ $response = new Response(200, [], 'err'); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andReturn($response); From 6ee0aa8d8de3f1187e598003308d99af2bb47cb9 Mon Sep 17 00:00:00 2001 From: Slava Razum Date: Tue, 20 Dec 2022 21:47:28 +0200 Subject: [PATCH 2/7] Stream support for completion endpoint --- src/Contracts/Stream.php | 13 ++++++ src/Contracts/Transporter.php | 2 +- src/Resources/Completions.php | 14 ++++-- src/Responses/Completions/CreateResponse.php | 6 +-- .../Completions/CreateResponseChoice.php | 2 +- src/Streams/EventStream.php | 46 +++++++++++++++++++ src/Streams/Stream.php | 22 +++++++++ src/Transporters/HttpTransporter.php | 15 +++--- 8 files changed, 105 insertions(+), 15 deletions(-) create mode 100644 src/Contracts/Stream.php create mode 100644 src/Streams/EventStream.php create mode 100644 src/Streams/Stream.php diff --git a/src/Contracts/Stream.php b/src/Contracts/Stream.php new file mode 100644 index 00000000..f3a12290 --- /dev/null +++ b/src/Contracts/Stream.php @@ -0,0 +1,13 @@ +> + */ + public function read(): iterable; +} diff --git a/src/Contracts/Transporter.php b/src/Contracts/Transporter.php index e3c2eb8a..40b6f717 100644 --- a/src/Contracts/Transporter.php +++ b/src/Contracts/Transporter.php @@ -21,7 +21,7 @@ interface Transporter * * @throws ErrorException|UnserializableResponse|TransporterException */ - public function requestObject(Payload $payload): array; + public function requestObject(Payload $payload, bool $stream = false): array|Stream; /** * Sends a content request to a server. diff --git a/src/Resources/Completions.php b/src/Resources/Completions.php index 69cfd569..f8da3791 100644 --- a/src/Resources/Completions.php +++ b/src/Resources/Completions.php @@ -4,7 +4,9 @@ namespace OpenAI\Resources; +use OpenAI\Contracts\Stream as StreamContract; use OpenAI\Responses\Completions\CreateResponse; +use OpenAI\Streams\Stream; use OpenAI\ValueObjects\Transporter\Payload; final class Completions @@ -18,13 +20,17 @@ final class Completions * * @param array $parameters */ - public function create(array $parameters): CreateResponse + public function create(array $parameters): CreateResponse|StreamContract { $payload = Payload::create('completions', $parameters); - /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}} $result */ - $result = $this->transporter->requestObject($payload); + /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}}|StreamContract $result */ + $result = $this->transporter->requestObject($payload, $parameters['stream'] ?? false); - return CreateResponse::from($result); + if (is_array($result)) { + return CreateResponse::from($result); + } + + return new Stream($result, fn (array $result): CreateResponse => CreateResponse::from($result)); } } diff --git a/src/Responses/Completions/CreateResponse.php b/src/Responses/Completions/CreateResponse.php index 882c6677..6695b14a 100644 --- a/src/Responses/Completions/CreateResponse.php +++ b/src/Responses/Completions/CreateResponse.php @@ -26,7 +26,7 @@ private function __construct( public readonly int $created, public readonly string $model, public readonly array $choices, - public readonly CreateResponseUsage $usage, + public readonly ?CreateResponseUsage $usage, ) { } @@ -47,7 +47,7 @@ public static function from(array $attributes): self $attributes['created'], $attributes['model'], $choices, - CreateResponseUsage::from($attributes['usage']) + isset($attributes['usage']) ? CreateResponseUsage::from($attributes['usage']) : null ); } @@ -65,7 +65,7 @@ public function toArray(): array static fn (CreateResponseChoice $result): array => $result->toArray(), $this->choices, ), - 'usage' => $this->usage->toArray(), + 'usage' => $this->usage?->toArray(), ]; } } diff --git a/src/Responses/Completions/CreateResponseChoice.php b/src/Responses/Completions/CreateResponseChoice.php index 6fa722fc..f7a84ab2 100644 --- a/src/Responses/Completions/CreateResponseChoice.php +++ b/src/Responses/Completions/CreateResponseChoice.php @@ -10,7 +10,7 @@ private function __construct( public readonly string $text, public readonly int $index, public readonly ?CreateResponseChoiceLogprobs $logprobs, - public readonly string $finishReason, + public readonly ?string $finishReason, ) { } diff --git a/src/Streams/EventStream.php b/src/Streams/EventStream.php new file mode 100644 index 00000000..b34b1fab --- /dev/null +++ b/src/Streams/EventStream.php @@ -0,0 +1,46 @@ +stream->eof()) { + $line = Utils::readLine($this->stream); + + if (! str_starts_with($line, 'data:')) { + continue; + } + + $rawData = substr(strstr($line, 'data: '), 6); + + if ($rawData === "[DONE]\n") { + break; + } + + try { + $data = json_decode($rawData, true, 512, JSON_THROW_ON_ERROR); + } catch (JsonException $jsonException) { + throw new UnserializableResponse($jsonException); + } + + if (isset($data['error'])) { + throw new ErrorException($data['error']); + } + + yield $data; + } + } +} diff --git a/src/Streams/Stream.php b/src/Streams/Stream.php new file mode 100644 index 00000000..8579fffb --- /dev/null +++ b/src/Streams/Stream.php @@ -0,0 +1,22 @@ +stream->read() as $data) { + yield ($this->callback)($data); + } + } +} diff --git a/src/Transporters/HttpTransporter.php b/src/Transporters/HttpTransporter.php index 4913804a..49b66e01 100644 --- a/src/Transporters/HttpTransporter.php +++ b/src/Transporters/HttpTransporter.php @@ -4,16 +4,18 @@ namespace OpenAI\Transporters; +use GuzzleHttp\ClientInterface; use JsonException; +use OpenAI\Contracts\Stream; use OpenAI\Contracts\Transporter; use OpenAI\Exceptions\ErrorException; use OpenAI\Exceptions\TransporterException; use OpenAI\Exceptions\UnserializableResponse; +use OpenAI\Streams\EventStream; use OpenAI\ValueObjects\Transporter\BaseUri; use OpenAI\ValueObjects\Transporter\Headers; use OpenAI\ValueObjects\Transporter\Payload; use Psr\Http\Client\ClientExceptionInterface; -use Psr\Http\Client\ClientInterface; /** * @internal @@ -34,21 +36,22 @@ public function __construct( /** * {@inheritDoc} */ - public function requestObject(Payload $payload): array + public function requestObject(Payload $payload, bool $stream = false): array|Stream { $request = $payload->toRequest($this->baseUri, $this->headers); try { - $response = $this->client->sendRequest($request); + $response = $this->client->send($request, ['stream' => $stream]); } catch (ClientExceptionInterface $clientException) { throw new TransporterException($clientException); } - $contents = $response->getBody()->getContents(); + if ($response->getHeaderLine('Content-Type') === 'text/event-stream') { + return new EventStream($response->getBody()); + } try { - /** @var array{error?: array{message: string, type: string, code: string}} $response */ - $response = json_decode($contents, true, 512, JSON_THROW_ON_ERROR); + $response = json_decode($response->getBody()->getContents(), true, 512, JSON_THROW_ON_ERROR); } catch (JsonException $jsonException) { throw new UnserializableResponse($jsonException); } From 497e5c79bec58d422148ad5e361a986a1a8b27a3 Mon Sep 17 00:00:00 2001 From: Slava Razum Date: Thu, 22 Dec 2022 03:25:06 +0200 Subject: [PATCH 3/7] Lint --- src/Streams/EventStream.php | 3 ++- tests/Resources/Completions.php | 3 +-- tests/Streams/EventStream.php | 4 ++-- tests/Transporters/HttpTransporter.php | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Streams/EventStream.php b/src/Streams/EventStream.php index b34b1fab..1b3cbb9d 100644 --- a/src/Streams/EventStream.php +++ b/src/Streams/EventStream.php @@ -13,7 +13,8 @@ final class EventStream implements Stream { public function __construct( private readonly StreamInterface $stream, - ) {} + ) { + } public function read(): iterable { diff --git a/tests/Resources/Completions.php b/tests/Resources/Completions.php index a25f522d..bbdb50a8 100644 --- a/tests/Resources/Completions.php +++ b/tests/Resources/Completions.php @@ -45,7 +45,7 @@ 'model' => 'da-vince', 'prompt' => 'hi', 'stream' => true, - ], new EventStream(Utils::streamFor('data: ' . json_encode(completion())))); + ], new EventStream(Utils::streamFor('data: '.json_encode(completion())))); $result = $client->completions()->create([ 'model' => 'da-vince', @@ -72,4 +72,3 @@ ->logprobs->toBe(null) ->finishReason->toBe('length'); }); - diff --git a/tests/Streams/EventStream.php b/tests/Streams/EventStream.php index d116c316..dac8460c 100644 --- a/tests/Streams/EventStream.php +++ b/tests/Streams/EventStream.php @@ -24,13 +24,13 @@ }); test('stream message serialization error', function () { - $stream = new EventStream(Utils::streamFor("data: invalid")); + $stream = new EventStream(Utils::streamFor('data: invalid')); iterator_to_array($stream->read()); })->throws(UnserializableResponse::class); test('stream message error', function () { - $stream = new EventStream(Utils::streamFor('data: ' . json_encode(['error' => [ + $stream = new EventStream(Utils::streamFor('data: '.json_encode(['error' => [ 'message' => 'Something went wrong.', 'type' => 'invalid_request_error', 'param' => null, diff --git a/tests/Transporters/HttpTransporter.php b/tests/Transporters/HttpTransporter.php index 83243435..852f15cc 100644 --- a/tests/Transporters/HttpTransporter.php +++ b/tests/Transporters/HttpTransporter.php @@ -83,7 +83,7 @@ $payload = Payload::create('completions', ['stream' => true]); $response = new Response(200, ['Content-Type' => 'text/event-stream'], json_encode([ - 'qdwq' + 'qdwq', ])); $this->client From a1f2d8ad42976199595c69508e3d90f3fd32484f Mon Sep 17 00:00:00 2001 From: Slava Razum Date: Thu, 22 Dec 2022 03:32:21 +0200 Subject: [PATCH 4/7] Stream response based on Generator --- src/Contracts/Transporter.php | 3 ++- src/Resources/Completions.php | 16 +++++++++++----- src/Streams/Stream.php | 22 ---------------------- src/Transporters/HttpTransporter.php | 6 +++--- tests/Pest.php | 3 +-- tests/Resources/Completions.php | 6 +++--- tests/Streams/Stream.php | 21 --------------------- tests/Transporters/HttpTransporter.php | 3 +-- 8 files changed, 21 insertions(+), 59 deletions(-) delete mode 100644 src/Streams/Stream.php delete mode 100644 tests/Streams/Stream.php diff --git a/src/Contracts/Transporter.php b/src/Contracts/Transporter.php index 40b6f717..3a378b13 100644 --- a/src/Contracts/Transporter.php +++ b/src/Contracts/Transporter.php @@ -4,6 +4,7 @@ namespace OpenAI\Contracts; +use Generator; use OpenAI\Exceptions\ErrorException; use OpenAI\Exceptions\TransporterException; use OpenAI\Exceptions\UnserializableResponse; @@ -21,7 +22,7 @@ interface Transporter * * @throws ErrorException|UnserializableResponse|TransporterException */ - public function requestObject(Payload $payload, bool $stream = false): array|Stream; + public function requestObject(Payload $payload, bool $stream = false): array|Generator; /** * Sends a content request to a server. diff --git a/src/Resources/Completions.php b/src/Resources/Completions.php index f8da3791..3d0aaa6a 100644 --- a/src/Resources/Completions.php +++ b/src/Resources/Completions.php @@ -4,9 +4,8 @@ namespace OpenAI\Resources; -use OpenAI\Contracts\Stream as StreamContract; +use Generator; use OpenAI\Responses\Completions\CreateResponse; -use OpenAI\Streams\Stream; use OpenAI\ValueObjects\Transporter\Payload; final class Completions @@ -20,17 +19,24 @@ final class Completions * * @param array $parameters */ - public function create(array $parameters): CreateResponse|StreamContract + public function create(array $parameters): CreateResponse|Generator { $payload = Payload::create('completions', $parameters); - /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}}|StreamContract $result */ + /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}}|Generator $result */ $result = $this->transporter->requestObject($payload, $parameters['stream'] ?? false); if (is_array($result)) { return CreateResponse::from($result); } - return new Stream($result, fn (array $result): CreateResponse => CreateResponse::from($result)); + return $this->stream($result); + } + + private function stream(Generator $stream): Generator + { + foreach ($stream as $data) { + yield CreateResponse::from($data); + } } } diff --git a/src/Streams/Stream.php b/src/Streams/Stream.php deleted file mode 100644 index 8579fffb..00000000 --- a/src/Streams/Stream.php +++ /dev/null @@ -1,22 +0,0 @@ -stream->read() as $data) { - yield ($this->callback)($data); - } - } -} diff --git a/src/Transporters/HttpTransporter.php b/src/Transporters/HttpTransporter.php index 49b66e01..7131f25e 100644 --- a/src/Transporters/HttpTransporter.php +++ b/src/Transporters/HttpTransporter.php @@ -4,9 +4,9 @@ namespace OpenAI\Transporters; +use Generator; use GuzzleHttp\ClientInterface; use JsonException; -use OpenAI\Contracts\Stream; use OpenAI\Contracts\Transporter; use OpenAI\Exceptions\ErrorException; use OpenAI\Exceptions\TransporterException; @@ -36,7 +36,7 @@ public function __construct( /** * {@inheritDoc} */ - public function requestObject(Payload $payload, bool $stream = false): array|Stream + public function requestObject(Payload $payload, bool $stream = false): array|Generator { $request = $payload->toRequest($this->baseUri, $this->headers); @@ -47,7 +47,7 @@ public function requestObject(Payload $payload, bool $stream = false): array|Str } if ($response->getHeaderLine('Content-Type') === 'text/event-stream') { - return new EventStream($response->getBody()); + return (new EventStream($response->getBody()))->read(); } try { diff --git a/tests/Pest.php b/tests/Pest.php index 12e4ebb6..1c3b2dca 100644 --- a/tests/Pest.php +++ b/tests/Pest.php @@ -1,14 +1,13 @@ 'da-vince', 'prompt' => 'hi', 'stream' => true, - ], new EventStream(Utils::streamFor('data: '.json_encode(completion())))); + ], (new EventStream(Utils::streamFor('data: '.json_encode(completion()))))->read()); $result = $client->completions()->create([ 'model' => 'da-vince', @@ -53,9 +53,9 @@ 'stream' => true, ]); - expect($result)->toBeInstanceOf(Stream::class); + expect($result)->toBeInstanceOf(Generator::class); - $data = iterator_to_array($result->read())[0]; + $data = iterator_to_array($result)[0]; expect($data) ->id->toBe('cmpl-5uS6a68SwurhqAqLBpZtibIITICna') diff --git a/tests/Streams/Stream.php b/tests/Streams/Stream.php deleted file mode 100644 index 3ee26e7e..00000000 --- a/tests/Streams/Stream.php +++ /dev/null @@ -1,21 +0,0 @@ -read()))->toBe([ - ['text' => 'Hey ', 'test' => true], - ['text' => 'there!', 'test' => true], - ]); -}); diff --git a/tests/Transporters/HttpTransporter.php b/tests/Transporters/HttpTransporter.php index 852f15cc..144b68ab 100644 --- a/tests/Transporters/HttpTransporter.php +++ b/tests/Transporters/HttpTransporter.php @@ -4,7 +4,6 @@ use GuzzleHttp\Exception\ConnectException; use GuzzleHttp\Psr7\Request as Psr7Request; use GuzzleHttp\Psr7\Response; -use OpenAI\Contracts\Stream; use OpenAI\Enums\Transporter\ContentType; use OpenAI\Exceptions\ErrorException; use OpenAI\Exceptions\TransporterException; @@ -93,7 +92,7 @@ $response = $this->http->requestObject($payload, true); - expect($response)->toBeInstanceOf(Stream::class); + expect($response)->toBeInstanceOf(Generator::class); }); test('request object server errors', function () { From 3e3d94bda84491f5ef200112bcfc762bf3d7c435 Mon Sep 17 00:00:00 2001 From: Slava Razum Date: Thu, 12 Jan 2023 17:48:45 +0200 Subject: [PATCH 5/7] - Completions\CreateResponse now allows to iterate over the stream - New Transporter Response object which works with raw or stream responses --- src/Contracts/Stream.php | 2 +- src/Contracts/Transporter.php | 6 +- src/Resources/Completions.php | 18 ++--- src/Resources/Edits.php | 2 +- src/Resources/Embeddings.php | 2 +- src/Resources/Files.php | 8 +-- src/Resources/FineTunes.php | 10 +-- src/Resources/Images.php | 6 +- src/Resources/Models.php | 6 +- src/Resources/Moderations.php | 2 +- src/Responses/Completions/CreateResponse.php | 51 +++++++++++++- src/Streams/EventStream.php | 20 ++---- src/Transporters/HttpTransporter.php | 31 +++------ src/ValueObjects/Transporter/Payload.php | 8 +++ src/ValueObjects/Transporter/Response.php | 71 ++++++++++++++++++++ tests/Pest.php | 27 +++++++- tests/Resources/Completions.php | 6 +- tests/Streams/EventStream.php | 29 +++----- tests/Transporters/HttpTransporter.php | 11 +-- 19 files changed, 211 insertions(+), 105 deletions(-) create mode 100644 src/ValueObjects/Transporter/Response.php diff --git a/src/Contracts/Stream.php b/src/Contracts/Stream.php index f3a12290..cdab4125 100644 --- a/src/Contracts/Stream.php +++ b/src/Contracts/Stream.php @@ -7,7 +7,7 @@ interface Stream /** * Iterates over the event-stream data. * - * @return iterable> + * @return iterable> */ public function read(): iterable; } diff --git a/src/Contracts/Transporter.php b/src/Contracts/Transporter.php index 3a378b13..816cc554 100644 --- a/src/Contracts/Transporter.php +++ b/src/Contracts/Transporter.php @@ -4,11 +4,11 @@ namespace OpenAI\Contracts; -use Generator; use OpenAI\Exceptions\ErrorException; use OpenAI\Exceptions\TransporterException; use OpenAI\Exceptions\UnserializableResponse; use OpenAI\ValueObjects\Transporter\Payload; +use OpenAI\ValueObjects\Transporter\Response; /** * @internal @@ -17,12 +17,10 @@ interface Transporter { /** * Sends a request to a server. - ** - * @return array * * @throws ErrorException|UnserializableResponse|TransporterException */ - public function requestObject(Payload $payload, bool $stream = false): array|Generator; + public function requestObject(Payload $payload): Response; /** * Sends a content request to a server. diff --git a/src/Resources/Completions.php b/src/Resources/Completions.php index 3d0aaa6a..12175cfe 100644 --- a/src/Resources/Completions.php +++ b/src/Resources/Completions.php @@ -19,24 +19,16 @@ final class Completions * * @param array $parameters */ - public function create(array $parameters): CreateResponse|Generator + public function create(array $parameters): CreateResponse { $payload = Payload::create('completions', $parameters); - /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}}|Generator $result */ - $result = $this->transporter->requestObject($payload, $parameters['stream'] ?? false); + $response = $this->transporter->requestObject($payload); - if (is_array($result)) { - return CreateResponse::from($result); + if ($response->isStream()) { + return CreateResponse::fromStream($response->stream()); } - return $this->stream($result); - } - - private function stream(Generator $stream): Generator - { - foreach ($stream as $data) { - yield CreateResponse::from($data); - } + return CreateResponse::from($response->object()); } } diff --git a/src/Resources/Edits.php b/src/Resources/Edits.php index b8d4340b..0a653972 100644 --- a/src/Resources/Edits.php +++ b/src/Resources/Edits.php @@ -23,7 +23,7 @@ public function create(array $parameters): CreateResponse $payload = Payload::create('edits', $parameters); /** @var array{object: string, created: int, choices: array, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return CreateResponse::from($result); } diff --git a/src/Resources/Embeddings.php b/src/Resources/Embeddings.php index c2709d21..28e500ca 100644 --- a/src/Resources/Embeddings.php +++ b/src/Resources/Embeddings.php @@ -23,7 +23,7 @@ public function create(array $parameters): CreateResponse $payload = Payload::create('embeddings', $parameters); /** @var array{object: string, data: array, index: int}>, usage: array{prompt_tokens: int, total_tokens: int}} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return CreateResponse::from($result); } diff --git a/src/Resources/Files.php b/src/Resources/Files.php index 18964c25..43791d1b 100644 --- a/src/Resources/Files.php +++ b/src/Resources/Files.php @@ -24,7 +24,7 @@ public function list(): ListResponse $payload = Payload::list('files'); /** @var array{object: string, data: array|null}>} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return ListResponse::from($result); } @@ -39,7 +39,7 @@ public function retrieve(string $file): RetrieveResponse $payload = Payload::retrieve('files', $file); /** @var array{id: string, object: string, created_at: int, bytes: int, filename: string, purpose: string, status: string, status_details: array|null} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return RetrieveResponse::from($result); } @@ -68,7 +68,7 @@ public function upload(array $parameters): CreateResponse $payload = Payload::upload('files', $parameters); /** @var array{id: string, object: string, created_at: int, bytes: int, filename: string, purpose: string, status: string, status_details: array|null} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return CreateResponse::from($result); } @@ -83,7 +83,7 @@ public function delete(string $file): DeleteResponse $payload = Payload::delete('files', $file); /** @var array{id: string, object: string, deleted: bool} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return DeleteResponse::from($result); } diff --git a/src/Resources/FineTunes.php b/src/Resources/FineTunes.php index f4e81203..4bdbac7a 100644 --- a/src/Resources/FineTunes.php +++ b/src/Resources/FineTunes.php @@ -27,7 +27,7 @@ public function create(array $parameters): RetrieveResponse $payload = Payload::create('fine-tunes', $parameters); /** @var array{id: string, object: string, model: string, created_at: int, events: array, fine_tuned_model: ?string, hyperparams: array{batch_size: ?int, learning_rate_multiplier: ?float, n_epochs: int, prompt_loss_weight: float}, organization_id: string, result_files: array|null}>, status: string, validation_files: array|null}>, training_files: array|null}>, updated_at: int} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return RetrieveResponse::from($result); } @@ -42,7 +42,7 @@ public function list(): ListResponse $payload = Payload::list('fine-tunes'); /** @var array{object: string, data: array, fine_tuned_model: ?string, hyperparams: array{batch_size: ?int, learning_rate_multiplier: ?float, n_epochs: int, prompt_loss_weight: float}, organization_id: string, result_files: array|null}>, status: string, validation_files: array|null}>, training_files: array|null}>, updated_at: int}>} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return ListResponse::from($result); } @@ -57,7 +57,7 @@ public function retrieve(string $fineTuneId): RetrieveResponse $payload = Payload::retrieve('fine-tunes', $fineTuneId); /** @var array{id: string, object: string, model: string, created_at: int, events: array, fine_tuned_model: ?string, hyperparams: array{batch_size: ?int, learning_rate_multiplier: ?float, n_epochs: int, prompt_loss_weight: float}, organization_id: string, result_files: array|null}>, status: string, validation_files: array|null}>, training_files: array|null}>, updated_at: int} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return RetrieveResponse::from($result); } @@ -72,7 +72,7 @@ public function cancel(string $fineTuneId): RetrieveResponse $payload = Payload::cancel('fine-tunes', $fineTuneId); /** @var array{id: string, object: string, model: string, created_at: int, events: array, fine_tuned_model: ?string, hyperparams: array{batch_size: ?int, learning_rate_multiplier: ?float, n_epochs: int, prompt_loss_weight: float}, organization_id: string, result_files: array|null}>, status: string, validation_files: array|null}>, training_files: array|null}>, updated_at: int} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return RetrieveResponse::from($result); } @@ -87,7 +87,7 @@ public function listEvents(string $fineTuneId): ListEventsResponse $payload = Payload::retrieve('fine-tunes', $fineTuneId, '/events'); /** @var array{object: string, data: array} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return ListEventsResponse::from($result); } diff --git a/src/Resources/Images.php b/src/Resources/Images.php index dabff0ba..418107a4 100644 --- a/src/Resources/Images.php +++ b/src/Resources/Images.php @@ -25,7 +25,7 @@ public function create(array $parameters): CreateResponse $payload = Payload::create('images/generations', $parameters); /** @var array{created: int, data: array} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return CreateResponse::from($result); } @@ -42,7 +42,7 @@ public function edit(array $parameters): EditResponse $payload = Payload::upload('images/edits', $parameters); /** @var array{created: int, data: array} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return EditResponse::from($result); } @@ -59,7 +59,7 @@ public function variation(array $parameters): VariationResponse $payload = Payload::upload('images/variations', $parameters); /** @var array{created: int, data: array} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return VariationResponse::from($result); } diff --git a/src/Resources/Models.php b/src/Resources/Models.php index cd87fe73..b642b14a 100644 --- a/src/Resources/Models.php +++ b/src/Resources/Models.php @@ -23,7 +23,7 @@ public function list(): ListResponse $payload = Payload::list('models'); /** @var array{object: string, data: array, root: string, parent: ?string}>} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return ListResponse::from($result); } @@ -38,7 +38,7 @@ public function retrieve(string $model): RetrieveResponse $payload = Payload::retrieve('models', $model); /** @var array{id: string, object: string, created: int, owned_by: string, permission: array, root: string, parent: ?string} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return RetrieveResponse::from($result); } @@ -53,7 +53,7 @@ public function delete(string $model): DeleteResponse $payload = Payload::delete('models', $model); /** @var array{id: string, object: string, deleted: bool} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return DeleteResponse::from($result); } diff --git a/src/Resources/Moderations.php b/src/Resources/Moderations.php index 14823b3f..c4cab140 100644 --- a/src/Resources/Moderations.php +++ b/src/Resources/Moderations.php @@ -23,7 +23,7 @@ public function create(array $parameters): CreateResponse $payload = Payload::create('moderations', $parameters); /** @var array{id: string, model: string, results: array, category_scores: array, flagged: bool}>} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return CreateResponse::from($result); } diff --git a/src/Responses/Completions/CreateResponse.php b/src/Responses/Completions/CreateResponse.php index 6695b14a..d9c6304b 100644 --- a/src/Responses/Completions/CreateResponse.php +++ b/src/Responses/Completions/CreateResponse.php @@ -4,19 +4,23 @@ namespace OpenAI\Responses\Completions; +use Generator; +use Iterator; use OpenAI\Contracts\Response; use OpenAI\Responses\Concerns\ArrayAccessible; /** * @implements Response, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> */ -final class CreateResponse implements Response +final class CreateResponse implements Response, Iterator { /** * @use ArrayAccessible, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> */ use ArrayAccessible; + private Generator $stream; + /** * @param array $choices */ @@ -51,6 +55,14 @@ public static function from(array $attributes): self ); } + public static function fromStream(Generator $stream): self + { + /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}} $attributes */ + $attributes = $stream->current(); + + return self::from($attributes)->withStream($stream); + } + /** * {@inheritDoc} */ @@ -68,4 +80,41 @@ public function toArray(): array 'usage' => $this->usage?->toArray(), ]; } + + public function current(): mixed + { + return self::from($this->stream->current()); + } + + public function next(): void + { + $this->stream->next(); + } + + public function key(): mixed + { + return $this->stream->key(); + } + + public function valid(): bool + { + return $this->stream->valid(); + } + + public function rewind(): void + { + $this->stream->rewind(); + } + + public function withStream(Generator $stream): self + { + $this->stream = $stream; + + return $this; + } + + public function isStream(): bool + { + return isset($this->stream); + } } diff --git a/src/Streams/EventStream.php b/src/Streams/EventStream.php index 1b3cbb9d..e3eefb77 100644 --- a/src/Streams/EventStream.php +++ b/src/Streams/EventStream.php @@ -2,11 +2,9 @@ namespace OpenAI\Streams; +use Generator; use GuzzleHttp\Psr7\Utils; -use JsonException; use OpenAI\Contracts\Stream; -use OpenAI\Exceptions\ErrorException; -use OpenAI\Exceptions\UnserializableResponse; use Psr\Http\Message\StreamInterface; final class EventStream implements Stream @@ -16,7 +14,7 @@ public function __construct( ) { } - public function read(): iterable + public function read(): Generator { while (! $this->stream->eof()) { $line = Utils::readLine($this->stream); @@ -25,22 +23,12 @@ public function read(): iterable continue; } - $rawData = substr(strstr($line, 'data: '), 6); + $data = trim(substr(strstr($line, 'data: '), 6)); - if ($rawData === "[DONE]\n") { + if ($data === "[DONE]") { break; } - try { - $data = json_decode($rawData, true, 512, JSON_THROW_ON_ERROR); - } catch (JsonException $jsonException) { - throw new UnserializableResponse($jsonException); - } - - if (isset($data['error'])) { - throw new ErrorException($data['error']); - } - yield $data; } } diff --git a/src/Transporters/HttpTransporter.php b/src/Transporters/HttpTransporter.php index 7131f25e..cb542fd1 100644 --- a/src/Transporters/HttpTransporter.php +++ b/src/Transporters/HttpTransporter.php @@ -4,17 +4,15 @@ namespace OpenAI\Transporters; -use Generator; use GuzzleHttp\ClientInterface; use JsonException; use OpenAI\Contracts\Transporter; use OpenAI\Exceptions\ErrorException; use OpenAI\Exceptions\TransporterException; -use OpenAI\Exceptions\UnserializableResponse; -use OpenAI\Streams\EventStream; use OpenAI\ValueObjects\Transporter\BaseUri; use OpenAI\ValueObjects\Transporter\Headers; use OpenAI\ValueObjects\Transporter\Payload; +use OpenAI\ValueObjects\Transporter\Response; use Psr\Http\Client\ClientExceptionInterface; /** @@ -36,31 +34,18 @@ public function __construct( /** * {@inheritDoc} */ - public function requestObject(Payload $payload, bool $stream = false): array|Generator + public function requestObject(Payload $payload): Response { - $request = $payload->toRequest($this->baseUri, $this->headers); - try { - $response = $this->client->send($request, ['stream' => $stream]); + return new Response( + $this->client->send( + $payload->toRequest($this->baseUri, $this->headers), + ['stream' => $payload->isStream()] + ) + ); } catch (ClientExceptionInterface $clientException) { throw new TransporterException($clientException); } - - if ($response->getHeaderLine('Content-Type') === 'text/event-stream') { - return (new EventStream($response->getBody()))->read(); - } - - try { - $response = json_decode($response->getBody()->getContents(), true, 512, JSON_THROW_ON_ERROR); - } catch (JsonException $jsonException) { - throw new UnserializableResponse($jsonException); - } - - if (isset($response['error'])) { - throw new ErrorException($response['error']); - } - - return $response; } /** diff --git a/src/ValueObjects/Transporter/Payload.php b/src/ValueObjects/Transporter/Payload.php index 298b20b3..9ae7c99b 100644 --- a/src/ValueObjects/Transporter/Payload.php +++ b/src/ValueObjects/Transporter/Payload.php @@ -30,6 +30,14 @@ private function __construct( // .. } + /** + * Determines whether the stream response was requested. + */ + public function isStream(): bool + { + return $this->parameters['stream'] ?? false; + } + /** * Creates a new Payload value object from the given parameters. */ diff --git a/src/ValueObjects/Transporter/Response.php b/src/ValueObjects/Transporter/Response.php new file mode 100644 index 00000000..60f8bb40 --- /dev/null +++ b/src/ValueObjects/Transporter/Response.php @@ -0,0 +1,71 @@ +response->getHeaderLine('Content-Type') === 'text/event-stream'; + } + + /** + * Returns decoded response object. + * + * @return array + */ + public function object(): array + { + return $this->decode($this->response->getBody()->getContents()); + } + + /** + * Returns stream generator with decoded response objects. + * + * @return Generator> + */ + public function stream(): Generator + { + foreach ((new EventStream($this->response->getBody()))->read() as $data) { + yield $this->decode($data); + } + } + + /** + * Decode raw content to json. + * + * @return array + */ + private function decode(string $contents): array + { + $result = []; + + try { + $result = json_decode($contents, true, 512, JSON_THROW_ON_ERROR); + } catch (JsonException $jsonException) { + throw new UnserializableResponse($jsonException); + } + + if (isset($result['error'])) { + throw new ErrorException($result['error']); + } + + return $result; + } +} diff --git a/tests/Pest.php b/tests/Pest.php index 1c3b2dca..5035e0fa 100644 --- a/tests/Pest.php +++ b/tests/Pest.php @@ -1,15 +1,40 @@ shouldReceive('getBody')->andReturn(Utils::streamFor( + is_array($response) ? json_encode($response) : $response + )); + $responseMock->shouldReceive('getHeaderLine')->andReturn('application/json'); + } elseif ($response instanceof Generator) { + $responseMock->shouldReceive('getBody')->andReturn(Utils::streamFor( + 'data: '.implode("\n\ndata: ", array_map( + fn ($value) => json_encode($value), + iterator_to_array($response) + ) + ))); + + $responseMock->shouldReceive('getHeaderLine')->andReturn('text/event-stream'); + } + + } $transporter ->shouldReceive($methodName) @@ -22,7 +47,7 @@ function mockClient(string $method, string $resource, array $params, array|strin return $request->getMethod() === $method && $request->getUri()->getPath() === "/v1/$resource"; - })->andReturn($response); + })->andReturn($responseMock ? new Response($responseMock) : $response); return new Client($transporter); } diff --git a/tests/Resources/Completions.php b/tests/Resources/Completions.php index 57a9a22b..3bfc3b12 100644 --- a/tests/Resources/Completions.php +++ b/tests/Resources/Completions.php @@ -45,7 +45,7 @@ 'model' => 'da-vince', 'prompt' => 'hi', 'stream' => true, - ], (new EventStream(Utils::streamFor('data: '.json_encode(completion()))))->read()); + ], (fn () => yield completion())()); $result = $client->completions()->create([ 'model' => 'da-vince', @@ -53,10 +53,12 @@ 'stream' => true, ]); - expect($result)->toBeInstanceOf(Generator::class); + expect($result)->toBeInstanceOf(CreateResponse::class); $data = iterator_to_array($result)[0]; + ray($result); + expect($data) ->id->toBe('cmpl-5uS6a68SwurhqAqLBpZtibIITICna') ->object->toBe('text_completion') diff --git a/tests/Streams/EventStream.php b/tests/Streams/EventStream.php index dac8460c..0f1fd9ac 100644 --- a/tests/Streams/EventStream.php +++ b/tests/Streams/EventStream.php @@ -6,36 +6,23 @@ use OpenAI\Streams\EventStream; test('read data stream', function () { - $stream = new EventStream(Utils::streamFor('data: {"text": "Hey!"}')); + $jsonString = '{"text": "Hey!"}'; - expect(iterator_to_array($stream->read())[0])->toBe(['text' => 'Hey!']); + $stream = new EventStream(Utils::streamFor("data: $jsonString")); + + $arr = iterator_to_array($stream->read()); + + expect($arr[0])->toBe($jsonString); }); test('skips empty stream lines', function () { $stream = new EventStream(Utils::streamFor("data: {\"text\": \"Hey\"}\n\ndata: {\"text\": \"there!\"}")); - expect(iterator_to_array($stream->read()))->toBe([['text' => 'Hey'], ['text' => 'there!']]); + expect(iterator_to_array($stream->read()))->toBe(['{"text": "Hey"}', '{"text": "there!"}']); }); test('aborts after done message', function () { $stream = new EventStream(Utils::streamFor("data: {\"text\": \"Hey\"}\ndata: [DONE]\ndata: {\"text\": \"there!\"}")); - expect(iterator_to_array($stream->read()))->toBe([['text' => 'Hey']]); + expect(iterator_to_array($stream->read()))->toBe(['{"text": "Hey"}']); }); - -test('stream message serialization error', function () { - $stream = new EventStream(Utils::streamFor('data: invalid')); - - iterator_to_array($stream->read()); -})->throws(UnserializableResponse::class); - -test('stream message error', function () { - $stream = new EventStream(Utils::streamFor('data: '.json_encode(['error' => [ - 'message' => 'Something went wrong.', - 'type' => 'invalid_request_error', - 'param' => null, - 'code' => 'invalid_request_error', - ]]))); - - iterator_to_array($stream->read()); -})->throws(ErrorException::class, 'Something went wrong'); diff --git a/tests/Transporters/HttpTransporter.php b/tests/Transporters/HttpTransporter.php index 144b68ab..29955f90 100644 --- a/tests/Transporters/HttpTransporter.php +++ b/tests/Transporters/HttpTransporter.php @@ -13,6 +13,7 @@ use OpenAI\ValueObjects\Transporter\BaseUri; use OpenAI\ValueObjects\Transporter\Headers; use OpenAI\ValueObjects\Transporter\Payload; +use OpenAI\ValueObjects\Transporter\Response as TransporterResponse; beforeEach(function () { $this->client = Mockery::mock(ClientInterface::class); @@ -68,7 +69,7 @@ $response = $this->http->requestObject($payload); - expect($response)->toBe([ + expect($response->object())->toBe([ [ 'text' => 'Hey!', 'index' => 0, @@ -90,9 +91,9 @@ ->once() ->andReturn($response); - $response = $this->http->requestObject($payload, true); + $response = $this->http->requestObject($payload); - expect($response)->toBeInstanceOf(Generator::class); + expect($response)->toBeInstanceOf(TransporterResponse::class); }); test('request object server errors', function () { @@ -112,7 +113,7 @@ ->once() ->andReturn($response); - expect(fn () => $this->http->requestObject($payload)) + expect(fn () => $this->http->requestObject($payload)->object()) ->toThrow(function (ErrorException $e) { expect($e->getMessage())->toBe('Incorrect API key provided: foo. You can find your API key at https://beta.openai.com.') ->and($e->getErrorMessage())->toBe('Incorrect API key provided: foo. You can find your API key at https://beta.openai.com.') @@ -149,7 +150,7 @@ ->once() ->andReturn($response); - $this->http->requestObject($payload); + $this->http->requestObject($payload)->object(); })->throws(UnserializableResponse::class, 'Syntax error'); test('request content', function () { From e5b8529774cff6768029945091e2a2661bf74045 Mon Sep 17 00:00:00 2001 From: Slava Razum Date: Fri, 13 Jan 2023 02:26:25 +0200 Subject: [PATCH 6/7] Quality improvements --- src/Contracts/Stream.php | 2 +- src/Resources/Completions.php | 11 +++++++--- src/Responses/Completions/CreateResponse.php | 22 ++++++++++++++----- .../Completions/CreateResponseChoice.php | 4 ++-- src/Responses/Edits/CreateResponse.php | 2 +- src/Streams/EventStream.php | 9 ++++++-- src/Transporters/HttpTransporter.php | 2 +- src/ValueObjects/Transporter/Payload.php | 2 +- src/ValueObjects/Transporter/Response.php | 12 +++++----- tests/Pest.php | 5 ++--- tests/Resources/Completions.php | 3 --- tests/Streams/EventStream.php | 2 -- tests/Transporters/HttpTransporter.php | 8 +++---- 13 files changed, 50 insertions(+), 34 deletions(-) diff --git a/src/Contracts/Stream.php b/src/Contracts/Stream.php index cdab4125..60d2aa05 100644 --- a/src/Contracts/Stream.php +++ b/src/Contracts/Stream.php @@ -7,7 +7,7 @@ interface Stream /** * Iterates over the event-stream data. * - * @return iterable> + * @return iterable */ public function read(): iterable; } diff --git a/src/Resources/Completions.php b/src/Resources/Completions.php index 12175cfe..cff99cb7 100644 --- a/src/Resources/Completions.php +++ b/src/Resources/Completions.php @@ -4,7 +4,6 @@ namespace OpenAI\Resources; -use Generator; use OpenAI\Responses\Completions\CreateResponse; use OpenAI\ValueObjects\Transporter\Payload; @@ -26,9 +25,15 @@ public function create(array $parameters): CreateResponse $response = $this->transporter->requestObject($payload); if ($response->isStream()) { - return CreateResponse::fromStream($response->stream()); + /** @var \Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> $stream */ + $stream = $response->stream(); + + return CreateResponse::fromStream($stream); } - return CreateResponse::from($response->object()); + /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null} $object */ + $object = $response->object(); + + return CreateResponse::from($object); } } diff --git a/src/Responses/Completions/CreateResponse.php b/src/Responses/Completions/CreateResponse.php index d9c6304b..d1a2430c 100644 --- a/src/Responses/Completions/CreateResponse.php +++ b/src/Responses/Completions/CreateResponse.php @@ -10,15 +10,19 @@ use OpenAI\Responses\Concerns\ArrayAccessible; /** - * @implements Response, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> + * @implements Iterator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null}>> + * @implements Response, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null}> */ final class CreateResponse implements Response, Iterator { /** - * @use ArrayAccessible, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> + * @use ArrayAccessible, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null}> */ use ArrayAccessible; + /** + * @var Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null}> + */ private Generator $stream; /** @@ -37,7 +41,7 @@ private function __construct( /** * Acts as static factory, and returns a new Response instance. * - * @param array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}} $attributes + * @param array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null} $attributes */ public static function from(array $attributes): self { @@ -55,9 +59,12 @@ public static function from(array $attributes): self ); } + /** + * @param Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> $stream + */ public static function fromStream(Generator $stream): self { - /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}} $attributes */ + /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}} $attributes */ $attributes = $stream->current(); return self::from($attributes)->withStream($stream); @@ -81,7 +88,7 @@ public function toArray(): array ]; } - public function current(): mixed + public function current(): self { return self::from($this->stream->current()); } @@ -91,7 +98,7 @@ public function next(): void $this->stream->next(); } - public function key(): mixed + public function key(): int { return $this->stream->key(); } @@ -106,6 +113,9 @@ public function rewind(): void $this->stream->rewind(); } + /** + * @param Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> $stream + */ public function withStream(Generator $stream): self { $this->stream = $stream; diff --git a/src/Responses/Completions/CreateResponseChoice.php b/src/Responses/Completions/CreateResponseChoice.php index f7a84ab2..1c6f93dd 100644 --- a/src/Responses/Completions/CreateResponseChoice.php +++ b/src/Responses/Completions/CreateResponseChoice.php @@ -15,7 +15,7 @@ private function __construct( } /** - * @param array{text: string, index: int, logprobs: array{tokens: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string} $attributes + * @param array{text: string, index: int, logprobs: array{tokens: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null} $attributes */ public static function from(array $attributes): self { @@ -28,7 +28,7 @@ public static function from(array $attributes): self } /** - * @return array{text: string, index: int, logprobs: array{tokens: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string} + * @return array{text: string, index: int, logprobs: array{tokens: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null} */ public function toArray(): array { diff --git a/src/Responses/Edits/CreateResponse.php b/src/Responses/Edits/CreateResponse.php index 846bbf8b..21ebd75d 100644 --- a/src/Responses/Edits/CreateResponse.php +++ b/src/Responses/Edits/CreateResponse.php @@ -13,7 +13,7 @@ final class CreateResponse implements Response { /** - * @use ArrayAccessible, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}}> + * @use ArrayAccessible, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}}> */ use ArrayAccessible; diff --git a/src/Streams/EventStream.php b/src/Streams/EventStream.php index e3eefb77..2cd14476 100644 --- a/src/Streams/EventStream.php +++ b/src/Streams/EventStream.php @@ -14,6 +14,11 @@ public function __construct( ) { } + /** + * Iterates over the event-stream data. + * + * @return Generator + */ public function read(): Generator { while (! $this->stream->eof()) { @@ -23,9 +28,9 @@ public function read(): Generator continue; } - $data = trim(substr(strstr($line, 'data: '), 6)); + $data = trim(substr((string) strstr($line, 'data:'), 5)); - if ($data === "[DONE]") { + if ($data === '[DONE]') { break; } diff --git a/src/Transporters/HttpTransporter.php b/src/Transporters/HttpTransporter.php index cb542fd1..d2d8de94 100644 --- a/src/Transporters/HttpTransporter.php +++ b/src/Transporters/HttpTransporter.php @@ -56,7 +56,7 @@ public function requestContent(Payload $payload): string $request = $payload->toRequest($this->baseUri, $this->headers); try { - $response = $this->client->sendRequest($request); + $response = $this->client->send($request); } catch (ClientExceptionInterface $clientException) { throw new TransporterException($clientException); } diff --git a/src/ValueObjects/Transporter/Payload.php b/src/ValueObjects/Transporter/Payload.php index 9ae7c99b..aaacea38 100644 --- a/src/ValueObjects/Transporter/Payload.php +++ b/src/ValueObjects/Transporter/Payload.php @@ -35,7 +35,7 @@ private function __construct( */ public function isStream(): bool { - return $this->parameters['stream'] ?? false; + return (bool) ($this->parameters['stream'] ?? false); } /** diff --git a/src/ValueObjects/Transporter/Response.php b/src/ValueObjects/Transporter/Response.php index 60f8bb40..d87e501a 100644 --- a/src/ValueObjects/Transporter/Response.php +++ b/src/ValueObjects/Transporter/Response.php @@ -9,13 +9,14 @@ use OpenAI\Streams\EventStream; use Psr\Http\Message\ResponseInterface; -class Response +final class Response { /** * Creates a new Response value object. */ public function __construct(private readonly ResponseInterface $response) - {} + { + } /** * Determines whether the stream response was requested. @@ -28,7 +29,7 @@ public function isStream(): bool /** * Returns decoded response object. * - * @return array + * @return array */ public function object(): array { @@ -38,7 +39,7 @@ public function object(): array /** * Returns stream generator with decoded response objects. * - * @return Generator> + * @return Generator> */ public function stream(): Generator { @@ -50,13 +51,14 @@ public function stream(): Generator /** * Decode raw content to json. * - * @return array + * @return array */ private function decode(string $contents): array { $result = []; try { + /** @var array{error?: array{message: string, type: string, code: string}} $result */ $result = json_decode($contents, true, 512, JSON_THROW_ON_ERROR); } catch (JsonException $jsonException) { throw new UnserializableResponse($jsonException); diff --git a/tests/Pest.php b/tests/Pest.php index 5035e0fa..ff542b33 100644 --- a/tests/Pest.php +++ b/tests/Pest.php @@ -28,12 +28,11 @@ function mockClient(string $method, string $resource, array $params, array|strin 'data: '.implode("\n\ndata: ", array_map( fn ($value) => json_encode($value), iterator_to_array($response) - ) - ))); + ) + ))); $responseMock->shouldReceive('getHeaderLine')->andReturn('text/event-stream'); } - } $transporter diff --git a/tests/Resources/Completions.php b/tests/Resources/Completions.php index 3bfc3b12..b9939218 100644 --- a/tests/Resources/Completions.php +++ b/tests/Resources/Completions.php @@ -1,11 +1,8 @@ client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->withArgs(function (Psr7Request $request) { expect($request->getMethod())->toBe('GET') @@ -182,7 +182,7 @@ $response = new Response(200, [], 'My response content'); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andReturn($response); @@ -198,7 +198,7 @@ $headers = Headers::withAuthorization(ApiToken::from('foo')); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andThrow(new ConnectException('Could not resolve host.', $payload->toRequest($baseUri, $headers))); @@ -222,7 +222,7 @@ ])); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andReturn($response); From 7f422d23ce0a705bf6168ae72ec99b041762b477 Mon Sep 17 00:00:00 2001 From: Slava Razum Date: Fri, 13 Jan 2023 06:03:03 +0200 Subject: [PATCH 7/7] Type improvements --- src/Resources/Completions.php | 4 ++-- src/Responses/Completions/CreateResponse.php | 6 +++--- src/ValueObjects/Transporter/Payload.php | 3 +-- src/ValueObjects/Transporter/Response.php | 8 ++++---- tests/Transporters/HttpTransporter.php | 3 +-- 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/Resources/Completions.php b/src/Resources/Completions.php index cff99cb7..4c759dee 100644 --- a/src/Resources/Completions.php +++ b/src/Resources/Completions.php @@ -25,13 +25,13 @@ public function create(array $parameters): CreateResponse $response = $this->transporter->requestObject($payload); if ($response->isStream()) { - /** @var \Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> $stream */ + /** @var \Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: null}>, usage: null}> $stream */ $stream = $response->stream(); return CreateResponse::fromStream($stream); } - /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null} $object */ + /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}} $object */ $object = $response->object(); return CreateResponse::from($object); diff --git a/src/Responses/Completions/CreateResponse.php b/src/Responses/Completions/CreateResponse.php index d1a2430c..a42e3146 100644 --- a/src/Responses/Completions/CreateResponse.php +++ b/src/Responses/Completions/CreateResponse.php @@ -60,11 +60,11 @@ public static function from(array $attributes): self } /** - * @param Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> $stream + * @param Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: null}>, usage: null}> $stream */ public static function fromStream(Generator $stream): self { - /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}} $attributes */ + /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: null}>, usage: null} $attributes */ $attributes = $stream->current(); return self::from($attributes)->withStream($stream); @@ -114,7 +114,7 @@ public function rewind(): void } /** - * @param Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> $stream + * @param Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: null}>, usage: null}> $stream */ public function withStream(Generator $stream): self { diff --git a/src/ValueObjects/Transporter/Payload.php b/src/ValueObjects/Transporter/Payload.php index aaacea38..00f0c1a5 100644 --- a/src/ValueObjects/Transporter/Payload.php +++ b/src/ValueObjects/Transporter/Payload.php @@ -6,7 +6,6 @@ use GuzzleHttp\Psr7\MultipartStream; use GuzzleHttp\Psr7\Request as Psr7Request; -use OpenAI\Contracts\Request; use OpenAI\Enums\Transporter\ContentType; use OpenAI\Enums\Transporter\Method; use OpenAI\ValueObjects\ResourceUri; @@ -35,7 +34,7 @@ private function __construct( */ public function isStream(): bool { - return (bool) ($this->parameters['stream'] ?? false); + return isset($this->parameters['stream']) && is_bool($this->parameters['stream']) && $this->parameters['stream']; } /** diff --git a/src/ValueObjects/Transporter/Response.php b/src/ValueObjects/Transporter/Response.php index d87e501a..206dc18c 100644 --- a/src/ValueObjects/Transporter/Response.php +++ b/src/ValueObjects/Transporter/Response.php @@ -29,7 +29,7 @@ public function isStream(): bool /** * Returns decoded response object. * - * @return array + * @return array */ public function object(): array { @@ -39,7 +39,7 @@ public function object(): array /** * Returns stream generator with decoded response objects. * - * @return Generator> + * @return Generator> */ public function stream(): Generator { @@ -49,9 +49,9 @@ public function stream(): Generator } /** - * Decode raw content to json. + * Decode object from raw json string. * - * @return array + * @return array */ private function decode(string $contents): array { diff --git a/tests/Transporters/HttpTransporter.php b/tests/Transporters/HttpTransporter.php index f6263605..6b46479f 100644 --- a/tests/Transporters/HttpTransporter.php +++ b/tests/Transporters/HttpTransporter.php @@ -13,7 +13,6 @@ use OpenAI\ValueObjects\Transporter\BaseUri; use OpenAI\ValueObjects\Transporter\Headers; use OpenAI\ValueObjects\Transporter\Payload; -use OpenAI\ValueObjects\Transporter\Response as TransporterResponse; beforeEach(function () { $this->client = Mockery::mock(ClientInterface::class); @@ -93,7 +92,7 @@ $response = $this->http->requestObject($payload); - expect($response)->toBeInstanceOf(TransporterResponse::class); + expect($response->isStream())->toBeTrue(); }); test('request object server errors', function () {