Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added instrumentation for openai chat completion creation (#1862)
- Loading branch information
Showing
10 changed files
with
535 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
/* | ||
* Copyright 2023 New Relic Corporation. All rights reserved. | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
'use strict' | ||
const { openAiHeaders, openAiApiKey } = require('../../lib/symbols') | ||
const { | ||
LlmChatCompletionMessage, | ||
LlmChatCompletionSummary | ||
} = require('../../lib/llm-events/openai') | ||
|
||
const MIN_VERSION = '4.0.0' | ||
const semver = require('semver') | ||
|
||
/** | ||
* Checks if we should skip instrumentation. | ||
* Currently it checks if `feature_flag.openai_instrumentation` is true | ||
* and the package version >= 4.0.0 | ||
* | ||
* @param {object} config agent config | ||
* @param {Shim} shim instance of shim | ||
* @returns {boolean} flag if instrumentation should be skipped | ||
*/ | ||
function shouldSkipInstrumentation(config, shim) { | ||
// TODO: Remove when we release full support for OpenAI | ||
if (!config?.feature_flag?.openai_instrumentation) { | ||
shim.logger.debug('config.feature_flag.openai_instrumentation is disabled.') | ||
return true | ||
} | ||
|
||
const { version: pkgVersion } = shim.require('./package.json') | ||
return semver.lt(pkgVersion, MIN_VERSION) | ||
} | ||
|
||
module.exports = function initialize(agent, openai, moduleName, shim) { | ||
if (shouldSkipInstrumentation(agent.config, shim)) { | ||
shim.logger.debug( | ||
`${moduleName} instrumentation support is for versions >=${MIN_VERSION}. Skipping instrumentation.` | ||
) | ||
return | ||
} | ||
|
||
/** | ||
* Adds apiKey and response headers to the active segment | ||
* on symbols | ||
* | ||
* @param {object} result from openai request | ||
* @param {string} apiKey api key from openai client | ||
*/ | ||
function decorateSegment(result, apiKey) { | ||
const segment = shim.getActiveSegment() | ||
|
||
if (segment) { | ||
segment[openAiApiKey] = apiKey | ||
segment[openAiHeaders] = | ||
result?.response?.headers && Object.fromEntries(result.response.headers) | ||
} | ||
} | ||
|
||
/** | ||
* Enqueues a LLM event to the custom event aggregator | ||
* | ||
* @param {string} type of LLM event | ||
* @param {object} msg LLM event | ||
*/ | ||
function recordEvent(type, msg) { | ||
agent.customEventAggregator.add([{ type, timestamp: Date.now() }, msg]) | ||
} | ||
|
||
/** | ||
* Instrumentation is only done to get the response headers and attach | ||
* to the active segment as openai hides the headers from the functions we are | ||
* trying to instrument | ||
*/ | ||
shim.wrap(openai.prototype, 'makeRequest', function wrapRequest(shim, makeRequest) { | ||
return function wrappedRequest() { | ||
const apiKey = this.apiKey | ||
const result = makeRequest.apply(this, arguments) | ||
result.then( | ||
(data) => { | ||
// add headers on resolve | ||
decorateSegment(data, apiKey) | ||
}, | ||
(data) => { | ||
// add headers on reject | ||
decorateSegment(data, apiKey) | ||
} | ||
) | ||
return result | ||
} | ||
}) | ||
|
||
/** | ||
* Instruments chat completion creation | ||
* and creates the LLM events | ||
* | ||
* **Note**: Currently only for promises. streams will come later | ||
*/ | ||
shim.record( | ||
openai.Chat.Completions.prototype, | ||
'create', | ||
function wrapCreate(shim, create, name, args) { | ||
const [request] = args | ||
return { | ||
name: 'AI/OpenAI/Chat/Completions/Create', | ||
promise: true, | ||
opaque: true, | ||
// eslint-disable-next-line max-params | ||
after(_shim, _fn, _name, err, response, segment) { | ||
response.headers = segment[openAiHeaders] | ||
response.api_key = segment[openAiApiKey] | ||
|
||
// TODO: add LlmErrorMessage on failure | ||
// and exit | ||
// See: https://github.com/newrelic/node-newrelic/issues/1845 | ||
// if (err) {} | ||
|
||
const completionSummary = new LlmChatCompletionSummary({ | ||
agent, | ||
segment, | ||
request, | ||
response | ||
}) | ||
|
||
request.messages.forEach((_msg, index) => { | ||
const completionMsg = new LlmChatCompletionMessage({ | ||
agent, | ||
segment, | ||
request, | ||
response, | ||
index | ||
}) | ||
|
||
recordEvent('LlmChatCompletionMessage', completionMsg) | ||
}) | ||
|
||
recordEvent('LlmChatCompletionSummary', completionSummary) | ||
|
||
// cleanup keys on response before returning to user code | ||
delete response.api_key | ||
delete response.headers | ||
} | ||
} | ||
} | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* | ||
* Copyright 2023 New Relic Corporation. All rights reserved. | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
'use strict' | ||
|
||
const { test } = require('tap') | ||
const helper = require('../../lib/agent_helper') | ||
const GenericShim = require('../../../lib/shim/shim') | ||
const sinon = require('sinon') | ||
|
||
test('openai unit tests', (t) => { | ||
t.autoend() | ||
|
||
t.beforeEach(function (t) { | ||
const sandbox = sinon.createSandbox() | ||
const agent = helper.loadMockedAgent() | ||
agent.config.feature_flag = { openai_instrumentation: true } | ||
const shim = new GenericShim(agent, 'openai') | ||
sandbox.stub(shim, 'require') | ||
shim.require.returns({ version: '4.0.0' }) | ||
sandbox.stub(shim.logger, 'debug') | ||
|
||
t.context.agent = agent | ||
t.context.shim = shim | ||
t.context.sandbox = sandbox | ||
t.context.initialize = require('../../../lib/instrumentation/openai') | ||
}) | ||
|
||
t.afterEach(function (t) { | ||
helper.unloadAgent(t.context.agent) | ||
t.context.sandbox.restore() | ||
}) | ||
|
||
function getMockModule() { | ||
function Completions() {} | ||
Completions.prototype.create = function () {} | ||
function OpenAI() {} | ||
OpenAI.prototype.makeRequest = function () {} | ||
OpenAI.Chat = { Completions } | ||
return OpenAI | ||
} | ||
|
||
t.test('should instrument openapi if >= 4.0.0', (t) => { | ||
const { shim, agent, initialize } = t.context | ||
const MockOpenAi = getMockModule() | ||
initialize(agent, MockOpenAi, 'openai', shim) | ||
t.equal(shim.logger.debug.callCount, 0, 'should not log debug messages') | ||
const isWrapped = shim.isWrapped(MockOpenAi.Chat.Completions.prototype.create) | ||
t.equal(isWrapped, true, 'should wrap chat completions create') | ||
t.end() | ||
}) | ||
|
||
t.test('should not register instrumentation if openai is < 4.0.0', (t) => { | ||
const { shim, agent, initialize } = t.context | ||
const MockOpenAi = getMockModule() | ||
shim.require.returns({ version: '3.7.0' }) | ||
initialize(agent, MockOpenAi, 'openai', shim) | ||
t.equal(shim.logger.debug.callCount, 1, 'should log 2 debug messages') | ||
t.equal( | ||
shim.logger.debug.args[0][0], | ||
'openai instrumentation support is for versions >=4.0.0. Skipping instrumentation.' | ||
) | ||
const isWrapped = shim.isWrapped(MockOpenAi.Chat.Completions.prototype.create) | ||
t.equal(isWrapped, false, 'should not wrap chat completions create') | ||
t.end() | ||
}) | ||
|
||
t.test( | ||
'should not register instrumentation if feature_flag.openai_instrumentation is false', | ||
(t) => { | ||
const { shim, agent, initialize } = t.context | ||
const MockOpenAi = getMockModule() | ||
agent.config.feature_flag = { openai_instrumentation: false } | ||
|
||
initialize(agent, MockOpenAi, 'openai', shim) | ||
t.equal(shim.logger.debug.callCount, 2, 'should log 2 debug messages') | ||
t.equal( | ||
shim.logger.debug.args[0][0], | ||
'config.feature_flag.openai_instrumentation is disabled.' | ||
) | ||
t.equal( | ||
shim.logger.debug.args[1][0], | ||
'openai instrumentation support is for versions >=4.0.0. Skipping instrumentation.' | ||
) | ||
const isWrapped = shim.isWrapped(MockOpenAi.Chat.Completions.prototype.create) | ||
t.equal(isWrapped, false, 'should not wrap chat completions create') | ||
t.end() | ||
} | ||
) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.