Skip to content

Commit

Permalink
Support custom system prompts from the user (#399)
Browse files Browse the repository at this point in the history
* Support custom system prompts from the user

* linter

* types & lint
  • Loading branch information
nsarrazin committed Aug 21, 2023
1 parent 447c0ca commit cd6894d
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 11 deletions.
5 changes: 3 additions & 2 deletions src/lib/buildPrompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import { ObjectId } from "mongodb";
export async function buildPrompt(
messages: Pick<Message, "from" | "content">[],
model: BackendModel,
webSearchId?: string
webSearchId?: string,
preprompt?: string
): Promise<string> {
if (webSearchId) {
const webSearch = await collections.webSearches.findOne({
Expand All @@ -33,7 +34,7 @@ export async function buildPrompt(

return (
model
.chatPromptRender({ messages })
.chatPromptRender({ messages, preprompt })
// Not super precise, but it's truncated in the model's backend anyway
.split(" ")
.slice(-(model.parameters?.truncate ?? 0))
Expand Down
75 changes: 72 additions & 3 deletions src/lib/components/ModelsModal.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,56 @@
import { enhance } from "$app/forms";
import { base } from "$app/paths";
import CarbonEdit from "~icons/carbon/edit";
import CarbonSave from "~icons/carbon/save";
import CarbonRestart from "~icons/carbon/restart";
export let settings: LayoutData["settings"];
export let models: Array<Model>;
let selectedModelId = settings.activeModel;
const dispatch = createEventDispatcher<{ close: void }>();
let expanded = false;
function onToggle() {
if (expanded) {
settings.customPrompts[selectedModelId] = value;
}
expanded = !expanded;
}
let value = "";
function onModelChange() {
value =
settings.customPrompts[selectedModelId] ??
models.filter((el) => el.id === selectedModelId)[0].preprompt ??
"";
}
$: selectedModelId, onModelChange();
</script>

<Modal width="max-w-lg" on:close>
<form
action="{base}/settings"
method="post"
on:submit={() => {
if (expanded) {
onToggle();
}
}}
use:enhance={() => {
dispatch("close");
}}
class="flex w-full flex-col gap-5 p-6"
>
{#each Object.entries(settings).filter(([k]) => k !== "activeModel") as [key, val]}
{#each Object.entries(settings).filter(([k]) => !(k == "activeModel" || k === "customPrompts")) as [key, val]}
<input type="hidden" name={key} value={val} />
{/each}
<input type="hidden" name="customPrompts" value={JSON.stringify(settings.customPrompts)} />
<div class="flex items-start justify-between text-xl font-semibold text-gray-800">
<h2>Models</h2>
<button type="button" class="group" on:click={() => dispatch("close")}>
Expand All @@ -39,8 +69,9 @@

<div class="space-y-4">
{#each models as model}
{@const active = model.id === selectedModelId}
<div
class="rounded-xl border border-gray-100 {model.id === selectedModelId
class="rounded-xl border border-gray-100 {active
? 'bg-gradient-to-r from-primary-200/40 via-primary-500/10'
: ''}"
>
Expand All @@ -61,11 +92,49 @@
{/if}
</span>
<CarbonCheckmark
class="-mr-1 -mt-1 ml-auto shrink-0 text-xl {model.id === selectedModelId
class="-mr-1 -mt-1 ml-auto shrink-0 text-xl {active
? 'text-primary-400'
: 'text-transparent group-hover:text-gray-200'}"
/>
</label>
{#if active}
<div class=" overflow-hidden rounded-xl px-3 pb-2">
<div class="flex flex-row flex-nowrap gap-2 pb-1">
<div class="text-xs font-semibold text-gray-500">System Prompt</div>
{#if expanded}
<button
class="text-gray-500 hover:text-gray-900"
on:click|preventDefault={onToggle}
>
<CarbonSave class="text-sm " />
</button>
<button
class="text-gray-500 hover:text-gray-900"
on:click|preventDefault={() => {
value = model.preprompt ?? "";
}}
>
<CarbonRestart class="text-sm " />
</button>
{:else}
<button
class=" text-gray-500 hover:text-gray-900"
on:click|preventDefault={onToggle}
>
<CarbonEdit class="text-sm " />
</button>
{/if}
</div>
<textarea
enterkeyhint="send"
tabindex="0"
rows="1"
class="h-20 w-full resize-none scroll-p-3 overflow-x-hidden overflow-y-scroll rounded-md border border-gray-300 bg-transparent p-1 text-xs outline-none focus:ring-0 focus-visible:ring-0"
bind:value
hidden={!expanded}
/>
</div>
{/if}
<ModelCardMetadata {model} />
</div>
{/each}
Expand Down
2 changes: 1 addition & 1 deletion src/lib/components/chat/ChatIntroduction.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
</div>
</div>
{#if currentModelMetadata.promptExamples}
<div class="lg:col-span-3 lg:mt-12">
<div class="lg:col-span-3 lg:mt-6">
<p class="mb-3 text-gray-600 dark:text-gray-300">Examples</p>
<div class="grid gap-3 lg:grid-cols-3 lg:gap-5">
{#each currentModelMetadata.promptExamples as example}
Expand Down
6 changes: 4 additions & 2 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import type {
import { compileTemplate } from "$lib/utils/template";
import { z } from "zod";

type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;

const sagemakerEndpoint = z.object({
host: z.literal("sagemaker"),
url: z.string().url(),
Expand Down Expand Up @@ -57,7 +59,7 @@ const modelsRaw = z
assistantMessageToken: z.string().default(""),
assistantMessageEndToken: z.string().default(""),
messageEndToken: z.string().default(""),
preprompt: z.string().default(""),
preprompt: z.string().min(1).optional(),
prepromptUrl: z.string().url().optional(),
chatPromptTemplate: z
.string()
Expand Down Expand Up @@ -148,7 +150,7 @@ export const oldModels = OLD_MODELS
.map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
: [];

export type BackendModel = (typeof models)[0];
export type BackendModel = Optional<(typeof models)[0], "preprompt">;
export type Endpoint = z.infer<typeof endpoint>;

export const defaultModel = models[0];
Expand Down
1 change: 1 addition & 0 deletions src/lib/types/Model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ export type Model = Pick<
| "description"
| "modelUrl"
| "datasetUrl"
| "preprompt"
>;
3 changes: 3 additions & 0 deletions src/lib/types/Settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ export interface Settings extends Timestamps {
shareConversationsWithModelAuthors: boolean;
ethicsModalAcceptedAt: Date | null;
activeModel: string;

// model name and system prompts
customPrompts?: Record<string, string>;
}

// TODO: move this to a constant file along with other constants
Expand Down
3 changes: 2 additions & 1 deletion src/lib/types/Template.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { Message } from "./Message";

export type LegacyParamatersTemplateInput = {
preprompt: string;
preprompt?: string;
userMessageToken: string;
userMessageEndToken: string;
assistantMessageToken: string;
Expand All @@ -10,6 +10,7 @@ export type LegacyParamatersTemplateInput = {

export type ChatTemplateInput = {
messages: Pick<Message, "from" | "content">[];
preprompt?: string;
};

export type WebSearchSummaryTemplateInput = {
Expand Down
2 changes: 2 additions & 0 deletions src/routes/+layout.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export const load: LayoutServerLoad = async ({ locals, depends, url }) => {
ethicsModalAcceptedAt: settings?.ethicsModalAcceptedAt ?? null,
activeModel: settings?.activeModel ?? DEFAULT_SETTINGS.activeModel,
searchEnabled: !!(SERPAPI_KEY || SERPER_API_KEY),
customPrompts: settings?.customPrompts ?? {},
},
models: models.map((model) => ({
id: model.id,
Expand All @@ -74,6 +75,7 @@ export const load: LayoutServerLoad = async ({ locals, depends, url }) => {
description: model.description,
promptExamples: model.promptExamples,
parameters: model.parameters,
preprompt: model.preprompt,
})),
oldModels,
user: locals.user && {
Expand Down
9 changes: 8 additions & 1 deletion src/routes/conversation/[id]/+server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export async function POST({ request, fetch, locals, params }) {
}

const model = models.find((m) => m.id === conv.model);
const settings = await collections.settings.findOne(authCondition(locals));

if (!model) {
throw error(410, "Model not available anymore");
Expand Down Expand Up @@ -97,7 +98,13 @@ export async function POST({ request, fetch, locals, params }) {
];
})() satisfies Message[];

const prompt = await buildPrompt(messages, model, web_search_id);
const prompt = await buildPrompt(
messages,
model,
web_search_id,
settings?.customPrompts?.[model.id]
);

const randomEndpoint = modelEndpoint(model);

const abortController = new AbortController();
Expand Down
3 changes: 2 additions & 1 deletion src/routes/settings/+page.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ export const actions = {
.default(DEFAULT_SETTINGS.shareConversationsWithModelAuthors),
ethicsModalAccepted: z.boolean({ coerce: true }).optional(),
activeModel: validateModel(models),
customPrompts: z.record(z.string()).default({}),
})
.parse({
shareConversationsWithModelAuthors: formData.get("shareConversationsWithModelAuthors"),
ethicsModalAccepted: formData.get("ethicsModalAccepted"),
activeModel: formData.get("activeModel") ?? DEFAULT_SETTINGS.activeModel,
customPrompts: JSON.parse(formData.get("customPrompts")?.toString() ?? "{}"),
});

await collections.settings.updateOne(
Expand All @@ -40,7 +42,6 @@ export const actions = {
upsert: true,
}
);

throw redirect(303, request.headers.get("referer") || `${base}/`);
},
};

0 comments on commit cd6894d

Please sign in to comment.