Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming response generation #23

Merged
merged 4 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ _This library is not developed or endorsed by Google._
- [Multimodal input](#multimodal-input)
- [Chat Session (Multi-Turn Conversations)](#chat-session-multi-turn-conversations)
- [Chat Session with history](#chat-session-with-history)
- [Streaming responses](#streaming-responses)
- [Streaming Chat Session](#streaming-chat-session)
- [Tokens counting](#tokens-counting)
- [Listing models](#listing-models)

Expand Down Expand Up @@ -150,6 +152,85 @@ func main() {
This code will print "Hello World!" to the standard output.
```

### Streaming responses

> Requires `curl` extension to be enabled

In the streaming response, the callback function will be called whenever a response is returned from the server.

Long responses may be broken into separate responses, and you can start receiving responses faster using a content stream.

```php
$client = new GeminiAPI\Client('GEMINI_API_KEY');

$callback = function (GenerateContentResponse $response): void {
static $count = 0;

print "\nResponse #{$count}\n";
print $response->text();
$count++;
};

$client->geminiPro()->generateContentStream(
$callback,
new TextPart('PHP in less than 100 chars')
);
// Response #0
// PHP: a versatile, general-purpose scripting language for web development, popular for
// Response #1
// its simple syntax and rich library of functions.
```

### Streaming Chat Session

> Requires `curl` extension to be enabled

```php
$client = new GeminiAPI\Client('GEMINI_API_KEY');

$history = [
Content::text('Hello World in PHP', Role::User),
Content::text(
<<<TEXT
<?php
echo "Hello World!";
?>

This code will print "Hello World!" to the standard output.
TEXT,
Role::Model,
),
];
$chat = $client->geminiPro()
->startChat()
->withHistory($history);

$callback = function (GenerateContentResponse $response): void {
static $count = 0;

print "\nResponse #{$count}\n";
print $response->text();
$count++;
};

$chat->sendMessageStream($callback, new TextPart('in Go'));
```

```text
Response #0
package main

import "fmt"

func main() {

Response #1
fmt.Println("Hello World!")
}

This code will print "Hello World!" to the standard output.
```

### Embed Content

```php
Expand Down
3 changes: 3 additions & 0 deletions composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
"phpstan/phpstan": "^1.10.50",
"phpunit/phpunit": "^10.5"
},
"suggest": {
"ext-curl": "Required for streaming responses"
},
"autoload": {
"psr-4": {
"GeminiAPI\\": "src/"
Expand Down
31 changes: 31 additions & 0 deletions src/ChatSession.php
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,37 @@ public function sendMessage(PartInterface ...$parts): GenerateContentResponse
return $response;
}

/**
* @param callable(GenerateContentResponse): void $callback
* @param PartInterface ...$parts
* @return void
*/
public function sendMessageStream(
callable $callback,
PartInterface ...$parts,
): void {
$this->history[] = new Content($parts, Role::User);

$parts = [];
$partsCollectorCallback = function (GenerateContentResponse $response) use ($callback, &$parts) {
if(!empty($response->candidates)) {
array_push($parts, ...$response->parts());
}

$callback($response);
};

$config = (new GenerationConfig())
->withCandidateCount(1);
$this->model
->withGenerationConfig($config)
->generateContentStreamWithContents($partsCollectorCallback, $this->history);

if (!empty($parts)) {
$this->history[] = new Content($parts, Role::Model);
}
}

/**
* @return Content[]
*/
Expand Down
61 changes: 61 additions & 0 deletions src/Client.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@

namespace GeminiAPI;

use BadMethodCallException;
use CurlHandle;
use GeminiAPI\ClientInterface as GeminiClientInterface;
use GeminiAPI\Enums\ModelName;
use GeminiAPI\Json\ObjectListParser;
use GeminiAPI\Requests\CountTokensRequest;
use GeminiAPI\Requests\EmbedContentRequest;
use GeminiAPI\Requests\GenerateContentRequest;
use GeminiAPI\Requests\GenerateContentStreamRequest;
use GeminiAPI\Requests\ListModelsRequest;
use GeminiAPI\Requests\RequestInterface;
use GeminiAPI\Responses\CountTokensResponse;
Expand All @@ -23,7 +27,14 @@
use Psr\Http\Message\StreamFactoryInterface;
use RuntimeException;

use function curl_close;
use function curl_exec;
use function curl_getinfo;
use function curl_init;
use function curl_setopt;
use function extension_loaded;
use function json_decode;
use function sprintf;

class Client implements GeminiClientInterface
{
Expand Down Expand Up @@ -76,6 +87,56 @@ public function generateContent(GenerateContentRequest $request): GenerateConten
return GenerateContentResponse::fromArray($json);
}

/**
* @param callable(GenerateContentResponse): void $callback
* @throws BadMethodCallException
* @throws RuntimeException
*/
public function generateContentStream(
GenerateContentStreamRequest $request,
callable $callback,
): void {
if (!extension_loaded('curl')) {
throw new BadMethodCallException('Gemini API requires `curl` extension for streaming responses');
}

$parser = new ObjectListParser(
/* @phpstan-ignore-next-line */
static fn (array $arr) => $callback(GenerateContentResponse::fromArray($arr)),
);

$writeFunction = static function (CurlHandle $ch, string $str) use ($request, $parser): int {
$responseCode = curl_getinfo($ch, CURLINFO_RESPONSE_CODE);

return $responseCode === 200
? $parser->consume($str)
: throw new RuntimeException(
sprintf(
'Gemini API operation failed: operation=%s, status_code=%d, response=%s',
$request->getOperation(),
$responseCode,
$str,
),
);
};

$ch = curl_init("{$this->baseUrl}/v1/{$request->getOperation()}");

if ($ch === false) {
throw new RuntimeException('Gemini API cannot initialize streaming content request');
}

curl_setopt($ch, CURLOPT_POST, true);
curl_setopt($ch, CURLOPT_POSTFIELDS, json_encode($request));
curl_setopt($ch, CURLOPT_HTTPHEADER, [
'Content-type: application/json',
self::API_KEY_HEADER_NAME . ": {$this->apiKey}",
]);
curl_setopt($ch, CURLOPT_WRITEFUNCTION, $writeFunction);
curl_exec($ch);
curl_close($ch);
}

/**
* @throws ClientExceptionInterface
*/
Expand Down
36 changes: 36 additions & 0 deletions src/GenerativeModel.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

namespace GeminiAPI;

use BadMethodCallException;
use GeminiAPI\Enums\ModelName;
use GeminiAPI\Enums\Role;
use GeminiAPI\Requests\CountTokensRequest;
use GeminiAPI\Requests\GenerateContentRequest;
use GeminiAPI\Requests\GenerateContentStreamRequest;
use GeminiAPI\Responses\CountTokensResponse;
use GeminiAPI\Responses\GenerateContentResponse;
use GeminiAPI\Resources\Content;
Expand Down Expand Up @@ -58,6 +60,40 @@ public function generateContentWithContents(array $contents): GenerateContentRes
return $this->client->generateContent($request);
}

/**
* @param callable(GenerateContentResponse): void $callback
* @param PartInterface ...$parts
* @return void
* @throws BadMethodCallException
*/
public function generateContentStream(
callable $callback,
PartInterface ...$parts,
): void {
$content = new Content($parts, Role::User);

$this->generateContentStreamWithContents($callback, [$content]);
}

/**
* @param callable(GenerateContentResponse): void $callback
* @param Content[] $contents
* @return void
*/
public function generateContentStreamWithContents(callable $callback, array $contents): void
{
$this->ensureArrayOfType($contents, Content::class);

$request = new GenerateContentStreamRequest(
$this->modelName,
$contents,
$this->safetySettings,
$this->generationConfig,
);

$this->client->generateContentStream($request, $callback);
}

public function startChat(): ChatSession
{
return new ChatSession($this);
Expand Down
73 changes: 73 additions & 0 deletions src/Json/ObjectListParser.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
<?php

declare(strict_types=1);

namespace GeminiAPI\Json;

use RuntimeException;

class ObjectListParser
{
private int $depth = 0;
private bool $inString = false;
private bool $inEscape = false;
private string $json = '';

/** @var callable(array): void */
private $callback; // @phpstan-ignore-line

/**
* @phpstan-ignore-next-line
* @param callable(array): void $callback
*/
public function __construct(callable $callback)
{
$this->callback = $callback;
}

/**
* @param string $str
* @return int
* @throws RuntimeException
*/
public function consume(string $str): int
{
$offset = 0;
for ($i = 0; $i < strlen($str); $i++) {
if ($this->inEscape) {
$this->inEscape = false;
} elseif ($this->inString) {
if ($str[$i] === '\\') {
$this->inEscape = true;
} elseif ($str[$i] === '"') {
$this->inString = false;
}
} elseif ($str[$i] === '"') {
$this->inString = true;
} elseif ($str[$i] === '{') {
if ($this->depth === 0) {
$offset = $i;
}
$this->depth++;
} elseif ($str[$i] === '}') {
$this->depth--;
if ($this->depth === 0) {
$this->json .= substr($str, $offset, $i - $offset + 1);
$arr = json_decode($this->json, true);

if (json_last_error() !== JSON_ERROR_NONE) {
throw new RuntimeException('ObjectListParser could not decode the given message');
}

($this->callback)($arr);
$this->json = '';
$offset = $i + 1;
}
}
}

$this->json .= substr($str, $offset) ?: '';

return strlen($str);
}
}
Loading
Loading