Skip to content

Commit

Permalink
修复openai调用多次function会发生错误
Browse files Browse the repository at this point in the history
  • Loading branch information
hejianjun committed Jan 29, 2024
1 parent 5fdb97b commit 462efdb
Showing 1 changed file with 39 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,34 @@ public OpenAIEventSourceListener(SseEmitter sseEmitter, List<Message> messages,
}

public static List<ToolCalls> mergeToolCallsLists(List<ToolCalls> list1, List<ToolCalls> list2) {
List<ToolCalls> mergedList = new ArrayList<>();
int size = Math.max(list1.size(), list2.size());

for (int i = 0; i < size; i++) {
ToolCalls item1 = i < list1.size() ? list1.get(i) : null;
ToolCalls item2 = i < list2.size() ? list2.get(i) : null;

mergedList.add(mergeToolCalls(item1, item2));
List<ToolCalls> mergedList = new ArrayList<>(list1);
if (list2.isEmpty()) {
return mergedList;
}
ToolCalls item2 = list2.get(0);
boolean isMerged = false;
// 反向遍历
for (int i = list1.size() - 1; i >= 0; i--) {
ToolCalls item1 = list1.get(i);
if (item2.getId() == null || Objects.equals(item1.getId(), item2.getId())) {
mergedList.set(i, mergeToolCalls(item1, item2));
isMerged = true;
break;
}
}
if (!isMerged) {
// 如果 list2 中的对象与 list1 中的任何对象都不匹配,则作为新对象添加
mergedList.add(item2);
}

return mergedList;
}

private static ToolCalls mergeToolCalls(ToolCalls tc1, ToolCalls tc2) {
if (tc1 == null) return tc2;
if (tc2 == null) return tc1;

String id = mergeStrings(tc1.getId(), tc2.getId());
// 相同的逻辑,只是当 id 为 null 时进行合并
String id = tc1.getId() != null ? tc1.getId() : tc2.getId();
String type = mergeStrings(tc1.getType(), tc2.getType());
ToolCallFunction function = mergeToolCallFunctions(tc1.getFunction(), tc2.getFunction());

Expand Down Expand Up @@ -122,25 +132,27 @@ public void onEvent(EventSource eventSource, String id, String type, String data
messages.add(Message.builder()
.toolCalls(toolCalls)
.role(BaseMessage.Role.ASSISTANT).build());
for (ToolCalls toolCall : toolCalls) {
String callId = toolCall.getId();

ToolCallFunction function = toolCall.getFunction();
if (function != null && Objects.nonNull(function.getArguments())) {
String functionName = function.getName();
JSONObject arguments = JSONObject.parse(function.getArguments());
if ("get_table_columns".equals(functionName)) {
Chat2DBContext.putContext(connectInfo);
MetaData metaSchema = Chat2DBContext.getMetaData();
String ddl = metaSchema.tableDDL(Chat2DBContext.getConnection(), connectInfo.getDatabaseName(), connectInfo.getSchemaName(), arguments.getString("table_name"));
messages.add(Message.builder().role(BaseMessage.Role.TOOL)
.toolCallId(callId)
.name(functionName)
.content(ddl)
.build());
Chat2DBContext.removeContext();
Chat2DBContext.putContext(connectInfo);
try {
for (ToolCalls toolCall : toolCalls) {
String callId = toolCall.getId();
ToolCallFunction function = toolCall.getFunction();
if (function != null && Objects.nonNull(function.getArguments())) {
String functionName = function.getName();
JSONObject arguments = JSONObject.parse(function.getArguments());
if ("get_table_columns".equals(functionName)) {
MetaData metaSchema = Chat2DBContext.getMetaData();
String ddl = metaSchema.tableDDL(Chat2DBContext.getConnection(), connectInfo.getDatabaseName(), connectInfo.getSchemaName(), arguments.getString("table_name"));
messages.add(Message.builder().role(BaseMessage.Role.TOOL)
.toolCallId(callId)
.name(functionName)
.content(ddl)
.build());
}
}
}
} finally {
Chat2DBContext.removeContext();
}
OpenAIClient.getInstance().streamChatCompletion(messages, this);
toolCalls.clear();
Expand Down

0 comments on commit 462efdb

Please sign in to comment.