Skip to content

Commit

Permalink
Create signalbus.py, move code all over the place
Browse files Browse the repository at this point in the history
  • Loading branch information
epandurski committed Feb 9, 2019
1 parent d4fe69a commit 3075497
Show file tree
Hide file tree
Showing 9 changed files with 322 additions and 303 deletions.
275 changes: 9 additions & 266 deletions flask_signalbus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,266 +1,9 @@
"""
Adds to Flask-SQLAlchemy the capability to atomically send
messages (signals) over a message bus.
"""

import time
import logging
from sqlalchemy import event, inspect, and_
from sqlalchemy.exc import DBAPIError
from flask_signalbus.utils import retry_on_deadlock, get_db_error_code, DEADLOCK_ERROR_CODES, DBSerializationError

__all__ = ['SignalBus', 'SignalBusMixin']


SIGNALS_TO_FLUSH_SESSION_INFO_KEY = 'flask_signalbus__signals_to_flush'
FLUSHMANY_LIMIT = 1000


def _raise_error_if_not_signal_model(model):
if not hasattr(model, 'send_signalbus_message'):
raise RuntimeError(
'{} can not be flushed because it does not have a'
' "send_signalbus_message" method.'
)


class SignalBusMixin(object):
"""A **mixin class** that can be used to extend the
`flask_sqlalchemy.SQLAlchemy` class to handle signals.
For example::
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_signalbus import SignalBusMixin
class CustomSQLAlchemy(SignalBusMixin, SQLAlchemy):
pass
app = Flask(__name__)
db = CustomSQLAlchemy(app)
db.signalbus.flush()
"""

def init_app(self, app, *args, **kwargs):
super(SignalBusMixin, self).init_app(app, *args, **kwargs)
self.signalbus._init_app(app)

@property
def signalbus(self):
"""The associated `SignalBus` object."""

try:
signalbus = self.__signalbus
except AttributeError:
signalbus = self.__signalbus = SignalBus(self, init_app=False)
return signalbus


class SignalBus(object):
"""Instances of this class automatically send signal messages that
have been recorded in the SQL database, over a message
bus. Normally, the sending of the recorded messages (if there are
any) is done after each transaction commit, but it also can be
triggered explicitly by a command.
:param db: The `flask_sqlalchemy.SQLAlchemy` instance
For example::
from flask_sqlalchemy import SQLAlchemy
from flask_signalbus import SignalBus
app = Flask(__name__)
db = SQLAlchemy(app)
signalbus = SignalBus(db)
signalbus.flush()
"""

def __init__(self, db, init_app=True):
self.db = db
self.signal_session = self.db.create_scoped_session({'expire_on_commit': False})
self.logger = logging.getLogger(__name__)
self._autoflush = True
retry = retry_on_deadlock(self.signal_session, retries=10, max_wait=1.0)
self._flush_signals_with_retry = retry(self._flush_signals)
event.listen(self.db.session, 'transient_to_pending', self._transient_to_pending_handler)
event.listen(self.db.session, 'after_commit', self._safe_after_commit_handler)
event.listen(self.db.session, 'after_rollback', self._after_rollback_handler)
if init_app:
if db.app is None:
raise RuntimeError(
'No application found. The SQLAlchemy instance passed to'
' SignalBus should be constructed with an application.'
)
self._init_app(db.app)

@property
def autoflush(self):
"""Setting this property to `False` instructs the `SignalBus` instance
to not automatically flush pending signals after each
transaction commit. Setting it back to `True` restores the
default behavior.
"""

return self._autoflush

@autoflush.setter
def autoflush(self, value):
self._autoflush = bool(value)

def get_signal_models(self):
"""Return all signal types in a list.
:rtype: list(`signal-model`)
"""

base = self.db.Model
return [
cls for cls in base._decl_class_registry.values() if (
isinstance(cls, type)
and issubclass(cls, base)
and hasattr(cls, 'send_signalbus_message')
)
]

def flush(self, models=None, wait=3.0):
"""Send all pending signals over the message bus.
:param models: If passed, flushes only signals of the specified types.
:type models: list(`signal-model`) or `None`
:param float wait: The number of seconds the method will wait
after obtaining the list of pending signals, to allow
concurrent senders to complete
:return: The total number of signals that have been sent
"""

models_to_flush = self.get_signal_models() if models is None else models
pks_to_flush = {}
try:
for model in models_to_flush:
_raise_error_if_not_signal_model(model)
m = inspect(model)
pk_attrs = [m.get_property_by_column(c).class_attribute for c in m.primary_key]
pks_to_flush[model] = self.signal_session.query(*pk_attrs).all()
self.signal_session.rollback()
time.sleep(wait)
return sum(
self._flush_signals_with_retry(model, pk_values_set=set(pks_to_flush[model]))
for model in models_to_flush
)
finally:
self.signal_session.remove()

def flushmany(self):
"""Send a potentially huge number of pending signals over the message bus.
This method assumes that the number of pending signals might
be huge, so that they might not fit into memory. However,
`SignalBus.flushmany` is not very smart in handling concurrent
senders. It is mostly useful when recovering from long periods
of disconnectedness from the message bus.
:return: The total number of signals that have been sent
"""

models_to_flush = self.get_signal_models()
try:
return sum(self._flushmany_signals(model) for model in models_to_flush)
finally:
self.signal_session.remove()

def _init_app(self, app):
from . import cli

if not hasattr(app, 'extensions'):
app.extensions = {}
if app.extensions.get('signalbus') not in [None, self]:
raise RuntimeError('Can not attach more than one SignalBus to one application.')
app.extensions['signalbus'] = self
app.cli.add_command(cli.signalbus)

