From b0c5cb534480109c4757cbcee395ef04dbf60437 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Thu, 4 Oct 2012 15:45:58 +1300 Subject: [PATCH] Initial commit from blog entry --- .gitattributes | 1 + .gitignore | 1 + README.md | 8 ++ mro.py | 337 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 347 insertions(+) create mode 100644 .gitattributes create mode 100644 .gitignore create mode 100644 README.md create mode 100644 mro.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..176a458 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0d20b64 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.pyc diff --git a/README.md b/README.md new file mode 100644 index 0000000..be63ecf --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +MRO: Map Rows to Objects with web.py +==================================== + +MRO is not an ORM. Well, not really -- it's a (slightly too) minimalist mapper +of rows to objects for use with Python and web.py. + +Read my [original blog entry](http://blog.brush.co.nz/2010/01/mro/) or check +out the [source](https://github.com/benhoyt/mro/blob/master/mro.py). diff --git a/mro.py b/mro.py new file mode 100644 index 0000000..7a22e38 --- /dev/null +++ b/mro.py @@ -0,0 +1,337 @@ +""" MRO: Map Rows to Objects with web.py. + +See docstrings and the UserTest class for examples. Or read my blog entry about +MRO which lives at: + http://blog.brush.co.nz/2010/01/mro/ + +MRO is released under the 3-clause "New BSD license", and is copyright (c) 2009 +Brush Technology. See the full text of the license at: + http://www.opensource.org/licenses/bsd-license.php + +""" + +import datetime +import web + +class Column(object): + """ Defines a column in a Table object. """ + + def __init__(self, sql_type=None, indexed=False, primary_key=False, secondary_key=False, **kwargs): + """ Define a column inside a Table object with an SQL type of sql_type + (if not given, the column's _sql_type attribute is used). + + If "indexed" is True, an index for this column will be created. If + "indexed" is a string, it will be used as an SQL indexing function, + for example indexed='LOWER(username)'. + + "primary_key" is True if this column is an (integer) primary key, + or "secondary_key" can be True if this column is a unique, string + secondary key (like username, email address, or slug). + + Other keyword args are converted to SQL constraints as follows: + underscores in key name are replaced with spaces and it's converted + to uppercase. If value is a bool and True, key name is added to the + constraints, otherwise value is converted to a string and appended + to the constraint. For example: + + >>> column = Column('TEXT', not_null=True, default="'NZD'") + >>> print column._sql_type, column._constraints + TEXT DEFAULT 'NZD' NOT NULL + """ + if sql_type is not None: + self._sql_type = sql_type + self._indexed = indexed + constraints = [] + self._primary_key = primary_key + if primary_key: + constraints.append('PRIMARY KEY') + self._secondary_key = secondary_key + if secondary_key: + constraints.append('NOT NULL UNIQUE') + self._indexed = True + for name, value in sorted(kwargs.iteritems()): + name = name.replace('_', ' ').upper() + if isinstance(value, bool) and value: + constraints.append(name) + else: + constraints.append('%s %s' % (name, value)) + self._constraints = ' '.join(constraints) + +class Serial(Column): + _sql_type = 'SERIAL' + +class Integer(Column): + _sql_type = 'INTEGER' + +class String(Column): + _sql_type = 'TEXT' + +class Date(Column): + _sql_type = 'DATE' + +class Timestamp(Column): + _sql_type = 'TIMESTAMP WITHOUT TIME ZONE' + +class Inet(Column): + _sql_type = 'INET' + +class Table(object): + """ Defines a database table with its columns. See the UserTest class for + an example, and see __init__'s docstring for examples of how to use the + constructor. + """ + + def __init__(self, _init=None, _fromdb=False, _test=False, **kwargs): + """ Initialise a row object of this table from the given _init data. + If _init is None, fields are taken from kwargs. If _init is a dict, + fields are taken from it. Otherwise _init is assumed to be a + primary key if it's an int or a secondary key if it's a non-int, + and fields are loaded from the database (raising a KeyError if no + rows match the given key). + + _fromdb is used internally to signal that these values have been + loaded from the database. + + >>> UserTest() + UserTest() + >>> UserTest({'username': 'bill', 'hash': '4321'}) + UserTest(hash='4321', username='bill') + >>> u = UserTest(username='bob', hash='1234') + >>> u + UserTest(hash='1234', username='bob') + >>> print u.save(_test=True) + INSERT INTO users (username, hash) VALUES ('bob', '1234') + + >>> u = UserTest(5, _test=True) + SELECT * FROM users WHERE id = 5 + >>> u = UserTest('bob', _test=True) + SELECT * FROM users WHERE username = 'bob' + + >>> u = UserTest('baduser', _test=[{}, {}]) + Traceback (most recent call last): + ... + KeyError: "no users (or more than one) with username of 'baduser'" + """ + self._changed = set() + self._init_columns() + if _init is None: + _init = kwargs + elif not isinstance(_init, dict): + key_name = self._primary_key if isinstance(_init, int) else self._secondary_key + select = web.select(self._table, where='%s = $key' % key_name, vars={'key': _init}, _test=_test) + if _test: + print select + select = _test if not isinstance(_test, bool) else [{}] + rows = list(select) + if len(rows) != 1: + raise KeyError('no %s (or more than one) with %s of %r' % (self._table, key_name, _init)) + _init = rows[0] + self.setattrs(_init) + if _fromdb: + self._changed.clear() + + @classmethod + def get(cls, key, _test=False): + """ Get and return a single row from the database given a primary or + secondary key, returning None if no rows match the given key + (unlike __init__, which raises a KeyError). + + >>> u = UserTest.get(5, _test=True) + SELECT * FROM users WHERE id = 5 + >>> u = UserTest.get('bob', _test=True) + SELECT * FROM users WHERE username = 'bob' + >>> print UserTest.get('baduser', _test=[{}, {}]) + SELECT * FROM users WHERE username = 'baduser' + None + """ + try: + return cls(key, _fromdb=True, _test=_test) + except KeyError: + return None + + @classmethod + def select(cls, _test=False, **kwargs): + """ Select and return multiple rows from the database via the web.py + SQL-like query given via kwargs. For example: + + >>> print UserTest.select(where='username LIKE $u', vars={'u': 'jo%'}, order='username', limit=5, _test=True) + SELECT * FROM users WHERE username LIKE 'jo%' ORDER BY username LIMIT 5 + """ + select = web.select(cls._table, _test=_test, **kwargs) + return select if _test else [cls(row, _fromdb=True) for row in select] + + @classmethod + def _column_sql(cls, name, column): + """ Return the SQL (column_sql, index_sql) required to create the given + column and its index, if any. """ + sql = '%s %s%s' % (name, column._sql_type, ' ' + column._constraints if column._constraints else '') + index = '' + if column._indexed: + func = name if isinstance(column._indexed, bool) else column._indexed + index = 'CREATE INDEX %s_%s_idx ON %s (%s);' % (cls._table, name, cls._table, func) + else: + index = '' + return sql, index + + @classmethod + def create(cls, _test=False): + """ Create the table and its indexes based on the column description. + + >>> print UserTest.create(_test=True) + CREATE TABLE users ( + hash TEXT, + id SERIAL PRIMARY KEY, + time TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL, + username TEXT NOT NULL UNIQUE); + CREATE INDEX users_username_idx ON users (username); + """ + columns = [] + indexes = [] + for name, value in sorted(cls.__dict__.iteritems()): + if isinstance(value, Column): + column, index = cls._column_sql(name, value) + columns.append(' ' + column) + if index: + indexes.append(index) + sql = 'CREATE TABLE %s (\n' % cls._table + ',\n'.join(columns) + ');\n' + '\n'.join(indexes) + return web.query(sql, _test=_test) + + @classmethod + def add_column(cls, name, _test=False): + """ Add a column to the table (after table has been created). + + >>> print UserTest.add_column('username', _test=True) + ALTER TABLE users ADD COLUMN username TEXT NOT NULL UNIQUE; + CREATE INDEX users_username_idx ON users (username); + >>> print UserTest.add_column('hash', _test=True) + ALTER TABLE users ADD COLUMN hash TEXT; + """ + column = getattr(cls, name) + column, index = cls._column_sql(name, column) + sql = 'ALTER TABLE %s ADD COLUMN %s;' % (cls._table, column) + if index: + sql += '\n' + index + return web.query(sql, _test=_test) + + def save(self, _test=False): + """ Save this row to the database: update row (only changed fields) if + primary key attribute has been set, otherwise insert a new row. + + >>> u = UserTest(username='bob', hash='asdf') + >>> print u.save(_test=True) + INSERT INTO users (username, hash) VALUES ('bob', 'asdf') + >>> u = UserTest(id=5) + >>> u.username = 'bill' + >>> print u.save(_test=True) + UPDATE users SET username = 'bill' WHERE id = 5 + """ + if not isinstance(getattr(self, self._primary_key), (Column, type(None))): + return self.update(key_name=self._primary_key, _test=_test) + else: + return self.insert(_test=_test) + + def insert(self, _test=False): + """ Insert current row as a new row into the database. + + >>> u = UserTest(username='bob', hash='asdf') + >>> print u.insert(_test=True) + INSERT INTO users (username, hash) VALUES ('bob', 'asdf') + """ + changed = self._changed_values() + if self._primary_key in changed: + del changed[self._primary_key] + return web.insert(self._table, _test=_test, **changed) + + def update(self, key_name=None, _test=False): + """ Update row (only changed fields). Primary key is used unless + another key_name is specified. + + >>> u = UserTest(id=5) + >>> u.username = 'bill' + >>> print u.update(_test=True) + UPDATE users SET username = 'bill' WHERE id = 5 + """ + changed = self._changed_values() + if key_name is None: + key_name = self._primary_key + if key_name in changed: + del changed[key_name] + return web.update(self._table, where='%s = $key' % key_name, vars={'key': getattr(self, key_name)}, _test=_test, **changed) + + def delete(self, _test=False): + """ Delete this row based on its primary key. + + >>> u = UserTest(id=5) + >>> print u.delete(_test=True) + DELETE FROM users WHERE id = 5 + """ + key_name = self._primary_key + return web.delete(self._table, where='%s = $key' % key_name, vars={'key': getattr(self, key_name)}, _test=_test) + + def setattrs(self, d): + """ Set fields of self from key/value pairs in given dict. + + >>> u = UserTest() + >>> u + UserTest() + >>> u.setattrs({'username': 'bob', 'hash': '1234'}) + >>> u + UserTest(hash='1234', username='bob') + """ + for name, value in d.iteritems(): + setattr(self, name, value) + + def _init_columns(self): + """ Initialise self's columns list and primary/secondary key fields. """ + cls = self.__class__ + self._columns = [] + for name, value in sorted(cls.__dict__.iteritems()): + if isinstance(value, Column): + self._columns.append((name, value)) + if value._primary_key: + self._primary_key = name + elif value._secondary_key: + self._secondary_key = name + + def __setattr__(self, name, value): + """ Override __setattr__ so we can tell which values have been changed + for insert or update. + """ + object.__setattr__(self, name, value) + if isinstance(getattr(self.__class__, name, None), Column): + self._changed.add(name) + + def _changed_values(self): + """ Return a list of changed values as (name, value) pairs. """ + return dict((name, getattr(self, name)) for name in self._changed) + + def __str__(self): + """ Return a more or less human-readable string representation of + given row, showing all fields that have been set. + + >>> u = UserTest(id=5) + >>> u.username = 'bob' + >>> u.hash = 'asdf' + >>> print str(u) + UserTest(hash='asdf', id=5, username='bob') + """ + args = [] + for name, column in self._columns: + value = getattr(self, name) + if not isinstance(value, Column): + args.append('%s=%r' % (name, value)) + return '%s(%s)' % (self.__class__.__name__, ', '.join(args)) + + __repr__ = __str__ + +class UserTest(Table): + """ Example "users" table, used by doctests. """ + _table = 'users' + id = Serial(primary_key=True) + username = String(secondary_key=True) + hash = String() + time = Timestamp(not_null=True, default='now()') + +if __name__ == '__main__': + import doctest + doctest.testmod()