Skip to content

Commit

Permalink
Allow custom streaming client
Browse files Browse the repository at this point in the history
  • Loading branch information
erdemkose committed Jan 5, 2024
1 parent 055a60e commit 6c60090
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 13 deletions.
82 changes: 81 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ _This library is not developed or endorsed by Google._
- [Streaming Chat Session](#streaming-chat-session)
- [Tokens counting](#tokens-counting)
- [Listing models](#listing-models)
- [Advanced Usages](#advanced-usages)
- [Safety Settings and Generation Configuration](#safety-settings-and-generation-configuration)
- [Using your own HTTP client](#using-your-own-http-client)
- [Using your own HTTP client for streaming responses](#using-your-own-http-client-for-streaming-responses)

## Installation

Expand Down Expand Up @@ -173,7 +177,7 @@ $callback = function (GenerateContentResponse $response): void {

$client->geminiPro()->generateContentStream(
$callback,
new TextPart('PHP in less than 100 chars')
[new TextPart('PHP in less than 100 chars')],
);
// Response #0
// PHP: a versatile, general-purpose scripting language for web development, popular for
Expand Down Expand Up @@ -287,3 +291,79 @@ print_r($response->models);
// )
//]
```

### Advanced Usages

#### Safety Settings and Generation Configuration

```php
$client = new GeminiAPI\Client('GEMINI_API_KEY');
$safetySetting = new GeminiAPI\SafetySetting(
HarmCategory::HARM_CATEGORY_HATE_SPEECH,
HarmBlockThreshold::BLOCK_LOW_AND_ABOVE,
);
$generationConfig = (new GeminiAPI\GenerationConfig())
->withCandidateCount(1)
->withMaxOutputTokens(40)
->withTemperature(0.5)
->withTopK(40)
->withTopP(0.6)
->withStopSequences(['STOP']);

$response = $client->geminiPro()
->withAddedSafetySetting($safetySetting)
->withGenerationConfig($generationConfig)
->generateContent(
new TextPart('PHP in less than 100 chars')
);
```

#### Using your own HTTP client

```php
$guzzle = new GuzzleHttp\Client([
'proxy' => 'http://localhost:8125',
]);
$client = new GeminiAPI\Client('GEMINI_API_KEY', $guzzle);

$response = $client->geminiPro()->generateContent(
new TextPart('PHP in less than 100 chars')
);
```

#### Using your own HTTP client for streaming responses

> Requires `curl` extension to be enabled
Since streaming responses are fetched using `curl` extension, they cannot use the custom HTTP client passed to the Gemini Client.
You need to pass a `CurlHandler` if you want to override connection options.

The following curl options will be overwritten by the Gemini Client.

- `CURLOPT_URL`
- `CURLOPT_POST`
- `CURLOPT_POSTFIELDS`
- `CURLOPT_WRITEFUNCTION`

You can also pass the headers you want to be used in the requests.

```php
$client = new GeminiAPI\Client('GEMINI_API_KEY');

$callback = function (GenerateContentResponse $response): void {
print $response->text();
};

$ch = curl_init();
curl_setopt($ch, CURLOPT_PROXY, 'http://localhost:8125');

$client->withRequestHeaders([
'User-Agent' => 'My Gemini-backed app'
])
->geminiPro()
->generateContentStream(
$callback,
[new TextPart('PHP in less than 100 chars')],
$ch,
);
```
44 changes: 39 additions & 5 deletions src/Client.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
use Psr\Http\Message\StreamFactoryInterface;
use RuntimeException;

use function array_map;
use function curl_close;
use function curl_exec;
use function curl_getinfo;
Expand All @@ -35,10 +36,17 @@
use function extension_loaded;
use function json_decode;
use function sprintf;
use function strtolower;

class Client implements GeminiClientInterface
{
private string $baseUrl = 'https://generativelanguage.googleapis.com';

/**
* @var array<string, string>
*/
private array $requestHeaders = [];

public function __construct(
private readonly string $apiKey,
private ?HttpClientInterface $client = null,
Expand Down Expand Up @@ -88,13 +96,16 @@ public function generateContent(GenerateContentRequest $request): GenerateConten
}

/**
* @param GenerateContentStreamRequest $request
* @param callable(GenerateContentResponse): void $callback
* @param CurlHandle|null $curl
* @throws BadMethodCallException
* @throws RuntimeException
*/
public function generateContentStream(
GenerateContentStreamRequest $request,
callable $callback,
?CurlHandle $curl = null,
): void {
if (!extension_loaded('curl')) {
throw new BadMethodCallException('Gemini API requires `curl` extension for streaming responses');
Expand All @@ -120,18 +131,25 @@ public function generateContentStream(
);
};

$ch = curl_init("{$this->baseUrl}/v1/{$request->getOperation()}");
$ch = $curl ?? curl_init();

if ($ch === false) {
throw new RuntimeException('Gemini API cannot initialize streaming content request');
}

$headers = $this->requestHeaders + [
'content-type' => 'application/json',
self::API_KEY_HEADER_NAME => $this->apiKey,
];
$headerLines = [];
foreach ($headers as $name => $value) {
$headerLines[] = "{$name}: {$value}";
}

curl_setopt($ch, CURLOPT_URL, "{$this->baseUrl}/v1/{$request->getOperation()}");
curl_setopt($ch, CURLOPT_POST, true);
curl_setopt($ch, CURLOPT_POSTFIELDS, json_encode($request));
curl_setopt($ch, CURLOPT_HTTPHEADER, [
'Content-type: application/json',
self::API_KEY_HEADER_NAME . ": {$this->apiKey}",
]);
curl_setopt($ch, CURLOPT_HTTPHEADER, $headerLines);
curl_setopt($ch, CURLOPT_WRITEFUNCTION, $writeFunction);
curl_exec($ch);
curl_close($ch);
Expand Down Expand Up @@ -179,6 +197,18 @@ public function withBaseUrl(string $baseUrl): self
return $clone;
}

/**
* @param array<string, string> $headers
* @return self
*/
public function withRequestHeaders(array $headers): self
{
$clone = clone $this;
$clone->requestHeaders = array_map(strtolower(...), $headers);

return $clone;
}

/**
* @throws ClientExceptionInterface
*/
Expand All @@ -193,6 +223,10 @@ private function doRequest(RequestInterface $request): string
->createRequest($request->getHttpMethod(), $uri)
->withAddedHeader(self::API_KEY_HEADER_NAME, $this->apiKey);

foreach ($this->requestHeaders as $name => $value) {
$httpRequest = $httpRequest->withAddedHeader($name, $value);
}

$payload = $request->getHttpPayload();
if (!empty($payload)) {
$stream = $this->streamFactory->createStream($payload);
Expand Down
22 changes: 15 additions & 7 deletions src/GenerativeModel.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace GeminiAPI;

use BadMethodCallException;
use CurlHandle;
use GeminiAPI\Enums\ModelName;
use GeminiAPI\Enums\Role;
use GeminiAPI\Requests\CountTokensRequest;
Expand Down Expand Up @@ -62,26 +63,33 @@ public function generateContentWithContents(array $contents): GenerateContentRes

/**
* @param callable(GenerateContentResponse): void $callback
* @param PartInterface ...$parts
* @param PartInterface[] $parts
* @param CurlHandle|null $ch
* @return void
* @throws BadMethodCallException
*/
public function generateContentStream(
callable $callback,
PartInterface ...$parts,
array $parts,
?CurlHandle $ch = null,
): void {
$this->ensureArrayOfType($parts, PartInterface::class);

$content = new Content($parts, Role::User);

$this->generateContentStreamWithContents($callback, [$content]);
$this->generateContentStreamWithContents($callback, [$content], $ch);
}

/**
* @param callable(GenerateContentResponse): void $callback
* @param Content[] $contents
* @param CurlHandle|null $ch
* @return void
*/
public function generateContentStreamWithContents(callable $callback, array $contents): void
{
public function generateContentStreamWithContents(
callable $callback,
array $contents,
?CurlHandle $ch = null,
): void {
$this->ensureArrayOfType($contents, Content::class);

$request = new GenerateContentStreamRequest(
Expand All @@ -91,7 +99,7 @@ public function generateContentStreamWithContents(callable $callback, array $con
$this->generationConfig,
);

$this->client->generateContentStream($request, $callback);
$this->client->generateContentStream($request, $callback, $ch);
}

public function startChat(): ChatSession
Expand Down

0 comments on commit 6c60090

Please sign in to comment.