Skip to content

Commit

Permalink
core[minor],community[minor]: Upgrade all @root_validator() to @pre_i…
Browse files Browse the repository at this point in the history
…nit (#23841)

This PR introduces a @pre_init decorator that's a @root_validator(pre=True) but with all the defaults populated!
  • Loading branch information
eyurtsev committed Jul 8, 2024
1 parent f152d6e commit 2c180d6
Show file tree
Hide file tree
Showing 114 changed files with 439 additions and 276 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from langchain_core._api.deprecation import deprecated
from langchain_core.outputs import ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils import get_from_dict_or_env, pre_init

from langchain_community.chat_models.openai import ChatOpenAI
from langchain_community.utils.openai import is_openai_v1
Expand Down Expand Up @@ -106,7 +106,7 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "azure_openai"]

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
Expand Down
5 changes: 2 additions & 3 deletions libs/community/langchain_community/chat_models/edenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool

from langchain_community.utilities.requests import Requests
Expand Down Expand Up @@ -300,7 +299,7 @@ class Config:

extra = Extra.forbid

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["edenai_api_key"] = convert_to_secret_str(
Expand Down
6 changes: 3 additions & 3 deletions libs/community/langchain_community/chat_models/google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
ChatGeneration,
ChatResult,
)
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.pydantic_v1 import BaseModel, SecretStr
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from tenacity import (
before_sleep_log,
retry,
Expand Down Expand Up @@ -261,7 +261,7 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "google_palm"]

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
google_api_key = convert_to_secret_str(
Expand Down
3 changes: 2 additions & 1 deletion libs/community/langchain_community/chat_models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
pre_init,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -190,7 +191,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["model_kwargs"] = extra
return values

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
values["hunyuan_api_base"] = get_from_dict_or_env(
values,
Expand Down
3 changes: 2 additions & 1 deletion libs/community/langchain_community/chat_models/jinachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
pre_init,
)
from tenacity import (
before_sleep_log,
Expand Down Expand Up @@ -218,7 +219,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["model_kwargs"] = extra
return values

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["jinachat_api_key"] = convert_to_secret_str(
Expand Down
6 changes: 4 additions & 2 deletions libs/community/langchain_community/chat_models/kinetica.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast

from langchain_core.utils import pre_init

if TYPE_CHECKING:
import gpudb

Expand All @@ -24,7 +26,7 @@
)
from langchain_core.output_parsers.transform import BaseOutputParser
from langchain_core.outputs import ChatGeneration, ChatResult, Generation
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -341,7 +343,7 @@ class ChatKinetica(BaseChatModel):
kdbc: Any = Field(exclude=True)
""" Kinetica DB connection. """

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Pydantic object validator."""

Expand Down
6 changes: 3 additions & 3 deletions libs/community/langchain_community/chat_models/konko.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
)
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.pydantic_v1 import Field, SecretStr
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init

from langchain_community.adapters.openai import (
convert_message_to_dict,
Expand Down Expand Up @@ -85,7 +85,7 @@ def is_lc_serializable(cls) -> bool:
max_tokens: int = 20
"""Maximum number of tokens to generate."""

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["konko_api_key"] = convert_to_secret_str(
Expand Down
6 changes: 3 additions & 3 deletions libs/community/langchain_community/chat_models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env
from langchain_core.utils import get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -249,7 +249,7 @@ def _completion_with_retry(**kwargs: Any) -> Any:

return _completion_with_retry(**kwargs)

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
try:
Expand Down
4 changes: 2 additions & 2 deletions libs/community/langchain_community/chat_models/moonshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from typing import Dict

from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
pre_init,
)

from langchain_community.chat_models import ChatOpenAI
Expand All @@ -29,7 +29,7 @@ class MoonshotChat(MoonshotCommon, ChatOpenAI): # type: ignore[misc]
moonshot = MoonshotChat(model="moonshot-v1-8k")
"""

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the environment is set up correctly."""
values["moonshot_api_key"] = convert_to_secret_str(
Expand Down
6 changes: 3 additions & 3 deletions libs/community/langchain_community/chat_models/octoai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from typing import Dict

from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.pydantic_v1 import Field, SecretStr
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init

from langchain_community.chat_models.openai import ChatOpenAI
from langchain_community.utils.openai import is_openai_v1
Expand Down Expand Up @@ -48,7 +48,7 @@ def lc_secrets(self) -> Dict[str, str]:
def is_lc_serializable(cls) -> bool:
return False

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["octoai_api_base"] = get_from_dict_or_env(
Expand Down
3 changes: 2 additions & 1 deletion libs/community/langchain_community/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from langchain_core.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
pre_init,
)

from langchain_community.adapters.openai import (
Expand Down Expand Up @@ -274,7 +275,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["model_kwargs"] = extra
return values

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
Expand Down
5 changes: 2 additions & 3 deletions libs/community/langchain_community/chat_models/premai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import get_from_dict_or_env
from langchain_core.utils import get_from_dict_or_env, pre_init

if TYPE_CHECKING:
from premai.api.chat_completions.v1_chat_completions_create import (
Expand Down Expand Up @@ -249,7 +248,7 @@ class Config:
allow_population_by_field_name = True
arbitrary_types_allowed = True

@root_validator()
@pre_init
def validate_environments(cls, values: Dict) -> Dict:
"""Validate that the package is installed and that the API token is valid"""
try:
Expand Down
3 changes: 2 additions & 1 deletion libs/community/langchain_community/chat_models/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
pre_init,
)
from langchain_core.utils.utils import build_extra_kwargs

Expand Down Expand Up @@ -135,7 +136,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
)
return values

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
try:
from snowflake.snowpark import Session
Expand Down
6 changes: 3 additions & 3 deletions libs/community/langchain_community/chat_models/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Dict

from langchain_core._api import deprecated
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import Field
from langchain_core.utils import get_from_dict_or_env, pre_init

from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms.solar import SOLAR_SERVICE_URL_BASE, SolarCommon
Expand Down Expand Up @@ -37,7 +37,7 @@ class Config:
arbitrary_types_allowed = True
extra = "ignore"

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the environment is set up correctly."""
values["solar_api_key"] = get_from_dict_or_env(
Expand Down
6 changes: 3 additions & 3 deletions libs/community/langchain_community/chat_models/tongyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool
from requests.exceptions import HTTPError
from tenacity import (
Expand Down Expand Up @@ -431,7 +431,7 @@ def _llm_type(self) -> str:
"""Return type of llm."""
return "tongyi"

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["dashscope_api_key"] = convert_to_secret_str(
Expand Down
4 changes: 2 additions & 2 deletions libs/community/langchain_community/chat_models/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import pre_init

from langchain_community.llms.vertexai import (
_VertexAICommon,
Expand Down Expand Up @@ -225,7 +225,7 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "vertexai"]

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
is_gemini = is_gemini_model(values["model_name"])
Expand Down
3 changes: 2 additions & 1 deletion libs/community/langchain_community/chat_models/yuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from langchain_core.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
pre_init,
)
from tenacity import (
before_sleep_log,
Expand Down Expand Up @@ -166,7 +167,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["model_kwargs"] = extra
return values

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["yuan2_api_key"] = get_from_dict_or_env(
Expand Down
6 changes: 3 additions & 3 deletions libs/community/langchain_community/embeddings/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from typing import Dict

from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.pydantic_v1 import Field, SecretStr
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init

from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.utils.openai import is_openai_v1
Expand Down Expand Up @@ -34,7 +34,7 @@ def lc_secrets(self) -> Dict[str, str]:
"anyscale_api_key": "ANYSCALE_API_KEY",
}

@root_validator()
@pre_init
def validate_environment(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
values["anyscale_api_key"] = convert_to_secret_str(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,7 +48,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""extra params for model invoke using with `do`."""

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""
Validate whether qianfan_ak and qianfan_sk in the environment variables or
Expand Down
6 changes: 3 additions & 3 deletions libs/community/langchain_community/embeddings/deepinfra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.utils import get_from_dict_or_env, pre_init

DEFAULT_MODEL_ID = "sentence-transformers/clip-ViT-B-32"
MAX_BATCH_SIZE = 1024
Expand Down Expand Up @@ -59,7 +59,7 @@ class Config:

extra = Extra.forbid

@root_validator()
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
deepinfra_api_token = get_from_dict_or_env(
Expand Down
Loading

0 comments on commit 2c180d6

Please sign in to comment.