Skip to content

Commit

Permalink
feat: Include usage stats when streaming with OpenAI and ChatOpenAI (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed May 9, 2024
1 parent e76dd70 commit 5e2b0ec
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 58 deletions.
5 changes: 4 additions & 1 deletion packages/langchain_core/lib/src/chat_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class ChatResult extends LanguageModelResult<AIChatMessage> {
return ChatResult(
id: other.id,
output: output.concat(other.output),
finishReason: other.finishReason,
finishReason: finishReason != FinishReason.unspecified &&
other.finishReason == FinishReason.unspecified
? finishReason
: other.finishReason,
metadata: {
...metadata,
...other.metadata,
Expand Down
36 changes: 19 additions & 17 deletions packages/langchain_core/lib/src/language_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -134,23 +134,25 @@ class LanguageModelUsage {
/// Merges this usage with another by summing the values.
LanguageModelUsage concat(final LanguageModelUsage other) {
return LanguageModelUsage(
promptTokens: promptTokens != null && other.promptTokens != null
? promptTokens! + other.promptTokens!
: null,
promptBillableCharacters: promptBillableCharacters != null &&
other.promptBillableCharacters != null
? promptBillableCharacters! + other.promptBillableCharacters!
: null,
responseTokens: responseTokens != null && other.responseTokens != null
? responseTokens! + other.responseTokens!
: null,
responseBillableCharacters: responseBillableCharacters != null &&
other.responseBillableCharacters != null
? responseBillableCharacters! + other.responseBillableCharacters!
: null,
totalTokens: totalTokens != null && other.totalTokens != null
? totalTokens! + other.totalTokens!
: null,
promptTokens: promptTokens == null && other.promptTokens == null
? null
: (promptTokens ?? 0) + (other.promptTokens ?? 0),
promptBillableCharacters: promptBillableCharacters == null &&
other.promptBillableCharacters == null
? null
: (promptBillableCharacters ?? 0) +
(other.promptBillableCharacters ?? 0),
responseTokens: responseTokens == null && other.responseTokens == null
? null
: (responseTokens ?? 0) + (other.responseTokens ?? 0),
responseBillableCharacters: responseBillableCharacters == null &&
other.responseBillableCharacters == null
? null
: (responseBillableCharacters ?? 0) +
(other.responseBillableCharacters ?? 0),
totalTokens: totalTokens == null && other.totalTokens == null
? null
: (totalTokens ?? 0) + (other.totalTokens ?? 0),
);
}

Expand Down
5 changes: 4 additions & 1 deletion packages/langchain_core/lib/src/llms/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ class LLMResult extends LanguageModelResult<String> {
return LLMResult(
id: other.id,
output: output + other.output,
finishReason: other.finishReason,
finishReason: finishReason != FinishReason.unspecified &&
other.finishReason == FinishReason.unspecified
? finishReason
: other.finishReason,
metadata: {
...metadata,
...other.metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
request: _createChatCompletionRequest(
input.toChatMessages(),
options: options,
stream: true,
),
)
.map(
Expand All @@ -272,6 +273,7 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
CreateChatCompletionRequest _createChatCompletionRequest(
final List<ChatMessage> messages, {
final ChatOpenAIOptions? options,
final bool stream = false,
}) {
final messagesDtos = messages.toChatCompletionMessages();
final toolsDtos = options?.tools?.toChatCompletionTool() ??
Expand Down Expand Up @@ -304,6 +306,8 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
temperature: options?.temperature ?? defaultOptions.temperature,
topP: options?.topP ?? defaultOptions.topP,
user: options?.user ?? defaultOptions.user,
streamOptions:
stream ? const ChatCompletionStreamOptions(includeUsage: true) : null,
);
}

Expand Down
35 changes: 18 additions & 17 deletions packages/langchain_openai/lib/src/chat_models/mappers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,14 @@ extension CreateChatCompletionResponseMapper on CreateChatCompletionResponse {
arguments: args,
);
}
}

LanguageModelUsage _mapUsage(final CompletionUsage? usage) {
return LanguageModelUsage(
promptTokens: usage?.promptTokens,
responseTokens: usage?.completionTokens,
totalTokens: usage?.totalTokens,
);
}
LanguageModelUsage _mapUsage(final CompletionUsage? usage) {
return LanguageModelUsage(
promptTokens: usage?.promptTokens,
responseTokens: usage?.completionTokens,
totalTokens: usage?.totalTokens,
);
}

extension ChatToolListMapper on List<ToolSpec> {
Expand Down Expand Up @@ -206,23 +206,24 @@ extension ChatToolChoiceMapper on ChatToolChoice {
extension CreateChatCompletionStreamResponseMapper
on CreateChatCompletionStreamResponse {
ChatResult toChatResult(final String id) {
final choice = choices.first;
final delta = choice.delta;
final choice = choices.firstOrNull;
final delta = choice?.delta;
return ChatResult(
id: id,
output: AIChatMessage(
content: delta.content ?? '',
toolCalls:
delta.toolCalls?.map(_mapMessageToolCall).toList(growable: false) ??
const [],
content: delta?.content ?? '',
toolCalls: delta?.toolCalls
?.map(_mapMessageToolCall)
.toList(growable: false) ??
const [],
),
finishReason: _mapFinishReason(choice.finishReason),
finishReason: _mapFinishReason(choice?.finishReason),
metadata: {
'model': model,
'created': created,
'system_fingerprint': systemFingerprint,
if (model != null) 'model': model,
if (systemFingerprint != null) 'system_fingerprint': systemFingerprint,
},
usage: const LanguageModelUsage(),
usage: _mapUsage(usage),
streaming: true,
);
}
Expand Down
25 changes: 21 additions & 4 deletions packages/langchain_openai/lib/src/llms/mappers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,35 @@ import 'package:openai_dart/openai_dart.dart';

extension CreateCompletionResponseMapper on CreateCompletionResponse {
List<LLMResult> toLLMResults({final bool streaming = false}) {
final metadata = {
'created': created,
'model': model,
if (systemFingerprint != null) 'system_fingerprint': systemFingerprint,
};
final totalUsage = _mapUsage(usage);
if (choices.isEmpty) {
return [
LLMResult(
id: '$id:0',
output: '',
finishReason: FinishReason.unspecified,
metadata: metadata,
usage: totalUsage,
streaming: streaming,
),
];
}

return choices
.mapIndexed(
(final index, final choice) => LLMResult(
id: '$id:$index',
output: choice.text,
finishReason: _mapFinishReason(choice.finishReason),
metadata: {
'created': created,
'model': model,
'system_fingerprint': systemFingerprint,
'logprobs': choice.logprobs?.toJson(),
...metadata,
if (choice.logprobs != null)
'logprobs': choice.logprobs?.toJson(),
},
usage: totalUsage,
streaming: streaming,
Expand Down
4 changes: 4 additions & 0 deletions packages/langchain_openai/lib/src/llms/openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ class OpenAI extends BaseLLM<OpenAIOptions> {
request: _createCompletionRequest(
[input.toString()],
options: options,
stream: true,
),
)
.map(
Expand All @@ -297,6 +298,7 @@ class OpenAI extends BaseLLM<OpenAIOptions> {
CreateCompletionRequest _createCompletionRequest(
final List<String> prompts, {
final OpenAIOptions? options,
final bool stream = false,
}) {
return CreateCompletionRequest(
model: CompletionModel.modelId(
Expand All @@ -320,6 +322,8 @@ class OpenAI extends BaseLLM<OpenAIOptions> {
temperature: options?.temperature ?? defaultOptions.temperature,
topP: options?.topP ?? defaultOptions.topP,
user: options?.user ?? defaultOptions.user,
streamOptions:
stream ? const ChatCompletionStreamOptions(includeUsage: true) : null,
);
}

Expand Down
28 changes: 16 additions & 12 deletions packages/langchain_openai/test/chat_models/chat_openai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -241,29 +241,33 @@ void main() {
});

test('Test ChatOpenAI streaming', () async {
final promptTemplate = ChatPromptTemplate.fromPromptMessages([
SystemChatMessagePromptTemplate.fromTemplate(
final promptTemplate = ChatPromptTemplate.fromTemplates(const [
(
ChatMessageType.system,
'You are a helpful assistant that replies only with numbers '
'in order without any spaces or commas',
),
HumanChatMessagePromptTemplate.fromTemplate(
'List the numbers from 1 to {max_num}',
'in order without any spaces or commas',
),
(ChatMessageType.human, 'List the numbers from 1 to {max_num}'),
]);
final chat = ChatOpenAI(apiKey: openaiApiKey);
const stringOutputParser = StringOutputParser<ChatResult>();

final chain = promptTemplate.pipe(chat).pipe(stringOutputParser);
final chain = promptTemplate.pipe(chat);
final stream = chain.stream({'max_num': '9'});

String content = '';
ChatResult? result;
int count = 0;
await for (final res in stream) {
content += res;
await for (final ChatResult res in stream) {
result = result?.concat(res) ?? res;
count++;
}
expect(count, greaterThan(1));
expect(content.replaceAll(RegExp(r'[\s\n]'), ''), contains('123456789'));
expect(
result!.output.content.replaceAll(RegExp(r'[\s\n]'), ''),
contains('123456789'),
);
expect(result.usage.promptTokens, greaterThan(0));
expect(result.usage.responseTokens, greaterThan(0));
expect(result.usage.totalTokens, greaterThan(0));
});

test('Test ChatOpenAI streaming with functions', () async {
Expand Down
13 changes: 7 additions & 6 deletions packages/langchain_openai/test/llms/openai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ library; // Uses dart:io
import 'dart:io';

import 'package:langchain_core/llms.dart';
import 'package:langchain_core/output_parsers.dart';
import 'package:langchain_core/prompts.dart';
import 'package:langchain_openai/langchain_openai.dart';
import 'package:test/test.dart';
Expand Down Expand Up @@ -122,20 +121,22 @@ void main() {
final promptTemplate = PromptTemplate.fromTemplate(
'List the numbers from 1 to {max_num} in order without any spaces or commas',
);
const stringOutputParser = StringOutputParser<LLMResult>();

final chain = promptTemplate.pipe(llm).pipe(stringOutputParser);
final chain = promptTemplate.pipe(llm);

final stream = chain.stream({'max_num': '9'});

String content = '';
LLMResult? result;
int count = 0;
await for (final res in stream) {
content += res;
result = result?.concat(res) ?? res;
count++;
}
expect(count, greaterThan(1));
expect(content, contains('123456789'));
expect(result!.output, contains('123456789'));
expect(result.usage.promptTokens, greaterThan(0));
expect(result.usage.responseTokens, greaterThan(0));
expect(result.usage.totalTokens, greaterThan(0));
});

test('Test response seed', () async {
Expand Down

0 comments on commit 5e2b0ec

Please sign in to comment.