Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(instructor): introduce ANTHROPIC_JSON mode #542

Merged
merged 9 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/blog/posts/anthropic.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors:

A special shoutout to [Shreya](https://twitter.com/shreyaw_) for her contributions to the anthropic support. As of now, all features are operational with the exception of streaming support.

For those eager to experiment, simply patch the client with `ANTHROPIC_TOOLS`, which will enable you to leverage the `anthropic` client for making requests.
For those eager to experiment, simply patch the client with `ANTHROPIC_JSON`, which will enable you to leverage the `anthropic` client for making requests.

```
pip install instructor[anthropic]
Expand All @@ -28,7 +28,7 @@ import instructor
# Patching the Anthropics client with the instructor for enhanced capabilities
anthropic_client = instructor.patch(
create=anthropic.Anthropic().messages.create,
mode=instructor.Mode.ANTHROPIC_TOOLS
mode=instructor.Mode.ANTHROPIC_JSON
)

class Properties(BaseModel):
Expand Down
30 changes: 30 additions & 0 deletions docs/concepts/patching.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,33 @@ from openai import OpenAI

client = instructor.patch(OpenAI(), mode=instructor.Mode.MD_JSON)
```

## Anthropic JSON Mode

Anthropic JSON mode uses Anthropic's JSON format for responses by setting the `mode` parameter to `instructor.Mode.ANTHROPIC_JSON` when patching the client.

```python
import instructor
from anthropic import Anthropic

create = instructor.patch(
create=anthropic.Anthropic().messages.create,
mode=instructor.Mode.ANTHROPIC_JSON
)

class User(BaseModel):
name: str
age: int

resp = create(
model="claude-3-haiku-20240307",
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Create a user",
}
],
response_model=User,
)
```
32 changes: 0 additions & 32 deletions examples/classification/test_run.py

This file was deleted.

8 changes: 7 additions & 1 deletion instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,13 @@ def from_response(
assert hasattr(completion, "content")
return xml_to_model(cls, extract_xml(completion.content[0].text)) # type:ignore

assert hasattr(completion, "choices")
if mode == Mode.ANTHROPIC_JSON:
assert hasattr(completion, "content")
text = completion.content[0].text # type: ignore
extra_text = extract_json_from_codeblock(text)
return cls.model_validate_json(extra_text)

assert hasattr(completion, "choices"), "No choices in completion"

if completion.choices[0].finish_reason == "length":
logger.error("Incomplete output detected, should increase max_tokens")
Expand Down
1 change: 1 addition & 0 deletions instructor/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Mode(enum.Enum):
MD_JSON = "markdown_json_mode"
JSON_SCHEMA = "json_schema_mode"
ANTHROPIC_TOOLS = "anthropic_tools"
ANTHROPIC_JSON = "anthropic_json"

def __new__(cls, value: str) -> "Mode":
member = object.__new__(cls)
Expand Down
32 changes: 30 additions & 2 deletions instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from openai.types.chat import ChatCompletion
from pydantic import BaseModel


import json
import inspect
import logging
from typing import (
Expand Down Expand Up @@ -245,7 +245,7 @@ def handle_response_model(
As a genius expert, your task is to understand the content and provide
the parsed objects in json that match the following json_schema:\n

{response_model.model_json_schema()}
{json.dumps(response_model.model_json_schema(), indent=2)}

Make sure to return an instance of the JSON, not the schema itself
"""
Expand Down Expand Up @@ -305,6 +305,34 @@ def handle_response_model(
new_kwargs["system"] = f"{system_prompt}\n{new_kwargs['system']}"
else:
new_kwargs["system"] = system_prompt
elif mode == Mode.ANTHROPIC_JSON:
# anthropic wants system message to be a string so we first extract out any system message
openai_system_messages = [
message["content"]
for message in new_kwargs.get("messages", [])
if message["role"] == "system"
]

new_kwargs["system"] = (
new_kwargs.get("system", "")
+ "\n\n"
+ "\n\n".join(openai_system_messages)
)

new_kwargs["system"] += f"""
You must only response in JSON format that adheres to the following schema:

<JSON_SCHEMA>
{json.dumps(response_model.model_json_schema(), indent=2)}
</JSON_SCHEMA>
"""
new_kwargs["system"] = dedent(new_kwargs["system"])

new_kwargs["messages"] = [
message
for message in new_kwargs.get("messages", [])
if message["role"] != "system"
]
else:
raise ValueError(f"Invalid patch mode: {mode}")

Expand Down
67 changes: 0 additions & 67 deletions tests/anthropic/test_anthropic.py

This file was deleted.

82 changes: 77 additions & 5 deletions tests/anthropic/test_simple.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import anthropic
import instructor
from pydantic import BaseModel
from typing import List
from typing import List, Literal
from enum import Enum

create = instructor.patch(
create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_TOOLS
create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_JSON
)


Expand Down Expand Up @@ -63,14 +64,14 @@ class User(BaseModel):
assert resp.address.street_name == "First Avenue"


def test_list():
def test_list_str():
class User(BaseModel):
name: str
age: int
family: List[str]

resp = create(
model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
messages=[
Expand All @@ -88,6 +89,54 @@ class User(BaseModel):
assert isinstance(member, str)


def test_enum():
class Role(str, Enum):
ADMIN = "admin"
USER = "user"

class User(BaseModel):
name: str
role: Role

resp = create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
messages=[
{
"role": "user",
"content": "Create a user for a model with a name and role of admin.",
}
],
response_model=User,
) # type: ignore

assert isinstance(resp, User)
assert resp.role == Role.ADMIN


def test_literal():
class User(BaseModel):
name: str
role: Literal["admin", "user"]

resp = create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
messages=[
{
"role": "user",
"content": "Create a admin user for a model with a name and role.",
}
],
response_model=User,
) # type: ignore

assert isinstance(resp, User)
assert resp.role == "admin"


def test_nested_list():
class Properties(BaseModel):
key: str
Expand All @@ -99,7 +148,7 @@ class User(BaseModel):
properties: List[Properties]

resp = create(
model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
messages=[
Expand All @@ -114,3 +163,26 @@ class User(BaseModel):
assert isinstance(resp, User)
for property in resp.properties:
assert isinstance(property, Properties)


def test_system_messages_allcaps():
class User(BaseModel):
name: str
age: int

resp = create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
messages=[
{"role": "system", "content": "EVERYTHING MUST BE IN ALL CAPS"},
{
"role": "user",
"content": "Create a user for a model with a name and age.",
},
],
response_model=User,
) # type: ignore

assert isinstance(resp, User)
assert resp.name.isupper()
Loading