In [3]:
import json
import instructor

from openai import AsyncOpenAI
from helpers import dicts_to_df
from datetime import date
from pydantic import BaseModel, Field


class DateRange(BaseModel):
    chain_of_thought: str = Field(
        description="Think step by step to plan what is the best time range to search in"
    )
    start: date
    end: date


class Query(BaseModel):
    rewritten_query: str = Field(
        description="Rewrite the query to make it more specific"
    )
    published_daterange: DateRange = Field(
        description="Effective date range to search in"
    )

    def report(self):
        dct = self.model_dump()
        dct["usage"] = self._raw_response.usage.model_dump()
        return dct



# We'll use a different client for async calls
# To highlight the difference and how we can use both
aclient = instructor.patch(AsyncOpenAI())


async def expand_query(
    q, *, model: str = "gpt-3.5-turbo", temp: float = 0
) -> Query:
    return await aclient.chat.completions.create(
        model=model,
        temperature=temp,
        response_model=Query,
        messages=[
            {
                "role": "system",
                "content": f"You're a query understanding system for the Metafor Systems search engine. Today is {date.today()}. Here are some tips: ...",
            },
            {"role": "user", "content": f"query: {q}"},
        ],
    )

In [7]:
import asyncio
import time
import pandas as pd
import wandb

model = "gpt-3.5-turbo"
temp = 1

run = wandb.init(
    project="query",
    config={"model": model, "temp": temp},
)

test_queries = [
    "latest developments in artificial intelligence last 3 weeks",
    "renewable energy trends past month",
    "quantum computing advancements last 2 months",
    "biotechnology updates last 10 days",
]
start = time.perf_counter()
queries = await asyncio.gather(
    *[expand_query(q, model=model, temp=temp) for q in test_queries]
)
duration = time.perf_counter() - start

with open("schema.json", "w+") as f:
    schema = Query.model_json_schema()
    json.dump(schema, f, indent=2)

with open("results.jsonlines", "w+") as f:
    for query in queries:
        f.write(query.model_dump_json() + "\n")

df = dicts_to_df([q.report() for q in queries])
df["input"] = test_queries
df.to_csv("results.csv")


run.log({"schema": wandb.Table(dataframe=pd.DataFrame([{"schema": schema}]))})

run.log(
    {
        "usage_total_tokens": df["usage_total_tokens"].sum(),
        "usage_completion_tokens": df["usage_completion_tokens"].sum(),
        "usage_prompt_tokens": df["usage_prompt_tokens"].sum(),
        "duration (s)": duration,
        "average duration (s)": duration / len(queries),
        "n_queries": len(queries),
    }
)


run.log(
    {
        "results": wandb.Table(dataframe=df),
    }
)

files = wandb.Artifact("data", type="dataset")

files.add_file("schema.json")
files.add_file("results.jsonlines")
files.add_file("results.csv")


run.log_artifact(files)
run.finish()

0,1
average duration (s),▁
duration (s),▁
n_queries,▁
usage_completion_tokens,▁
usage_prompt_tokens,▁
usage_total_tokens,▁

0,1
average duration (s),0.59857
duration (s),2.39428
n_queries,4.0
usage_completion_tokens,200.0
usage_prompt_tokens,780.0
usage_total_tokens,980.0
