Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ask the system prompt from the Playground creation form #643

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
113 changes: 45 additions & 68 deletions packages/backend/src/managers/playgroundV2Manager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ test('submit should throw an error if the server is stopped', async () => {
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'model1' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'model1' } as ModelInfo, '', 'tracking-1');

vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
Expand Down Expand Up @@ -107,7 +107,7 @@ test('submit should throw an error if the server is unhealthy', async () => {
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('p1', { id: 'model1' } as ModelInfo, 'tracking-1');
await manager.createPlayground('p1', { id: 'model1' } as ModelInfo, '', 'tracking-1');
const playgroundId = manager.getPlaygrounds()[0].id;
await expect(manager.submit(playgroundId, 'dummyUserInput', '')).rejects.toThrowError(
'Inference server is not healthy, currently status: unhealthy.',
Expand All @@ -133,12 +133,42 @@ test('create playground should create conversation.', async () => {
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
expect(manager.getConversations().length).toBe(0);
await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, '', 'tracking-1');

const conversations = manager.getConversations();
expect(conversations.length).toBe(1);
});

test('create playground called with a system prompt should create conversation with a system message.', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
status: 'running',
health: {
Status: 'healthy',
},
models: [
{
id: 'dummyModelId',
file: {
file: 'dummyModelFile',
},
},
],
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
expect(manager.getConversations().length).toBe(0);
await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, 'a system prompt', 'tracking-1');

const conversations = manager.getConversations();
expect(conversations.length).toBe(1);
const conversation = conversations[0];
expect(conversation.messages).toHaveLength(1);
const systemMessage = conversation.messages[0];
expect(systemMessage.role).toEqual('system');
expect(systemMessage.content).toEqual('a system prompt');
});

test('valid submit should create IPlaygroundMessage and notify the webview', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
Expand Down Expand Up @@ -169,7 +199,7 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy
} as unknown as OpenAI);

const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, undefined, 'tracking-1');

const date = new Date(2000, 1, 1, 13);
vi.setSystemTime(date);
Expand Down Expand Up @@ -208,65 +238,6 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy
});
});

test.each(['', 'my system prompt'])(
'valid submit should send a message with system prompt if non empty, system prompt is "%s"}',
async (systemPrompt: string) => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
status: 'running',
health: {
Status: 'healthy',
},
models: [
{
id: 'dummyModelId',
file: {
file: 'dummyModelFile',
},
},
],
connection: {
port: 8888,
},
} as unknown as InferenceServer,
]);
const createMock = vi.fn().mockResolvedValue([]);
vi.mocked(OpenAI).mockReturnValue({
chat: {
completions: {
create: createMock,
},
},
} as unknown as OpenAI);

const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1');

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput', systemPrompt);

const messages: unknown[] = [
{
content: 'dummyUserInput',
id: expect.any(String),
role: 'user',
timestamp: expect.any(Number),
},
];
if (systemPrompt) {
messages.push({
content: 'my system prompt',
role: 'system',
});
}
expect(createMock).toHaveBeenCalledWith({
messages,
model: 'dummyModelFile',
stream: true,
});
},
);

