-
-
Notifications
You must be signed in to change notification settings - Fork 613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for cohere models #585
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
6e1bbfb
Add support for cohere
lidiyam ba6c246
Add tests
lidiyam f1281c1
ask for json directly
lidiyam 8fb352a
Skip cohere tests
lidiyam 932c209
clean up
lidiyam c373abb
Add support for async client
lidiyam b8e8e69
Resolve merge conflict
lidiyam 66cbed1
bump
jxnl d962128
Add cohere docs
lidiyam 4aa845e
merge conflicts
lidiyam 128defb
add cohere to examples
lidiyam 04c9408
Document Segmentation with Cohere example
lidiyam a98ac66
resolve conflicts
lidiyam e6b46d7
Add links to mkdocs
lidiyam 420e7f4
update mkdocs
lidiyam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Document Segmentation | ||
|
||
In this guide, we demonstrate how to do document segmentation using structured output from an LLM. We'll be using [command-r-plus](https://docs.cohere.com/docs/command-r-plus) - one of Cohere's latest LLMs with 128k context length and testing the approach on an article explaining the Transformer architecture. Same approach to document segmentation can be applied to any other domain where we need to break down a complex long document into smaller chunks. | ||
|
||
!!! tips "Motivation" | ||
Sometimes we need a way to split the document into meaningful parts that center around a signle key concept/idea. Simple length-based / rule-based text-splitters are not reliable enough. Consider the cases where documents contain code snippets or math equations - we don't want to split those on `'\n\n'` or have to write extensive rules for different types of documents. It turns out that LLMs with sufficiently long context length are well suited for this task. | ||
|
||
## Defining the Data Structures | ||
|
||
First, we need to define a **`Section`** class for each of the document's segments. **`StructuredDocument`** class will then encapsulate a list of these sections. | ||
|
||
Note that in order to avoid LLM regenerating the content of each section, we can simply enumerate each line of the input document and then ask LLM to segment it by providing start-end line numbers for each section. | ||
|
||
```python | ||
from pydantic import BaseModel, Field | ||
from typing import List, Dict, Any | ||
|
||
class Section(BaseModel): | ||
title: str = Field(description="main topic of this section of the document") | ||
start_index: int = Field(description="line number where the section begins") | ||
end_index: int = Field(description="line number where the section ends") | ||
|
||
|
||
class StructuredDocument(BaseModel): | ||
"""obtains meaningful sections, each centered around a single concept/topic""" | ||
sections: List[Section] = Field(description="a list of sections of the document") | ||
``` | ||
|
||
## Document Preprocessing | ||
|
||
Preprocess the input `document` by prepending each line with its number. | ||
|
||
```python | ||
def doc_with_lines(document): | ||
document_lines = document.split("\n") | ||
document_with_line_numbers = "" | ||
line2text = {} | ||
for i, line in enumerate(document_lines): | ||
document_with_line_numbers += f"[{i}] {line}\n" | ||
line2text[i] = line | ||
return document_with_line_numbers, line2text | ||
``` | ||
|
||
## Segmentation | ||
|
||
Next use a Cohere client to extract `StructuredDocument` from the preprocessed doc. | ||
|
||
```python | ||
import instructor | ||
import cohere | ||
|
||
# Apply the patch to the cohere client | ||
# enables response_model keyword | ||
client = instructor.from_cohere(cohere.Client()) | ||
|
||
|
||
system_prompt = f"""\ | ||
You are a world class educator working on organizing your lecture notes. | ||
Read the document below and extract a StructuredDocument object from it where each section of the document is centered around a single concept/topic that can be taught in one lesson. | ||
Each line of the document is marked with its line number in square brackets (e.g. [1], [2], [3], etc). Use the line numbers to indicate section start and end. | ||
""" | ||
|
||
|
||
def get_structured_document(document_with_line_numbers) -> StructuredDocument: | ||
return client.chat.completions.create( | ||
model="command-r-plus", | ||
response_model=StructuredDocument, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": system_prompt, | ||
}, | ||
{ | ||
"role": "user", | ||
"content": document_with_line_numbers, | ||
}, | ||
], | ||
) # type: ignore | ||
``` | ||
|
||
|
||
Next, we need to get back the section text based on the start/end indices and our `line2text` dict from the preprocessing step. | ||
|
||
```python | ||
def get_sections_text(structured_doc, line2text): | ||
segments = [] | ||
for s in structured_doc.sections: | ||
contents = [] | ||
for line_id in range(s.start_index, s.end_index): | ||
contents.append(line2text.get(line_id, '')) | ||
segments.append({ | ||
"title": s.title, | ||
"content": "\n".join(contents), | ||
"start": s.start_index, | ||
"end": s.end_index | ||
}) | ||
return segments | ||
``` | ||
|
||
|
||
## Example | ||
|
||
Here's an example of using these classes and functions to segment a tutorial on Transformers from [Sebastian Raschka](https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html). We can use `trafilatura` package to scrape the web page content of the article. | ||
|
||
```python | ||
from trafilatura import fetch_url, extract | ||
|
||
|
||
url='https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html' | ||
downloaded = fetch_url(url) | ||
document = extract(downloaded) | ||
|
||
|
||
document_with_line_numbers, line2text = doc_with_lines(document) | ||
structured_doc = get_structured_document(document_with_line_numbers) | ||
segments = get_sections_text(structured_doc, line2text) | ||
``` | ||
|
||
``` | ||
print(segments[5]['title']) | ||
""" | ||
Introduction to Multi-Head Attention | ||
""" | ||
print(segments[5]['content']) | ||
""" | ||
Multi-Head Attention | ||
In the very first figure, at the top of this article, we saw that transformers use a module called multi-head attention. How does that relate to the self-attention mechanism (scaled-dot product attention) we walked through above? | ||
In the scaled dot-product attention, the input sequence was transformed using three matrices representing the query, key, and value. These three matrices can be considered as a single attention head in the context of multi-head attention. The figure below summarizes this single attention head we covered previously: | ||
As its name implies, multi-head attention involves multiple such heads, each consisting of query, key, and value matrices. This concept is similar to the use of multiple kernels in convolutional neural networks. | ||
To illustrate this in code, suppose we have 3 attention heads, so we now extend the \(d' \times d\) dimensional weight matrices so \(3 \times d' \times d\): | ||
In: | ||
h = 3 | ||
multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d)) | ||
multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d)) | ||
multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d)) | ||
Consequently, each query element is now \(3 \times d_q\) dimensional, where \(d_q=24\) (here, let’s keep the focus on the 3rd element corresponding to index position 2): | ||
In: | ||
multihead_query_2 = multihead_W_query.matmul(x_2) | ||
print(multihead_query_2.shape) | ||
Out: | ||
torch.Size([3, 24]) | ||
""" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Structured Outputs with Cohere | ||
|
||
If you want to try this example using `instructor hub`, you can pull it by running | ||
|
||
```bash | ||
instructor hub pull --slug cohere --py > cohere_example.py | ||
``` | ||
|
||
You can now use any of the Cohere's [command models](https://docs.cohere.com/docs/models) with the `instructor` library to get structured outputs. | ||
|
||
You'll need a cohere API key which can be obtained by signing up [here](https://dashboard.cohere.com/) and gives you [free](https://cohere.com/pricing), rate-limited usage for learning and prototyping. | ||
|
||
## Setup | ||
``` | ||
pip install cohere | ||
``` | ||
Export your key: | ||
``` | ||
export CO_API_KEY=<YOUR_COHERE_API_KEY> | ||
``` | ||
|
||
## Example | ||
|
||
```python | ||
from pydantic import BaseModel, Field | ||
from typing import List | ||
import cohere | ||
import instructor | ||
|
||
|
||
# Patching the Cohere client with the instructor for enhanced capabilities | ||
client = instructor.from_cohere( | ||
cohere.Client(), | ||
max_tokens=1000, | ||
model="command-r-plus", | ||
) | ||
|
||
|
||
class Person(BaseModel): | ||
name: str = Field(description="name of the person") | ||
country_of_origin: str = Field(description="country of origin of the person") | ||
|
||
|
||
class Group(BaseModel): | ||
group_name: str = Field(description="name of the group") | ||
members: List[Person] = Field(description="list of members in the group") | ||
|
||
|
||
task = """\ | ||
Given the following text, create a Group object for 'The Beatles' band | ||
|
||
Text: | ||
The Beatles were an English rock band formed in Liverpool in 1960. With a line-up comprising John Lennon, Paul McCartney, George Harrison and Ringo Starr, they are regarded as the most influential band of all time. The group were integral to the development of 1960s counterculture and popular music's recognition as an art form. | ||
""" | ||
group = client.messages.create( | ||
response_model=Group, | ||
messages=[{"role": "user", "content": task}], | ||
temperature=0, | ||
) | ||
|
||
print(group.model_dump_json(indent=2)) | ||
""" | ||
{ | ||
"group_name": "The Beatles", | ||
"members": [ | ||
{ | ||
"name": "John Lennon", | ||
"country_of_origin": "England" | ||
}, | ||
{ | ||
"name": "Paul McCartney", | ||
"country_of_origin": "England" | ||
}, | ||
{ | ||
"name": "George Harrison", | ||
"country_of_origin": "England" | ||
}, | ||
{ | ||
"name": "Ringo Starr", | ||
"country_of_origin": "England" | ||
} | ||
] | ||
} | ||
""" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import cohere | ||
import instructor | ||
from pydantic import BaseModel, Field | ||
from typing import List | ||
|
||
|
||
# Patching the Cohere client with the instructor for enhanced capabilities | ||
client = instructor.from_cohere( | ||
cohere.Client(), | ||
max_tokens=1000, | ||
model="command-r-plus", | ||
) | ||
|
||
|
||
class Person(BaseModel): | ||
name: str = Field(description="name of the person") | ||
country_of_origin: str = Field(description="country of origin of the person") | ||
|
||
|
||
class Group(BaseModel): | ||
group_name: str = Field(description="name of the group") | ||
members: List[Person] = Field(description="list of members in the group") | ||
|
||
|
||
task = """\ | ||
Given the following text, create a Group object for 'The Beatles' band | ||
|
||
Text: | ||
The Beatles were an English rock band formed in Liverpool in 1960. With a line-up comprising John Lennon, Paul McCartney, George Harrison and Ringo Starr, they are regarded as the most influential band of all time. The group were integral to the development of 1960s counterculture and popular music's recognition as an art form. | ||
""" | ||
group = client.messages.create( | ||
response_model=Group, | ||
messages=[{"role": "user", "content": task}], | ||
temperature=0, | ||
) | ||
|
||
print(group.model_dump_json(indent=2)) | ||
""" | ||
{ | ||
"group_name": "The Beatles", | ||
"members": [ | ||
{ | ||
"name": "John Lennon", | ||
"country_of_origin": "England" | ||
}, | ||
{ | ||
"name": "Paul McCartney", | ||
"country_of_origin": "England" | ||
}, | ||
{ | ||
"name": "George Harrison", | ||
"country_of_origin": "England" | ||
}, | ||
{ | ||
"name": "Ringo Starr", | ||
"country_of_origin": "England" | ||
} | ||
] | ||
} | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you'll have to add this in the mkdocs.yml file
or even reference it in examples/index.md