Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,13 @@ def get_field_metadata(field: Any) -> Any:
return FakeMetadata()

def post_init_field_info(field_info: FieldInfo) -> None:
return None
if IS_PYDANTIC_V2:
if field_info.alias and not field_info.validation_alias:
field_info.validation_alias = field_info.alias
if field_info.alias and not field_info.serialization_alias:
field_info.serialization_alias = field_info.alias
else:
field_info._validate() # type: ignore[attr-defined]

# Dummy to make it importable
def _calculate_keys(
Expand Down
102 changes: 66 additions & 36 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def Field(
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
validation_alias: Optional[str] = None,
serialization_alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
Expand Down Expand Up @@ -260,6 +262,8 @@ def Field(
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
validation_alias: Optional[str] = None,
serialization_alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
Expand Down Expand Up @@ -314,6 +318,8 @@ def Field(
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
validation_alias: Optional[str] = None,
serialization_alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
Expand Down Expand Up @@ -349,6 +355,8 @@ def Field(
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
validation_alias: Optional[str] = None,
serialization_alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
Expand Down Expand Up @@ -387,43 +395,65 @@ def Field(
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
field_info = FieldInfo(
default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
exclude=exclude,
include=include,
const=const,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
max_digits=max_digits,
decimal_places=decimal_places,
min_items=min_items,
max_items=max_items,
unique_items=unique_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
discriminator=discriminator,
repr=repr,
primary_key=primary_key,
foreign_key=foreign_key,
ondelete=ondelete,
unique=unique,
nullable=nullable,
index=index,
sa_type=sa_type,
sa_column=sa_column,
sa_column_args=sa_column_args,
sa_column_kwargs=sa_column_kwargs,
field_info_kwargs = {
"alias": alias,
"title": title,
"description": description,
"exclude": exclude,
"include": include,
"const": const,
"gt": gt,
"ge": ge,
"lt": lt,
"le": le,
"multiple_of": multiple_of,
"max_digits": max_digits,
"decimal_places": decimal_places,
"min_items": min_items,
"max_items": max_items,
"unique_items": unique_items,
"min_length": min_length,
"max_length": max_length,
"allow_mutation": allow_mutation,
"regex": regex,
"discriminator": discriminator,
"repr": repr,
"primary_key": primary_key,
"foreign_key": foreign_key,
"ondelete": ondelete,
"unique": unique,
"nullable": nullable,
"index": index,
"sa_type": sa_type,
"sa_column": sa_column,
"sa_column_args": sa_column_args,
"sa_column_kwargs": sa_column_kwargs,
**current_schema_extra,
)
}
if IS_PYDANTIC_V2:
# Add Pydantic v2 specific parameters
field_info_kwargs.update(
{
"validation_alias": validation_alias,
"serialization_alias": serialization_alias,
}
)
field_info = FieldInfo(
default,
default_factory=default_factory,
**field_info_kwargs,
)
else:
if validation_alias:
raise RuntimeError("validation_alias is not supported in Pydantic v1")
if serialization_alias:
raise RuntimeError("serialization_alias is not supported in Pydantic v1")
field_info = FieldInfo(
default,
default_factory=default_factory,
**field_info_kwargs,
)

post_init_field_info(field_info)
return field_info

Expand Down
176 changes: 176 additions & 0 deletions tests/test_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from typing import Type, Union

import pytest
from pydantic import VERSION, BaseModel, ValidationError
from pydantic import Field as PField
from sqlmodel import Field, SQLModel

# -----------------------------------------------------------------------------------
# Models


class PydanticUser(BaseModel):
full_name: str = PField(alias="fullName")


class SQLModelUser(SQLModel):
full_name: str = Field(alias="fullName")


# Models with config (validate_by_name=True)


if VERSION.startswith("2."):

class PydanticUserWithConfig(PydanticUser):
model_config = {"validate_by_name": True}

class SQLModelUserWithConfig(SQLModelUser):
model_config = {"validate_by_name": True}

else:

class PydanticUserWithConfig(PydanticUser):
class Config:
allow_population_by_field_name = True

class SQLModelUserWithConfig(SQLModelUser):
class Config:
allow_population_by_field_name = True


# -----------------------------------------------------------------------------------
# Tests

# Test validate by name


@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser])
def test_create_with_field_name(model: Union[Type[PydanticUser], Type[SQLModelUser]]):
with pytest.raises(ValidationError):
model(full_name="Alice")


@pytest.mark.parametrize("model", [PydanticUserWithConfig, SQLModelUserWithConfig])
def test_create_with_field_name_with_config(
model: Union[Type[PydanticUserWithConfig], Type[SQLModelUserWithConfig]],
):
user = model(full_name="Alice")
assert user.full_name == "Alice"


# Test validate by alias


@pytest.mark.parametrize(
"model",
[PydanticUser, SQLModelUser, PydanticUserWithConfig, SQLModelUserWithConfig],
)
def test_create_with_alias(
model: Union[
Type[PydanticUser],
Type[SQLModelUser],
Type[PydanticUserWithConfig],
Type[SQLModelUserWithConfig],
],
):
user = model(fullName="Bob") # using alias
assert user.full_name == "Bob"


# Test validate by name and alias


@pytest.mark.parametrize("model", [PydanticUserWithConfig, SQLModelUserWithConfig])
def test_create_with_both_prefers_alias(
model: Union[Type[PydanticUserWithConfig], Type[SQLModelUserWithConfig]],
):
user = model(full_name="IGNORED", fullName="Charlie")
assert user.full_name == "Charlie" # alias should take precedence


# Test serialize


@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser])
def test_dict_default_uses_field_names(
model: Union[Type[PydanticUser], Type[SQLModelUser]],
):
user = model(fullName="Dana")
data = user.dict()
assert "full_name" in data
assert "fullName" not in data
assert data["full_name"] == "Dana"


# Test serialize by alias


@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser])
def test_dict_default_uses_aliases(
model: Union[Type[PydanticUser], Type[SQLModelUser]],
):
user = model(fullName="Dana")
data = user.dict(by_alias=True)
assert "fullName" in data
assert "full_name" not in data
assert data["fullName"] == "Dana"


