Skip to content

Commit

Permalink
feat: Added instrumentation for openai chat completion creation (#1862)
Browse files Browse the repository at this point in the history
  • Loading branch information
bizob2828 committed Nov 15, 2023
1 parent 20e7f1d commit 34dcd70
Show file tree
Hide file tree
Showing 10 changed files with 535 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -158,7 +158,7 @@ These are the steps to work on core agent features, with more detail below:
$ git clone git@github.com:your-user-name/node-newrelic.git
$ cd node-newrelic

2. Install the project's dependences:
2. Install the project's dependencies:

$ npm install

Expand Down
147 changes: 147 additions & 0 deletions lib/instrumentation/openai.js
@@ -0,0 +1,147 @@
/*
* Copyright 2023 New Relic Corporation. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

'use strict'
const { openAiHeaders, openAiApiKey } = require('../../lib/symbols')
const {
LlmChatCompletionMessage,
LlmChatCompletionSummary
} = require('../../lib/llm-events/openai')

const MIN_VERSION = '4.0.0'
const semver = require('semver')

/**
* Checks if we should skip instrumentation.
* Currently it checks if `feature_flag.openai_instrumentation` is true
* and the package version >= 4.0.0
*
* @param {object} config agent config
* @param {Shim} shim instance of shim
* @returns {boolean} flag if instrumentation should be skipped
*/
function shouldSkipInstrumentation(config, shim) {
// TODO: Remove when we release full support for OpenAI
if (!config?.feature_flag?.openai_instrumentation) {
shim.logger.debug('config.feature_flag.openai_instrumentation is disabled.')
return true
}

const { version: pkgVersion } = shim.require('./package.json')
return semver.lt(pkgVersion, MIN_VERSION)
}

module.exports = function initialize(agent, openai, moduleName, shim) {
if (shouldSkipInstrumentation(agent.config, shim)) {
shim.logger.debug(
`${moduleName} instrumentation support is for versions >=${MIN_VERSION}. Skipping instrumentation.`
)
return
}

/**
* Adds apiKey and response headers to the active segment
* on symbols
*
* @param {object} result from openai request
* @param {string} apiKey api key from openai client
*/
function decorateSegment(result, apiKey) {
const segment = shim.getActiveSegment()

if (segment) {
segment[openAiApiKey] = apiKey
segment[openAiHeaders] =
result?.response?.headers && Object.fromEntries(result.response.headers)
}
}

/**
* Enqueues a LLM event to the custom event aggregator
*
* @param {string} type of LLM event
* @param {object} msg LLM event
*/
function recordEvent(type, msg) {
agent.customEventAggregator.add([{ type, timestamp: Date.now() }, msg])
}

/**
* Instrumentation is only done to get the response headers and attach
* to the active segment as openai hides the headers from the functions we are
* trying to instrument
*/
shim.wrap(openai.prototype, 'makeRequest', function wrapRequest(shim, makeRequest) {
return function wrappedRequest() {
const apiKey = this.apiKey
const result = makeRequest.apply(this, arguments)
result.then(
(data) => {
// add headers on resolve
decorateSegment(data, apiKey)
},
(data) => {
// add headers on reject
decorateSegment(data, apiKey)
}
)
return result
}
})

/**
* Instruments chat completion creation
* and creates the LLM events
*
* **Note**: Currently only for promises. streams will come later
*/
shim.record(
openai.Chat.Completions.prototype,
'create',
function wrapCreate(shim, create, name, args) {
const [request] = args
return {
name: 'AI/OpenAI/Chat/Completions/Create',
promise: true,
opaque: true,
// eslint-disable-next-line max-params
after(_shim, _fn, _name, err, response, segment) {
response.headers = segment[openAiHeaders]
response.api_key = segment[openAiApiKey]

// TODO: add LlmErrorMessage on failure
// and exit
// See: https://github.com/newrelic/node-newrelic/issues/1845
// if (err) {}

const completionSummary = new LlmChatCompletionSummary({
agent,
segment,
request,
response
})

request.messages.forEach((_msg, index) => {
const completionMsg = new LlmChatCompletionMessage({
agent,
segment,
request,
response,
index
})

recordEvent('LlmChatCompletionMessage', completionMsg)
})

recordEvent('LlmChatCompletionSummary', completionSummary)

// cleanup keys on response before returning to user code
delete response.api_key
delete response.headers
}
}
}
)
}
1 change: 1 addition & 0 deletions lib/instrumentations.js
Expand Up @@ -28,6 +28,7 @@ module.exports = function instrumentations() {
'memcached': { type: MODULE_TYPE.DATASTORE },
'mongodb': { type: MODULE_TYPE.DATASTORE },
'mysql': { module: './instrumentation/mysql' },
'openai': { type: MODULE_TYPE.GENERIC },
'@nestjs/core': { type: MODULE_TYPE.WEB_FRAMEWORK },
'pino': { module: './instrumentation/pino' },
'pg': { type: MODULE_TYPE.DATASTORE },
Expand Down
6 changes: 3 additions & 3 deletions lib/shim/shim.js
Expand Up @@ -957,12 +957,12 @@ function record(nodule, properties, recordNamer) {
return ret.then(
function onThen(val) {
segment.touch()
segDesc.after(shim, fn, name, null, val)
segDesc.after(shim, fn, name, null, val, segment)
return val
},
function onCatch(err) {
segment.touch()
segDesc.after(shim, fn, name, err, null)
segDesc.after(shim, fn, name, err, null, segment)
throw err // NOTE: This is not an error from our instrumentation.
}
)
Expand All @@ -973,7 +973,7 @@ function record(nodule, properties, recordNamer) {
throw err // Just rethrowing this error, not our error!
} finally {
if (segDesc.after && (error || !promised)) {
segDesc.after(shim, fn, name, error, ret)
segDesc.after(shim, fn, name, error, ret, segment)
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions lib/symbols.js
Expand Up @@ -18,6 +18,8 @@ module.exports = {
offTheRecord: Symbol('offTheRecord'),
original: Symbol('original'),
wrapped: Symbol('shimWrapped'),
openAiHeaders: Symbol('openAiHeaders'),
openAiApiKey: Symbol('openAiApiKey'),
parentSegment: Symbol('parentSegment'),
prismaConnection: Symbol('prismaConnection'),
prismaModelCall: Symbol('modelCall'),
Expand Down
92 changes: 92 additions & 0 deletions test/unit/instrumentation/openai.test.js
@@ -0,0 +1,92 @@
/*
* Copyright 2023 New Relic Corporation. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

'use strict'

const { test } = require('tap')
const helper = require('../../lib/agent_helper')
const GenericShim = require('../../../lib/shim/shim')
const sinon = require('sinon')

test('openai unit tests', (t) => {
t.autoend()

t.beforeEach(function (t) {
const sandbox = sinon.createSandbox()
const agent = helper.loadMockedAgent()
agent.config.feature_flag = { openai_instrumentation: true }
const shim = new GenericShim(agent, 'openai')
sandbox.stub(shim, 'require')
shim.require.returns({ version: '4.0.0' })
sandbox.stub(shim.logger, 'debug')

t.context.agent = agent
t.context.shim = shim
t.context.sandbox = sandbox
t.context.initialize = require('../../../lib/instrumentation/openai')
})

t.afterEach(function (t) {
helper.unloadAgent(t.context.agent)
t.context.sandbox.restore()
})

function getMockModule() {
function Completions() {}
Completions.prototype.create = function () {}
function OpenAI() {}
OpenAI.prototype.makeRequest = function () {}
OpenAI.Chat = { Completions }
return OpenAI
}

t.test('should instrument openapi if >= 4.0.0', (t) => {
const { shim, agent, initialize } = t.context
const MockOpenAi = getMockModule()
initialize(agent, MockOpenAi, 'openai', shim)
t.equal(shim.logger.debug.callCount, 0, 'should not log debug messages')
const isWrapped = shim.isWrapped(MockOpenAi.Chat.Completions.prototype.create)
t.equal(isWrapped, true, 'should wrap chat completions create')
t.end()
})

t.test('should not register instrumentation if openai is < 4.0.0', (t) => {
const { shim, agent, initialize } = t.context
const MockOpenAi = getMockModule()
shim.require.returns({ version: '3.7.0' })
initialize(agent, MockOpenAi, 'openai', shim)
t.equal(shim.logger.debug.callCount, 1, 'should log 2 debug messages')
t.equal(
shim.logger.debug.args[0][0],
'openai instrumentation support is for versions >=4.0.0. Skipping instrumentation.'
)
const isWrapped = shim.isWrapped(MockOpenAi.Chat.Completions.prototype.create)
t.equal(isWrapped, false, 'should not wrap chat completions create')
t.end()
})

t.test(
'should not register instrumentation if feature_flag.openai_instrumentation is false',
(t) => {
const { shim, agent, initialize } = t.context
const MockOpenAi = getMockModule()
agent.config.feature_flag = { openai_instrumentation: false }

initialize(agent, MockOpenAi, 'openai', shim)
t.equal(shim.logger.debug.callCount, 2, 'should log 2 debug messages')
t.equal(
shim.logger.debug.args[0][0],
'config.feature_flag.openai_instrumentation is disabled.'
)
t.equal(
shim.logger.debug.args[1][0],
'openai instrumentation support is for versions >=4.0.0. Skipping instrumentation.'
)
const isWrapped = shim.isWrapped(MockOpenAi.Chat.Completions.prototype.create)
t.equal(isWrapped, false, 'should not wrap chat completions create')
t.end()
}
)
})
7 changes: 6 additions & 1 deletion test/unit/llm-events/openai/embedding.test.js
Expand Up @@ -44,7 +44,12 @@ tap.test('LlmEmbedding', (t) => {
const api = helper.getAgentApi()
const metadata = { key: 'value', meta: 'data', test: true, data: [1, 2, 3] }
api.setLlmMetadata(metadata)
const embeddingEvent = new LlmEmbedding({ agent, segment: null, request: {}, response: {} })
const embeddingEvent = new LlmEmbedding({
agent,
segment: null,
request: {},
response: {}
})
t.same(embeddingEvent.metadata, metadata)
t.end()
})
Expand Down

0 comments on commit 34dcd70

Please sign in to comment.