Skip to content
This repository
Browse code

First stab at django-like generic foreign keys

  • Loading branch information...
commit 075bf0fc92a88882bc1a6c892589a65818e7bb8b 1 parent 2fc4e82
Charles Leifer authored October 17, 2012
98  playhouse/gfk.py
... ...
@@ -0,0 +1,98 @@
  1
+"""
  2
+class Tag(Model):
  3
+    tag = CharField()
  4
+    object_type = CharField(null=True)
  5
+    object_id = IntegerField(null=True)
  6
+    object = GFKField('object_type', 'object_id')
  7
+
  8
+class Blog(Model):
  9
+    tags = ReverseGFK(Tag, 'object_type', 'object_id')
  10
+
  11
+tag.object -> should be a blog
  12
+blog.tags -> select query of tags for ``blog`` instance
  13
+Blog.tags -> select query of all tags for Blog instances
  14
+"""
  15
+from peewee import *
  16
+from peewee import FieldDescriptor, SelectQuery, UpdateQuery, Model as _Model, BaseModel as _BaseModel
  17
+
  18
+
  19
+all_models = set()
  20
+table_cache = {}
  21
+
  22
+
  23
+class BaseModel(_BaseModel):
  24
+    def __new__(cls, name, bases, attrs):
  25
+        cls = super(BaseModel, cls).__new__(cls, name, bases, attrs)
  26
+        all_models.add(cls)
  27
+        return cls
  28
+
  29
+class Model(_Model):
  30
+    __metaclass__ = BaseModel
  31
+
  32
+def get_model(tbl_name):
  33
+    if tbl_name not in table_cache:
  34
+        for model in all_models:
  35
+            if model._meta.db_table == tbl_name:
  36
+                table_cache[tbl_name] = model
  37
+                break
  38
+    return table_cache.get(tbl_name)
  39
+
  40
+class GFKField(object):
  41
+    def __init__(self, model_type_field='object_type', model_id_field='object_id'):
  42
+        self.model_type_field = model_type_field
  43
+        self.model_id_field = model_id_field
  44
+        self.att_name = '.'.join((self.model_type_field, self.model_id_field))
  45
+
  46
+    def __get__(self, instance, instance_type=None):
  47
+        if instance:
  48
+            if self.att_name not in instance._obj_cache:
  49
+                inst_data = instance._data
  50
+                if inst_data.get(self.model_type_field) and inst_data.get(self.model_id_field):
  51
+                    tbl_name = instance._data[self.model_type_field]
  52
+                    model_class = get_model(tbl_name)
  53
+                    if not model_class:
  54
+                        raise AttributeError('Model for table "%s" not found in GFK lookup' % tbl_name)
  55
+
  56
+                    instance._obj_cache[self.att_name] = model_class.select().where(
  57
+                        model_class._meta.primary_key == instance._data[self.model_id_field]
  58
+                    ).get()
  59
+            return instance._obj_cache.get(self.att_name)
  60
+        return self.field
  61
+
  62
+    def __set__(self, instance, value):
  63
+        instance._obj_cache[self.att_name] = value
  64
+        instance._data[self.model_type_field] = value._meta.db_table
  65
+        instance._data[self.model_id_field] = value.get_id()
  66
+
  67
+class ReverseGFK(object):
  68
+    def __init__(self, model, model_type_field='object_type', model_id_field='object_id'):
  69
+        self.model_class = model
  70
+        self.model_type_field = model._meta.fields[model_type_field]
  71
+        self.model_id_field = model._meta.fields[model_id_field]
  72
+
  73
+    def __get__(self, instance, instance_type=None):
  74
+        if instance:
  75
+            return self.model_class.select().where(
  76
+                (self.model_type_field == instance._meta.db_table) &
  77
+                (self.model_id_field == instance.get_id())
  78
+            )
  79
+        else:
  80
+            return self.model_class.select().where(
  81
+                self.model_type_field == instance_type._meta.db_table
  82
+            )
  83
+
  84
+    def __set__(self, instance, value):
  85
+        mtv = instance._meta.db_table
  86
+        miv = instance.get_id()
  87
+        if isinstance(value, SelectQuery) and value.model_class == self.model_class:
  88
+            uq = UpdateQuery(self.model_class, {
  89
+                self.model_type_field: mtv,
  90
+                self.model_id_field: miv,
  91
+            }).where(value._where).execute()
  92
+        elif all(map(lambda i: isinstance(i, self.model_class), value)):
  93
+            for obj in value:
  94
+                setattr(obj, self.model_type_field.name, mtv)
  95
+                setattr(obj, self.model_id_field.name, miv)
  96
+                obj.save()
  97
+        else:
  98
+            raise ValueError('ReverseGFK field unable to handle "%s"' % value)
137  playhouse/tests_gfk.py
... ...
@@ -0,0 +1,137 @@
  1
+import unittest
  2
+
  3
+from peewee import *
  4
+from playhouse.gfk import *
  5
+
  6
+
  7
+db = SqliteDatabase(':memory:')
  8
+
  9
+class BaseModel(Model):
  10
+    class Meta:
  11
+        database = db
  12
+
  13
+    def add_tag(self, tag):
  14
+        t = Tag(tag=tag)
  15
+        t.object = self
  16
+        t.save()
  17
+        return t
  18
+
  19
