diff --git a/libs/community/langchain_community/chat_models/sparkllm.py b/libs/community/langchain_community/chat_models/sparkllm.py index 7d122f0a0c89d0..20dc1380c9d658 100644 --- a/libs/community/langchain_community/chat_models/sparkllm.py +++ b/libs/community/langchain_community/chat_models/sparkllm.py @@ -43,6 +43,9 @@ logger = logging.getLogger(__name__) +SPARK_API_URL = "wss://spark-api.xf-yun.com/v3.5/chat" +SPARK_LLM_DOMAIN = "generalv3.5" + def _convert_message_to_dict(message: BaseMessage) -> dict: if isinstance(message, ChatMessage): @@ -108,7 +111,7 @@ class ChatSparkLLM(BaseChatModel): Extra infos: 1. Get app_id, api_key, api_secret from the iFlyTek Open Platform Console: https://console.xfyun.cn/services/bm35 - 2. By default, iFlyTek Spark LLM V3.0 is invoked. + 2. By default, iFlyTek Spark LLM V3.5 is invoked. If you need to invoke other versions, please configure the corresponding parameters(spark_api_url and spark_llm_domain) according to the document: https://www.xfyun.cn/doc/spark/Web.html @@ -134,17 +137,31 @@ def lc_secrets(self) -> Dict[str, str]: } client: Any = None #: :meta private: - spark_app_id: Optional[str] = None + spark_app_id: Optional[str] = Field(default=None, alias="app_id") + """Automatically inferred from env var `IFLYTEK_SPARK_APP_ID` + if not provided.""" spark_api_key: Optional[str] = Field(default=None, alias="api_key") - spark_api_secret: Optional[str] = None - spark_api_url: Optional[str] = None - spark_llm_domain: Optional[str] = None + """Automatically inferred from env var `IFLYTEK_SPARK_API_KEY` + if not provided.""" + spark_api_secret: Optional[str] = Field(default=None, alias="api_secret") + """Automatically inferred from env var `IFLYTEK_SPARK_API_SECRET` + if not provided.""" + spark_api_url: Optional[str] = Field(default=None, alias="api_url") + """Base URL path for API requests, leave blank if not using a proxy or service + emulator.""" + spark_llm_domain: Optional[str] = Field(default=None, alias="model") + """Model name to use.""" spark_user_id: str = "lc_user" streaming: bool = False + """Whether to stream the results or not.""" request_timeout: int = Field(30, alias="timeout") + """request timeout for chat http requests""" temperature: float = Field(default=0.5) + """What sampling temperature to use.""" top_k: int = 4 + """What search sampling control to use.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for API call not explicitly specified.""" class Config: """Configuration for this pydantic object.""" @@ -199,13 +216,13 @@ def validate_environment(cls, values: Dict) -> Dict: values, "spark_api_url", "IFLYTEK_SPARK_API_URL", - "wss://spark-api.xf-yun.com/v3.1/chat", + SPARK_API_URL, ) values["spark_llm_domain"] = get_from_dict_or_env( values, "spark_llm_domain", "IFLYTEK_SPARK_LLM_DOMAIN", - "generalv3", + SPARK_LLM_DOMAIN, ) # put extra params into model_kwargs values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature @@ -307,12 +324,10 @@ def __init__( "Please install it with `pip install websocket-client`." ) - self.api_url = ( - "wss://spark-api.xf-yun.com/v3.1/chat" if not api_url else api_url - ) + self.api_url = SPARK_API_URL if not api_url else api_url self.app_id = app_id self.model_kwargs = model_kwargs - self.spark_domain = spark_domain or "generalv3" + self.spark_domain = spark_domain or SPARK_LLM_DOMAIN self.queue: Queue[Dict] = Queue() self.blocking_message = {"content": "", "role": "assistant"} self.api_key = api_key