Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Currently, we support the following providers:
- [Sambanova](https://sambanova.ai)
- [Scaleway](https://www.scaleway.com/en/generative-apis/)
- [Together](https://together.xyz)
- [Baseten](https://baseten.co)
- [Blackforestlabs](https://blackforestlabs.ai)
- [Cohere](https://cohere.com)
- [Cerebras](https://cerebras.ai/)
Expand Down Expand Up @@ -97,6 +98,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
- [Scaleway supported models](https://huggingface.co/api/partners/scaleway/models)
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [Baseten supported models](https://huggingface.co/api/partners/baseten/models)
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
- [Groq supported models](https://console.groq.com/docs/models)
Expand Down
4 changes: 4 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import * as Baseten from "../providers/baseten.js";
import * as BlackForestLabs from "../providers/black-forest-labs.js";
import * as Cerebras from "../providers/cerebras.js";
import * as Cohere from "../providers/cohere.js";
Expand Down Expand Up @@ -55,6 +56,9 @@ import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from
import { InferenceClientInputError } from "../errors.js";

export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
baseten: {
conversational: new Baseten.BasetenConversationalTask(),
},
"black-forest-labs": {
"text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(),
},
Expand Down
25 changes: 25 additions & 0 deletions packages/inference/src/providers/baseten.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/**
* See the registered mapping of HF model ID => Baseten model ID here:
*
* https://huggingface.co/api/partners/baseten/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Baseten and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Baseten, please open an issue on the present repo
* and we will tag Baseten team members.
*
* Thanks!
*/
import { BaseConversationalTask } from "./providerHelper.js";

const BASETEN_API_BASE_URL = "https://inference.baseten.co";

export class BasetenConversationalTask extends BaseConversationalTask {
constructor() {
super("baseten", BASETEN_API_BASE_URL);
}
}
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
* Example:
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
*/
baseten: {},
"black-forest-labs": {},
cerebras: {},
cohere: {},
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export interface Options {
export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";

export const INFERENCE_PROVIDERS = [
"baseten",
"black-forest-labs",
"cerebras",
"cohere",
Expand Down
58 changes: 58 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2343,4 +2343,62 @@ describe.skip("InferenceClient", () => {
},
TIMEOUT
);

describe.concurrent(
"Baseten",
() => {
const client = new InferenceClient(env.HF_BASETEN_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["baseten"] = {
"Qwen/Qwen3-235B-A22B-Instruct-2507": {
provider: "baseten",
hfModelId: "Qwen/Qwen3-235B-A22B-Instruct-2507",
providerId: "Qwen/Qwen3-235B-A22B-Instruct-2507",
status: "live",
task: "conversational",
},
};

it("chatCompletion - Qwen3 235B Instruct", async () => {
const res = await client.chatCompletion({
model: "Qwen/Qwen3-235B-A22B-Instruct-2507",
provider: "baseten",
messages: [{ role: "user", content: "What is 5 + 3?" }],
max_tokens: 20,
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toBeDefined();
expect(typeof completion).toBe("string");
expect(completion).toMatch(/(eight|8)/i);
}
});

it("chatCompletion stream - Qwen3 235B", async () => {
const stream = client.chatCompletionStream({
model: "Qwen/Qwen3-235B-A22B-Instruct-2507",
provider: "baseten",
messages: [{ role: "user", content: "Count from 1 to 3" }],
stream: true,
max_tokens: 20,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
expect(fullResponse).toMatch(/1.*2.*3/);
});
},
TIMEOUT
);
});