Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Community/Partners: Add support for Perplexity AI #17024

Merged
merged 33 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f816626
Add Perplexity boilerplate
atherfawaz Feb 4, 2024
994af05
Sanitize
atherfawaz Feb 4, 2024
e0d467f
Add tests
atherfawaz Feb 4, 2024
c011593
Lint
atherfawaz Feb 4, 2024
9cf23c1
Complete unit tests
atherfawaz Feb 4, 2024
cf30da5
Remove unused fn
atherfawaz Feb 4, 2024
d808172
Update perplexties
atherfawaz Feb 5, 2024
7e9642d
Merge branch 'master' into master
atherfawaz Feb 5, 2024
cb6b5cc
cr
baskaryan Feb 5, 2024
78a1ee8
cr
baskaryan Feb 5, 2024
9464b3c
fmt
baskaryan Feb 5, 2024
5577809
Merge branch 'master' into atherfawaz/master
baskaryan Feb 6, 2024
ea246ca
fmt
baskaryan Feb 6, 2024
e44a0c5
fmt
baskaryan Feb 6, 2024
11d5c02
nit
baskaryan Feb 6, 2024
3e2c640
Add docs
atherfawaz Feb 6, 2024
7ba1a7f
Merge branch 'master' into master
atherfawaz Feb 6, 2024
62239b2
Lint
atherfawaz Feb 7, 2024
a54ebe0
Merge branch 'master' into master
atherfawaz Feb 8, 2024
28150d3
Merge branch 'master' into atherfawaz/master
efriis Feb 10, 2024
346c03b
PerplexityLLM
efriis Feb 10, 2024
ec1995e
LLM
efriis Feb 10, 2024
84aeda5
Merge branch 'master' into master
atherfawaz Feb 19, 2024
29fd27d
Move module import inside
atherfawaz Feb 21, 2024
6669968
Merge branch 'master' into master
atherfawaz Feb 21, 2024
fc43713
Merge branch 'master' into master
atherfawaz Feb 23, 2024
faf4110
Merge branch 'master' into master
atherfawaz Feb 24, 2024
809725d
Remove PerplexityLLM
atherfawaz Feb 24, 2024
0079edd
Lint, remove LLM notebook
atherfawaz Feb 26, 2024
d59c62d
Merge branch 'master' into master
atherfawaz Feb 26, 2024
3f67e0e
Lint, remove LLM notebook
atherfawaz Feb 26, 2024
003822b
Merge branch 'master' into master
atherfawaz Mar 1, 2024
c323b7c
Remove unused import
atherfawaz Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions libs/community/langchain_community/chat_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from langchain_community.chat_models.ollama import ChatOllama
from langchain_community.chat_models.openai import ChatOpenAI
from langchain_community.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
from langchain_community.chat_models.perplexity import ChatPerplexity
from langchain_community.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from langchain_community.chat_models.sparkllm import ChatSparkLLM
from langchain_community.chat_models.tongyi import ChatTongyi
Expand Down Expand Up @@ -95,4 +96,5 @@
"VolcEngineMaasChat",
"GPTRouter",
"ChatZhipuAI",
"ChatPerplexity",
]
275 changes: 275 additions & 0 deletions libs/community/langchain_community/chat_models/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
"""Wrapper around Perplexity APIs."""
from __future__ import annotations

import logging
from typing import (
Any,
Dict,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
GenerationChunk,
)
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names

logger = logging.getLogger(__name__)


class ChatPerplexity(BaseChatModel):
"""`Perplexity AI` Chat models API.

To use, you should have the ``openai`` python package installed, and the
environment variable ``PPLX_API_KEY`` set to your API key.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.

Example:
.. code-block:: python

from langchain.chat_models import ChatPerplexity

chat = ChatPerplexity()
"""

client: Any #: :meta private:
temperature: float = 0.7
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
pplx_api_key: Optional[str] = None
"""Base URL path for API requests,
leave blank if not using a proxy or service emulator."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to PerplexityChat completion API. Default is 600 seconds."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
model: str = Field(default="pplx-70b-online", alias="model")

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@property
def lc_secrets(self) -> Dict[str, str]:
return {"pplx_api_key": "PPLX_API_KEY"}

@root_validator(pre=True, allow_reuse=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not a default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)

invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)

values["model_kwargs"] = extra
return values

@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["pplx_api_key"] = get_from_dict_or_env(
values, "pplx_api_key", "PPLX_API_KEY"
)
try:
import openai # noqa: F401
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.OpenAI(
api_key=values["pplx_api_key"], base_url="https://api.perplexity.ai"
)
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
return values

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling PerplexityChat API."""
return {
"request_timeout": self.request_timeout,
"max_tokens": self.max_tokens,
"stream": self.streaming,
"temperature": self.temperature,
**self.model_kwargs,
}

def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict

def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = dict(self._invocation_params)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [self._convert_message_to_dict(m) for m in messages]
return message_dicts, params

def _convert_delta_to_message_chunk(
self, _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]

if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
default_chunk_class = AIMessageChunk

if stop:
params["stop_sequences"] = stop
stream_resp = self.client.chat.completions.create(
model=params["model"], messages=message_dicts, stream=True
)
for chunk in stream_resp:
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = self._convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
if stream_iter:
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.client.chat.completions.create(
model=params["model"], messages=message_dicts
)
message = AIMessage(content=response.choices[0].message.content)
return ChatResult(generations=[ChatGeneration(message=message)])

@property
def _invocation_params(self) -> Mapping[str, Any]:
"""Get the parameters used to invoke the model."""
pplx_creds: Dict[str, Any] = {
"api_key": self.pplx_api_key,
"api_base": "https://api.perplexity.ai",
"model": self.model,
}
return {**pplx_creds, **self._default_params}

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "perplexitychat"
Loading
Loading