diff --git a/src/Contracts/Stream.php b/src/Contracts/Stream.php new file mode 100644 index 00000000..60d2aa05 --- /dev/null +++ b/src/Contracts/Stream.php @@ -0,0 +1,13 @@ + + */ + public function read(): iterable; +} diff --git a/src/Contracts/Transporter.php b/src/Contracts/Transporter.php index e3c2eb8a..816cc554 100644 --- a/src/Contracts/Transporter.php +++ b/src/Contracts/Transporter.php @@ -8,6 +8,7 @@ use OpenAI\Exceptions\TransporterException; use OpenAI\Exceptions\UnserializableResponse; use OpenAI\ValueObjects\Transporter\Payload; +use OpenAI\ValueObjects\Transporter\Response; /** * @internal @@ -16,12 +17,10 @@ interface Transporter { /** * Sends a request to a server. - ** - * @return array * * @throws ErrorException|UnserializableResponse|TransporterException */ - public function requestObject(Payload $payload): array; + public function requestObject(Payload $payload): Response; /** * Sends a content request to a server. diff --git a/src/Resources/Completions.php b/src/Resources/Completions.php index 69cfd569..4c759dee 100644 --- a/src/Resources/Completions.php +++ b/src/Resources/Completions.php @@ -22,9 +22,18 @@ public function create(array $parameters): CreateResponse { $payload = Payload::create('completions', $parameters); - /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}} $result */ - $result = $this->transporter->requestObject($payload); + $response = $this->transporter->requestObject($payload); - return CreateResponse::from($result); + if ($response->isStream()) { + /** @var \Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: null}>, usage: null}> $stream */ + $stream = $response->stream(); + + return CreateResponse::fromStream($stream); + } + + /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}} $object */ + $object = $response->object(); + + return CreateResponse::from($object); } } diff --git a/src/Resources/Edits.php b/src/Resources/Edits.php index b8d4340b..0a653972 100644 --- a/src/Resources/Edits.php +++ b/src/Resources/Edits.php @@ -23,7 +23,7 @@ public function create(array $parameters): CreateResponse $payload = Payload::create('edits', $parameters); /** @var array{object: string, created: int, choices: array, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return CreateResponse::from($result); } diff --git a/src/Resources/Embeddings.php b/src/Resources/Embeddings.php index c2709d21..28e500ca 100644 --- a/src/Resources/Embeddings.php +++ b/src/Resources/Embeddings.php @@ -23,7 +23,7 @@ public function create(array $parameters): CreateResponse $payload = Payload::create('embeddings', $parameters); /** @var array{object: string, data: array, index: int}>, usage: array{prompt_tokens: int, total_tokens: int}} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return CreateResponse::from($result); } diff --git a/src/Resources/Files.php b/src/Resources/Files.php index 18964c25..43791d1b 100644 --- a/src/Resources/Files.php +++ b/src/Resources/Files.php @@ -24,7 +24,7 @@ public function list(): ListResponse $payload = Payload::list('files'); /** @var array{object: string, data: array|null}>} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return ListResponse::from($result); } @@ -39,7 +39,7 @@ public function retrieve(string $file): RetrieveResponse $payload = Payload::retrieve('files', $file); /** @var array{id: string, object: string, created_at: int, bytes: int, filename: string, purpose: string, status: string, status_details: array|null} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return RetrieveResponse::from($result); } @@ -68,7 +68,7 @@ public function upload(array $parameters): CreateResponse $payload = Payload::upload('files', $parameters); /** @var array{id: string, object: string, created_at: int, bytes: int, filename: string, purpose: string, status: string, status_details: array|null} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return CreateResponse::from($result); } @@ -83,7 +83,7 @@ public function delete(string $file): DeleteResponse $payload = Payload::delete('files', $file); /** @var array{id: string, object: string, deleted: bool} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return DeleteResponse::from($result); } diff --git a/src/Resources/FineTunes.php b/src/Resources/FineTunes.php index f4e81203..4bdbac7a 100644 --- a/src/Resources/FineTunes.php +++ b/src/Resources/FineTunes.php @@ -27,7 +27,7 @@ public function create(array $parameters): RetrieveResponse $payload = Payload::create('fine-tunes', $parameters); /** @var array{id: string, object: string, model: string, created_at: int, events: array, fine_tuned_model: ?string, hyperparams: array{batch_size: ?int, learning_rate_multiplier: ?float, n_epochs: int, prompt_loss_weight: float}, organization_id: string, result_files: array|null}>, status: string, validation_files: array|null}>, training_files: array|null}>, updated_at: int} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return RetrieveResponse::from($result); } @@ -42,7 +42,7 @@ public function list(): ListResponse $payload = Payload::list('fine-tunes'); /** @var array{object: string, data: array, fine_tuned_model: ?string, hyperparams: array{batch_size: ?int, learning_rate_multiplier: ?float, n_epochs: int, prompt_loss_weight: float}, organization_id: string, result_files: array|null}>, status: string, validation_files: array|null}>, training_files: array|null}>, updated_at: int}>} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return ListResponse::from($result); } @@ -57,7 +57,7 @@ public function retrieve(string $fineTuneId): RetrieveResponse $payload = Payload::retrieve('fine-tunes', $fineTuneId); /** @var array{id: string, object: string, model: string, created_at: int, events: array, fine_tuned_model: ?string, hyperparams: array{batch_size: ?int, learning_rate_multiplier: ?float, n_epochs: int, prompt_loss_weight: float}, organization_id: string, result_files: array|null}>, status: string, validation_files: array|null}>, training_files: array|null}>, updated_at: int} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return RetrieveResponse::from($result); } @@ -72,7 +72,7 @@ public function cancel(string $fineTuneId): RetrieveResponse $payload = Payload::cancel('fine-tunes', $fineTuneId); /** @var array{id: string, object: string, model: string, created_at: int, events: array, fine_tuned_model: ?string, hyperparams: array{batch_size: ?int, learning_rate_multiplier: ?float, n_epochs: int, prompt_loss_weight: float}, organization_id: string, result_files: array|null}>, status: string, validation_files: array|null}>, training_files: array|null}>, updated_at: int} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return RetrieveResponse::from($result); } @@ -87,7 +87,7 @@ public function listEvents(string $fineTuneId): ListEventsResponse $payload = Payload::retrieve('fine-tunes', $fineTuneId, '/events'); /** @var array{object: string, data: array} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return ListEventsResponse::from($result); } diff --git a/src/Resources/Images.php b/src/Resources/Images.php index dabff0ba..418107a4 100644 --- a/src/Resources/Images.php +++ b/src/Resources/Images.php @@ -25,7 +25,7 @@ public function create(array $parameters): CreateResponse $payload = Payload::create('images/generations', $parameters); /** @var array{created: int, data: array} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return CreateResponse::from($result); } @@ -42,7 +42,7 @@ public function edit(array $parameters): EditResponse $payload = Payload::upload('images/edits', $parameters); /** @var array{created: int, data: array} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return EditResponse::from($result); } @@ -59,7 +59,7 @@ public function variation(array $parameters): VariationResponse $payload = Payload::upload('images/variations', $parameters); /** @var array{created: int, data: array} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return VariationResponse::from($result); } diff --git a/src/Resources/Models.php b/src/Resources/Models.php index cd87fe73..b642b14a 100644 --- a/src/Resources/Models.php +++ b/src/Resources/Models.php @@ -23,7 +23,7 @@ public function list(): ListResponse $payload = Payload::list('models'); /** @var array{object: string, data: array, root: string, parent: ?string}>} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return ListResponse::from($result); } @@ -38,7 +38,7 @@ public function retrieve(string $model): RetrieveResponse $payload = Payload::retrieve('models', $model); /** @var array{id: string, object: string, created: int, owned_by: string, permission: array, root: string, parent: ?string} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return RetrieveResponse::from($result); } @@ -53,7 +53,7 @@ public function delete(string $model): DeleteResponse $payload = Payload::delete('models', $model); /** @var array{id: string, object: string, deleted: bool} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return DeleteResponse::from($result); } diff --git a/src/Resources/Moderations.php b/src/Resources/Moderations.php index 14823b3f..c4cab140 100644 --- a/src/Resources/Moderations.php +++ b/src/Resources/Moderations.php @@ -23,7 +23,7 @@ public function create(array $parameters): CreateResponse $payload = Payload::create('moderations', $parameters); /** @var array{id: string, model: string, results: array, category_scores: array, flagged: bool}>} $result */ - $result = $this->transporter->requestObject($payload); + $result = $this->transporter->requestObject($payload)->object(); return CreateResponse::from($result); } diff --git a/src/Responses/Completions/CreateResponse.php b/src/Responses/Completions/CreateResponse.php index 882c6677..a42e3146 100644 --- a/src/Responses/Completions/CreateResponse.php +++ b/src/Responses/Completions/CreateResponse.php @@ -4,19 +4,27 @@ namespace OpenAI\Responses\Completions; +use Generator; +use Iterator; use OpenAI\Contracts\Response; use OpenAI\Responses\Concerns\ArrayAccessible; /** - * @implements Response, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> + * @implements Iterator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null}>> + * @implements Response, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null}> */ -final class CreateResponse implements Response +final class CreateResponse implements Response, Iterator { /** - * @use ArrayAccessible, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}}> + * @use ArrayAccessible, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null}> */ use ArrayAccessible; + /** + * @var Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null}> + */ + private Generator $stream; + /** * @param array $choices */ @@ -26,14 +34,14 @@ private function __construct( public readonly int $created, public readonly string $model, public readonly array $choices, - public readonly CreateResponseUsage $usage, + public readonly ?CreateResponseUsage $usage, ) { } /** * Acts as static factory, and returns a new Response instance. * - * @param array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}} $attributes + * @param array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null}>, usage: array{prompt_tokens: int, completion_tokens: int|null, total_tokens: int}|null} $attributes */ public static function from(array $attributes): self { @@ -47,10 +55,21 @@ public static function from(array $attributes): self $attributes['created'], $attributes['model'], $choices, - CreateResponseUsage::from($attributes['usage']) + isset($attributes['usage']) ? CreateResponseUsage::from($attributes['usage']) : null ); } + /** + * @param Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: null}>, usage: null}> $stream + */ + public static function fromStream(Generator $stream): self + { + /** @var array{id: string, object: string, created: int, model: string, choices: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: null}>, usage: null} $attributes */ + $attributes = $stream->current(); + + return self::from($attributes)->withStream($stream); + } + /** * {@inheritDoc} */ @@ -65,7 +84,47 @@ public function toArray(): array static fn (CreateResponseChoice $result): array => $result->toArray(), $this->choices, ), - 'usage' => $this->usage->toArray(), + 'usage' => $this->usage?->toArray(), ]; } + + public function current(): self + { + return self::from($this->stream->current()); + } + + public function next(): void + { + $this->stream->next(); + } + + public function key(): int + { + return $this->stream->key(); + } + + public function valid(): bool + { + return $this->stream->valid(); + } + + public function rewind(): void + { + $this->stream->rewind(); + } + + /** + * @param Generator, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: null}>, usage: null}> $stream + */ + public function withStream(Generator $stream): self + { + $this->stream = $stream; + + return $this; + } + + public function isStream(): bool + { + return isset($this->stream); + } } diff --git a/src/Responses/Completions/CreateResponseChoice.php b/src/Responses/Completions/CreateResponseChoice.php index 6fa722fc..1c6f93dd 100644 --- a/src/Responses/Completions/CreateResponseChoice.php +++ b/src/Responses/Completions/CreateResponseChoice.php @@ -10,12 +10,12 @@ private function __construct( public readonly string $text, public readonly int $index, public readonly ?CreateResponseChoiceLogprobs $logprobs, - public readonly string $finishReason, + public readonly ?string $finishReason, ) { } /** - * @param array{text: string, index: int, logprobs: array{tokens: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string} $attributes + * @param array{text: string, index: int, logprobs: array{tokens: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null} $attributes */ public static function from(array $attributes): self { @@ -28,7 +28,7 @@ public static function from(array $attributes): self } /** - * @return array{text: string, index: int, logprobs: array{tokens: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string} + * @return array{text: string, index: int, logprobs: array{tokens: array, token_logprobs: array, top_logprobs: array|null, text_offset: array}|null, finish_reason: string|null} */ public function toArray(): array { diff --git a/src/Responses/Edits/CreateResponse.php b/src/Responses/Edits/CreateResponse.php index 846bbf8b..21ebd75d 100644 --- a/src/Responses/Edits/CreateResponse.php +++ b/src/Responses/Edits/CreateResponse.php @@ -13,7 +13,7 @@ final class CreateResponse implements Response { /** - * @use ArrayAccessible, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}}> + * @use ArrayAccessible, usage: array{prompt_tokens: int, completion_tokens: int, total_tokens: int}}> */ use ArrayAccessible; diff --git a/src/Streams/EventStream.php b/src/Streams/EventStream.php new file mode 100644 index 00000000..2cd14476 --- /dev/null +++ b/src/Streams/EventStream.php @@ -0,0 +1,40 @@ + + */ + public function read(): Generator + { + while (! $this->stream->eof()) { + $line = Utils::readLine($this->stream); + + if (! str_starts_with($line, 'data:')) { + continue; + } + + $data = trim(substr((string) strstr($line, 'data:'), 5)); + + if ($data === '[DONE]') { + break; + } + + yield $data; + } + } +} diff --git a/src/Transporters/HttpTransporter.php b/src/Transporters/HttpTransporter.php index 4913804a..d2d8de94 100644 --- a/src/Transporters/HttpTransporter.php +++ b/src/Transporters/HttpTransporter.php @@ -4,16 +4,16 @@ namespace OpenAI\Transporters; +use GuzzleHttp\ClientInterface; use JsonException; use OpenAI\Contracts\Transporter; use OpenAI\Exceptions\ErrorException; use OpenAI\Exceptions\TransporterException; -use OpenAI\Exceptions\UnserializableResponse; use OpenAI\ValueObjects\Transporter\BaseUri; use OpenAI\ValueObjects\Transporter\Headers; use OpenAI\ValueObjects\Transporter\Payload; +use OpenAI\ValueObjects\Transporter\Response; use Psr\Http\Client\ClientExceptionInterface; -use Psr\Http\Client\ClientInterface; /** * @internal @@ -34,30 +34,18 @@ public function __construct( /** * {@inheritDoc} */ - public function requestObject(Payload $payload): array + public function requestObject(Payload $payload): Response { - $request = $payload->toRequest($this->baseUri, $this->headers); - try { - $response = $this->client->sendRequest($request); + return new Response( + $this->client->send( + $payload->toRequest($this->baseUri, $this->headers), + ['stream' => $payload->isStream()] + ) + ); } catch (ClientExceptionInterface $clientException) { throw new TransporterException($clientException); } - - $contents = $response->getBody()->getContents(); - - try { - /** @var array{error?: array{message: string, type: string, code: string}} $response */ - $response = json_decode($contents, true, 512, JSON_THROW_ON_ERROR); - } catch (JsonException $jsonException) { - throw new UnserializableResponse($jsonException); - } - - if (isset($response['error'])) { - throw new ErrorException($response['error']); - } - - return $response; } /** @@ -68,7 +56,7 @@ public function requestContent(Payload $payload): string $request = $payload->toRequest($this->baseUri, $this->headers); try { - $response = $this->client->sendRequest($request); + $response = $this->client->send($request); } catch (ClientExceptionInterface $clientException) { throw new TransporterException($clientException); } diff --git a/src/ValueObjects/Transporter/Payload.php b/src/ValueObjects/Transporter/Payload.php index 298b20b3..00f0c1a5 100644 --- a/src/ValueObjects/Transporter/Payload.php +++ b/src/ValueObjects/Transporter/Payload.php @@ -6,7 +6,6 @@ use GuzzleHttp\Psr7\MultipartStream; use GuzzleHttp\Psr7\Request as Psr7Request; -use OpenAI\Contracts\Request; use OpenAI\Enums\Transporter\ContentType; use OpenAI\Enums\Transporter\Method; use OpenAI\ValueObjects\ResourceUri; @@ -30,6 +29,14 @@ private function __construct( // .. } + /** + * Determines whether the stream response was requested. + */ + public function isStream(): bool + { + return isset($this->parameters['stream']) && is_bool($this->parameters['stream']) && $this->parameters['stream']; + } + /** * Creates a new Payload value object from the given parameters. */ diff --git a/src/ValueObjects/Transporter/Response.php b/src/ValueObjects/Transporter/Response.php new file mode 100644 index 00000000..206dc18c --- /dev/null +++ b/src/ValueObjects/Transporter/Response.php @@ -0,0 +1,73 @@ +response->getHeaderLine('Content-Type') === 'text/event-stream'; + } + + /** + * Returns decoded response object. + * + * @return array + */ + public function object(): array + { + return $this->decode($this->response->getBody()->getContents()); + } + + /** + * Returns stream generator with decoded response objects. + * + * @return Generator> + */ + public function stream(): Generator + { + foreach ((new EventStream($this->response->getBody()))->read() as $data) { + yield $this->decode($data); + } + } + + /** + * Decode object from raw json string. + * + * @return array + */ + private function decode(string $contents): array + { + $result = []; + + try { + /** @var array{error?: array{message: string, type: string, code: string}} $result */ + $result = json_decode($contents, true, 512, JSON_THROW_ON_ERROR); + } catch (JsonException $jsonException) { + throw new UnserializableResponse($jsonException); + } + + if (isset($result['error'])) { + throw new ErrorException($result['error']); + } + + return $result; + } +} diff --git a/tests/Pest.php b/tests/Pest.php index 72156a21..ff542b33 100644 --- a/tests/Pest.php +++ b/tests/Pest.php @@ -1,15 +1,39 @@ shouldReceive('getBody')->andReturn(Utils::streamFor( + is_array($response) ? json_encode($response) : $response + )); + $responseMock->shouldReceive('getHeaderLine')->andReturn('application/json'); + } elseif ($response instanceof Generator) { + $responseMock->shouldReceive('getBody')->andReturn(Utils::streamFor( + 'data: '.implode("\n\ndata: ", array_map( + fn ($value) => json_encode($value), + iterator_to_array($response) + ) + ))); + + $responseMock->shouldReceive('getHeaderLine')->andReturn('text/event-stream'); + } + } $transporter ->shouldReceive($methodName) @@ -22,7 +46,7 @@ function mockClient(string $method, string $resource, array $params, array|strin return $request->getMethod() === $method && $request->getUri()->getPath() === "/v1/$resource"; - })->andReturn($response); + })->andReturn($responseMock ? new Response($responseMock) : $response); return new Client($transporter); } diff --git a/tests/Resources/Completions.php b/tests/Resources/Completions.php index 1eb5f627..b9939218 100644 --- a/tests/Resources/Completions.php +++ b/tests/Resources/Completions.php @@ -36,3 +36,38 @@ ->completionTokens->toBe(16) ->totalTokens->toBe(17); }); + +test('create stream', function () { + $client = mockClient('POST', 'completions', [ + 'model' => 'da-vince', + 'prompt' => 'hi', + 'stream' => true, + ], (fn () => yield completion())()); + + $result = $client->completions()->create([ + 'model' => 'da-vince', + 'prompt' => 'hi', + 'stream' => true, + ]); + + expect($result)->toBeInstanceOf(CreateResponse::class); + + $data = iterator_to_array($result)[0]; + + ray($result); + + expect($data) + ->id->toBe('cmpl-5uS6a68SwurhqAqLBpZtibIITICna') + ->object->toBe('text_completion') + ->created->toBe(1664136088) + ->model->toBe('davinci') + ->choices->toBeArray()->toHaveCount(1) + ->choices->each->toBeInstanceOf(CreateResponseChoice::class) + ->usage->toBeInstanceOf(CreateResponseUsage::class); + + expect($data->choices[0]) + ->text->toBe("el, she elaborates more on the Corruptor's role, suggesting K") + ->index->toBe(0) + ->logprobs->toBe(null) + ->finishReason->toBe('length'); +}); diff --git a/tests/Streams/EventStream.php b/tests/Streams/EventStream.php new file mode 100644 index 00000000..48d9be63 --- /dev/null +++ b/tests/Streams/EventStream.php @@ -0,0 +1,26 @@ +read()); + + expect($arr[0])->toBe($jsonString); +}); + +test('skips empty stream lines', function () { + $stream = new EventStream(Utils::streamFor("data: {\"text\": \"Hey\"}\n\ndata: {\"text\": \"there!\"}")); + + expect(iterator_to_array($stream->read()))->toBe(['{"text": "Hey"}', '{"text": "there!"}']); +}); + +test('aborts after done message', function () { + $stream = new EventStream(Utils::streamFor("data: {\"text\": \"Hey\"}\ndata: [DONE]\ndata: {\"text\": \"there!\"}")); + + expect(iterator_to_array($stream->read()))->toBe(['{"text": "Hey"}']); +}); diff --git a/tests/Transporters/HttpTransporter.php b/tests/Transporters/HttpTransporter.php index b8806214..6b46479f 100644 --- a/tests/Transporters/HttpTransporter.php +++ b/tests/Transporters/HttpTransporter.php @@ -1,5 +1,6 @@ client = Mockery::mock(ClientInterface::class); @@ -34,7 +34,7 @@ ])); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->withArgs(function (Psr7Request $request) { expect($request->getMethod())->toBe('GET') @@ -62,13 +62,13 @@ ])); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andReturn($response); $response = $this->http->requestObject($payload); - expect($response)->toBe([ + expect($response->object())->toBe([ [ 'text' => 'Hey!', 'index' => 0, @@ -78,6 +78,23 @@ ]); }); +test('request object stream response', function () { + $payload = Payload::create('completions', ['stream' => true]); + + $response = new Response(200, ['Content-Type' => 'text/event-stream'], json_encode([ + 'qdwq', + ])); + + $this->client + ->shouldReceive('send') + ->once() + ->andReturn($response); + + $response = $this->http->requestObject($payload); + + expect($response->isStream())->toBeTrue(); +}); + test('request object server errors', function () { $payload = Payload::list('models'); @@ -91,11 +108,11 @@ ])); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andReturn($response); - expect(fn () => $this->http->requestObject($payload)) + expect(fn () => $this->http->requestObject($payload)->object()) ->toThrow(function (ErrorException $e) { expect($e->getMessage())->toBe('Incorrect API key provided: foo. You can find your API key at https://beta.openai.com.') ->and($e->getErrorMessage())->toBe('Incorrect API key provided: foo. You can find your API key at https://beta.openai.com.') @@ -111,7 +128,7 @@ $headers = Headers::withAuthorization(ApiToken::from('foo')); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andThrow(new ConnectException('Could not resolve host.', $payload->toRequest($baseUri, $headers))); @@ -128,11 +145,11 @@ $response = new Response(200, [], 'err'); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andReturn($response); - $this->http->requestObject($payload); + $this->http->requestObject($payload)->object(); })->throws(UnserializableResponse::class, 'Syntax error'); test('request content', function () { @@ -143,7 +160,7 @@ ])); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->withArgs(function (Psr7Request $request) { expect($request->getMethod())->toBe('GET') @@ -164,7 +181,7 @@ $response = new Response(200, [], 'My response content'); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andReturn($response); @@ -180,7 +197,7 @@ $headers = Headers::withAuthorization(ApiToken::from('foo')); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andThrow(new ConnectException('Could not resolve host.', $payload->toRequest($baseUri, $headers))); @@ -204,7 +221,7 @@ ])); $this->client - ->shouldReceive('sendRequest') + ->shouldReceive('send') ->once() ->andReturn($response);