diff --git a/botstory/ast/library.py b/botstory/ast/library.py index d57b25f..f484f7d 100644 --- a/botstory/ast/library.py +++ b/botstory/ast/library.py @@ -2,10 +2,12 @@ import itertools from . import parser +from .. import di logger = logging.getLogger(__name__) +@di.desc(reg=False) class StoriesLibrary: """ storage of all available stories diff --git a/botstory/ast/processor.py b/botstory/ast/processor.py index 96c8d58..669e8df 100644 --- a/botstory/ast/processor.py +++ b/botstory/ast/processor.py @@ -2,41 +2,28 @@ import inspect from . import parser, callable, forking -from .. import matchers +from .. import di, matchers from ..integrations import mocktracker logger = logging.getLogger(__name__) +@di.desc(reg=False) class StoryProcessor: def __init__(self, parser_instance, library, middlewares=[]): - self.interfaces = [] self.library = library self.middlewares = middlewares self.parser_instance = parser_instance - self.storage = None self.tracker = mocktracker.MockTracker() - def add_interface(self, interface): - if self.storage: - interface.add_storage(self.storage) - self.interfaces.append(interface) - interface.processor = self - - def add_storage(self, storage): - self.storage = storage - for interface in self.interfaces: - interface.add_storage(storage) - + @di.inject() def add_tracker(self, tracker): logger.debug('add_tracker') logger.debug(tracker) + if not tracker: + return self.tracker = tracker - def clear(self): - self.interfaces = [] - self.storage = None - async def match_message(self, message): """ because match_message is recursive we split function to @@ -49,6 +36,8 @@ async def match_message(self, message): logger.debug('> match_message <') logger.debug('') logger.debug(' {} '.format(message)) + logger.debug('self.tracker') + logger.debug(self.tracker) self.tracker.new_message( user=message and message['user'], data=message['data'], @@ -151,6 +140,9 @@ async def process_story(self, session, message, compiled_story, idx=0, story_arg story_part = story_line[idx] logger.debug(' going to call: {}'.format(story_part.__name__)) + + logger.debug('self.tracker') + logger.debug(self.tracker) self.tracker.story( user=message and message['user'], story_name=compiled_story.topic, diff --git a/botstory/ast/users.py b/botstory/ast/users.py index fddda35..f9a9464 100644 --- a/botstory/ast/users.py +++ b/botstory/ast/users.py @@ -1,10 +1,12 @@ import logging -from ..integrations.mocktracker import tracker +from .. import di +from ..integrations.mocktracker import tracker as tracker_module logger = logging.getLogger(__name__) _tracker = None +@di.inject() def add_tracker(tracker): logger.debug('add_tracker') logger.debug(tracker) @@ -26,7 +28,7 @@ def on_new_user_comes(user): def clear(): global _tracker - _tracker = tracker.MockTracker() + _tracker = tracker_module.MockTracker() clear() diff --git a/botstory/chat.py b/botstory/chat.py index 577bbcd..59621b8 100644 --- a/botstory/chat.py +++ b/botstory/chat.py @@ -70,13 +70,6 @@ async def send_text_message_to_all_interfaces(*args, **kwargs): return res -def add_http(http): - logger.debug('add_http') - logger.debug(http) - for _, interface in interfaces.items(): - interface.add_http(http) - - def add_interface(interface): logger.debug('add_interface') logger.debug(interface) diff --git a/botstory/di/__init__.py b/botstory/di/__init__.py new file mode 100644 index 0000000..662b4f0 --- /dev/null +++ b/botstory/di/__init__.py @@ -0,0 +1,14 @@ +from . import desciption as desc_module, inject as inject_module, injector_service + +__all__ = [] + +injector = injector_service.Injector() + +bind = injector.bind +child_scope = injector.child_scope +clear_instances = injector.clear_instances +desc = desc_module.desc +get = injector.get +inject = inject_module.inject + +__all__.extend([bind, child_scope, clear_instances, desc, inject]) diff --git a/botstory/di/desciption.py b/botstory/di/desciption.py new file mode 100644 index 0000000..c03dc09 --- /dev/null +++ b/botstory/di/desciption.py @@ -0,0 +1,25 @@ +import inspect +from .parser import camel_case_to_underscore +from .. import di + + +def desc(t=None, reg=True): + """ + Describe Class Dependency + + :param reg: should we register this class as well + :param t: custom type as well + :return: + """ + + def decorated_fn(cls): + if not inspect.isclass(cls): + return NotImplemented('For now we can only describe classes') + name = t or camel_case_to_underscore(cls.__name__)[0] + if reg: + di.injector.register(name, cls) + else: + di.injector.describe(name, cls) + return cls + + return decorated_fn diff --git a/botstory/di/inject.py b/botstory/di/inject.py new file mode 100644 index 0000000..1c583cf --- /dev/null +++ b/botstory/di/inject.py @@ -0,0 +1,21 @@ +import inspect + +from .parser import camel_case_to_underscore + +from .. import di + + +def inject(t=None): + def decorated_fn(fn): + if inspect.isclass(fn): + name = t or camel_case_to_underscore(fn.__name__)[0] + print('register {} on name {}'.format(fn, name)) + di.injector.register(name, fn) + elif inspect.isfunction(fn): + di.injector.requires(fn) + else: + # I'm not sure whether it possible case + raise NotImplementedError('try decorate {}'.format(fn)) + return fn + + return decorated_fn diff --git a/botstory/di/inject_test.py b/botstory/di/inject_test.py new file mode 100644 index 0000000..18b92c2 --- /dev/null +++ b/botstory/di/inject_test.py @@ -0,0 +1,138 @@ +import pytest +from .. import di + + +def test_inject_decorator(): + with di.child_scope(): + @di.inject() + class OneClass: + def __init__(self): + pass + + assert isinstance(di.injector.get('one_class'), OneClass) + + +def test_bind_singleton_instance_by_default(): + with di.child_scope(): + @di.inject() + class OneClass: + def __init__(self): + pass + + assert di.injector.get('one_class') == di.injector.get('one_class') + + +def test_inject_into_method_of_class(): + with di.child_scope(): + @di.inject() + class OuterClass: + @di.inject() + def inner(self, inner_class): + self.inner_class = inner_class + + @di.inject() + class InnerClass: + pass + + outer = di.injector.get('outer_class') + assert isinstance(outer, OuterClass) + assert isinstance(outer.inner_class, InnerClass) + + +def test_bind_should_inject_deps_in_decorated_methods_(): + with di.child_scope(): + @di.inject() + class OuterClass: + @di.inject() + def inner(self, inner_class): + self.inner_class = inner_class + + @di.inject() + class InnerClass: + pass + + outer = di.bind(OuterClass()) + assert isinstance(outer, OuterClass) + assert isinstance(outer.inner_class, InnerClass) + + +def test_inject_default_value_if_we_dont_have_dep(): + with di.child_scope(): + @di.inject() + class OuterClass: + @di.inject() + def inner(self, inner_class='Hello World!'): + self.inner_class = inner_class + + outer = di.bind(OuterClass()) + assert isinstance(outer, OuterClass) + assert outer.inner_class == 'Hello World!' + + +def test_no_autoupdate_deps_on_new_instance_comes(): + with di.child_scope(): + @di.inject() + class OuterClass: + @di.inject() + def inner(self, inner_class=None): + self.inner_class = inner_class + + outer = di.bind(OuterClass(), auto=False) + + @di.inject() + class InnerClass: + pass + + assert isinstance(outer, OuterClass) + assert outer.inner_class is None + + +def test_autoupdate_deps_on_new_instance_comes(): + with di.child_scope(): + @di.inject() + class OuterClass: + @di.inject() + def inner(self, inner_class=None): + self.inner_class = inner_class + + outer = di.bind(OuterClass(), auto=True) + + @di.inject() + class InnerClass: + pass + + assert isinstance(outer, OuterClass) + assert isinstance(outer.inner_class, InnerClass) + + +def test_fail_on_cyclic_deps(): + with di.child_scope(): + @di.inject() + class FirstClass: + @di.inject() + def deps(self, second_class=None): + self.second_class = second_class + + @di.inject() + class SecondClass: + @di.inject() + def deps(self, first_class=None): + self.first_class = first_class + + first_class = di.injector.get('first_class') + assert isinstance(first_class.second_class, SecondClass) + assert isinstance(first_class.second_class.first_class, FirstClass) + + +def test_custom_type(): + with di.child_scope(): + @di.inject('qwerty') + class OneClass: + pass + + assert isinstance(di.injector.get('qwerty'), OneClass) + + +def test_fail_on_incorrect_using(): + with pytest.raises(NotImplementedError): + di.inject()('qwerty') diff --git a/botstory/di/injector_service.py b/botstory/di/injector_service.py new file mode 100644 index 0000000..c0bf145 --- /dev/null +++ b/botstory/di/injector_service.py @@ -0,0 +1,247 @@ +import logging +import inspect + +from . import parser + +logger = logging.getLogger(__name__) + + +class Scope: + def __init__(self, name, parent=None): + self.storage = {} + # instances that will auto update on each new instance come + self.auto_bind_list = [] + # description of classes + self.described = {} + # functions that waits for deps + self.requires_fns = {} + # all instances that are singletones + self.singleton_cache = {} + self.name = name + # parent scope + self.parent = parent + + def get(self, type_name): + try: + item = self.storage[type_name] + if inspect.isclass(item): + return item() + else: + return item + except KeyError as err: + if self.parent: + return self.parent.get(type_name) + raise err + + def get_instance(self, type_name): + try: + return self.singleton_cache[type_name] + except KeyError as err: + if self.parent: + return self.parent.get_instance(type_name) + raise err + + def describe(self, type_name, cls): + self.described[cls] = { + 'type': type_name, + } + + def get_auto_bind_list(self): + yield from self.auto_bind_list + if self.parent: + yield from self.parent.get_auto_bind_list() + + def auto_bind(self, instance): + self.auto_bind_list.append(instance) + + def get_description(self, value): + try: + return self.described.get(value, self.described[type(value)]) + except KeyError as err: + if self.parent: + return self.parent.get_description(value) + raise err + + def store_instance(self, type_name, instance): + self.singleton_cache[type_name] = instance + + def get_endpoint_deps(self, method_ptr): + return self.requires_fns.get(method_ptr, {}).items() or \ + self.parent and self.parent.get_endpoint_deps(method_ptr) or \ + [] + + def store_deps_endpoint(self, fn, deps): + self.requires_fns[fn] = deps + + def clear(self): + self.auto_bind_list = [] + self.described = {} + self.requires_fns = {} + self.singleton_cache = {} + + def clear_instances(self): + self.auto_bind_list = [] + self.singleton_cache = {} + self.storage = {} + + def register(self, type_name, value): + if type_name in self.storage: + self.remove_type(type_name) + + self.storage[type_name] = value + + def remove_type(self, type_name): + self.storage.pop(type_name, True) + instance = self.singleton_cache.pop(type_name, True) + # TODO: clear self.requires_fns and self.auto_update_list maybe we can use type(instance)? + + def __repr__(self): + return ' {}'.format(self.name, { + 'storage': self.storage, + 'auto_update_list': self.auto_bind_list, + 'described': self.described, + 'requires_fns': self.requires_fns, + 'singleton_cache': self.singleton_cache, + }) + + +def empty_array_if_empty(default): + return [default] if default is not inspect.Parameter.empty else [] + + +class MissedDescriptionError(Exception): + pass + + +class Injector: + def __init__(self): + self.root = Scope('root') + self.current_scope = self.root + + def describe(self, type_name, cls): + """ + add description of class + + :param type_name: + :param cls: + :return: + """ + + self.current_scope.describe(type_name, cls) + + def register(self, type_name=None, instance=None): + print() + print('register {} = {}'.format(type_name, instance)) + if not isinstance(type_name, str) and type_name is not None: + raise ValueError('type_name parameter should be string or None') + if type_name is None: + try: + desc = self.current_scope.get_description(instance) + except KeyError: + # TODO: should raise exception + # raise MissedDescriptionError('{} was not registered'.format(instance)) + # print('self.described') + # print(self.current_scope.described) + # print('type(instance)') + # print(type(instance)) + # print('self.described.get(type(instance))') + # print(self.current_scope.described.get(type(instance))) + print('{} was not registered'.format(instance)) + return None + type_name = desc['type'] + # print('> before store {} = {}'.format(type_name, instance)) + self.current_scope.register(type_name, instance) + for wait_instance in self.current_scope.get_auto_bind_list(): + self.bind(wait_instance) + + def requires(self, fn): + fn_sig = inspect.signature(fn) + self.current_scope.store_deps_endpoint(fn, { + key: {'default': default for default in empty_array_if_empty(fn_sig.parameters[key].default)} + for key in fn_sig.parameters.keys() if key != 'self'}) + + def entrypoint_deps(self, method_ptr): + # we should have something to inject + # registered instance, class or default value + if not any(self.get(dep) or 'default' in dep_spec + for dep, dep_spec + in self.current_scope.get_endpoint_deps(method_ptr)): + # otherwise we should inject anything + return {} + + return {dep: self.get(dep) or dep_spec['default'] + for dep, dep_spec in self.current_scope.get_endpoint_deps(method_ptr)} + + def bind(self, instance, auto=False): + """ + Bind deps to instance + + :param instance: + :param auto: follow update of DI and refresh binds once we will get something new + :return: + """ + methods = [ + (m, cls.__dict__[m]) + for cls in inspect.getmro(type(instance)) + for m in cls.__dict__ if inspect.isfunction(cls.__dict__[m]) + ] + + try: + deps_of_endpoints = [(method_ptr, self.entrypoint_deps(method_ptr)) + for (method_name, method_ptr) in methods] + + for (method_ptr, method_deps) in deps_of_endpoints: + if len(method_deps) > 0: + method_ptr(instance, **method_deps) + except KeyError: + pass + + if auto and instance not in self.current_scope.get_auto_bind_list(): + self.current_scope.auto_bind(instance) + + return instance + + def get(self, type_name): + type_name = parser.kebab_to_underscore(type_name) + try: + return self.current_scope.get_instance(type_name) + except KeyError: + try: + instance = self.current_scope.get(type_name) + except KeyError: + # TODO: sometimes we should fail loudly in this case + return None + + self.current_scope.store_instance(type_name, instance) + instance = self.bind(instance) + + return instance + + def child_scope(self, name='undefined'): + return ChildScopeBuilder(self, self.current_scope, name) + + def clear_instances(self): + self.current_scope.clear_instances() + + def add_scope(self, scope): + self.current_scope = scope + + def remove_scope(self, scope): + assert self.current_scope == scope + self.current_scope = scope.parent + + +class ChildScopeBuilder: + def __init__(self, injector, parent, name): + self.injector = injector + self.name = name + self.parent = parent + + def __enter__(self): + self.scope = Scope(self.name, self.parent) + self.injector.add_scope(self.scope) + return self.scope + + def __exit__(self, exc_type, exc_val, exc_tb): + self.scope.clear() + self.injector.remove_scope(self.scope) diff --git a/botstory/di/injector_service_test.py b/botstory/di/injector_service_test.py new file mode 100644 index 0000000..4ae7c70 --- /dev/null +++ b/botstory/di/injector_service_test.py @@ -0,0 +1,182 @@ +import pytest +from .. import di + + +def test_injector_get(): + with di.child_scope(): + di.injector.register('once_instance', 'Hello World!') + assert di.injector.get('once_instance') == 'Hello World!' + + +def test_lazy_description_should_not_register_class(): + with di.child_scope(): + @di.desc(reg=False) + class OneClass: + pass + + assert di.injector.get('one_class') is None + + +def test_lazy_description_should_simplify_registration(): + with di.child_scope(): + @di.desc(reg=False) + class OneClass: + pass + + di.injector.register(instance=OneClass()) + + assert isinstance(di.injector.get('one_class'), OneClass) + + +def test_not_lazy_description_should_simplify_registration(): + with di.child_scope(): + @di.desc(reg=True) + class OneClass: + pass + + assert isinstance(di.injector.get('one_class'), OneClass) + + +def test_fail_if_type_is_not_string(): + with di.child_scope(): + class OneClass: + pass + + with pytest.raises(ValueError): + di.injector.register(OneClass) + + +def test_kebab_string_style_is_synonym_to_underscore(): + with di.child_scope(): + @di.desc() + class OneClass: + pass + + assert isinstance(di.injector.get('one-class'), OneClass) + + +def test_later_binding(): + with di.child_scope(): + @di.desc() + class OuterClass: + @di.inject() + def deps(self, test_class): + self.test_class = test_class + + @di.desc('test_class', reg=False) + class InnerClass: + pass + + outer = OuterClass() + di.injector.register(instance=outer) + di.bind(outer, auto=True) + + inner = InnerClass() + di.injector.register(instance=inner) + + assert outer.test_class == inner + + +def test_overwrite_previous_singleton_instance(): + with di.child_scope(): + @di.desc('test_class') + class FirstClass: + pass + + first_class = di.get('test_class') + + @di.desc('test_class') + class SecondClass: + pass + + second_class = di.get('test_class') + + assert first_class != second_class + assert isinstance(first_class, FirstClass) + assert isinstance(second_class, SecondClass) + + +def test_inherit_scope(): + with di.child_scope('first'): + @di.desc() + class First: + pass + + with di.child_scope('second'): + @di.desc() + class Second: + @di.inject() + def deps(self, first): + self.first = first + + second = di.get('second') + assert isinstance(second, Second) + assert isinstance(second.first, First) + + +def test_do_not_call_deps_endpoint_before_we_have_all_needed_deps(): + with di.child_scope(): + @di.desc() + class Container: + def __init__(self): + self.one = 'undefined' + self.two = 'undefined' + + @di.inject() + def deps(self, one, two): + self.one = one + self.two = two + + container = di.get('container') + assert container.one == 'undefined' + assert container.two == 'undefined' + + @di.desc() + class One: + pass + + di.bind(container) + assert container.one == 'undefined' + assert container.two == 'undefined' + + @di.desc() + class Two: + pass + + di.bind(container) + assert isinstance(container.one, One) + assert isinstance(container.two, Two) + + +def test_do_not_call_deps_endpoint_before_we_have_all_needed_deps_or_default_value(): + with di.child_scope(): + @di.desc() + class Container: + def __init__(self): + self.one = 'undefined' + self.two = 'undefined' + + @di.inject() + def deps(self, one, two='default'): + self.one = one + self.two = two + + container = di.get('container') + assert container.one == 'undefined' + assert container.two == 'undefined' + + @di.desc() + class One: + pass + + di.bind(container) + assert isinstance(container.one, One) + assert container.two == 'default' + + @di.desc() + class Two: + pass + + di.bind(container) + assert isinstance(container.one, One) + assert isinstance(container.two, Two) diff --git a/botstory/di/parser.py b/botstory/di/parser.py new file mode 100644 index 0000000..4aabe7c --- /dev/null +++ b/botstory/di/parser.py @@ -0,0 +1,37 @@ +import re + + +def camel_case_to_underscore(class_name): + """Converts normal class names into normal arg names. + Normal class names are assumed to be CamelCase with an optional leading + underscore. Normal arg names are assumed to be lower_with_underscores. + Args: + class_name: a class name, e.g., "FooBar" or "_FooBar" + Returns: + all likely corresponding arg names, e.g., ["foo_bar"] + + based on: + + """ + parts = [] + rest = class_name + if rest.startswith('_'): + rest = rest[1:] + while True: + m = re.match(r'([A-Z][a-z]+)(.*)', rest) + if m is None: + break + parts.append(m.group(1)) + rest = m.group(2) + if not parts: + return [] + return ['_'.join(part.lower() for part in parts)] + + +def kebab_to_underscore(s): + """ + Convert kebab-styled-string to underscore_styled_string + :param s: + :return: + """ + return s.replace('-', '_') diff --git a/botstory/di/parser_test.py b/botstory/di/parser_test.py new file mode 100644 index 0000000..1d549ab --- /dev/null +++ b/botstory/di/parser_test.py @@ -0,0 +1,17 @@ +from .parser import camel_case_to_underscore, kebab_to_underscore + + +def test_camelcase_to_underscore(): + assert camel_case_to_underscore('ClassName')[0] == 'class_name' + + +def test_remove_leading_underscore(): + assert camel_case_to_underscore('_ClassName')[0] == 'class_name' + + +def test_should_return_empty_array_if_no_any_class_name_here(): + assert camel_case_to_underscore('_qwerty') == [] + + +def test_kebab_to_underscore(): + assert kebab_to_underscore('hello-world') == 'hello_world' diff --git a/botstory/integrations/aiohttp/aiohttp.py b/botstory/integrations/aiohttp/aiohttp.py index ac2e156..f8cec34 100644 --- a/botstory/integrations/aiohttp/aiohttp.py +++ b/botstory/integrations/aiohttp/aiohttp.py @@ -7,6 +7,7 @@ from yarl import URL from ..commonhttp import errors as common_errors, statuses +from ... import di logger = logging.getLogger(__name__) @@ -24,9 +25,8 @@ def is_ok(status): return 200 <= status < 400 +@di.desc('http', reg=False) class AioHttpInterface: - type = 'interface.aiohttp' - def __init__(self, host='0.0.0.0', port=None, shutdown_timeout=60.0, ssl_context=None, backlog=128, auto_start=True, @@ -154,15 +154,16 @@ async def method(self, method_type, session, url, **kwargs): message=await resp.text(), ) except Exception as err: - logger.warn('Exception: status: {status}, message: {message}, type: {type}, method: {method}, url: {url}, {kwargs}' - .format(status=getattr(err, 'code', None), - message=getattr(err, 'message', None), - type=type(err), - method=method_name, - url=url, - kwargs=kwargs, - ) - ) + logger.warn( + 'Exception: status: {status}, message: {message}, type: {type}, method: {method}, url: {url}, {kwargs}' + .format(status=getattr(err, 'code', None), + message=getattr(err, 'message', None), + type=type(err), + method=method_name, + url=url, + kwargs=kwargs, + ) + ) raise err return resp diff --git a/botstory/integrations/aiohttp/aiohttp_test.py b/botstory/integrations/aiohttp/aiohttp_test.py index 82d336d..efe527e 100644 --- a/botstory/integrations/aiohttp/aiohttp_test.py +++ b/botstory/integrations/aiohttp/aiohttp_test.py @@ -2,8 +2,15 @@ import json import pytest from . import AioHttpInterface +from .. import aiohttp from ..commonhttp import errors from ..tests import fake_server +from ... import di, story + + +def teardown_function(function): + story.clear() + @pytest.fixture def webhook_handler(): @@ -13,6 +20,7 @@ def webhook_handler(): 'text': json.dumps({'message': 'Ok!'}), }) + @pytest.mark.asyncio async def test_post(event_loop): async with fake_server.FakeFacebook(event_loop) as server: @@ -179,3 +187,16 @@ async def mock_middleware_handler(request): assert handler_stub.called finally: await http.stop() + + +def test_get_as_deps(): + story.use(aiohttp.AioHttpInterface()) + + with di.child_scope('http'): + @di.desc() + class OneClass: + @di.inject() + def deps(self, http): + self.http = http + + assert isinstance(di.injector.get('one_class').http, aiohttp.AioHttpInterface) diff --git a/botstory/integrations/fb/messenger.py b/botstory/integrations/fb/messenger.py index 08d61fd..8e683e3 100644 --- a/botstory/integrations/fb/messenger.py +++ b/botstory/integrations/fb/messenger.py @@ -2,14 +2,16 @@ import logging from . import validate from .. import commonhttp +from ... import di from ...middlewares import option from ...ast import users logger = logging.getLogger(__name__) +@di.desc('fb', reg=False) class FBInterface: - type = 'interface.facebook' + type = 'facebook' def __init__(self, api_uri='https://graph.facebook.com/v2.6', @@ -37,9 +39,39 @@ def __init__(self, self.library = None self.http = None - self.processor = None + self.story_processor = None self.storage = None + @di.inject() + def add_library(self, stories_library): + logger.debug('add_library') + logger.debug(stories_library) + self.library = stories_library + + @di.inject() + def add_http(self, http): + """ + inject http provider + + :param http: + :return: + """ + logger.debug('add_http') + logger.debug(http) + self.http = http + + @di.inject() + def add_processor(self, story_processor): + logger.debug('add_processor') + logger.debug(story_processor) + self.story_processor = story_processor + + @di.inject() + def add_storage(self, storage): + logger.debug('add_storage') + logger.debug(storage) + self.storage = storage + async def send_text_message(self, recipient, text, options=None): """ async send message to the facebook user (recipient) @@ -80,24 +112,6 @@ async def send_text_message(self, recipient, text, options=None): 'message': message, }) - def add_http(self, http): - """ - inject http provider - - :param http: - :return: - """ - logger.debug('add_http') - logger.debug(http) - self.http = http - if self.webhook: - http.webhook(self.webhook, self.handle, self.webhook_token) - - def add_storage(self, storage): - logger.debug('add_storage') - logger.debug(storage) - self.storage = storage - async def request_profile(self, facebook_user_id): """ Make request to facebook @@ -198,13 +212,13 @@ async def handle(self, data): message['data'] = data - await self.processor.match_message(message) + await self.story_processor.match_message(message) elif 'postback' in m: message['data'] = { 'option': m['postback']['payload'], } - await self.processor.match_message(message) + await self.story_processor.match_message(message) elif 'delivery' in m: logger.debug('delivery notification') elif 'read' in m: @@ -241,6 +255,10 @@ async def setup(self): await self.remove_greeting_call_to_action_payload() await self.set_greeting_call_to_action_payload(option.OnStart.DEFAULT_OPTION_PAYLOAD) + async def start(self): + if self.webhook and self.http: + self.http.webhook(self.webhook, self.handle, self.webhook_token) + async def replace_greeting_text(self, message): """ delete greeting text before diff --git a/botstory/integrations/fb/messenger_test.py b/botstory/integrations/fb/messenger_test.py index e984e94..e3ccbef 100644 --- a/botstory/integrations/fb/messenger_test.py +++ b/botstory/integrations/fb/messenger_test.py @@ -4,8 +4,8 @@ import pytest from . import messenger -from .. import commonhttp, mockdb, mockhttp -from ... import chat, story, utils +from .. import commonhttp, fb, mockdb, mockhttp +from ... import chat, di, story, utils from ...middlewares import any, option logger = logging.getLogger(__name__) @@ -13,8 +13,7 @@ def teardown_function(function): logger.debug('tear down!') - story.stories_library.clear() - chat.interfaces = {} + story.clear() @pytest.mark.asyncio @@ -24,6 +23,8 @@ async def test_send_text_message(): interface = story.use(messenger.FBInterface(page_access_token='qwerty1')) mock_http = story.use(mockhttp.MockHttpInterface()) + await story.start() + await interface.send_text_message( recipient=user, text='hi!', options=None ) @@ -133,6 +134,8 @@ async def test_setup_webhook(): )) mock_http = story.use(mockhttp.MockHttpInterface()) + await story.start() + mock_http.webhook.assert_called_with( '/webhook', fb_interface.handle, @@ -308,6 +311,8 @@ async def builder(): storage = story.use(mockdb.MockDB()) fb = story.use(messenger.FBInterface(page_access_token='qwerty')) + await story.start() + await storage.set_session(session) await storage.set_user(user) @@ -573,7 +578,7 @@ async def test_can_set_greeting_text_before_inject_http(): mock_http = story.use(mockhttp.MockHttpInterface()) - await fb_interface.setup() + await story.setup() # give few a moment for lazy initialization of greeting text await asyncio.sleep(0.1) @@ -601,7 +606,7 @@ async def test_can_set_greeting_text_in_constructor(): mock_http = story.use(mockhttp.MockHttpInterface()) - await fb.setup() + await story.setup() # give few a moment for lazy initialization of greeting text await asyncio.sleep(0.1) @@ -738,7 +743,7 @@ async def test_can_set_persistent_menu_before_http(): mock_http = story.use(mockhttp.MockHttpInterface()) - await fb_interface.setup() + await story.setup() # give few a moment for lazy initialization of greeting text await asyncio.sleep(0.1) @@ -766,7 +771,7 @@ async def test_can_set_persistent_menu_before_http(): @pytest.mark.asyncio async def test_can_set_persistent_menu_inside_of_constructor(): - fb = story.use(messenger.FBInterface( + story.use(messenger.FBInterface( page_access_token='qwerty15', persistent_menu=[{ 'type': 'postback', @@ -781,7 +786,7 @@ async def test_can_set_persistent_menu_inside_of_constructor(): mock_http = story.use(mockhttp.MockHttpInterface()) - await fb.setup() + await story.setup() # give few a moment for lazy initialization of greeting text await asyncio.sleep(0.1) @@ -835,3 +840,32 @@ async def test_remove_persistent_menu(): 'thread_state': 'existing_thread' } ) + + +def test_get_fb_as_deps(): + story.use(messenger.FBInterface()) + + with di.child_scope(): + @di.desc() + class OneClass: + @di.inject() + def deps(self, fb): + self.fb = fb + + assert isinstance(di.injector.get('one_class').fb, messenger.FBInterface) + + +def test_bind_fb_deps(): + story.use(messenger.FBInterface()) + story.use(mockdb.MockDB()) + story.use(mockhttp.MockHttpInterface()) + + with di.child_scope(): + @di.desc() + class OneClass: + @di.inject() + def deps(self, fb): + self.fb = fb + + assert isinstance(di.injector.get('one_class').fb.http, mockhttp.MockHttpInterface) + assert isinstance(di.injector.get('one_class').fb.storage, mockdb.MockDB) diff --git a/botstory/integrations/ga/tracker.py b/botstory/integrations/ga/tracker.py index b7f11d9..4b3c2af 100644 --- a/botstory/integrations/ga/tracker.py +++ b/botstory/integrations/ga/tracker.py @@ -1,21 +1,26 @@ +""" +pageview: [ page path ] +event: category, action, [ label [, value ] ] +social: network, action [, target ] +timing: category, variable, time [, label ] +""" + import functools import json +import logging from .universal_analytics.tracker import Tracker +from ... import di from ...utils import queue +logger = logging.getLogger(__name__) + +@di.desc('tracker', reg=False) class GAStatistics: - type = 'interface.tracker' - """ - pageview: [ page path ] - event: category, action, [ label [, value ] ] - social: network, action [, target ] - timing: category, variable, time [, label ] - """ def __init__(self, - tracking_id, + tracking_id=None, story_tracking_template='{story}/{part}', new_message_tracking_template='receive: {data}', ): @@ -26,7 +31,13 @@ def __init__(self, self.story_tracking_template = story_tracking_template self.new_message_tracking_template = new_message_tracking_template + @staticmethod + def __hash__(): + return hash('ga.tracker') + def get_tracker(self, user): + logger.debug('get_tracker') + logger.debug(Tracker) return Tracker( account=self.tracking_id, client_id=user and user['_id'], diff --git a/botstory/integrations/ga/tracker_test.py b/botstory/integrations/ga/tracker_test.py index 6f60446..a272604 100644 --- a/botstory/integrations/ga/tracker_test.py +++ b/botstory/integrations/ga/tracker_test.py @@ -5,8 +5,9 @@ from unittest import mock from . import GAStatistics, tracker -from .. import fb, mockdb, mockhttp -from ... import story, utils +from .. import fb, ga, mockdb, mockhttp +from ... import di, story, story_test, utils +from ..ga import tracker_test def setup_function(): @@ -98,7 +99,7 @@ def greeting(message): story.use(mockdb.MockDB()) facebook = story.use(fb.FBInterface()) story.use(mockhttp.MockHttpInterface()) - story.use(GAStatistics(tracking_id='UA-XXXXX-Y')) + story.use(ga.GAStatistics(tracking_id='UA-XXXXX-Y')) await story.start() await facebook.handle({ @@ -123,13 +124,26 @@ def greeting(message): await asyncio.sleep(0.1) tracker_mock.send.assert_has_calls([ - mock.call('event', - 'new_user', 'start', 'new user starts chat' - ), - mock.call('pageview', - 'receive: {}'.format(json.dumps({'text': {'raw': 'hi!'}})), - ), + # mock.call('event', + # 'new_user', 'start', 'new user starts chat' + # ), + # mock.call('pageview', + # 'receive: {}'.format(json.dumps({'text': {'raw': 'hi!'}})), + # ), mock.call('pageview', 'one_story/greeting', ), ]) + + +def test_get_as_deps(): + story.use(ga.GAStatistics()) + + with di.child_scope(): + @di.desc() + class OneClass: + @di.inject() + def deps(self, tracker): + self.tracker = tracker + + assert isinstance(di.injector.get('one_class').tracker, ga.GAStatistics) diff --git a/botstory/integrations/mockdb/db.py b/botstory/integrations/mockdb/db.py index e745628..3cec89f 100644 --- a/botstory/integrations/mockdb/db.py +++ b/botstory/integrations/mockdb/db.py @@ -1,10 +1,9 @@ import aiohttp -from ... import utils +from ... import di, utils +@di.desc('storage', reg=False) class MockDB: - type = 'interface.session_storage' - def __init__(self): self.session = None self.user = None diff --git a/botstory/integrations/mockdb/db_test.py b/botstory/integrations/mockdb/db_test.py new file mode 100644 index 0000000..a5b8b93 --- /dev/null +++ b/botstory/integrations/mockdb/db_test.py @@ -0,0 +1,16 @@ +import pytest +from .. import mockdb +from ... import di, story + + +def test_get_mockdb_as_dep(): + story.use(mockdb.MockDB()) + + with di.child_scope(): + @di.desc() + class OneClass: + @di.inject() + def deps(self, storage): + self.storage = storage + + assert isinstance(di.injector.get('one_class').storage, mockdb.MockDB) diff --git a/botstory/integrations/mockhttp/mockhttp.py b/botstory/integrations/mockhttp/mockhttp.py index c6a1202..7ed3e60 100644 --- a/botstory/integrations/mockhttp/mockhttp.py +++ b/botstory/integrations/mockhttp/mockhttp.py @@ -3,6 +3,8 @@ import json from unittest import mock +from ... import di + def stub(name=None): """ @@ -15,9 +17,8 @@ def stub(name=None): return mock.MagicMock(spec=lambda *args, **kwargs: None, name=name) +@di.desc('http', reg=False) class MockHttpInterface: - type = 'interface.http' - def __init__(self, get={}, get_raise=sentinel, post=True, post_raise=sentinel, diff --git a/botstory/integrations/mockhttp/mockhttp_test.py b/botstory/integrations/mockhttp/mockhttp_test.py new file mode 100644 index 0000000..24e692e --- /dev/null +++ b/botstory/integrations/mockhttp/mockhttp_test.py @@ -0,0 +1,16 @@ +import pytest +from .. import mockhttp +from ... import di, story + + +def test_get_mockhttp_as_dep(): + story.use(mockhttp.MockHttpInterface()) + + with di.child_scope(): + @di.desc() + class OneClass: + @di.inject() + def deps(self, http): + self.http = http + + assert isinstance(di.injector.get('one_class').http, mockhttp.MockHttpInterface) diff --git a/botstory/integrations/mocktracker/tracker.py b/botstory/integrations/mocktracker/tracker.py index 0bd3c35..c1fff75 100644 --- a/botstory/integrations/mocktracker/tracker.py +++ b/botstory/integrations/mocktracker/tracker.py @@ -1,11 +1,11 @@ import logging +from ... import di logger = logging.getLogger(__name__) +@di.desc('tracker', reg=False) class MockTracker: - type = 'interface.tracker' - def event(self, *args, **kwargs): logging.debug('event') logging.debug(kwargs) diff --git a/botstory/integrations/mocktracker/tracker_test.py b/botstory/integrations/mocktracker/tracker_test.py index a57ecfc..9c14c75 100644 --- a/botstory/integrations/mocktracker/tracker_test.py +++ b/botstory/integrations/mocktracker/tracker_test.py @@ -1,4 +1,12 @@ +import pytest + from . import tracker +from .. import mocktracker +from ... import di, story + + +def teardown_function(function): + story.clear() def test_event(): @@ -19,3 +27,16 @@ def test_new_user(): def test_story(): t = tracker.MockTracker() t.story() + + +def test_get_mock_tracker_as_dep(): + story.use(mocktracker.MockTracker()) + + with di.child_scope(): + @di.desc() + class OneClass: + @di.inject() + def deps(self, tracker): + self.tracker = tracker + + assert isinstance(di.injector.get('one_class').tracker, mocktracker.MockTracker) diff --git a/botstory/integrations/mongodb/db.py b/botstory/integrations/mongodb/db.py index e85a282..c79ce7e 100644 --- a/botstory/integrations/mongodb/db.py +++ b/botstory/integrations/mongodb/db.py @@ -1,13 +1,13 @@ import asyncio import logging from motor import motor_asyncio +from ... import di logger = logging.getLogger(__name__) +@di.desc('storage', reg=False) class MongodbInterface: - type = 'interface.session_storage' - """ https://github.com/mongodb/motor """ diff --git a/botstory/integrations/mongodb/db_test.py b/botstory/integrations/mongodb/db_test.py index 6db00b5..7cf4abf 100644 --- a/botstory/integrations/mongodb/db_test.py +++ b/botstory/integrations/mongodb/db_test.py @@ -3,15 +3,16 @@ import pytest from . import db +from .. import mongodb from ..fb import messenger -from ... import story, utils +from ... import di, story, utils logger = logging.getLogger(__name__) def teardown_function(function): logger.debug('tear down!') - story.stories_library.clear() + story.clear() @pytest.fixture @@ -119,3 +120,16 @@ async def test_start_should_open_connection_and_close_on_stop(): assert not db_interface.session_collection assert not db_interface.user_collection assert not db_interface.db + + +def test_get_mongodb_as_dep(): + story.use(mongodb.MongodbInterface()) + + with di.child_scope(): + @di.desc() + class OneClass: + @di.inject() + def deps(self, storage): + self.storage = storage + + assert isinstance(di.injector.get('one_class').storage, mongodb.MongodbInterface) diff --git a/botstory/integrations/tests/integration_test.py b/botstory/integrations/tests/integration_test.py index d280935..27df6ef 100644 --- a/botstory/integrations/tests/integration_test.py +++ b/botstory/integrations/tests/integration_test.py @@ -1,19 +1,15 @@ import logging import os + import pytest from . import fake_server -from .. import fb, aiohttp, mongodb, mockhttp -from ... import story, chat, utils +from .. import aiohttp, fb, mongodb, mockhttp +from ... import chat, di, story, utils logger = logging.getLogger(__name__) -def setup_function(): - logger.debug('setup!') - story.clear() - - def teardown_function(function): logger.debug('tear down!') story.clear() @@ -32,10 +28,12 @@ async def builder(db, no_session=False, no_user=False): await db.set_session(session) story.use(db) - interface = story.use(fb.FBInterface(page_access_token='qwerty')) + fb_interface = story.use(fb.FBInterface(page_access_token='qwerty')) http = story.use(mockhttp.MockHttpInterface()) - return interface, http, user + await story.start() + + return fb_interface, http, user return builder @@ -67,7 +65,7 @@ async def test_facebook_interface_should_use_aiohttp_to_post_message(event_loop) async with server.session() as server_session: # 1) setup app - try: + # try: story.use(fb.FBInterface( webhook_url='/webhook', )) @@ -92,8 +90,8 @@ async def test_facebook_interface_should_use_aiohttp_to_post_message(event_loop) 'text': 'Pryvit!' } } - finally: - await story.stop() + # finally: + # await story.stop() @pytest.mark.asyncio @@ -197,7 +195,6 @@ def greeting(message): trigger.passed() await story.setup() - await story.start() http.delete.assert_called_with( 'https://graph.facebook.com/v2.6/me/thread_settings', @@ -226,6 +223,8 @@ def greeting(message): } ) + await story.start() + await facebook.handle({ 'object': 'page', 'entry': [{ diff --git a/botstory/middlewares/location/location_test.py b/botstory/middlewares/location/location_test.py index 6f0e297..83f7edd 100644 --- a/botstory/middlewares/location/location_test.py +++ b/botstory/middlewares/location/location_test.py @@ -6,11 +6,11 @@ def teardown_function(function): print('tear down!') - story.stories_library.clear() + story.clear() @pytest.mark.asyncio -async def test_should_trigger_on_any_location(): +async def test_should_trigger_on_any_location(event_loop): trigger = SimpleTrigger() session = build_fake_session() user = build_fake_user() @@ -21,6 +21,9 @@ def one_story(): def then(message): trigger.passed() + assert not event_loop.is_closed() + await story.start(event_loop) + await answer.location({'lat': 1, 'lng': 1}, session, user) assert trigger.is_triggered diff --git a/botstory/story.py b/botstory/story.py index 98b492e..b6012ae 100644 --- a/botstory/story.py +++ b/botstory/story.py @@ -1,7 +1,7 @@ import asyncio import logging -from . import chat +from . import chat, di from .ast import callable as callable_module, common, \ forking, library, parser, processor, users @@ -75,26 +75,12 @@ def use(middleware): middlewares.append(middleware) - # TODO: maybe it is good time to start using DI (dependency injection) + di.injector.register(instance=middleware) + di.bind(middleware, auto=True) + # TODO: should use DI somehow if check_spec(['send_text_message'], middleware): chat.add_interface(middleware) - # TODO: should find more elegant way to inject library to fb interface - # or information whether we have On Start story - middleware.library = stories_library - - if check_spec(['handle'], middleware): - story_processor_instance.add_interface(middleware) - - if check_spec(['get_user', 'set_user', 'get_session', 'set_session'], middleware): - story_processor_instance.add_storage(middleware) - - if check_spec(['post', 'webhook'], middleware): - chat.add_http(middleware) - - if middleware.type == 'interface.tracker': - story_processor_instance.add_tracker(middleware) - users.add_tracker(middleware) return middleware @@ -108,7 +94,6 @@ def clear(clear_library=True): :return: """ - story_processor_instance.clear() if clear_library: stories_library.clear() chat.clear() @@ -117,23 +102,34 @@ def clear(clear_library=True): global middlewares middlewares = [] + di.clear_instances() + + +def register(): + di.injector.register(instance=story_processor_instance) + di.injector.register(instance=stories_library) + di.injector.bind(story_processor_instance, auto=True) + di.injector.bind(stories_library, auto=True) + -async def setup(): - await _do_for_each_extension('setup') +async def setup(event_loop=None): + register() + await _do_for_each_extension('setup', event_loop) -async def start(): - await _do_for_each_extension('start') +async def start(event_loop=None): + register() + await _do_for_each_extension('start', event_loop) -async def stop(): - await _do_for_each_extension('stop') +async def stop(event_loop=None): + await _do_for_each_extension('stop', event_loop) -async def _do_for_each_extension(command): +async def _do_for_each_extension(command, even_loop): await asyncio.gather( - *[getattr(m, command)() for m in middlewares if hasattr(m, command)] - ) + *[getattr(m, command)() for m in middlewares if hasattr(m, command)], + loop=even_loop) def forever(loop):