Skip to content

Commit

Permalink
fix issue #8
Browse files Browse the repository at this point in the history
  • Loading branch information
forestwanglin committed Feb 3, 2024
1 parent 30c718f commit bfdb98d
Show file tree
Hide file tree
Showing 12 changed files with 264 additions and 122 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

OpenAi API for Java. Including all API from OpenAI official document, and the counting token method.

[![GitHub version](https://img.shields.io/static/v1?label=version&message=v3.8.20240202&color=blue)](https://github.com/forestwanglin/openai-java)
[![GitHub version](https://img.shields.io/static/v1?label=version&message=v3.8.20240203&color=blue)](https://github.com/forestwanglin/openai-java)
[![License](https://img.shields.io/static/v1?label=license&message=MIT&color=orange)](https://github.com/forestwanglin/openai-java/blob/main/LICENSE)
[![License](https://img.shields.io/static/v1?label=license&message=MIT&color=orange)](https://github.com/forestwanglin/openai-java/blob/main/LICENSE)

Expand Down Expand Up @@ -56,7 +56,7 @@ OpenAi API for Java. Including all API from OpenAI official document, and the co
<dependency>
<groupId>xyz.felh</groupId>
<artifactId>service</artifactId>
<version>3.8.20240202</version>
<version>3.8.20240203</version>
</dependency>
```

Expand All @@ -65,22 +65,22 @@ OpenAi API for Java. Including all API from OpenAI official document, and the co
<dependency>
<groupId>xyz.felh</groupId>
<artifactId>jtokkit</artifactId>
<version>3.8.20240202</version>
<version>3.8.20240203</version>
</dependency>
```

### Gradle

```yaml
implementation group: 'xyz.felh', name: 'service', version: '3.8.20240202'
implementation group: 'xyz.felh', name: 'jtokkit', version: '3.8.20240202'
implementation group: 'xyz.felh', name: 'service', version: '3.8.20240203'
implementation group: 'xyz.felh', name: 'jtokkit', version: '3.8.20240203'
```

### sbt

```javascript
libraryDependencies += "xyz.felh" % "service" % "3.8.20240202"
libraryDependencies += "xyz.felh" % "jtokkit" % "3.8.20240202"
libraryDependencies += "xyz.felh" % "service" % "3.8.20240203"
libraryDependencies += "xyz.felh" % "jtokkit" % "3.8.20240203"
```

## Example (Spring Boot 3)
Expand All @@ -92,7 +92,7 @@ libraryDependencies += "xyz.felh" % "jtokkit" % "3.8.20240202"
<dependency>
<groupId>xyz.felh</groupId>
<artifactId>service</artifactId>
<version>3.8.20240202</version>
<version>3.8.20240203</version>
</dependency>
```

Expand Down
2 changes: 1 addition & 1 deletion core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>xyz.felh</groupId>
<artifactId>openai-java</artifactId>
<version>3.8.20240202</version>
<version>3.8.20240203</version>
</parent>

<artifactId>core</artifactId>
Expand Down
4 changes: 2 additions & 2 deletions jtokkit/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>xyz.felh</groupId>
<artifactId>openai-java</artifactId>
<version>3.8.20240202</version>
<version>3.8.20240203</version>
</parent>

<artifactId>jtokkit</artifactId>
Expand All @@ -28,7 +28,7 @@
<dependency>
<groupId>xyz.felh</groupId>
<artifactId>core</artifactId>
<version>3.8.20240202</version>
<version>3.8.20240203</version>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package xyz.felh.openai.jtokkit.utils;

import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class ArgumentFormat {

public static String formatArguments(String arguments) {
List<String> lines = new ArrayList<>();
lines.add("{");
JSONObject jsonObject = JSONObject.parseObject(arguments);
List<String> properties = new ArrayList<>();
for (String fieldName : jsonObject.keySet()) {
properties.add(String.format("\"%s\":%s", fieldName, formatValue(jsonObject.get(fieldName))));
}
lines.add(String.join(",\n", properties));
lines.add("}");
return String.join("\n", lines);
}

private static String formatValue(Object value) {
if (value instanceof String || value instanceof Number) {
return String.format("\"%s\"", value);
}
if (value instanceof JSONArray array) {
String result = "[";
if (!array.isEmpty()) {
result += array.stream().map(it -> {
if (it instanceof Number || it instanceof String) {
return String.format("\"%s\"", it);
} else {
return "\"\"";
}
}).collect(Collectors.joining(","));
}
result += "]";
return result;
}
return "\"\"";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ public static String formatFunctionDefinitions(List<Tool> tools) {
lines.add("");
}
lines.add("} // namespace functions");
// log.info("\n" + String.join("\n", lines));
return String.join("\n", lines);
}

Expand Down
119 changes: 66 additions & 53 deletions jtokkit/src/main/java/xyz/felh/openai/jtokkit/utils/TikTokenUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


import com.alibaba.fastjson2.JSONObject;
import com.alibaba.fastjson2.JSONWriter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.SerializationUtils;
import xyz.felh.openai.chat.ChatCompletion;
Expand All @@ -27,6 +26,9 @@
import java.io.IOException;
import java.net.URL;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.ToIntBiFunction;
import java.util.function.ToIntFunction;

@Slf4j
public class TikTokenUtils {
Expand Down Expand Up @@ -243,6 +245,12 @@ public static int estimateTokensInMessages(String modelName, List<ChatMessage> m

public static int estimateTokensInMessages(String modelName, List<ChatMessage> messages, List<Tool> tools) {
int tokens = 0;
int toolMessageSize = (int) messages.stream().filter(it -> it.getRole() == ChatMessageRole.TOOL).count();
// size = 1, equal
// size = 2 - 5; 3 - 7, 4 - 9
if (toolMessageSize > 1) {
tokens += toolMessageSize * 2 + 1;
}
boolean paddedSystem = false;
for (ChatMessage message : messages) {
ChatMessage msg = SerializationUtils.clone(message);
Expand All @@ -252,7 +260,7 @@ public static int estimateTokensInMessages(String modelName, List<ChatMessage> m
}
paddedSystem = true;
}
tokens += estimateTokensInMessage(modelName, msg);
tokens += estimateTokensInMessage(modelName, msg, toolMessageSize);
}
// Each completion (vs message) seems to carry a 3-token overhead
tokens += 3;
Expand All @@ -267,67 +275,76 @@ public static int estimateTokensInMessages(String modelName, List<ChatMessage> m
* @param message 消息体
* @return tokens数量
*/
public static int estimateTokensInMessage(String modelName, ChatMessage message) {
public static int estimateTokensInMessage(String modelName, ChatMessage message, int toolMessageSize) {
Encoding encoding = getEncoding(modelName);
int tokens = 0;
// role
tokens += tokens(encoding, message.getRole().value());

// content
if (message.getContent() instanceof String) {
tokens += tokens(encoding, message.getContent().toString());
if (message.getRole() == ChatMessageRole.TOOL) {
if (toolMessageSize == 1) {
tokens += tokens(encoding, message.getContent().toString());
} else {
tokens += tokens(encoding, ToolContentFormat.format(message.getContent()));
}
} else {
List<ChatMessage.ContentItem> items = ListUtils.castList(message.getContent(), ChatMessage.ContentItem.class);
if (Preconditions.isNotBlank(items)) {
for (ChatMessage.ContentItem item : items) {
if (item.getType() == ChatMessage.ContentType.TEXT) {
// 不需要计算type
tokens += tokens(encoding, item.getText());
} else if (item.getType() == ChatMessage.ContentType.IMAGE_URL) {
ChatMessage.ImageUrl imageUrl = item.getImageUrl();
// https://openai.com/pricing
if (imageUrl.getDetail() == ChatMessage.ImageUrlDetail.LOW) {
tokens += 85;
} else if (imageUrl.getDetail() == ChatMessage.ImageUrlDetail.HIGH) {
tokens += 85;
int width = 0;
int height = 0;
if (imageUrl.getUrl().startsWith("f")) {
// base64
Base64.Decoder decoder = Base64.getDecoder();
try {
String b64 = imageUrl.getUrl();
b64 = b64.substring(b64.indexOf(";base64,") + 8);
b64 = b64.substring(0, b64.length() - 1);
byte[] bytes = decoder.decode(b64);
ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes);
BufferedImage bi = ImageIO.read(inputStream);
if (Preconditions.isNotBlank(inputStream)) {
inputStream.close();
if (message.getContent() instanceof String) {
tokens += tokens(encoding, message.getContent().toString());
} else {
List<ChatMessage.ContentItem> items = ListUtils.castList(message.getContent(), ChatMessage.ContentItem.class);
if (Preconditions.isNotBlank(items)) {
for (ChatMessage.ContentItem item : items) {
if (item.getType() == ChatMessage.ContentType.TEXT) {
// 不需要计算type
tokens += tokens(encoding, item.getText());
} else if (item.getType() == ChatMessage.ContentType.IMAGE_URL) {
ChatMessage.ImageUrl imageUrl = item.getImageUrl();
// https://openai.com/pricing
if (imageUrl.getDetail() == ChatMessage.ImageUrlDetail.LOW) {
tokens += 85;
} else if (imageUrl.getDetail() == ChatMessage.ImageUrlDetail.HIGH) {
tokens += 85;
int width = 0;
int height = 0;
if (imageUrl.getUrl().startsWith("f")) {
// base64
Base64.Decoder decoder = Base64.getDecoder();
try {
String b64 = imageUrl.getUrl();
b64 = b64.substring(b64.indexOf(";base64,") + 8);
b64 = b64.substring(0, b64.length() - 1);
byte[] bytes = decoder.decode(b64);
ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes);
BufferedImage bi = ImageIO.read(inputStream);
if (Preconditions.isNotBlank(inputStream)) {
inputStream.close();
}
width = bi.getWidth();
height = bi.getHeight();
} catch (Exception e) {
log.error("image to base64 error", e);
}
} else {
// image url
try {
BufferedImage bi = ImageIO.read(new URL(imageUrl.getUrl()));
width = bi.getWidth();
height = bi.getHeight();
} catch (IOException e) {
throw new RuntimeException(e);
}
width = bi.getWidth();
height = bi.getHeight();
} catch (Exception e) {
log.error("image to base64 error", e);
}
} else {
// image url
try {
BufferedImage bi = ImageIO.read(new URL(imageUrl.getUrl()));
width = bi.getWidth();
height = bi.getHeight();
} catch (IOException e) {
throw new RuntimeException(e);
}
// 1 per 512x512
int tiles = (int) Math.ceil(width / 512.0) * (int) Math.ceil(height / 512.0);
tokens += 170 * tiles;
}
// 1 per 512x512
int tiles = (int) Math.ceil(width / 512.0) * (int) Math.ceil(height / 512.0);
tokens += 170 * tiles;
}
}
}
}
}

// name 如果是 tool的时候不计算 name
if (Preconditions.isNotBlank(message.getName()) && message.getRole() != ChatMessageRole.TOOL) {
tokens += tokens(encoding, message.getName()) + 1; // +1 for the name
Expand All @@ -341,11 +358,7 @@ public static int estimateTokensInMessage(String modelName, ChatMessage message)
tokens += tokens(encoding, toolCall.getFunction().getName());
}
if (Preconditions.isNotBlank(toolCall.getFunction().getArguments())) {
// 这个地方要特殊处理,按照标准print args,然后再计算tokens
String args = JSONObject.toJSONString(JSONObject.parseObject(toolCall.getFunction().getArguments()),
JSONWriter.Feature.PrettyFormat);
args = args.replaceAll("\\t", "");
tokens += tokens(encoding, args);
tokens += tokens(encoding, ArgumentFormat.formatArguments(toolCall.getFunction().getArguments()));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package xyz.felh.openai.jtokkit.utils;

import com.alibaba.fastjson2.JSONObject;

public class ToolContentFormat {

public static boolean isJSONString(String content) {
try {
JSONObject.parseObject(content);
return true;
} catch (Exception ignored) {
}
return false;
}

public static String format(Object content) {
try {
JSONObject.parseObject(content.toString());
return ArgumentFormat.formatArguments(content.toString());
} catch (Exception ex) {
// error
}
return content.toString();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void testTokens() {
new ChatMessage(ChatMessageRole.USER, "Count 1 to 3"),
new ChatMessage(ChatMessageRole.ASSISTANT, "1,2, 3"),
new ChatMessage(ChatMessageRole.USER, " 中国和美国距离有多远?😄😄😄✅ "));
log.info("{}", TikTokenUtils.estimateTokensInMessage(ChatCompletion.Model.GPT_4.getName(), messages.get(0)));
log.info("{}", TikTokenUtils.estimateTokensInMessage(ChatCompletion.Model.GPT_4.getName(), messages.get(0), 0));
}

}
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>xyz.felh</groupId>
<artifactId>openai-java</artifactId>
<version>3.8.20240202</version>
<version>3.8.20240203</version>
<packaging>pom</packaging>

<name>ChatGPT of OpenAI API for Java</name>
Expand Down
6 changes: 3 additions & 3 deletions service/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>xyz.felh</groupId>
<artifactId>openai-java</artifactId>
<version>3.8.20240202</version>
<version>3.8.20240203</version>
</parent>

<artifactId>service</artifactId>
Expand All @@ -31,7 +31,7 @@
<dependency>
<groupId>xyz.felh</groupId>
<artifactId>core</artifactId>
<version>3.8.20240202</version>
<version>3.8.20240203</version>
</dependency>

<dependency>
Expand Down Expand Up @@ -97,7 +97,7 @@
<dependency>
<groupId>xyz.felh</groupId>
<artifactId>jtokkit</artifactId>
<version>3.8.20240202</version>
<version>3.8.20240203</version>
<scope>test</scope>
</dependency>

Expand Down
Loading

0 comments on commit bfdb98d

Please sign in to comment.