Skip to content

Commit

Permalink
adding provider specific transformers + better logging + fixing anysc…
Browse files Browse the repository at this point in the history
…ale (#132)
  • Loading branch information
roodboi committed Mar 5, 2024
1 parent 076aaa6 commit f65672c
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 29 deletions.
5 changes: 5 additions & 0 deletions .changeset/stupid-ducks-act.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@instructor-ai/instructor": minor
---

adding meta to standard completions as well and including usage - also added more verbose debug logs and new provider specific transformers to handle discrepencies in various apis
51 changes: 50 additions & 1 deletion src/constants/providers.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { MODE, type Mode } from "zod-stream"
import { omit } from "@/lib"
import OpenAI from "openai"
import { z } from "zod"
import { MODE, withResponseModel, type Mode } from "zod-stream"

export const PROVIDERS = {
OAI: "OAI",
Expand All @@ -24,6 +27,52 @@ export const NON_OAI_PROVIDER_URLS = {
[PROVIDERS.OAI]: "api.openai.com"
} as const

export const PROVIDER_PARAMS_TRANSFORMERS = {
[PROVIDERS.ANYSCALE]: {
[MODE.JSON_SCHEMA]: function removeAdditionalPropertiesKeyJSONSchema<
T extends z.AnyZodObject,
P extends OpenAI.ChatCompletionCreateParams
>(params: ReturnType<typeof withResponseModel<T, "JSON_SCHEMA", P>>) {
if ("additionalProperties" in params.response_format.schema) {
return {
...params,
response_format: {
...params.response_format,
schema: omit(["additionalProperties"], params.response_format.schema)
}
}
}

return params
},
[MODE.TOOLS]: function removeAdditionalPropertiesKeyTools<
T extends z.AnyZodObject,
P extends OpenAI.ChatCompletionCreateParams
>(params: ReturnType<typeof withResponseModel<T, "TOOLS", P>>) {
if (params.tools.some(tool => tool.function?.parameters)) {
return {
...params,
tools: params.tools.map(tool => {
if (tool.function?.parameters) {
return {
...tool,
function: {
...tool.function,
parameters: omit(["additionalProperties"], tool.function.parameters)
}
}
}

return tool
})
}
}

return params
}
}
} as const

export const PROVIDER_SUPPORTED_MODES_BY_MODEL = {
[PROVIDERS.OTHER]: {
[MODE.FUNCTIONS]: ["*"],
Expand Down
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Instructor from "./instructor"

export * from "./types"
export default Instructor
75 changes: 58 additions & 17 deletions src/instructor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,18 @@ import {
} from "@/types"
import OpenAI from "openai"
import { z } from "zod"
import ZodStream, {
CompletionMeta,
OAIResponseParser,
OAIStream,
withResponseModel,
type Mode
} from "zod-stream"
import ZodStream, { OAIResponseParser, OAIStream, withResponseModel, type Mode } from "zod-stream"
import { fromZodError } from "zod-validation-error"

import {
NON_OAI_PROVIDER_URLS,
Provider,
PROVIDER_PARAMS_TRANSFORMERS,
PROVIDER_SUPPORTED_MODES,
PROVIDER_SUPPORTED_MODES_BY_MODEL,
PROVIDERS
} from "./constants/providers"
import { CompletionMeta } from "./types"

const MAX_RETRIES_DEFAULT = 0

Expand Down Expand Up @@ -109,7 +105,9 @@ class Instructor {
let validationIssues = ""
let lastMessage: OpenAI.ChatCompletionMessageParam | null = null

const completionParams = withResponseModel({
const paramsTransformer = PROVIDER_PARAMS_TRANSFORMERS?.[this.provider]?.[this.mode]

let completionParams = withResponseModel({
params: {
...params,
stream: false
Expand All @@ -118,6 +116,10 @@ class Instructor {
response_model
})

if (!!paramsTransformer) {
completionParams = paramsTransformer(completionParams)
}

const makeCompletionCall = async () => {
let resolvedParams = completionParams

Expand All @@ -135,17 +137,33 @@ class Instructor {
}
}

this.log("debug", response_model.name, "making completion call with params: ", resolvedParams)
let completion: OpenAI.Chat.Completions.ChatCompletion | null = null

const completion = await this.client.chat.completions.create(resolvedParams)
try {
completion = await this.client.chat.completions.create(resolvedParams)
this.log("debug", "raw standard completion response: ", completion)
} catch (error) {
this.log(
"error",
`Error making completion call - mode: ${this.mode} | Client base URL: ${this.client.baseURL} | with params:`,
resolvedParams,
`raw error`,
error
)

throw error
}

const parsedCompletion = OAIResponseParser(
completion as OpenAI.Chat.Completions.ChatCompletion
)

try {
return JSON.parse(parsedCompletion) as z.infer<T>
const data = JSON.parse(parsedCompletion) as z.infer<T> & { _meta?: CompletionMeta }
return { ...data, _meta: { usage: completion?.usage ?? undefined } }
} catch (error) {
this.log("error", "failed to parse completion", parsedCompletion, this.mode)
throw error
}
}

Expand Down Expand Up @@ -173,13 +191,29 @@ class Instructor {
return validation.data
} catch (error) {
if (attempts < max_retries) {
this.log("debug", response_model.name, "Retrying, attempt: ", attempts)
this.log("warn", response_model.name, "Validation error: ", validationIssues)
this.log(
"debug",
`response model: ${response_model.name} - Retrying, attempt: `,
attempts
)
this.log(
"warn",
`response model: ${response_model.name} - Validation issues: `,
validationIssues
)
attempts++
return await makeCompletionCallWithRetries()
} else {
this.log("debug", response_model.name, "Max attempts reached: ", attempts)
this.log("error", response_model.name, "Error: ", validationIssues)
this.log(
"debug",
`response model: ${response_model.name} - Max attempts reached: ${attempts}`
)
this.log(
"error",
`response model: ${response_model.name} - Validation issues: `,
validationIssues
)

throw error
}
}
Expand All @@ -193,13 +227,15 @@ class Instructor {
response_model,
...params
}: ChatCompletionCreateParamsWithModel<T>): Promise<
AsyncGenerator<Partial<T> & { _meta: CompletionMeta }, void, unknown>
AsyncGenerator<Partial<T> & { _meta?: CompletionMeta }, void, unknown>
> {
if (max_retries) {
this.log("warn", "max_retries is not supported for streaming completions")
}

const completionParams = withResponseModel({
const paramsTransformer = PROVIDER_PARAMS_TRANSFORMERS?.[this.provider]?.[this.mode]

let completionParams = withResponseModel({
params: {
...params,
stream: true
Expand All @@ -208,13 +244,18 @@ class Instructor {
mode: this.mode
})

if (paramsTransformer) {
completionParams = paramsTransformer(completionParams)
}

const streamClient = new ZodStream({
debug: this.debug ?? false
})

return streamClient.create({
completionPromise: async () => {
const completion = await this.client.chat.completions.create(completionParams)
this.log("debug", "raw stream completion response: ", completion)

return OAIStream({
res: completion
Expand Down
11 changes: 7 additions & 4 deletions src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ import OpenAI from "openai"
import { Stream } from "openai/streaming"
import { z } from "zod"
import {
CompletionMeta,
CompletionMeta as ZCompletionMeta,
type Mode as ZMode,
type ResponseModel as ZResponseModel
} from "zod-stream"

export type LogLevel = "debug" | "info" | "warn" | "error"

export type CompletionMeta = Partial<ZCompletionMeta> & {
usage?: OpenAI.CompletionUsage
}
export type Mode = ZMode
export type ResponseModel<T extends z.AnyZodObject> = ZResponseModel<T>

Expand Down Expand Up @@ -37,7 +39,8 @@ export type ReturnTypeBasedOnParams<P> =
response_model: ResponseModel<infer T>
}
) ?
Promise<AsyncGenerator<Partial<z.infer<T>> & { _meta: CompletionMeta }, void, unknown>>
: P extends { response_model: ResponseModel<infer T> } ? Promise<z.infer<T>>
Promise<AsyncGenerator<Partial<z.infer<T>> & { _meta?: CompletionMeta }, void, unknown>>
: P extends { response_model: ResponseModel<infer T> } ?
Promise<z.infer<T> & { _meta?: CompletionMeta }>
: P extends { stream: true } ? Stream<OpenAI.Chat.Completions.ChatCompletionChunk>
: OpenAI.Chat.Completions.ChatCompletion
14 changes: 9 additions & 5 deletions tests/inference.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
// 6. response_model, no stream, max_retries

import Instructor from "@/instructor"
import { type CompletionMeta } from "@/types"
import { describe, expect, test } from "bun:test"
import OpenAI from "openai"
import { Stream } from "openai/streaming"
import { type } from "ts-inference-check"
import { z } from "zod"
import { CompletionMeta } from "zod-stream"

describe("Inference Checking", () => {
const UserSchema = z.object({
Expand Down Expand Up @@ -61,7 +61,9 @@ describe("Inference Checking", () => {
stream: false
})

expect(type(user).strictly.is<z.infer<typeof UserSchema>>(true)).toBe(true)
expect(
type(user).strictly.is<z.infer<typeof UserSchema> & { _meta?: CompletionMeta }>(true)
).toBe(true)
})

test("response_model, stream", async () => {
Expand All @@ -79,7 +81,7 @@ describe("Inference Checking", () => {
Partial<{
name: string
age: number
}> & { _meta: CompletionMeta },
}> & { _meta?: CompletionMeta },
void,
unknown
>
Expand All @@ -103,7 +105,7 @@ describe("Inference Checking", () => {
Partial<{
name: string
age: number
}> & { _meta: CompletionMeta },
}> & { _meta?: CompletionMeta },
void,
unknown
>
Expand All @@ -120,6 +122,8 @@ describe("Inference Checking", () => {
max_retries: 3
})

expect(type(user).strictly.is<z.infer<typeof UserSchema>>(true)).toBe(true)
expect(
type(user).strictly.is<z.infer<typeof UserSchema> & { _meta?: CompletionMeta }>(true)
).toBe(true)
})
})
3 changes: 1 addition & 2 deletions tests/mode.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ async function extractUser(model: string, mode: Mode, provider: Provider) {
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: model,
response_model: { schema: UserSchema, name: "User" },
max_retries: 4,
seed: provider === PROVIDERS.OAI ? 1 : undefined
max_retries: 4
})

return user
Expand Down

0 comments on commit f65672c

Please sign in to comment.