test('submit should send options', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
Expand Down Expand Up @@ -297,7 +268,7 @@ test('submit should send options', async () => {
} as unknown as OpenAI);

const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, undefined, 'tracking-1');

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput', '', { temperature: 0.123, max_tokens: 45, top_p: 0.345 });
Expand Down Expand Up @@ -334,6 +305,7 @@ test('creating a new playground should send new playground to frontend', async (
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(webviewMock.postMessage).toHaveBeenCalledWith({
Expand All @@ -357,6 +329,7 @@ test('creating a new playground with no name should send new playground to front
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(webviewMock.postMessage).toHaveBeenCalledWith({
Expand All @@ -381,6 +354,7 @@ test('creating a new playground with no model served should start an inference s
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(createInferenceServerMock).toHaveBeenCalledWith(
Expand Down Expand Up @@ -417,6 +391,7 @@ test('creating a new playground with the model already served should not start a
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(createInferenceServerMock).not.toHaveBeenCalled();
Expand Down Expand Up @@ -445,6 +420,7 @@ test('creating a new playground with the model server stopped should start the i
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(createInferenceServerMock).not.toHaveBeenCalled();
Expand All @@ -462,6 +438,7 @@ test('delete conversation should delete the conversation', async () => {
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);

Expand All @@ -484,9 +461,9 @@ test('requestCreatePlayground should call createPlayground and createTask, then
});
const createPlaygroundSpy = vi.spyOn(manager, 'createPlayground').mockResolvedValue('playground-1');

const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo);
const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo, '');

expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, expect.any(String));
expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, '', expect.any(String));
expect(createTaskMock).toHaveBeenCalledWith('Creating Playground environment', 'loading', {
trackingId: id,
});
Expand All @@ -513,9 +490,9 @@ test('requestCreatePlayground should call createPlayground and createTask, then
});
const createPlaygroundSpy = vi.spyOn(manager, 'createPlayground').mockRejectedValue(new Error('an error'));

const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo);
const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo, '');

expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, expect.any(String));
expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, '', expect.any(String));
expect(createTaskMock).toHaveBeenCalledWith('Creating Playground environment', 'loading', {
trackingId: id,
});
Expand Down
22 changes: 18 additions & 4 deletions packages/backend/src/managers/playgroundV2Manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import type { ChatCompletionChunk, ChatCompletionMessageParam } from 'openai/src
import type { ModelOptions } from '@shared/src/models/IModelOptions';
import type { Stream } from 'openai/streaming';
import { ConversationRegistry } from '../registries/conversationRegistry';
import type { Conversation, PendingChat, UserChat } from '@shared/src/models/IPlaygroundMessage';
import type { Conversation, PendingChat, SystemPrompt, UserChat } from '@shared/src/models/IPlaygroundMessage';
import type { PlaygroundV2 } from '@shared/src/models/IPlaygroundV2';
import { Publisher } from '../utils/Publisher';
import { Messages } from '@shared/Messages';
Expand Down Expand Up @@ -54,13 +54,13 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
this.notify();
}

async requestCreatePlayground(name: string, model: ModelInfo): Promise<string> {
async requestCreatePlayground(name: string, model: ModelInfo, systemPrompt?: string): Promise<string> {
const trackingId: string = getRandomString();
const task = this.taskRegistry.createTask('Creating Playground environment', 'loading', {
trackingId: trackingId,
});

this.createPlayground(name, model, trackingId)
this.createPlayground(name, model, systemPrompt, trackingId)
.then((playgroundId: string) => {
this.taskRegistry.updateTask({
...task,
Expand Down Expand Up @@ -94,7 +94,12 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
return trackingId;
}

async createPlayground(name: string, model: ModelInfo, trackingId: string): Promise<string> {
async createPlayground(
name: string,
model: ModelInfo,
systemPrompt: string | undefined,
trackingId: string,
): Promise<string> {
const id = `${this.#playgroundCounter++}`;

if (!name) {
Expand All @@ -103,6 +108,15 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di

this.#conversationRegistry.createConversation(id);

if (systemPrompt) {
this.#conversationRegistry.submit(id, {
content: systemPrompt,
role: 'system',
id: this.getUniqueId(),
timestamp: Date.now(),
} as SystemPrompt);
}

// create/start inference server if necessary
const servers = this.inferenceManager.getServers();
const server = servers.find(s => s.models.map(mi => mi.id).includes(model.id));
Expand Down
4 changes: 2 additions & 2 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ export class StudioApiImpl implements StudioAPI {
});
}

async requestCreatePlayground(name: string, model: ModelInfo): Promise<string> {
async requestCreatePlayground(name: string, model: ModelInfo, systemPrompt?: string): Promise<string> {
try {
return this.playgroundV2.requestCreatePlayground(name, model);
return this.playgroundV2.requestCreatePlayground(name, model, systemPrompt);
} catch (err: unknown) {
console.error('Something went wrong while trying to create playground environment', err);
throw err;
Expand Down
20 changes: 5 additions & 15 deletions packages/frontend/src/pages/Playground.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
isPendingChat,
isUserChat,
type AssistantChat,
isSystemPrompt,
} from '@shared/src/models/IPlaygroundMessage';
import NavPage from '../lib/NavPage.svelte';
import { playgrounds } from '../stores/playgrounds-v2';
Expand Down Expand Up @@ -46,7 +47,7 @@ $: {
}

const roleNames = {
system: 'System',
system: 'System prompt',
user: 'User',
assistant: 'Assistant',
};
Expand All @@ -61,9 +62,8 @@ function getMessageParagraphs(message: ChatMessage): string[] {
.join('')
.split('\n');
}
} else if (isUserChat(message)) {
const msg = message as UserChat;
return msg.content?.split('\n') ?? [];
} else if (isUserChat(message) || isSystemPrompt(message)) {
return message.content?.split('\n') ?? [];
}
return [];
}
Expand Down Expand Up @@ -130,6 +130,7 @@ function elapsedTime(msg: AssistantChat): string {
<div
class="p-4 rounded-md"
class:bg-charcoal-400="{isUserChat(message)}"
class:bg-charcoal-800="{isSystemPrompt(message)}"
class:bg-charcoal-900="{isAssistantChat(message)}"
class:ml-8="{isAssistantChat(message)}"
class:mr-8="{isUserChat(message)}">
Expand All @@ -152,17 +153,6 @@ function elapsedTime(msg: AssistantChat): string {
</svelte:fragment>
<svelte:fragment slot="details">
<div class="text-gray-800 text-xs">Next prompt will use these settings</div>
<div class="bg-charcoal-600 w-full rounded-md text-xs p-4">
<div class="mb-4">System Prompt</div>
<div class="w-full">
<textarea
bind:value="{systemPrompt}"
class="p-2 w-full outline-none bg-charcoal-500 rounded-sm text-gray-700 placeholder-gray-700"
rows="4"
placeholder="Provide system prompt to define general context, instructions or guidelines to be used with each query"
></textarea>
</div>
</div>
<div class="bg-charcoal-600 w-full rounded-md text-xs p-4">
<div class="mb-4 flex flex-col">Model Parameters</div>
<div class="flex flex-col space-y-4">
Expand Down
13 changes: 12 additions & 1 deletion packages/frontend/src/pages/PlaygroundCreate.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ let localModels: ModelInfo[];
$: localModels = $modelsInfo.filter(model => model.file);
$: availModels = $modelsInfo.filter(model => !model.file);
let modelId: string | undefined = undefined;
let systemPrompt: string | undefined = undefined;
let submitted: boolean = false;
let playgroundName: string;

Expand Down Expand Up @@ -55,7 +56,8 @@ async function submit() {
// disable submit button
submitted = true;
try {
trackingId = await studioClient.requestCreatePlayground(playgroundName, model);
// Using || and not && as we want to have the empty string systemPrompt passed as undefined
trackingId = await studioClient.requestCreatePlayground(playgroundName, model, systemPrompt || undefined);
} catch (err: unknown) {
trackingId = undefined;
console.error('Something wrong while trying to create the playground.', err);
Expand Down Expand Up @@ -161,6 +163,15 @@ onDestroy(() => {
</div>
</div>
{/if}

<label for="model" class="pt-4 block mb-2 text-sm font-bold text-gray-400">System prompt</label>
<textarea
aria-label="system-prompt-textarea"
bind:value="{systemPrompt}"
class="w-full p-2 outline-none text-sm bg-charcoal-600 rounded-sm text-gray-700 placeholder-gray-700"
rows="4"
placeholder="Optionally provide system prompt to define general context, instructions or guidelines to be used with each query"
></textarea>
</div>
<footer>
<div class="w-full flex flex-col">
Expand Down
2 changes: 1 addition & 1 deletion packages/shared/src/StudioAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ export abstract class StudioAPI {
*/
abstract createSnippet(options: RequestOptions, language: string, variant: string): Promise<string>;

abstract requestCreatePlayground(name: string, model: ModelInfo): Promise<string>;
abstract requestCreatePlayground(name: string, model: ModelInfo, systemPrompt?: string): Promise<string>;

abstract getPlaygroundsV2(): Promise<PlaygroundV2[]>;

Expand Down