diff --git a/README.md b/README.md index fa454d3c1..c0d76b7af 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,46 @@ assert resp.name == "Jason" assert resp.age == 25 ``` +### Using Cohere Models + +Make sure to install `cohere` and set your system environment variable with `export CO_API_KEY=`. + +``` +pip install cohere +``` + +```python +import instructor +import cohere +from pydantic import BaseModel + + +class User(BaseModel): + name: str + age: int + + +client = instructor.from_cohere(cohere.Client()) + +# note that client.chat.completions.create will also work +resp = client.chat.completions.create( + model="command-r-plus", + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Extract Jason is 25 years old.", + } + ], + response_model=User, +) + +assert isinstance(resp, User) +assert resp.name == "Jason" +assert resp.age == 25 +``` + + ### Using Litellm ```python @@ -152,7 +192,7 @@ Now if you use a IDE, you can see the type is correctly infered. ### Handling async: `await create` -This will also work correctly with asynchronous clients. +This will also work correctly with asynchronous clients. ```python import openai diff --git a/docs/examples/document_segmentation.md b/docs/examples/document_segmentation.md new file mode 100644 index 000000000..e67260e37 --- /dev/null +++ b/docs/examples/document_segmentation.md @@ -0,0 +1,143 @@ +# Document Segmentation + +In this guide, we demonstrate how to do document segmentation using structured output from an LLM. We'll be using [command-r-plus](https://docs.cohere.com/docs/command-r-plus) - one of Cohere's latest LLMs with 128k context length and testing the approach on an article explaining the Transformer architecture. Same approach to document segmentation can be applied to any other domain where we need to break down a complex long document into smaller chunks. + +!!! tips "Motivation" + Sometimes we need a way to split the document into meaningful parts that center around a signle key concept/idea. Simple length-based / rule-based text-splitters are not reliable enough. Consider the cases where documents contain code snippets or math equations - we don't want to split those on `'\n\n'` or have to write extensive rules for different types of documents. It turns out that LLMs with sufficiently long context length are well suited for this task. + +## Defining the Data Structures + +First, we need to define a **`Section`** class for each of the document's segments. **`StructuredDocument`** class will then encapsulate a list of these sections. + +Note that in order to avoid LLM regenerating the content of each section, we can simply enumerate each line of the input document and then ask LLM to segment it by providing start-end line numbers for each section. + +```python +from pydantic import BaseModel, Field +from typing import List, Dict, Any + +class Section(BaseModel): + title: str = Field(description="main topic of this section of the document") + start_index: int = Field(description="line number where the section begins") + end_index: int = Field(description="line number where the section ends") + + +class StructuredDocument(BaseModel): + """obtains meaningful sections, each centered around a single concept/topic""" + sections: List[Section] = Field(description="a list of sections of the document") +``` + +## Document Preprocessing + +Preprocess the input `document` by prepending each line with its number. + +```python +def doc_with_lines(document): + document_lines = document.split("\n") + document_with_line_numbers = "" + line2text = {} + for i, line in enumerate(document_lines): + document_with_line_numbers += f"[{i}] {line}\n" + line2text[i] = line + return document_with_line_numbers, line2text +``` + +## Segmentation + +Next use a Cohere client to extract `StructuredDocument` from the preprocessed doc. + +```python +import instructor +import cohere + +# Apply the patch to the cohere client +# enables response_model keyword +client = instructor.from_cohere(cohere.Client()) + + +system_prompt = f"""\ +You are a world class educator working on organizing your lecture notes. +Read the document below and extract a StructuredDocument object from it where each section of the document is centered around a single concept/topic that can be taught in one lesson. +Each line of the document is marked with its line number in square brackets (e.g. [1], [2], [3], etc). Use the line numbers to indicate section start and end. +""" + + +def get_structured_document(document_with_line_numbers) -> StructuredDocument: + return client.chat.completions.create( + model="command-r-plus", + response_model=StructuredDocument, + messages=[ + { + "role": "system", + "content": system_prompt, + }, + { + "role": "user", + "content": document_with_line_numbers, + }, + ], + ) # type: ignore +``` + + +Next, we need to get back the section text based on the start/end indices and our `line2text` dict from the preprocessing step. + +```python +def get_sections_text(structured_doc, line2text): + segments = [] + for s in structured_doc.sections: + contents = [] + for line_id in range(s.start_index, s.end_index): + contents.append(line2text.get(line_id, '')) + segments.append({ + "title": s.title, + "content": "\n".join(contents), + "start": s.start_index, + "end": s.end_index + }) + return segments +``` + + +## Example + +Here's an example of using these classes and functions to segment a tutorial on Transformers from [Sebastian Raschka](https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html). We can use `trafilatura` package to scrape the web page content of the article. + +```python +from trafilatura import fetch_url, extract + + +url='https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html' +downloaded = fetch_url(url) +document = extract(downloaded) + + +document_with_line_numbers, line2text = doc_with_lines(document) +structured_doc = get_structured_document(document_with_line_numbers) +segments = get_sections_text(structured_doc, line2text) +``` + +``` +print(segments[5]['title']) +""" +Introduction to Multi-Head Attention +""" +print(segments[5]['content']) +""" +Multi-Head Attention +In the very first figure, at the top of this article, we saw that transformers use a module called multi-head attention. How does that relate to the self-attention mechanism (scaled-dot product attention) we walked through above? +In the scaled dot-product attention, the input sequence was transformed using three matrices representing the query, key, and value. These three matrices can be considered as a single attention head in the context of multi-head attention. The figure below summarizes this single attention head we covered previously: +As its name implies, multi-head attention involves multiple such heads, each consisting of query, key, and value matrices. This concept is similar to the use of multiple kernels in convolutional neural networks. +To illustrate this in code, suppose we have 3 attention heads, so we now extend the \(d' \times d\) dimensional weight matrices so \(3 \times d' \times d\): +In: +h = 3 +multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d)) +multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d)) +multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d)) +Consequently, each query element is now \(3 \times d_q\) dimensional, where \(d_q=24\) (here, let’s keep the focus on the 3rd element corresponding to index position 2): +In: +multihead_query_2 = multihead_W_query.matmul(x_2) +print(multihead_query_2.shape) +Out: +torch.Size([3, 24]) +""" +``` diff --git a/docs/examples/index.md b/docs/examples/index.md index 2ccb4b2b0..26feb1eea 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -18,5 +18,6 @@ 14. [How to use local models from Ollama](ollama.md) 15. [How to store responses in a database with SQLModel](sqlmodel.md) 16. [How to use groqcloud api](groq.md) +17. [How to do document segmentation using LLMs?](document_segmentation.md) Explore more! diff --git a/docs/hub/cohere.md b/docs/hub/cohere.md new file mode 100644 index 000000000..7f6fa7e72 --- /dev/null +++ b/docs/hub/cohere.md @@ -0,0 +1,85 @@ +# Structured Outputs with Cohere + +If you want to try this example using `instructor hub`, you can pull it by running + +```bash +instructor hub pull --slug cohere --py > cohere_example.py +``` + +You can now use any of the Cohere's [command models](https://docs.cohere.com/docs/models) with the `instructor` library to get structured outputs. + +You'll need a cohere API key which can be obtained by signing up [here](https://dashboard.cohere.com/) and gives you [free](https://cohere.com/pricing), rate-limited usage for learning and prototyping. + +## Setup +``` +pip install cohere +``` +Export your key: +``` +export CO_API_KEY= +``` + +## Example + +```python +from pydantic import BaseModel, Field +from typing import List +import cohere +import instructor + + +# Patching the Cohere client with the instructor for enhanced capabilities +client = instructor.from_cohere( + cohere.Client(), + max_tokens=1000, + model="command-r-plus", +) + + +class Person(BaseModel): + name: str = Field(description="name of the person") + country_of_origin: str = Field(description="country of origin of the person") + + +class Group(BaseModel): + group_name: str = Field(description="name of the group") + members: List[Person] = Field(description="list of members in the group") + + +task = """\ +Given the following text, create a Group object for 'The Beatles' band + +Text: +The Beatles were an English rock band formed in Liverpool in 1960. With a line-up comprising John Lennon, Paul McCartney, George Harrison and Ringo Starr, they are regarded as the most influential band of all time. The group were integral to the development of 1960s counterculture and popular music's recognition as an art form. +""" +group = client.messages.create( + response_model=Group, + messages=[{"role": "user", "content": task}], + temperature=0, +) + +print(group.model_dump_json(indent=2)) +""" +{ + "group_name": "The Beatles", + "members": [ + { + "name": "John Lennon", + "country_of_origin": "England" + }, + { + "name": "Paul McCartney", + "country_of_origin": "England" + }, + { + "name": "George Harrison", + "country_of_origin": "England" + }, + { + "name": "Ringo Starr", + "country_of_origin": "England" + } + ] +} +""" +``` \ No newline at end of file diff --git a/examples/cohere/cohere.py b/examples/cohere/cohere.py new file mode 100644 index 000000000..fc330024e --- /dev/null +++ b/examples/cohere/cohere.py @@ -0,0 +1,60 @@ +import cohere +import instructor +from pydantic import BaseModel, Field +from typing import List + + +# Patching the Cohere client with the instructor for enhanced capabilities +client = instructor.from_cohere( + cohere.Client(), + max_tokens=1000, + model="command-r-plus", +) + + +class Person(BaseModel): + name: str = Field(description="name of the person") + country_of_origin: str = Field(description="country of origin of the person") + + +class Group(BaseModel): + group_name: str = Field(description="name of the group") + members: List[Person] = Field(description="list of members in the group") + + +task = """\ +Given the following text, create a Group object for 'The Beatles' band + +Text: +The Beatles were an English rock band formed in Liverpool in 1960. With a line-up comprising John Lennon, Paul McCartney, George Harrison and Ringo Starr, they are regarded as the most influential band of all time. The group were integral to the development of 1960s counterculture and popular music's recognition as an art form. +""" +group = client.messages.create( + response_model=Group, + messages=[{"role": "user", "content": task}], + temperature=0, +) + +print(group.model_dump_json(indent=2)) +""" +{ + "group_name": "The Beatles", + "members": [ + { + "name": "John Lennon", + "country_of_origin": "England" + }, + { + "name": "Paul McCartney", + "country_of_origin": "England" + }, + { + "name": "George Harrison", + "country_of_origin": "England" + }, + { + "name": "Ringo Starr", + "country_of_origin": "England" + } + ] +} +""" \ No newline at end of file diff --git a/examples/extract-table/run_vision.py b/examples/extract-table/run_vision.py index 27a1d8713..f14a14260 100644 --- a/examples/extract-table/run_vision.py +++ b/examples/extract-table/run_vision.py @@ -1,6 +1,6 @@ from openai import OpenAI from io import StringIO -from typing import Annotated, Any, List, Generator +from typing import Annotated, Any, List from pydantic import ( BaseModel, BeforeValidator, diff --git a/examples/extract-table/test.py b/examples/extract-table/test.py new file mode 100644 index 000000000..3ec708cf5 --- /dev/null +++ b/examples/extract-table/test.py @@ -0,0 +1,106 @@ +from pydantic import BaseModel + +from openai import OpenAI +import instructor + +client = OpenAI() + +client = instructor.from_openai(client) + + +class User(BaseModel): + name: str + email: str + + +class MeetingInfo(BaseModel): + user: User + date: str + location: str + budget: int + deadline: str + + +data = """ +Jason Liu jason@gmail.com +Meeting Date: 2024-01-01 +Meeting Location: 1234 Main St +Meeting Budget: $1000 +Meeting Deadline: 2024-01-31 +""" +stream1 = client.chat.completions.create_partial( + model="gpt-4", + response_model=MeetingInfo, + messages=[ + { + "role": "user", + "content": f"Get the information about the meeting and the users {data}", + }, + ], + stream=True, +) # type: ignore + +for message in stream1: + print(message) +""" +ser={} date=None location=None budget=None deadline=None +user={} date=None location=None budget=None deadline=None +user={} date=None location=None budget=None deadline=None +user={} date=None location=None budget=None deadline=None +user=PartialUser(name=None, email=None) date=None location=None budget=None deadline=None +user=PartialUser(name=None, email=None) date=None location=None budget=None deadline=None +user=PartialUser(name=None, email=None) date=None location=None budget=None deadline=None +user=PartialUser(name=None, email=None) date=None location=None budget=None deadline=None +user=PartialUser(name=None, email=None) date=None location=None budget=None deadline=None +user=PartialUser(name=None, email=None) date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email=None) date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email=None) date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email=None) date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email=None) date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email=None) date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email=None) date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email=None) date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email=None) date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email=None) date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date=None location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location=None budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=None deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=100 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline=None +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline='2024-01-31' +user=PartialUser(name='Jason Liu', email='jason@gmail.com') date='2024-01-01' location='1234 Main St' budget=1000 deadline='2024-01-31' +""" diff --git a/instructor/__init__.py b/instructor/__init__.py index 3c4e79d94..b7bac032c 100644 --- a/instructor/__init__.py +++ b/instructor/__init__.py @@ -56,3 +56,8 @@ from .client_groq import from_groq __all__ += ["from_groq"] + +if importlib.util.find_spec("cohere") is not None: + from .client_cohere import from_cohere + + __all__ += ["from_cohere"] diff --git a/instructor/client_cohere.py b/instructor/client_cohere.py new file mode 100644 index 000000000..8064694bb --- /dev/null +++ b/instructor/client_cohere.py @@ -0,0 +1,85 @@ +import cohere +import instructor +from functools import wraps +from typing import ( + ParamSpec, + Type, + TypeVar, + overload, +) +from pydantic import BaseModel +from instructor.process_response import handle_response_model +from instructor.retry import retry_async + + +T_Model = TypeVar("T_Model", bound=BaseModel) +T_ParamSpec = ParamSpec("T_ParamSpec") + + +@overload +def from_cohere( + client: cohere.Client, + mode: instructor.Mode = instructor.Mode.COHERE_TOOLS, + **kwargs, +) -> instructor.Instructor: ... + + +@overload +def from_cohere( + client: cohere.AsyncClient, + mode: instructor.Mode = instructor.Mode.COHERE_TOOLS, + **kwargs, +) -> instructor.AsyncInstructor: ... + + +def from_cohere( + client: cohere.Client | cohere.AsyncClient, + mode: instructor.Mode = instructor.Mode.COHERE_TOOLS, + **kwargs, +): + assert mode in { + instructor.Mode.COHERE_TOOLS, + }, "Mode be one of {instructor.Mode.COHERE_TOOLS}" + + assert isinstance( + client, (cohere.Client, cohere.AsyncClient) + ), "Client must be an instance of cohere.Cohere or cohere.AsyncCohere" + + if isinstance(client, cohere.Client): + return instructor.Instructor( + client=client, + create=instructor.patch(create=client.chat, mode=mode), + provider=instructor.Provider.COHERE, + mode=mode, + **kwargs, + ) + + @wraps(client.chat) + async def new_create_async( + response_model: Type[T_Model] = None, + validation_context: dict = None, + max_retries: int = 1, + *args: T_ParamSpec.args, + **kwargs: T_ParamSpec.kwargs, + ) -> T_Model: + response_model, new_kwargs = handle_response_model( + response_model=response_model, mode=mode, **kwargs + ) + response = await retry_async( + func=client.chat, + response_model=response_model, + validation_context=validation_context, + max_retries=max_retries, + args=args, + kwargs=new_kwargs, + mode=mode, + ) # type: ignore + return response + + return instructor.AsyncInstructor( + client=client, + create=new_create_async, + provider=instructor.Provider.COHERE, + mode=mode, + **kwargs, + ) diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 99db37daf..227bd4be0 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -95,6 +95,9 @@ def from_response( if mode == Mode.ANTHROPIC_JSON: return cls.parse_anthropic_json(completion, validation_context, strict) + if mode == Mode.COHERE_TOOLS: + return cls.parse_cohere_tools(completion, validation_context, strict) + if completion.choices[0].finish_reason == "length": raise IncompleteOutputException() @@ -137,6 +140,19 @@ def parse_anthropic_json( extra_text, context=validation_context, strict=strict ) + @classmethod + def parse_cohere_tools( + cls: Type[BaseModel], + completion, + validation_context: Optional[Dict[str, Any]] = None, + strict: Optional[bool] = None, + ) -> BaseModel: + text = completion.text + extra_text = extract_json_from_codeblock(text) + return cls.model_validate_json( + extra_text, context=validation_context, strict=strict + ) + @classmethod def parse_functions( cls: Type[BaseModel], diff --git a/instructor/mode.py b/instructor/mode.py index c2d31baf6..f40d02dc2 100644 --- a/instructor/mode.py +++ b/instructor/mode.py @@ -14,6 +14,7 @@ class Mode(enum.Enum): JSON_SCHEMA = "json_schema_mode" ANTHROPIC_TOOLS = "anthropic_tools" ANTHROPIC_JSON = "anthropic_json" + COHERE_TOOLS = "cohere_tools" def __new__(cls, value: str) -> "Mode": member = object.__new__(cls) diff --git a/instructor/process_response.py b/instructor/process_response.py index 3ea312127..727d5c069 100644 --- a/instructor/process_response.py +++ b/instructor/process_response.py @@ -308,7 +308,7 @@ def handle_response_model( new_kwargs["system"] += f""" You must only response in JSON format that adheres to the following schema: - + {json.dumps(response_model.model_json_schema(), indent=2)} @@ -325,6 +325,29 @@ def handle_response_model( # consecutive user messages into a single message new_kwargs["messages"] = merge_consecutive_messages(new_kwargs["messages"]) + elif mode == Mode.COHERE_TOOLS: + instruction = f"""\ +Extract a valid {response_model.__name__} object based on the chat history and the json schema below. +{response_model.model_json_schema()} +The JSON schema was obtained by running: +```python +schema = {response_model.__name__}.model_json_schema() +``` + +The output must be a valid JSON object that `{response_model.__name__}.model_validate_json()` can successfully parse. +""" + messages = new_kwargs.pop("messages", []) + chat_history = [] + for message in messages: + # format in Cohere's ChatMessage format + chat_history.append( + { + "role": message["role"], + "message": message["content"], + } + ) + new_kwargs["message"] = instruction + new_kwargs["chat_history"] = chat_history else: raise ValueError(f"Invalid patch mode: {mode}") diff --git a/instructor/retry.py b/instructor/retry.py index 3431856a6..e997685fd 100644 --- a/instructor/retry.py +++ b/instructor/retry.py @@ -69,6 +69,12 @@ def reask_messages(response: ChatCompletion, mode: Mode, exception: Exception): "content": f"""Validation Errors found:\n{exception}\nRecall the function correctly, fix the errors found in the following attempt:\n{response.content[0].text}""", } return + if mode == Mode.COHERE_TOOLS: + yield { + "role": "user", + "content": f"Validation Error found:\n{exception}\nRecall the function correctly, fix the errors", + } + return yield dump_message(response.choices[0].message) # TODO: Give users more control on configuration diff --git a/instructor/utils.py b/instructor/utils.py index 01c0b4e27..4f3c84b95 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -21,6 +21,7 @@ class Provider(Enum): ANYSCALE = "anyscale" TOGETHER = "together" GROQ = "groq" + COHERE = "cohere" UNKNOWN = "unknown" @@ -35,6 +36,8 @@ def get_provider(base_url: str) -> Provider: return Provider.GROQ elif "openai" in str(base_url): return Provider.OPENAI + elif "cohere" in str(base_url): + return Provider.COHERE return Provider.UNKNOWN diff --git a/mkdocs.yml b/mkdocs.yml index 4c492f186..2ce078d67 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -143,6 +143,7 @@ nav: - Ollama: 'examples/ollama.md' - SQLModel Integration: 'examples/sqlmodel.md' - Including Examples in Prompt: 'examples/examples.md' + - Document Segmentation: 'examples/document_segmentation.md' - Blog: - "blog/index.md" - Concepts: @@ -177,6 +178,7 @@ nav: - Using Together Compute: 'hub/together.md' - Using Anyscale: 'hub/anyscale.md' - Using Groq: 'hub/groq.md' + - Using Cohere: 'hub/cohere.md' - Batch Async Classification w/ Langsmith: 'hub/batch_classification_langsmith.md' - Action Items: 'hub/action_items.md' - Partial Streaming: 'hub/partial_streaming.md' diff --git a/poetry.lock b/poetry.lock index 08c8ec053..9e8a16490 100644 --- a/poetry.lock +++ b/poetry.lock @@ -554,6 +554,26 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "cohere" +version = "5.2.4" +description = "" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "cohere-5.2.4-py3-none-any.whl", hash = "sha256:50e8cbd009a6d6f6ce7127a0b62c50d3dfcfdb853f681f0ac315cfa70599fee4"}, + {file = "cohere-5.2.4.tar.gz", hash = "sha256:2bf6e905773116ad3fff348e054e4ce1a1830092a63cb48fa8180beda4cbb96a"}, +] + +[package.dependencies] +fastavro = ">=1.9.4,<2.0.0" +httpx = ">=0.21.2" +pydantic = ">=1.9.2" +requests = ">=2.0.0,<3.0.0" +tokenizers = ">=0.15.2,<0.16.0" +types-requests = ">=2.0.0,<3.0.0" +typing_extensions = ">=4.0.0" + [[package]] name = "colorama" version = "0.4.6" @@ -808,6 +828,52 @@ typing-extensions = ">=4.8.0" [package.extras] all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +[[package]] +name = "fastavro" +version = "1.9.4" +description = "Fast read/write of AVRO files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastavro-1.9.4-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:60cb38f07462a7fb4e4440ed0de67d3d400ae6b3d780f81327bebde9aa55faef"}, + {file = "fastavro-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:063d01d197fc929c20adc09ca9f0ca86d33ac25ee0963ce0b438244eee8315ae"}, + {file = "fastavro-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87a9053fcfbc895f2a16a4303af22077e3a8fdcf1cd5d6ed47ff2ef22cbba2f0"}, + {file = "fastavro-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:02bf1276b7326397314adf41b34a4890f6ffa59cf7e0eb20b9e4ab0a143a1598"}, + {file = "fastavro-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56bed9eca435389a8861e6e2d631ec7f8f5dda5b23f93517ac710665bd34ca29"}, + {file = "fastavro-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:0cd2099c8c672b853e0b20c13e9b62a69d3fbf67ee7c59c7271ba5df1680310d"}, + {file = "fastavro-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:af8c6d8c43a02b5569c093fc5467469541ac408c79c36a5b0900d3dd0b3ba838"}, + {file = "fastavro-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4a138710bd61580324d23bc5e3df01f0b82aee0a76404d5dddae73d9e4c723f"}, + {file = "fastavro-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:903d97418120ca6b6a7f38a731166c1ccc2c4344ee5e0470d09eb1dc3687540a"}, + {file = "fastavro-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c443eeb99899d062dbf78c525e4614dd77e041a7688fa2710c224f4033f193ae"}, + {file = "fastavro-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ac26ab0774d1b2b7af6d8f4300ad20bbc4b5469e658a02931ad13ce23635152f"}, + {file = "fastavro-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:cf7247874c22be856ba7d1f46a0f6e0379a6025f1a48a7da640444cbac6f570b"}, + {file = "fastavro-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:68912f2020e1b3d70557260b27dd85fb49a4fc6bfab18d384926127452c1da4c"}, + {file = "fastavro-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6925ce137cdd78e109abdb0bc33aad55de6c9f2d2d3036b65453128f2f5f5b92"}, + {file = "fastavro-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b928cd294e36e35516d0deb9e104b45be922ba06940794260a4e5dbed6c192a"}, + {file = "fastavro-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:90c9838bc4c991ffff5dd9d88a0cc0030f938b3fdf038cdf6babde144b920246"}, + {file = "fastavro-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:eca6e54da571b06a3c5a72dbb7212073f56c92a6fbfbf847b91c347510f8a426"}, + {file = "fastavro-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a4b02839ac261100cefca2e2ad04cdfedc556cb66b5ec735e0db428e74b399de"}, + {file = "fastavro-1.9.4-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:4451ee9a305a73313a1558d471299f3130e4ecc10a88bf5742aa03fb37e042e6"}, + {file = "fastavro-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8524fccfb379565568c045d29b2ebf71e1f2c0dd484aeda9fe784ef5febe1a8"}, + {file = "fastavro-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33d0a00a6e09baa20f6f038d7a2ddcb7eef0e7a9980e947a018300cb047091b8"}, + {file = "fastavro-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:23d7e5b29c9bf6f26e8be754b2c8b919838e506f78ef724de7d22881696712fc"}, + {file = "fastavro-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2e6ab3ee53944326460edf1125b2ad5be2fadd80f7211b13c45fa0c503b4cf8d"}, + {file = "fastavro-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:64d335ec2004204c501f8697c385d0a8f6b521ac82d5b30696f789ff5bc85f3c"}, + {file = "fastavro-1.9.4-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:7e05f44c493e89e73833bd3ff3790538726906d2856f59adc8103539f4a1b232"}, + {file = "fastavro-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:253c63993250bff4ee7b11fb46cf3a4622180a783bedc82a24c6fdcd1b10ca2a"}, + {file = "fastavro-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24d6942eb1db14640c2581e0ecd1bbe0afc8a83731fcd3064ae7f429d7880cb7"}, + {file = "fastavro-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d47bb66be6091cd48cfe026adcad11c8b11d7d815a2949a1e4ccf03df981ca65"}, + {file = "fastavro-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c293897f12f910e58a1024f9c77f565aa8e23b36aafda6ad8e7041accc57a57f"}, + {file = "fastavro-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:f05d2afcb10a92e2a9e580a3891f090589b3e567fdc5641f8a46a0b084f120c3"}, + {file = "fastavro-1.9.4.tar.gz", hash = "sha256:56b8363e360a1256c94562393dc7f8611f3baf2b3159f64fb2b9c6b87b14e876"}, +] + +[package.extras] +codecs = ["cramjam", "lz4", "zstandard"] +lz4 = ["lz4"] +snappy = ["cramjam"] +zstandard = ["zstandard"] + [[package]] name = "fastjsonschema" version = "2.19.1" @@ -1418,13 +1484,13 @@ test-ui = ["calysto-bash"] [[package]] name = "litellm" -version = "1.34.27" +version = "1.34.34" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.34.27-py3-none-any.whl", hash = "sha256:77c7084f4089cfc661e70730f2302ec3ddd015ae56f77a055aea0b681343e81d"}, - {file = "litellm-1.34.27.tar.gz", hash = "sha256:03bd0abaca4317e003daa011766815b97420efb6584413a8735de6bdc0835041"}, + {file = "litellm-1.34.34-py3-none-any.whl", hash = "sha256:c9eefd4b5adec3c2e6d0ab765a4fcebd475a895c7e417f47f8e677410b607f51"}, + {file = "litellm-1.34.34.tar.gz", hash = "sha256:d11c9d5296d052a9e5e1187ac7b33683f3a581740abc4de6a9c327d3f3c7187c"}, ] [package.dependencies] @@ -3659,6 +3725,20 @@ rich = ">=10.11.0" shellingham = ">=1.3.0" typing-extensions = ">=3.7.4.3" +[[package]] +name = "types-requests" +version = "2.31.0.20240406" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.31.0.20240406.tar.gz", hash = "sha256:4428df33c5503945c74b3f42e82b181e86ec7b724620419a2966e2de604ce1a1"}, + {file = "types_requests-2.31.0.20240406-py3-none-any.whl", hash = "sha256:6216cdac377c6b9a040ac1c0404f7284bd13199c0e1bb235f4324627e8898cf5"}, +] + +[package.dependencies] +urllib3 = ">=2" + [[package]] name = "typing-extensions" version = "4.11.0" @@ -3892,8 +3972,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] anthropic = ["anthropic", "xmltodict"] +cohere = ["cohere"] groq = ["groq"] -test-docs = ["anthropic", "diskcache", "fastapi", "groq", "litellm", "pandas", "pydantic_extra_types", "redis", "tabulate"] +test-docs = ["anthropic", "cohere", "diskcache", "fastapi", "groq", "litellm", "pandas", "pydantic_extra_types", "redis", "tabulate"] [metadata] lock-version = "2.0" diff --git a/pyproject.toml b/pyproject.toml index f2109b845..57c2f6747 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,11 +30,13 @@ litellm = { version = "^1.0.0", optional = true } anthropic = { version = "^0.23.1", optional = true } xmltodict = { version = "^0.13.0", optional = true } groq = { version = "^0.4.2", optional = true } +cohere = { version = "^5.1.8", optional = true } [tool.poetry.extras] anthropic = ["anthropic", "xmltodict"] groq = ["groq"] -test-docs = ["fastapi", "redis", "diskcache", "pandas", "tabulate", "pydantic_extra_types", "litellm", "anthropic", "groq"] +cohere = ["cohere"] +test-docs = ["fastapi", "redis", "diskcache", "pandas", "tabulate", "pydantic_extra_types", "litellm", "anthropic", "groq", "cohere"] [tool.poetry.scripts] instructor = "instructor.cli.cli:app" @@ -61,7 +63,7 @@ anthropic = "^0.23.1" [tool.poetry.group.test-docs.dependencies] fastapi = "^0.109.2" -redis = "^5.0.1" +redis = "^5.0.1" diskcache = "^5.6.3" pandas = "^2.2.0" tabulate = "^0.9.0" @@ -71,7 +73,7 @@ anthropic = "^0.23.1" xmltodict = "^0.13.0" groq = "^0.4.2" phonenumbers = "^8.13.33" - +cohere = "^5.1.8" [build-system] requires = ["poetry-core"] diff --git a/requirements.txt b/requirements.txt index cd007fd47..f2ffccc85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ aiohttp ruff==0.1.7 pre-commit==3.5.0 mypy==1.7.1 -typer \ No newline at end of file +typer +cohere \ No newline at end of file diff --git a/tests/test_new_client.py b/tests/test_new_client.py index 2ab98b795..a139e106c 100644 --- a/tests/test_new_client.py +++ b/tests/test_new_client.py @@ -1,9 +1,11 @@ +import cohere import os import openai import instructor import anthropic -from pydantic import BaseModel import pytest +from pydantic import BaseModel, Field +from typing import List class User(BaseModel): @@ -244,3 +246,95 @@ async def test_async_client_anthropic_bedrock_response(): ) assert user.name == "Jason" assert user.age == 10 + + +@pytest.mark.skip(reason="Skipping if Cohere API is not available") +def test_client_cohere_response(): + client = cohere.Client() + instructor_client = instructor.from_cohere( + client, + max_tokens=1000, + model="command-r-plus", + ) + + user = instructor_client.messages.create( + response_model=User, + messages=[{"role": "user", "content": "Jason is 10"}], + temperature=0, + ) + assert user.name == "Jason" + assert user.age == 10 + + + +@pytest.mark.skip(reason="Skipping if Cohere API is not available") +def test_client_cohere_response_with_nested_classes(): + client = cohere.Client() + instructor_client = instructor.from_cohere( + client, + max_tokens=1000, + model="command-r-plus", + ) + + class Person(BaseModel): + name: str = Field(description="name of the person") + country_of_origin: str = Field(description="country of origin of the person") + + class Group(BaseModel): + group_name: str = Field(description="name of the group") + members: List[Person] = Field(description="list of members in the group") + + task = """\ + Given the following text, create a Group object for 'The Beatles' band + + Text: + The Beatles were an English rock band formed in Liverpool in 1960. With a line-up comprising John Lennon, Paul McCartney, George Harrison and Ringo Starr, they are regarded as the most influential band of all time. The group were integral to the development of 1960s counterculture and popular music's recognition as an art form. + """ + group = instructor_client.messages.create( + response_model=Group, + messages=[{"role": "user", "content": task}], + temperature=0, + ) + assert group.group_name == "The Beatles" + assert len(group.members) == 4 + assert group.members[0].name == "John Lennon" + assert group.members[1].name == "Paul McCartney" + assert group.members[2].name == "George Harrison" + assert group.members[3].name == "Ringo Starr" + + +@pytest.mark.skip(reason="Skipping if Cohere API is not available") +@pytest.mark.asyncio +async def test_client_cohere_async(): + client = cohere.AsyncClient() + instructor_client = instructor.from_cohere( + client, + max_tokens=1000, + model="command-r-plus", + ) + + class Person(BaseModel): + name: str = Field(description="name of the person") + country_of_origin: str = Field(description="country of origin of the person") + + class Group(BaseModel): + group_name: str = Field(description="name of the group") + members: List[Person] = Field(description="list of members in the group") + + task = """\ + Given the following text, create a Group object for 'The Beatles' band + + Text: + The Beatles were an English rock band formed in Liverpool in 1960. With a line-up comprising John Lennon, Paul McCartney, George Harrison and Ringo Starr, they are regarded as the most influential band of all time. The group were integral to the development of 1960s counterculture and popular music's recognition as an art form. + """ + group = await instructor_client.messages.create( + response_model=Group, + messages=[{"role": "user", "content": task}], + temperature=0, + ) + assert group.group_name == "The Beatles" + assert len(group.members) == 4 + assert group.members[0].name == "John Lennon" + assert group.members[1].name == "Paul McCartney" + assert group.members[2].name == "George Harrison" + assert group.members[3].name == "Ringo Starr"