Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

First stab at django-like generic foreign keys

  • Loading branch information...
commit 075bf0fc92a88882bc1a6c892589a65818e7bb8b 1 parent 2fc4e82
Charles Leifer authored
Showing with 237 additions and 1 deletion.
  1. +98 −0 playhouse/gfk.py
  2. +137 −0 playhouse/tests_gfk.py
  3. +2 −1  runtests.py
98 playhouse/gfk.py
View
@@ -0,0 +1,98 @@
+"""
+class Tag(Model):
+ tag = CharField()
+ object_type = CharField(null=True)
+ object_id = IntegerField(null=True)
+ object = GFKField('object_type', 'object_id')
+
+class Blog(Model):
+ tags = ReverseGFK(Tag, 'object_type', 'object_id')
+
+tag.object -> should be a blog
+blog.tags -> select query of tags for ``blog`` instance
+Blog.tags -> select query of all tags for Blog instances
+"""
+from peewee import *
+from peewee import FieldDescriptor, SelectQuery, UpdateQuery, Model as _Model, BaseModel as _BaseModel
+
+
+all_models = set()
+table_cache = {}
+
+
+class BaseModel(_BaseModel):
+ def __new__(cls, name, bases, attrs):
+ cls = super(BaseModel, cls).__new__(cls, name, bases, attrs)
+ all_models.add(cls)
+ return cls
+
+class Model(_Model):
+ __metaclass__ = BaseModel
+
+def get_model(tbl_name):
+ if tbl_name not in table_cache:
+ for model in all_models:
+ if model._meta.db_table == tbl_name:
+ table_cache[tbl_name] = model
+ break
+ return table_cache.get(tbl_name)
+
+class GFKField(object):
+ def __init__(self, model_type_field='object_type', model_id_field='object_id'):
+ self.model_type_field = model_type_field
+ self.model_id_field = model_id_field
+ self.att_name = '.'.join((self.model_type_field, self.model_id_field))
+
+ def __get__(self, instance, instance_type=None):
+ if instance:
+ if self.att_name not in instance._obj_cache:
+ inst_data = instance._data
+ if inst_data.get(self.model_type_field) and inst_data.get(self.model_id_field):
+ tbl_name = instance._data[self.model_type_field]
+ model_class = get_model(tbl_name)
+ if not model_class:
+ raise AttributeError('Model for table "%s" not found in GFK lookup' % tbl_name)
+
+ instance._obj_cache[self.att_name] = model_class.select().where(
+ model_class._meta.primary_key == instance._data[self.model_id_field]
+ ).get()
+ return instance._obj_cache.get(self.att_name)
+ return self.field
+
+ def __set__(self, instance, value):
+ instance._obj_cache[self.att_name] = value
+ instance._data[self.model_type_field] = value._meta.db_table
+ instance._data[self.model_id_field] = value.get_id()
+
+class ReverseGFK(object):
+ def __init__(self, model, model_type_field='object_type', model_id_field='object_id'):
+ self.model_class = model
+ self.model_type_field = model._meta.fields[model_type_field]
+ self.model_id_field = model._meta.fields[model_id_field]
+
+ def __get__(self, instance, instance_type=None):
+ if instance:
+ return self.model_class.select().where(
+ (self.model_type_field == instance._meta.db_table) &
+ (self.model_id_field == instance.get_id())
+ )
+ else:
+ return self.model_class.select().where(
+ self.model_type_field == instance_type._meta.db_table
+ )
+
+ def __set__(self, instance, value):
+ mtv = instance._meta.db_table
+ miv = instance.get_id()
+ if isinstance(value, SelectQuery) and value.model_class == self.model_class:
+ uq = UpdateQuery(self.model_class, {
+ self.model_type_field: mtv,
+ self.model_id_field: miv,
+ }).where(value._where).execute()
+ elif all(map(lambda i: isinstance(i, self.model_class), value)):
+ for obj in value:
+ setattr(obj, self.model_type_field.name, mtv)
+ setattr(obj, self.model_id_field.name, miv)
+ obj.save()
+ else:
+ raise ValueError('ReverseGFK field unable to handle "%s"' % value)
137 playhouse/tests_gfk.py
View
@@ -0,0 +1,137 @@
+import unittest
+
+from peewee import *
+from playhouse.gfk import *
+
+
+db = SqliteDatabase(':memory:')
+
+class BaseModel(Model):
+ class Meta:
+ database = db
+
+ def add_tag(self, tag):
+ t = Tag(tag=tag)
+ t.object = self
+ t.save()
+ return t
+
+class Tag(BaseModel):
+ tag = CharField()
+
+ object_type = CharField(null=True)
+ object_id = IntegerField(null=True)
+ object = GFKField()
+
+ class Meta:
+ order_by = ('tag',)
+
+
+class Appetizer(BaseModel):
+ name = CharField()
+ tags = ReverseGFK(Tag)
+
+class Entree(BaseModel):
+ name = CharField()
+ tags = ReverseGFK(Tag)
+
+class Dessert(BaseModel):
+ name = CharField()
+ tags = ReverseGFK(Tag)
+
+
+
+class GFKTestCase(unittest.TestCase):
+ data = {
+ Appetizer: (
+ ('wings', ('fried', 'spicy')),
+ ('mozzarella sticks', ('fried', 'sweet')),
+ ('potstickers', ('fried',)),
+ ('edamame', ('salty',)),
+ ),
+ Entree: (
+ ('phad thai', ('spicy',)),
+ ('fried chicken', ('fried', 'salty')),
+ ('tacos', ('fried', 'spicy')),
+ ),
+ Dessert: (
+ ('sundae', ('sweet',)),
+ ('churro', ('fried', 'sweet')),
+ )
+ }
+ def setUp(self):
+ Tag.create_table(True)
+ Appetizer.create_table(True)
+ Entree.create_table(True)
+ Dessert.create_table(True)
+
+ def tearDown(self):
+ Tag.drop_table()
+ Appetizer.drop_table()
+ Entree.drop_table()
+ Dessert.drop_table()
+
+ def create(self):
+ for model, foods in self.data.items():
+ for name, tags in foods:
+ inst = model.create(name=name)
+ for tag in tags:
+ inst.add_tag(tag)
+
+ def test_creation(self):
+ t = Tag.create(tag='a tag')
+ t.object = t
+ t.save()
+
+ t_db = Tag.get(Tag.id == t.id)
+ self.assertEqual(t_db.object_id, t_db.get_id())
+ self.assertEqual(t_db.object_type, 'tag')
+ self.assertEqual(t_db.object, t_db)
+
+ def test_gfk_api(self):
+ self.create()
+
+ # test instance api
+ for model, foods in self.data.items():
+ for food, tags in foods:
+ inst = model.get(model.name == food)
+ self.assertEqual([t.tag for t in inst.tags], list(tags))
+
+ # test class api and ``object`` api
+ apps_tags = [(t.tag, t.object.name) for t in Appetizer.tags.order_by(Tag.id)]
+ data_tags = []
+ for food, tags in self.data[Appetizer]:
+ for t in tags:
+ data_tags.append((t, food))
+
+ self.assertEqual(apps_tags, data_tags)
+
+ def test_missing(self):
+ t = Tag.create(tag='sour')
+ self.assertEqual(t.object, None)
+
+ t.object_type = 'appetizer'
+ t.object_id = 1
+ # accessing the descriptor will raise a DoesNotExist
+ self.assertRaises(Appetizer.DoesNotExist, getattr, t, 'object')
+
+ t.object_type = 'unknown'
+ t.object_id = 1
+ self.assertRaises(AttributeError, getattr, t, 'object')
+
+ def test_set_reverse(self):
+ # assign query
+ e = Entree.create(name='phad thai')
+ s = Tag.create(tag='spicy')
+ p = Tag.create(tag='peanuts')
+ t = Tag.create(tag='thai')
+ b = Tag.create(tag='beverage')
+
+ e.tags = Tag.select().where(Tag.tag != 'beverage')
+ self.assertEqual([t.tag for t in e.tags], ['peanuts', 'spicy', 'thai'])
+
+ e = Entree.create(name='panang curry')
+ c = Tag.create(tag='coconut')
+
+ e.tags = [p, t, c, s]
+ self.assertEqual([t.tag for t in e.tags], ['coconut', 'peanuts', 'spicy', 'thai'])
3  runtests.py
View
@@ -32,8 +32,9 @@ def get_option_parser():
if options.all or options.extra:
modules = [tests]
- from playhouse import tests_signals
+ from playhouse import tests_signals, tests_gfk
modules.append(tests_signals)
+ modules.append(tests_gfk)
#from playhouse import tests as extras_tests
#modules.append(extras_tests)
Please sign in to comment.
Something went wrong with that request. Please try again.