Skip to content

Commit

Permalink
feat: Add support for done reason in ollama_dart (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed May 11, 2024
1 parent 97acab4 commit cc5b1b0
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 32 deletions.
14 changes: 7 additions & 7 deletions packages/ollama_dart/example/ollama_dart_example.dart
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Future<void> _generateEmbedding(final OllamaClient client) async {
Future<void> _createModel(final OllamaClient client) async {
final res = await client.createModel(
request: const CreateModelRequest(
name: 'mario',
model: 'mario',
modelfile:
'FROM mistral:latest\nSYSTEM You are mario from Super Mario Bros.',
),
Expand All @@ -161,7 +161,7 @@ Future<void> _createModel(final OllamaClient client) async {
Future<void> _createModelStream(final OllamaClient client) async {
final stream = client.createModelStream(
request: const CreateModelRequest(
name: 'mario',
model: 'mario',
modelfile:
'FROM mistral:latest\nSYSTEM You are mario from Super Mario Bros.',
),
Expand All @@ -178,21 +178,21 @@ Future<void> _listModels(final OllamaClient client) async {

Future<void> _showModelInfo(final OllamaClient client) async {
final res = await client.showModelInfo(
request: const ModelInfoRequest(name: 'mistral:latest'),
request: const ModelInfoRequest(model: 'mistral:latest'),
);
print(res);
}

Future<void> _pullModel(final OllamaClient client) async {
final res = await client.pullModel(
request: const PullModelRequest(name: 'yarn-llama2:13b-128k-q4_1'),
request: const PullModelRequest(model: 'yarn-llama2:13b-128k-q4_1'),
);
print(res.status);
}

Future<void> _pullModelStream(final OllamaClient client) async {
final stream = client.pullModelStream(
request: const PullModelRequest(name: 'yarn-llama2:13b-128k-q4_1'),
request: const PullModelRequest(model: 'yarn-llama2:13b-128k-q4_1'),
);
await for (final res in stream) {
print(res.status);
Expand All @@ -201,14 +201,14 @@ Future<void> _pullModelStream(final OllamaClient client) async {

Future<void> _pushModel(final OllamaClient client) async {
final res = await client.pushModel(
request: const PushModelRequest(name: 'mattw/pygmalion:latest'),
request: const PushModelRequest(model: 'mattw/pygmalion:latest'),
);
print(res.status);
}

Future<void> _pushModelStream(final OllamaClient client) async {
final stream = client.pushModelStream(
request: const PushModelRequest(name: 'mattw/pygmalion:latest'),
request: const PushModelRequest(model: 'mattw/pygmalion:latest'),
);
await for (final res in stream) {
print(res.status);
Expand Down
19 changes: 19 additions & 0 deletions packages/ollama_dart/lib/src/generated/schema/done_reason.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// coverage:ignore-file
// GENERATED CODE - DO NOT MODIFY BY HAND
// ignore_for_file: type=lint
// ignore_for_file: invalid_annotation_target
part of ollama_schema;

// ==========================================
// ENUM: DoneReason
// ==========================================

/// Reason why the model is done generating a response.
enum DoneReason {
@JsonValue('stop')
stop,
@JsonValue('length')
length,
@JsonValue('load')
load,
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ class GenerateChatCompletionResponse with _$GenerateChatCompletionResponse {
/// Whether the response has completed.
@JsonKey(includeIfNull: false) bool? done,

/// Reason the response is done.
@JsonKey(name: 'done_reason', includeIfNull: false) String? doneReason,
/// Reason why the model is done generating a response.
@JsonKey(
name: 'done_reason',
includeIfNull: false,
unknownEnumValue: JsonKey.nullForUndefinedEnumValue,
)
DoneReason? doneReason,

/// Time spent generating the response.
@JsonKey(name: 'total_duration', includeIfNull: false) int? totalDuration,
Expand Down
1 change: 1 addition & 0 deletions packages/ollama_dart/lib/src/generated/schema/schema.dart
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ part 'response_format.dart';
part 'generate_completion_response.dart';
part 'generate_chat_completion_request.dart';
part 'generate_chat_completion_response.dart';
part 'done_reason.dart';
part 'message.dart';
part 'generate_embedding_request.dart';
part 'generate_embedding_response.dart';
Expand Down
56 changes: 40 additions & 16 deletions packages/ollama_dart/lib/src/generated/schema/schema.freezed.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2496,9 +2496,12 @@ mixin _$GenerateChatCompletionResponse {
@JsonKey(includeIfNull: false)
bool? get done => throw _privateConstructorUsedError;

/// Reason the response is done.
@JsonKey(name: 'done_reason', includeIfNull: false)
String? get doneReason => throw _privateConstructorUsedError;
/// Reason why the model is done generating a response.
@JsonKey(
name: 'done_reason',
includeIfNull: false,
unknownEnumValue: JsonKey.nullForUndefinedEnumValue)
DoneReason? get doneReason => throw _privateConstructorUsedError;

/// Time spent generating the response.
@JsonKey(name: 'total_duration', includeIfNull: false)
Expand Down Expand Up @@ -2543,7 +2546,11 @@ abstract class $GenerateChatCompletionResponseCopyWith<$Res> {
@JsonKey(includeIfNull: false) String? model,
@JsonKey(name: 'created_at', includeIfNull: false) String? createdAt,
@JsonKey(includeIfNull: false) bool? done,
@JsonKey(name: 'done_reason', includeIfNull: false) String? doneReason,
@JsonKey(
name: 'done_reason',
includeIfNull: false,
unknownEnumValue: JsonKey.nullForUndefinedEnumValue)
DoneReason? doneReason,
@JsonKey(name: 'total_duration', includeIfNull: false) int? totalDuration,
@JsonKey(name: 'load_duration', includeIfNull: false) int? loadDuration,
@JsonKey(name: 'prompt_eval_count', includeIfNull: false)
Expand Down Expand Up @@ -2602,7 +2609,7 @@ class _$GenerateChatCompletionResponseCopyWithImpl<$Res,
doneReason: freezed == doneReason
? _value.doneReason
: doneReason // ignore: cast_nullable_to_non_nullable
as String?,
as DoneReason?,
totalDuration: freezed == totalDuration
? _value.totalDuration
: totalDuration // ignore: cast_nullable_to_non_nullable
Expand Down Expand Up @@ -2657,7 +2664,11 @@ abstract class _$$GenerateChatCompletionResponseImplCopyWith<$Res>
@JsonKey(includeIfNull: false) String? model,
@JsonKey(name: 'created_at', includeIfNull: false) String? createdAt,
@JsonKey(includeIfNull: false) bool? done,
@JsonKey(name: 'done_reason', includeIfNull: false) String? doneReason,
@JsonKey(
name: 'done_reason',
includeIfNull: false,
unknownEnumValue: JsonKey.nullForUndefinedEnumValue)
DoneReason? doneReason,
@JsonKey(name: 'total_duration', includeIfNull: false) int? totalDuration,
@JsonKey(name: 'load_duration', includeIfNull: false) int? loadDuration,
@JsonKey(name: 'prompt_eval_count', includeIfNull: false)
Expand Down Expand Up @@ -2716,7 +2727,7 @@ class __$$GenerateChatCompletionResponseImplCopyWithImpl<$Res>
doneReason: freezed == doneReason
? _value.doneReason
: doneReason // ignore: cast_nullable_to_non_nullable
as String?,
as DoneReason?,
totalDuration: freezed == totalDuration
? _value.totalDuration
: totalDuration // ignore: cast_nullable_to_non_nullable
Expand Down Expand Up @@ -2754,7 +2765,11 @@ class _$GenerateChatCompletionResponseImpl
@JsonKey(includeIfNull: false) this.model,
@JsonKey(name: 'created_at', includeIfNull: false) this.createdAt,
@JsonKey(includeIfNull: false) this.done,
@JsonKey(name: 'done_reason', includeIfNull: false) this.doneReason,
@JsonKey(
name: 'done_reason',
includeIfNull: false,
unknownEnumValue: JsonKey.nullForUndefinedEnumValue)
this.doneReason,
@JsonKey(name: 'total_duration', includeIfNull: false) this.totalDuration,
@JsonKey(name: 'load_duration', includeIfNull: false) this.loadDuration,
@JsonKey(name: 'prompt_eval_count', includeIfNull: false)
Expand Down Expand Up @@ -2791,10 +2806,13 @@ class _$GenerateChatCompletionResponseImpl
@JsonKey(includeIfNull: false)
final bool? done;

/// Reason the response is done.
/// Reason why the model is done generating a response.
@override
@JsonKey(name: 'done_reason', includeIfNull: false)
final String? doneReason;
@JsonKey(
name: 'done_reason',
includeIfNull: false,
unknownEnumValue: JsonKey.nullForUndefinedEnumValue)
final DoneReason? doneReason;

/// Time spent generating the response.
@override
Expand Down Expand Up @@ -2897,8 +2915,11 @@ abstract class _GenerateChatCompletionResponse
@JsonKey(name: 'created_at', includeIfNull: false)
final String? createdAt,
@JsonKey(includeIfNull: false) final bool? done,
@JsonKey(name: 'done_reason', includeIfNull: false)
final String? doneReason,
@JsonKey(
name: 'done_reason',
includeIfNull: false,
unknownEnumValue: JsonKey.nullForUndefinedEnumValue)
final DoneReason? doneReason,
@JsonKey(name: 'total_duration', includeIfNull: false)
final int? totalDuration,
@JsonKey(name: 'load_duration', includeIfNull: false)
Expand Down Expand Up @@ -2939,9 +2960,12 @@ abstract class _GenerateChatCompletionResponse
bool? get done;
@override

/// Reason the response is done.
@JsonKey(name: 'done_reason', includeIfNull: false)
String? get doneReason;
/// Reason why the model is done generating a response.
@JsonKey(
name: 'done_reason',
includeIfNull: false,
unknownEnumValue: JsonKey.nullForUndefinedEnumValue)
DoneReason? get doneReason;
@override

/// Time spent generating the response.
Expand Down
12 changes: 10 additions & 2 deletions packages/ollama_dart/lib/src/generated/schema/schema.g.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 8 additions & 3 deletions packages/ollama_dart/oas/ollama-curated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,7 @@ components:
description: Whether the response has completed.
example: true
done_reason:
type: string
nullable: true
description: Reason the response is done.
$ref: '#/components/schemas/DoneReason'
total_duration:
type: integer
format: int64
Expand Down Expand Up @@ -604,6 +602,13 @@ components:
format: int64
description: Time in nanoseconds spent generating the response.
example: 1325948000
DoneReason:
type: string
description: Reason why the model is done generating a response.
enum:
- stop # The generation hit a stop token.
- length # The maximum num_tokens was reached.
- load # The request was sent with an empty body to load the model.
Message:
type: object
description: A message in the chat endpoint
Expand Down
25 changes: 23 additions & 2 deletions packages/ollama_dart/test/ollama_dart_chat_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ void main() {
isNotEmpty,
);
expect(response.done, isTrue);
expect(response.doneReason, DoneReason.stop);
expect(response.totalDuration, greaterThan(0));
expect(response.promptEvalCount, greaterThan(0));
expect(response.evalCount, greaterThan(0));
Expand Down Expand Up @@ -118,8 +119,7 @@ void main() {
Message(
role: MessageRole.user,
content: 'List the numbers from 1 to 9 in order. '
'Output ONLY the numbers in one line without any spaces or commas. '
'NUMBERS: ',
'Output ONLY the numbers without spaces or commas.',
),
],
options: RequestOptions(stop: ['4']),
Expand All @@ -128,6 +128,27 @@ void main() {
final generation = res.message?.content.replaceAll(RegExp(r'[\s\n]'), '');
expect(generation, contains('123'));
expect(generation, isNot(contains('456789')));
expect(res.doneReason, DoneReason.stop);
});

test('Test call chat completions API with max tokens', () async {
final res = await client.generateChatCompletion(
request: const GenerateChatCompletionRequest(
model: defaultModel,
messages: [
Message(
role: MessageRole.system,
content: 'You are a helpful assistant.',
),
Message(
role: MessageRole.user,
content: 'List the numbers from 1 to 9 in order.',
),
],
options: RequestOptions(numPredict: 1),
),
);
expect(res.doneReason, DoneReason.length);
});

test('Test call chat completions API with image', () async {
Expand Down

0 comments on commit cc5b1b0

Please sign in to comment.