Skip to content

Commit

Permalink
Various fixes and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
evertheylen committed Aug 29, 2016
1 parent b9a70df commit b6aade3
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 24 deletions.
74 changes: 59 additions & 15 deletions sparrow/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ def __str__(self):
str: "VARCHAR",
float: "DOUBLE PRECISION",
bool: "BOOL",
datetime.datetime: "TIMESTAMP" # but consider perhaps amount of seconds since UNIX epoch
datetime.datetime: "TIMESTAMP", # but consider perhaps amount of seconds since UNIX epoch
}


class Type:
def __init__(self, python_type, sql_type=None):
if sql_type is not None:
Expand All @@ -86,9 +87,24 @@ def to_sql(self, obj: "self.python_type"):
def from_sql(self, obj):
return obj

constraint = None

def __str__(self):
return "Type({}, {})".format(str(self.python_type), self.sql_type)

class StaticType(Type):
pass

class _Json(StaticType):
@staticmethod
def to_sql(obj):
return json.dumps(obj)

@staticmethod
def from_sql(obj):
return json.loads(obj)

Json = _Json(str)

class List(Type):
def __init__(self, inner_type):
Expand All @@ -98,7 +114,7 @@ def __init__(self, inner_type):
self.inner_type = inner_type
self.sql_type = inner_type.sql_type + "[]"

def to_sql(self, obj: "self.python_type", f=lambda o: repr(o)):
def to_sql(self, obj: "self.python_type", f=repr):
return "{" + ", ".join([f(o) for o in obj]) + "}"

def __str__(self):
Expand All @@ -111,6 +127,12 @@ def __init__(self, *args):
self.inv_options = {val: num for num, val in enumerate(self.options)}
self.python_type = str

@property
def constraint(self):
def _constraint(val, _self=self):
return val in self.options
return _constraint

