Skip to content

Commit

Permalink
updating to tee the primary stream if stream usage is enabled - so we…
Browse files Browse the repository at this point in the history
… can extract usage and include in _meta (#176)
  • Loading branch information
roodboi committed May 17, 2024
1 parent 3fb0b08 commit 6dd4255
Show file tree
Hide file tree
Showing 28 changed files with 90 additions and 44 deletions.
5 changes: 5 additions & 0 deletions .changeset/calm-knives-sin.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@instructor-ai/instructor": minor
---

add ability to include usage from streams by teeing stream when option is present
Binary file modified bun.lockb
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/concepts/streaming.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ A follow-up meeting is scheduled for January 25th at 3 PM GMT to finalize the ag

const extractionStream = await client.chat.completions.create({
messages: [{ role: "user", content: textBlock }],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: {
schema: ExtractionValuesSchema,
name: "value extraction"
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/action_items.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ const extractActionItems = async (data: string): Promise<ActionItems | undefined
"content": `Create the action items for the following transcript: ${data}`,
},
],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: ActionItemsSchema },
max_tokens: 1000,
temperature: 0.0,
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/query_decomposition.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ const createQueryPlan = async (question: string): Promise<QueryPlan | undefined>
"content": `Consider: ${question}\nGenerate the correct query plan.`,
},
],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: QueryPlanSchema },
max_tokens: 1000,
temperature: 0.0,
Expand Down
8 changes: 4 additions & 4 deletions docs/examples/self_correction.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ const question = "What is the meaning of life?"
const context = "According to the devil the meaning of live is to live a life of sin and debauchery."

