Skip to content

Commit

Permalink
Kill ShardingKeyGenerationMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
epandurski committed Feb 10, 2019
1 parent cdb666b commit 69f613f
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 46 deletions.
1 change: 0 additions & 1 deletion flask_signalbus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@

from flask_signalbus.atomic import ( # noqa: F401
AtomicProceduresMixin,
ShardingKeyGenerationMixin,
)
70 changes: 31 additions & 39 deletions flask_signalbus/atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
creating consistent and correct database APIs.
"""

import os
import struct
from functools import wraps
from contextlib import contextmanager
from sqlalchemy.sql.expression import and_
from sqlalchemy.inspection import inspect
from sqlalchemy.exc import IntegrityError
from flask_signalbus.utils import DBSerializationError, retry_on_deadlock

__all__ = ['AtomicProceduresMixin', 'ShardingKeyGenerationMixin']
__all__ = ['AtomicProceduresMixin']


_ATOMIC_FLAG_SESSION_INFO_KEY = 'flask_signalbus__atomic_flag'
Expand Down Expand Up @@ -47,6 +45,34 @@ def _get_pk_values(cls, instance_or_pk):
instance_or_pk = inspect(cls).primary_key_from_instance(instance_or_pk)
return instance_or_pk if isinstance(instance_or_pk, tuple) else (instance_or_pk,)

@classmethod
def _conjure_instance(cls, *args, **kwargs):
"""Continuously try to create an instance, flush it to the database, and return it.
This is useful, for example, when a constructor is defined
that generates a random primary key, which is not guaranteed
to be unique.
Note: This method uses database savepoints to recover after
unsuccessful database flush. It will not work correctly on
databases that do not support savepoints.
"""

session = cls._flask_signalbus_sa.session
tries = kwargs.pop('__tries', 50)
for _ in range(tries):
instance = cls(*args, **kwargs)
session.begin_nested()
session.add(instance)
try:
session.commit()
except IntegrityError:
session.rollback()
continue
return instance
raise RuntimeError('Can not conjure an instance.')


class AtomicProceduresMixin(object):
"""Adds utility functions to :class:`~flask_sqlalchemy.SQLAlchemy` and the declarative base.
Expand Down Expand Up @@ -125,7 +151,7 @@ def wrapper(*args, **kwargs):

return wrapper

def execute_atomic(self, __func__, *args, **kwargs):
def execute_atomic(self, __func, *args, **kwargs):
"""A decorator that executes a function in an atomic block.
Example::
Expand All @@ -149,7 +175,7 @@ def result():
"""

return self.atomic(__func__)(*args, **kwargs)
return self.atomic(__func)(*args, **kwargs)

@contextmanager
def retry_on_integrity_error(self):
Expand Down Expand Up @@ -183,37 +209,3 @@ def retry_on_integrity_error(self):
session.flush()
except IntegrityError:
raise DBSerializationError


class ShardingKeyGenerationMixin(object):
"""Adds random sharding key generation functionality to a model.
The model should be defined as follows::
class SomeModelName(ShardingKeyGenerationMixin, db.Model):
sharding_key_value = db.Column(db.BigInteger, primary_key=True, autoincrement=False)
"""

def __init__(self, sharding_key_value=None):
modulo = 1 << 63
if sharding_key_value is None:
sharding_key_value = struct.unpack('>q', os.urandom(8))[0] % modulo or 1
assert 0 < sharding_key_value < modulo
self.sharding_key_value = sharding_key_value

@classmethod
def generate(cls, sharding_key_value=None, tries=50):
"""Create a unique instance and return its `sharding_key_value`."""

session = cls._flask_signalbus_sa.session
for _ in range(tries):
instance = cls(sharding_key_value=sharding_key_value)
session.begin_nested()
session.add(instance)
try:
session.commit()
except IntegrityError:
session.rollback()
continue
return instance.sharding_key_value
raise RuntimeError('Can not generate a unique sharding key.')
Empty file removed tests/__init__.py
Empty file.
8 changes: 5 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,12 @@ class NonSignal(db.Model):


@pytest.fixture
def ShardingKey(db):
class ShardingKey(fsb.ShardingKeyGenerationMixin, db.Model):
def ShardingKey(atomic_db):
db = atomic_db

class ShardingKey(db.Model):
__tablename__ = 'test_sharding_key'
sharding_key_value = db.Column(db.BigInteger, primary_key=True, autoincrement=False)
id = db.Column(db.Integer, primary_key=True)

db.create_all()
yield ShardingKey
Expand Down
5 changes: 3 additions & 2 deletions tests/test_atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def test_get_pk_values(atomic_db, AtomicModel):
assert AtomicModel._get_pk_values((1,)) == (1,)


@pytest.mark.skip('SQLite does not support savepoints')
def test_create_sharding_key(ShardingKey):
assert ShardingKey().sharding_key_value
assert hasattr(ShardingKey, 'generate')
id_ = ShardingKey._conjure_instance().id
assert type(id_) is int
2 changes: 1 addition & 1 deletion tests/test_signalbus.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import flask_sqlalchemy as fsa
from flask_signalbus import SignalBus
from .conftest import SignalBusAlchemy
from conftest import SignalBusAlchemy


def test_create_signalbus_alchemy(app):
Expand Down

0 comments on commit 69f613f

Please sign in to comment.