+class Tag(BaseModel):
  20
+    tag = CharField()
  21
+
  22
+    object_type = CharField(null=True)
  23
+    object_id = IntegerField(null=True)
  24
+    object = GFKField()
  25
+
  26
+    class Meta:
  27
+        order_by = ('tag',)
  28
+
  29
+
  30
+class Appetizer(BaseModel):
  31
+    name = CharField()
  32
+    tags = ReverseGFK(Tag)
  33
+
  34
+class Entree(BaseModel):
  35
+    name = CharField()
  36
+    tags = ReverseGFK(Tag)
  37
+
  38
+class Dessert(BaseModel):
  39
+    name = CharField()
  40
+    tags = ReverseGFK(Tag)
  41
+
  42
+
  43
+
  44
+class GFKTestCase(unittest.TestCase):
  45
+    data = {
  46
+        Appetizer: (
  47
+            ('wings', ('fried', 'spicy')),
  48
+            ('mozzarella sticks', ('fried', 'sweet')),
  49
+            ('potstickers', ('fried',)),
  50
+            ('edamame', ('salty',)),
  51
+        ),
  52
+        Entree: (
  53
+            ('phad thai', ('spicy',)),
  54
+            ('fried chicken', ('fried', 'salty')),
  55
+            ('tacos', ('fried', 'spicy')),
  56
+        ),
  57
+        Dessert: (
  58
+            ('sundae', ('sweet',)),
  59
+            ('churro', ('fried', 'sweet')),
  60
+        )
  61
+    }
  62
+    def setUp(self):
  63
+        Tag.create_table(True)
  64
+        Appetizer.create_table(True)
  65
+        Entree.create_table(True)
  66
+        Dessert.create_table(True)
  67
+
  68
+    def tearDown(self):
  69
+        Tag.drop_table()
  70
+        Appetizer.drop_table()
  71
+        Entree.drop_table()
  72
+        Dessert.drop_table()
  73
+
  74
+    def create(self):
  75
+        for model, foods in self.data.items():
  76
+            for name, tags in foods:
  77
+                inst = model.create(name=name)
  78
+                for tag in tags:
  79
+                    inst.add_tag(tag)
  80
+
  81
+    def test_creation(self):
  82
+        t = Tag.create(tag='a tag')
  83
+        t.object = t
  84
+        t.save()
  85
+
  86
+        t_db = Tag.get(Tag.id == t.id)
  87
+        self.assertEqual(t_db.object_id, t_db.get_id())
  88
+        self.assertEqual(t_db.object_type, 'tag')
  89
+        self.assertEqual(t_db.object, t_db)
  90
+
  91
+    def test_gfk_api(self):
  92
+        self.create()
  93
+
  94
+        # test instance api
  95
+        for model, foods in self.data.items():
  96
+            for food, tags in foods:
  97
+                inst = model.get(model.name == food)
  98
+                self.assertEqual([t.tag for t in inst.tags], list(tags))
  99
+
  100
+        # test class api and ``object`` api
  101
+        apps_tags = [(t.tag, t.object.name) for t in Appetizer.tags.order_by(Tag.id)]
  102
+        data_tags = []
  103
+        for food, tags in self.data[Appetizer]:
  104
+            for t in tags:
  105
+                data_tags.append((t, food))
  106
+
  107
+        self.assertEqual(apps_tags, data_tags)
  108
+
  109
+    def test_missing(self):
  110
+        t = Tag.create(tag='sour')
  111
+        self.assertEqual(t.object, None)
  112
+
  113
+        t.object_type = 'appetizer'
  114
+        t.object_id = 1
  115
+        # accessing the descriptor will raise a DoesNotExist
  116
+        self.assertRaises(Appetizer.DoesNotExist, getattr, t, 'object')
  117
+
  118
+        t.object_type = 'unknown'
  119
+        t.object_id = 1
  120
+        self.assertRaises(AttributeError, getattr, t, 'object')
  121
+
  122
+    def test_set_reverse(self):
  123
+        # assign query
  124
+        e = Entree.create(name='phad thai')
  125
+        s = Tag.create(tag='spicy')
  126
+        p = Tag.create(tag='peanuts')
  127
+        t = Tag.create(tag='thai')
  128
+        b = Tag.create(tag='beverage')
  129
+
  130
+        e.tags = Tag.select().where(Tag.tag != 'beverage')
  131
+        self.assertEqual([t.tag for t in e.tags], ['peanuts', 'spicy', 'thai'])
  132
+
  133
+        e = Entree.create(name='panang curry')
  134
+        c = Tag.create(tag='coconut')
  135
+
  136
+        e.tags = [p, t, c, s]
  137
+        self.assertEqual([t.tag for t in e.tags], ['coconut', 'peanuts', 'spicy', 'thai'])
3  runtests.py
@@ -32,8 +32,9 @@ def get_option_parser():
32 32
 
33 33
     if options.all or options.extra:
34 34
         modules = [tests]
35  
-        from playhouse import tests_signals
  35
+        from playhouse import tests_signals, tests_gfk
36 36
         modules.append(tests_signals)
  37
+        modules.append(tests_gfk)
37 38
 
38 39
         #from playhouse import tests as extras_tests
39 40
         #modules.append(extras_tests)

0 notes on commit 075bf0f

Please sign in to comment.
Something went wrong with that request. Please try again.