Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic V2 migration #9193

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mindsdb/integrations/handlers/chromadb_handler/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import difflib

from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, model_validator


class ChromaHandlerConfig(BaseModel):
Expand All @@ -15,13 +15,13 @@ class ChromaHandlerConfig(BaseModel):
password: str = None

class Config:
extra = Extra.forbid
extra = "forbid"

@root_validator(pre=True, allow_reuse=True, skip_on_failure=True)
@model_validator(mode="before", skip_on_failure=True)
def check_param_typos(cls, values):
"""Check if there are any typos in the parameters."""

expected_params = cls.__fields__.keys()
expected_params = cls.model_fields.keys()
for key in values.keys():
if key not in expected_params:
close_matches = difflib.get_close_matches(
Expand All @@ -35,7 +35,7 @@ def check_param_typos(cls, values):
raise ValueError(f"Unexpected parameter '{key}'.")
return values

@root_validator(allow_reuse=True, skip_on_failure=True)
@model_validator(skip_on_failure=True)
def check_config(cls, values):
"""Check if config is valid."""

Expand Down
4 changes: 2 additions & 2 deletions mindsdb/integrations/handlers/email_handler/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class EmailSearchOptions(BaseModel):
since_email_id: str = None

class Config:
schema_extra = {
json_schema_extra = {
"example": {
"mailbox": "INBOX",
"subject": "Test",
Expand All @@ -47,7 +47,7 @@ class EmailConnectionDetails(BaseModel):
smtp_port: int = 587

class Config:
schema_extra = {
json_schema_extra = {
"example": {
"email": "joe@bloggs.com",
"password": "password",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _get_chat_model_params(self, args: Dict, pred_args: Dict) -> Dict:
p: get_api_key(p, args, self.engine_storage, strict=False) for p in SUPPORTED_PROVIDERS
}
llm_config = get_llm_config(args.get('provider', self._get_llm_provider(args)), model_config)
config_dict = llm_config.dict()
config_dict = llm_config.model_dump()
config_dict = {k: v for k, v in config_dict.items() if v is not None}
return config_dict

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def create(self, target: str, df: pd.DataFrame = None, args: Optional[Dict] = No
model_parameters = ModelParameters(**args["model_params"])

# store model parameters
args["model_params"] = model_parameters.dict()
args["model_params"] = model_parameters.model_dump()

rec_preprocessor = RecommenderPreprocessor(
interaction_data=df,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create(
input_args.update({k: v for k, v in ml_engine_args.items()})

# validate args
export_args = CompletionParameters(**input_args).dict()
export_args = CompletionParameters(**input_args).model_dump()

# store args
self.model_storage.json_set("args", export_args)
Expand All @@ -64,7 +64,7 @@ def predict(self, df: pd.DataFrame = None, args: dict = None):
input_args = self.model_storage.json_get("args")

# validate args
args = CompletionParameters(**input_args).dict()
args = CompletionParameters(**input_args).model_dump()

# build messages
self._build_messages(args, df)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _unpack_config(self):
"""
try:
config = ConnectionConfig(**self.connection_data)
return config.dict(exclude_unset=True)
return config.model_dump(exclude_unset=True)
except ValueError as e:
raise ValueError(str(e))

Expand Down
4 changes: 2 additions & 2 deletions mindsdb/integrations/handlers/mysql_handler/settings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, AnyUrl, root_validator
from pydantic import BaseModel, AnyUrl, model_validator
from urllib.parse import urlparse


Expand All @@ -11,7 +11,7 @@ class ConnectionConfig(BaseModel):
password: str = None
database: str = None

@root_validator(pre=True)
@model_validator(mode="before")
def check_db_params(cls, values):
"""Ensures either URL is provided or all individual parameters are provided."""
url = values.get('url')
Expand Down
2 changes: 1 addition & 1 deletion mindsdb/integrations/handlers/palm_handler/palm_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def create(self, target, args=None, **kwargs):
f"Invalid operation mode. Please use one of {self.supported_modes}"
)

self.model_storage.json_set("args", args_model.dict())
self.model_storage.json_set("args", args_model.model_dump())

def predict(self, df, args=None):
"""
Expand Down
2 changes: 1 addition & 1 deletion mindsdb/integrations/handlers/rag_handler/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, args: RAGBaseParameters):

if isinstance(args, RAGHandlerParameters):

llm_config = {"llm_config": args.llm_params.dict()}
llm_config = {"llm_config": args.llm_params.model_dump()}

llm_loader = LLMLoader(**llm_config)

Expand Down
14 changes: 7 additions & 7 deletions mindsdb/integrations/handlers/rag_handler/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class OpenAIParameters(LLMParameters):
model_id: str = Field(default="gpt-3.5-turbo-instruct", title="model name")
n: int = Field(default=1, title="number of responses to return")

@validator("model_id", allow_reuse=True)
@field_validator("model_id")
def openai_model_must_be_supported(cls, v, values):
supported_models = get_available_openai_model_ids(values)
if v not in supported_models:
Expand All @@ -238,7 +238,7 @@ class WriterLLMParameters(LLMParameters):
callbacks: List[StreamingStdOutCallbackHandler] = [StreamingStdOutCallbackHandler()]
verbose: bool = False

@validator("model_id", allow_reuse=True)
@field_validator("model_id")
def writer_model_must_be_supported(cls, v, values):
supported_models = get_available_writer_model_ids(values)
if v not in supported_models:
Expand Down Expand Up @@ -307,7 +307,7 @@ class Config:
arbitrary_types_allowed = True
use_enum_values = True

@validator("prompt_template", allow_reuse=True)
@field_validator("prompt_template")
def prompt_format_must_be_valid(cls, v):
if "{context}" not in v or "{question}" not in v:
raise InvalidPromptTemplate(
Expand All @@ -316,11 +316,11 @@ def prompt_format_must_be_valid(cls, v):
)
return v

@validator("vector_store_name", allow_reuse=True)
@field_validator("vector_store_name")
def name_must_be_lower(cls, v):
return v.lower()

@validator("vector_store_name", allow_reuse=True)
@field_validator("vector_store_name")
def vector_store_must_be_supported(cls, v):
if not is_valid_store(v):
raise UnsupportedVectorStore(
Expand All @@ -335,7 +335,7 @@ class RAGHandlerParameters(RAGBaseParameters):
llm_type: str
llm_params: LLMParameters

@validator("llm_type", allow_reuse=True)
@field_validator("llm_type")
def llm_type_must_be_supported(cls, v):
if v not in SUPPORTED_LLMS:
raise UnsupportedLLM(f"'llm_type' must be one of {SUPPORTED_LLMS}, got {v}")
Expand Down Expand Up @@ -437,7 +437,7 @@ def on_create_build_llm_params(

llm_params = {"llm_name": args["llm_type"]}

for param in llm_config_class.__fields__.keys():
for param in llm_config_class.model_fields.keys():
if param in args:
llm_params[param] = args.pop(param)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def create(self, target, df=None, args=None, **kwargs):
args = args["using"]

valid_args = Parameters(**args)
self.model_storage.json_set("args", valid_args.dict())
self.model_storage.json_set("args", valid_args.model_dump())

def predict(self, df, args=None):
"""loads persisted embeddings model and gets embeddings on input text column(s)"""
Expand Down
14 changes: 7 additions & 7 deletions mindsdb/integrations/handlers/twelve_labs_handler/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, model_validator
from pydantic_settings import BaseSettings

from mindsdb.integrations.handlers.utilities.validation_utilities import ParameterValidationUtilities
Expand Down Expand Up @@ -80,9 +80,9 @@ class TwelveLabsHandlerModel(BaseModel):
prompt: Optional[str] = None

class Config:
extra = Extra.forbid
extra = "forbid"

@root_validator(pre=True, allow_reuse=True, skip_on_failure=True)
@model_validator(mode="before", skip_on_failure=True)
def check_param_typos(cls, values):
"""
Root validator to check if there are any typos in the parameters.
Expand All @@ -102,7 +102,7 @@ def check_param_typos(cls, values):

return values

@root_validator(pre=True, allow_reuse=True, skip_on_failure=True)
@model_validator(mode="before", skip_on_failure=True)
def check_for_valid_task(cls, values):
"""
Root validator to check if the task provided is valid.
Expand All @@ -127,7 +127,7 @@ def check_for_valid_task(cls, values):

return values

@root_validator(pre=True, allow_reuse=True, skip_on_failure=True)
@model_validator(mode="before", skip_on_failure=True)
def check_for_valid_engine_options(cls, values):
"""
Root validator to check if the options specified for particular engines are valid.
Expand All @@ -154,7 +154,7 @@ def check_for_valid_engine_options(cls, values):

return values

@root_validator(allow_reuse=True, skip_on_failure=True)
@model_validator(skip_on_failure=True)
def check_for_video_urls_or_video_files(cls, values):
"""
Root validator to check if video_urls or video_files have been provided.
Expand Down Expand Up @@ -183,7 +183,7 @@ def check_for_video_urls_or_video_files(cls, values):

return values

@root_validator(allow_reuse=True, skip_on_failure=True)
@model_validator(skip_on_failure=True)
def check_for_task_specific_parameters(cls, values):
"""
Root validator to check if task has been provided along with the other relevant parameters for each task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
class ParameterValidationUtilities:
@staticmethod
def validate_parameter_spelling(handler_cls, parameters):
expected_params = handler_cls.__fields__.keys()
expected_params = handler_cls.model_fields.keys()
for key in parameters.keys():
if key not in expected_params:
close_matches = difflib.get_close_matches(
Expand Down
2 changes: 1 addition & 1 deletion mindsdb/integrations/handlers/writer_handler/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ def __init__(self, args: WriterHandlerParameters):

super().__init__(args)

self.llm = Writer(**args.llm_params.dict())
self.llm = Writer(**args.llm_params.model_dump())
6 changes: 3 additions & 3 deletions mindsdb/integrations/handlers/writer_handler/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Config:
arbitrary_types_allowed = True
use_enum_values = True

@validator("generation_evaluation_metrics", allow_reuse=True)
@field_validator("generation_evaluation_metrics")
def generation_evaluation_metrics_must_be_supported(cls, v):
for metric in v:
if metric not in GENERATION_METRICS:
Expand All @@ -69,7 +69,7 @@ def generation_evaluation_metrics_must_be_supported(cls, v):
)
return v

@validator("retrieval_evaluation_metrics", allow_reuse=True)
@field_validator("retrieval_evaluation_metrics")
def retrieval_evaluation_metrics_must_be_supported(cls, v):
for metric in v:
if metric not in GENERATION_METRICS:
Expand All @@ -78,7 +78,7 @@ def retrieval_evaluation_metrics_must_be_supported(cls, v):
)
return v

@validator("evaluation_type", allow_reuse=True)
@field_validator("evaluation_type")
def evaluation_type_must_be_supported(cls, v):
if v not in SUPPORTED_EVALUATION_TYPES:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def extract_llm_params(args):
"""extract llm params from input query args"""

llm_params = {}
for param in WriterLLMParameters.__fields__:
for param in WriterLLMParameters.model_fields:
if param in args:
llm_params[param] = args.pop(param)

Expand Down
2 changes: 1 addition & 1 deletion mindsdb/integrations/utilities/rag/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,4 @@ class Config:

@classmethod
def get_field_names(cls):
return list(cls.__fields__.keys())
return list(cls.model_fields.keys())
Loading