Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added setLlmTokenCountCallback API endpoint to register a callback for calculating token count when none is provided #2065

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 34 additions & 1 deletion api.js
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@
* @param {string} params.traceId Identifier for the feedback event.
* Obtained from {@link getTraceMetadata}.
* @param {string} params.category A tag for the event.
* @param {string} params.rating A indicator of how useful the message was.

Check warning on line 1551 in api.js

View workflow job for this annotation

GitHub Actions / lint (lts/*)

The type 'getTraceMetadata' is undefined
* @param {string} [params.message] The message that triggered the event.
* @param {object} [params.metadata] Additional key-value pairs to associate
* with the recorded event.
Expand Down Expand Up @@ -1839,7 +1839,7 @@
)
metric.incrementCallCount()

if (!this.shim.isFunction(callback) || this.shim.isPromise(callback)) {
if (!this.shim.isFunction(callback) || this.shim.isAsyncFunction(callback)) {
logger.warn(
'Error Group callback must be a synchronous function, Error Group attribute will not be added'
)
Expand All @@ -1849,4 +1849,37 @@
this.agent.errors.errorGroupCallback = callback
}

/**
* Registers a callback which will be used for calculating token counts on Llm events when they are not
* available. This function will typically only be used if `ai_monitoring.record_content.enabled` is false
* and you want to still capture token counts for Llm events.
*
* Provided callbacks must return an integer value for the token count for a given piece of content.
*
* @param {Function} callback - synchronous function called to calculate token count for content.
* @example
* // @param {string} model - name of model (i.e. gpt-3.5-turbo)
* // @param {string} content - prompt or completion response
* function tokenCallback(model, content) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value of content is extremely varied. I can see us fielding plenty of issues asking what it will be.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It must be a string. In the openai cases they are always strings. I know for langchain they aren't always strings but we don't assign tokens. Looking at bedrock they are strings as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it shouldn't be the parsed body. it should be the key in the body that contains the content.

* // calculate tokens based on model and content
* // return token count
* return 40
* }
*/
API.prototype.setLlmTokenCountCallback = function setLlmTokenCountCallback(callback) {
const metric = this.agent.metrics.getOrCreateMetric(
NAMES.SUPPORTABILITY.API + '/setLlmTokenCountCallback'
)
metric.incrementCallCount()

if (!this.shim.isFunction(callback) || this.shim.isAsyncFunction(callback)) {
logger.warn(
'Llm token count callback must be a synchronous function, callback will not be registered.'
)
return
}

this.agent.llm.tokenCountCallback = callback
}

module.exports = API
2 changes: 1 addition & 1 deletion lib/instrumentation/restify.js
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function wrapMiddleware(shim, middleware, _name, route) {
})

const wrappedMw = shim.recordMiddleware(middleware, spec)
if (middleware.constructor.name === 'AsyncFunction') {
if (shim.isAsyncFunction(middleware)) {
return async function asyncShim() {
return wrappedMw.apply(this, arguments)
}
Expand Down
8 changes: 6 additions & 2 deletions lib/llm-events/openai/chat-completion-message.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@ module.exports = class LlmChatCompletionMessage extends LlmEvent {
}

if (this.is_response) {
this.token_count = response?.usage?.completion_tokens
this.token_count =
response?.usage?.completion_tokens ||
agent.llm?.tokenCountCallback?.(this['response.model'], message?.content)
} else {
this.token_count = response?.usage?.prompt_tokens
this.token_count =
response?.usage?.prompt_tokens ||
agent.llm?.tokenCountCallback?.(request.model || request.engine, message?.content)
}
}
}
4 changes: 3 additions & 1 deletion lib/llm-events/openai/embedding.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ module.exports = class LlmEmbedding extends LlmEvent {
if (agent.config.ai_monitoring.record_content.enabled === true) {
this.input = request.input?.toString()
}
this.token_count = response?.usage?.prompt_tokens
this.token_count =
response?.usage?.prompt_tokens ||
agent.llm?.tokenCountCallback?.(this['request.model'], request.input?.toString())
}
}
15 changes: 15 additions & 0 deletions lib/shim/shim.js
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Shim.prototype.getName = getName
Shim.prototype.isObject = isObject
Shim.prototype.isFunction = isFunction
Shim.prototype.isPromise = isPromise
Shim.prototype.isAsyncFunction = isAsyncFunction
Shim.prototype.isString = isString
Shim.prototype.isNumber = isNumber
Shim.prototype.isBoolean = isBoolean
Expand Down Expand Up @@ -1345,6 +1346,20 @@ function isPromise(obj) {
return obj && typeof obj.then === 'function'
}