@app.teardown_appcontext
def shutdown_signal_session(response_or_exc):
self.signal_session.remove()
return response_or_exc

def _transient_to_pending_handler(self, session, instance):
model = type(instance)
if hasattr(model, 'send_signalbus_message') and getattr(model, 'signalbus_autoflush', True):
signals_to_flush = session.info.setdefault(SIGNALS_TO_FLUSH_SESSION_INFO_KEY, set())
signals_to_flush.add(instance)

def _after_commit_handler(self, session):
signals_to_flush = session.info.pop(SIGNALS_TO_FLUSH_SESSION_INFO_KEY, set())
if self.autoflush and signals_to_flush:
signals = [self.signal_session.merge(s, load=False) for s in signals_to_flush]
for signal in signals:
try:
signal.send_signalbus_message()
except Exception:
self.logger.exception('Caught error while sending %s.', signal)
self.signal_session.rollback()
return
self.signal_session.delete(signal)
self.signal_session.commit()
self.signal_session.expire_all()

def _after_rollback_handler(self, session):
session.info.pop(SIGNALS_TO_FLUSH_SESSION_INFO_KEY, None)

def _safe_after_commit_handler(self, session):
try:
return self._after_commit_handler(session)
except DBAPIError as e:
if get_db_error_code(e.orig) not in DEADLOCK_ERROR_CODES:
self.logger.exception('Caught database error during autoflush.')
self.signal_session.rollback()

def _get_lock_for_update(self, pk_attrs, pk_values):
clause = and_(*[attr == value for attr, value in zip(pk_attrs, pk_values)])
query = self.signal_session.query(*pk_attrs).filter(clause).with_for_update()
return query.one_or_none() is not None

def _flushmany_signals(self, model):
self.logger.warning('Flushing %s in "flushmany" mode.', model.__name__)
sent_count = 0
while True:
n = self._flush_signals(model, max_count=FLUSHMANY_LIMIT)
sent_count += n
if n < FLUSHMANY_LIMIT:
break
return sent_count

def _flush_signals(self, model, pk_values_set=None, max_count=None):
query = self.signal_session.query(model)
if max_count is None:
self.logger.info('Flushing %s.', model.__name__)
else:
query = query.limit(max_count)
signals = query.all()
self.signal_session.commit()
burst_count = int(getattr(model, 'signalbus_burst_count', 1))
sent_count = 0
m = inspect(model)
pk_attrs = [m.get_property_by_column(c).class_attribute for c in m.primary_key]
for signal in signals:
pk_values = m.primary_key_from_instance(signal)
if pk_values_set is None or pk_values in pk_values_set:
if not self._get_lock_for_update(pk_attrs, pk_values):
# The row has been deleted by a concurrent sender.
continue
signal.send_signalbus_message()
self.signal_session.delete(signal)
sent_count += 1
if sent_count % burst_count == 0:
self.signal_session.commit()
self.signal_session.commit()
self.signal_session.expire_all()
return sent_count
from flask_signalbus.signalbus import ( # noqa: F401
SignalBus,
SignalBusMixin,
)

from flask_signalbus.atomic import ( # noqa: F401
AtomicProceduresMixin,
ShardingKeyGenerationMixin,
)
20 changes: 14 additions & 6 deletions flask_signalbus/atomic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
"""
Adds to Flask-SQLAlchemy simple but powerful utilities for
creating consistent and correct database APIs.
"""

import os
import struct
from functools import wraps
from contextlib import contextmanager
from sqlalchemy.sql.expression import and_
from sqlalchemy.inspection import inspect
from sqlalchemy.exc import IntegrityError
from flask_signalbus import DBSerializationError, retry_on_deadlock
from flask_signalbus.utils import DBSerializationError, retry_on_deadlock

__all__ = ['AtomicProceduresMixin', 'ShardingKeyGenerationMixin']


ATOMIC_FLAG_SESSION_INFO_KEY = 'flask_signalbus__atomic_flag'
_ATOMIC_FLAG_SESSION_INFO_KEY = 'flask_signalbus__atomic_flag'


@contextmanager
Expand Down Expand Up @@ -110,10 +118,10 @@ def result():

session = self.session
session_info = session.info
assert not session_info.get(ATOMIC_FLAG_SESSION_INFO_KEY), \
assert not session_info.get(_ATOMIC_FLAG_SESSION_INFO_KEY), \
'"execute_atomic" calls can not be nested'
func = retry_on_deadlock(session)(__func__)
session_info[ATOMIC_FLAG_SESSION_INFO_KEY] = True
session_info[_ATOMIC_FLAG_SESSION_INFO_KEY] = True
try:
result = func(*args, **kwargs)
session.commit()
Expand All @@ -122,7 +130,7 @@ def result():
session.rollback()
raise
finally:
session_info[ATOMIC_FLAG_SESSION_INFO_KEY] = False
session_info[_ATOMIC_FLAG_SESSION_INFO_KEY] = False

def modification(self, func):
"""Raise assertion error if `func` is called outside of atomic block.
Expand All @@ -136,7 +144,7 @@ def modification(self, func):

@wraps(func)
def wrapper(*args, **kwargs):
assert self.session.info.get(ATOMIC_FLAG_SESSION_INFO_KEY), \
assert self.session.info.get(_ATOMIC_FLAG_SESSION_INFO_KEY), \
'calls to "{}" must be wrapped in "execute_atomic"'.format(func.__name__)
return func(*args, **kwargs)

Expand Down

0 comments on commit 3075497

Please sign in to comment.