diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..d543141 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[run] +omit = test/* +source = feeds diff --git a/.travis.yml b/.travis.yml index a1359a5..45c6e70 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,8 +5,6 @@ python: - 3.6 services: - docker -# env: -# global: before_install: - sudo apt-get -qq update @@ -17,4 +15,7 @@ install: - pip install -r dev-requirements.txt script: - - make test \ No newline at end of file + - make test + +after_script: + - coveralls \ No newline at end of file diff --git a/Makefile b/Makefile index df3bf04..dd1d607 100644 --- a/Makefile +++ b/Makefile @@ -1,17 +1,17 @@ install: pip install -r requirements.txt -build-docs: +docs: -rm -r docs -rm -r docsource/internal_apis mkdir -p docs - sphinx-apidoc --separate -o docsource/internal_apis src + sphinx-apidoc --separate -o docsource/internal_apis feeds test: - # flake8 feeds - pytest --verbose test --cov feeds + flake8 feeds + pytest --verbose test --cov --cov-report html feeds -s start: - gunicorn --worker-class gevent --timeout 300 --workers 10 --bind :5000 feeds.server:app + gunicorn --worker-class gevent --timeout 300 --workers 17 --bind :5000 feeds.server:app -.PHONY: test \ No newline at end of file +.PHONY: test docs \ No newline at end of file diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..8e278a7 --- /dev/null +++ b/TODO.md @@ -0,0 +1,53 @@ +### Notification +* handle plain string or structured note context + * service-dependent +* subclass, or otherwise mark context for each service + * avoid subclassing, so we don't need to modify feeds whenever a new service is added + +### NotificationFeed +* combine with global feed for maintenance, etc. + +### NotificatonManager +* handle fanouts in a more consistent way + +### TimelineStorage +* abstract caching +* filters on + * seen + * service + * level +* add remove notification + +### ActivityStorage +* caching to avoid lookups +* expire activities after configured time + +### Storage +* prototype with MongoDB adapter + +### Build & Deployment +* build Dockerfile +* build DockerCompose file +* Dockerize +* Sphinx docs +* make service token and encrypt + +### Server +* add params to GET notifications endpoint + +### Test Interface +* make one. +* maybe jump straight into KBase-UI module. + +### Actor +* validate actor properly +* include groups as an actor + +### Object +* validate object of notification where appropriate +* context dependent + +### Docs +* Do so, in general +* annotate deploy.cfg.example + diff --git a/deploy.cfg.example b/deploy.cfg.example deleted file mode 100644 index 75f8542..0000000 --- a/deploy.cfg.example +++ /dev/null @@ -1,6 +0,0 @@ -[feeds] -redis-host=localhost -redis-port=6379 -redis-user= -redis-pw= -auth-url=https://ci.kbase.us/services/auth \ No newline at end of file diff --git a/deployment/conf/.templates/deploy.cfg.templ b/deployment/conf/.templates/deploy.cfg.templ new file mode 100644 index 0000000..581f137 --- /dev/null +++ b/deployment/conf/.templates/deploy.cfg.templ @@ -0,0 +1,35 @@ +[feeds] +# DB info +# db-engine - allowed values = redis, mongodb. Others will raise an error on startup. +db-engine = {{ default .Env.db_engine "mongodb" }} + +# db-name - name of the database to use. default = "feeds". +db-name = {{ default .Env.db_name "feeds" }} + +# Other db info. The usual - host, port, user, and password. You know the drill. +db-host = {{ default .Env.db_host "ci-mongo" }} +db-port = {{ default .Env.db_port "27017" }} +db-user = {{ default .Env.db_user "feedsserv" }} +db-pw = {{ default .Env.db_pw "fake_password" }} + +# admins are allowed to use their auth tokens to create global notifications. +# examples would be notices about KBase downtime or events. +admins = wjriehl,scanon,kkeller,drakemm + +# fake user name for the global feed. Should be something that's not a valid +# user name. +global-feed = {{ default .Env.global_feed "_global_" }} + +# Default lifetime for each notification in days. Notes older than this won't be +# returned without explicitly looking them up by either their id or external key +# (when given). +lifespan = {{ default .Env.lifespan "30" }} + +# In debug mode, auth is mostly ignored. +# Useful for testing, etc. +# SET TO FALSE IN PRODUCTION! +debug = False + +auth-url = {{ default .Env.auth_url "https://ci.kbase.us/services/auth" }} +workspace-url = {{ default .Env.workspace_url "https://ci.kbase.us/services/ws" }} +groups-url = {{ default .Env.groups_url "https://ci.kbase.us/services/groups" }} diff --git a/deployment/deploy.cfg.example b/deployment/deploy.cfg.example new file mode 100644 index 0000000..8abd533 --- /dev/null +++ b/deployment/deploy.cfg.example @@ -0,0 +1,36 @@ +[feeds] +# DB info +# db-engine - allowed values = redis, mongodb. Others will raise an error on startup. +db-engine=mongodb + +# db-name - name of the database to use. default = "feeds". +db-name=feeds + +# Other db info. The usual - host, port, user, and password. You know the drill. +db-host=localhost +db-port=6379 +db-user= +db-pw= + +# Service urls +auth-url=https://ci.kbase.us/services/auth +workspace-url=https://ci.kbase.us/services/ws +groups-url=https://ci.kbase.us/services/groups + +# admins are allowed to use their auth tokens to create global notifications. +# examples would be notices about KBase downtime or events. +admins=wjriehl,scanon,kkeller,mmdrake + +# fake user name for the global feed. Should be something that's not a valid +# user name. +global-feed=_global_ + +# Default lifetime for each notification in days. Notes older than this won't be +# returned without explicitly looking them up by either their id or external key +# (when given). +lifespan=30 + +# In debug mode, auth is effectively ignored. +# Useful for testing, etc. +# SET TO FALSE IN PRODUCTION! +debug=False \ No newline at end of file diff --git a/deployment/docker-compose.yml b/deployment/docker-compose.yml new file mode 100644 index 0000000..2c35243 --- /dev/null +++ b/deployment/docker-compose.yml @@ -0,0 +1,28 @@ +services: + feeds: + ports: + - "5000":"5000" + environment: + - db-engine=mongodb + - db-name=feeds + - db-host=localhost + - db-port=27017 + - auth-url=https://ci.kbase.us/services/auth + - workspace-url=https://ci.kbase.us/services/ws + - groups-url=https://ci.kbase.us/services/groups + - AUTH_TOKEN=fake_token + command: + - "-wait" + - "tcp://ci-mongo:27017" + - "-timeout" + - "-template" + - "/kb/module/deployment/conf/.templates/deploy.cfg.templ:/kb/module/deploy.cfg" + - "make start" + depends_on: ["ci-mongo"] + + ci-mongo: + image: mongo:2 + command: + - "--smallfiles" + ports: + - "27017:27017" diff --git a/deployment/push2dockerhub.sh b/deployment/push2dockerhub.sh new file mode 100644 index 0000000..e69de29 diff --git a/dev-requirements.txt b/dev-requirements.txt index 524d7d2..a45c6b0 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,6 @@ coverage==4.5.1 pytest-cov==2.6.0 flake8==3.5.0 -pytest==3.8.2 \ No newline at end of file +pytest==3.8.2 +coveralls==1.5.1 +requests-mock==1.5.2 \ No newline at end of file diff --git a/feeds/activity/base.py b/feeds/activity/base.py index 135e1b9..c2888bd 100644 --- a/feeds/activity/base.py +++ b/feeds/activity/base.py @@ -1,6 +1,11 @@ +from abc import abstractmethod + + class BaseActivity(object): """ Common parent class for Activity and Notification. Activity will be done later. But a Notification is an Activity. """ - pass \ No newline at end of file + @abstractmethod + def to_dict(self): + pass diff --git a/feeds/activity/notification.py b/feeds/activity/notification.py index e5b273e..ccf41f7 100644 --- a/feeds/activity/notification.py +++ b/feeds/activity/notification.py @@ -1,12 +1,21 @@ from .base import BaseActivity import uuid import json -from datetime import datetime from ..util import epoch_ms from .. import verbs +from ..actor import validate_actor +from .. import notification_level +from feeds.exceptions import ( + InvalidExpirationError, + InvalidNotificationError +) +import datetime +from feeds.config import get_config + class Notification(BaseActivity): - def __init__(self, actor, verb, note_object, source, target=None, context={}): + def __init__(self, actor: str, verb, note_object: str, source: str, level='alert', + target: list=None, context: dict=None, expires: int=None, external_key: str=None): """ A notification is roughly of this form: actor, verb, object, (target) @@ -31,6 +40,10 @@ def __init__(self, actor, verb, note_object, source, target=None, context={}): :param source: source service for the note. String. :param target: target of the note. Optional. Should be a user id or group id if present. :param context: freeform context of the note. key-value pairs. + :param validate: if True, runs _validate immediately + :param expires: if not None, set a new expiration date - should be an int, ms since epoch + :param external_key: an optional special key given by the service that created the + notification TODO: * decide on global ids for admin use @@ -40,23 +53,178 @@ def __init__(self, actor, verb, note_object, source, target=None, context={}): * validate target is valid * validate context fits """ + assert actor is not None, "actor must not be None" + assert verb is not None, "verb must not be None" + assert note_object is not None, "note_object must not be None" + assert source is not None, "source must not be None" + assert level is not None, "level must not be None" + assert target is None or isinstance(target, list), "target must be either a list or None" + assert context is None or isinstance(context, dict), "context must be either a dict or None" + self.id = str(uuid.uuid4()) self.actor = actor self.verb = verbs.translate_verb(verb) self.object = note_object self.source = source self.target = target self.context = context - self.time = epoch_ms() # int timestamp down to millisecond + self.level = notification_level.translate_level(level) + self.created = epoch_ms() # int timestamp down to millisecond + if expires is None: + expires = self._default_lifespan() + self.created + self.validate_expiration(expires, self.created) + self.expires = expires + self.external_key = external_key + + def validate(self): + """ + Validates whether the notification fields are accurate. Should be called before + sending a new notification to storage. + """ + self.validate_expiration(self.expires, self.created) + validate_actor(self.actor) + + def validate_expiration(self, expires: int, created: int): + """ + Validates whether the expiration time is valid and after the created time. + If yes, returns True. If not, raises an InvalidExpirationError. + """ + # Just validate that the time looks like a real time in epoch millis. + try: + datetime.datetime.fromtimestamp(expires/1000) + except (TypeError, ValueError): + raise InvalidExpirationError( + "Expiration time should be the number " + "of milliseconds since the epoch" + ) + if expires <= created: + raise InvalidExpirationError( + "Notifications should expire sometime after they are created" + ) + + def _default_lifespan(self) -> int: + """ + Returns the default lifespan of this notification in ms. + """ + return get_config().lifespan * 24 * 60 * 60 * 1000 + + def to_dict(self) -> dict: + """ + Returns a dict form of the Notification. + Useful for storing in a document store, returns the id of each verb and level. + Less useful, but not terrible, for returning to a user. + """ + dict_form = { + "id": self.id, + "actor": self.actor, + "verb": self.verb.id, + "object": self.object, + "source": self.source, + "context": self.context, + "target": self.target, + "level": self.level.id, + "created": self.created, + "expires": self.expires, + "external_key": self.external_key + } + return dict_form + + def user_view(self) -> dict: + """ + Returns a view of the Notification that's intended for the user. + That means we leave out the target and external keys. + """ + view = { + "id": self.id, + "actor": self.actor, + "verb": self.verb.past_tense, + "object": self.object, + "source": self.source, + "context": self.context, + "level": self.level.name, + "created": self.created, + "expires": self.expires + } + return view + + def serialize(self) -> str: + """ + Serializes this notification to a string for caching / simple storage. + Assumes it's been validated. + Just dumps it all to a json string. + """ + serial = { + "i": self.id, + "a": self.actor, + "v": self.verb.id, + "o": self.object, + "s": self.source, + "t": self.target, + "l": self.level.id, + "c": self.created, + "e": self.expires, + "x": self.external_key, + "n": self.context + } + return json.dumps(serial, separators=(',', ':')) - def _validate(self): + @classmethod + def deserialize(cls, serial: str): """ - Validates whether the notification fields are accurate. Should be called before sending a new notification to storage. + Deserializes and returns a new Notification instance. """ - self.validate_actor(self.actor) + try: + assert serial + except AssertionError: + raise InvalidNotificationError("Can't deserialize an input of 'None'") + try: + struct = json.loads(serial) + except json.JSONDecodeError: + raise InvalidNotificationError("Can only deserialize a JSON string") + required_keys = set(['a', 'v', 'o', 's', 'l', 't', 'c', 'i', 'e']) + missing_keys = required_keys.difference(struct.keys()) + if missing_keys: + raise InvalidNotificationError('Missing keys: {}'.format(missing_keys)) + deserial = cls( + struct['a'], + str(struct['v']), + struct['o'], + struct['s'], + level=str(struct['l']), + target=struct.get('t'), + context=struct.get('n'), + external_key=struct.get('x') + ) + deserial.created = struct['c'] + deserial.id = struct['i'] + deserial.expires = struct['e'] + return deserial - def validate_actor(self): + @classmethod + def from_dict(cls, serial: dict): """ - TODO: add group validation. only users are actors for now. - TODO: migrate to base class for users + Returns a new Notification from a serialized dictionary (e.g. used in Mongo) """ - pass \ No newline at end of file + try: + assert serial is not None and isinstance(serial, dict) + except AssertionError: + raise InvalidNotificationError("Can only run 'from_dict' on a dict.") + required_keys = set([ + 'actor', 'verb', 'object', 'source', 'level', 'created', 'expires', 'id' + ]) + missing_keys = required_keys.difference(set(serial.keys())) + if missing_keys: + raise InvalidNotificationError('Missing keys: {}'.format(missing_keys)) + deserial = cls( + serial['actor'], + str(serial['verb']), + serial['object'], + serial['source'], + level=str(serial['level']), + target=serial.get('target'), + context=serial.get('context'), + external_key=serial.get('external_key') + ) + deserial.created = serial['created'] + deserial.expires = serial['expires'] + deserial.id = serial['id'] + return deserial diff --git a/feeds/actor.py b/feeds/actor.py new file mode 100644 index 0000000..b3e9d12 --- /dev/null +++ b/feeds/actor.py @@ -0,0 +1,16 @@ +""" +A module for defining actors. +TODO: decide whether to use a class, or just a validated string. I'm leaning toward string. +""" +from .auth import validate_user_id +from .exceptions import InvalidActorError + + +def validate_actor(actor): + """ + TODO: groups can be actors, too, when that's ready. + """ + if validate_user_id(actor): + return True + else: + raise InvalidActorError("Actor '{}' is not a real user.".format(actor)) diff --git a/feeds/auth.py b/feeds/auth.py index ea968b9..8b8176f 100644 --- a/feeds/auth.py +++ b/feeds/auth.py @@ -16,10 +16,12 @@ Cache, TTLCache ) -AUTH_URL = get_config().auth_url +config = get_config() +AUTH_URL = config.auth_url AUTH_API_PATH = '/api/V2/' CACHE_EXPIRE_TIME = 300 # seconds + class TokenCache(TTLCache): """ Extends the TTLCache to handle KBase auth tokens. @@ -28,12 +30,15 @@ class TokenCache(TTLCache): """ def __getitem__(self, key, cache_getitem=Cache.__getitem__): token = super(TokenCache, self).__getitem__(key, cache_getitem=cache_getitem) - if token.get('expires', 0) < epoch_ms(): + if token.get('expires', 0) <= epoch_ms(): return self.__missing__(key) else: return token + __token_cache = TokenCache(1000, CACHE_EXPIRE_TIME) +__user_cache = TTLCache(1000, CACHE_EXPIRE_TIME) + def validate_service_token(token): """ @@ -41,26 +46,38 @@ def validate_service_token(token): If invalid, raises an InvalidTokenError. If any other errors occur, raises a TokenLookupError. + Also returns a valid response - the token's user - if that user is in the + configured list of admins. + TODO: I know this is going to be rife with issues. The name of the token doesn't have to be the service. But as long as it's a Service token, then it came from in KBase, so everything should be ok. - TODO: Add 'source' to PUT notification endpoint. """ token = __fetch_token(token) if token.get('type') == 'Service': return token.get('name') + elif token.get('user') in config.admins: + return token.get('user') else: raise InvalidTokenError("Token is not a Service token!") + def validate_user_token(token): """ Validates a user auth token. - If valid, does nothing. If invalid, raises an InvalidTokenError. + If valid, returns the user id. If invalid, raises an InvalidTokenError. + If debug is True, always validates and returns a nonsense user name """ - __fetch_token(token) + return __fetch_token(token)['user'] + def validate_user_id(user_id): - return validate_user_ids([user_id]) + """ + Validates whether a SINGLE user is real or not. + Returns a boolean. + """ + return user_id in validate_user_ids([user_id]) + def validate_user_ids(user_ids): """ @@ -69,8 +86,25 @@ def validate_user_ids(user_ids): key is a user that exists, each value is their user name. Raises an HTTPError if something bad happens. """ - r = __auth_request('users?list={}'.format(','.join(user_ids))) - return json.loads(r.content) + users = dict() + # fetch ones we know of from the cache + for user_id in user_ids: + try: + users[user_id] = __user_cache[user_id] + except KeyError: + pass + # now we have a partial list. the ones that weren't found will + # not be in the users dict. Use set difference to find the + # remaining user ids. + filtered_users = set(user_ids).difference(set(users)) + if not filtered_users: + return users + r = __auth_request('users?list={}'.format(','.join(filtered_users)), config.auth_token) + found_users = json.loads(r.content) + __user_cache.update(found_users) + users.update(found_users) + return users + def __fetch_token(token): """ @@ -90,12 +124,13 @@ def __fetch_token(token): except requests.HTTPError as e: _handle_errors(e) + def __auth_request(path, token): """ Makes a request of the auth server after cramming the token in a header. Only makes GET requests, since that's all we should need. """ - headers = {'Authorization', token} + headers = {'Authorization': token} r = requests.get(AUTH_URL + AUTH_API_PATH + path, headers=headers) # the requests that fail based on the token (401, 403) get returned for the # calling function to turn into an informative error @@ -103,6 +138,7 @@ def __auth_request(path, token): r.raise_for_status() return r + def _handle_errors(err): if err.response.status_code == 401: err_content = json.loads(err.response.content) diff --git a/feeds/config.py b/feeds/config.py index 97528eb..a56f2ed 100644 --- a/feeds/config.py +++ b/feeds/config.py @@ -1,7 +1,6 @@ import os import configparser from .exceptions import ConfigError -import logging DEFAULT_CONFIG_PATH = "deploy.cfg" ENV_CONFIG_PATH = "FEEDS_CONFIG" @@ -9,23 +8,25 @@ ENV_AUTH_TOKEN = "AUTH_TOKEN" INI_SECTION = "feeds" -DB_HOST = "redis-host" -DB_HOST_PORT = "redis-port" -DB_USER = "redis-user" -DB_PW = "redis-pw" -AUTH_URL = "auth-url" + +KEY_DB_HOST = "db-host" +KEY_DB_PORT = "db-port" +KEY_DB_USER = "db-user" +KEY_DB_PW = "db-pw" +KEY_DB_NAME = "db-name" +KEY_DB_ENGINE = "db-engine" +KEY_AUTH_URL = "auth-url" +KEY_ADMIN_LIST = "admins" +KEY_GLOBAL_FEED = "global-feed" +KEY_DEBUG = "debug" +KEY_LIFESPAN = "lifespan" + class FeedsConfig(object): """ Loads a config set from the root deploy.cfg file. This should be in ini format. Keys of note are: - - redis-host - redis-port - redis-user - redis-pw - auth-url """ def __init__(self): @@ -36,16 +37,37 @@ def __init__(self): config_file = self._find_config_path() cfg = self._load_config(config_file) if not cfg.has_section(INI_SECTION): - raise ConfigError("Error parsing config file: section {} not found!".format(INI_SECTION)) - self.redis_host = self._get_line(cfg, DB_HOST) - self.redis_port = self._get_line(cfg, DB_HOST_PORT) - self.redis_user = self._get_line(cfg, DB_USER, required=False) - self.redis_pw = self._get_line(cfg, DB_PW, required=False) - self.auth_url = self._get_line(cfg, AUTH_URL) + raise ConfigError( + "Error parsing config file: section {} not found!".format(INI_SECTION) + ) + self.db_engine = self._get_line(cfg, KEY_DB_ENGINE) + self.db_host = self._get_line(cfg, KEY_DB_HOST) + self.db_port = self._get_line(cfg, KEY_DB_PORT) + try: + self.db_port = int(self.db_port) + except ValueError: + raise ConfigError("{} must be an int! Got {}".format(KEY_DB_PORT, self.db_port)) + self.db_user = self._get_line(cfg, KEY_DB_USER, required=False) + self.db_pw = self._get_line(cfg, KEY_DB_PW, required=False) + self.db_name = self._get_line(cfg, KEY_DB_NAME, required=False) + self.global_feed = self._get_line(cfg, KEY_GLOBAL_FEED) + self.auth_url = self._get_line(cfg, KEY_AUTH_URL) + self.admins = self._get_line(cfg, KEY_ADMIN_LIST).split(",") + self.lifespan = self._get_line(cfg, KEY_LIFESPAN) + try: + self.lifespan = int(self._get_line(cfg, KEY_LIFESPAN)) + except ValueError: + raise ConfigError("{} must be an int! Got {}".format(KEY_LIFESPAN, self.lifespan)) + self.debug = self._get_line(cfg, KEY_DEBUG, required=False) + if not self.debug or self.debug.lower() != "true": + self.debug = False + else: + self.debug = True def _find_config_path(self): """ - A little helper to test whether a given file path, or one given by an environment variable, exists. + A little helper to test whether a given file path, or one given by an + environment variable, exists. """ for env in [ENV_CONFIG_PATH, ENV_CONFIG_BACKUP]: env_path = os.environ.get(env) @@ -88,10 +110,12 @@ def _get_line(self, config, key, required=True): raise ConfigError("Required option {} has no value!".format(key)) return val + __config = None -def get_config(): + +def get_config(from_disk=False): global __config if not __config: __config = FeedsConfig() - return __config \ No newline at end of file + return __config diff --git a/feeds/exceptions.py b/feeds/exceptions.py index 80df8fe..0cce2a5 100644 --- a/feeds/exceptions.py +++ b/feeds/exceptions.py @@ -1,17 +1,17 @@ -from requests import HTTPError - class ConfigError(Exception): """ Raised when there's a problem with the service configuration. """ pass + class MissingVerbError(Exception): """ Raised when trying to convert from string -> registered verb, but the string's wrong. """ pass + class InvalidTokenError(Exception): """ Raised when finding out that a user or service auth token is invalid. @@ -23,6 +23,7 @@ def __init__(self, msg=None, http_error=None): super(InvalidTokenError, self).__init__(msg) self.http_error = http_error + class TokenLookupError(Exception): """ Raised when having problems looking up an auth token. Wraps HTTPError. @@ -33,8 +34,66 @@ def __init__(self, msg=None, http_error=None): super(TokenLookupError, self).__init__(msg) self.http_error = http_error + class InvalidActorError(Exception): """ Raised when an actor doesn't exist in the system as either a user or Group. """ - pass \ No newline at end of file + pass + + +class MissingTokenError(Exception): + """ + Raised when a request header doesn't have a token, but needs one. + """ + pass + + +class IllegalParameterError(Exception): + """ + Raised if a request receives an unexpected parameter format. E.g., + a JSON list instead of a JSON object. + """ + pass + + +class MissingParameterError(Exception): + """ + Raised if a request is missing required parameters, but is otherwise well-formed. + """ + pass + + +class MissingLevelError(Exception): + """ + Raised if looking for a Notification Level that doesn't exist. + """ + pass + + +class ActivityStorageError(Exception): + """ + Raised if an activity is failed to be stored in a database. + """ + pass + + +class ActivityRetrievalError(Exception): + """ + Raised if the service fails to retrieve an activity from a database. + """ + pass + + +class InvalidExpirationError(Exception): + """ + Raised when trying to give a Notification an invalid expiration time. + """ + pass + + +class InvalidNotificationError(Exception): + """ + Raised when trying to deserialize a Notification that has been stored badly. + """ + pass diff --git a/feeds/feeds/__init__.py b/feeds/feeds/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feeds/feeds/base.py b/feeds/feeds/base.py new file mode 100644 index 0000000..4f6a014 --- /dev/null +++ b/feeds/feeds/base.py @@ -0,0 +1,7 @@ +class BaseFeed(object): + """ + A feed should keep track of a user's activities. It should know how to add to them, fetch them, + and store them. It does NOT know how to fan those out to other feeds. It's just a really + fancy, database-powered list of Activities. + """ + pass diff --git a/feeds/feeds/notification/__init__.py b/feeds/feeds/notification/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feeds/feeds/notification/notification_feed.py b/feeds/feeds/notification/notification_feed.py new file mode 100644 index 0000000..89e70ac --- /dev/null +++ b/feeds/feeds/notification/notification_feed.py @@ -0,0 +1,67 @@ +from ..base import BaseFeed +from feeds.activity.notification import Notification +from feeds.storage.mongodb.activity_storage import MongoActivityStorage +from feeds.storage.mongodb.timeline_storage import MongoTimelineStorage +from cachetools import TTLCache +import logging + + +class NotificationFeed(BaseFeed): + def __init__(self, user_id): + self.user_id = user_id + self.timeline_storage = MongoTimelineStorage(self.user_id) + self.activity_storage = MongoActivityStorage() + self.timeline = None + self.cache = TTLCache(1000, 600) + + def _update_timeline(self): + """ + Updates a local user timeline cache. This is a list of activity ids + that are used for fetching from activity storage (for now). Sorted + by newest first. + + TODO: add metadata to timeline storage - type and verb, first. + """ + logging.getLogger(__name__).info('Fetching timeline for ' + self.user_id) + self.timeline = self.timeline_storage.get_timeline() + + def get_notifications(self, count=10): + return self.get_activities(count=count) + + def get_activities(self, count=10): + """ + Returns a selection of activities. + :param count: Maximum number of Notifications to return (default 10) + """ + # steps. + # 0. If in cache, return them. <-- later + # 1. Get storage adapter. + # 2. Query it for recent activities from this user. + # 3. Cache them here. + # 4. Return them. + if count < 1 or not isinstance(count, int): + raise ValueError('Count must be an integer > 0') + serial_notes = self.timeline_storage.get_timeline(count=count) + note_list = [Notification.from_dict(note) for note in serial_notes] + return note_list + + def mark_activities(self, activity_ids, seen=False): + """ + Marks the given list of activities as either seen (True) or unseen (False). + """ + pass + + def add_notification(self, note): + return self.add_activity(note) + + def add_activity(self, note): + """ + Adds an activity to this user's feed + """ + self.timeline_storage.add_to_timeline(note) + + def add_activities(self): + """ + Adds several activities to this user's feed. + """ + pass diff --git a/feeds/logger.py b/feeds/logger.py new file mode 100644 index 0000000..1fdd3ec --- /dev/null +++ b/feeds/logger.py @@ -0,0 +1,13 @@ +import logging + + +def get_log(name): + return logging.getLogger(__name__) + + +def log(name, msg, *args, level=logging.INFO): + logging.getLogger(__name__).log(level, msg, *args) + + +def log_error(name, error): + log(name, "Exception: " + str(error) + error, level=logging.ERROR) diff --git a/feeds/managers/__init__.py b/feeds/managers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feeds/managers/base.py b/feeds/managers/base.py new file mode 100644 index 0000000..66361d3 --- /dev/null +++ b/feeds/managers/base.py @@ -0,0 +1,9 @@ +class BaseManager(object): + def __init__(self): + pass + + def get_target_users(self, activity): + """ + TODO: Abstract some basic functionality here for the generic activity type. + """ + return [] diff --git a/feeds/managers/notification_manager.py b/feeds/managers/notification_manager.py new file mode 100644 index 0000000..bb077da --- /dev/null +++ b/feeds/managers/notification_manager.py @@ -0,0 +1,44 @@ +""" +A NotificationManager manages adding new notifications. +This is separate from a feed - notifications get added by the KBase system to +user feeds, based on the content and context of the Notification. + +See also the docs in NotificationFeed. +""" + +from .base import BaseManager +from ..storage.mongodb.activity_storage import MongoActivityStorage +from feeds.config import get_config + + +class NotificationManager(BaseManager): + def __init__(self): + # init storage + pass + + def add_notification(self, note): + """ + Adds a new notification. + Triggers validation first. + """ + note.validate() # any errors get raised to be caught by the server. + target_users = self.get_target_users(note) + # add the notification to the database. + activity_storage = MongoActivityStorage() + activity_storage.add_to_storage(note, target_users) + + def get_target_users(self, note): + """ + This is gonna get complex. + The target users are a combination of: + - the list in note.target + - workspace admins (if it comes from a workspace) + - everyone, if it's global - mark as _global_ feed. + TODO: add adapters, maybe subclass notifications to handle each source? + """ + user_list = list() + if note.target: + user_list = user_list + note.target + elif note.source == 'kbase': + user_list.append(get_config().global_feed) + return user_list diff --git a/feeds/notification_level.py b/feeds/notification_level.py new file mode 100644 index 0000000..63da3e0 --- /dev/null +++ b/feeds/notification_level.py @@ -0,0 +1,81 @@ +from .exceptions import MissingLevelError + +_level_register = dict() + + +def register(level): + if not issubclass(level, Level): + raise TypeError("Can only register Level subclasses") + + if level.id is None: + raise ValueError("A level must have an id") + elif str(level.id) in _level_register: + raise ValueError("The level id '{}' is already taken by {}".format( + level.id, _level_register[str(level.id)].name + )) + + if level.name is None: + raise ValueError("A level must have a name") + elif level.name.lower() in _level_register: + raise ValueError("The level '{}' is already registered!".format(level.name)) + + _level_register.update({ + str(level.id): level, + level.name.lower(): level + }) + + +def get_level(key): + key = str(key) + if key.lower() in _level_register: + return _level_register[key]() + else: + raise MissingLevelError('Level "{}" not found.'.format(key)) + + +def translate_level(level): + """ + Allows level to be either an id, a name, or a Level. + Regardless, returns the Level instance, or raises a MissingLevelError + + :param level: Either a string or a Level. (stringify numerical ids before looking them up) + """ + if isinstance(level, int): + return get_level(str(level)) + elif isinstance(level, str): + return get_level(level) + elif isinstance(level, Level): + return get_level(level.name) + else: + raise TypeError("Must be either a subclass of Level or a string.") + + +class Level(object): + id = 0 + name = None + + +class Alert(Level): + id = 1 + name = 'alert' + + +class Warning(Level): + id = 2 + name = 'warning' + + +class Error(Level): + id = 3 + name = 'error' + + +class Request(Level): + id = 4 + name = 'request' + + +register(Alert) +register(Warning) +register(Error) +register(Request) diff --git a/feeds/server.py b/feeds/server.py index e01cd5c..0f78eae 100644 --- a/feeds/server.py +++ b/feeds/server.py @@ -1,33 +1,103 @@ -import os import json import flask from flask import ( Flask, request ) +import traceback +import logging +from http.client import responses from flask.logging import default_handler from .util import epoch_ms -from .config import FeedsConfig -import logging +from .config import get_config +from .auth import ( + validate_service_token, + validate_user_token +) +from .exceptions import ( + MissingTokenError, + InvalidTokenError, + TokenLookupError, + IllegalParameterError, + MissingParameterError +) +from feeds.managers.notification_manager import NotificationManager +from feeds.activity.notification import Notification +from feeds.feeds.notification.notification_feed import NotificationFeed VERSION = "0.0.1" + def _initialize_logging(): root = logging.getLogger() root.addHandler(default_handler) root.setLevel('INFO') -def _initialize_config(): - # TODO - include config for: - # * database access - return FeedsConfig() -def _log(msg, *args): - logging.getLogger(__name__).info(msg, *args) +def _log(msg, *args, level=logging.INFO): + logging.getLogger(__name__).log(level, msg, *args) + + +def _log_error(error): + formatted_error = ''.join( + traceback.format_exception( + etype=type(error), + value=error, + tb=error.__traceback__) + ) + _log("Exception: " + formatted_error, level=logging.ERROR) + + +def _get_auth_token(request, required=True): + token = request.headers.get('Authorization') + if not token and required: + raise MissingTokenError() + return token + + +def _make_error(error, msg, status_code): + _log("%s %s", status_code, msg) + err_response = { + "http_code": status_code, + "http_status": responses[status_code], + "message": msg, + "time": epoch_ms() + } + return (flask.jsonify({'error': err_response}), status_code) + + +def _get_notification_params(params): + """ + Parses and verifies all the notification params are present. + Raises a MissingParameter error otherwise. + """ + # * `actor` - a user or org id. + # * `type` - one of the type keywords (see below, TBD (as of 10/8)) + # * `target` - optional, a user or org id. - always receives this notification + # * `object` - object of the notice. For invitations, the group to be invited to. + # For narratives, the narrative UPA. + # * `level` - alert, error, warning, or request. + # * `content` - optional, content of the notification, otherwise it'll be + # autogenerated from the info above. + # * `global` - true or false. If true, gets added to the global notification feed + # and everyone gets a copy. + + if not isinstance(params, dict): + raise IllegalParameterError('Expected a JSON object as an input.') + required_list = ['actor', 'verb', 'target', 'object', 'level'] + missing = [r for r in required_list if r not in params] + if missing: + raise MissingParameterError("Missing parameter{} - {}".format( + "s" if len(missing) > 1 else '', + ", ".join(missing) + )) + # TODO - add more checks + return params + def create_app(test_config=None): _initialize_logging() - _initialize_config() + cfg = get_config() app = Flask(__name__, instance_relative_config=True) if test_config is None: @@ -35,6 +105,17 @@ def create_app(test_config=None): else: app.config.from_mapping(test_config) + @app.before_request + def preprocess_request(): + _log('%s %s', request.method, request.path) + pass + + @app.after_request + def postprocess_request(response): + _log('%s %s %s %s', request.method, request.path, response.status_code, + request.headers.get('User-Agent')) + return response + @app.route('/', methods=['GET']) def root(): return flask.jsonify({ @@ -43,8 +124,9 @@ def root(): "servertime": epoch_ms() }) - @app.route('/api/V1/notifications/', methods=['GET']) + @app.route('/api/V1/notifications', methods=['GET']) def get_notifications(): + # TODO: add filtering """ General flow should be: 1. validate/authenticate user @@ -54,32 +136,40 @@ def get_notifications(): # dummy code below max_notes = request.args.get('n', default=10, type=int) rev_sort = request.args.get('rev', default=0, type=int) - rev_sort = False if rev_sort==0 else True - level_filter = request.args.get('f', default=None, type=str) - include_seen = request.args.get('seen', default=0, type=int) - include_seen = False if include_seen==0 else True - return json.dumps({ - "max_notes": max_notes, - "rev_sort": rev_sort, - "level_filter": level_filter, - "include_seen": include_seen - }) + rev_sort = False if rev_sort == 0 else True + # level_filter = request.args.get('f', default=None, type=str) + include_seen = request.args.get('seen', default=1, type=int) + include_seen = False if include_seen == 0 else True + # return json.dumps({ + # "max_notes": max_notes, + # "rev_sort": rev_sort, + # "level_filter": level_filter, + # "include_seen": include_seen + # }) + user_id = validate_user_token(_get_auth_token(request)) + _log('Getting feed for {}'.format(user_id)) + feed = NotificationFeed(user_id) + notes = feed.get_notifications(count=max_notes) + return_list = list() + for note in notes: + return_list.append(note.user_view()) + return (flask.jsonify(return_list), 200) @app.route('/api/V1/notification/', methods=['GET']) def get_single_notification(note_id): raise NotImplementedError() - @app.route('/api/V1/notifications/unsee/', methods=['POST']) + @app.route('/api/V1/notifications/unsee', methods=['POST']) def mark_notifications_unseen(): """Form data should have a list of notification ids to mark as unseen""" raise NotImplementedError() - @app.route('/api/V1/notifications/see/', methods=['POST']) + @app.route('/api/V1/notifications/see', methods=['POST']) def mark_notifications_seen(): """Form data should have a list of notifications to mark as seen""" raise NotImplementedError() - @app.route('/api/V1/notification/', methods=['PUT']) + @app.route('/api/V1/notification', methods=['POST', 'PUT']) def add_notification(): """ Adds a new notification for other users to see. @@ -87,16 +177,81 @@ def add_notification(): * `actor` - a user or org id. * `type` - one of the type keywords (see below, TBD (as of 10/8)) * `target` - optional, a user or org id. - always receives this notification - * `object` - object of the notice. For invitations, the group to be invited to. For narratives, the narrative UPA. + * `object` - object of the notice. For invitations, the group to be invited to. + For narratives, the narrative UPA. * `level` - alert, error, warning, or request. - * `content` - optional, content of the notification, otherwise it'll be autogenerated from the info above. - * `global` - true or false. If true, gets added to the global notification feed and everyone gets a copy. + * `content` - optional, content of the notification, otherwise it'll be + autogenerated from the info above. + * `global` - true or false. If true, gets added to the global notification + feed and everyone gets a copy. - This also requires a service token as an Authorization header. Once validated, will be used - as the Source of the notification, and used in logic to determine which feeds get notified. + This also requires a service token as an Authorization header. + Once validated, will be used as the Source of the notification, + and used in logic to determine which feeds get notified. """ - raise NotImplementedError() + if not cfg.debug: + token = _get_auth_token(request) + service = validate_service_token(token) # can also be an admin user + if not service: + raise InvalidTokenError("Token must come from a service, not a user!") + params = _get_notification_params(json.loads(request.get_data())) + # create a Notification from params. + new_note = Notification( + params.get('actor'), + params.get('verb'), + params.get('object'), + params.get('source'), + params.get('level'), + target=params.get('target'), + context=params.get('context') + ) + # pass it to the NotificationManager to dole out to its audience feeds. + manager = NotificationManager() + manager.add_notification(new_note) + # on success, return the notification id and info. + return (flask.jsonify({'id': new_note.id}), 200) + + @app.errorhandler(IllegalParameterError) + @app.errorhandler(json.JSONDecodeError) + def handle_illegal_parameter(err): + _log_error(err) + return _make_error(err, "Incorrect data format", 400) + + @app.errorhandler(InvalidTokenError) + def handle_invalid_token(err): + _log_error(err) + return _make_error(err, "Invalid token", 401) + + @app.errorhandler(MissingTokenError) + def handle_missing_token(err): + _log_error(err) + return _make_error(err, "Authentication token required", 403) + + @app.errorhandler(404) + def not_found(err): + return _make_error(err, "Path {} not found.".format(request.path), 404) + + @app.errorhandler(405) + def handle_not_allowed(err): + _log_error(err) + return _make_error(err, "Method not allowed", 405) + + @app.errorhandler(MissingParameterError) + def handle_missing_params(err): + _log_error(err) + return _make_error(err, str(err), 422) + + @app.errorhandler(TokenLookupError) + def handle_auth_service_error(err): + _log_error(err) + return _make_error(err, "Unable to fetch authentication information", 500) + + @app.errorhandler(Exception) + def general_error(err): + _log_error(err) + return _make_error(err, str(err), 500) return app -app = create_app() \ No newline at end of file + +app = create_app() diff --git a/feeds/storage/__init__.py b/feeds/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feeds/storage/base.py b/feeds/storage/base.py new file mode 100644 index 0000000..1bccd96 --- /dev/null +++ b/feeds/storage/base.py @@ -0,0 +1,38 @@ +class BaseStorage(object): + def __init__(self): + pass + + def serialize(self): + raise NotImplementedError() + + def deserialize(self): + raise NotImplementedError() + + +class ActivityStorage(BaseStorage): + def __init__(self): + pass + + def add_to_storage(self, activities): + raise NotImplementedError() + + def get_from_storage(self, activity_ids): + raise NotImplementedError() + + def remove_from_storage(self, activity_ids): + raise NotImplementedError() + + +class TimelineStorage(BaseStorage): + def __init__(self, user_id): + assert user_id + self.user_id = user_id + + def add_to_timeline(self, activity): + raise NotImplementedError() + + def get_timeline(self, count=10): + raise NotImplementedError() + + def remove_from_timeline(self, activity_ids): + raise NotImplementedError() diff --git a/feeds/storage/mongodb/__init__.py b/feeds/storage/mongodb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feeds/storage/mongodb/activity_storage.py b/feeds/storage/mongodb/activity_storage.py new file mode 100644 index 0000000..ca1afa6 --- /dev/null +++ b/feeds/storage/mongodb/activity_storage.py @@ -0,0 +1,38 @@ +from typing import List +from ..base import ActivityStorage +from .connection import get_feeds_collection +from feeds.exceptions import ( + ActivityStorageError +) +from pymongo.errors import PyMongoError + + +class MongoActivityStorage(ActivityStorage): + def add_to_storage(self, activity, target_users: List[str]): + """ + Adds a single activity to the MongoDB. + Returns None if successful. + Raises an ActivityStorageError if it fails. + """ + coll = get_feeds_collection() + act_doc = activity.to_dict() + act_doc["users"] = target_users + act_doc["unseen"] = target_users + try: + coll.insert_one(act_doc) + except PyMongoError as e: + raise ActivityStorageError("Failed to store activity: " + str(e)) + + def get_from_storage(self, activity_ids): + pass + + def remove_from_storage(self, activity_ids): + raise NotImplementedError() + + def change_seen_mark(self, act_id: str, user: str, seen: bool): + """ + :param act_id: activity id + :user: user id + :seen: whether or not it's been seen. Boolean. + """ + raise NotImplementedError() diff --git a/feeds/storage/mongodb/connection.py b/feeds/storage/mongodb/connection.py new file mode 100644 index 0000000..4e9c549 --- /dev/null +++ b/feeds/storage/mongodb/connection.py @@ -0,0 +1,74 @@ +from pymongo import ( + MongoClient, + ASCENDING, + DESCENDING +) +from feeds.config import get_config +import feeds.logger as log + +_connection = None + +_COL_NOTIFICATIONS = "notifications" + +# Searches to support: +# 1. Lookup by activity id. Easy. +# 2. Lookup all by user, include docs where user is not in unseen, sort by time. +# 3. Lookup all by user, ignore docs where user is not in unseen, sort by time. +# 4. Lookup all by user, sort by source, then sort by time. +# 5. Lookup all by user, sort by type, then sort by time. +# 6. Lookup all by user, sort by source, then sort by time, then sort by time. +# 7. Aggregations... later. Maybe part of the Timeline class. + +_INDEXES = [ + [("act_id", ASCENDING)], + + # sort by creation date + [("created", DESCENDING)], + + # sort by target users + [("users", ASCENDING)], + + # sort by unseen users + [("unseen", ASCENDING)], + + # sort by source, then creation date + [("users", ASCENDING), ("source", ASCENDING), ("created", DESCENDING)], + + # sort by level, then creation date + [("users", ASCENDING), ("level", ASCENDING), ("created", DESCENDING)] +] + + +def get_feeds_collection(): + conn = get_mongo_connection() + return conn.get_collection(_COL_NOTIFICATIONS) + + +def get_mongo_connection(): + global _connection + if _connection is None: + _connection = FeedsMongoConnection() + return _connection + + +class FeedsMongoConnection(object): + def __init__(self): + self.cfg = get_config() + log.log(__name__, "opening MongoDB connection {} {}".format( + self.cfg.db_host, self.cfg.db_port) + ) + self.conn = MongoClient(host=self.cfg.db_host, port=self.cfg.db_port) + self.db = self.conn[self.cfg.db_name] + self._setup_indexes() + self._setup_schema() + + def get_collection(self, collection_name): + return self.db[collection_name] + + def _setup_indexes(self): + coll = self.get_collection(_COL_NOTIFICATIONS) + for index in _INDEXES: + coll.create_index(index) + + def _setup_schema(self): + pass diff --git a/feeds/storage/mongodb/timeline_storage.py b/feeds/storage/mongodb/timeline_storage.py new file mode 100644 index 0000000..876b794 --- /dev/null +++ b/feeds/storage/mongodb/timeline_storage.py @@ -0,0 +1,34 @@ +import pymongo +from ..base import TimelineStorage +from .connection import get_feeds_collection + + +class MongoTimelineStorage(TimelineStorage): + def add_to_timeline(self, activity): + raise NotImplementedError() + + def get_timeline(self, count=10, include_seen=False, level=None, verb=None, sort=None): + """ + :param count: int > 0 + :param include_seen: boolean + :param level: Level or None + :param verb: Verb or None + """ + # TODO: add filtering + # TODO: input validation + coll = get_feeds_collection() + query = { + "users": [self.user_id] + } + if not include_seen: + query['unseen'] = [self.user_id] + if level is not None: + query['level'] = level.id + if verb is not None: + query['verb'] = verb.id + timeline = coll.find(query).sort("created", pymongo.DESCENDING) + serial_notes = [note for note in timeline] + return serial_notes + + def remove_from_timeline(self, activity_ids): + raise NotImplementedError() diff --git a/feeds/storage/redis/__init__.py b/feeds/storage/redis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feeds/storage/redis/activity_storage.py b/feeds/storage/redis/activity_storage.py new file mode 100644 index 0000000..7e35079 --- /dev/null +++ b/feeds/storage/redis/activity_storage.py @@ -0,0 +1,63 @@ +from ..base import ActivityStorage +from .connection import get_redis_connection +# from cachetools import TTLCache +from .util import get_activity_key +from collections import defaultdict + +""" +Activities get added to Redis like this: +Each activity comes from a source - that source is used as the key. +Each activity gets a unique id that it knows how to make. +Each activity can also be serialized/deserialized into a string. Maybe it gets pickled. +So now we have the makings of a hash. +Each key = activity's unique id. +Each value = serialized version of activity. + +Should be small and fast. + +If we make a MongoDB adapter, it can work with documents. +""" + +# activity_cache = TTLCache(1000, 600) + + +class RedisActivityStorage(ActivityStorage): + def serialize(self): + raise NotImplementedError() + + def deserialize(self): + raise NotImplementedError() + + def add_to_storage(self, activity): + # namespaces notes under the first character of their id. + # should help with sharding, if we need to. + key = get_activity_key(activity) + r = get_redis_connection() + r.hset(key, activity.id, activity.serialize()) + + def get_from_storage(self, activity_ids): + # returns a list of serialized strings. We don't know what subclass + # to return as here, so it's up to the calling Manager or Feed to deserialize. + + # first, map the activity_ids onto their stored hash keys + lookup_map = defaultdict(list) + for id_ in activity_ids: + lookup_map[get_activity_key(id_)].append(id_) + r = get_redis_connection() + acts = dict() # act id -> Activity + for key in lookup_map: + id_list = lookup_map[key] + serial_acts = r.hmget(key, lookup_map[key]) + # The above are the same length (redis fills in None for missing keys) + # just smash them into a dict + for idx, id_ in enumerate(id_list): + acts[id_] = serial_acts[idx] + # now, just map the acts dict back onto the original activity_ids list + # to maintain the order + ret_list = list() + for id_ in activity_ids: + ret_list.append(acts[id_]) + return ret_list + + def remove_from_storage(self, activity_ids): + raise NotImplementedError() diff --git a/feeds/storage/redis/connection.py b/feeds/storage/redis/connection.py new file mode 100644 index 0000000..abdc73f --- /dev/null +++ b/feeds/storage/redis/connection.py @@ -0,0 +1,38 @@ +import redis +from feeds.config import get_config + +connection_pool = None + + +def get_redis_connection(server_name='default'): + ''' + Gets the specified redis connection + ''' + global connection_pool + + if connection_pool is None: + connection_pool = setup_redis() + + return redis.StrictRedis(connection_pool=connection_pool) + + +def setup_redis(): + ''' + Starts the connection pool for the configured redis server + ''' + config = get_config() + pool = redis.ConnectionPool( + host=config.db_host, + port=config.db_port, + password=config.db_pw, + db=config.db_name + + # decode_responses=config.get('decode_responses', True), + # # connection options + # socket_timeout=config.get('socket_timeout', None), + # socket_connect_timeout=config.get('socket_connect_timeout', None), + # socket_keepalive=config.get('socket_keepalive', False), + # socket_keepalive_options=config.get('socket_keepalive_options', None), + # retry_on_timeout=config.get('retry_on_timeout', False), + ) + return pool diff --git a/feeds/storage/redis/timeline_storage.py b/feeds/storage/redis/timeline_storage.py new file mode 100644 index 0000000..1885c2c --- /dev/null +++ b/feeds/storage/redis/timeline_storage.py @@ -0,0 +1,26 @@ +from ..base import TimelineStorage +from .connection import get_redis_connection +from .util import ( + get_user_key +) + + +class RedisTimelineStorage(TimelineStorage): + # TODO: CACHING!! + + def get_timeline(self, count=10): + """ + Gets the user's timeline of activity ids. + """ + r = get_redis_connection() + user_key = get_user_key(self.user_id) + user_timeline = r.zrevrange(user_key, 0, count) + return user_timeline + + def add_to_timeline(self, activity): + r = get_redis_connection() + feed_key = get_user_key(self.user_id) + r.zadd(feed_key, activity.time, activity.id) + + def remove_from_timeline(self, activity): + raise NotImplementedError() diff --git a/feeds/storage/redis/util.py b/feeds/storage/redis/util.py new file mode 100644 index 0000000..9c4c198 --- /dev/null +++ b/feeds/storage/redis/util.py @@ -0,0 +1,19 @@ +USER_FEED_KEY = "feed:user:{}" +ACTIVITY_STORAGE_KEY = "notes:{}" + + +def get_user_key(user): + return USER_FEED_KEY.format(user) + + +def get_note_id(note): + return "{}-{}".format(note.source, note.id) + + +def get_activity_key(activity): + if isinstance(activity, bytes): + return ACTIVITY_STORAGE_KEY.format(activity.decode('utf-8')[0]) + elif hasattr(activity, "id"): + return ACTIVITY_STORAGE_KEY.format(activity.id[0]) + else: + return ACTIVITY_STORAGE_KEY.format(activity[0]) diff --git a/feeds/util.py b/feeds/util.py index f58cfa2..74d3aa8 100644 --- a/feeds/util.py +++ b/feeds/util.py @@ -1,4 +1,5 @@ from datetime import datetime + def epoch_ms(): - return int(datetime.utcnow().timestamp()*1000) \ No newline at end of file + return int(datetime.utcnow().timestamp()*1000) diff --git a/feeds/verbs.py b/feeds/verbs.py index 9deda5c..731b175 100644 --- a/feeds/verbs.py +++ b/feeds/verbs.py @@ -2,6 +2,7 @@ _verb_register = dict() + def register(verb): if not issubclass(verb, Verb): raise TypeError("Can only register Verb subclasses") @@ -9,7 +10,8 @@ def register(verb): if verb.id is None: raise ValueError("A verb must have an id") elif str(verb.id) in _verb_register: - raise ValueError("The verb id '{}' is already taken by {}".format(verb.id, _verb_register[str(verb.id)].infinitive)) + raise ValueError("The verb id '{}' is already taken by {}".format( + verb.id, _verb_register[str(verb.id)].infinitive)) if verb.infinitive is None: raise ValueError("A verb must have an infinitive form") @@ -27,6 +29,7 @@ def register(verb): verb.past_tense.lower(): verb }) + def translate_verb(verb): """ Translates a given verb into a verb object. @@ -36,13 +39,16 @@ def translate_verb(verb): - if it's a verb that's registered, return it - if it's not a Verb or a str, raise a TypeError """ - if isinstance(verb, str): + if isinstance(verb, int): + return get_verb(str(verb)) + elif isinstance(verb, str): return get_verb(verb) - elif issubclass(verb, Verb): + elif isinstance(verb, Verb): return get_verb(verb.infinitive) else: raise TypeError("Must be either a subclass of Verb or a string.") + def get_verb(key): # if they're both None, fail. # otherwise, look it up. @@ -54,6 +60,7 @@ def get_verb(key): else: raise MissingVerbError('Verb "{}" not found.'.format(key)) + class Verb(object): id = None infinitive = None @@ -65,51 +72,61 @@ def __str__(self): def serialize(self): return self.id + class Invite(Verb): id = 1 infinitive = "invite" past_tense = "invited" + class Accept(Verb): id = 2 infinitive = "accept" past_tense = "accepted" + class Reject(Verb): id = 3 infinitive = "reject" past_tense = "rejected" + class Share(Verb): id = 4 infinitive = "share" past_tense = "shared" + class Unshare(Verb): id = 5 infinitive = "unshare" past_tense = "unshared" + class Join(Verb): id = 6 infinitive = "join" past_tense = "joined" + class Leave(Verb): id = 7 infinitive = "leave" past_tense = "left" + class Request(Verb): id = 8 infinitive = "request" past_tense = "requested" + class Update(Verb): id = 9 infinitive = "update" past_tense = "updated" + register(Invite) register(Accept) register(Reject) diff --git a/requirements.txt b/requirements.txt index c45ca79..92d2711 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,6 @@ flask==1.0.2 requests==2.19.1 gunicorn==19.9.0 gevent==1.3.7 -cachetools==2.1.0 \ No newline at end of file +cachetools==2.1.0 +redis==2.10.6 +pymongo==3.7.2 \ No newline at end of file diff --git a/test/activity/__init__.py b/test/activity/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/activity/test_notification.py b/test/activity/test_notification.py new file mode 100644 index 0000000..c6d3ee2 --- /dev/null +++ b/test/activity/test_notification.py @@ -0,0 +1,333 @@ +import pytest +import json +from feeds.activity.notification import Notification +import uuid +from feeds.util import epoch_ms +from ..conftest import test_config +from ..util import assert_is_uuid +from feeds.exceptions import ( + MissingVerbError, + MissingLevelError, + InvalidExpirationError, + InvalidNotificationError +) + +cfg = test_config() + +# some dummy "good" inputs for testing +actor = "test_actor" +verb_inf = "invite" +verb_past = "invited" +verb_id = 1 +note_object = "foo" +source = "groups" +level_name = "warning" +level_id = 2 +target = ["target_actor"] +context = {"some": "context"} +expires = epoch_ms() + (10 * 24 * 60 * 60 * 1000) # 10 days +external_key = "an_external_key" + +def assert_note_ok(note, **kwargs): + keys = [ + 'actor', 'object', 'source', 'target', 'context', 'expires', 'external_key' + ] + for k in keys: + if k in kwargs: + assert getattr(note, k) == kwargs[k] + if 'verb_id' in kwargs: + assert note.verb.id == int(kwargs['verb_id']) + if 'verb_inf' in kwargs: + assert note.verb.infinitive == kwargs['verb_inf'] + if 'level_id' in kwargs: + assert note.level.id == int(kwargs['level_id']) + if 'level_name' in kwargs: + assert note.level.name == kwargs['level_name'] + if 'expires' not in kwargs: + assert note.expires == note.created + (int(cfg.get('feeds', 'lifespan')) * 24 * 60 * 60 * 1000) + assert note.created < note.expires + assert_is_uuid(note.id) + +def test_note_new_ok_no_kwargs(): + note = Notification(actor, verb_inf, note_object, source) + assert_note_ok(note, actor=actor, verb_inf=verb_inf, object=note_object, source=source) + +def test_note_new_diff_levels(): + assert_args = { + "actor": actor, + "verb_inf": verb_inf, + "object": note_object, + "source": source + } + for name in ['alert', 'warning', 'request', 'error']: + note = Notification(actor, verb_inf, note_object, source, level=name) + test_args = assert_args.copy() + test_args['level_name'] = name + assert_note_ok(note, **test_args) + for id_ in ['1', '2', '3', '4']: + note = Notification(actor, verb_inf, note_object, source, level=id_) + test_args = assert_args.copy() + test_args['level_id'] = id_ + assert_note_ok(note, **test_args) + + +def test_note_new_target(): + note = Notification(actor, verb_inf, note_object, source, target=target) + assert_note_ok(note, actor=actor, verb_inf=verb_inf, + object=note_object, source=source, target=target) + + +def test_note_new_context(): + note = Notification(actor, verb_inf, note_object, source, context=context) + assert_note_ok(note, actor=actor, verb_inf=verb_inf, + object=note_object, source=source, context=context) + + +def test_note_new_expiration(): + note = Notification(actor, verb_inf, note_object, source, expires=expires) + assert_note_ok(note, actor=actor, verb_inf=verb_inf, + object=note_object, source=source, expires=expires) + + +def test_note_new_external_key(): + note = Notification(actor, verb_inf, note_object, source, external_key=external_key) + assert_note_ok(note, actor=actor, verb_inf=verb_inf, + object=note_object, source=source, external_key=external_key) + + +def test_note_new_bad_actor(): + # TODO: Should only fail on validate - shouldn't do a lookup whenever a new note is made. + # also, shouldn't be None. + with pytest.raises(AssertionError) as e: + Notification(None, verb_inf, note_object, source) + assert "actor must not be None" in str(e.value) + + +def test_note_new_bad_verb(): + with pytest.raises(AssertionError) as e: + Notification(actor, None, note_object, source) + assert "verb must not be None" in str(e.value) + + with pytest.raises(MissingVerbError) as e: + Notification(actor, "foobar", note_object, source) + assert 'Verb "foobar" not found' in str(e.value) + + +def test_note_new_bad_object(): + # TODO: Also test object validation itself later. + with pytest.raises(AssertionError) as e: + Notification(actor, verb_inf, None, source) + assert 'note_object must not be None' in str(e.value) + +def test_note_new_bad_source(): + # TODO: Validate sources as being real. + with pytest.raises(AssertionError) as e: + Notification(actor, verb_inf, note_object, None) + assert 'source must not be None' in str(e.value) + + +def test_note_new_bad_level(): + with pytest.raises(AssertionError) as e: + Notification(actor, verb_inf, note_object, source, level=None) + assert "level must not be None" in str(e.value) + + with pytest.raises(MissingLevelError) as e: + Notification(actor, verb_inf, note_object, source, level="foobar") + assert 'Level "foobar" not found' in str(e.value) + + +def test_note_new_bad_target(): + bad_targets = [{}, "foo", 123, False] + for bad in bad_targets: + with pytest.raises(AssertionError) as e: + Notification(actor, verb_inf, note_object, source, target=bad) + assert "target must be either a list or None" in str(e.value) + + +def test_note_new_bad_context(): + bad_context = [[], "foo", 123, False] + for bad in bad_context: + with pytest.raises(AssertionError) as e: + Notification(actor, verb_inf, note_object, source, context=bad) + assert "context must be either a dict or None" in str(e.value) + + +def test_note_new_bad_expires(): + bad_expires = ["foo", {}, []] + for bad in bad_expires: + with pytest.raises(InvalidExpirationError) as e: + Notification(actor, verb_inf, note_object, source, expires=bad) + assert "Expiration time should be the number of milliseconds" in str(e.value) + bad_expires = [123, True, False] + for bad in bad_expires: + with pytest.raises(InvalidExpirationError) as e: + Notification(actor, verb_inf, note_object, source, expires=bad) + assert "Notifications should expire sometime after they are created" in str(e.value) + + +def test_validate_ok(requests_mock): + user_id = "foo" + user_display = "Foo Bar" + requests_mock.get('{}/api/V2/users?list={}'.format(cfg.get('feeds', 'auth-url'), user_id), text=json.dumps({user_id: user_display})) + note = Notification(user_id, verb_inf, note_object, source) + # If this doesn't throw any errors, then it passes! + note.validate() + + +def test_validate_bad(requests_mock): + user_id = "foo" + requests_mock.get('{}/api/V2/users?list={}'.format(cfg.get('feeds', 'auth-url'), user_id), text=json.dumps({})) + note = Notification(user_id, verb_inf, note_object, source) + # If this doesn't throw any errors, then it passes! + note.validate() + + +def test_default_lifespan(): + note = Notification(actor, verb_inf, note_object, source) + lifespan = int(cfg.get('feeds', 'lifespan')) + assert note.expires - note.created == lifespan * 24 * 60 * 60 * 1000 + + +def test_to_dict(): + note = Notification(actor, verb_inf, note_object, source, level=level_name) + d = note.to_dict() + assert d["actor"] == actor + assert d["verb"] == verb_id + assert d["object"] == note_object + assert d["source"] == source + assert isinstance(d["expires"], int) and d["expires"] == note.expires + assert isinstance(d["created"], int) and d["created"] == note.created + assert d["target"] is None + assert d["context"] is None + assert d["level"] == level_id + assert d["external_key"] is None + + +def test_user_view(): + note = Notification(actor, verb_inf, note_object, source, level=level_id) + v = note.user_view() + assert v["actor"] == actor + assert v["verb"] == verb_past + assert v["object"] == note_object + assert v["source"] == source + assert isinstance(v["expires"], int) and v["expires"] == note.expires + assert isinstance(v["created"], int) and v["created"] == note.created + assert "target" not in v + assert v["context"] is None + assert v["level"] == level_name + assert "external_key" not in v + + +def test_from_dict(): + act_id = str(uuid.uuid4()) + verb = [verb_id, str(verb_id), verb_inf, verb_past] + level = [level_id, level_name, str(level_id)] + d = { + "actor": actor, + "object": note_object, + "source": source, + "expires": 1234567890111, + "created": 1234567890000, + "target": target, + "context": context, + "external_key": external_key, + "id": act_id + } + for v in verb: + for l in level: + note_d = d.copy() + note_d.update({'level': l, 'verb': v}) + note = Notification.from_dict(note_d) + assert_note_ok(note, **note_d) + + +def test_from_dict_missing_keys(): + d = { + "actor": actor + } + with pytest.raises(InvalidNotificationError) as e: + Notification.from_dict(d) + assert "Missing keys" in str(e.value) + + with pytest.raises(InvalidNotificationError) as e: + Notification.from_dict(None) + assert "Can only run 'from_dict' on a dict" in str(e.value) + + +def test_serialization(): + note = Notification(actor, verb_inf, note_object, source, level=level_id) + serial = note.serialize() + json_serial = json.loads(serial) + # serial = { + # "i": self.id, + # "a": self.actor, + # "v": self.verb.id, + # "o": self.object, + # "s": self.source, + # "t": self.target, + # "l": self.level.id, + # "c": self.created, + # "e": self.expires, + # "x": self.external_key + # } + assert "i" in json_serial + assert_is_uuid(json_serial['i']) + assert "a" in json_serial and json_serial['a'] == actor + assert "v" in json_serial and json_serial['v'] == verb_id + assert "o" in json_serial and json_serial['o'] == note_object + assert "s" in json_serial and json_serial['s'] == source + assert "l" in json_serial and json_serial['l'] == level_id + assert "c" in json_serial and json_serial['c'] == note.created + assert "e" in json_serial and json_serial['e'] == note.expires + assert "n" in json_serial and json_serial['n'] == None + assert "x" in json_serial and json_serial['x'] == None + assert "t" in json_serial and json_serial['t'] == None + + +def test_serialization_all_kwargs(): + note = Notification(actor, verb_inf, note_object, source, level=level_id, + target=target, external_key=external_key, context=context) + serial = note.serialize() + json_serial = json.loads(serial) + assert "i" in json_serial + assert_is_uuid(json_serial['i']) + assert "a" in json_serial and json_serial['a'] == actor + assert "v" in json_serial and json_serial['v'] == verb_id + assert "o" in json_serial and json_serial['o'] == note_object + assert "s" in json_serial and json_serial['s'] == source + assert "l" in json_serial and json_serial['l'] == level_id + assert "c" in json_serial and json_serial['c'] == note.created + assert "e" in json_serial and json_serial['e'] == note.expires + assert "n" in json_serial and json_serial['n'] == context + assert "x" in json_serial and json_serial['x'] == external_key + assert "t" in json_serial and json_serial['t'] == target + + +def test_deserialization(): + note = Notification(actor, verb_inf, note_object, source, level=level_id, + target=target, external_key=external_key, context=context) + serial = note.serialize() + note2 = Notification.deserialize(serial) + assert note2.id == note.id + assert note2.actor == note.actor + assert note2.verb.id == note.verb.id + assert note2.object == note.object + assert note2.source == note.source + assert note2.level.id == note.level.id + assert note2.target == note.target + assert note2.external_key == note.external_key + assert note2.context == note.context + + +def test_deserialize_bad(): + with pytest.raises(InvalidNotificationError) as e: + Notification.deserialize(None) + assert "Can't deserialize an input of 'None'" in str(e.value) + + with pytest.raises(InvalidNotificationError) as e: + Notification.deserialize(json.dumps({'a': actor})) + assert "Missing keys" in str(e.value) + + with pytest.raises(InvalidNotificationError) as e: + Notification.deserialize("foo") + assert "Can only deserialize a JSON string" in str(e.value) \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..b93db29 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,19 @@ +import os +import configparser + +def pytest_sessionstart(session): + os.environ['AUTH_TOKEN'] = 'foo' + os.environ['FEEDS_CONFIG'] = os.path.join(os.path.dirname(__file__), 'test.cfg') + +def pytest_sessionfinish(session, exitstatus): + pass + +def test_config(): + """ + Returns a ConfigParser. + Because I'm lazy. + """ + cfg = configparser.ConfigParser() + with open(os.environ['FEEDS_CONFIG'], 'r') as f: + cfg.read_file(f) + return cfg \ No newline at end of file diff --git a/test/feeds/__init__.py b/test/feeds/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/feeds/notification/test_notification_feed.py b/test/feeds/notification/test_notification_feed.py new file mode 100644 index 0000000..e69de29 diff --git a/test/managers/__init__.py b/test/managers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/storage/__init__.py b/test/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/storage/mongodb/test_mongo_activity_storage.py b/test/storage/mongodb/test_mongo_activity_storage.py new file mode 100644 index 0000000..e69de29 diff --git a/test/storage/mongodb/test_mongo_connection.py b/test/storage/mongodb/test_mongo_connection.py new file mode 100644 index 0000000..e69de29 diff --git a/test/storage/mongodb/test_mongo_timeline_storage.py b/test/storage/mongodb/test_mongo_timeline_storage.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test.cfg b/test/test.cfg new file mode 100644 index 0000000..c6ad7e3 --- /dev/null +++ b/test/test.cfg @@ -0,0 +1,12 @@ +[feeds] +db-engine=mongodb +db-host=localhost +db-port=27017 +db-user= +db-pw= +db-name=feeds +auth-url=http://localhost/auth +admins=feeds_admin +global-feed=_kbase_ +debug=True +lifespan=30 \ No newline at end of file diff --git a/test/test_actor.py b/test/test_actor.py new file mode 100644 index 0000000..48faae5 --- /dev/null +++ b/test/test_actor.py @@ -0,0 +1,22 @@ +import pytest +import requests +import json +import os +from feeds.actor import validate_actor +from .conftest import test_config +from feeds.exceptions import InvalidActorError + +cfg = test_config() +def test_validate_actor(requests_mock): + user_id = "foo" + user_display = "Foo Bar" + requests_mock.get('{}/api/V2/users?list={}'.format(cfg.get('feeds', 'auth-url'), user_id), text=json.dumps({user_id: user_display})) + assert validate_actor(user_id) + + +def test_validate_actor_fail(requests_mock): + user_id = "foo2" + requests_mock.get('{}/api/V2/users?list={}'.format(cfg.get('feeds', 'auth-url'), user_id), text=json.dumps({})) + with pytest.raises(InvalidActorError) as e: + validate_actor(user_id) + assert "Actor '{}' is not a real user".format(user_id) in str(e.value) diff --git a/test/test_auth.py b/test/test_auth.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_config.py b/test/test_config.py index 9c7c581..c11e089 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -4,85 +4,109 @@ from unittest import mock from pathlib import Path import os +from tempfile import mkstemp # TODO - more error checking FAKE_AUTH_TOKEN = "I'm an auth token!" -def write_test_cfg(path, cfg_lines): - if os.path.exists(path): - raise ValueError("Not gonna overwrite some existing file with this test stuff!") - with open(path, "w") as f: - f.write("\n".join(cfg_lines)) - -def test_config_from_env_ok(): - cfg_lines = [ - '[feeds]', - 'redis-host=foo', - 'redis-port=bar', - 'auth-url=baz' - ] - cfg_path = "fake_test_config_delete_me.cfg" - write_test_cfg(cfg_path, cfg_lines) +GOOD_CONFIG = [ + '[feeds]', + 'db-engine=redis', + 'db-host=foo', + 'db-port=5', + 'auth-url=baz', + 'global-feed=global', + 'admins=admin1,admin2,admin3', + 'lifespan=30' +] +@pytest.fixture(scope="function") +def dummy_auth_token(): + backup_token = os.environ.get('AUTH_TOKEN') os.environ['AUTH_TOKEN'] = FAKE_AUTH_TOKEN + yield + del os.environ['AUTH_TOKEN'] + if backup_token is not None: + os.environ['AUTH_TOKEN'] = backup_token + +@pytest.fixture(scope="function") +def dummy_config(): + (f, fname) = mkstemp(text=True) + + def _write_test_cfg(cfg_lines): + with open(fname, 'w') as cfg: + cfg.write("\n".join(cfg_lines)) + return fname + yield _write_test_cfg + os.remove(fname) + + +def test_config_from_env_ok(dummy_config, dummy_auth_token): + cfg_path = dummy_config(GOOD_CONFIG) + + feeds_config_backup = os.environ.get('FEEDS_CONFIG') os.environ['FEEDS_CONFIG'] = cfg_path cfg = config.FeedsConfig() assert cfg.auth_url == 'baz' - assert cfg.redis_host == 'foo' - assert cfg.redis_port == 'bar' - + assert cfg.db_host == 'foo' + assert cfg.db_port == 5 del os.environ['FEEDS_CONFIG'] + + kb_dep_config = os.environ.get('KB_DEPLOYMENT_CONFIG') os.environ['KB_DEPLOYMENT_CONFIG'] = cfg_path cfg = config.FeedsConfig() assert cfg.auth_url == 'baz' - assert cfg.redis_host == 'foo' - assert cfg.redis_port == 'bar' - + assert cfg.db_host == 'foo' + assert cfg.db_port == 5 del os.environ['KB_DEPLOYMENT_CONFIG'] - del os.environ['AUTH_TOKEN'] - os.remove(cfg_path) + if kb_dep_config is not None: + os.environ['KB_DEPLOYMENT_CONFIG'] = path_backup -def test_config_from_env_errors(): - os.environ['AUTH_TOKEN'] = FAKE_AUTH_TOKEN + if feeds_config_backup is not None: + os.environ['FEEDS_CONFIG'] = feeds_config_backup + + +def test_config_from_env_errors(dummy_config, dummy_auth_token): cfg_lines = [ '[not-feeds]', - 'redis-host=foo' + 'db-host=foo' ] - cfg_path = "fake_test_config_delete_me.cfg" - write_test_cfg(cfg_path, cfg_lines) + + cfg_path = dummy_config(cfg_lines) + path_backup = os.environ.get('FEEDS_CONFIG') os.environ['FEEDS_CONFIG'] = cfg_path with pytest.raises(ConfigError) as e: config.FeedsConfig() assert "Error parsing config file: section feeds not found!" in str(e.value) - - del os.environ['AUTH_TOKEN'] del os.environ['FEEDS_CONFIG'] - os.remove(cfg_path) + if path_backup is not None: + os.environ['FEEDS_CONFIG'] = path_backup def test_config_from_env_no_auth(): + backup_token = os.environ.get('AUTH_TOKEN') + if 'AUTH_TOKEN' in os.environ: + del os.environ['AUTH_TOKEN'] with pytest.raises(RuntimeError) as e: config.FeedsConfig() assert "The AUTH_TOKEN environment variable must be set!" in str(e.value) + if backup_token is not None: + os.environ['AUTH_TOKEN'] = backup_token -def test_get_config(): - cfg_lines = [ - '[feeds]', - 'redis-host=foo', - 'redis-port=bar', - 'auth-url=baz' - ] - cfg_path = "fake_test_config_delete_me.cfg" - write_test_cfg(cfg_path, cfg_lines) +def test_get_config(dummy_config, dummy_auth_token): + cfg_path = dummy_config(GOOD_CONFIG) + + path_backup = os.environ.get('FEEDS_CONFIG') os.environ['FEEDS_CONFIG'] = cfg_path - os.environ['AUTH_TOKEN'] = FAKE_AUTH_TOKEN + config.__config = None cfg = config.get_config() - assert cfg.redis_host == 'foo' - assert cfg.redis_port == 'bar' + assert cfg.db_host == 'foo' + assert cfg.db_port == 5 assert cfg.auth_url == 'baz' assert cfg.auth_token == FAKE_AUTH_TOKEN - os.remove("fake_test_config_delete_me.cfg") del os.environ['FEEDS_CONFIG'] - del os.environ['AUTH_TOKEN'] \ No newline at end of file + if path_backup is not None: + os.environ['FEEDS_CONFIG'] = path_backup + config.__config = None \ No newline at end of file diff --git a/test/test_notification.py b/test/test_notification.py deleted file mode 100644 index 415cff3..0000000 --- a/test/test_notification.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest -from feeds.activity.notification import Notification -import uuid - -def test_basic_notification(): - assert True - # n = Notification('foo', 'bar', 'baz') - # assert n.actor == 'foo' - # assert n.note_type == 'bar' - # assert n.object == 'baz' - # assert n.content == {} - # assert n.target == None - # assert validate_uuid(n.id) \ No newline at end of file diff --git a/test/test_notification_level.py b/test/test_notification_level.py new file mode 100644 index 0000000..7eabdb9 --- /dev/null +++ b/test/test_notification_level.py @@ -0,0 +1,92 @@ +import pytest +import feeds.notification_level as level +from feeds.exceptions import ( + MissingLevelError +) + +def test_register_level_ok(): + class TestLevel(level.Level): + id=666 + name="test" + level.register(TestLevel) + assert '666' in level._level_register + assert level._level_register['666'] == TestLevel + assert 'test' in level._level_register + assert level._level_register['test'] == TestLevel + +def test_register_level_bad(): + class NoId(level.Level): + id=None + name="noid" + + with pytest.raises(ValueError) as e: + level.register(NoId) + assert "A level must have an id" in str(e.value) + + class NoName(level.Level): + id=667 + + with pytest.raises(ValueError) as e: + level.register(NoName) + assert "A level must have a name" in str(e.value) + + class DuplicateId(level.Level): + id='1' + name='duplicate' + + with pytest.raises(ValueError) as e: + level.register(DuplicateId) + assert "The level id '1' is already taken by alert" in str(e.value) + + class DuplicateName(level.Level): + id=668 + name="warning" + + with pytest.raises(ValueError) as e: + level.register(DuplicateName) + assert "The level 'warning' is already registered" in str(e.value) + + with pytest.raises(TypeError) as e: + level.register(str) + assert "Can only register Level subclasses" in str(e.value) + + with pytest.raises(ValueError) as e: + level.register(level.Alert) + assert "The level id '1' is already taken by alert" in str(e.value) + + +def test_get_level(): + l = level.get_level('warning') + assert isinstance(l, level.Warning) + assert l.id == level.Warning.id + assert l.name == level.Warning.name + + missing = "not_a_real_level" + with pytest.raises(MissingLevelError) as e: + level.get_level(missing) + assert 'Level "{}" not found'.format(missing) in str(e.value) + + +def test_translate_level(): + l = level.Alert() + l_trans = level.translate_level(l) + assert isinstance(l_trans, level.Alert) + + l = level.translate_level(1) + assert isinstance(l, level.Alert) + assert l.name == 'alert' + + l = level.translate_level('1') + assert isinstance(l, level.Alert) + assert l.name == 'alert' + + l = level.translate_level('alert') + assert isinstance(l, level.Alert) + + with pytest.raises(MissingLevelError) as e: + level.translate_level('foo') + assert 'Level "foo" not found' in str(e.value) + + with pytest.raises(TypeError) as e: + level.translate_level([]) + assert 'Must be either a subclass of Level or a string' in str(e.value) diff --git a/test/test_verbs.py b/test/test_verbs.py index a135535..dba41cb 100644 --- a/test/test_verbs.py +++ b/test/test_verbs.py @@ -85,4 +85,28 @@ def test_get_verb_fail(): def test_serialize(): v = verbs.get_verb('invite') - assert v.serialize() == 1 \ No newline at end of file + assert v.serialize() == 1 + +def test_translate_verb(): + v = verbs.Request() + v_trans = verbs.translate_verb(v) + assert isinstance(v_trans, verbs.Request) + + v = verbs.translate_verb(1) + assert isinstance(v, verbs.Invite) + assert v.infinitive == 'invite' + + v = verbs.translate_verb('1') + assert isinstance(v, verbs.Invite) + assert v.infinitive == 'invite' + + l = verbs.translate_verb('invite') + assert isinstance(l, verbs.Invite) + + with pytest.raises(MissingVerbError) as e: + verbs.translate_verb('foo') + assert 'Verb "foo" not found' in str(e.value) + + with pytest.raises(TypeError) as e: + verbs.translate_verb([]) + assert 'Must be either a subclass of Verb or a string' in str(e.value) diff --git a/test/util.py b/test/util.py index e69de29..a517730 100644 --- a/test/util.py +++ b/test/util.py @@ -0,0 +1,5 @@ +import uuid + +def assert_is_uuid(s): + # raises a ValueError if not. Good enough for testing. + uuid.UUID(s) \ No newline at end of file