def __postinit__(self):
self.__postinited__ = True
self._create_type_command = RawSql("CREATE TYPE {s.name} AS ENUM ({opt})".format(
Expand Down Expand Up @@ -185,9 +207,15 @@ def _set_overloads(use_own: bool):
class Property(Queryable):
def __init__(self, typ, constraint: types.FunctionType = None, sql_extra: str = "",
required: bool = True, json: bool = True):
if not isinstance(typ, Type):
if not isinstance(typ, (Type, StaticType)):
typ = Type(typ)
self.type = typ
if self.type.constraint is not None:
if constraint is None:
constraint = self.type.constraint
else:
constraint = lambda val: self.type.constraint(val) and constraint(val)

self.constraint = constraint
self.sql_extra = sql_extra
self.required = required
Expand Down Expand Up @@ -540,7 +568,7 @@ def __metainit__(obj, db_args=None, json_dict=None, **kwargs):
for (i, p) in enumerate(init_raw_ref_properties, len(init_properties)):
obj.__dict__[p.dataname] = db_args[i]
elif json_dict is not None:
used = set()
#used = set()
obj.in_db = False
# Init from a (more raw) dictionary, possibly unsafe
for (p, constrained) in init_properties:
Expand Down Expand Up @@ -652,7 +680,7 @@ def __call__(self, *args, **kwargs):

class Entity(metaclass=MetaEntity):
"""Central class for an Entity.
NOTE: Be careful with changing the key as it will fuck with caching.
WARNING: Be careful with changing the key as it will fuck with caching.
Basically, don't do it.
"""

Expand All @@ -661,8 +689,10 @@ class Entity(metaclass=MetaEntity):
def __init__(self, *args, **kwargs):
self.__metainit__(*args, **kwargs)

async def insert(self, db: Database, replace=False):
async def insert(self, db: Database = None, replace=False):
"""Insert in database."""
if db is None:
db = GlobalDb.get()

if self.key is None:
assert type(self)._incomplete
Expand All @@ -676,7 +706,9 @@ async def insert(self, db: Database, replace=False):
else:
await self._simple_insert(db, replace)

async def _simple_insert(self, db: Database, replace=False):
async def _simple_insert(self, db: Database = None, replace=False):
if db is None:
db = GlobalDb.get()
self.check()
assert not self.in_db
cls = type(self)
Expand All @@ -701,9 +733,10 @@ async def _simple_insert(self, db: Database, replace=False):
if val in rt_ref.ref.cache:
rt_ref.ref.cache[val].new_reference(rt_ref, self)

async def update(self, db):
async def update(self, db=None):
"""Update object in the database."""

if db is None:
db = GlobalDb.get()
self.check()
assert self.in_db
dct = {}
Expand All @@ -714,9 +747,10 @@ async def update(self, db):
await type(self)._update_command.with_data(**dct).exec(db)


async def delete(self, db):
async def delete(self, db=None):
"""Delete object from the database."""

if db is None:
db = GlobalDb.get()
assert self.in_db
dct = {}
for p in type(self).key.referencing_props():
Expand Down Expand Up @@ -750,8 +784,11 @@ def get(cls: MetaEntity, *where_clauses: list) -> Sql:
return Select(cls, where_clauses)

@classmethod
async def find_by_key(cls: MetaEntity, key, db: Database) -> "cls":
async def find_by_key(cls: MetaEntity, key, db: Database = None) -> "cls":
"""Works different from `get`, as it will immediatly return the object"""
if db is None:
db = GlobalDb.get()

try:
return cls.cache[key]
except KeyError:
Expand Down Expand Up @@ -817,17 +854,24 @@ def __init__(self, *args, **kwargs):
self._listeners = set()
super(RTEntity, self).__init__(*args, **kwargs)

async def update(self, db: Database):
async def update(self, db: Database = None):
if db is None:
db = GlobalDb.get()

await super(RTEntity, self).update(db)
for l in self._listeners:
l.update(self)

def send_update(self, db):
def send_update(self, db = None):
if db is None:
db = GlobalDb.get()
"""To manually send messages to all listeners. Won't save to database."""
for l in self._listeners:
l.update(self)

async def delete(self, db):
async def delete(self, db = None):
if db is None:
db = GlobalDb.get()
await super(RTEntity, self).delete(db)
for l in self._listeners:
l.delete(self)
Expand Down
4 changes: 3 additions & 1 deletion sparrow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class SparrowModel:
"""
The central class that keeps everything together.
"""
def __init__(self, ioloop, db_args, classes, debug=True, db=None):
def __init__(self, ioloop, db_args, classes, debug=True, db=None, set_global_db=False):
self.ioloop = ioloop
if db is not None:
self.db = db
else:
self.db = Database(ioloop, **db_args)
if set_global_db:
GlobalDb.set(self.db)
self.classes = classes
self.debug = debug

Expand Down
36 changes: 28 additions & 8 deletions sparrow/sql.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@


"""
"""

from functools import wraps
import copy

Expand Down Expand Up @@ -45,19 +39,41 @@ def __init__(self, ioloop, dbname, user="postgres", password="postgres", host="l
dbname=dbname, user=user, password=password, host=host, port=port)
self.pdb = momoko.Pool(dsn=dsn, size=momoko_poolsize, ioloop=ioloop)
self.pdb.connect()


async def get_cursor(self, statement: "Sql", unsafe_dict: dict):
statement = str(statement)
cursor = await self.pdb.execute(statement, unsafe_dict)
return cursor

class GlobalDb:
db = None

@classmethod
def get(cls):
return cls.db

@classmethod
def set(cls, db):
cls.db = db

@classmethod
def globalize(cls, db):
if db is not None:
return cls.db

# Helper classes for building a query
# -----------------------------------

class Unsafe:
"""Wrapper for unsafe data. (For data that needs to inserted later, use Field.)"""

def __new__(typ, val, *args, **kwargs):
if isinstance(val, Unsafe):
return val
obj = object.__new__(typ, *args, **kwargs)
return obj

def __init__(self, value):
self.key = str(id(self))
self.text = "%({0})s".format(self.key)
Expand Down Expand Up @@ -140,7 +156,9 @@ def count(self):

def _wrapper_sqlresult(method):
@wraps(method)
async def wrapper(self, db: Database, *args, **kwargs):
async def wrapper(self, db: Database = None, *args, **kwargs):
if db is None:
db = GlobalDb.get()
result = await self.exec(db)
return method(result, *args, **kwargs)
wrapper.__doc__ += "\n\nWrapped version, first argument is the database."
Expand All @@ -162,8 +180,10 @@ def __preinit__(self):
# By default, there is no class
cls = None

async def exec(self, db: Database):
async def exec(self, db: Database = None):
"""Execute the SQL statement on the given database."""
if db is None:
db = GlobalDb.get()
try:
return SqlResult(await db.get_cursor(str(self), self.data), self)
except psycopg2.Error as e:
Expand Down

0 comments on commit b6aade3

Please sign in to comment.