Skip to content

Commit

Permalink
Add langserve endpoint (#1009)
Browse files Browse the repository at this point in the history
* Add support for langserve endpoints

* Add support for langserve endpoints

* Fix linting

* Fix linting issues

* Fix issue import

---------

Co-authored-by: antoniora <antonio.ramos@adyen.com>
  • Loading branch information
antonioramos1 and antoniora committed Apr 16, 2024
1 parent 4538f1d commit f12455d
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 0 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,24 @@ MODELS=`[
```

##### LangServe

LangChain applications that are deployed using LangServe can be called with the following config:

```
MODELS=`[
//...
{
"name": "summarization-chain", //model-name
"endpoints" : [{
"type": "langserve",
"url" : "http://127.0.0.1:8100",
}]
},
]`
```

### Custom endpoint authorization

#### Basic and Bearer
Expand Down
5 changes: 5 additions & 0 deletions src/lib/server/endpoints/endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import endpointCloudflare, {
endpointCloudflareParametersSchema,
} from "./cloudflare/endpointCloudflare";
import { endpointCohere, endpointCohereParametersSchema } from "./cohere/endpointCohere";
import endpointLangserve, {
endpointLangserveParametersSchema,
} from "./langserve/endpointLangserve";

// parameters passed when generating text
export interface EndpointParameters {
Expand Down Expand Up @@ -48,6 +51,7 @@ export const endpoints = {
vertex: endpointVertex,
cloudflare: endpointCloudflare,
cohere: endpointCohere,
langserve: endpointLangserve,
};

export const endpointSchema = z.discriminatedUnion("type", [
Expand All @@ -60,5 +64,6 @@ export const endpointSchema = z.discriminatedUnion("type", [
endpointVertexParametersSchema,
endpointCloudflareParametersSchema,
endpointCohereParametersSchema,
endpointLangserveParametersSchema,
]);
export default endpoints;
128 changes: 128 additions & 0 deletions src/lib/server/endpoints/langserve/endpointLangserve.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import { buildPrompt } from "$lib/buildPrompt";
import { z } from "zod";
import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";

export const endpointLangserveParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("langserve"),
url: z.string().url(),
});

export function endpointLangserve(
input: z.input<typeof endpointLangserveParametersSchema>
): Endpoint {
const { url, model } = endpointLangserveParametersSchema.parse(input);

return async ({ messages, preprompt, continueMessage }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
preprompt,
model,
});

const r = await fetch(`${url}/stream`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
input: { text: prompt },
}),
});

if (!r.ok) {
throw new Error(`Failed to generate text: ${await r.text()}`);
}

const encoder = new TextDecoderStream();
const reader = r.body?.pipeThrough(encoder).getReader();

return (async function* () {
let stop = false;
let generatedText = "";
let tokenId = 0;
let accumulatedData = ""; // Buffer to accumulate data chunks

while (!stop) {
// Read the stream and log the outputs to console
const out = (await reader?.read()) ?? { done: false, value: undefined };

// If it's done, we cancel
if (out.done) {
reader?.cancel();
return;
}

if (!out.value) {
return;
}

// Accumulate the data chunk
accumulatedData += out.value;
// Keep read data to check event type
const eventData = out.value;

// Process each complete JSON object in the accumulated data
while (accumulatedData.includes("\n")) {
// Assuming each JSON object ends with a newline
const endIndex = accumulatedData.indexOf("\n");
let jsonString = accumulatedData.substring(0, endIndex).trim();
// Remove the processed part from the buffer

accumulatedData = accumulatedData.substring(endIndex + 1);

// Stopping with end event
if (eventData.startsWith("event: end")) {
stop = true;
yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: generatedText,
details: null,
} satisfies TextGenerationStreamOutput;
reader?.cancel();
continue;
}

if (eventData.startsWith("event: data") && jsonString.startsWith("data: ")) {
jsonString = jsonString.slice(6);
let data = null;

// Handle the parsed data
try {
data = JSON.parse(jsonString);
} catch (e) {
console.error("Failed to parse JSON", e);
console.error("Problematic JSON string:", jsonString);
continue; // Skip this iteration and try the next chunk
}
// Assuming content within data is a plain string
if (data) {
generatedText += data;
const output: TextGenerationStreamOutput = {
token: {
id: tokenId++,
text: data,
logprob: 0,
special: false,
},
generated_text: null,
details: null,
};
yield output;
}
}
}
}
})();
};
}

export default endpointLangserve;
2 changes: 2 additions & 0 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
return await endpoints.cloudflare(args);
case "cohere":
return await endpoints.cohere(args);
case "langserve":
return await endpoints.langserve(args);
default:
// for legacy reason
return endpoints.tgi(args);
Expand Down

0 comments on commit f12455d

Please sign in to comment.