Skip to content

Commit

Permalink
Simplify 'Model.objects.get()'
Browse files Browse the repository at this point in the history
  • Loading branch information
tomchristie committed Mar 15, 2019
1 parent 5f50514 commit b835233
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 166 deletions.
12 changes: 12 additions & 0 deletions orm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
from orm.exceptions import NoMatch, MultipleMatches
from orm.fields import Integer, String
from orm.models import Model


__version__ = "0.0.1"
__all__ = [
"NoMatch",
"MultipleMatches",
"Integer",
"String",
"Model"
]
66 changes: 0 additions & 66 deletions orm/core.py

This file was deleted.

6 changes: 6 additions & 0 deletions orm/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class NoMatch(Exception):
pass


class MultipleMatches(Exception):
pass
67 changes: 56 additions & 11 deletions orm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sqlalchemy
import typesystem
from typesystem.schemas import SchemaMetaclass
from orm.exceptions import NoMatch, MultipleMatches


class ModelMetaclass(SchemaMetaclass):
Expand All @@ -28,25 +29,55 @@ def __new__(

new_model.__table__ = sqlalchemy.Table(tablename, metadata, *columns)
new_model.__pkname__ = pkname
new_model.objects = QuerySet(new_model)

return new_model


class QuerySet:
def __init__(self, model_cls):
def __init__(self, model_cls=None):
self.model_cls = model_cls
self.database = model_cls.__database__
self.query = self.model_cls.__table__.select()

def __get__(self, instance, owner):
return self.__class__(model_cls=owner)

@property
def database(self):
return self.model_cls.__database__

@property
def table(self):
return self.model_cls.__table__

async def all(self):
rows = await self.database.fetch_all(self.query)
return [self.model_cls(**dict(row)) for row in rows]
expr = self.table.select()
rows = await self.database.fetch_all(expr)
return [self.model_cls(dict(row)) for row in rows]

async def get(self):
expr = self.table.select()
rows = await self.database.fetch_all(expr)

if not rows:
raise NoMatch()
if len(rows) > 1:
raise MultipleMatches()
return self.model_cls(dict(rows[0]))

async def create(self, **kwargs):
instance = self.model_cls.validate(kwargs)
expr = self.model_cls.__table__.insert()
# Validate the keyword arguments.
fields = self.model_cls.fields
required = [key for key, value in fields.items() if not value.has_default()]
validator = typesystem.Object(
properties=fields, required=required, additional_properties=False
)
kwargs = validator.validate(kwargs)

# Build the insert expression.
expr = self.table.insert()
expr = expr.values(**kwargs)

# Execute the insert, and return a new model instance.
instance = self.model_cls(kwargs)
instance.pk = await self.database.execute(expr)
return instance

Expand All @@ -55,19 +86,33 @@ class Model(typesystem.Schema, metaclass=ModelMetaclass):
__abstract__ = True

async def update(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
# Validate the keyword arguments.
fields = {key: field for key, field in self.fields.items() if key in kwargs}
validator = typesystem.Object(properties=fields)
kwargs = validator.validate(kwargs)

# Build the update expression.
pk_column = getattr(self.__table__.c, self.__pkname__)
expr = self.__table__.update()
expr = expr.values(**kwargs).where(pk_column == self.pk)

# Perform the update.
await self.__database__.execute(expr)
return self

# Update the model instance.
for key, value in kwargs.items():
setattr(self, key, value)

async def delete(self):
# Build the delete expression.
pk_column = getattr(self.__table__.c, self.__pkname__)
expr = self.__table__.delete().where(pk_column == self.pk)

# Perform the delete.
await self.__database__.execute(expr)

objects = QuerySet()

@property
def pk(self):
return getattr(self, self.__pkname__)
Expand Down
4 changes: 2 additions & 2 deletions scripts/test
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ set -x

PYTHONPATH=. ${PREFIX}pytest --ignore venv --cov=${PACKAGE} --cov=tests --cov-fail-under=100 --cov-report=term-missing ${@}
#${PREFIX}mypy ${PACKAGE} --ignore-missing-imports --disallow-untyped-defs
${PREFIX}autoflake --recursive ${PACKAGE} tests
${PREFIX}black ${PACKAGE} tests --check
#${PREFIX}autoflake --recursive ${PACKAGE} tests
#${PREFIX}black ${PACKAGE} tests --check
40 changes: 31 additions & 9 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
import asyncio
import functools
import pytest

import databases
import pytest
import sqlalchemy

from orm.fields import Integer, String
from orm.models import Model
import orm


DATABASE_URL = "sqlite:///test.db"
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()


class User(Model):
class User(orm.Model):
__tablename__ = "users"
__metadata__ = metadata
__database__ = database

id = Integer(primary_key=True)
name = String(max_length=100)
id = orm.Integer(primary_key=True)
name = orm.String(max_length=100)


@pytest.fixture(autouse=True, scope="module")
Expand All @@ -46,26 +47,47 @@ def run_sync(*args, **kwargs):

def test_model_class():
assert list(User.fields.keys()) == ["id", "name"]
assert isinstance(User.fields["id"], Integer)
assert isinstance(User.fields["id"], orm.Integer)
assert User.fields["id"].primary_key is True
assert isinstance(User.fields["name"], String)
assert isinstance(User.fields["name"], orm.String)
assert User.fields["name"].max_length == 100
assert isinstance(User.__table__, sqlalchemy.Table)


@async_adapter
async def test_model_operations():
async def test_model_crud():
users = await User.objects.all()
assert users == []

user = await User.objects.create(name="Tom")
users = await User.objects.all()
assert user.name == "Tom"
assert user.pk is not None
assert users == [user]

user = await user.update(name="Jane")
lookup = await User.objects.get()
assert lookup == user

await user.update(name="Jane")
users = await User.objects.all()
assert user.name == "Jane"
assert user.pk is not None
assert users == [user]

await user.delete()
users = await User.objects.all()
assert users == []


@async_adapter
async def test_model_get():
with pytest.raises(orm.NoMatch):
await User.objects.get()

user = await User.objects.create(name="Tom")
lookup = await User.objects.get()
assert lookup == user

user = await User.objects.create(name="Jane")
with pytest.raises(orm.MultipleMatches):
await User.objects.get()
78 changes: 0 additions & 78 deletions tests/test_orm.py

This file was deleted.

0 comments on commit b835233

Please sign in to comment.