await instructor.chat.completions.create({
model: "gpt-4",
model: "gpt-4o",
max_retries: 0,
response_model: { schema: QuestionAnswer, name: "Question and Answer" },
messages: [
Expand Down Expand Up @@ -82,14 +82,14 @@ const QuestionAnswer = z.object({
question: z.string(),
answer: z.string().superRefine(
LLMValidator(instructor, statement, {
model: "gpt-4"
model: "gpt-4o"
})
)
})

try {
await instructor.chat.completions.create({
model: "gpt-4",
model: "gpt-4o",
max_retries: 0,
response_model: { schema: QuestionAnswer, name: "Question and Answer" },
messages: [
Expand Down Expand Up @@ -132,7 +132,7 @@ By adding the `max_retries` parameter, we can retry the request with corrections
```ts
try {
await instructor.chat.completions.create({
model: "gpt-4",
model: "gpt-4o",
max_retries: 2,
response_model: { schema: QuestionAnswer, name: "Question and Answer" },
messages: [
Expand Down
2 changes: 1 addition & 1 deletion examples/action_items/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const extractActionItems = async (data: string) => {
content: `Create the action items for the following transcript: ${data}`
}
],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: ActionItemsSchema, name: "ActionItems" },
max_tokens: 1000,
temperature: 0.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/extract_user/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const client = Instructor({

const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-4",
model: "gpt-4o",
response_model: {
schema: UserSchema,
name: "User"
Expand Down
2 changes: 1 addition & 1 deletion examples/extract_user/properties.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const client = Instructor({

const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Happy Potter" }],
model: "gpt-4",
model: "gpt-4o",
response_model: { schema: UserSchema, name: "User" },
max_retries: 3,
seed: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/extract_user_stream/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ let extraction = {}

const extractionStream = await client.chat.completions.create({
messages: [{ role: "user", content: textBlock }],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: {
schema: ExtractionValuesSchema,
name: "value extraction"
Expand Down
4 changes: 2 additions & 2 deletions examples/llm-validator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const QuestionAnswer = z.object({
question: z.string(),
answer: z.string().superRefine(
LLMValidator(instructor, statement, {
model: "gpt-4-turbo"
model: "gpt-4o"
})
)
})
Expand All @@ -25,7 +25,7 @@ const question = "What is the meaning of life?"

const check = async (context: string) => {
return await instructor.chat.completions.create({
model: "gpt-4-turbo",
model: "gpt-4o",
max_retries: 2,
response_model: { schema: QuestionAnswer, name: "Question and Answer" },
messages: [
Expand Down
2 changes: 1 addition & 1 deletion examples/query_decomposition/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const createQueryPlan = async (question: string) => {
content: `Consider: ${question}\nGenerate the correct query plan.`
}
],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: QueryPlanSchema, name: "Query Plan Decomposition" },
max_tokens: 1000,
temperature: 0.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/query_expansions/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ const runExtraction = async (query: string) => {
{ role: "system", content: systemPrompt },
{ role: "user", content: query }
],
model: "gpt-4",
model: "gpt-4o",
response_model: {
schema: ExtractionValuesSchema,
name: "value_extraction"
Expand Down
4 changes: 2 additions & 2 deletions examples/query_expansions/run_sync.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ export const runExtractionStream = async (query: string) => {
{ role: "system", content: systemPrompt },
{ role: "user", content: query }
],
model: "gpt-4",
model: "gpt-4o",
response_model: {
schema: SearchQuery,
name: "value_extraction"
Expand Down Expand Up @@ -124,7 +124,7 @@ const runExtraction = async (query: string) => {
{ role: "system", content: systemPrompt },
{ role: "user", content: query }
],
model: "gpt-4",
model: "gpt-4o",
response_model: {
schema: Response,
name: "Respond"
Expand Down
2 changes: 1 addition & 1 deletion examples/resolving-complex-entitities/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ const askAi = async (input: string) => {
content: input
}
],
model: "gpt-4",
model: "gpt-4o",
response_model: { schema: DocumentExtractionSchema, name: "Document Extraction" },
max_retries: 3,
seed: 1
Expand Down
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
},
"homepage": "https://github.com/instructor-ai/instructor-js#readme",
"dependencies": {
"zod-stream": "1.0.2",
"zod-stream": "1.0.3",
"zod-validation-error": "^2.1.0"
},
"peerDependencies": {
Expand All @@ -76,6 +76,7 @@
"eslint-plugin-prettier": "^5.1.2",
"husky": "^8.0.3",
"llm-polyglot": "1.0.0",
"openai": "latest",
"prettier": "latest",
"ts-inference-check": "^0.3.0",
"tsup": "^8.0.1",
Expand Down
2 changes: 1 addition & 1 deletion src/constants/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ export const PROVIDER_SUPPORTED_MODES_BY_MODEL = {
[PROVIDERS.OAI]: {
[MODE.FUNCTIONS]: ["*"],
[MODE.TOOLS]: ["*"],
[MODE.JSON]: ["gpt-3.5-turbo-1106", "gpt-4-turbo", "gpt-4-0125-preview", "gpt-4-turbo-preview"],
[MODE.JSON]: ["*"],
[MODE.MD_JSON]: ["*"]
},
[PROVIDERS.TOGETHER]: {
Expand Down
41 changes: 38 additions & 3 deletions src/instructor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
ReturnTypeBasedOnParams
} from "@/types"
import OpenAI from "openai"
import { Stream } from "openai/streaming.mjs"
import { z, ZodError } from "zod"
import ZodStream, { OAIResponseParser, OAIStream, withResponseModel, type Mode } from "zod-stream"
import { fromZodError } from "zod-validation-error"
Expand Down Expand Up @@ -266,10 +267,10 @@ class Instructor<C extends GenericClient | OpenAI> {
return makeCompletionCallWithRetries()
}

private async chatCompletionStream<T extends z.AnyZodObject>(
private async *chatCompletionStream<T extends z.AnyZodObject>(
{ max_retries, response_model, ...params }: ChatCompletionCreateParamsWithModel<T>,
requestOptions?: ClientTypeChatCompletionRequestOptions<C>
): Promise<AsyncGenerator<Partial<T> & { _meta?: CompletionMeta }, void, unknown>> {
): AsyncGenerator<Partial<T> & { _meta?: CompletionMeta }, void, unknown> {
if (max_retries) {
this.log("warn", "max_retries is not supported for streaming completions")
}
Expand All @@ -293,7 +294,16 @@ class Instructor<C extends GenericClient | OpenAI> {
debug: this.debug ?? false
})

return streamClient.create({
async function checkForUsage(reader: Stream<OpenAI.ChatCompletionChunk>) {
for await (const chunk of reader) {
if ("usage" in chunk) {
streamUsage = chunk.usage as CompletionMeta["usage"]
}
}
}

let streamUsage: CompletionMeta["usage"] | undefined
const structuredStream = await streamClient.create({
completionPromise: async () => {
if (this.client.chat?.completions?.create) {
const completion = await this.client.chat.completions.create(
Expand All @@ -306,6 +316,21 @@ class Instructor<C extends GenericClient | OpenAI> {

this.log("debug", "raw stream completion response: ", completion)

if (
this.provider === "OAI" &&
completionParams?.stream &&
"stream_options" in completionParams &&
completion instanceof Stream
) {
const [completion1, completion2] = completion.tee()

checkForUsage(completion1)

return OAIStream({
res: completion2
})
}

return OAIStream({
res: completion as unknown as AsyncIterable<OpenAI.ChatCompletionChunk>
})
Expand All @@ -315,6 +340,16 @@ class Instructor<C extends GenericClient | OpenAI> {
},
response_model
})

for await (const chunk of structuredStream) {
yield {
...chunk,
_meta: {
usage: streamUsage ?? undefined,
...(chunk?._meta ?? {})
}
}
}
}

private isChatCompletionCreateParamsWithModel<T extends z.AnyZodObject>(
Expand Down
2 changes: 1 addition & 1 deletion src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export type ReturnTypeBasedOnParams<C, P> =
response_model: ResponseModel<infer T>
}
) ?
Promise<AsyncGenerator<Partial<z.infer<T>> & { _meta?: CompletionMeta }, void, unknown>>
AsyncGenerator<Partial<z.infer<T>> & { _meta?: CompletionMeta }, void, unknown>
: P extends { response_model: ResponseModel<infer T> } ?
Promise<z.infer<T> & { _meta?: CompletionMeta }>
: C extends OpenAI ?
Expand Down
6 changes: 3 additions & 3 deletions tests/extract.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async function extractUser() {

const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: UserSchema, name: "User" },
seed: 1
})
Expand Down Expand Up @@ -49,7 +49,7 @@ async function extractUserValidated() {

const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-4",
model: "gpt-4o",
response_model: { schema: UserSchema, name: "User" },
max_retries: 3,
seed: 1
Expand Down Expand Up @@ -82,7 +82,7 @@ async function extractUserMany() {

const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason is 30 years old, Sarah is 12" }],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: UsersSchema, name: "Users" },
max_retries: 3,
seed: 1
Expand Down
6 changes: 3 additions & 3 deletions tests/functions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async function extractUser() {

const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: UserSchema, name: "User" },
seed: 1
})
Expand Down Expand Up @@ -52,7 +52,7 @@ async function extractUserValidated() {

const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: UserSchema, name: "User" },
max_retries: 3,
seed: 1
Expand Down Expand Up @@ -85,7 +85,7 @@ async function extractUserMany() {

const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason is 30 years old, Sarah is 12" }],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: UsersSchema, name: "Users" },
max_retries: 3,
seed: 1
Expand Down
12 changes: 6 additions & 6 deletions tests/inference.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ describe("Inference Checking", () => {
test("no response_model, no stream", async () => {
const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-4-turbo",
model: "gpt-4o",
seed: 1,
stream: false
})
Expand All @@ -44,7 +44,7 @@ describe("Inference Checking", () => {
test("no response_model, stream", async () => {
const userStream = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-4-turbo",
model: "gpt-4o",
seed: 1,
stream: true
})
Expand All @@ -57,7 +57,7 @@ describe("Inference Checking", () => {
test("response_model, no stream", async () => {
const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: UserSchema, name: "User" },
seed: 1,
stream: false
Expand All @@ -71,7 +71,7 @@ describe("Inference Checking", () => {
test("response_model, stream", async () => {
const userStream = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: UserSchema, name: "User" },
seed: 1,
stream: true
Expand All @@ -94,7 +94,7 @@ describe("Inference Checking", () => {
test("response_model, stream, max_retries", async () => {
const userStream = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: UserSchema, name: "User" },
seed: 1,
stream: true,
Expand All @@ -118,7 +118,7 @@ describe("Inference Checking", () => {
test("response_model, no stream, max_retries", async () => {
const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-4-turbo",
model: "gpt-4o",
response_model: { schema: UserSchema, name: "User" },
seed: 1,
max_retries: 3
Expand Down
2 changes: 1 addition & 1 deletion tests/maybe.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async function maybeExtractUser(content: string) {

const user = await client.chat.completions.create({
messages: [{ role: "user", content: "Extract " + content }],
model: "gpt-4",
model: "gpt-4o",
response_model: { schema: MaybeUserSchema, name: "User" },
max_retries: 3,
seed: 1
Expand Down
Loading

0 comments on commit 6dd4255

Please sign in to comment.