/**
* Determines if function is an async function.
* Note it does not test if the return value of function is a
* promise or async function
*
* @memberof Shim.prototype
* @param fn
* @param (function) function to test if async
* @returns {boolean} True if the function is an async function
*/
function isAsyncFunction(fn) {
return fn.constructor.name === 'AsyncFunction'
}

/**
* Determines if the given value is null.
*
Expand Down
48 changes: 47 additions & 1 deletion test/unit/api/api-llm.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ tap.test('Agent API LLM methods', (t) => {
loggerMock.warn.reset()
const agent = helper.loadMockedAgent()
t.context.api = new API(agent)
t.context.api.agent.config.ai_monitoring.enabled = true
agent.config.ai_monitoring.enabled = true
t.context.agent = agent
})

t.afterEach((t) => {
Expand Down Expand Up @@ -119,4 +120,49 @@ tap.test('Agent API LLM methods', (t) => {
})
})
})

t.test('setLlmTokenCount should register callback to calculate token counts', async (t) => {
const { api, agent } = t.context
function callback(model, content) {
if (model === 'foo' && content === 'bar') {
return 10
}

return 1
}
api.setLlmTokenCountCallback(callback)
t.same(agent.llm.tokenCountCallback, callback)
})

t.test('should not store token count callback if it is async', async (t) => {
const { api, agent } = t.context
async function callback(model, content) {
return await new Promise((resolve) => {
if (model === 'foo' && content === 'bar') {
resolve(10)
}
})
}
api.setLlmTokenCountCallback(callback)
t.same(agent.llm.tokenCountCallback, undefined)
t.equal(loggerMock.warn.callCount, 1)
t.equal(
loggerMock.warn.args[0][0],
'Llm token count callback must be a synchronous function, callback will not be registered.'
)
})

t.test(
'should not store token count callback if callback is not actually a function',
async (t) => {
const { api, agent } = t.context
api.setLlmTokenCountCallback({ unit: 'test' })
t.same(agent.llm.tokenCountCallback, undefined)
t.equal(loggerMock.warn.callCount, 1)
t.equal(
loggerMock.warn.args[0][0],
'Llm token count callback must be a synchronous function, callback will not be registered.'
)
}
)
})
4 changes: 2 additions & 2 deletions test/unit/api/api-set-error-group-callback.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ tap.test('Agent API = set Error Group callback', (t) => {
})

t.test('should not attach the callback when async function', (t) => {
function callback() {
return new Promise((resolve) => {
async function callback() {
return await new Promise((resolve) => {
setTimeout(() => {
resolve()
}, 200)
Expand Down
2 changes: 1 addition & 1 deletion test/unit/api/stub.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
const tap = require('tap')
const API = require('../../../stub_api')

const EXPECTED_API_COUNT = 34
const EXPECTED_API_COUNT = 35

tap.test('Agent API - Stubbed Agent API', (t) => {
t.autoend()
Expand Down
113 changes: 111 additions & 2 deletions test/unit/llm-events/openai/chat-completion-message.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ const helper = require('../../../lib/agent_helper')
const { req, chatRes, getExpectedResult } = require('./common')

tap.test('LlmChatCompletionMessage', (t) => {
t.autoend()

let agent
t.beforeEach(() => {
agent = helper.loadMockedAgent()
Expand Down Expand Up @@ -104,4 +102,115 @@ tap.test('LlmChatCompletionMessage', (t) => {
t.end()
})
})

t.test('should use token_count from tokenCountCallback for prompt message', (t) => {
const api = helper.getAgentApi()
const expectedCount = 4
function cb(model, content) {
t.equal(model, 'gpt-3.5-turbo-0613')
t.equal(content, 'What is a woodchuck?')
return expectedCount
}
api.setLlmTokenCountCallback(cb)
helper.runInTransaction(agent, () => {
api.startSegment('fakeSegment', false, () => {
const segment = api.shim.getActiveSegment()
const summaryId = 'chat-summary-id'
delete chatRes.usage
const chatMessageEvent = new LlmChatCompletionMessage({
agent,
segment,
request: req,
response: chatRes,
completionId: summaryId,
message: req.messages[0],
index: 0
})
t.equal(chatMessageEvent.token_count, expectedCount)
t.end()
})
})
})

t.test('should use token_count from tokenCountCallback for completion messages', (t) => {
const api = helper.getAgentApi()
const expectedCount = 4
function cb(model, content) {
t.equal(model, 'gpt-3.5-turbo-0613')
t.equal(content, 'a lot')
return expectedCount
}
api.setLlmTokenCountCallback(cb)
helper.runInTransaction(agent, () => {
api.startSegment('fakeSegment', false, () => {
const segment = api.shim.getActiveSegment()
const summaryId = 'chat-summary-id'
delete chatRes.usage
const chatMessageEvent = new LlmChatCompletionMessage({
agent,
segment,
request: req,
response: chatRes,
completionId: summaryId,
message: chatRes.choices[0].message,
index: 2
})
t.equal(chatMessageEvent.token_count, expectedCount)
t.end()
})
})
})

t.test('should not set token_count if not set in usage nor a callback registered', (t) => {
const api = helper.getAgentApi()
helper.runInTransaction(agent, () => {
api.startSegment('fakeSegment', false, () => {
const segment = api.shim.getActiveSegment()
const summaryId = 'chat-summary-id'
delete chatRes.usage
const chatMessageEvent = new LlmChatCompletionMessage({
agent,
segment,
request: req,
response: chatRes,
completionId: summaryId,
message: chatRes.choices[0].message,
index: 2
})
t.equal(chatMessageEvent.token_count, undefined)
t.end()
})
})
})

t.test(
'should not set token_count if not set in usage nor a callback registered returns count',
(t) => {
const api = helper.getAgentApi()
function cb() {
// empty cb
}
api.setLlmTokenCountCallback(cb)
helper.runInTransaction(agent, () => {
api.startSegment('fakeSegment', false, () => {
const segment = api.shim.getActiveSegment()
const summaryId = 'chat-summary-id'
delete chatRes.usage
const chatMessageEvent = new LlmChatCompletionMessage({
agent,
segment,
request: req,
response: chatRes,
completionId: summaryId,
message: chatRes.choices[0].message,
index: 2
})
t.equal(chatMessageEvent.token_count, undefined)
t.end()
})
})
}
)

t.end()
})
50 changes: 50 additions & 0 deletions test/unit/llm-events/openai/embedding.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,54 @@ tap.test('LlmEmbedding', (t) => {
t.end()
})
})

t.test('should calculate token count from tokenCountCallback', (t) => {
const req = {
input: 'This is my test input',
model: 'gpt-3.5-turbo-0613'
}

const api = helper.getAgentApi()

function cb(model, content) {
if (model === req.model) {
return content.length
}
}

api.setLlmTokenCountCallback(cb)
helper.runInTransaction(agent, () => {
const segment = api.shim.getActiveSegment()
delete res.usage
const embeddingEvent = new LlmEmbedding({
agent,
segment,
request: req,
response: res
})
t.equal(embeddingEvent.token_count, 21)
t.end()
})
})

t.test('should not set token count when not present in usage nor tokenCountCallback', (t) => {
const req = {
input: 'This is my test input',
model: 'gpt-3.5-turbo-0613'
}

const api = helper.getAgentApi()
helper.runInTransaction(agent, () => {
const segment = api.shim.getActiveSegment()
delete res.usage
const embeddingEvent = new LlmEmbedding({
agent,
segment,
request: req,
response: res
})
t.equal(embeddingEvent.token_count, undefined)
t.end()
})
})
})