Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions docs/plugins/ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,22 @@ example:
ollama pull gemma
```

To use this plugin, specify it when you call `configureGenkit()`.
To use the plugin, specify it when you call genkit:

```js
```typescript
import { genkit } from 'genkit';
import { ollama } from 'genkitx-ollama';

export default configureGenkit({
const ai = genkit({
plugins: [
ollama({
models: [
{
name: 'gemma',
type: 'generate', // type: 'chat' | 'generate' | undefined
type: 'generate', // Options: 'chat' | 'generate' |
},
],
serverAddress: 'http://127.0.0.1:11434', // default local address
serverAddress: 'http://127.0.0.1:11434', // default serverAddress to use
}),
],
});
Expand Down Expand Up @@ -64,7 +65,7 @@ the Google Auth library:
```js
import { GoogleAuth } from 'google-auth-library';
import { ollama, OllamaPluginParams } from 'genkitx-ollama';
import { configureGenkit, isDevEnv } from '@genkit-ai/core';
import { genkit, isDevEnv } from '@genkit-ai/core';

const ollamaCommon = { models: [{ name: 'gemma:2b' }] };

Expand All @@ -82,7 +83,7 @@ const ollamaProd = {
},
} as OllamaPluginParams;

export default configureGenkit({
const ai = genkit({
plugins: [
ollama(isDevEnv() ? ollamaDev : ollamaProd),
],
Expand Down Expand Up @@ -117,8 +118,33 @@ This plugin doesn't statically export model references. Specify one of the
models you configured using a string identifier:

```js
const llmResponse = await generate({
model: 'ollama/gemma',
const llmResponse = await ai.generate({
model: 'ollama/gemma:2b',
prompt: 'Tell me a joke.',
});
```

## Embedders
The Ollama plugin supports embeddings, which can be used for similarity searches and other NLP tasks.

```typescript
const ai = genkit({
plugins: [
ollama({
serverAddress: 'http://localhost:11434',
embedders: [{ name: 'nomic-embed-text', dimensions: 768 }],
}),
],
});

async function getEmbedding() {
const embedding = await ai.embed({
embedder: 'ollama/nomic-embed-text',
content: 'Some text to embed!',
})

return embedding;
}

getEmbedding().then((e) => console.log(e))
```
1 change: 1 addition & 0 deletions js/plugins/ollama/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"devDependencies": {
"@types/node": "^20.11.16",
"npm-run-all": "^4.1.5",
"ollama": "^0.5.9",
"rimraf": "^6.0.1",
"tsup": "^8.0.2",
"tsx": "^4.7.0",
Expand Down
120 changes: 74 additions & 46 deletions js/plugins/ollama/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,50 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { Genkit } from 'genkit';
import { logger } from 'genkit/logging';
import { OllamaPluginParams } from './index.js';
import { Document, Genkit } from 'genkit';
import { EmbedRequest, EmbedResponse } from 'ollama';
import { DefineOllamaEmbeddingParams, RequestHeaders } from './types.js';

interface OllamaEmbeddingPrediction {
embedding: number[];
}
async function toOllamaEmbedRequest(
modelName: string,
dimensions: number,
documents: Document[],
serverAddress: string,
requestHeaders?: RequestHeaders
): Promise<{
url: string;
requestPayload: EmbedRequest;
headers: Record<string, string>;
}> {
const requestPayload: EmbedRequest = {
model: modelName,
input: documents.map((doc) => doc.text),
};

// Determine headers
const extraHeaders = requestHeaders
? typeof requestHeaders === 'function'
? await requestHeaders({
serverAddress,
model: {
name: modelName,
dimensions,
},
embedRequest: requestPayload,
})
: requestHeaders
: {};

interface DefineOllamaEmbeddingParams {
name: string;
modelName: string;
dimensions: number;
options: OllamaPluginParams;
const headers = {
'Content-Type': 'application/json',
...extraHeaders, // Add any dynamic headers
};

return {
url: `${serverAddress}/api/embed`,
requestPayload,
headers,
};
}

export function defineOllamaEmbedder(
Expand All @@ -34,50 +65,47 @@ export function defineOllamaEmbedder(
) {
return ai.defineEmbedder(
{
name,
name: `ollama/${name}`,
info: {
label: 'Ollama Embedding - ' + modelName,
label: 'Ollama Embedding - ' + name,
dimensions,
supports: {
// TODO: do any ollama models support other modalities?
input: ['text'],
},
},
},
async (input) => {
const serverAddress = options.serverAddress;
const responses = await Promise.all(
input.map(async (i) => {
const requestPayload = {
model: modelName,
prompt: i.text,
};
let res: Response;
try {
res = await fetch(`${serverAddress}/api/embeddings`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(requestPayload),
});
} catch (e) {
logger.error('Failed to fetch Ollama embedding');
throw new Error(`Error fetching embedding from Ollama: ${e}`);
}
if (!res.ok) {
logger.error('Failed to fetch Ollama embedding');
throw new Error(
`Error fetching embedding from Ollama: ${res.statusText}`
);
}
const responseData = (await res.json()) as OllamaEmbeddingPrediction;
return responseData;
})
async (input, config) => {
const serverAddress = config?.serverAddress || options.serverAddress;

const { url, requestPayload, headers } = await toOllamaEmbedRequest(
modelName,
dimensions,
input,
serverAddress,
options.requestHeaders
);
return {
embeddings: responses,
};

const response: Response = await fetch(url, {
method: 'POST',
headers,
body: JSON.stringify(requestPayload),
});

if (!response.ok) {
throw new Error(
`Error fetching embedding from Ollama: ${response.statusText}`
);
}

const payload: EmbedResponse = await response.json();

const embeddings: { embedding: number[] }[] = [];

for (const embedding of payload.embeddings) {
embeddings.push({ embedding });
}
return { embeddings };
}
);
}
75 changes: 29 additions & 46 deletions js/plugins/ollama/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,24 @@ import {
} from 'genkit/model';
import { GenkitPlugin, genkitPlugin } from 'genkit/plugin';
import { defineOllamaEmbedder } from './embeddings';
import {
ApiType,
ModelDefinition,
OllamaPluginParams,
RequestHeaders,
} from './types';

type ApiType = 'chat' | 'generate';

type RequestHeaders =
| Record<string, string>
| ((
params: { serverAddress: string; model: ModelDefinition },
request: GenerateRequest
) => Promise<Record<string, string> | void>);

type ModelDefinition = { name: string; type?: ApiType };
type EmbeddingModelDefinition = { name: string; dimensions: number };

export interface OllamaPluginParams {
models: ModelDefinition[];
embeddingModels?: EmbeddingModelDefinition[];

/**
* ollama server address.
*/
serverAddress: string;

requestHeaders?: RequestHeaders;
}
export { defineOllamaEmbedder };

export function ollama(params: OllamaPluginParams): GenkitPlugin {
return genkitPlugin('ollama', async (ai: Genkit) => {
const serverAddress = params?.serverAddress;
params.models.map((model) =>
const serverAddress = params.serverAddress;
params.models?.map((model) =>
ollamaModel(ai, model, serverAddress, params.requestHeaders)
);
params.embeddingModels?.map((model) =>
params.embedders?.map((model) =>
defineOllamaEmbedder(ai, {
name: `${ollama}/model.name`,
name: model.name,
modelName: model.name,
dimensions: model.dimensions,
options: params,
Expand All @@ -85,20 +69,20 @@ function ollamaModel(
},
async (input, streamingCallback) => {
const options: Record<string, any> = {};
if (input.config?.hasOwnProperty('temperature')) {
options.temperature = input.config?.temperature;
if (input.config?.temperature !== undefined) {
options.temperature = input.config.temperature;
}
if (input.config?.hasOwnProperty('topP')) {
options.top_p = input.config?.topP;
if (input.config?.topP !== undefined) {
options.top_p = input.config.topP;
}
if (input.config?.hasOwnProperty('topK')) {
options.top_k = input.config?.topK;
if (input.config?.topK !== undefined) {
options.top_k = input.config.topK;
}
if (input.config?.hasOwnProperty('stopSequences')) {
options.stop = input.config?.stopSequences?.join('');
if (input.config?.stopSequences !== undefined) {
options.stop = input.config.stopSequences.join('');
}
if (input.config?.hasOwnProperty('maxOutputTokens')) {
options.num_predict = input.config?.maxOutputTokens;
if (input.config?.maxOutputTokens !== undefined) {
options.num_predict = input.config.maxOutputTokens;
}
const type = model.type ?? 'chat';
const request = toOllamaRequest(
Expand Down Expand Up @@ -137,13 +121,12 @@ function ollamaModel(
);
} catch (e) {
const cause = (e as any).cause;
if (cause) {
if (
cause instanceof Error &&
cause.message?.includes('ECONNREFUSED')
) {
cause.message += '. Make sure ollama server is running.';
}
if (
cause &&
cause instanceof Error &&
cause.message?.includes('ECONNREFUSED')
) {
cause.message += '. Make sure the Ollama server is running.';
throw cause;
}
throw e;
Expand Down Expand Up @@ -225,11 +208,11 @@ function toOllamaRequest(
type: ApiType,
stream: boolean
) {
const request = {
const request: any = {
model: name,
options,
stream,
} as any;
};
if (type === 'chat') {
const messages: Message[] = [];
input.messages.forEach((m) => {
Expand Down
Loading
Loading