diff --git a/README.md b/README.md index 3ba5ba38e2..f3e11dfa15 100644 --- a/README.md +++ b/README.md @@ -158,7 +158,7 @@ These are the steps to work on core agent features, with more detail below: $ git clone git@github.com:your-user-name/node-newrelic.git $ cd node-newrelic -2. Install the project's dependences: +2. Install the project's dependencies: $ npm install diff --git a/lib/instrumentation/openai.js b/lib/instrumentation/openai.js new file mode 100644 index 0000000000..4d5f27c3c1 --- /dev/null +++ b/lib/instrumentation/openai.js @@ -0,0 +1,147 @@ +/* + * Copyright 2023 New Relic Corporation. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +'use strict' +const { openAiHeaders, openAiApiKey } = require('../../lib/symbols') +const { + LlmChatCompletionMessage, + LlmChatCompletionSummary +} = require('../../lib/llm-events/openai') + +const MIN_VERSION = '4.0.0' +const semver = require('semver') + +/** + * Checks if we should skip instrumentation. + * Currently it checks if `feature_flag.openai_instrumentation` is true + * and the package version >= 4.0.0 + * + * @param {object} config agent config + * @param {Shim} shim instance of shim + * @returns {boolean} flag if instrumentation should be skipped + */ +function shouldSkipInstrumentation(config, shim) { + // TODO: Remove when we release full support for OpenAI + if (!config?.feature_flag?.openai_instrumentation) { + shim.logger.debug('config.feature_flag.openai_instrumentation is disabled.') + return true + } + + const { version: pkgVersion } = shim.require('./package.json') + return semver.lt(pkgVersion, MIN_VERSION) +} + +module.exports = function initialize(agent, openai, moduleName, shim) { + if (shouldSkipInstrumentation(agent.config, shim)) { + shim.logger.debug( + `${moduleName} instrumentation support is for versions >=${MIN_VERSION}. Skipping instrumentation.` + ) + return + } + + /** + * Adds apiKey and response headers to the active segment + * on symbols + * + * @param {object} result from openai request + * @param {string} apiKey api key from openai client + */ + function decorateSegment(result, apiKey) { + const segment = shim.getActiveSegment() + + if (segment) { + segment[openAiApiKey] = apiKey + segment[openAiHeaders] = + result?.response?.headers && Object.fromEntries(result.response.headers) + } + } + + /** + * Enqueues a LLM event to the custom event aggregator + * + * @param {string} type of LLM event + * @param {object} msg LLM event + */ + function recordEvent(type, msg) { + agent.customEventAggregator.add([{ type, timestamp: Date.now() }, msg]) + } + + /** + * Instrumentation is only done to get the response headers and attach + * to the active segment as openai hides the headers from the functions we are + * trying to instrument + */ + shim.wrap(openai.prototype, 'makeRequest', function wrapRequest(shim, makeRequest) { + return function wrappedRequest() { + const apiKey = this.apiKey + const result = makeRequest.apply(this, arguments) + result.then( + (data) => { + // add headers on resolve + decorateSegment(data, apiKey) + }, + (data) => { + // add headers on reject + decorateSegment(data, apiKey) + } + ) + return result + } + }) + + /** + * Instruments chat completion creation + * and creates the LLM events + * + * **Note**: Currently only for promises. streams will come later + */ + shim.record( + openai.Chat.Completions.prototype, + 'create', + function wrapCreate(shim, create, name, args) { + const [request] = args + return { + name: 'AI/OpenAI/Chat/Completions/Create', + promise: true, + opaque: true, + // eslint-disable-next-line max-params + after(_shim, _fn, _name, err, response, segment) { + response.headers = segment[openAiHeaders] + response.api_key = segment[openAiApiKey] + + // TODO: add LlmErrorMessage on failure + // and exit + // See: https://github.com/newrelic/node-newrelic/issues/1845 + // if (err) {} + + const completionSummary = new LlmChatCompletionSummary({ + agent, + segment, + request, + response + }) + + request.messages.forEach((_msg, index) => { + const completionMsg = new LlmChatCompletionMessage({ + agent, + segment, + request, + response, + index + }) + + recordEvent('LlmChatCompletionMessage', completionMsg) + }) + + recordEvent('LlmChatCompletionSummary', completionSummary) + + // cleanup keys on response before returning to user code + delete response.api_key + delete response.headers + } + } + } + ) +} diff --git a/lib/instrumentations.js b/lib/instrumentations.js index 5058ac3ee8..3af2b365cd 100644 --- a/lib/instrumentations.js +++ b/lib/instrumentations.js @@ -28,6 +28,7 @@ module.exports = function instrumentations() { 'memcached': { type: MODULE_TYPE.DATASTORE }, 'mongodb': { type: MODULE_TYPE.DATASTORE }, 'mysql': { module: './instrumentation/mysql' }, + 'openai': { type: MODULE_TYPE.GENERIC }, '@nestjs/core': { type: MODULE_TYPE.WEB_FRAMEWORK }, 'pino': { module: './instrumentation/pino' }, 'pg': { type: MODULE_TYPE.DATASTORE }, diff --git a/lib/shim/shim.js b/lib/shim/shim.js index 96d3df4784..75d680916c 100644 --- a/lib/shim/shim.js +++ b/lib/shim/shim.js @@ -957,12 +957,12 @@ function record(nodule, properties, recordNamer) { return ret.then( function onThen(val) { segment.touch() - segDesc.after(shim, fn, name, null, val) + segDesc.after(shim, fn, name, null, val, segment) return val }, function onCatch(err) { segment.touch() - segDesc.after(shim, fn, name, err, null) + segDesc.after(shim, fn, name, err, null, segment) throw err // NOTE: This is not an error from our instrumentation. } ) @@ -973,7 +973,7 @@ function record(nodule, properties, recordNamer) { throw err // Just rethrowing this error, not our error! } finally { if (segDesc.after && (error || !promised)) { - segDesc.after(shim, fn, name, error, ret) + segDesc.after(shim, fn, name, error, ret, segment) } } } diff --git a/lib/symbols.js b/lib/symbols.js index 6d3403b600..2b3e74081e 100644 --- a/lib/symbols.js +++ b/lib/symbols.js @@ -18,6 +18,8 @@ module.exports = { offTheRecord: Symbol('offTheRecord'), original: Symbol('original'), wrapped: Symbol('shimWrapped'), + openAiHeaders: Symbol('openAiHeaders'), + openAiApiKey: Symbol('openAiApiKey'), parentSegment: Symbol('parentSegment'), prismaConnection: Symbol('prismaConnection'), prismaModelCall: Symbol('modelCall'), diff --git a/test/unit/instrumentation/openai.test.js b/test/unit/instrumentation/openai.test.js new file mode 100644 index 0000000000..6bcb521a9d --- /dev/null +++ b/test/unit/instrumentation/openai.test.js @@ -0,0 +1,92 @@ +/* + * Copyright 2023 New Relic Corporation. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +'use strict' + +const { test } = require('tap') +const helper = require('../../lib/agent_helper') +const GenericShim = require('../../../lib/shim/shim') +const sinon = require('sinon') + +test('openai unit tests', (t) => { + t.autoend() + + t.beforeEach(function (t) { + const sandbox = sinon.createSandbox() + const agent = helper.loadMockedAgent() + agent.config.feature_flag = { openai_instrumentation: true } + const shim = new GenericShim(agent, 'openai') + sandbox.stub(shim, 'require') + shim.require.returns({ version: '4.0.0' }) + sandbox.stub(shim.logger, 'debug') + + t.context.agent = agent + t.context.shim = shim + t.context.sandbox = sandbox + t.context.initialize = require('../../../lib/instrumentation/openai') + }) + + t.afterEach(function (t) { + helper.unloadAgent(t.context.agent) + t.context.sandbox.restore() + }) + + function getMockModule() { + function Completions() {} + Completions.prototype.create = function () {} + function OpenAI() {} + OpenAI.prototype.makeRequest = function () {} + OpenAI.Chat = { Completions } + return OpenAI + } + + t.test('should instrument openapi if >= 4.0.0', (t) => { + const { shim, agent, initialize } = t.context + const MockOpenAi = getMockModule() + initialize(agent, MockOpenAi, 'openai', shim) + t.equal(shim.logger.debug.callCount, 0, 'should not log debug messages') + const isWrapped = shim.isWrapped(MockOpenAi.Chat.Completions.prototype.create) + t.equal(isWrapped, true, 'should wrap chat completions create') + t.end() + }) + + t.test('should not register instrumentation if openai is < 4.0.0', (t) => { + const { shim, agent, initialize } = t.context + const MockOpenAi = getMockModule() + shim.require.returns({ version: '3.7.0' }) + initialize(agent, MockOpenAi, 'openai', shim) + t.equal(shim.logger.debug.callCount, 1, 'should log 2 debug messages') + t.equal( + shim.logger.debug.args[0][0], + 'openai instrumentation support is for versions >=4.0.0. Skipping instrumentation.' + ) + const isWrapped = shim.isWrapped(MockOpenAi.Chat.Completions.prototype.create) + t.equal(isWrapped, false, 'should not wrap chat completions create') + t.end() + }) + + t.test( + 'should not register instrumentation if feature_flag.openai_instrumentation is false', + (t) => { + const { shim, agent, initialize } = t.context + const MockOpenAi = getMockModule() + agent.config.feature_flag = { openai_instrumentation: false } + + initialize(agent, MockOpenAi, 'openai', shim) + t.equal(shim.logger.debug.callCount, 2, 'should log 2 debug messages') + t.equal( + shim.logger.debug.args[0][0], + 'config.feature_flag.openai_instrumentation is disabled.' + ) + t.equal( + shim.logger.debug.args[1][0], + 'openai instrumentation support is for versions >=4.0.0. Skipping instrumentation.' + ) + const isWrapped = shim.isWrapped(MockOpenAi.Chat.Completions.prototype.create) + t.equal(isWrapped, false, 'should not wrap chat completions create') + t.end() + } + ) +}) diff --git a/test/unit/llm-events/openai/embedding.test.js b/test/unit/llm-events/openai/embedding.test.js index 14b17931e8..98dd115db4 100644 --- a/test/unit/llm-events/openai/embedding.test.js +++ b/test/unit/llm-events/openai/embedding.test.js @@ -44,7 +44,12 @@ tap.test('LlmEmbedding', (t) => { const api = helper.getAgentApi() const metadata = { key: 'value', meta: 'data', test: true, data: [1, 2, 3] } api.setLlmMetadata(metadata) - const embeddingEvent = new LlmEmbedding({ agent, segment: null, request: {}, response: {} }) + const embeddingEvent = new LlmEmbedding({ + agent, + segment: null, + request: {}, + response: {} + }) t.same(embeddingEvent.metadata, metadata) t.end() }) diff --git a/test/unit/shim/shim.test.js b/test/unit/shim/shim.test.js index a3eac7442b..80a740727c 100644 --- a/test/unit/shim/shim.test.js +++ b/test/unit/shim/shim.test.js @@ -960,6 +960,64 @@ tap.test('Shim', function (t) { t.end() }) }) + + t.test('should call after hook on record when function is done executing', function (t) { + helper.runInTransaction(agent, function () { + function testAfter() { + return 'result' + } + const wrapped = shim.record(testAfter, function () { + return { + name: 'test segment', + callback: shim.LAST, + after(...args) { + t.equal(args.length, 6, 'should have 6 args to after hook') + const [, fn, fnName, err, val, segment] = args + t.equal(segment.name, 'test segment') + t.not(err) + t.same(fn, testAfter) + t.equal(fnName, testAfter.name) + t.equal(val, 'result') + } + } + }) + t.doesNotThrow(function () { + wrapped() + }) + t.end() + }) + }) + + t.test( + 'should call after hook on record when the function is done executing after failure', + function (t) { + const err = new Error('test err') + helper.runInTransaction(agent, function () { + function testAfter() { + throw err + } + const wrapped = shim.record(testAfter, function () { + return { + name: 'test segment', + callback: shim.LAST, + after(...args) { + t.equal(args.length, 6, 'should have 6 args to after hook') + const [, fn, fnName, expectedErr, val, segment] = args + t.equal(segment.name, 'test segment') + t.same(expectedErr, err) + t.equal(val, undefined) + t.same(fn, testAfter) + t.equal(fnName, testAfter.name) + } + } + }) + t.throws(function () { + wrapped() + }) + t.end() + }) + } + ) }) t.test('#record with a stream', function (t) { @@ -1285,6 +1343,66 @@ tap.test('Shim', function (t) { promise.reject(result) }, 5) }) + + t.test('should call after hook when promise resolves', (t) => { + const name = 'test segment' + const result = { returned: true } + const wrapped = shim.record(toWrap, function () { + return { + name, + promise: true, + after(...args) { + t.equal(args.length, 6, 'should have 6 args to after hook') + const [, fn, fnName, err, val, segment] = args + t.same(fn, toWrap) + t.equal(fnName, toWrap.name) + t.not(err) + t.same(val, result) + t.equal(segment.name, name) + t.end() + } + } + }) + + helper.runInTransaction(agent, function () { + const ret = wrapped() + t.ok(ret instanceof Object.getPrototypeOf(promise).constructor) + }) + + setTimeout(function () { + promise.resolve(result) + }, 5) + }) + + t.test('should call after hook when promise reject', (t) => { + const name = 'test segment' + const result = { returned: true } + const wrapped = shim.record(toWrap, function () { + return { + name, + promise: true, + after(...args) { + t.equal(args.length, 6, 'should have 6 args to after hook') + const [, fn, fnName, err, val, segment] = args + t.same(fn, toWrap) + t.equal(fnName, toWrap.name) + t.same(err, result) + t.not(val) + t.equal(segment.name, name) + t.end() + } + } + }) + + helper.runInTransaction(agent, function () { + const ret = wrapped() + t.ok(ret instanceof Object.getPrototypeOf(promise).constructor) + }) + + setTimeout(function () { + promise.reject(result) + }, 5) + }) }) t.test('#record wrapper when called without a transaction', function (t) { diff --git a/test/versioned/openai/openai.tap.js b/test/versioned/openai/openai.tap.js new file mode 100644 index 0000000000..b7dfbf3554 --- /dev/null +++ b/test/versioned/openai/openai.tap.js @@ -0,0 +1,144 @@ +/* + * Copyright 2023 New Relic Corporation. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +'use strict' + +const tap = require('tap') +const helper = require('../../lib/agent_helper') +const createOpenAIMockServer = require('../../lib/openai-mock-server') +// TODO: remove config once we fully release OpenAI instrumentation +const config = { + feature_flag: { + openai_instrumentation: true + } +} + +tap.test('OpenAI instrumentation', (t) => { + t.autoend() + + t.before(async () => { + const { host, port, server } = await createOpenAIMockServer() + t.context.server = server + t.context.agent = helper.instrumentMockedAgent(config) + const OpenAI = require('openai') + t.context.client = new OpenAI({ + apiKey: 'fake-versioned-test-key', + baseURL: `http://${host}:${port}` + }) + }) + + t.afterEach(() => { + t.context.agent.customEventAggregator.clear() + }) + + t.teardown(() => { + t.context?.server?.close() + t.context.agent && helper.unloadAgent(t.context.agent) + }) + + t.test('should create chat completion span on successful chat completion create', (test) => { + const { client, agent } = t.context + helper.runInTransaction(agent, async (tx) => { + const results = await client.chat.completions.create({ + messages: [{ role: 'user', content: 'You are a mathematician.' }] + }) + + test.not(results.headers, 'should remove response headers from user result') + test.not(results.api_key, 'should remove api_key from user result') + test.equal(results.choices[0].message.content, '1 plus 2 is 3.') + + const [span] = tx.trace.root.children + test.equal(span.name, 'AI/OpenAI/Chat/Completions/Create') + test.end() + }) + }) + + t.test('should create chat completion message and summary for every message sent', (test) => { + const { client, agent } = t.context + helper.runInTransaction(agent, async (tx) => { + await client.chat.completions.create({ + max_tokens: 100, + temperature: 0.5, + model: 'gpt-3.5-turbo-0613', + messages: [ + { role: 'user', content: 'You are a mathematician.' }, + { role: 'user', content: 'What does 1 plus 1 equal?' } + ] + }) + + const events = agent.customEventAggregator.events.toArray() + test.equal(events.length, 3, 'should create a chat completion message and summary event') + const chatMsgs = events.filter(([{ type }]) => type === 'LlmChatCompletionMessage') + const expectedChatMsg = { + 'appName': 'New Relic for Node.js tests', + 'request_id': '49dbbffbd3c3f4612aa48def69059aad', + 'trace_id': tx.traceId, + 'span_id': tx.trace.root.children[0].id, + 'transaction_id': tx.id, + 'response.model': 'gpt-3.5-turbo-0613', + 'vendor': 'openAI', + 'ingest_source': 'Node', + 'role': 'user', + 'completion_id': /[a-f0-9]{36}/ + } + + chatMsgs.forEach((msg) => { + if (msg[1].sequence === 0) { + expectedChatMsg.sequence = 0 + ;(expectedChatMsg.id = 'chatcmpl-87sb95K4EF2nuJRcTs43Tm9ntTeat-0'), + (expectedChatMsg.content = 'You are a mathematician.') + } else { + expectedChatMsg.sequence = 1 + ;(expectedChatMsg.id = 'chatcmpl-87sb95K4EF2nuJRcTs43Tm9ntTeat-1'), + (expectedChatMsg.content = 'What does 1 plus 1 equal?') + } + + test.equal(msg[0].type, 'LlmChatCompletionMessage') + test.match(msg[1], expectedChatMsg, 'should match chat completion message') + }) + + const chatSummary = events.filter(([{ type }]) => type === 'LlmChatCompletionSummary')[0] + test.equal(chatSummary[0].type, 'LlmChatCompletionSummary') + const expectedChatSummary = { + 'id': /[a-f0-9]{36}/, + 'appName': 'New Relic for Node.js tests', + 'request_id': '49dbbffbd3c3f4612aa48def69059aad', + 'trace_id': tx.traceId, + 'span_id': tx.trace.root.children[0].id, + 'transaction_id': tx.id, + 'response.model': 'gpt-3.5-turbo-0613', + 'vendor': 'openAI', + 'ingest_source': 'Node', + 'request.model': 'gpt-3.5-turbo-0613', + 'duration': tx.trace.root.children[0].getExclusiveDurationInMillis(), + 'api_key_last_four_digits': 'sk--key', + 'response.organization': 'new-relic-nkmd8b', + 'response.usage.total_tokens': 64, + 'response.usage.prompt_tokens': 53, + 'response.headers.llmVersion': '2020-10-01', + 'response.headers.ratelimitLimitRequests': '200', + 'response.headers.ratelimitLimitTokens': '40000', + 'response.headers.ratelimitResetTokens': '90ms', + 'response.headers.ratelimitRemainingTokens': '39940', + 'response.headers.ratelimitRemainingRequests': '199', + 'response.number_of_messages': 3, + 'response.usage.completion_tokens': 11, + 'response.choices.finish_reason': 'stop' + } + test.match(chatSummary[1], expectedChatSummary, 'should match chat summary message') + test.end() + }) + }) + + t.test('should not llm events when not in a transaction', async (test) => { + const { client, agent } = t.context + await client.chat.completions.create({ + messages: [{ role: 'user', content: 'You are a mathematician.' }] + }) + + const events = agent.customEventAggregator.events.toArray() + test.equal(events.length, 0, 'should not create llm events') + }) +}) diff --git a/test/versioned/openai/package.json b/test/versioned/openai/package.json new file mode 100644 index 0000000000..3a9993adf8 --- /dev/null +++ b/test/versioned/openai/package.json @@ -0,0 +1,21 @@ +{ + "name": "openai-tests", + "version": "0.0.0", + "private": true, + "engines": { + "node": ">=16" + }, + "tests": [ + { + "engines": { + "node": ">=16" + }, + "dependencies": { + "openai": ">=4.0.0" + }, + "files": [ + "openai.tap.js" + ] + } + ] +}