Skip to content

Commit

Permalink
refactor: Remove default model from the language model options (#498)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Jul 20, 2024
1 parent faa23ae commit 44363e4
Show file tree
Hide file tree
Showing 33 changed files with 767 additions and 302 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ class ChatAnthropic extends BaseChatModel<ChatAnthropicOptions> {
final Map<String, dynamic>? queryParams,
final http.Client? client,
super.defaultOptions = const ChatAnthropicOptions(
model: 'claude-3-5-sonnet-20240620',
model: defaultModel,
maxTokens: defaultMaxTokens,
),
this.encoding = 'cl100k_base',
}) : _client = a.AnthropicClient(
Expand All @@ -177,6 +178,12 @@ class ChatAnthropic extends BaseChatModel<ChatAnthropicOptions> {
@override
String get modelType => 'anthropic-chat';

/// The default model to use unless another is specified.
static const defaultModel = 'claude-3-5-sonnet-20240620';

/// The default max tokens to use unless another is specified.
static const defaultMaxTokens = 1024;

@override
Future<ChatResult> invoke(
final PromptValue input, {
Expand All @@ -187,7 +194,6 @@ class ChatAnthropic extends BaseChatModel<ChatAnthropicOptions> {
input.toChatMessages(),
options: options,
defaultOptions: defaultOptions,
throwNullModelError: throwNullModelError,
),
);
return completion.toChatResult();
Expand All @@ -205,7 +211,6 @@ class ChatAnthropic extends BaseChatModel<ChatAnthropicOptions> {
options: options,
defaultOptions: defaultOptions,
stream: true,
throwNullModelError: throwNullModelError,
),
)
.transform(MessageStreamEventTransformer());
Expand Down
8 changes: 5 additions & 3 deletions packages/langchain_anthropic/lib/src/chat_models/mappers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import 'package:langchain_core/language_models.dart';
import 'package:langchain_core/tools.dart';
import 'package:rxdart/rxdart.dart' show WhereNotNullExtension;

import 'chat_anthropic.dart';
import 'types.dart';

/// Creates a [CreateMessageRequest] from the given input.
Expand All @@ -17,7 +18,6 @@ a.CreateMessageRequest createMessageRequest(
required final ChatAnthropicOptions? options,
required final ChatAnthropicOptions defaultOptions,
final bool stream = false,
required Never Function() throwNullModelError,
}) {
final systemMsg = messages.firstOrNull is SystemChatMessage
? messages.firstOrNull?.contentAsString
Expand All @@ -31,10 +31,12 @@ a.CreateMessageRequest createMessageRequest(

return a.CreateMessageRequest(
model: a.Model.modelId(
options?.model ?? defaultOptions.model ?? throwNullModelError(),
options?.model ?? defaultOptions.model ?? ChatAnthropic.defaultModel,
),
messages: messagesDtos,
maxTokens: options?.maxTokens ?? defaultOptions.maxTokens ?? 1024,
maxTokens: options?.maxTokens ??
defaultOptions.maxTokens ??
ChatAnthropic.defaultMaxTokens,
stopSequences: options?.stopSequences ?? defaultOptions.stopSequences,
system: systemMsg,
temperature: options?.temperature ?? defaultOptions.temperature,
Expand Down
61 changes: 45 additions & 16 deletions packages/langchain_anthropic/lib/src/chat_models/types.dart
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
import 'package:collection/collection.dart';
import 'package:langchain_core/chat_models.dart';
import 'package:langchain_core/tools.dart';
import 'package:meta/meta.dart';

/// {@template chat_anthropic_options}
/// Options to pass into the Anthropic Chat Model.
///
/// Available models:
/// - `claude-3-5-sonnet-20240620`
/// - `claude-3-haiku-20240307`
/// - `claude-3-opus-20240229`
/// - `claude-3-sonnet-20240229`
/// - `claude-2.0`
/// - `claude-2.1`
///
/// Mind that the list may be outdated.
/// See https://docs.anthropic.com/en/docs/about-claude/models for the latest list.
/// {@endtemplate}
@immutable
class ChatAnthropicOptions extends ChatModelOptions {
/// {@macro chat_anthropic_options}
const ChatAnthropicOptions({
this.model = 'claude-3-5-sonnet-20240620',
this.maxTokens = 1024,
super.model,
this.maxTokens,
this.stopSequences,
this.temperature,
this.topK,
Expand All @@ -19,20 +33,6 @@ class ChatAnthropicOptions extends ChatModelOptions {
super.concurrencyLimit,
});

/// ID of the model to use (e.g. 'claude-3-5-sonnet-20240620').
///
/// Available models:
/// - `claude-3-5-sonnet-20240620`
/// - `claude-3-haiku-20240307`
/// - `claude-3-opus-20240229`
/// - `claude-3-sonnet-20240229`
/// - `claude-2.0`
/// - `claude-2.1`
///
/// Mind that the list may be outdated.
/// See https://docs.anthropic.com/en/docs/about-claude/models for the latest list.
final String? model;

/// The maximum number of tokens to generate before stopping.
///
/// Note that our models may stop _before_ reaching this maximum. This parameter
Expand Down Expand Up @@ -113,4 +113,33 @@ class ChatAnthropicOptions extends ChatModelOptions {
concurrencyLimit: concurrencyLimit ?? this.concurrencyLimit,
);
}

@override
bool operator ==(covariant final ChatAnthropicOptions other) {
return model == other.model &&
maxTokens == other.maxTokens &&
const ListEquality<String>()
.equals(stopSequences, other.stopSequences) &&
temperature == other.temperature &&
topK == other.topK &&
topP == other.topP &&
userId == other.userId &&
const ListEquality<ToolSpec>().equals(tools, other.tools) &&
toolChoice == other.toolChoice &&
concurrencyLimit == other.concurrencyLimit;
}

@override
int get hashCode {
return model.hashCode ^
maxTokens.hashCode ^
const ListEquality<String>().hash(stopSequences) ^
temperature.hashCode ^
topK.hashCode ^
topP.hashCode ^
userId.hashCode ^
const ListEquality<ToolSpec>().hash(tools) ^
toolChoice.hashCode ^
concurrencyLimit.hashCode;
}
}
1 change: 1 addition & 0 deletions packages/langchain_anthropic/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
http: ^1.1.0
langchain_core: 0.3.3
langchain_tiktoken: ^1.0.1
meta: ^1.11.0
rxdart: ^0.27.7

dev_dependencies:
Expand Down
3 changes: 2 additions & 1 deletion packages/langchain_core/lib/src/chat_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import '../tools/base.dart';
class ChatModelOptions extends LanguageModelOptions {
/// {@macro chat_model_options}
const ChatModelOptions({
super.concurrencyLimit,
super.model,
this.tools,
this.toolChoice,
super.concurrencyLimit,
});

/// A list of tools the model may call.
Expand Down
31 changes: 0 additions & 31 deletions packages/langchain_core/lib/src/language_models/base.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import 'package:meta/meta.dart';

import '../langchain/base.dart';
import '../prompts/types.dart';
import 'types.dart';
Expand Down Expand Up @@ -58,33 +56,4 @@ abstract class BaseLanguageModel<

@override
String toString() => modelType;

/// Throws an error if the model id is not specified.
@protected
Never throwNullModelError() {
throw ArgumentError('''
Null model in $runtimeType.
You need to specify the id of model to use either in `$runtimeType.defaultOptions`
or in the options passed when invoking the model.
Example:
```
// In defaultOptions
final model = $runtimeType(
defaultOptions: ${runtimeType}Options(
model: 'model-id',
),
);
// Or when invoking the model
final res = await model.invoke(
prompt,
options: ${runtimeType}Options(
model: 'model-id',
),
);
```
''');
}
}
5 changes: 5 additions & 0 deletions packages/langchain_core/lib/src/language_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@ import '../langchain/types.dart';
abstract class LanguageModelOptions extends BaseLangChainOptions {
/// {@macro language_model_options}
const LanguageModelOptions({
this.model,
super.concurrencyLimit,
});

/// ID of the language model to use.
/// Check the provider's documentation for available models.
final String? model;
}

/// {@template language_model}
Expand Down
1 change: 1 addition & 0 deletions packages/langchain_core/lib/src/llms/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import '../language_models/types.dart';
class LLMOptions extends LanguageModelOptions {
/// {@macro llm_options}
const LLMOptions({
super.model,
super.concurrencyLimit,
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class ChatFirebaseVertexAI extends BaseChatModel<ChatFirebaseVertexAIOptions> {
/// - [ChatFirebaseVertexAI.location]
ChatFirebaseVertexAI({
super.defaultOptions = const ChatFirebaseVertexAIOptions(
model: 'gemini-1.5-flash',
model: defaultModel,
),
this.app,
this.appCheck,
Expand Down Expand Up @@ -188,15 +188,18 @@ class ChatFirebaseVertexAI extends BaseChatModel<ChatFirebaseVertexAIOptions> {
/// A UUID generator.
late final Uuid _uuid = const Uuid();

@override
String get modelType => 'chat-firebase-vertex-ai';

/// The current model set in [_firebaseClient];
String _currentModel;

/// The current system instruction set in [_firebaseClient];
String? _currentSystemInstruction;

@override
String get modelType => 'chat-firebase-vertex-ai';

/// The default model to use unless another is specified.
static const defaultModel = 'gemini-1.5-flash';

@override
Future<ChatResult> invoke(
final PromptValue input, {
Expand Down Expand Up @@ -329,8 +332,7 @@ class ChatFirebaseVertexAI extends BaseChatModel<ChatFirebaseVertexAIOptions> {
final List<ChatMessage> messages,
final ChatFirebaseVertexAIOptions? options,
) {
final model =
options?.model ?? defaultOptions.model ?? throwNullModelError();
final model = options?.model ?? defaultOptions.model ?? defaultModel;

final systemInstruction = messages.firstOrNull is SystemChatMessage
? messages.firstOrNull?.contentAsString
Expand Down
Loading

0 comments on commit 44363e4

Please sign in to comment.