Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Azure] Support Batch API #833

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ export interface AzureClientOptions extends ClientOptions {
/** API Client for interfacing with the Azure OpenAI API. */
export class AzureOpenAI extends OpenAI {
private _azureADTokenProvider: (() => Promise<string>) | undefined;
private _deployment: string | undefined;
apiVersion: string = '';
/**
* API Client for interfacing with the Azure OpenAI API.
Expand Down Expand Up @@ -412,11 +413,7 @@ export class AzureOpenAI extends OpenAI {
);
}

if (deployment) {
baseURL = `${endpoint}/openai/deployments/${deployment}`;
} else {
baseURL = `${endpoint}/openai`;
}
baseURL = `${endpoint}/openai`;
} else {
if (endpoint) {
throw new Errors.OpenAIError('baseURL and endpoint are mutually exclusive');
Expand All @@ -432,6 +429,7 @@ export class AzureOpenAI extends OpenAI {

this._azureADTokenProvider = azureADTokenProvider;
this.apiVersion = apiVersion;
this._deployment = deployment;
}

override buildRequest(options: Core.FinalRequestOptions<unknown>): {
Expand All @@ -443,7 +441,7 @@ export class AzureOpenAI extends OpenAI {
if (!Core.isObj(options.body)) {
throw new Error('Expected request body to be an object');
}
const model = options.body['model'];
const model = this._deployment || options.body['model'];
delete options.body['model'];
if (model !== undefined && !this.baseURL.includes('/deployments')) {
options.path = `/deployments/${model}${options.path}`;
Expand Down Expand Up @@ -494,6 +492,7 @@ const _deployments_endpoints = new Set([
'/audio/translations',
'/audio/speech',
'/images/generations',
'/batches',
]);

const API_KEY_SENTINEL = '<Missing Key>';
Expand Down
274 changes: 274 additions & 0 deletions tests/lib/azure.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import { Headers } from 'openai/core';
import defaultFetch, { Response, type RequestInit, type RequestInfo } from 'node-fetch';

const apiVersion = '2024-02-15-preview';
const deployment = 'deployment';
const model = 'unused model';

describe('instantiate azure client', () => {
const env = process.env;
Expand Down Expand Up @@ -275,6 +277,278 @@ describe('instantiate azure client', () => {
describe('azure request building', () => {
const client = new AzureOpenAI({ baseURL: 'https://example.com', apiKey: 'My API Key', apiVersion });

describe('model to deployment mapping', function () {
const testFetch = async (url: RequestInfo): Promise<Response> => {
return new Response(JSON.stringify({ url }), { headers: { 'content-type': 'application/json' } });
};
describe('with client-level deployment', function () {
const client = new AzureOpenAI({
endpoint: 'https://example.com',
apiKey: 'My API Key',
apiVersion,
deployment,
fetch: testFetch,
});

test('handles Batch', async () => {
expect(
await client.batches.create({
completion_window: '24h',
endpoint: '/v1/chat/completions',
input_file_id: 'file-id',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/batches?api-version=${apiVersion}`,
});
});

test('handles completions', async () => {
expect(
await client.completions.create({
model,
prompt: 'prompt',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/completions?api-version=${apiVersion}`,
});
});

test('handles chat completions', async () => {
expect(
await client.chat.completions.create({
model,
messages: [{ role: 'system', content: 'Hello' }],
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/chat/completions?api-version=${apiVersion}`,
});
});

test('handles embeddings', async () => {
expect(
await client.embeddings.create({
model,
input: 'input',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/embeddings?api-version=${apiVersion}`,
});
});

test('handles audio translations', async () => {
expect(
await client.audio.translations.create({
model,
file: { url: 'https://example.com', blob: () => 0 as any },
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/audio/translations?api-version=${apiVersion}`,
});
});

test('handles audio transcriptions', async () => {
expect(
await client.audio.transcriptions.create({
model,
file: { url: 'https://example.com', blob: () => 0 as any },
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/audio/transcriptions?api-version=${apiVersion}`,
});
});

test('handles text to speech', async () => {
expect(
await (
await client.audio.speech.create({
model,
input: '',
voice: 'alloy',
})
).json(),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/audio/speech?api-version=${apiVersion}`,
});
});

test('handles image generation', async () => {
expect(
await client.images.generate({
model,
prompt: 'prompt',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/images/generations?api-version=${apiVersion}`,
});
});

test('handles assistants', async () => {
expect(
await client.beta.assistants.create({
model,
}),
).toStrictEqual({
url: `https://example.com/openai/assistants?api-version=${apiVersion}`,
});
});

test('handles files', async () => {
expect(
await client.files.create({
file: { url: 'https://example.com', blob: () => 0 as any },
purpose: 'assistants',
}),
).toStrictEqual({
url: `https://example.com/openai/files?api-version=${apiVersion}`,
});
});

test('handles fine tuning', async () => {
expect(
await client.fineTuning.jobs.create({
model,
training_file: '',
}),
).toStrictEqual({
url: `https://example.com/openai/fine_tuning/jobs?api-version=${apiVersion}`,
});
});
});

describe('with no client-level deployment', function () {
const client = new AzureOpenAI({
endpoint: 'https://example.com',
apiKey: 'My API Key',
apiVersion,
fetch: testFetch,
});

test('Batch is not handled', async () => {
expect(
await client.batches.create({
completion_window: '24h',
endpoint: '/v1/chat/completions',
input_file_id: 'file-id',
}),
).toStrictEqual({
url: `https://example.com/openai/batches?api-version=${apiVersion}`,
});
});

test('handles completions', async () => {
expect(
await client.completions.create({
model: deployment,
prompt: 'prompt',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/completions?api-version=${apiVersion}`,
});
});

test('handles chat completions', async () => {
expect(
await client.chat.completions.create({
model: deployment,
messages: [{ role: 'system', content: 'Hello' }],
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/chat/completions?api-version=${apiVersion}`,
});
});

test('handles embeddings', async () => {
expect(
await client.embeddings.create({
model: deployment,
input: 'input',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/embeddings?api-version=${apiVersion}`,
});
});

test('Audio translations is not handled', async () => {
expect(
await client.audio.translations.create({
model: deployment,
file: { url: 'https://example.com', blob: () => 0 as any },
}),
).toStrictEqual({
url: `https://example.com/openai/audio/translations?api-version=${apiVersion}`,
});
});

test('Audio transcriptions is not handled', async () => {
expect(
await client.audio.transcriptions.create({
model: deployment,
file: { url: 'https://example.com', blob: () => 0 as any },
}),
).toStrictEqual({
url: `https://example.com/openai/audio/transcriptions?api-version=${apiVersion}`,
});
});

test('handles text to speech', async () => {
expect(
await (
await client.audio.speech.create({
model: deployment,
input: '',
voice: 'alloy',
})
).json(),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/audio/speech?api-version=${apiVersion}`,
});
});

test('handles image generation', async () => {
expect(
await client.images.generate({
model: deployment,
prompt: 'prompt',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/images/generations?api-version=${apiVersion}`,
});
});

test('handles assistants', async () => {
expect(
await client.beta.assistants.create({
model,
}),
).toStrictEqual({
url: `https://example.com/openai/assistants?api-version=${apiVersion}`,
});
});

test('handles files', async () => {
expect(
await client.files.create({
file: { url: 'https://example.com', blob: () => 0 as any },
purpose: 'assistants',
}),
).toStrictEqual({
url: `https://example.com/openai/files?api-version=${apiVersion}`,
});
});

test('handles fine tuning', async () => {
expect(
await client.fineTuning.jobs.create({
model,
training_file: '',
}),
).toStrictEqual({
url: `https://example.com/openai/fine_tuning/jobs?api-version=${apiVersion}`,
});
});
});
});

describe('Content-Length', () => {
test('handles multi-byte characters', () => {
const { req } = client.buildRequest({ path: '/foo', method: 'post', body: { value: '—' } });
Expand Down