# Test json by alias


@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser])
def test_json_by_alias(
model: Union[Type[PydanticUser], Type[SQLModelUser]],
):
user = model(fullName="Frank")
json_data = user.json(by_alias=True)
assert ('"fullName":"Frank"' in json_data) or ('"fullName": "Frank"' in json_data)
assert "full_name" not in json_data


# Pydantic v2 specific models - only define if we're running Pydantic v2
if VERSION.startswith("2."):

class PydanticUserV2(BaseModel):
first_name: str = PField(
validation_alias="firstName", serialization_alias="f_name"
)

class SQLModelUserV2(SQLModel):
first_name: str = Field(
validation_alias="firstName", serialization_alias="f_name"
)
else:
# Dummy classes for Pydantic v1 to prevent import errors
PydanticUserV2 = None
SQLModelUserV2 = None


@pytest.mark.skipif(
not VERSION.startswith("2."),
reason="validation_alias and serialization_alias are not supported in Pydantic v1",
)
@pytest.mark.parametrize("model", [PydanticUserV2, SQLModelUserV2])
def test_create_with_validation_alias(
model: Union[Type[PydanticUserV2], Type[SQLModelUserV2]],
):
user = model(firstName="John")
assert user.first_name == "John"


@pytest.mark.skipif(
not VERSION.startswith("2."),
reason="validation_alias and serialization_alias are not supported in Pydantic v1",
)
@pytest.mark.parametrize("model", [PydanticUserV2, SQLModelUserV2])
def test_serialize_with_serialization_alias(
model: Union[Type[PydanticUserV2], Type[SQLModelUserV2]],
):
user = model(firstName="Jane")
data = user.dict(by_alias=True)
assert "f_name" in data
assert "firstName" not in data
assert "first_name" not in data
assert data["f_name"] == "Jane"
Loading