Skip to content

Commit

Permalink
Enum fixes (typing and inheritance)
Browse files Browse the repository at this point in the history
Addresses issues tiangolo#96 & tiangolo#164
  • Loading branch information
chriswhite199 committed Nov 24, 2021
1 parent 02da85c commit c3f8340
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 13 deletions.
17 changes: 4 additions & 13 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,9 @@
from pydantic.main import BaseConfig, ModelMetaclass, validate_model
from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations
from pydantic.utils import ROOT_KEY, Representation
from sqlalchemy import (
Boolean,
Column,
Date,
DateTime,
Float,
ForeignKey,
Integer,
Interval,
Numeric,
inspect,
)
from sqlalchemy import Boolean, Column, Date, DateTime
from sqlalchemy import Enum as sa_Enum
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
Expand Down Expand Up @@ -389,7 +380,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
if issubclass(field.type_, time):
return Time
if issubclass(field.type_, Enum):
return Enum
return sa_Enum(field.type_)
if issubclass(field.type_, bytes):
return LargeBinary
if issubclass(field.type_, Decimal):
Expand Down
72 changes: 72 additions & 0 deletions tests/test_enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import enum
import uuid

from sqlalchemy import create_mock_engine
from sqlalchemy.sql.type_api import TypeEngine
from sqlmodel import Field, SQLModel

"""
Tests related to Enums
Associated issues:
* https://github.com/tiangolo/sqlmodel/issues/96
* https://github.com/tiangolo/sqlmodel/issues/164
"""


class MyEnum1(enum.Enum):
A = "A"
B = "B"


class MyEnum2(enum.Enum):
C = "C"
D = "D"


class BaseModel(SQLModel):
id: uuid.UUID = Field(primary_key=True)
enum_field: MyEnum2


class FlatModel(SQLModel, table=True):
id: uuid.UUID = Field(primary_key=True)
enum_field: MyEnum1


class InheritModel(BaseModel, table=True):
pass


def pg_dump(sql: TypeEngine, *args, **kwargs):
dialect = sql.compile(dialect=postgres_engine.dialect)
sql_str = str(dialect).rstrip()
if sql_str:
print(sql_str + ";")


def sqlite_dump(sql: TypeEngine, *args, **kwargs):
dialect = sql.compile(dialect=sqlite_engine.dialect)
sql_str = str(dialect).rstrip()
if sql_str:
print(sql_str + ";")


postgres_engine = create_mock_engine("postgresql://", pg_dump)
sqlite_engine = create_mock_engine("sqlite://", sqlite_dump)


def test_postgres_ddl_sql(capsys):
SQLModel.metadata.create_all(bind=postgres_engine, checkfirst=False)

captured = capsys.readouterr()
assert "CREATE TYPE myenum1 AS ENUM ('A', 'B');" in captured.out
assert "CREATE TYPE myenum2 AS ENUM ('C', 'D');" in captured.out


def test_sqlite_ddl_sql(capsys):
SQLModel.metadata.create_all(bind=sqlite_engine, checkfirst=False)

captured = capsys.readouterr()
assert "enum_field VARCHAR(1) NOT NULL" in captured.out
assert "CREATE TYPE" not in captured.out

0 comments on commit c3f8340

Please sign in to comment.