Skip to content

Commit

Permalink
feat: Added instrumentation for chat completion streams (#1884)
Browse files Browse the repository at this point in the history
  • Loading branch information
bizob2828 committed Nov 27, 2023
1 parent 4ddfd81 commit 404e317
Show file tree
Hide file tree
Showing 15 changed files with 964 additions and 537 deletions.
116 changes: 88 additions & 28 deletions lib/instrumentation/openai.js
Expand Up @@ -14,6 +14,7 @@ const {
const LlmTrackedIds = require('../../lib/llm-events/tracked-ids')

const MIN_VERSION = '4.0.0'
const MIN_STREAM_VERSION = '4.12.2'
const { AI } = require('../../lib/metrics/names')
const semver = require('semver')

Expand Down Expand Up @@ -135,18 +136,27 @@ module.exports = function initialize(agent, openai, moduleName, shim) {
* @param {TraceSegment} params.segment active segment from chat completion
* @param {object} params.request chat completion params
* @param {object} params.response chat completion response
* @param {boolean} [params.hasError=false] indicates that the completion
* resulted in some sort of error case
* @returns {LlmChatCompletionSummary} A summary object.
* @param {boolean} [params.err] err if it exists
*/
function recordChatCompletionMessages({ segment, request, response, hasError = false }) {
function recordChatCompletionMessages({ segment, request, response, err }) {
if (!response) {
// If we get an error, it is possible that `response = null`.
// In that case, we define it to be an empty object.
response = {}
}

response.headers = segment[openAiHeaders]
response.api_key = segment[openAiApiKey]
const tx = segment.transaction
// explicitly end segment to consistent duration
// for both LLM events and the segment
segment.end()
const completionSummary = new LlmChatCompletionSummary({
agent,
segment,
request,
response,
withError: hasError
withError: err != null
})

// Only take the first response message and append to input messages
Expand All @@ -168,7 +178,60 @@ module.exports = function initialize(agent, openai, moduleName, shim) {

recordEvent('LlmChatCompletionSummary', completionSummary)

return completionSummary
if (err) {
const llmError = new LlmErrorMessage({ cause: err, summary: completionSummary, response })
shim.agent.errors.add(segment.transaction, err, llmError)
}

delete response.headers
delete response.api_key
}

/*
* Chat completions create can return a stream once promise resolves
* This wraps the iterator which is a generator function
* We will call the original iterator, intercept chunks and yield
* to the original. On complete we will construct the new message object
* with what we have seen in the stream and create the chat completion
* messages
*
*/
function instrumentStream({ request, response, segment }) {
shim.wrap(response, 'iterator', function wrapIterator(shim, orig) {
return async function* wrappedIterator() {
let content = ''
let role = ''
let chunk
let err
try {
const iterator = orig.apply(this, arguments)

for await (chunk of iterator) {
if (chunk.choices?.[0]?.delta?.role) {
role = chunk.choices[0].delta.role
}

content += chunk.choices?.[0]?.delta?.content ?? ''
yield chunk
}
} catch (streamErr) {
err = streamErr
throw err
} finally {
chunk.choices[0].message = { role, content }
// update segment duration since we want to extend the time it took to
// handle the stream
segment.touch()

recordChatCompletionMessages({
segment,
request,
response: chunk,
err
})
}
}
})
}

/**
Expand All @@ -182,34 +245,28 @@ module.exports = function initialize(agent, openai, moduleName, shim) {
'create',
function wrapCreate(shim, create, name, args) {
const [request] = args
if (request.stream && semver.lt(shim.pkgVersion, MIN_STREAM_VERSION)) {
shim.logger.warn(
`Instrumenting chat completion streams is only supported with openai version ${MIN_STREAM_VERSION}+.`
)
return
}

return {
name: `${AI.OPEN_AI}/Chat/Completions/Create`,
promise: true,
// eslint-disable-next-line max-params
after(_shim, _fn, _name, err, response, segment) {
if (!response) {
// If we get an error, it is possible that `response = null`.
// In that case, we define it to be an empty object.
response = {}
if (request.stream) {
instrumentStream({ request, response, segment })
} else {
recordChatCompletionMessages({
segment,
request,
response,
err
})
}
response.headers = segment[openAiHeaders]
response.api_key = segment[openAiApiKey]

const summary = recordChatCompletionMessages({
segment,
request,
response,
hasError: err != null
})

if (err) {
const llmError = new LlmErrorMessage({ cause: err, summary, response })
shim.agent.errors.add(segment.transaction, err, llmError)
}

// cleanup keys on response before returning to user code
delete response.api_key
delete response.headers
}
}
}
Expand All @@ -236,6 +293,9 @@ module.exports = function initialize(agent, openai, moduleName, shim) {
}
response.headers = segment[openAiHeaders]
response.api_key = segment[openAiApiKey]
// explicitly end segment to get consistent duration
// for both LLM events and the segment
segment.end()
const embedding = new LlmEmbedding({
agent,
segment,
Expand Down
2 changes: 1 addition & 1 deletion lib/llm-events/openai/event.js
Expand Up @@ -27,7 +27,7 @@ module.exports = class LlmEvent {
*/
if (responseAttrs) {
this['request.model'] = request.model || request.engine
this.duration = segment?.getExclusiveDurationInMillis()
this.duration = segment?.getDurationInMillis()
this.api_key_last_four_digits = response?.api_key && `sk-${response.api_key.slice(-4)}`
this.responseAttrs(response)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/shim/shim.js
Expand Up @@ -30,7 +30,7 @@ const fnApply = Function.prototype.apply
* @param {string} moduleName - The name of the module being instrumented.
* @param {string} resolvedName - The full path to the loaded module.
* @param {string} shimName - Used to persist shim ids across different instances. This is
* @param {string} pkgVersion - version of module
* @param {string} pkgVersion - version of package getting instrumented
* applicable to instrument that compliments each other across libraries(i.e - koa + koa-route/koa-router)
*/
function Shim(agent, moduleName, resolvedName, shimName, pkgVersion) {
Expand Down
20 changes: 18 additions & 2 deletions test/unit/instrumentation/openai.test.js
Expand Up @@ -20,6 +20,7 @@ test('openai unit tests', (t) => {
const shim = new GenericShim(agent, 'openai')
shim.pkgVersion = '4.0.0'
sandbox.stub(shim.logger, 'debug')
sandbox.stub(shim.logger, 'warn')

t.context.agent = agent
t.context.shim = shim
Expand All @@ -34,12 +35,12 @@ test('openai unit tests', (t) => {

function getMockModule() {
function Completions() {}
Completions.prototype.create = function () {}
Completions.prototype.create = async function () {}
function OpenAI() {}
OpenAI.prototype.makeRequest = function () {}
OpenAI.Chat = { Completions }
OpenAI.Embeddings = function () {}
OpenAI.Embeddings.prototype.create = function () {}
OpenAI.Embeddings.prototype.create = async function () {}
return OpenAI
}

Expand All @@ -53,6 +54,21 @@ test('openai unit tests', (t) => {
t.end()
})

t.test('should not instrument chat completion streams if < 4.12.2', async (t) => {
const { shim, agent, initialize } = t.context
shim.pkgVersion = '4.12.0'
const MockOpenAi = getMockModule()
initialize(agent, MockOpenAi, 'openai', shim)
const completions = new MockOpenAi.Chat.Completions()

await completions.create({ stream: true })
t.equal(
shim.logger.warn.args[0][0],
'Instrumenting chat completion streams is only supported with openai version 4.12.2+.'
)
t.end()
})

t.test('should not register instrumentation if openai is < 4.0.0', (t) => {
const { shim, agent, initialize } = t.context
const MockOpenAi = getMockModule()
Expand Down
Expand Up @@ -27,6 +27,7 @@ tap.test('LlmChatCompletionSummary', (t) => {
helper.runInTransaction(agent, (tx) => {
api.startSegment('fakeSegment', false, () => {
const segment = api.shim.getActiveSegment()
segment.end()
const chatSummaryEvent = new LlmChatCompletionSummary({
agent,
segment,
Expand Down
2 changes: 1 addition & 1 deletion test/unit/llm-events/openai/common.js
Expand Up @@ -58,7 +58,7 @@ function getExpectedResult(tx, event, type, completionId) {
'ingest_source': 'Node'
}
const resKeys = {
'duration': trace.children[0].getExclusiveDurationInMillis(),
'duration': trace.children[0].getDurationInMillis(),
'request.model': 'gpt-3.5-turbo-0613',
'api_key_last_four_digits': 'sk-7890',
'response.organization': 'new-relic',
Expand Down
1 change: 1 addition & 0 deletions test/unit/llm-events/openai/embedding.test.js
Expand Up @@ -32,6 +32,7 @@ tap.test('LlmEmbedding', (t) => {
helper.runInTransaction(agent, (tx) => {
api.startSegment('fakeSegment', false, () => {
const segment = api.shim.getActiveSegment()
segment.end()
const embeddingEvent = new LlmEmbedding({ agent, segment, request: req, response: res })
const expected = getExpectedResult(tx, embeddingEvent, 'embedding')
t.same(embeddingEvent, expected)
Expand Down

0 comments on commit 404e317

Please sign in to comment.