Skip to content

Commit

Permalink
Add streaming response generation
Browse files Browse the repository at this point in the history
Closes #21
  • Loading branch information
erdemkose committed Jan 5, 2024
1 parent 3eafe66 commit cda893e
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 0 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ _This library is not developed or endorsed by Google._
- [Multimodal input](#multimodal-input)
- [Chat Session (Multi-Turn Conversations)](#chat-session-multi-turn-conversations)
- [Chat Session with history](#chat-session-with-history)
- [Streaming responses](#streaming-responses)
- [Tokens counting](#tokens-counting)
- [Listing models](#listing-models)

Expand Down Expand Up @@ -150,6 +151,34 @@ func main() {
This code will print "Hello World!" to the standard output.
```

### Streaming responses

In the streaming response, the callback function will be called whenever a response is returned from the server.

Long responses may be broken into separate responses, and you can start receiving responses faster using a content stream.

```php
$client = new GeminiAPI\Client('GEMINI_API_KEY');

$callback = function (GenerateContentResponse $response): void {
static $count = 0;

print "\nResponse #{$count}\n";
print $response->text();
$count++;
};

$client->geminiPro()->generateContentStream(
$callback,
new TextPart('PHP in less than 100 chars')
);
// Response #0
// PHP: a versatile, general-purpose scripting language for web development, popular for
// Response #1
// its simple syntax and rich library of functions.
```


### Embed Content

```php
Expand Down
3 changes: 3 additions & 0 deletions composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
"phpstan/phpstan": "^1.10.50",
"phpunit/phpunit": "^10.5"
},
"suggest": {
"ext-curl": "Required for streaming responses"
},
"autoload": {
"psr-4": {
"GeminiAPI\\": "src/"
Expand Down
43 changes: 43 additions & 0 deletions src/Client.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@

namespace GeminiAPI;

use BadMethodCallException;
use CurlHandle;
use GeminiAPI\ClientInterface as GeminiClientInterface;
use GeminiAPI\Enums\ModelName;
use GeminiAPI\Json\ObjectListParser;
use GeminiAPI\Requests\CountTokensRequest;
use GeminiAPI\Requests\EmbedContentRequest;
use GeminiAPI\Requests\GenerateContentRequest;
use GeminiAPI\Requests\GenerateContentStreamRequest;
use GeminiAPI\Requests\ListModelsRequest;
use GeminiAPI\Requests\RequestInterface;
use GeminiAPI\Responses\CountTokensResponse;
Expand All @@ -23,6 +27,11 @@
use Psr\Http\Message\StreamFactoryInterface;
use RuntimeException;

use function curl_close;
use function curl_exec;
use function curl_init;
use function curl_setopt;
use function extension_loaded;
use function json_decode;

class Client implements GeminiClientInterface
Expand Down Expand Up @@ -76,6 +85,40 @@ public function generateContent(GenerateContentRequest $request): GenerateConten
return GenerateContentResponse::fromArray($json);
}

/**
* @param callable(GenerateContentResponse): void $callback
* @throws BadMethodCallException
* @throws RuntimeException
*/
public function generateContentStream(
GenerateContentStreamRequest $request,
callable $callback,
): void {
if (!extension_loaded('curl')) {
throw new BadMethodCallException('Gemini API requires `curl` extension for streaming responses');
}

$parser = new ObjectListParser(
/* @phpstan-ignore-next-line */
static fn (array $arr) => $callback(GenerateContentResponse::fromArray($arr)),
);

$ch = curl_init("{$this->baseUrl}/v1/{$request->getOperation()}");
curl_setopt($ch, CURLOPT_POST, true);

Check failure on line 107 in src/Client.php

View workflow job for this annotation

GitHub Actions / Code Quality (8.1)

Parameter #1 $handle of function curl_setopt expects CurlHandle, CurlHandle|false given.
curl_setopt($ch, CURLOPT_POSTFIELDS, json_encode($request));

Check failure on line 108 in src/Client.php

View workflow job for this annotation

GitHub Actions / Code Quality (8.1)

Parameter #1 $handle of function curl_setopt expects CurlHandle, CurlHandle|false given.
curl_setopt($ch, CURLOPT_HTTPHEADER, [

Check failure on line 109 in src/Client.php

View workflow job for this annotation

GitHub Actions / Code Quality (8.1)

Parameter #1 $handle of function curl_setopt expects CurlHandle, CurlHandle|false given.
'Content-type: application/json',
self::API_KEY_HEADER_NAME . ": {$this->apiKey}",
]);
curl_setopt(
$ch,

Check failure on line 114 in src/Client.php

View workflow job for this annotation

GitHub Actions / Code Quality (8.1)

Parameter #1 $handle of function curl_setopt expects CurlHandle, CurlHandle|false given.
CURLOPT_WRITEFUNCTION,
static fn (CurlHandle $ch, string $str): int => $parser->consume($str),
);
curl_exec($ch);

Check failure on line 118 in src/Client.php

View workflow job for this annotation

GitHub Actions / Code Quality (8.1)

Parameter #1 $handle of function curl_exec expects CurlHandle, CurlHandle|false given.
curl_close($ch);

Check failure on line 119 in src/Client.php

View workflow job for this annotation

GitHub Actions / Code Quality (8.1)

Parameter #1 $handle of function curl_close expects CurlHandle, CurlHandle|false given.
}

/**
* @throws ClientExceptionInterface
*/
Expand Down
24 changes: 24 additions & 0 deletions src/GenerativeModel.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

namespace GeminiAPI;

use BadMethodCallException;
use GeminiAPI\Enums\ModelName;
use GeminiAPI\Enums\Role;
use GeminiAPI\Requests\CountTokensRequest;
use GeminiAPI\Requests\GenerateContentRequest;
use GeminiAPI\Requests\GenerateContentStreamRequest;
use GeminiAPI\Responses\CountTokensResponse;
use GeminiAPI\Responses\GenerateContentResponse;
use GeminiAPI\Resources\Content;
Expand Down Expand Up @@ -58,6 +60,28 @@ public function generateContentWithContents(array $contents): GenerateContentRes
return $this->client->generateContent($request);
}

/**
* @param callable(GenerateContentResponse): void $callback
* @param PartInterface ...$parts
* @return void
* @throws BadMethodCallException
*/
public function generateContentStream(
callable $callback,
PartInterface ...$parts,
): void {
$content = new Content($parts, Role::User);

$request = new GenerateContentStreamRequest(
$this->modelName,
[$content],
$this->safetySettings,
$this->generationConfig,
);

$this->client->generateContentStream($request, $callback);
}

public function startChat(): ChatSession
{
return new ChatSession($this);
Expand Down
73 changes: 73 additions & 0 deletions src/Json/ObjectListParser.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
<?php

declare(strict_types=1);

namespace GeminiAPI\Json;

use RuntimeException;

class ObjectListParser
{
private int $depth = 0;
private bool $inString = false;
private bool $inEscape = false;
private string $json = '';

/** @var callable(array): void */
private $callback; // @phpstan-ignore-line

/**
* @phpstan-ignore-next-line
* @param callable(array): void $callback
*/
public function __construct(callable $callback)
{
$this->callback = $callback;
}

/**
* @param string $str
* @return int
* @throws RuntimeException
*/
public function consume(string $str): int
{
$offset = 0;
for ($i = 0; $i < strlen($str); $i++) {
if ($this->inEscape) {
$this->inEscape = false;
} elseif ($this->inString) {
if ($str[$i] === '\\') {
$this->inEscape = true;
} elseif ($str[$i] === '"') {
$this->inString = false;
}
} elseif ($str[$i] === '"') {
$this->inString = true;
} elseif ($str[$i] === '{') {
if ($this->depth === 0) {
$offset = $i;
}
$this->depth++;
} elseif ($str[$i] === '}') {
$this->depth--;
if ($this->depth === 0) {
$this->json .= substr($str, $offset, $i - $offset + 1);
$arr = json_decode($this->json, true);

if (json_last_error() !== JSON_ERROR_NONE) {
throw new RuntimeException('ObjectListParser could not decode the given message');
}

($this->callback)($arr);
$this->json = '';
$offset = $i + 1;
}
}
}

$this->json .= substr($str, $offset) ?: '';

return strlen($str);
}
}
81 changes: 81 additions & 0 deletions src/Requests/GenerateContentStreamRequest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
<?php

declare(strict_types=1);

namespace GeminiAPI\Requests;

use GeminiAPI\Enums\ModelName;
use GeminiAPI\GenerationConfig;
use GeminiAPI\SafetySetting;
use GeminiAPI\Traits\ArrayTypeValidator;
use GeminiAPI\Resources\Content;
use JsonSerializable;

use function json_encode;

class GenerateContentStreamRequest implements JsonSerializable, RequestInterface
{
use ArrayTypeValidator;

/**
* @param ModelName $modelName
* @param Content[] $contents
* @param SafetySetting[] $safetySettings
* @param GenerationConfig|null $generationConfig
*/
public function __construct(
public readonly ModelName $modelName,
public readonly array $contents,
public readonly array $safetySettings = [],
public readonly ?GenerationConfig $generationConfig = null,
) {
$this->ensureArrayOfType($this->contents, Content::class);
$this->ensureArrayOfType($this->safetySettings, SafetySetting::class);
}

public function getOperation(): string
{
return "{$this->modelName->value}:streamGenerateContent";
}

public function getHttpMethod(): string
{
return 'POST';
}

public function getHttpPayload(): string
{
return (string) $this;
}

/**
* @return array{
* model: string,
* contents: Content[],
* safetySettings?: SafetySetting[],
* generationConfig?: GenerationConfig,
* }
*/
public function jsonSerialize(): array
{
$arr = [
'model' => $this->modelName->value,
'contents' => $this->contents,
];

if (!empty($this->safetySettings)) {
$arr['safetySettings'] = $this->safetySettings;
}

if ($this->generationConfig) {
$arr['generationConfig'] = $this->generationConfig;
}

return $arr;
}

public function __toString(): string
{
return json_encode($this) ?: '';
}
}

0 comments on commit cda893e

Please sign in to comment.