Skip to content

Commit

Permalink
feat: Add tool calling support in ollama_dart (#504)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Jul 24, 2024
1 parent 30d05a6 commit 1ffdb41
Show file tree
Hide file tree
Showing 18 changed files with 1,540 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ extension ChatResultMapper on GenerateChatCompletionResponse {
return ChatResult(
id: id,
output: AIChatMessage(
content: message?.content ?? '',
content: message.content,
),
finishReason: _mapFinishReason(doneReason),
metadata: {
Expand Down
99 changes: 93 additions & 6 deletions packages/ollama_dart/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@ Unofficial Dart client for [Ollama](https://ollama.ai/) API.
**Supported endpoints:**

- Completions (with streaming support)
- Chat completions
- Chat completions (with streaming and tool calling support)
- Embeddings
- Models
- Blobs
- Version

## Table of contents

- [Usage](#usage)
* [Completions](#completions)
+ [Generate completion](#generate-completion)
+ [Stream completion](#stream-completion)
* [Chat completions](#chat-completions)
+ [Generate chat completion](#generate-chat-completion)
+ [Stream chat completion](#stream-chat-completion)
+ [Tool calling](#tool-calling)
* [Embeddings](#embeddings)
+ [Generate embedding](#generate-embedding)
* [Models](#models)
+ [Create model](#create-model)
+ [List models](#list-models)
Expand All @@ -54,7 +61,7 @@ Refer to the [documentation](https://github.com/jmorganca/ollama/blob/main/docs/

Given a prompt, the model will generate a response.

**Generate completion:**
#### Generate completion

```dart
final generated = await client.generateCompletion(
Expand All @@ -67,7 +74,7 @@ print(generated.response);
// The sky appears blue because of a phenomenon called Rayleigh scattering...
```

**Stream completion:**
#### Stream completion

```dart
final stream = client.generateCompletionStream(
Expand All @@ -88,7 +95,7 @@ print(text);

Given a prompt, the model will generate a response in a chat format.

**Generate chat completion:**
#### Generate chat completion

```dart
final res = await client.generateChatCompletion(
Expand All @@ -111,7 +118,7 @@ print(res);
// Message(role: MessageRole.assistant, content: 123456789)
```

**Stream chat completion:**
#### Stream chat completion

```dart
final stream = client.generateChatCompletionStream(
Expand Down Expand Up @@ -139,11 +146,91 @@ print(text);
// 123456789
```

#### Tool calling

Tool calling allows a model to respond to a given prompt by generating output that matches a user-defined schema, that you can then use to call the tools in your code and return the result back to the model to complete the conversation.

**Notes:**
- Tool calling requires Ollama 0.2.8 or newer.
- Streaming tool calls is not supported at the moment.
- Not all models support tool calls. Check the Ollama catalogue for models that have the `Tools` tag (e.g. [`llama3.1`](https://ollama.com/library/llama3.1)).

```dart
const tool = Tool(
function: ToolFunction(
name: 'get_current_weather',
description: 'Get the current weather in a given location',
parameters: {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and country, e.g. San Francisco, US',
},
'unit': {
'type': 'string',
'description': 'The unit of temperature to return',
'enum': ['celsius', 'fahrenheit'],
},
},
'required': ['location'],
},
),
);
const userMsg = Message(
role: MessageRole.user,
content: 'What’s the weather like in Barcelona in celsius?',
);
final res1 = await client.generateChatCompletion(
request: GenerateChatCompletionRequest(
model: 'llama3.1',
messages: [userMsg],
tools: [tool],
),
);
print(res1.message.toolCalls);
// [
// ToolCall(
// function:
// ToolCallFunction(
// name: get_current_weather,
// arguments: {
// location: Barcelona, ES,
// unit: celsius
// }
// )
// )
// ]
// Call your tool here. For this example, we'll just mock the response.
const toolResult = '{"location": "Barcelona, ES", "temperature": 20, "unit": "celsius"}';
// Submit the response of the tool call to the model
final res2 = await client.generateChatCompletion(
request: GenerateChatCompletionRequest(
model: 'llama3.1',
messages: [
userMsg,
res1.message,
Message(
role: MessageRole.tool,
content: toolResult,
),
],
),
);
print(res2.message.content);
// The current weather in Barcelona is 20°C.
```

### Embeddings

Given a prompt, the model will generate an embedding representing the prompt.

**Generate embedding:**
#### Generate embedding

```dart
final generated = await client.generateEmbedding(
Expand Down
80 changes: 77 additions & 3 deletions packages/ollama_dart/example/ollama_dart_example.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Future<void> main() async {
await _generateChatCompletion(client);
await _generateChatCompletionWithHistory(client);
await _generateChatCompletionStream(client);
await _generateChatToolCalling(client);

// Embeddings
await _generateEmbedding(client);
Expand Down Expand Up @@ -86,7 +87,7 @@ Future<void> _generateChatCompletion(final OllamaClient client) async {
],
),
);
print(generated.message?.content);
print(generated.message.content);
}

Future<void> _generateChatCompletionWithHistory(
Expand All @@ -111,7 +112,7 @@ Future<void> _generateChatCompletionWithHistory(
],
),
);
print(generated.message?.content);
print(generated.message.content);
}

Future<void> _generateChatCompletionStream(final OllamaClient client) async {
Expand All @@ -132,11 +133,84 @@ Future<void> _generateChatCompletionStream(final OllamaClient client) async {
);
String text = '';
await for (final res in stream) {
text += (res.message?.content ?? '').trim();
text += res.message.content.trim();
}
print(text);
}

Future<void> _generateChatToolCalling(final OllamaClient client) async {
const tool = Tool(
function: ToolFunction(
name: 'get_current_weather',
description: 'Get the current weather in a given location',
parameters: {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and country, e.g. San Francisco, US',
},
'unit': {
'type': 'string',
'description': 'The unit of temperature to return',
'enum': ['celsius', 'fahrenheit'],
},
},
'required': ['location'],
},
),
);

const userMsg = Message(
role: MessageRole.user,
content: 'What’s the weather like in Barcelona in celsius?',
);

final res1 = await client.generateChatCompletion(
request: const GenerateChatCompletionRequest(
model: 'llama3.1',
messages: [userMsg],
tools: [tool],
keepAlive: 1,
),
);

print(res1.message.toolCalls);
// [
// ToolCall(
// function:
// ToolCallFunction(
// name: get_current_weather,
// arguments: {
// location: Barcelona, ES,
// unit: celsius
// }
// )
// )
// ]

// Call your tool here. For this example, we'll just mock the response.
const toolResult =
'{"location": "Barcelona, ES", "temperature": 20, "unit": "celsius"}';

// Submit the response of the tool call to the model
final res2 = await client.generateChatCompletion(
request: GenerateChatCompletionRequest(
model: 'llama3.1',
messages: [
userMsg,
res1.message,
const Message(
role: MessageRole.tool,
content: toolResult,
),
],
),
);
print(res2.message.content);
// The current weather in Barcelona is 20°C.
}

Future<void> _generateEmbedding(final OllamaClient client) async {
final generated = await client.generateEmbedding(
request: const GenerateEmbeddingRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class GenerateChatCompletionRequest with _$GenerateChatCompletionRequest {
/// - If set to 0, the model will be unloaded immediately once finished.
/// - If not set, the model will stay loaded for 5 minutes by default
@JsonKey(name: 'keep_alive', includeIfNull: false) int? keepAlive,

/// A list of tools the model may call.
@JsonKey(includeIfNull: false) List<Tool>? tools,
}) = _GenerateChatCompletionRequest;

/// Object construction from a JSON representation
Expand All @@ -60,7 +63,8 @@ class GenerateChatCompletionRequest with _$GenerateChatCompletionRequest {
'format',
'options',
'stream',
'keep_alive'
'keep_alive',
'tools'
];

/// Perform validations on the schema property values
Expand All @@ -77,6 +81,7 @@ class GenerateChatCompletionRequest with _$GenerateChatCompletionRequest {
'options': options,
'stream': stream,
'keep_alive': keepAlive,
'tools': tools,
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ class GenerateChatCompletionResponse with _$GenerateChatCompletionResponse {
/// Factory constructor for GenerateChatCompletionResponse
const factory GenerateChatCompletionResponse({
/// A message in the chat endpoint
@JsonKey(includeIfNull: false) Message? message,
required Message message,

/// The model name.
///
/// Model names follow a `model:tag` format. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
@JsonKey(includeIfNull: false) String? model,
required String model,

/// Date on which a model was created.
@JsonKey(name: 'created_at', includeIfNull: false) String? createdAt,
@JsonKey(name: 'created_at') required String createdAt,

/// Whether the response has completed.
@JsonKey(includeIfNull: false) bool? done,
required bool done,

/// Reason why the model is done generating a response.
@JsonKey(
Expand Down
14 changes: 13 additions & 1 deletion packages/ollama_dart/lib/src/generated/schema/message.dart
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,23 @@ class Message with _$Message {

/// (optional) a list of Base64-encoded images to include in the message (for multimodal models such as llava)
@JsonKey(includeIfNull: false) List<String>? images,

/// A list of tools the model wants to call.
@JsonKey(name: 'tool_calls', includeIfNull: false)
List<ToolCall>? toolCalls,
}) = _Message;

/// Object construction from a JSON representation
factory Message.fromJson(Map<String, dynamic> json) =>
_$MessageFromJson(json);

/// List of all property names of schema
static const List<String> propertyNames = ['role', 'content', 'images'];
static const List<String> propertyNames = [
'role',
'content',
'images',
'tool_calls'
];

/// Perform validations on the schema property values
String? validateSchema() {
Expand All @@ -43,6 +52,7 @@ class Message with _$Message {
'role': role,
'content': content,
'images': images,
'tool_calls': toolCalls,
};
}
}
Expand All @@ -59,4 +69,6 @@ enum MessageRole {
user,
@JsonValue('assistant')
assistant,
@JsonValue('tool')
tool,
}
6 changes: 6 additions & 0 deletions packages/ollama_dart/lib/src/generated/schema/schema.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ part 'generate_chat_completion_request.dart';
part 'generate_chat_completion_response.dart';
part 'done_reason.dart';
part 'message.dart';
part 'tool.dart';
part 'tool_function.dart';
part 'tool_function_params.dart';
part 'tool_call.dart';
part 'tool_call_function.dart';
part 'tool_call_function_args.dart';
part 'generate_embedding_request.dart';
part 'generate_embedding_response.dart';
part 'create_model_request.dart';
Expand Down
Loading

0 comments on commit 1ffdb41

Please sign in to comment.