Skip to content

Commit 5e133a0

Browse files
authored
Merge pull request #21 from yiyuan-he/llm-support-v1
Set Up LLM Inference Attributes Auto-Instrumentation Java v1
2 parents 749d025 + 0309a76 commit 5e133a0

File tree

4 files changed

+510
-30
lines changed

4 files changed

+510
-30
lines changed

instrumentation/aws-sdk/aws-sdk-1.11/library/src/main/java/io/opentelemetry/instrumentation/awssdk/v1_11/AwsExperimentalAttributes.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,23 @@ final class AwsExperimentalAttributes {
3030
stringKey("gen_ai.request.model");
3131
static final AttributeKey<String> AWS_BEDROCK_SYSTEM = stringKey("gen_ai.system");
3232

33+
static final AttributeKey<String> GEN_AI_REQUEST_MAX_TOKENS =
34+
stringKey("gen_ai.request.max_tokens");
35+
36+
static final AttributeKey<String> GEN_AI_REQUEST_TEMPERATURE =
37+
stringKey("gen_ai.request.temperature");
38+
39+
static final AttributeKey<String> GEN_AI_REQUEST_TOP_P = stringKey("gen_ai.request.top_p");
40+
41+
static final AttributeKey<String> GEN_AI_RESPONSE_FINISH_REASONS =
42+
stringKey("gen_ai.response.finish_reasons");
43+
44+
static final AttributeKey<String> GEN_AI_USAGE_INPUT_TOKENS =
45+
stringKey("gen_ai.usage.input_tokens");
46+
47+
static final AttributeKey<String> GEN_AI_USAGE_OUTPUT_TOKENS =
48+
stringKey("gen_ai.usage.output_tokens");
49+
3350
static final AttributeKey<String> AWS_STATE_MACHINE_ARN =
3451
stringKey("aws.stepfunctions.state_machine.arn");
3552

instrumentation/aws-sdk/aws-sdk-1.11/library/src/main/java/io/opentelemetry/instrumentation/awssdk/v1_11/AwsSdkExperimentalAttributesExtractor.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
import static io.opentelemetry.instrumentation.awssdk.v1_11.AwsExperimentalAttributes.AWS_STEP_FUNCTIONS_ACTIVITY_ARN;
2626
import static io.opentelemetry.instrumentation.awssdk.v1_11.AwsExperimentalAttributes.AWS_STREAM_NAME;
2727
import static io.opentelemetry.instrumentation.awssdk.v1_11.AwsExperimentalAttributes.AWS_TABLE_NAME;
28+
import static io.opentelemetry.instrumentation.awssdk.v1_11.AwsExperimentalAttributes.GEN_AI_REQUEST_MAX_TOKENS;
29+
import static io.opentelemetry.instrumentation.awssdk.v1_11.AwsExperimentalAttributes.GEN_AI_REQUEST_TEMPERATURE;
30+
import static io.opentelemetry.instrumentation.awssdk.v1_11.AwsExperimentalAttributes.GEN_AI_REQUEST_TOP_P;
31+
import static io.opentelemetry.instrumentation.awssdk.v1_11.AwsExperimentalAttributes.GEN_AI_RESPONSE_FINISH_REASONS;
32+
import static io.opentelemetry.instrumentation.awssdk.v1_11.AwsExperimentalAttributes.GEN_AI_USAGE_INPUT_TOKENS;
33+
import static io.opentelemetry.instrumentation.awssdk.v1_11.AwsExperimentalAttributes.GEN_AI_USAGE_OUTPUT_TOKENS;
2834

2935
import com.amazonaws.AmazonWebServiceResponse;
3036
import com.amazonaws.Request;
@@ -144,6 +150,14 @@ private static void bedrockOnStart(
144150
Function<Object, String> getter = RequestAccess::getModelId;
145151
String modelId = getter.apply(originalRequest);
146152
attributes.put(AWS_BEDROCK_RUNTIME_MODEL_ID, modelId);
153+
154+
setAttribute(
155+
attributes, GEN_AI_REQUEST_MAX_TOKENS, originalRequest, RequestAccess::getMaxTokens);
156+
setAttribute(
157+
attributes, GEN_AI_REQUEST_TEMPERATURE, originalRequest, RequestAccess::getTemperature);
158+
setAttribute(attributes, GEN_AI_REQUEST_TOP_P, originalRequest, RequestAccess::getTopP);
159+
setAttribute(
160+
attributes, GEN_AI_USAGE_INPUT_TOKENS, originalRequest, RequestAccess::getInputTokens);
147161
break;
148162
default:
149163
break;
@@ -173,6 +187,17 @@ private static void bedrockOnEnd(
173187
setAttribute(attributes, AWS_AGENT_ID, awsResp, RequestAccess::getAgentId);
174188
setAttribute(attributes, AWS_KNOWLEDGE_BASE_ID, awsResp, RequestAccess::getKnowledgeBaseId);
175189
break;
190+
case BEDROCK_RUNTIME_SERVICE:
191+
if (!Objects.equals(awsResp.getClass().getSimpleName(), "InvokeModelResult")) {
192+
break;
193+
}
194+
195+
setAttribute(attributes, GEN_AI_USAGE_INPUT_TOKENS, awsResp, RequestAccess::getInputTokens);
196+
setAttribute(
197+
attributes, GEN_AI_USAGE_OUTPUT_TOKENS, awsResp, RequestAccess::getOutputTokens);
198+
setAttribute(
199+
attributes, GEN_AI_RESPONSE_FINISH_REASONS, awsResp, RequestAccess::getFinishReasons);
200+
break;
176201
default:
177202
break;
178203
}

instrumentation/aws-sdk/aws-sdk-1.11/library/src/main/java/io/opentelemetry/instrumentation/awssdk/v1_11/RequestAccess.java

Lines changed: 251 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@
55

66
package io.opentelemetry.instrumentation.awssdk.v1_11;
77

8+
import com.fasterxml.jackson.databind.JsonNode;
9+
import com.fasterxml.jackson.databind.ObjectMapper;
10+
import java.io.IOException;
811
import java.lang.invoke.MethodHandle;
912
import java.lang.invoke.MethodHandles;
1013
import java.lang.invoke.MethodType;
1114
import java.lang.reflect.Method;
15+
import java.nio.ByteBuffer;
16+
import java.util.Arrays;
17+
import java.util.Objects;
18+
import java.util.stream.Stream;
1219
import javax.annotation.Nullable;
1320

1421
final class RequestAccess {
@@ -21,6 +28,213 @@ protected RequestAccess computeValue(Class<?> type) {
2128
}
2229
};
2330

31+
private static final ObjectMapper objectMapper = new ObjectMapper();
32+
33+
@Nullable
34+
private static JsonNode parseTargetBody(ByteBuffer buffer) {
35+
try {
36+
byte[] bytes;
37+
// Create duplicate to avoid mutating the original buffer position
38+
ByteBuffer duplicate = buffer.duplicate();
39+
if (buffer.hasArray()) {
40+
bytes =
41+
Arrays.copyOfRange(
42+
duplicate.array(),
43+
duplicate.arrayOffset(),
44+
duplicate.arrayOffset() + duplicate.remaining());
45+
} else {
46+
bytes = new byte[buffer.remaining()];
47+
buffer.get(bytes);
48+
}
49+
return objectMapper.readTree(bytes);
50+
} catch (IOException e) {
51+
return null;
52+
}
53+
}
54+
55+
@Nullable
56+
private static JsonNode getJsonBody(Object target) {
57+
if (target == null) {
58+
return null;
59+
}
60+
61+
RequestAccess access = REQUEST_ACCESSORS.get(target.getClass());
62+
ByteBuffer bodyBuffer = invokeOrNullGeneric(access.getBody, target, ByteBuffer.class);
63+
if (bodyBuffer == null) {
64+
return null;
65+
}
66+
67+
return parseTargetBody(bodyBuffer);
68+
}
69+
70+
@Nullable
71+
private static String findFirstMatchingPath(JsonNode jsonBody, String... paths) {
72+
if (jsonBody == null) {
73+
return null;
74+
}
75+
76+
return Stream.of(paths)
77+
.map(
78+
path -> {
79+
JsonNode node = jsonBody.at(path);
80+
if (node != null && !node.isMissingNode()) {
81+
return node.asText();
82+
}
83+
return null;
84+
})
85+
.filter(Objects::nonNull)
86+
.findFirst()
87+
.orElse(null);
88+
}
89+
90+
@Nullable
91+
private static String approximateTokenCount(JsonNode jsonBody, String... textPaths) {
92+
if (jsonBody == null) {
93+
return null;
94+
}
95+
96+
return Stream.of(textPaths)
97+
.map(
98+
path -> {
99+
JsonNode node = jsonBody.at(path);
100+
if (node != null && !node.isMissingNode()) {
101+
int tokenEstimate = (int) Math.ceil(node.asText().length() / 6.0);
102+
return Integer.toString(tokenEstimate);
103+
}
104+
return null;
105+
})
106+
.filter(Objects::nonNull)
107+
.findFirst()
108+
.orElse(null);
109+
}
110+
111+
// Model -> Path Mapping:
112+
// Amazon Titan -> "/textGenerationConfig/maxTokenCount"
113+
// Anthropic Claude -> "/max_tokens"
114+
// Cohere Command -> "/max_tokens"
115+
// Cohere Command R -> "/max_tokens"
116+
// AI21 Jamba -> "/max_tokens"
117+
// Meta Llama -> "/max_gen_len"
118+
// Mistral AI -> "/max_tokens"
119+
@Nullable
120+
static String getMaxTokens(Object target) {
121+
return findFirstMatchingPath(
122+
getJsonBody(target), "/textGenerationConfig/maxTokenCount", "/max_tokens", "/max_gen_len");
123+
}
124+
125+
// Model -> Path Mapping:
126+
// Amazon Titan -> "/textGenerationConfig/temperature"
127+
// Anthropic Claude -> "/temperature"
128+
// Cohere Command -> "/temperature"
129+
// Cohere Command R -> "/temperature"
130+
// AI21 Jamba -> "/temperature"
131+
// Meta Llama -> "/temperature"
132+
// Mistral AI -> "/temperature"
133+
@Nullable
134+
static String getTemperature(Object target) {
135+
return findFirstMatchingPath(
136+
getJsonBody(target), "/textGenerationConfig/temperature", "/temperature");
137+
}
138+
139+
// Model -> Path Mapping:
140+
// Amazon Titan -> "/textGenerationConfig/topP"
141+
// Anthropic Claude -> "/top_p"
142+
// Cohere Command -> "/p"
143+
// Cohere Command R -> "/p"
144+
// AI21 Jamba -> "/top_p"
145+
// Meta Llama -> "/top_p"
146+
// Mistral AI -> "/top_p"
147+
@Nullable
148+
static String getTopP(Object target) {
149+
return findFirstMatchingPath(getJsonBody(target), "/textGenerationConfig/topP", "/top_p", "/p");
150+
}
151+
152+
// Model -> Path Mapping:
153+
// Amazon Titan -> "/inputTextTokenCount"
154+
// Anthropic Claude -> "/usage/input_tokens"
155+
// Cohere Command -> "/prompt"
156+
// Cohere Command R -> "/message"
157+
// AI21 Jamba -> "/usage/prompt_tokens"
158+
// Meta Llama -> "/prompt_token_count"
159+
// Mistral AI -> "/prompt"
160+
@Nullable
161+
static String getInputTokens(Object target) {
162+
JsonNode jsonBody = getJsonBody(target);
163+
if (jsonBody == null) {
164+
return null;
165+
}
166+
167+
// Try direct tokens counts first
168+
String directCount =
169+
findFirstMatchingPath(
170+
jsonBody,
171+
"/inputTextTokenCount",
172+
"/usage/input_tokens",
173+
"/usage/prompt_tokens",
174+
"/prompt_token_count");
175+
176+
if (directCount != null) {
177+
return directCount;
178+
}
179+
180+
// Fall back to token approximation
181+
return approximateTokenCount(jsonBody, "/prompt", "/message");
182+
}
183+
184+
// Model -> Path Mapping:
185+
// Amazon Titan -> "/results/0/tokenCount"
186+
// Anthropic Claude -> "/usage/output_tokens"
187+
// Cohere Command -> "/generations/0/text"
188+
// Cohere Command R -> "/text"
189+
// AI21 Jamba -> "/usage/completion_tokens"
190+
// Meta Llama -> "/generation_token_count"
191+
// Mistral AI -> "/outputs/0/text"
192+
@Nullable
193+
static String getOutputTokens(Object target) {
194+
JsonNode jsonBody = getJsonBody(target);
195+
if (jsonBody == null) {
196+
return null;
197+
}
198+
199+
// Try direct token counts first
200+
String directCount =
201+
findFirstMatchingPath(
202+
jsonBody,
203+
"/results/0/tokenCount",
204+
"/usage/output_tokens",
205+
"/usage/completion_tokens",
206+
"/generation_token_count");
207+
208+
if (directCount != null) {
209+
return directCount;
210+
}
211+
212+
return approximateTokenCount(jsonBody, "/outputs/0/text", "/text");
213+
}
214+
215+
// Model -> Path Mapping:
216+
// Amazon Titan -> "/results/0/completionReason"
217+
// Anthropic Claude -> "/stop_reason"
218+
// Cohere Command -> "/generations/0/finish_reason"
219+
// Cohere Command R -> "/finish_reason"
220+
// AI21 Jamba -> "/choices/0/finish_reason"
221+
// Meta Llama -> "/stop_reason"
222+
// Mistral AI -> "/outputs/0/stop_reason"
223+
@Nullable
224+
static String getFinishReasons(Object target) {
225+
String finishReason =
226+
findFirstMatchingPath(
227+
getJsonBody(target),
228+
"/results/0/completionReason",
229+
"/stop_reason",
230+
"/generations/0/finish_reason",
231+
"/choices/0/finish_reason",
232+
"/outputs/0/stop_reason",
233+
"/finish_reason");
234+
235+
return finishReason != null ? "[" + finishReason + "]" : null;
236+
}
237+
24238
@Nullable
25239
static String getLambdaName(Object request) {
26240
if (request == null) {
@@ -185,6 +399,19 @@ private static String invokeOrNull(@Nullable MethodHandle method, Object obj) {
185399
}
186400
}
187401

402+
@Nullable
403+
private static <T> T invokeOrNullGeneric(
404+
@Nullable MethodHandle method, Object obj, Class<T> returnType) {
405+
if (method == null) {
406+
return null;
407+
}
408+
try {
409+
return returnType.cast(method.invoke(obj));
410+
} catch (Throwable e) {
411+
return null;
412+
}
413+
}
414+
188415
@Nullable private final MethodHandle getBucketName;
189416
@Nullable private final MethodHandle getQueueUrl;
190417
@Nullable private final MethodHandle getQueueName;
@@ -195,6 +422,7 @@ private static String invokeOrNull(@Nullable MethodHandle method, Object obj) {
195422
@Nullable private final MethodHandle getDataSourceId;
196423
@Nullable private final MethodHandle getGuardrailId;
197424
@Nullable private final MethodHandle getModelId;
425+
@Nullable private final MethodHandle getBody;
198426
@Nullable private final MethodHandle getStateMachineArn;
199427
@Nullable private final MethodHandle getStepFunctionsActivityArn;
200428
@Nullable private final MethodHandle getSnsTopicArn;
@@ -203,29 +431,31 @@ private static String invokeOrNull(@Nullable MethodHandle method, Object obj) {
203431
@Nullable private final MethodHandle getLambdaResourceId;
204432

205433
private RequestAccess(Class<?> clz) {
206-
getBucketName = findAccessorOrNull(clz, "getBucketName");
207-
getQueueUrl = findAccessorOrNull(clz, "getQueueUrl");
208-
getQueueName = findAccessorOrNull(clz, "getQueueName");
209-
getStreamName = findAccessorOrNull(clz, "getStreamName");
210-
getTableName = findAccessorOrNull(clz, "getTableName");
211-
getAgentId = findAccessorOrNull(clz, "getAgentId");
212-
getKnowledgeBaseId = findAccessorOrNull(clz, "getKnowledgeBaseId");
213-
getDataSourceId = findAccessorOrNull(clz, "getDataSourceId");
214-
getGuardrailId = findAccessorOrNull(clz, "getGuardrailId");
215-
getModelId = findAccessorOrNull(clz, "getModelId");
216-
getStateMachineArn = findAccessorOrNull(clz, "getStateMachineArn");
217-
getStepFunctionsActivityArn = findAccessorOrNull(clz, "getActivityArn");
218-
getSnsTopicArn = findAccessorOrNull(clz, "getTopicArn");
219-
getSecretArn = findAccessorOrNull(clz, "getARN");
220-
getLambdaName = findAccessorOrNull(clz, "getFunctionName");
221-
getLambdaResourceId = findAccessorOrNull(clz, "getUUID");
222-
}
223-
224-
@Nullable
225-
private static MethodHandle findAccessorOrNull(Class<?> clz, String methodName) {
434+
getBucketName = findAccessorOrNull(clz, "getBucketName", String.class);
435+
getQueueUrl = findAccessorOrNull(clz, "getQueueUrl", String.class);
436+
getQueueName = findAccessorOrNull(clz, "getQueueName", String.class);
437+
getStreamName = findAccessorOrNull(clz, "getStreamName", String.class);
438+
getTableName = findAccessorOrNull(clz, "getTableName", String.class);
439+
getAgentId = findAccessorOrNull(clz, "getAgentId", String.class);
440+
getKnowledgeBaseId = findAccessorOrNull(clz, "getKnowledgeBaseId", String.class);
441+
getDataSourceId = findAccessorOrNull(clz, "getDataSourceId", String.class);
442+
getGuardrailId = findAccessorOrNull(clz, "getGuardrailId", String.class);
443+
getModelId = findAccessorOrNull(clz, "getModelId", String.class);
444+
getBody = findAccessorOrNull(clz, "getBody", ByteBuffer.class);
445+
getStateMachineArn = findAccessorOrNull(clz, "getStateMachineArn", String.class);
446+
getStepFunctionsActivityArn = findAccessorOrNull(clz, "getActivityArn", String.class);
447+
getSnsTopicArn = findAccessorOrNull(clz, "getTopicArn", String.class);
448+
getSecretArn = findAccessorOrNull(clz, "getARN", String.class);
449+
getLambdaName = findAccessorOrNull(clz, "getFunctionName", String.class);
450+
getLambdaResourceId = findAccessorOrNull(clz, "getUUID", String.class);
451+
}
452+
453+
@Nullable
454+
private static MethodHandle findAccessorOrNull(
455+
Class<?> clz, String methodName, Class<?> returnType) {
226456
try {
227457
return MethodHandles.publicLookup()
228-
.findVirtual(clz, methodName, MethodType.methodType(String.class));
458+
.findVirtual(clz, methodName, MethodType.methodType(returnType));
229459
} catch (Throwable t) {
230460
return null;
231461
}

0 commit comments

Comments
 (0)