Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tiktoken override #6697

Merged
merged 3 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions langchain/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,16 @@ def lc_serializable(self) -> bool:
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""

class Config:
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -448,15 +458,18 @@ def _llm_type(self) -> str:

def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
tiktoken_ = _import_tiktoken()
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.
model = "gpt-3.5-turbo-0301"
elif model == "gpt-4":
# gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314"
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.
model = "gpt-3.5-turbo-0301"
elif model == "gpt-4":
# gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314"
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model(model)
Expand Down
26 changes: 24 additions & 2 deletions langchain/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout in seconds for the OpenAPI request."""
headers: Any = None
tiktoken_model_name: Optional[str] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add docstring for the reference

"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""

class Config:
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -265,7 +275,13 @@ def _get_len_safe_embeddings(

tokens = []
indices = []
encoding = tiktoken.model.encoding_for_model(self.model)
model_name = self.tiktoken_model_name or self.model
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, text in enumerate(texts):
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
Expand Down Expand Up @@ -329,7 +345,13 @@ async def _aget_len_safe_embeddings(

tokens = []
indices = []
encoding = tiktoken.model.encoding_for_model(self.model)
model_name = self.tiktoken_model_name or self.model
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, text in enumerate(texts):
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
Expand Down
18 changes: 17 additions & 1 deletion langchain/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ def lc_serializable(self) -> bool:
"""Set of special tokens that are allowed。"""
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
"""Set of special tokens that are not allowed。"""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""

def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
"""Initialize the OpenAI object."""
Expand Down Expand Up @@ -491,7 +501,13 @@ def get_token_ids(self, text: str) -> List[int]:
"Please install it with `pip install tiktoken`."
)

enc = tiktoken.encoding_for_model(self.model_name)
model_name = self.tiktoken_model_name or self.model_name
try:
enc = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
enc = tiktoken.get_encoding(model)

return enc.encode(
text,
Expand Down