From 1a4ccd1e7d1907e340ce609cc6ba8d0543ee3421 Mon Sep 17 00:00:00 2001 From: Kelvin Dupont <116877495+kndpt@users.noreply.github.com> Date: Fri, 19 Jan 2024 13:42:35 +0100 Subject: [PATCH] refactor: Update safe_mode and max temperature in Mistral chat (#300) - `safe_mode` was renamed to `safe_prompt` - The max temperature is now 1 instead of 2 --- .../lib/src/chat_models/chat_mistralai.dart | 2 +- .../lib/src/chat_models/models/models.dart | 8 +-- .../test/chat_models/chat_mistralai_test.dart | 4 +- .../schema/chat_completion_request.dart | 14 ++--- .../src/generated/schema/schema.freezed.dart | 60 ++++++++++--------- .../lib/src/generated/schema/schema.g.dart | 4 +- .../oas/mistral_openapi_curated.yaml | 9 +-- .../oas/mistral_openapi_official.yaml | 50 ++++++++-------- .../mistralai_dart_chat_completions_test.dart | 4 +- 9 files changed, 79 insertions(+), 76 deletions(-) diff --git a/packages/langchain_mistralai/lib/src/chat_models/chat_mistralai.dart b/packages/langchain_mistralai/lib/src/chat_models/chat_mistralai.dart index 09957556..ffd648ef 100644 --- a/packages/langchain_mistralai/lib/src/chat_models/chat_mistralai.dart +++ b/packages/langchain_mistralai/lib/src/chat_models/chat_mistralai.dart @@ -231,7 +231,7 @@ class ChatMistralAI extends BaseChatModel { temperature: options?.temperature ?? defaultOptions.temperature, topP: options?.topP ?? defaultOptions.topP, maxTokens: options?.maxTokens ?? defaultOptions.maxTokens, - safeMode: options?.safeMode ?? defaultOptions.safeMode, + safePrompt: options?.safePrompt ?? defaultOptions.safePrompt, randomSeed: options?.randomSeed ?? defaultOptions.randomSeed, stream: stream, ); diff --git a/packages/langchain_mistralai/lib/src/chat_models/models/models.dart b/packages/langchain_mistralai/lib/src/chat_models/models/models.dart index 666e3a59..4e32b71c 100644 --- a/packages/langchain_mistralai/lib/src/chat_models/models/models.dart +++ b/packages/langchain_mistralai/lib/src/chat_models/models/models.dart @@ -10,7 +10,7 @@ class ChatMistralAIOptions extends ChatModelOptions { this.temperature, this.topP, this.maxTokens, - this.safeMode, + this.safePrompt, this.randomSeed, }); @@ -40,7 +40,7 @@ class ChatMistralAIOptions extends ChatModelOptions { final int? maxTokens; /// Whether to inject a safety prompt before all conversations. - final bool? safeMode; + final bool? safePrompt; /// The seed to use for random sampling. /// If set, different calls will generate deterministic results. @@ -53,7 +53,7 @@ class ChatMistralAIOptions extends ChatModelOptions { final double? temperature, final double? topP, final int? maxTokens, - final bool? safeMode, + final bool? safePrompt, final int? randomSeed, }) { return ChatMistralAIOptions( @@ -61,7 +61,7 @@ class ChatMistralAIOptions extends ChatModelOptions { temperature: temperature ?? this.temperature, topP: topP ?? this.topP, maxTokens: maxTokens ?? this.maxTokens, - safeMode: safeMode ?? this.safeMode, + safePrompt: safePrompt ?? this.safePrompt, randomSeed: randomSeed ?? this.randomSeed, ); } diff --git a/packages/langchain_mistralai/test/chat_models/chat_mistralai_test.dart b/packages/langchain_mistralai/test/chat_models/chat_mistralai_test.dart index 5ae1bba4..a5849dcc 100644 --- a/packages/langchain_mistralai/test/chat_models/chat_mistralai_test.dart +++ b/packages/langchain_mistralai/test/chat_models/chat_mistralai_test.dart @@ -27,7 +27,7 @@ void main() { temperature: 0.1, topP: 0.5, maxTokens: 10, - safeMode: true, + safePrompt: true, randomSeed: 1234, ); @@ -35,7 +35,7 @@ void main() { expect(chatModel.defaultOptions.temperature, 0.1); expect(chatModel.defaultOptions.topP, 0.5); expect(chatModel.defaultOptions.maxTokens, 10); - expect(chatModel.defaultOptions.safeMode, true); + expect(chatModel.defaultOptions.safePrompt, true); expect(chatModel.defaultOptions.randomSeed, 1234); }); diff --git a/packages/mistralai_dart/lib/src/generated/schema/chat_completion_request.dart b/packages/mistralai_dart/lib/src/generated/schema/chat_completion_request.dart index 93dd7729..2947b543 100644 --- a/packages/mistralai_dart/lib/src/generated/schema/chat_completion_request.dart +++ b/packages/mistralai_dart/lib/src/generated/schema/chat_completion_request.dart @@ -18,10 +18,10 @@ class ChatCompletionRequest with _$ChatCompletionRequest { /// ID of the model to use. You can use the [List Available Models](https://docs.mistral.ai/api#operation/listModels) API to see all of your available models, or see our [Model overview](https://docs.mistral.ai/models) for model descriptions. @_ChatCompletionModelConverter() required ChatCompletionModel model, - /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. + /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be `user` or `system`. required List messages, - /// What sampling temperature to use, between 0.0 and 2.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + /// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. /// /// We generally recommend altering this or `top_p` but not both. @JsonKey(includeIfNull: false) @Default(0.7) double? temperature, @@ -40,9 +40,9 @@ class ChatCompletionRequest with _$ChatCompletionRequest { @JsonKey(includeIfNull: false) @Default(false) bool? stream, /// Whether to inject a safety prompt before all conversations. - @JsonKey(name: 'safe_mode', includeIfNull: false) + @JsonKey(name: 'safe_prompt', includeIfNull: false) @Default(false) - bool? safeMode, + bool? safePrompt, /// The seed to use for random sampling. If set, different calls will generate deterministic results. @JsonKey(name: 'random_seed', includeIfNull: false) int? randomSeed, @@ -60,14 +60,14 @@ class ChatCompletionRequest with _$ChatCompletionRequest { 'top_p', 'max_tokens', 'stream', - 'safe_mode', + 'safe_prompt', 'random_seed' ]; /// Validation constants static const temperatureDefaultValue = 0.7; static const temperatureMinValue = 0.0; - static const temperatureMaxValue = 2.0; + static const temperatureMaxValue = 1.0; static const topPDefaultValue = 1.0; static const topPMinValue = 0.0; static const topPMaxValue = 1.0; @@ -102,7 +102,7 @@ class ChatCompletionRequest with _$ChatCompletionRequest { 'top_p': topP, 'max_tokens': maxTokens, 'stream': stream, - 'safe_mode': safeMode, + 'safe_prompt': safePrompt, 'random_seed': randomSeed, }; } diff --git a/packages/mistralai_dart/lib/src/generated/schema/schema.freezed.dart b/packages/mistralai_dart/lib/src/generated/schema/schema.freezed.dart index b7f5102b..b06fef54 100644 --- a/packages/mistralai_dart/lib/src/generated/schema/schema.freezed.dart +++ b/packages/mistralai_dart/lib/src/generated/schema/schema.freezed.dart @@ -25,11 +25,11 @@ mixin _$ChatCompletionRequest { @_ChatCompletionModelConverter() ChatCompletionModel get model => throw _privateConstructorUsedError; - /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. + /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be `user` or `system`. List get messages => throw _privateConstructorUsedError; - /// What sampling temperature to use, between 0.0 and 2.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + /// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. /// /// We generally recommend altering this or `top_p` but not both. @JsonKey(includeIfNull: false) @@ -52,8 +52,8 @@ mixin _$ChatCompletionRequest { bool? get stream => throw _privateConstructorUsedError; /// Whether to inject a safety prompt before all conversations. - @JsonKey(name: 'safe_mode', includeIfNull: false) - bool? get safeMode => throw _privateConstructorUsedError; + @JsonKey(name: 'safe_prompt', includeIfNull: false) + bool? get safePrompt => throw _privateConstructorUsedError; /// The seed to use for random sampling. If set, different calls will generate deterministic results. @JsonKey(name: 'random_seed', includeIfNull: false) @@ -78,7 +78,7 @@ abstract class $ChatCompletionRequestCopyWith<$Res> { @JsonKey(name: 'top_p', includeIfNull: false) double? topP, @JsonKey(name: 'max_tokens', includeIfNull: false) int? maxTokens, @JsonKey(includeIfNull: false) bool? stream, - @JsonKey(name: 'safe_mode', includeIfNull: false) bool? safeMode, + @JsonKey(name: 'safe_prompt', includeIfNull: false) bool? safePrompt, @JsonKey(name: 'random_seed', includeIfNull: false) int? randomSeed}); $ChatCompletionModelCopyWith<$Res> get model; @@ -104,7 +104,7 @@ class _$ChatCompletionRequestCopyWithImpl<$Res, Object? topP = freezed, Object? maxTokens = freezed, Object? stream = freezed, - Object? safeMode = freezed, + Object? safePrompt = freezed, Object? randomSeed = freezed, }) { return _then(_value.copyWith( @@ -132,9 +132,9 @@ class _$ChatCompletionRequestCopyWithImpl<$Res, ? _value.stream : stream // ignore: cast_nullable_to_non_nullable as bool?, - safeMode: freezed == safeMode - ? _value.safeMode - : safeMode // ignore: cast_nullable_to_non_nullable + safePrompt: freezed == safePrompt + ? _value.safePrompt + : safePrompt // ignore: cast_nullable_to_non_nullable as bool?, randomSeed: freezed == randomSeed ? _value.randomSeed @@ -168,7 +168,7 @@ abstract class _$$ChatCompletionRequestImplCopyWith<$Res> @JsonKey(name: 'top_p', includeIfNull: false) double? topP, @JsonKey(name: 'max_tokens', includeIfNull: false) int? maxTokens, @JsonKey(includeIfNull: false) bool? stream, - @JsonKey(name: 'safe_mode', includeIfNull: false) bool? safeMode, + @JsonKey(name: 'safe_prompt', includeIfNull: false) bool? safePrompt, @JsonKey(name: 'random_seed', includeIfNull: false) int? randomSeed}); @override @@ -193,7 +193,7 @@ class __$$ChatCompletionRequestImplCopyWithImpl<$Res> Object? topP = freezed, Object? maxTokens = freezed, Object? stream = freezed, - Object? safeMode = freezed, + Object? safePrompt = freezed, Object? randomSeed = freezed, }) { return _then(_$ChatCompletionRequestImpl( @@ -221,9 +221,9 @@ class __$$ChatCompletionRequestImplCopyWithImpl<$Res> ? _value.stream : stream // ignore: cast_nullable_to_non_nullable as bool?, - safeMode: freezed == safeMode - ? _value.safeMode - : safeMode // ignore: cast_nullable_to_non_nullable + safePrompt: freezed == safePrompt + ? _value.safePrompt + : safePrompt // ignore: cast_nullable_to_non_nullable as bool?, randomSeed: freezed == randomSeed ? _value.randomSeed @@ -243,7 +243,8 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest { @JsonKey(name: 'top_p', includeIfNull: false) this.topP = 1.0, @JsonKey(name: 'max_tokens', includeIfNull: false) this.maxTokens, @JsonKey(includeIfNull: false) this.stream = false, - @JsonKey(name: 'safe_mode', includeIfNull: false) this.safeMode = false, + @JsonKey(name: 'safe_prompt', includeIfNull: false) + this.safePrompt = false, @JsonKey(name: 'random_seed', includeIfNull: false) this.randomSeed}) : _messages = messages, super._(); @@ -256,10 +257,10 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest { @_ChatCompletionModelConverter() final ChatCompletionModel model; - /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. + /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be `user` or `system`. final List _messages; - /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. + /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be `user` or `system`. @override List get messages { if (_messages is EqualUnmodifiableListView) return _messages; @@ -267,7 +268,7 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest { return EqualUnmodifiableListView(_messages); } - /// What sampling temperature to use, between 0.0 and 2.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + /// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. /// /// We generally recommend altering this or `top_p` but not both. @override @@ -295,8 +296,8 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest { /// Whether to inject a safety prompt before all conversations. @override - @JsonKey(name: 'safe_mode', includeIfNull: false) - final bool? safeMode; + @JsonKey(name: 'safe_prompt', includeIfNull: false) + final bool? safePrompt; /// The seed to use for random sampling. If set, different calls will generate deterministic results. @override @@ -305,7 +306,7 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest { @override String toString() { - return 'ChatCompletionRequest(model: $model, messages: $messages, temperature: $temperature, topP: $topP, maxTokens: $maxTokens, stream: $stream, safeMode: $safeMode, randomSeed: $randomSeed)'; + return 'ChatCompletionRequest(model: $model, messages: $messages, temperature: $temperature, topP: $topP, maxTokens: $maxTokens, stream: $stream, safePrompt: $safePrompt, randomSeed: $randomSeed)'; } @override @@ -321,8 +322,8 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest { (identical(other.maxTokens, maxTokens) || other.maxTokens == maxTokens) && (identical(other.stream, stream) || other.stream == stream) && - (identical(other.safeMode, safeMode) || - other.safeMode == safeMode) && + (identical(other.safePrompt, safePrompt) || + other.safePrompt == safePrompt) && (identical(other.randomSeed, randomSeed) || other.randomSeed == randomSeed)); } @@ -337,7 +338,7 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest { topP, maxTokens, stream, - safeMode, + safePrompt, randomSeed); @JsonKey(ignore: true) @@ -364,7 +365,8 @@ abstract class _ChatCompletionRequest extends ChatCompletionRequest { @JsonKey(name: 'top_p', includeIfNull: false) final double? topP, @JsonKey(name: 'max_tokens', includeIfNull: false) final int? maxTokens, @JsonKey(includeIfNull: false) final bool? stream, - @JsonKey(name: 'safe_mode', includeIfNull: false) final bool? safeMode, + @JsonKey(name: 'safe_prompt', includeIfNull: false) + final bool? safePrompt, @JsonKey(name: 'random_seed', includeIfNull: false) final int? randomSeed}) = _$ChatCompletionRequestImpl; const _ChatCompletionRequest._() : super._(); @@ -379,11 +381,11 @@ abstract class _ChatCompletionRequest extends ChatCompletionRequest { ChatCompletionModel get model; @override - /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. + /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be `user` or `system`. List get messages; @override - /// What sampling temperature to use, between 0.0 and 2.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + /// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. /// /// We generally recommend altering this or `top_p` but not both. @JsonKey(includeIfNull: false) @@ -410,8 +412,8 @@ abstract class _ChatCompletionRequest extends ChatCompletionRequest { @override /// Whether to inject a safety prompt before all conversations. - @JsonKey(name: 'safe_mode', includeIfNull: false) - bool? get safeMode; + @JsonKey(name: 'safe_prompt', includeIfNull: false) + bool? get safePrompt; @override /// The seed to use for random sampling. If set, different calls will generate deterministic results. diff --git a/packages/mistralai_dart/lib/src/generated/schema/schema.g.dart b/packages/mistralai_dart/lib/src/generated/schema/schema.g.dart index 115fdd22..1a5b0aec 100644 --- a/packages/mistralai_dart/lib/src/generated/schema/schema.g.dart +++ b/packages/mistralai_dart/lib/src/generated/schema/schema.g.dart @@ -19,7 +19,7 @@ _$ChatCompletionRequestImpl _$$ChatCompletionRequestImplFromJson( topP: (json['top_p'] as num?)?.toDouble() ?? 1.0, maxTokens: json['max_tokens'] as int?, stream: json['stream'] as bool? ?? false, - safeMode: json['safe_mode'] as bool? ?? false, + safePrompt: json['safe_prompt'] as bool? ?? false, randomSeed: json['random_seed'] as int?, ); @@ -40,7 +40,7 @@ Map _$$ChatCompletionRequestImplToJson( writeNotNull('top_p', instance.topP); writeNotNull('max_tokens', instance.maxTokens); writeNotNull('stream', instance.stream); - writeNotNull('safe_mode', instance.safeMode); + writeNotNull('safe_prompt', instance.safePrompt); writeNotNull('random_seed', instance.randomSeed); return val; } diff --git a/packages/mistralai_dart/oas/mistral_openapi_curated.yaml b/packages/mistralai_dart/oas/mistral_openapi_curated.yaml index bd0d9a07..1adc44b4 100644 --- a/packages/mistralai_dart/oas/mistral_openapi_curated.yaml +++ b/packages/mistralai_dart/oas/mistral_openapi_curated.yaml @@ -102,19 +102,20 @@ components: messages: description: > The prompt(s) to generate completions for, encoded as a list of dict - with role and content. + with role and content. The first prompt role should be `user` or + `system`. type: array items: $ref: '#/components/schemas/ChatCompletionMessage' temperature: type: number minimum: 0 - maximum: 2 + maximum: 1 default: 0.7 example: 0.7 nullable: true description: > - What sampling temperature to use, between 0.0 and 2.0. Higher values + What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. @@ -156,7 +157,7 @@ components: stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON. - safe_mode: + safe_prompt: type: boolean default: false nullable: true diff --git a/packages/mistralai_dart/oas/mistral_openapi_official.yaml b/packages/mistralai_dart/oas/mistral_openapi_official.yaml index af28bb27..821097fb 100644 --- a/packages/mistralai_dart/oas/mistral_openapi_official.yaml +++ b/packages/mistralai_dart/oas/mistral_openapi_official.yaml @@ -105,7 +105,8 @@ components: messages: description: > The prompt(s) to generate completions for, encoded as a list of dict - with role and content. + with role and content. The first prompt role should be `user` or + `system`. type: array items: type: object @@ -113,8 +114,9 @@ components: role: type: string enum: + - system - user - - agent + - assistant content: type: string example: @@ -123,12 +125,12 @@ components: temperature: type: number minimum: 0 - maximum: 2 + maximum: 1 default: 0.7 example: 0.7 nullable: true description: > - What sampling temperature to use, between 0.0 and 2.0. Higher values + What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. @@ -170,7 +172,7 @@ components: stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON. - safe_mode: + safe_prompt: type: boolean default: false description: | @@ -212,26 +214,24 @@ components: type: integer example: 0 message: - type: array - items: - type: object - properties: - role: - type: string - enum: - - user - - assistant - example: assistant - content: - type: string - example: >- - I don't have a favorite condiment as I don't consume - food or condiments. However, I can tell you that many - people enjoy using ketchup, mayonnaise, hot sauce, soy - sauce, or mustard as condiments to enhance the flavor of - their meals. Some people also enjoy using herbs, spices, - or vinegars as condiments. Ultimately, the best - condiment is a matter of personal preference. + type: object + properties: + role: + type: string + enum: + - user + - assistant + example: assistant + content: + type: string + example: >- + I don't have a favorite condiment as I don't consume food + or condiments. However, I can tell you that many people + enjoy using ketchup, mayonnaise, hot sauce, soy sauce, or + mustard as condiments to enhance the flavor of their + meals. Some people also enjoy using herbs, spices, or + vinegars as condiments. Ultimately, the best condiment is + a matter of personal preference. finish_reason: type: string enum: diff --git a/packages/mistralai_dart/test/mistralai_dart_chat_completions_test.dart b/packages/mistralai_dart/test/mistralai_dart_chat_completions_test.dart index 7f7e4838..82aa0344 100644 --- a/packages/mistralai_dart/test/mistralai_dart_chat_completions_test.dart +++ b/packages/mistralai_dart/test/mistralai_dart_chat_completions_test.dart @@ -138,10 +138,10 @@ void main() { expect(choice1.message?.content, choice2.message?.content); }); - test('Test response safe_mode on', () async { + test('Test response safe_prompt on', () async { const request = ChatCompletionRequest( model: ChatCompletionModel.model(ChatCompletionModels.mistralTiny), - safeMode: true, + safePrompt: true, messages: [ ChatCompletionMessage( role: ChatCompletionMessageRole.user,