Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 39 additions & 43 deletions packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,24 @@ const mockLdClient: jest.Mocked<LDClientMin> = {

const testContext: LDContext = { kind: 'user', key: 'test-user' };

it('interpolates template variables', () => {
const client = new LDAIClientImpl(mockLdClient);
const template = 'Hello {{name}}, your score is {{score}}';
const variables = { name: 'John', score: 42 };

const result = client.interpolateTemplate(template, variables);
expect(result).toBe('Hello John, your score is 42');
});

it('handles empty variables in template interpolation', () => {
const client = new LDAIClientImpl(mockLdClient);
const template = 'Hello {{name}}';
const variables = {};

const result = client.interpolateTemplate(template, variables);
expect(result).toBe('Hello ');
});

it('returns model config with interpolated prompts', async () => {
it('returns config with interpolated messagess', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'test-flag';
const defaultValue: LDAIDefaults = {
model: { modelId: 'test', name: 'test-model' },
prompt: [],
model: { id: 'test', parameters: { name: 'test-model' } },
messages: [],
enabled: true,
};

const mockVariation = {
model: { modelId: 'example-provider', name: 'imagination', temperature: 0.7, maxTokens: 4096 },
prompt: [
model: {
id: 'example-model',
parameters: { name: 'imagination', temperature: 0.7, maxTokens: 4096 },
},
provider: {
id: 'example-provider',
},
messages: [
{ role: 'system', content: 'Hello {{name}}' },
{ role: 'user', content: 'Score: {{score}}' },
],
Expand All @@ -53,11 +41,17 @@ it('returns model config with interpolated prompts', async () => {
mockLdClient.variation.mockResolvedValue(mockVariation);

const variables = { name: 'John', score: 42 };
const result = await client.modelConfig(key, testContext, defaultValue, variables);
const result = await client.config(key, testContext, defaultValue, variables);

expect(result).toEqual({
model: { modelId: 'example-provider', name: 'imagination', temperature: 0.7, maxTokens: 4096 },
prompt: [
model: {
id: 'example-model',
parameters: { name: 'imagination', temperature: 0.7, maxTokens: 4096 },
},
provider: {
id: 'example-provider',
},
messages: [
{ role: 'system', content: 'Hello John' },
{ role: 'user', content: 'Score: 42' },
],
Expand All @@ -66,46 +60,46 @@ it('returns model config with interpolated prompts', async () => {
});
});

it('includes context in variables for prompt interpolation', async () => {
it('includes context in variables for messages interpolation', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'test-flag';
const defaultValue: LDAIDefaults = {
model: { modelId: 'test', name: 'test-model' },
prompt: [],
model: { id: 'test', parameters: { name: 'test-model' } },
messages: [],
};

const mockVariation = {
prompt: [{ role: 'system', content: 'User key: {{ldctx.key}}' }],
messages: [{ role: 'system', content: 'User key: {{ldctx.key}}' }],
_ldMeta: { versionKey: 'v1', enabled: true },
};

mockLdClient.variation.mockResolvedValue(mockVariation);

const result = await client.modelConfig(key, testContext, defaultValue);
const result = await client.config(key, testContext, defaultValue);

expect(result.prompt?.[0].content).toBe('User key: test-user');
expect(result.messages?.[0].content).toBe('User key: test-user');
});

it('handles missing metadata in variation', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'test-flag';
const defaultValue: LDAIDefaults = {
model: { modelId: 'test', name: 'test-model' },
prompt: [],
model: { id: 'test', parameters: { name: 'test-model' } },
messages: [],
};

const mockVariation = {
model: { modelId: 'example-provider', name: 'imagination' },
prompt: [{ role: 'system', content: 'Hello' }],
model: { id: 'example-provider', parameters: { name: 'imagination' } },
messages: [{ role: 'system', content: 'Hello' }],
};

mockLdClient.variation.mockResolvedValue(mockVariation);

const result = await client.modelConfig(key, testContext, defaultValue);
const result = await client.config(key, testContext, defaultValue);

expect(result).toEqual({
model: { modelId: 'example-provider', name: 'imagination' },
prompt: [{ role: 'system', content: 'Hello' }],
model: { id: 'example-provider', parameters: { name: 'imagination' } },
messages: [{ role: 'system', content: 'Hello' }],
tracker: expect.any(Object),
enabled: false,
});
Expand All @@ -115,18 +109,20 @@ it('passes the default value to the underlying client', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'non-existent-flag';
const defaultValue: LDAIDefaults = {
model: { modelId: 'default-model', name: 'default' },
prompt: [{ role: 'system', content: 'Default prompt' }],
model: { id: 'default-model', parameters: { name: 'default' } },
provider: { id: 'default-provider' },
messages: [{ role: 'system', content: 'Default messages' }],
enabled: true,
};

mockLdClient.variation.mockResolvedValue(defaultValue);

const result = await client.modelConfig(key, testContext, defaultValue);
const result = await client.config(key, testContext, defaultValue);

expect(result).toEqual({
model: defaultValue.model,
prompt: defaultValue.prompt,
messages: defaultValue.messages,
provider: defaultValue.provider,
tracker: expect.any(Object),
enabled: false,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ it('tracks OpenAI usage', async () => {
const PROMPT_TOKENS = 49;
const COMPLETION_TOKENS = 51;

await tracker.trackOpenAI(async () => ({
await tracker.trackOpenAIMetrics(async () => ({
usage: {
total_tokens: TOTAL_TOKENS,
prompt_tokens: PROMPT_TOKENS,
Expand Down Expand Up @@ -151,7 +151,7 @@ it('tracks Bedrock conversation with successful response', () => {
},
};

tracker.trackBedrockConverse(response);
tracker.trackBedrockConverseMetrics(response);

expect(mockTrack).toHaveBeenCalledWith(
'$ld:ai:generation',
Expand Down Expand Up @@ -198,7 +198,7 @@ it('tracks Bedrock conversation with error response', () => {

// TODO: We may want a track failure.

tracker.trackBedrockConverse(response);
tracker.trackBedrockConverseMetrics(response);

expect(mockTrack).not.toHaveBeenCalled();
});
Expand Down
14 changes: 7 additions & 7 deletions packages/sdk/server-ai/examples/bedrock/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ async function main() {

const aiClient = initAi(ldClient);

const aiConfig = await aiClient.modelConfig(
const aiConfig = await aiClient.config(
aiConfigKey!,
context,
{
model: {
modelId: 'my-default-model',
id: 'my-default-model',
},
enabled: true,
},
Expand All @@ -63,14 +63,14 @@ async function main() {
);
const { tracker } = aiConfig;

const completion = tracker.trackBedrockConverse(
const completion = tracker.trackBedrockConverseMetrics(
await awsClient.send(
new ConverseCommand({
modelId: aiConfig.model?.modelId ?? 'no-model',
messages: mapPromptToConversation(aiConfig.prompt ?? []),
modelId: aiConfig.model?.id ?? 'no-model',
messages: mapPromptToConversation(aiConfig.messages ?? []),
inferenceConfig: {
temperature: aiConfig.model?.temperature ?? 0.5,
maxTokens: aiConfig.model?.maxTokens ?? 4096,
temperature: (aiConfig.model?.parameters?.temperature as number) ?? 0.5,
maxTokens: (aiConfig.model?.parameters?.maxTokens as number) ?? 4096,
},
}),
),
Expand Down
14 changes: 7 additions & 7 deletions packages/sdk/server-ai/examples/openai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,24 @@ async function main(): Promise<void> {

const aiClient = initAi(ldClient);

const aiConfig = await aiClient.modelConfig(
const aiConfig = await aiClient.config(
aiConfigKey,
context,
{
model: {
modelId: 'gpt-4',
id: 'gpt-4',
},
},
{ myVariable: 'My User Defined Variable' },
);

const { tracker } = aiConfig;
const completion = await tracker.trackOpenAI(async () =>
const completion = await tracker.trackOpenAIMetrics(async () =>
client.chat.completions.create({
messages: aiConfig.prompt || [],
model: aiConfig.model?.modelId || 'gpt-4',
temperature: aiConfig.model?.temperature ?? 0.5,
max_tokens: aiConfig.model?.maxTokens ?? 4096,
messages: aiConfig.messages || [],
model: aiConfig.model?.id || 'gpt-4',
temperature: (aiConfig.model?.parameters?.temperature as number) ?? 0.5,
max_tokens: (aiConfig.model?.parameters?.maxTokens as number) ?? 4096,
}),
);

Expand Down
18 changes: 11 additions & 7 deletions packages/sdk/server-ai/src/LDAIClientImpl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import * as Mustache from 'mustache';

import { LDContext } from '@launchdarkly/js-server-sdk-common';

import { LDAIConfig, LDAIDefaults, LDMessage, LDModelConfig } from './api/config';
import { LDAIConfig, LDAIDefaults, LDMessage, LDModelConfig, LDProviderConfig } from './api/config';
import { LDAIClient } from './api/LDAIClient';
import { LDAIConfigTrackerImpl } from './LDAIConfigTrackerImpl';
import { LDClientMin } from './LDClientMin';
Expand All @@ -21,18 +21,19 @@ interface LDMeta {
*/
interface VariationContent {
model?: LDModelConfig;
prompt?: LDMessage[];
messages?: LDMessage[];
provider?: LDProviderConfig;
_ldMeta?: LDMeta;
}

export class LDAIClientImpl implements LDAIClient {
constructor(private _ldClient: LDClientMin) {}

interpolateTemplate(template: string, variables: Record<string, unknown>): string {
private _interpolateTemplate(template: string, variables: Record<string, unknown>): string {
return Mustache.render(template, variables, undefined, { escape: (item: any) => item });
}

async modelConfig(
async config(
key: string,
context: LDContext,
defaultValue: LDAIDefaults,
Expand All @@ -57,12 +58,15 @@ export class LDAIClientImpl implements LDAIClient {
if (value.model) {
config.model = { ...value.model };
}
if (value.provider) {
config.provider = { ...value.provider };
}
const allVariables = { ...variables, ldctx: context };

if (value.prompt) {
config.prompt = value.prompt.map((entry: any) => ({
if (value.messages) {
config.messages = value.messages.map((entry: any) => ({
...entry,
content: this.interpolateTemplate(entry.content, allVariables),
content: this._interpolateTemplate(entry.content, allVariables),
}));
}

Expand Down
4 changes: 2 additions & 2 deletions packages/sdk/server-ai/src/LDAIConfigTrackerImpl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export class LDAIConfigTrackerImpl implements LDAIConfigTracker {
this._ldClient.track('$ld:ai:generation', this._context, this._getTrackData(), 1);
}

async trackOpenAI<
async trackOpenAIMetrics<
TRes extends {
usage?: {
total_tokens?: number;
Expand All @@ -62,7 +62,7 @@ export class LDAIConfigTrackerImpl implements LDAIConfigTracker {
return result;
}

trackBedrockConverse<
trackBedrockConverseMetrics<
TRes extends {
$metadata: { httpStatusCode?: number };
metrics?: { latencyMs?: number };
Expand Down
11 changes: 1 addition & 10 deletions packages/sdk/server-ai/src/api/LDAIClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,6 @@ import { LDAIConfig, LDAIDefaults } from './config/LDAIConfig';
* Interface for performing AI operations using LaunchDarkly.
*/
export interface LDAIClient {
/**
* Parses and interpolates a template string with the provided variables.
*
* @param template The template string to be parsed and interpolated.
* @param variables An object containing the variables to be used for interpolation.
* @returns The interpolated string.
*/
interpolateTemplate(template: string, variables: Record<string, unknown>): string;

/**
* Retrieves and processes an AI configuration based on the provided key, LaunchDarkly context,
* and variables. This includes the model configuration and the processed prompts.
Expand Down Expand Up @@ -67,7 +58,7 @@ export interface LDAIClient {
* }
* ```
*/
modelConfig(
config(
key: string,
context: LDContext,
defaultValue: LDAIDefaults,
Expand Down
24 changes: 15 additions & 9 deletions packages/sdk/server-ai/src/api/config/LDAIConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,24 @@ export interface LDModelConfig {
/**
* The ID of the model.
*/
modelId: string;
id: string;

/**
* Tuning parameter for randomness versus determinism. Exact effect will be determined by the
* model.
* Model specific parameters.
*/
temperature?: number;
parameters?: { [index: string]: unknown };

/**
* The maximum number of tokens.
* Additional user-specified parameters.
*/
maxTokens?: number;
custom?: { [index: string]: unknown };
}

export interface LDProviderConfig {
/**
* And additional model specific information.
* The ID of the provider.
*/
[index: string]: unknown;
id: string;
}

/**
Expand Down Expand Up @@ -51,7 +52,12 @@ export interface LDAIConfig {
/**
* Optional prompt data.
*/
prompt?: LDMessage[];
messages?: LDMessage[];

/**
* Optional configuration for the provider.
*/
provider?: LDProviderConfig;

/**
* A tracker which can be used to generate analytics.
Expand Down
Loading