Skip to content

Commit

Permalink
refactor: Update safe_mode and max temperature in Mistral chat (#300)
Browse files Browse the repository at this point in the history
- `safe_mode` was renamed to `safe_prompt`
- The max temperature is now 1 instead of 2
  • Loading branch information
kndpt committed Jan 19, 2024
1 parent cff2c58 commit 1a4ccd1
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class ChatMistralAI extends BaseChatModel<ChatMistralAIOptions> {
temperature: options?.temperature ?? defaultOptions.temperature,
topP: options?.topP ?? defaultOptions.topP,
maxTokens: options?.maxTokens ?? defaultOptions.maxTokens,
safeMode: options?.safeMode ?? defaultOptions.safeMode,
safePrompt: options?.safePrompt ?? defaultOptions.safePrompt,
randomSeed: options?.randomSeed ?? defaultOptions.randomSeed,
stream: stream,
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ChatMistralAIOptions extends ChatModelOptions {
this.temperature,
this.topP,
this.maxTokens,
this.safeMode,
this.safePrompt,
this.randomSeed,
});

Expand Down Expand Up @@ -40,7 +40,7 @@ class ChatMistralAIOptions extends ChatModelOptions {
final int? maxTokens;

/// Whether to inject a safety prompt before all conversations.
final bool? safeMode;
final bool? safePrompt;

/// The seed to use for random sampling.
/// If set, different calls will generate deterministic results.
Expand All @@ -53,15 +53,15 @@ class ChatMistralAIOptions extends ChatModelOptions {
final double? temperature,
final double? topP,
final int? maxTokens,
final bool? safeMode,
final bool? safePrompt,
final int? randomSeed,
}) {
return ChatMistralAIOptions(
model: model ?? this.model,
temperature: temperature ?? this.temperature,
topP: topP ?? this.topP,
maxTokens: maxTokens ?? this.maxTokens,
safeMode: safeMode ?? this.safeMode,
safePrompt: safePrompt ?? this.safePrompt,
randomSeed: randomSeed ?? this.randomSeed,
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ void main() {
temperature: 0.1,
topP: 0.5,
maxTokens: 10,
safeMode: true,
safePrompt: true,
randomSeed: 1234,
);

expect(chatModel.defaultOptions.model, 'foo');
expect(chatModel.defaultOptions.temperature, 0.1);
expect(chatModel.defaultOptions.topP, 0.5);
expect(chatModel.defaultOptions.maxTokens, 10);
expect(chatModel.defaultOptions.safeMode, true);
expect(chatModel.defaultOptions.safePrompt, true);
expect(chatModel.defaultOptions.randomSeed, 1234);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class ChatCompletionRequest with _$ChatCompletionRequest {
/// ID of the model to use. You can use the [List Available Models](https://docs.mistral.ai/api#operation/listModels) API to see all of your available models, or see our [Model overview](https://docs.mistral.ai/models) for model descriptions.
@_ChatCompletionModelConverter() required ChatCompletionModel model,

/// The prompt(s) to generate completions for, encoded as a list of dict with role and content.
/// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be `user` or `system`.
required List<ChatCompletionMessage> messages,

/// What sampling temperature to use, between 0.0 and 2.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
/// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
@JsonKey(includeIfNull: false) @Default(0.7) double? temperature,
Expand All @@ -40,9 +40,9 @@ class ChatCompletionRequest with _$ChatCompletionRequest {
@JsonKey(includeIfNull: false) @Default(false) bool? stream,

/// Whether to inject a safety prompt before all conversations.
@JsonKey(name: 'safe_mode', includeIfNull: false)
@JsonKey(name: 'safe_prompt', includeIfNull: false)
@Default(false)
bool? safeMode,
bool? safePrompt,

/// The seed to use for random sampling. If set, different calls will generate deterministic results.
@JsonKey(name: 'random_seed', includeIfNull: false) int? randomSeed,
Expand All @@ -60,14 +60,14 @@ class ChatCompletionRequest with _$ChatCompletionRequest {
'top_p',
'max_tokens',
'stream',
'safe_mode',
'safe_prompt',
'random_seed'
];

/// Validation constants
static const temperatureDefaultValue = 0.7;
static const temperatureMinValue = 0.0;
static const temperatureMaxValue = 2.0;
static const temperatureMaxValue = 1.0;
static const topPDefaultValue = 1.0;
static const topPMinValue = 0.0;
static const topPMaxValue = 1.0;
Expand Down Expand Up @@ -102,7 +102,7 @@ class ChatCompletionRequest with _$ChatCompletionRequest {
'top_p': topP,
'max_tokens': maxTokens,
'stream': stream,
'safe_mode': safeMode,
'safe_prompt': safePrompt,
'random_seed': randomSeed,
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ mixin _$ChatCompletionRequest {
@_ChatCompletionModelConverter()
ChatCompletionModel get model => throw _privateConstructorUsedError;

/// The prompt(s) to generate completions for, encoded as a list of dict with role and content.
/// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be `user` or `system`.
List<ChatCompletionMessage> get messages =>
throw _privateConstructorUsedError;

/// What sampling temperature to use, between 0.0 and 2.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
/// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
@JsonKey(includeIfNull: false)
Expand All @@ -52,8 +52,8 @@ mixin _$ChatCompletionRequest {
bool? get stream => throw _privateConstructorUsedError;

/// Whether to inject a safety prompt before all conversations.
@JsonKey(name: 'safe_mode', includeIfNull: false)
bool? get safeMode => throw _privateConstructorUsedError;
@JsonKey(name: 'safe_prompt', includeIfNull: false)
bool? get safePrompt => throw _privateConstructorUsedError;

/// The seed to use for random sampling. If set, different calls will generate deterministic results.
@JsonKey(name: 'random_seed', includeIfNull: false)
Expand All @@ -78,7 +78,7 @@ abstract class $ChatCompletionRequestCopyWith<$Res> {
@JsonKey(name: 'top_p', includeIfNull: false) double? topP,
@JsonKey(name: 'max_tokens', includeIfNull: false) int? maxTokens,
@JsonKey(includeIfNull: false) bool? stream,
@JsonKey(name: 'safe_mode', includeIfNull: false) bool? safeMode,
@JsonKey(name: 'safe_prompt', includeIfNull: false) bool? safePrompt,
@JsonKey(name: 'random_seed', includeIfNull: false) int? randomSeed});

$ChatCompletionModelCopyWith<$Res> get model;
Expand All @@ -104,7 +104,7 @@ class _$ChatCompletionRequestCopyWithImpl<$Res,
Object? topP = freezed,
Object? maxTokens = freezed,
Object? stream = freezed,
Object? safeMode = freezed,
Object? safePrompt = freezed,
Object? randomSeed = freezed,
}) {
return _then(_value.copyWith(
Expand Down Expand Up @@ -132,9 +132,9 @@ class _$ChatCompletionRequestCopyWithImpl<$Res,
? _value.stream
: stream // ignore: cast_nullable_to_non_nullable
as bool?,
safeMode: freezed == safeMode
? _value.safeMode
: safeMode // ignore: cast_nullable_to_non_nullable
safePrompt: freezed == safePrompt
? _value.safePrompt
: safePrompt // ignore: cast_nullable_to_non_nullable
as bool?,
randomSeed: freezed == randomSeed
? _value.randomSeed
Expand Down Expand Up @@ -168,7 +168,7 @@ abstract class _$$ChatCompletionRequestImplCopyWith<$Res>
@JsonKey(name: 'top_p', includeIfNull: false) double? topP,
@JsonKey(name: 'max_tokens', includeIfNull: false) int? maxTokens,
@JsonKey(includeIfNull: false) bool? stream,
@JsonKey(name: 'safe_mode', includeIfNull: false) bool? safeMode,
@JsonKey(name: 'safe_prompt', includeIfNull: false) bool? safePrompt,
@JsonKey(name: 'random_seed', includeIfNull: false) int? randomSeed});

@override
Expand All @@ -193,7 +193,7 @@ class __$$ChatCompletionRequestImplCopyWithImpl<$Res>
Object? topP = freezed,
Object? maxTokens = freezed,
Object? stream = freezed,
Object? safeMode = freezed,
Object? safePrompt = freezed,
Object? randomSeed = freezed,
}) {
return _then(_$ChatCompletionRequestImpl(
Expand Down Expand Up @@ -221,9 +221,9 @@ class __$$ChatCompletionRequestImplCopyWithImpl<$Res>
? _value.stream
: stream // ignore: cast_nullable_to_non_nullable
as bool?,
safeMode: freezed == safeMode
? _value.safeMode
: safeMode // ignore: cast_nullable_to_non_nullable
safePrompt: freezed == safePrompt
? _value.safePrompt
: safePrompt // ignore: cast_nullable_to_non_nullable
as bool?,
randomSeed: freezed == randomSeed
? _value.randomSeed
Expand All @@ -243,7 +243,8 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest {
@JsonKey(name: 'top_p', includeIfNull: false) this.topP = 1.0,
@JsonKey(name: 'max_tokens', includeIfNull: false) this.maxTokens,
@JsonKey(includeIfNull: false) this.stream = false,
@JsonKey(name: 'safe_mode', includeIfNull: false) this.safeMode = false,
@JsonKey(name: 'safe_prompt', includeIfNull: false)
this.safePrompt = false,
@JsonKey(name: 'random_seed', includeIfNull: false) this.randomSeed})
: _messages = messages,
super._();
Expand All @@ -256,18 +257,18 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest {
@_ChatCompletionModelConverter()
final ChatCompletionModel model;

/// The prompt(s) to generate completions for, encoded as a list of dict with role and content.
/// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be `user` or `system`.
final List<ChatCompletionMessage> _messages;

/// The prompt(s) to generate completions for, encoded as a list of dict with role and content.
/// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be `user` or `system`.
@override
List<ChatCompletionMessage> get messages {
if (_messages is EqualUnmodifiableListView) return _messages;
// ignore: implicit_dynamic_type
return EqualUnmodifiableListView(_messages);
}

/// What sampling temperature to use, between 0.0 and 2.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
/// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
@override
Expand Down Expand Up @@ -295,8 +296,8 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest {

/// Whether to inject a safety prompt before all conversations.
@override
@JsonKey(name: 'safe_mode', includeIfNull: false)
final bool? safeMode;
@JsonKey(name: 'safe_prompt', includeIfNull: false)
final bool? safePrompt;

/// The seed to use for random sampling. If set, different calls will generate deterministic results.
@override
Expand All @@ -305,7 +306,7 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest {

@override
String toString() {
return 'ChatCompletionRequest(model: $model, messages: $messages, temperature: $temperature, topP: $topP, maxTokens: $maxTokens, stream: $stream, safeMode: $safeMode, randomSeed: $randomSeed)';
return 'ChatCompletionRequest(model: $model, messages: $messages, temperature: $temperature, topP: $topP, maxTokens: $maxTokens, stream: $stream, safePrompt: $safePrompt, randomSeed: $randomSeed)';
}

@override
Expand All @@ -321,8 +322,8 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest {
(identical(other.maxTokens, maxTokens) ||
other.maxTokens == maxTokens) &&
(identical(other.stream, stream) || other.stream == stream) &&
(identical(other.safeMode, safeMode) ||
other.safeMode == safeMode) &&
(identical(other.safePrompt, safePrompt) ||
other.safePrompt == safePrompt) &&
(identical(other.randomSeed, randomSeed) ||
other.randomSeed == randomSeed));
}
Expand All @@ -337,7 +338,7 @@ class _$ChatCompletionRequestImpl extends _ChatCompletionRequest {
topP,
maxTokens,
stream,
safeMode,
safePrompt,
randomSeed);

@JsonKey(ignore: true)
Expand All @@ -364,7 +365,8 @@ abstract class _ChatCompletionRequest extends ChatCompletionRequest {
@JsonKey(name: 'top_p', includeIfNull: false) final double? topP,
@JsonKey(name: 'max_tokens', includeIfNull: false) final int? maxTokens,
@JsonKey(includeIfNull: false) final bool? stream,
@JsonKey(name: 'safe_mode', includeIfNull: false) final bool? safeMode,
@JsonKey(name: 'safe_prompt', includeIfNull: false)
final bool? safePrompt,
@JsonKey(name: 'random_seed', includeIfNull: false)
final int? randomSeed}) = _$ChatCompletionRequestImpl;
const _ChatCompletionRequest._() : super._();
Expand All @@ -379,11 +381,11 @@ abstract class _ChatCompletionRequest extends ChatCompletionRequest {
ChatCompletionModel get model;
@override

/// The prompt(s) to generate completions for, encoded as a list of dict with role and content.
/// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be `user` or `system`.
List<ChatCompletionMessage> get messages;
@override

/// What sampling temperature to use, between 0.0 and 2.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
/// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
@JsonKey(includeIfNull: false)
Expand All @@ -410,8 +412,8 @@ abstract class _ChatCompletionRequest extends ChatCompletionRequest {
@override

/// Whether to inject a safety prompt before all conversations.
@JsonKey(name: 'safe_mode', includeIfNull: false)
bool? get safeMode;
@JsonKey(name: 'safe_prompt', includeIfNull: false)
bool? get safePrompt;
@override

/// The seed to use for random sampling. If set, different calls will generate deterministic results.
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions packages/mistralai_dart/oas/mistral_openapi_curated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,20 @@ components:
messages:
description: >
The prompt(s) to generate completions for, encoded as a list of dict
with role and content.
with role and content. The first prompt role should be `user` or
`system`.
type: array
items:
$ref: '#/components/schemas/ChatCompletionMessage'
temperature:
type: number
minimum: 0
maximum: 2
maximum: 1
default: 0.7
example: 0.7
nullable: true
description: >
What sampling temperature to use, between 0.0 and 2.0. Higher values
What sampling temperature to use, between 0.0 and 1.0. Higher values
like 0.8 will make the output more random, while lower values like
0.2 will make it more focused and deterministic.
Expand Down Expand Up @@ -156,7 +157,7 @@ components:
stream terminated by a data: [DONE] message. Otherwise, the server
will hold the request open until the timeout or until completion,
with the response containing the full result as JSON.
safe_mode:
safe_prompt:
type: boolean
default: false
nullable: true
Expand Down
Loading

0 comments on commit 1a4ccd1

Please sign in to comment.