Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,37 @@ def _calculate_keys(
) -> Optional[AbstractSet[str]]: # pragma: no cover
return None

def validate_access_primary_key_autotype(
self: InstanceOrType["SQLModel"], name: str, value: Any
) -> None:
"""
Pydantic v2
Validates if the attribute being accessed is a primary key with an auto type and has not been set.

Args:
self (InstanceOrType["SQLModel"]): The instance or type of SQLModel.
name (str): The name of the attribute being accessed.
value (Any): The value of the attribute being accessed.

Raises:
ValueError: If the attribute is a primary key with an auto type and has not been set.

Returns:
None
"""
if name != "model_fields":
model_fields = object.__getattribute__(self, "model_fields")
field = model_fields.get(name)
if (
field is not None
and isinstance(field, FieldInfo)
and hasattr(field, "primary_key")
):
if field.primary_key and field.annotation is int and value is None:
raise ValueError(
f"Primary key attribute '{name}' has not been set, please commit() it first."
)

def sqlmodel_table_construct(
*,
self_instance: _TSQLModel,
Expand Down Expand Up @@ -386,15 +417,15 @@ class SQLModelConfig(BaseConfig): # type: ignore[no-redef]
def get_config_value(
*, model: InstanceOrType["SQLModel"], parameter: str, default: Any = None
) -> Any:
return getattr(model.__config__, parameter, default) # type: ignore[union-attr]
return getattr(model.__config__, parameter, default)

def set_config_value(
*,
model: InstanceOrType["SQLModel"],
parameter: str,
value: Any,
) -> None:
setattr(model.__config__, parameter, value) # type: ignore
setattr(model.__config__, parameter, value)

def get_model_fields(model: InstanceOrType[BaseModel]) -> Dict[str, "FieldInfo"]:
return model.__fields__ # type: ignore
Expand Down Expand Up @@ -499,6 +530,41 @@ def _calculate_keys(

return keys

def validate_access_primary_key_autotype(
self: InstanceOrType["SQLModel"], name: str, value: Any
) -> None:
"""
Pydantic v1
Validates if the attribute being accessed is a primary key with an auto type and has not been set.

Args:
self (InstanceOrType["SQLModel"]): The instance or type of SQLModel.
name (str): The name of the attribute being accessed.
value (Any): The value of the attribute being accessed.

Raises:
ValueError: If the attribute is a primary key with an auto type and has not been set.

Returns:
None
"""
if name != "__fields__":
fields = object.__getattribute__(self, "__fields__")
field = fields.get(name)
if (
field is not None
and isinstance(field.field_info, FieldInfo)
and hasattr(field.field_info, "primary_key")
):
if (
field.field_info.primary_key
and field.annotation is int
and value is None
):
raise ValueError(
f"Primary key attribute '{name}' has not been set, please commit() it first."
)

def sqlmodel_validate(
cls: Type[_TSQLModel],
obj: Any,
Expand Down Expand Up @@ -542,7 +608,7 @@ def sqlmodel_validate(
setattr(m, key, value)
# Continue with standard Pydantic logic
object.__setattr__(m, "__fields_set__", fields_set)
m._init_private_attributes() # type: ignore[attr-defined] # noqa
m._init_private_attributes()
return m

def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None:
Expand Down
7 changes: 7 additions & 0 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
set_config_value,
sqlmodel_init,
sqlmodel_validate,
validate_access_primary_key_autotype,
)
from .sql.sqltypes import GUID, AutoString

Expand Down Expand Up @@ -732,6 +733,12 @@ def __setattr__(self, name: str, value: Any) -> None:
if name not in self.__sqlmodel_relationships__:
super().__setattr__(name, value)

def __getattribute__(self, name: str) -> Any:
# Access attributes safely using object.__getattribute__ to avoid recursion
value = object.__getattribute__(self, name)
validate_access_primary_key_autotype(self, name, value)
return value

def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
# Don't show SQLAlchemy private attributes
return [
Expand Down