Skip to content

Commit

Permalink
Add support for composite primary keys
Browse files Browse the repository at this point in the history
  • Loading branch information
gvx committed Jan 18, 2021
1 parent fd8ad77 commit 86f9d2c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
11 changes: 11 additions & 0 deletions tests/test_wurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class NoRowid(wurm.WithoutRowid):
key: wurm.Primary[str]
count: int

@dataclass
class CompositeKey(wurm.WithoutRowid):
part_one: wurm.Primary[int]
part_two: wurm.Primary[int]

@pytest.fixture
def connection():
Expand Down Expand Up @@ -289,3 +293,10 @@ def test_no_subclass_nonabstract_table():
@dataclass
class Point3D(Point):
z: int

def test_composite_key(connection):
CompositeKey(1, 2).insert()
CompositeKey(2, 1).insert()
CompositeKey(1, 1).insert()
with pytest.raises(wurm.WurmError):
CompositeKey(1, 2).insert()
8 changes: 3 additions & 5 deletions wurm/sql.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from dataclasses import fields

from .typemaps import sql_type_for

def create(table):
return f'create table if not exists {table.__table_name__}({", ".join(make_field(f) for f in fields(table))})'
return f'create table if not exists {table.__table_name__}({", ".join(make_field(name, ty) for name, ty in table.__fields_info__.items())}, PRIMARY KEY ({", ".join(table.__primary_key__)}))'

def make_field(f):
return f'{f.name} {sql_type_for(f.type)}'
def make_field(name, ty):
return f'{name} {sql_type_for(ty)}'

def count(table, where=None):
if not where:
Expand Down
4 changes: 1 addition & 3 deletions wurm/typemaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def sql_type_for(python_type):
postfix = ''
if get_origin(python_type) is Annotated:
python_type, *rest = get_args(python_type)
if any(_PrimaryMarker is arg for arg in rest):
postfix = ' PRIMARY KEY'
elif any(_UniqueMarker is arg for arg in rest):
if any(_UniqueMarker is arg for arg in rest):
postfix = ' UNIQUE'
return TYPE_MAPPING[python_type].sql_type + postfix

Expand Down

0 comments on commit 86f9d2c

Please sign in to comment.