From b4bbc78b58320a2d055ad02f1d4a6e19b8340062 Mon Sep 17 00:00:00 2001 From: Romain Dorgueil Date: Tue, 8 Nov 2016 13:37:55 +0100 Subject: [PATCH] models/getters: refactor Getter/Filter to be more flexible, use it in TextDimensionMixin. --- windflow/models/mixins.py | 16 ++++++------- windflow/models/utils.py | 48 ++++++++++++++++++++++++++------------- 2 files changed, 39 insertions(+), 25 deletions(-) diff --git a/windflow/models/mixins.py b/windflow/models/mixins.py index 149376c..019c42c 100644 --- a/windflow/models/mixins.py +++ b/windflow/models/mixins.py @@ -1,6 +1,7 @@ from sqlalchemy import Column, DateTime, Integer, String, func from sqlalchemy.ext.declarative import declared_attr -from windflow.models.utils import modelmethod +from windflow.models.utils import Filter, Getter +from windflow.utils import generate_repr_method, generate_str_method class TimestampableMixin(): @@ -25,12 +26,9 @@ class TextDimensionMixin(TimestampableMixin): def __tablename__(cls): return 'dim_' + cls.__name__.lower() - @modelmethod - def get(cls, session, value, mock=False): - obj = ((not mock) and session.query(cls).filter_by(value=value).first()) or cls(value=value) - if session: - session.add(obj) - return obj + @Getter(Filter('value', str)) + def get(cls, session, filters): + return session.query(cls).filter_by(**filters).first() - def __str__(self): - return self.value + __str__ = generate_str_method('value') + __repr__ = generate_repr_method('value') diff --git a/windflow/models/utils.py b/windflow/models/utils.py index 0c8581d..d7c3f88 100644 --- a/windflow/models/utils.py +++ b/windflow/models/utils.py @@ -57,10 +57,11 @@ def getter(cls, session, *values, create=True, **defaults): class Filter: - def __init__(self, name, instanceof, factory=None): + def __init__(self, name, instanceof, factory=None, required=True): self.name = name self.instanceof = instanceof self.factory = factory or instanceof + self.required = required def __call__(self, session, value, **kwargs): try: @@ -78,15 +79,6 @@ def _get_cache_dict(cls, session): return cache_holder._unique_cache -def _apply_filters(filters, values, create=True, session=None): - return { - filter.name: ( - values[i] if isinstance(values[i], filter.instanceof) - else filter(session, values[i], create=create) - ) for i, filter in enumerate(filters) - } - - class Getter: """ Flexible decorator to generate model getters. @@ -104,16 +96,40 @@ def mark_as_getter(cls, f): f.apply = lambda *args, **kwargs: functools.partial(f, *args, **kwargs) return f - def __wrapped_call__(self, model, session, *values, create=True, **defaults): + def apply_filters(self, filters, values, create=True, session=None, **defaults): + filtered = {} + i = 0 + + for filter in filters: + if filter.required: + try: + value = values[i] + except IndexError as e: + raise TypeError('Value for required filter "{}" is missing.'.format(filter.name)) + i += 1 + elif filter.name in defaults: + value = defaults.pop(filter.name) + else: + continue + + if filter.instanceof and isinstance(value, filter.instanceof): + filtered[filter.name] = value + else: + filtered[filter.name] = filter(session, value, create=create) + + return filtered, defaults + + def __wrapped_call__(self, model, session, *values, create=True, **named_values_and_defaults): assert session is None or isinstance(session, ( Session, scoped_session)), 'If provided, session should be an sqlalchemy session object.' - # cache ? - cache = _get_cache_dict(model, session) - cache_key = (model,) + values - # compute filter values (includes call to related models) - values = _apply_filters(self.filters, values, create=create, session=session) + values, defaults = self.apply_filters(self.filters, values, create=create, session=session, + **named_values_and_defaults) + + # get cache dictionary and key + cache = _get_cache_dict(model, session) + cache_key = (model,) + tuple(values.items()) # if no cache, delegate to real (decorated) Getter method if not cache_key in cache: