diff --git a/run_all_tests.py b/run_all_tests.py index 65ae1b54f..af6d6c68e 100644 --- a/run_all_tests.py +++ b/run_all_tests.py @@ -46,5 +46,6 @@ def run_tests(): successful = False finally: from flask import current_app + current_app.running_context.executor.shutdown_pool() sys.exit(not successful) diff --git a/tests/util/mock_objects.py b/tests/util/mock_objects.py index cf1eb1686..704a373b5 100644 --- a/tests/util/mock_objects.py +++ b/tests/util/mock_objects.py @@ -130,7 +130,7 @@ def __init__(self, current_app): def send(self, packet): with self.current_app.app_context(): - self.send_callback(packet) + self._send_callback(packet) def _increment_execution_count(self): global workflows_executed diff --git a/walkoff.py b/walkoff.py index 517eeafa4..1587f9fdb 100644 --- a/walkoff.py +++ b/walkoff.py @@ -9,19 +9,18 @@ import walkoff import walkoff.config - +from scripts.compose_api import compose_api +from walkoff.multiprocessedexecutor.multiprocessedexecutor import spawn_worker_processes from walkoff.server.app import create_app logger = logging.getLogger('walkoff') def run(app, host, port): - from walkoff.multiprocessedexecutor.multiprocessedexecutor import spawn_worker_processes print_banner() pids = spawn_worker_processes() monkey.patch_all() - from scripts.compose_api import compose_api compose_api() app.running_context.executor.initialize_threading(app, pids) diff --git a/walkoff/cache.py b/walkoff/cache.py index d63d2a6dc..11cf6dad1 100644 --- a/walkoff/cache.py +++ b/walkoff/cache.py @@ -1,20 +1,21 @@ import logging import os -import os.path +import pickle import sqlite3 +import threading +from copy import deepcopy from datetime import timedelta from functools import partial from weakref import WeakSet + +import os.path from diskcache import FanoutCache, DEFAULT_SETTINGS, Cache from diskcache.core import DBNAME from gevent import sleep from gevent.event import AsyncResult, Event -import threading -import walkoff.config -import pickle -from copy import deepcopy from six import string_types, binary_type -import json + +import walkoff.config logger = logging.getLogger(__name__) @@ -23,7 +24,6 @@ except ImportError: from cStringIO import StringIO as BytesIO - unsubscribe_message = b'__UNSUBSCRIBE__' """(str): The message used to unsubscribe from and close a PubSub channel """ @@ -41,6 +41,7 @@ class DiskSubscription(object): Args: channel (str): The channel name associated with this subscription """ + def __init__(self, channel): self.channel = channel self._listener = None @@ -169,7 +170,8 @@ def __push_to_subscribers(self, channel, value): @staticmethod def __get_value(value): - if value == unsubscribe_message or isinstance(value, string_types) or isinstance(value, int) or isinstance(value, float): + if value == unsubscribe_message or isinstance(value, string_types) or isinstance(value, int) or isinstance( + value, float): return value if isinstance(value, binary_type): return value.decode('utf-8') @@ -231,6 +233,7 @@ class DiskCacheAdapter(object): retry (bool, optional): Should this database retry timed out transactions? Default to True **settings: Other setting which will be passsed to the `cache` attribute on initialization """ + def __init__(self, directory, shards=8, timeout=0.01, retry=True, **settings): self.directory = directory self.retry = retry @@ -430,7 +433,7 @@ def _convert_expire_to_seconds(time): Returns: (float): The expiration time in seconds """ - return time.total_seconds() if isinstance(time, timedelta) else float(time)/1000. + return time.total_seconds() if isinstance(time, timedelta) else float(time) / 1000. def shutdown(self): """Shuts down the connection to the cache diff --git a/walkoff/config.py b/walkoff/config.py index e7d7a777a..7a4cb91ea 100644 --- a/walkoff/config.py +++ b/walkoff/config.py @@ -2,10 +2,12 @@ import logging import logging.config import sys -from os.path import isfile, join, abspath import warnings import yaml +from os.path import isfile, join, abspath + +from walkoff.appgateway import cache_apps logger = logging.getLogger(__name__) @@ -136,6 +138,9 @@ class Config(object): @classmethod def load_config(cls, config_path=None): """ Loads Walkoff configuration from JSON file + + Args: + config_path (str): Optional path to the config. Defaults to the CONFIG_PATH class variable. """ if config_path: cls.CONFIG_PATH = config_path @@ -158,8 +163,7 @@ def load_config(cls, config_path=None): @classmethod def write_values_to_file(cls, keys=None): - """ Writes the current walkoff configuration to a file - """ + """ Writes the current walkoff configuration to a file""" if keys is None: keys = [key for key in dir(cls) if not key.startswith('__')] @@ -173,10 +177,8 @@ def write_values_to_file(cls, keys=None): def initialize(config_path=None): - """Loads the config file, loads the app cache, and loads the app APIs into memory - """ + """Loads the config file, loads the app cache, and loads the app APIs into memory""" Config.load_config(config_path) setup_logger() - from walkoff.appgateway import cache_apps cache_apps(Config.APPS_PATH) load_app_apis() diff --git a/walkoff/messaging/utils.py b/walkoff/messaging/utils.py index dc2909fe2..8a216afab 100644 --- a/walkoff/messaging/utils.py +++ b/walkoff/messaging/utils.py @@ -1,12 +1,13 @@ import json import logging +from flask import current_app + import walkoff.messaging from walkoff.events import WalkoffEvent from walkoff.extensions import db from walkoff.serverdb import Role, User from walkoff.serverdb.message import Message -from flask import current_app logger = logging.getLogger(__name__) diff --git a/walkoff/multiprocessedexecutor/multiprocessedexecutor.py b/walkoff/multiprocessedexecutor/multiprocessedexecutor.py index cfb05979d..36534380d 100644 --- a/walkoff/multiprocessedexecutor/multiprocessedexecutor.py +++ b/walkoff/multiprocessedexecutor/multiprocessedexecutor.py @@ -9,16 +9,16 @@ import gevent import zmq.green as zmq -from walkoff.executiondb import ExecutionDatabase +import walkoff.config from walkoff.events import WalkoffEvent +from walkoff.executiondb import ExecutionDatabase from walkoff.executiondb import WorkflowStatusEnum from walkoff.executiondb.saved_workflow import SavedWorkflow from walkoff.executiondb.workflow import Workflow from walkoff.executiondb.workflowresults import WorkflowStatus -from walkoff.multiprocessedexecutor.workflowexecutioncontroller import WorkflowExecutionController, Receiver from walkoff.multiprocessedexecutor.threadauthenticator import ThreadAuthenticator from walkoff.multiprocessedexecutor.worker import Worker -import walkoff.config +from walkoff.multiprocessedexecutor.workflowexecutioncontroller import WorkflowExecutionController, Receiver logger = logging.getLogger(__name__) @@ -58,7 +58,8 @@ def initialize_threading(self, app, pids=None): """Initialize the multiprocessing communication threads, allowing for parallel execution of workflows. Args: - pids (list, optional): Optional list of spawned processes. Defaults to None + app (FlaskApp): The current_app object + pids (list[Process], optional): Optional list of spawned processes. Defaults to None """ if not (os.path.exists(walkoff.config.Config.ZMQ_PUBLIC_KEYS_PATH) and @@ -82,6 +83,11 @@ def initialize_threading(self, app, pids=None): logger.debug('Controller threading initialized') def wait_and_reset(self, num_workflows): + """Waits for all of the workflows to be completed + + Args: + num_workflows (int): The number of workflows to wait for + """ timeout = 0 shutdown = 10 @@ -94,8 +100,7 @@ def wait_and_reset(self, num_workflows): self.receiver.workflows_executed = 0 def shutdown_pool(self): - """Shuts down the threadpool. - """ + """Shuts down the threadpool""" self.manager.send_exit_to_worker_comms() if len(self.pids) > 0: for p in self.pids: @@ -120,8 +125,7 @@ def shutdown_pool(self): return def cleanup_threading(self): - """Once the threadpool has been shutdown, clear out all of the data structures used in the pool. - """ + """Once the threadpool has been shutdown, clear out all of the data structures used in the pool""" self.pids = [] self.receiver_thread = None self.workflows_executed = 0 @@ -130,18 +134,18 @@ def cleanup_threading(self): self.receiver = None def execute_workflow(self, workflow_id, execution_id_in=None, start=None, start_arguments=None, resume=False): - """Executes a workflow. + """Executes a workflow Args: workflow_id (Workflow): The Workflow to be executed. - execution_id_in (str, optional): The optional execution ID to provide for the workflow. Should only be + execution_id_in (UUID, optional): The optional execution ID to provide for the workflow. Should only be used (and is required) when resuming a workflow. Must be valid UUID4. Defaults to None. - start (str, optional): The ID of the first, or starting action. Defaults to None. + start (UUID, optional): The ID of the first, or starting action. Defaults to None. start_arguments (list[Argument]): The arguments to the starting action of the workflow. Defaults to None. resume (bool, optional): Optional boolean to resume a previously paused workflow. Defaults to False. Returns: - The execution ID of the Workflow. + (UUID): The execution ID of the Workflow. """ workflow = self.execution_db.session.query(Workflow).filter_by(id=workflow_id).first() if not workflow: @@ -166,7 +170,10 @@ def pause_workflow(self, execution_id): """Pauses a workflow that is currently executing. Args: - execution_id (str): The execution id of the workflow. + execution_id (UUID): The execution id of the workflow. + + Returns: + (bool): True if Workflow successfully paused, False otherwise """ workflow_status = self.execution_db.session.query(WorkflowStatus).filter_by( execution_id=execution_id).first() @@ -181,7 +188,10 @@ def resume_workflow(self, execution_id): """Resumes a workflow that is currently paused. Args: - execution_id (str): The execution id of the workflow. + execution_id (UUID): The execution id of the workflow. + + Returns: + (bool): True if workflow successfully resumed, False otherwise """ workflow_status = self.execution_db.session.query(WorkflowStatus).filter_by( execution_id=execution_id).first() @@ -202,10 +212,13 @@ def resume_workflow(self, execution_id): return False def abort_workflow(self, execution_id): - """Abort a workflow. + """Abort a workflow Args: - execution_id (str): The execution id of the workflow. + execution_id (UUID): The execution id of the workflow. + + Returns: + (bool): True if successfully aborted workflow, False otherwise """ workflow_status = self.execution_db.session.query(WorkflowStatus).filter_by( execution_id=execution_id).first() @@ -231,10 +244,13 @@ def resume_trigger_step(self, execution_id, data_in, arguments=None): """Resumes a workflow awaiting trigger data, if the conditions are met. Args: - execution_id (str): The execution ID of the workflow + execution_id (UUID): The execution ID of the workflow data_in (dict): The data to send to the trigger arguments (list[Argument], optional): Optional list of new Arguments for the trigger action. Defaults to None. + + Returns: + (bool): True if successfully resumed trigger step, false otherwise """ saved_state = self.execution_db.session.query(SavedWorkflow).filter_by( workflow_execution_id=execution_id).first() @@ -273,7 +289,7 @@ def get_waiting_workflows(self): """Gets a list of the execution IDs of workflows currently awaiting data to be sent to a trigger. Returns: - A list of execution IDs of workflows currently awaiting data to be sent to a trigger. + (list[UUID]): A list of execution IDs of workflows currently awaiting data to be sent to a trigger. """ self.execution_db.session.expire_all() wf_statuses = self.execution_db.session.query(WorkflowStatus).filter_by( @@ -284,10 +300,10 @@ def get_workflow_status(self, execution_id): """Gets the current status of a workflow by its execution ID Args: - execution_id (str): The execution ID of the workflow + execution_id (UUID): The execution ID of the workflow Returns: - The status of the workflow + (int): The status of the workflow """ workflow_status = self.execution_db.session.query(WorkflowStatus).filter_by( execution_id=execution_id).first() @@ -304,10 +320,27 @@ def _log_and_send_event(self, event, sender=None, data=None): event.send(sender, data=data) def create_case(self, case_id, subscriptions): + """Creates a Case + + Args: + case_id (int): The ID of the Case + subscriptions (list[Subscription]): List of Subscriptions to subscribe to + """ self.manager.create_case(case_id, subscriptions) def update_case(self, case_id, subscriptions): + """Updates a Case + + Args: + case_id (int): The ID of the Case + subscriptions (list[Subscription]): List of Subscriptions to subscribe to + """ self.manager.create_case(case_id, subscriptions) def delete_case(self, case_id): + """Deletes a Case + + Args: + case_id (int): The ID of the Case to delete + """ self.manager.delete_case(case_id) diff --git a/walkoff/multiprocessedexecutor/proto_helpers.py b/walkoff/multiprocessedexecutor/proto_helpers.py index 0f555c677..98b2ac3f1 100644 --- a/walkoff/multiprocessedexecutor/proto_helpers.py +++ b/walkoff/multiprocessedexecutor/proto_helpers.py @@ -1,5 +1,6 @@ import json import logging + from six import string_types from walkoff.events import EventType, WalkoffEvent @@ -18,7 +19,7 @@ def convert_to_protobuf(sender, workflow, **kwargs): kwargs (dict, optional): A dict of extra fields, such as data, callback_name, etc. Returns: - The newly formed protobuf object, serialized as a string to send over the ZMQ socket. + (str): The newly formed protobuf object, serialized as a string to send over the ZMQ socket. """ event = kwargs['event'] data = kwargs['data'] if 'data' in kwargs else None @@ -41,6 +42,13 @@ def convert_to_protobuf(sender, workflow, **kwargs): def convert_workflow_to_proto(packet, sender, data=None): + """Converts a Workflow object to a protobuf object + + Args: + packet (Message): The protobuf packet to add the Workflow to + sender (Workflow): The Workflow to add to the packet + data (dict): Any additional data to add to the protobuf packet + """ packet.type = Message.WORKFLOWPACKET workflow_packet = packet.workflow_packet if 'data' is not None: @@ -49,6 +57,14 @@ def convert_workflow_to_proto(packet, sender, data=None): def convert_send_message_to_protobuf(packet, message, workflow, **kwargs): + """Converts a Message object to a protobuf object + + Args: + packet (protobuf): The protobuf packet + message (Message): The Message object to be converted + workflow (Workflow): The Workflow relating to this Message + **kwargs (dict, optional): Any additional arguments + """ packet.type = Message.USERMESSAGE message_packet = packet.message_packet message_packet.subject = message.pop('subject', '') @@ -63,6 +79,14 @@ def convert_send_message_to_protobuf(packet, message, workflow, **kwargs): def convert_log_message_to_protobuf(packet, sender, workflow, **kwargs): + """Converts a logging message to protobuf + + Args: + packet (protobuf): The protobuf packet + sender (Action): The Action from which this logging message originated + workflow (Workflow): The Workflow under which this Action falls + **kwargs (dict, optional): Any additional arguments + """ packet.type = Message.LOGMESSAGE logging_packet = packet.logging_packet logging_packet.name = sender.name @@ -74,6 +98,14 @@ def convert_log_message_to_protobuf(packet, sender, workflow, **kwargs): def convert_action_to_proto(packet, sender, workflow, data=None): + """Converts an Action to protobuf + + Args: + packet (protobuf): The protobuf packet + sender (Action): The Action + workflow (Workflow): The WOrkflow under which this Action falls + data (dict, optional): Any additional data. Defaults to None. + """ packet.type = Message.ACTIONPACKET action_packet = packet.action_packet if 'data' is not None: @@ -84,6 +116,12 @@ def convert_action_to_proto(packet, sender, workflow, data=None): def add_sender_to_action_packet_proto(action_packet, sender): + """Adds a sender to a protobuf packet + + Args: + action_packet (protobuf): The protobuf packet + sender (Action): The sender + """ action_packet.sender.name = sender.name action_packet.sender.id = str(sender.id) action_packet.sender.execution_id = sender.get_execution_id() @@ -94,6 +132,12 @@ def add_sender_to_action_packet_proto(action_packet, sender): def add_arguments_to_action_proto(action_packet, sender): + """Adds Arguments to the Action protobuf packet + + Args: + action_packet (protobuf): The protobuf packet + sender (Action): The Action under which fall the Arguments + """ for argument in sender.arguments: arg = action_packet.sender.arguments.add() arg.name = argument.name @@ -101,6 +145,12 @@ def add_arguments_to_action_proto(action_packet, sender): def set_argument_proto(arg_proto, arg_obj): + """Sets up the Argument protobuf + + Args: + arg_proto (protobuf): The Argument protobuf field + arg_obj (Argument): The Argument object + """ arg_proto.name = arg_obj.name for field in ('value', 'reference', 'selection'): val = getattr(arg_obj, field) @@ -115,12 +165,25 @@ def set_argument_proto(arg_proto, arg_obj): def add_workflow_to_proto(packet, workflow): + """Adds a Workflow to a protobuf packet + + Args: + packet (protobuf): The protobuf packet + workflow (Workflow): The Workflow object to add to the protobuf message + """ packet.name = workflow.name packet.id = str(workflow.id) packet.execution_id = str(workflow.get_execution_id()) def convert_branch_transform_condition_to_proto(packet, sender, workflow): + """Converts a Branch, Transform, or Condition to protobuf + + Args: + packet (protobuf): The protobuf packet + sender (Branch|Transform|Condition): The object to be converted to protobuf + workflow (Workflow): The Workflow under which the object falls + """ packet.type = Message.GENERALPACKET general_packet = packet.general_packet general_packet.sender.id = str(sender.id) diff --git a/walkoff/multiprocessedexecutor/worker.py b/walkoff/multiprocessedexecutor/worker.py index 3053f313e..0cae5ab08 100644 --- a/walkoff/multiprocessedexecutor/worker.py +++ b/walkoff/multiprocessedexecutor/worker.py @@ -3,41 +3,49 @@ import signal import threading import time -from enum import Enum from collections import namedtuple +from threading import Lock + import nacl.bindings import nacl.utils import zmq import zmq.auth as auth from concurrent.futures import ThreadPoolExecutor +from enum import Enum from google.protobuf.json_format import MessageToDict from nacl.public import PrivateKey, Box -from walkoff.executiondb import ExecutionDatabase -from walkoff.case.database import CaseDatabase +import walkoff.cache +import walkoff.config from walkoff.appgateway.appinstancerepo import AppInstanceRepo +from walkoff.case.database import CaseDatabase +from walkoff.case.logger import CaseLogger +from walkoff.case.subscription import Subscription, SubscriptionCache from walkoff.events import WalkoffEvent +from walkoff.executiondb import ExecutionDatabase from walkoff.executiondb.argument import Argument from walkoff.executiondb.saved_workflow import SavedWorkflow from walkoff.executiondb.workflow import Workflow +from walkoff.multiprocessedexecutor.proto_helpers import convert_to_protobuf from walkoff.proto.build.data_pb2 import CommunicationPacket, ExecuteWorkflowMessage, CaseControl, \ WorkflowControl -import walkoff.cache -from walkoff.case.logger import CaseLogger -from walkoff.case.subscription import Subscription, SubscriptionCache -from threading import Lock -from walkoff.multiprocessedexecutor.proto_helpers import convert_to_protobuf -import walkoff.config logger = logging.getLogger(__name__) class WorkflowResultsHandler(object): - def __init__(self, socket_id, client_secret_key, client_public_key, server_public_key, zmq_results_address, execution_db, case_logger): - """Initialize a Workflow object, which will be executing workflows. + def __init__(self, socket_id, client_secret_key, client_public_key, server_public_key, zmq_results_address, + execution_db, case_logger): + """Initialize a WorkflowResultsHandler object, which will be executing workflows. Args: - id_ (str): The ID of the worker. Needed for ZMQ socket communication. + socket_id (str): The ID for the results socket + client_secret_key (str): The secret key for the client + client_public_key (str): The public key for the client + server_public_key (str): The public key for the server + zmq_results_address (str): The address for the ZMQ results socket + execution_db (ExecutionDatabase): An ExecutionDatabase connection object + case_logger (CaseLoger): A CaseLogger instance """ self.results_sock = zmq.Context().socket(zmq.PUSH) self.results_sock.identity = socket_id @@ -51,6 +59,8 @@ def __init__(self, socket_id, client_secret_key, client_public_key, server_publi self.case_logger = case_logger def shutdown(self): + """Shuts down the results socket and tears down the ExecutionDatabase + """ self.results_sock.close() self.execution_db.tear_down() @@ -59,7 +69,8 @@ def handle_event(self, workflow, sender, **kwargs): callback in the main thread. Args: - sender (execution element): The execution element that sent the signal. + workflow (Workflow): The Workflow object that triggered the event + sender (ExecutionElement): The execution element that sent the signal. kwargs (dict): Any extra data to send. """ event = kwargs['event'] @@ -102,11 +113,14 @@ class CaseCommunicationMessageType(Enum): class WorkflowCommunicationReceiver(object): def __init__(self, socket_id, client_secret_key, client_public_key, server_public_key, zmq_communication_address): - """Initialize a Workflow object, which will be executing workflows. + """Initialize a WorkflowCommunicationReceiver object, which will receive messages on the comm socket Args: - id_ (str): The ID of the worker. Needed for ZMQ socket communication. - worker_environment_setup (func, optional): Function to setup globals in the worker. + socket_id (str): The socket ID for the ZMQ communication socket + client_secret_key (str): The secret key for the client + client_public_key (str): The public key for the client + server_public_key (str): The public key for the server + zmq_communication_address (str): The IP address for the ZMQ communication socket """ self.comm_sock = zmq.Context().socket(zmq.SUB) self.comm_sock.identity = socket_id @@ -118,12 +132,13 @@ def __init__(self, socket_id, client_secret_key, client_public_key, server_publi self.exit = False def shutdown(self): + """Shuts down the object by setting self.exit to True and closing the communication socket + """ self.exit = True self.comm_sock.close() def receive_communications(self): - """Constantly receives data from the ZMQ socket and handles it accordingly. - """ + """Constantly receives data from the ZMQ socket and handles it accordingly""" while not self.exit: try: @@ -172,12 +187,22 @@ def _format_case_message_data(message): class WorkflowReceiver(object): def __init__(self, key, server_key, cache_config): + """Initializes a WorkflowReceiver object, which receives workflow execution requests and ships them off to a + worker to execute + + Args: + key (PrivateKey): The NaCl PrivateKey generated by the Worker + server_key (PrivateKey): The NaCl PrivateKey generated by the Worker + cache_config (dict): Cache configuration + """ self.key = key self.server_key = server_key self.cache = walkoff.cache.make_cache(cache_config) self.exit = False def shutdown(self): + """Shuts down the object by setting self.exit to True and shutting down the cache + """ self.exit = True self.cache.shutdown() @@ -205,10 +230,10 @@ def receive_workflows(self): class Worker(object): def __init__(self, id_, config_path): - """Initialize a Workflow object, which will be executing workflows. + """Initialize a Workfer object, which will be managing the execution of Workflows Args: - id_ (str): The ID of the worker. Needed for ZMQ socket communication. + id_ (str): The ID of the worker config_path (str): The path to the configuration file to be loaded """ self.id_ = id_ @@ -217,7 +242,6 @@ def __init__(self, id_, config_path): signal.signal(signal.SIGABRT, self.exit_handler) if os.name == 'nt': - import apps # need this import walkoff.config.initialize(config_path=config_path) else: walkoff.config.Config.load_config(config_path) @@ -279,8 +303,7 @@ def handle_data_sent(sender, **kwargs): self.receive_workflows() def exit_handler(self, signum, frame): - """Clean up upon receiving a SIGINT or SIGABT. - """ + """Clean up upon receiving a SIGINT or SIGABT""" self.thread_exit = True self.workflow_receiver.shutdown() if self.threadpool: @@ -307,7 +330,14 @@ def __is_pool_at_capacity(self): return len(self.workflows) >= self.capacity def execute_workflow_worker(self, workflow_id, workflow_execution_id, start, start_arguments=None, resume=False): - """Execute a workflow. + """Execute a workflow + + Args: + workflow_id (UUID): The ID of the Workflow to be executed + workflow_execution_id (UUID): The execution ID of the Workflow to be executed + start (UUID): The ID of the starting Action + start_arguments (list[Argument], optional): Optional list of starting Arguments. Defaults to None + resume (bool, optional): Optional boolean to signify that this Workflow is being resumed. Defaults to False. """ self.execution_db.session.expire_all() workflow = self.execution_db.session.query(Workflow).filter_by(id=workflow_id).first() @@ -333,8 +363,7 @@ def execute_workflow_worker(self, workflow_id, workflow_execution_id, start, sta self.workflows.pop(threading.current_thread().name) def receive_communications(self): - """Constantly receives data from the ZMQ socket and handles it accordingly. - """ + """Constantly receives data from the ZMQ socket and handles it accordingly""" for message in self.workflow_communication_receiver.receive_communications(): if message.type == WorkerCommunicationMessageType.workflow: self._handle_workflow_control_communication(message.data) @@ -362,7 +391,7 @@ def on_data_sent(self, sender, **kwargs): callback in the main thread. Args: - sender (execution element): The execution element that sent the signal. + sender (ExecutionElement): The execution element that sent the signal. kwargs (dict): Any extra data to send. """ workflow = self._get_current_workflow() diff --git a/walkoff/multiprocessedexecutor/workflowexecutioncontroller.py b/walkoff/multiprocessedexecutor/workflowexecutioncontroller.py index 79bcd6c8b..2f7ca8dde 100644 --- a/walkoff/multiprocessedexecutor/workflowexecutioncontroller.py +++ b/walkoff/multiprocessedexecutor/workflowexecutioncontroller.py @@ -11,11 +11,11 @@ from nacl.public import PrivateKey, Box from six import string_types +import walkoff.config from walkoff.events import WalkoffEvent, EventType -from walkoff.proto.build.data_pb2 import Message, CommunicationPacket, ExecuteWorkflowMessage, CaseControl, WorkflowControl from walkoff.helpers import json_dumps_or_string -import walkoff.config -from flask import current_app +from walkoff.proto.build.data_pb2 import Message, CommunicationPacket, ExecuteWorkflowMessage, CaseControl, \ + WorkflowControl logger = logging.getLogger(__name__) @@ -23,6 +23,9 @@ class WorkflowExecutionController: def __init__(self, cache): """Initialize a LoadBalancer object, which manages workflow execution. + + Args: + cache (Cache): The Cache object """ server_secret_file = os.path.join(walkoff.config.Config.ZMQ_PRIVATE_KEYS_PATH, "server.key_secret") server_public, server_secret = auth.load_certificate(server_secret_file) @@ -43,10 +46,11 @@ def add_workflow(self, workflow_id, workflow_execution_id, start=None, start_arg """Adds a workflow ID to the queue to be executed. Args: - workflow_id (int): The ID of the workflow to be executed. - workflow_execution_id (str): The execution ID of the workflow to be executed. - start (str, optional): The ID of the first, or starting action. Defaults to None. - start_arguments (list[Argument]): The arguments to the starting action of the workflow. Defaults to None. + workflow_id (UUID): The ID of the workflow to be executed. + workflow_execution_id (UUID): The execution ID of the workflow to be executed. + start (UUID, optional): The ID of the first, or starting action. Defaults to None. + start_arguments (list[Argument], optional): The arguments to the starting action of the workflow. Defaults + to None. resume (bool, optional): Optional boolean to resume a previously paused workflow. Defaults to False. """ message = ExecuteWorkflowMessage() @@ -67,7 +71,7 @@ def pause_workflow(self, workflow_execution_id): """Pauses a workflow currently executing. Args: - workflow_execution_id (str): The execution ID of the workflow. + workflow_execution_id (UUID): The execution ID of the workflow. """ logger.info('Pausing workflow {0}'.format(workflow_execution_id)) message = self._create_workflow_control_message(WorkflowControl.PAUSE, workflow_execution_id) @@ -77,7 +81,7 @@ def abort_workflow(self, workflow_execution_id): """Aborts a workflow currently executing. Args: - workflow_execution_id (str): The execution ID of the workflow. + workflow_execution_id (UUID): The execution ID of the workflow. """ logger.info('Aborting workflow {0}'.format(workflow_execution_id)) message = self._create_workflow_control_message(WorkflowControl.ABORT, workflow_execution_id) @@ -92,8 +96,7 @@ def _create_workflow_control_message(control_type, workflow_execution_id): return message def send_exit_to_worker_comms(self): - """Sends the exit message over the communication sockets, otherwise worker receiver threads will hang - """ + """Sends the exit message over the communication sockets, otherwise worker receiver threads will hang""" message = CommunicationPacket() message.type = CommunicationPacket.EXIT self._send_message(message) @@ -112,14 +115,31 @@ def _set_arguments_for_proto(message, arguments): setattr(arg, field, val) def create_case(self, case_id, subscriptions): + """Creates a Case + + Args: + case_id (int): The ID of the Case + subscriptions (list[Subscription]): List of Subscriptions to subscribe to + """ message = self._create_case_update_message(case_id, CaseControl.CREATE, subscriptions=subscriptions) self._send_message(message) def update_case(self, case_id, subscriptions): + """Updates a Case + + Args: + case_id (int): The ID of the Case + subscriptions (list[Subscription]): List of Subscriptions to subscribe to + """ message = self._create_case_update_message(case_id, CaseControl.UPDATE, subscriptions=subscriptions) self._send_message(message) def delete_case(self, case_id): + """Deletes a Case + + Args: + case_id (int): The ID of the Case to delete + """ message = self._create_case_update_message(case_id, CaseControl.DELETE) self._send_message(message) @@ -143,7 +163,7 @@ def _send_message(self, message): class Receiver: def __init__(self, current_app): - """Initialize a Receiver object, which will receive callbacks from the execution elements. + """Initialize a Receiver object, which will receive callbacks from the ExecutionElements. Args: current_app (Flask.App): The current Flask app @@ -164,9 +184,7 @@ def __init__(self, current_app): self.current_app = current_app def receive_results(self): - """Keep receiving results from execution elements over a ZMQ socket, and trigger the callbacks. - """ - + """Keep receiving results from execution elements over a ZMQ socket, and trigger the callbacks""" while True: if self.thread_exit: break @@ -177,11 +195,12 @@ def receive_results(self): continue with self.current_app.app_context(): - self.send_callback(message_bytes) + self._send_callback(message_bytes) self.results_sock.close() - def send_callback(self, message_bytes): + def _send_callback(self, message_bytes): + message_outer = Message() message_outer.ParseFromString(message_bytes) callback_name = message_outer.event_name @@ -231,6 +250,14 @@ def _increment_execution_count(self): def format_message_event_data(message): + """Formats a Message + + Args: + message (Message): The Message to be formatted + + Returns: + (dict): The formatted Message object + """ return {'users': message.users, 'roles': message.roles, 'requires_reauth': message.requires_reauth, diff --git a/walkoff/server/app.py b/walkoff/server/app.py index 62d0eb3bb..d6376d09e 100644 --- a/walkoff/server/app.py +++ b/walkoff/server/app.py @@ -1,14 +1,20 @@ import logging import connexion +from flask import Blueprint from jinja2 import FileSystemLoader +import interfaces +import walkoff.config +from walkoff.extensions import db, jwt +from walkoff.helpers import import_submodules +from walkoff.server import context +from walkoff.server.blueprints import custominterface, workflowresults, notifications, console, root + logger = logging.getLogger(__name__) def register_blueprints(flaskapp): - from walkoff.server.blueprints import custominterface, workflowresults, notifications, console, root - flaskapp.register_blueprint(custominterface.custom_interface_page, url_prefix='/custominterfaces/') flaskapp.register_blueprint(workflowresults.workflowresults_page, url_prefix='/api/streams/workflowqueue') flaskapp.register_blueprint(notifications.notifications_page, url_prefix='/api/streams/messages') @@ -20,7 +26,6 @@ def register_blueprints(flaskapp): def __get_blueprints_in_module(module): - from flask import Blueprint blueprints = [getattr(module, field) for field in dir(module) if (not field.startswith('__') and isinstance(getattr(module, field), Blueprint))] @@ -28,8 +33,7 @@ def __get_blueprints_in_module(module): def __register_blueprint(flaskapp, blueprint, url_prefix): - from interfaces import AppBlueprint - if isinstance(blueprint, AppBlueprint): + if isinstance(blueprint, interfaces.AppBlueprint): blueprint.cache = flaskapp.running_context.cache url_prefix = '{0}{1}'.format(url_prefix, blueprint.url_suffix) if blueprint.url_suffix else url_prefix blueprint.url_prefix = url_prefix @@ -43,8 +47,6 @@ def __register_app_blueprints(flaskapp, app_name, blueprints): def __register_all_app_blueprints(flaskapp): - from walkoff.helpers import import_submodules - import interfaces imported_apps = import_submodules(interfaces) for interface_name, interfaces_module in imported_apps.items(): try: @@ -58,10 +60,6 @@ def __register_all_app_blueprints(flaskapp): def create_app(app_config): - import walkoff.config - from walkoff.server import context - from walkoff.extensions import db, jwt - connexion_app = connexion.App(__name__, specification_dir='../api/') _app = connexion_app.app _app.jinja_loader = FileSystemLoader(['walkoff/templates']) @@ -74,6 +72,4 @@ def create_app(app_config): _app.running_context = context.Context(walkoff.config.Config) register_blueprints(_app) - import walkoff.server.workflowresults # Don't delete this import - import walkoff.messaging.utils # Don't delete this import return _app diff --git a/walkoff/server/blueprints/console.py b/walkoff/server/blueprints/console.py index c38b6c0f9..e8aff3912 100644 --- a/walkoff/server/blueprints/console.py +++ b/walkoff/server/blueprints/console.py @@ -1,7 +1,8 @@ +import logging + from walkoff.events import WalkoffEvent from walkoff.security import jwt_required_in_query from walkoff.sse import SseStream, StreamableBlueprint -import logging console_stream = SseStream('console_results') console_page = StreamableBlueprint('console_page', __name__, streams=(console_stream,)) @@ -29,6 +30,3 @@ def console_log_callback(sender, **kwargs): @jwt_required_in_query('access_token') def stream_console_events(): return console_stream.stream() - - - diff --git a/walkoff/server/blueprints/notifications.py b/walkoff/server/blueprints/notifications.py index 53f0edcb1..27ead2a95 100644 --- a/walkoff/server/blueprints/notifications.py +++ b/walkoff/server/blueprints/notifications.py @@ -7,7 +7,6 @@ from walkoff.security import jwt_required_in_query from walkoff.sse import FilteredSseStream, StreamableBlueprint - sse_stream = FilteredSseStream('notifications') notifications_page = StreamableBlueprint('notifications_page', __name__, streams=[sse_stream]) diff --git a/walkoff/server/blueprints/root.py b/walkoff/server/blueprints/root.py index f9ba74e39..fac5699b8 100644 --- a/walkoff/server/blueprints/root.py +++ b/walkoff/server/blueprints/root.py @@ -1,16 +1,16 @@ import logging import os +from flask import current_app from flask import render_template, send_from_directory, Blueprint +from sqlalchemy.exc import SQLAlchemyError import walkoff.config -from sqlalchemy.exc import SQLAlchemyError -from walkoff.server.problem import Problem -from walkoff.server.returncodes import SERVER_ERROR -from walkoff.extensions import db from walkoff import helpers from walkoff.executiondb.device import App -from flask import current_app +from walkoff.extensions import db +from walkoff.server.problem import Problem +from walkoff.server.returncodes import SERVER_ERROR logger = logging.getLogger(__name__) diff --git a/walkoff/server/blueprints/workflowresults.py b/walkoff/server/blueprints/workflowresults.py index 603be3a90..4829ecc52 100644 --- a/walkoff/server/blueprints/workflowresults.py +++ b/walkoff/server/blueprints/workflowresults.py @@ -1,6 +1,7 @@ from datetime import datetime from flask import current_app + from walkoff.events import WalkoffEvent from walkoff.executiondb import ActionStatusEnum, WorkflowStatusEnum from walkoff.executiondb.workflowresults import WorkflowStatus diff --git a/walkoff/server/context.py b/walkoff/server/context.py index 470d717a2..c08c64e9e 100644 --- a/walkoff/server/context.py +++ b/walkoff/server/context.py @@ -1,3 +1,12 @@ +import walkoff.cache +import walkoff.case.database +import walkoff.executiondb +import walkoff.multiprocessedexecutor.multiprocessedexecutor as executor +import walkoff.scheduler +from walkoff.case.logger import CaseLogger +from walkoff.case.subscription import SubscriptionCache + + class Context(object): def __init__(self, config): @@ -7,14 +16,6 @@ def __init__(self, config): Args: config (Config): A config object """ - import walkoff.multiprocessedexecutor.multiprocessedexecutor as executor - import walkoff.scheduler - from walkoff.case.logger import CaseLogger - import walkoff.case.database - import walkoff.cache - from walkoff.case.subscription import SubscriptionCache - import walkoff.executiondb - self.execution_db = walkoff.executiondb.ExecutionDatabase(config.EXECUTION_DB_TYPE, config.EXECUTION_DB_PATH) self.case_db = walkoff.case.database.CaseDatabase(config.CASE_DB_TYPE, config.CASE_DB_PATH) diff --git a/walkoff/server/endpoints/cases.py b/walkoff/server/endpoints/cases.py index c79aa275b..fc2252954 100644 --- a/walkoff/server/endpoints/cases.py +++ b/walkoff/server/endpoints/cases.py @@ -4,12 +4,12 @@ from flask_jwt_extended import jwt_required import walkoff.case.database as case_database +from walkoff.case.subscription import Subscription from walkoff.security import permissions_accepted_for_resources, ResourcePermissions from walkoff.server.decorators import with_resource_factory from walkoff.server.problem import Problem from walkoff.server.returncodes import * from walkoff.serverdb import db -from walkoff.case.subscription import Subscription from walkoff.serverdb.casesubscription import CaseSubscription try: diff --git a/walkoff/server/endpoints/devices.py b/walkoff/server/endpoints/devices.py index 6ff43cc73..8fa9fab87 100644 --- a/walkoff/server/endpoints/devices.py +++ b/walkoff/server/endpoints/devices.py @@ -18,7 +18,8 @@ with_device = with_resource_factory( 'device', - lambda device_id: current_app.running_context.execution_db.session.query(Device).filter(Device.id == device_id).first()) + lambda device_id: current_app.running_context.execution_db.session.query(Device).filter( + Device.id == device_id).first()) def get_device_json_with_app_name(device): @@ -154,7 +155,7 @@ def __func(device): def _update_device(device, update_device_json, validate_required=True): fields = ({field['name']: field['value'] for field in update_device_json['fields']} - if 'fields' in update_device_json else None) + if 'fields' in update_device_json else None) app = update_device_json['app_name'] device_type = update_device_json['type'] if 'type' in update_device_json else device.type try: @@ -171,7 +172,6 @@ def _update_device(device, update_device_json, validate_required=True): device.update_from_json(update_device_json, complete_object=validate_required) current_app.running_context.execution_db.session.commit() device_json = get_device_json_with_app_name(device) - # remove_configuration_keys_from_device_json(device_json) return device_json, SUCCESS diff --git a/walkoff/server/endpoints/metadata.py b/walkoff/server/endpoints/metadata.py index 1e0842021..33e2bb560 100644 --- a/walkoff/server/endpoints/metadata.py +++ b/walkoff/server/endpoints/metadata.py @@ -52,8 +52,8 @@ def validate_path(directory, filename): def read_client_file(filename): - file = validate_path(walkoff.config.Config.CLIENT_PATH, filename) - if file is not None: - return send_file(file), 200 + f = validate_path(walkoff.config.Config.CLIENT_PATH, filename) + if f is not None: + return send_file(f), 200 else: return {"error": "invalid path"}, 463 diff --git a/walkoff/server/endpoints/metrics.py b/walkoff/server/endpoints/metrics.py index 16231970e..95f865f1c 100644 --- a/walkoff/server/endpoints/metrics.py +++ b/walkoff/server/endpoints/metrics.py @@ -1,9 +1,9 @@ -from flask_jwt_extended import jwt_required from flask import current_app +from flask_jwt_extended import jwt_required +from walkoff.executiondb.metrics import AppMetric, WorkflowMetric from walkoff.security import permissions_accepted_for_resources, ResourcePermissions from walkoff.server.returncodes import * -from walkoff.executiondb.metrics import AppMetric, WorkflowMetric def read_app_metrics(): diff --git a/walkoff/server/endpoints/playbooks.py b/walkoff/server/endpoints/playbooks.py index b6a898205..17183122e 100644 --- a/walkoff/server/endpoints/playbooks.py +++ b/walkoff/server/endpoints/playbooks.py @@ -8,7 +8,7 @@ from walkoff.executiondb.playbook import Playbook from walkoff.executiondb.workflow import Workflow -from walkoff.helpers import InvalidExecutionElement, regenerate_workflow_ids +from walkoff.helpers import regenerate_workflow_ids from walkoff.security import permissions_accepted_for_resources, ResourcePermissions from walkoff.server.decorators import with_resource_factory, validate_resource_exists_factory, is_valid_uid from walkoff.server.returncodes import * @@ -329,7 +329,8 @@ def delete_workflow(workflow_id): @permissions_accepted_for_resources(ResourcePermissions('playbooks', ['delete'])) @with_workflow('delete', workflow_id) def __func(workflow): - playbook = current_app.running_context.execution_db.session.query(Playbook).filter_by(id=workflow.playbook_id).first() + playbook = current_app.running_context.execution_db.session.query(Playbook).filter_by( + id=workflow.playbook_id).first() playbook_workflows = len(playbook.workflows) - 1 workflow = current_app.running_context.execution_db.session.query(Workflow).filter_by(id=workflow_id).first() current_app.running_context.execution_db.session.delete(workflow) @@ -365,7 +366,8 @@ def __func(workflow): regenerate_workflow_ids(workflow_json) if current_app.running_context.execution_db.session.query(exists().where(Playbook.id == playbook_id)).scalar(): - playbook = current_app.running_context.execution_db.session.query(Playbook).filter_by(id=playbook_id).first() + playbook = current_app.running_context.execution_db.session.query(Playbook).filter_by( + id=playbook_id).first() else: current_app.running_context.execution_db.session.rollback() current_app.logger.error('Could not copy workflow {}. Playbook does not exist'.format(playbook_id)) diff --git a/walkoff/server/endpoints/triggers.py b/walkoff/server/endpoints/triggers.py index 4e548422c..7e15ced41 100644 --- a/walkoff/server/endpoints/triggers.py +++ b/walkoff/server/endpoints/triggers.py @@ -1,8 +1,8 @@ from flask import request, current_app from flask_jwt_extended import jwt_required, get_jwt_identity, get_jwt_claims -from walkoff.messaging.utils import log_action_taken_on_message from walkoff.executiondb.argument import Argument +from walkoff.messaging.utils import log_action_taken_on_message from walkoff.security import permissions_accepted_for_resources, ResourcePermissions from walkoff.server.returncodes import * from walkoff.serverdb.message import Message diff --git a/walkoff/server/endpoints/workflowqueue.py b/walkoff/server/endpoints/workflowqueue.py index 582234c03..5d0b5516a 100644 --- a/walkoff/server/endpoints/workflowqueue.py +++ b/walkoff/server/endpoints/workflowqueue.py @@ -18,11 +18,13 @@ def does_workflow_exist(workflow_id): def does_execution_id_exist(execution_id): - return current_app.running_context.execution_db.session.query(exists().where(WorkflowStatus.execution_id == execution_id)).scalar() + return current_app.running_context.execution_db.session.query( + exists().where(WorkflowStatus.execution_id == execution_id)).scalar() def workflow_status_getter(execution_id): - return current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by(execution_id=execution_id).first() + return current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( + execution_id=execution_id).first() def workflow_getter(workflow_id): @@ -31,7 +33,6 @@ def workflow_getter(workflow_id): with_workflow = with_resource_factory('workflow', workflow_getter, validator=is_valid_uid) - with_workflow_status = with_resource_factory('workflow', workflow_status_getter, validator=is_valid_uid) validate_workflow_is_registered = validate_resource_exists_factory('workflow', does_workflow_exist) validate_execution_id_is_registered = validate_resource_exists_factory('workflow', does_execution_id_exist) @@ -104,7 +105,8 @@ def __func(workflow): 'Cannot execute workflow.', 'Some arguments are invalid. Reason: {}'.format(errors)) - execution_id = current_app.running_context.executor.execute_workflow(workflow_id, start=start, start_arguments=arguments) + execution_id = current_app.running_context.executor.execute_workflow(workflow_id, start=start, + start_arguments=arguments) current_app.logger.info('Executed workflow {0}'.format(workflow_id)) return {'id': execution_id}, SUCCESS_ASYNC diff --git a/walkoff/server/workflowresults.py b/walkoff/server/workflowresults.py index afa14a485..a60c85bde 100644 --- a/walkoff/server/workflowresults.py +++ b/walkoff/server/workflowresults.py @@ -10,153 +10,144 @@ @WalkoffEvent.WorkflowExecutionPending.connect def __workflow_pending(sender, **kwargs): - with current_app.app_context(): - current_app.running_context.execution_db.session.expire_all() - workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( - execution_id=sender['execution_id']).first() - if workflow_status: - workflow_status.status = WorkflowStatusEnum.pending - else: - workflow_status = WorkflowStatus(sender['execution_id'], sender['id'], sender['name']) - current_app.running_context.execution_db.session.add(workflow_status) - current_app.running_context.execution_db.session.commit() + current_app.running_context.execution_db.session.expire_all() + workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( + execution_id=sender['execution_id']).first() + if workflow_status: + workflow_status.status = WorkflowStatusEnum.pending + else: + workflow_status = WorkflowStatus(sender['execution_id'], sender['id'], sender['name']) + current_app.running_context.execution_db.session.add(workflow_status) + current_app.running_context.execution_db.session.commit() @WalkoffEvent.WorkflowExecutionStart.connect def __workflow_started_callback(sender, **kwargs): - with current_app.app_context(): - current_app.running_context.execution_db.session.expire_all() - workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( - execution_id=sender['execution_id']).first() - workflow_status.running() - current_app.running_context.execution_db.session.commit() + current_app.running_context.execution_db.session.expire_all() + workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( + execution_id=sender['execution_id']).first() + workflow_status.running() + current_app.running_context.execution_db.session.commit() @WalkoffEvent.WorkflowPaused.connect def __workflow_paused_callback(sender, **kwargs): - with current_app.app_context(): - current_app.running_context.execution_db.session.expire_all() - workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( - execution_id=sender['execution_id']).first() - workflow_status.paused() - current_app.running_context.execution_db.session.commit() + current_app.running_context.execution_db.session.expire_all() + workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( + execution_id=sender['execution_id']).first() + workflow_status.paused() + current_app.running_context.execution_db.session.commit() @WalkoffEvent.TriggerActionAwaitingData.connect def __workflow_awaiting_data_callback(sender, **kwargs): - with current_app.app_context(): - workflow_execution_id = kwargs['data']['workflow']['execution_id'] - current_app.running_context.execution_db.session.expire_all() - workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( - execution_id=workflow_execution_id).first() - workflow_status.awaiting_data() - current_app.running_context.execution_db.session.commit() + workflow_execution_id = kwargs['data']['workflow']['execution_id'] + current_app.running_context.execution_db.session.expire_all() + workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( + execution_id=workflow_execution_id).first() + workflow_status.awaiting_data() + current_app.running_context.execution_db.session.commit() @WalkoffEvent.WorkflowShutdown.connect def __workflow_ended_callback(sender, **kwargs): - with current_app.app_context(): - current_app.running_context.execution_db.session.expire_all() - workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( - execution_id=sender['execution_id']).first() - workflow_status.completed() - current_app.running_context.execution_db.session.commit() - - saved_state = current_app.running_context.execution_db.session.query(SavedWorkflow).filter_by( - workflow_execution_id=sender['execution_id']).first() - if saved_state: - current_app.running_context.execution_db.session.delete(saved_state) - - # Update metrics - execution_time = (workflow_status.completed_at - workflow_status.started_at).total_seconds() - - workflow_metric = current_app.running_context.execution_db.session.query(WorkflowMetric).filter_by(workflow_id=sender['id']).first() - if workflow_metric is None: - workflow_metric = WorkflowMetric(sender['id'], sender['name'], execution_time) - current_app.running_context.execution_db.session.add(workflow_metric) - else: - workflow_metric.update(execution_time) + current_app.running_context.execution_db.session.expire_all() + workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( + execution_id=sender['execution_id']).first() + workflow_status.completed() + current_app.running_context.execution_db.session.commit() - current_app.running_context.execution_db.session.commit() + saved_state = current_app.running_context.execution_db.session.query(SavedWorkflow).filter_by( + workflow_execution_id=sender['execution_id']).first() + if saved_state: + current_app.running_context.execution_db.session.delete(saved_state) + + # Update metrics + execution_time = (workflow_status.completed_at - workflow_status.started_at).total_seconds() + + workflow_metric = current_app.running_context.execution_db.session.query(WorkflowMetric).filter_by( + workflow_id=sender['id']).first() + if workflow_metric is None: + workflow_metric = WorkflowMetric(sender['id'], sender['name'], execution_time) + current_app.running_context.execution_db.session.add(workflow_metric) + else: + workflow_metric.update(execution_time) + + current_app.running_context.execution_db.session.commit() @WalkoffEvent.WorkflowAborted.connect def __workflow_aborted(sender, **kwargs): - with current_app.app_context(): - current_app.running_context.execution_db.session.expire_all() - workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( - execution_id=sender['execution_id']).first() - workflow_status.aborted() + current_app.running_context.execution_db.session.expire_all() + workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( + execution_id=sender['execution_id']).first() + workflow_status.aborted() - saved_state = current_app.running_context.execution_db.session.query(SavedWorkflow).filter_by( - workflow_execution_id=sender['execution_id']).first() - if saved_state: - current_app.running_context.execution_db.session.delete(saved_state) + saved_state = current_app.running_context.execution_db.session.query(SavedWorkflow).filter_by( + workflow_execution_id=sender['execution_id']).first() + if saved_state: + current_app.running_context.execution_db.session.delete(saved_state) - current_app.running_context.execution_db.session.commit() + current_app.running_context.execution_db.session.commit() @WalkoffEvent.ActionStarted.connect def __action_start_callback(sender, **kwargs): - with current_app.app_context(): - workflow_execution_id = kwargs['data']['workflow']['execution_id'] - current_app.running_context.execution_db.session.expire_all() - action_status = current_app.running_context.execution_db.session.query(ActionStatus).filter_by( - execution_id=sender['execution_id']).first() - if action_status: - action_status.status = ActionStatusEnum.executing - else: - workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( - execution_id=workflow_execution_id).first() - arguments = sender['arguments'] if 'arguments' in sender else [] - action_status = ActionStatus(sender['execution_id'], sender['id'], sender['name'], sender['app_name'], - sender['action_name'], json.dumps(arguments)) - workflow_status.add_action_status(action_status) - current_app.running_context.execution_db.session.add(action_status) + workflow_execution_id = kwargs['data']['workflow']['execution_id'] + current_app.running_context.execution_db.session.expire_all() + action_status = current_app.running_context.execution_db.session.query(ActionStatus).filter_by( + execution_id=sender['execution_id']).first() + if action_status: + action_status.status = ActionStatusEnum.executing + else: + workflow_status = current_app.running_context.execution_db.session.query(WorkflowStatus).filter_by( + execution_id=workflow_execution_id).first() + arguments = sender['arguments'] if 'arguments' in sender else [] + action_status = ActionStatus(sender['execution_id'], sender['id'], sender['name'], sender['app_name'], + sender['action_name'], json.dumps(arguments)) + workflow_status.add_action_status(action_status) + current_app.running_context.execution_db.session.add(action_status) - current_app.running_context.execution_db.session.commit() + current_app.running_context.execution_db.session.commit() @WalkoffEvent.ActionExecutionSuccess.connect def __action_execution_success_callback(sender, **kwargs): - with current_app.app_context(): - current_app.running_context.execution_db.session.expire_all() - action_status = current_app.running_context.execution_db.session.query(ActionStatus).filter_by( - execution_id=sender['execution_id']).first() - action_status.completed_success(kwargs['data']['data']) + current_app.running_context.execution_db.session.expire_all() + action_status = current_app.running_context.execution_db.session.query(ActionStatus).filter_by( + execution_id=sender['execution_id']).first() + action_status.completed_success(kwargs['data']['data']) - # Update metrics - __update_success_action_tracker(action_status) + # Update metrics + __update_success_action_tracker(action_status) - current_app.running_context.execution_db.session.commit() + current_app.running_context.execution_db.session.commit() @WalkoffEvent.ActionExecutionError.connect def __action_execution_error_callback(sender, **kwargs): - with current_app.app_context(): - current_app.running_context.execution_db.session.expire_all() - action_status = current_app.running_context.execution_db.session.query(ActionStatus).filter_by( - execution_id=sender['execution_id']).first() + current_app.running_context.execution_db.session.expire_all() + action_status = current_app.running_context.execution_db.session.query(ActionStatus).filter_by( + execution_id=sender['execution_id']).first() - action_status.completed_failure(kwargs['data']['data']) + action_status.completed_failure(kwargs['data']['data']) - # Update metrics - __update_error_action_tracker(action_status) - current_app.running_context.execution_db.session.commit() + # Update metrics + __update_error_action_tracker(action_status) + current_app.running_context.execution_db.session.commit() @WalkoffEvent.ActionArgumentsInvalid.connect def __action_args_invalid_callback(sender, **kwargs): - with current_app.app_context(): - current_app.running_context.execution_db.session.expire_all() - action_status = current_app.running_context.execution_db.session.query(ActionStatus).filter_by( - execution_id=sender['execution_id']).first() + current_app.running_context.execution_db.session.expire_all() + action_status = current_app.running_context.execution_db.session.query(ActionStatus).filter_by( + execution_id=sender['execution_id']).first() - action_status.completed_failure(kwargs['data']['data']) + action_status.completed_failure(kwargs['data']['data']) - # Update metrics - __update_error_action_tracker(action_status) - current_app.running_context.execution_db.session.commit() + # Update metrics + __update_error_action_tracker(action_status) + current_app.running_context.execution_db.session.commit() def __update_success_action_tracker(action_status): @@ -168,7 +159,8 @@ def __update_error_action_tracker(action_status): def __update_action_tracker(status, action_status): - app_metric = current_app.running_context.execution_db.session.query(AppMetric).filter_by(app=action_status.app_name).first() + app_metric = current_app.running_context.execution_db.session.query(AppMetric).filter_by( + app=action_status.app_name).first() if app_metric is None: app_metric = AppMetric(action_status.app_name) current_app.running_context.execution_db.session.add(app_metric) diff --git a/walkoff/serverdb/__init__.py b/walkoff/serverdb/__init__.py index 753a50a1a..e03f53c9a 100644 --- a/walkoff/serverdb/__init__.py +++ b/walkoff/serverdb/__init__.py @@ -36,6 +36,7 @@ def initialize_default_resources_admin(): + """Initializes the default resources for an admin user""" admin = Role.query.filter(Role.id == 1).first() if not admin: admin = Role("admin", resources=default_resource_permissions_admin) @@ -46,6 +47,7 @@ def initialize_default_resources_admin(): def initialize_default_resources_guest(): + """Initializes the default resources for a guest user""" guest = Role.query.filter(Role.name == "guest").first() if not guest: guest = Role("guest", resources=default_resource_permissions_guest) @@ -56,15 +58,15 @@ def initialize_default_resources_guest(): def get_roles_by_resource_permissions(resource_permission): - resource = resource_permission.resource + r = resource_permission.resource permissions = resource_permission.permissions roles = [] for permission in permissions: roles.extend(Role.query.join(Role.resources).join(Resource.permissions).filter( - Resource.name == resource, Permission.name == permission).all()) + Resource.name == r, Permission.name == permission).all()) - return {role.id for role in roles} + return {role_obj.id for role_obj in roles} def set_resources_for_role(role_name, resources): @@ -75,8 +77,8 @@ def set_resources_for_role(role_name, resources): resources (dict[resource:list[permission]): A dictionary containing the name of the resource, with the value being a list of permission names """ - role = Role.query.filter(Role.name == role_name).first() - role.set_resources(resources) + r = Role.query.filter(Role.name == role_name).first() + r.set_resources(resources) def clear_resources_for_role(role_name): @@ -85,12 +87,17 @@ def clear_resources_for_role(role_name): Args: role_name (str): The name of the role. """ - role = Role.query.filter(Role.name == role_name).first() - role.resources = [] + r = Role.query.filter(Role.name == role_name).first() + r.resources = [] db.session.commit() def get_all_available_resource_actions(): + """Gets a list of all of the available resource actions + + Returns: + (list[dict]): A list of dicts containing the resource name and the actions available for that resource + """ resource_actions = [] for resource_perm in default_resource_permissions_admin: resource_actions.append( @@ -104,16 +111,16 @@ def add_user(username, password, roles=None): Args: username (str): The username for the User. password (str): The password for the User. - roles (list[int]): A list of roles for the User. + roles (list[int], optional): A list of roles for the User. Defaults to None. Returns: - The new User object if successful, else None. + (User): The new User object if successful, else None. """ if User.query.filter_by(username=username).first() is None: - user = User(username, password, roles=roles) - db.session.add(user) + u = User(username, password, roles=roles) + db.session.add(u) db.session.commit() - return user + return u else: return None diff --git a/walkoff/serverdb/message.py b/walkoff/serverdb/message.py index 5da6a44b1..048c8fb37 100644 --- a/walkoff/serverdb/message.py +++ b/walkoff/serverdb/message.py @@ -14,7 +14,6 @@ db.Column('user_id', db.Integer, db.ForeignKey('user.id')), db.Column('message_id', db.Integer, db.ForeignKey('message.id'))) - role_messages_association = db.Table('role_messages', db.Column('role_id', db.Integer, db.ForeignKey('role.id')), db.Column('message_id', db.Integer, db.ForeignKey('message.id'))) diff --git a/walkoff/serverdb/role.py b/walkoff/serverdb/role.py index 23e01ebf1..4e8191bf5 100644 --- a/walkoff/serverdb/role.py +++ b/walkoff/serverdb/role.py @@ -58,7 +58,7 @@ def as_json(self, with_users=False): Role in the JSON representation. Defaults to False. Returns: - The dictionary representation of the Role object. + (dict): The dictionary representation of the Role object. """ out = {"id": self.id, "name": self.name, diff --git a/walkoff/serverdb/tokens.py b/walkoff/serverdb/tokens.py index fa75aab42..1c409a251 100644 --- a/walkoff/serverdb/tokens.py +++ b/walkoff/serverdb/tokens.py @@ -17,7 +17,7 @@ def as_json(self): """Get the JSON representation of a BlacklistedToken object. Returns: - The JSON representation of a BlacklistedToken object. + (dict): The JSON representation of a BlacklistedToken object. """ return { 'id': self.id, @@ -28,7 +28,10 @@ def as_json(self): def revoke_token(decoded_token): - """Adds a new token to the database. It is not revoked when it is added. + """Adds a new token to the database. It is not revoked when it is added + + Args: + decoded_token (dict): The decoded token """ jti = decoded_token['jti'] user_identity = decoded_token[current_app.config['JWT_IDENTITY_CLAIM']] @@ -51,7 +54,7 @@ def is_token_revoked(decoded_token): it was created. Returns: - True if the token is revoked, False otherwise. + (bool): True if the token is revoked, False otherwise. """ jti = decoded_token['jti'] token = BlacklistedToken.query.filter_by(jti=jti).first() @@ -59,7 +62,11 @@ def is_token_revoked(decoded_token): def approve_token(token_id, user): - """Approves the given token. + """Approves the given token + + Args: + token_id (int): The ID of the token + user (User): The User """ token = BlacklistedToken.query.filter_by(id=token_id, user_identity=user).first() if token is not None: @@ -69,14 +76,14 @@ def approve_token(token_id, user): def prune_if_necessary(): + """Prunes the database if necessary""" current_app.running_context.cache.incr("number_of_operations") if current_app.running_context.cache.get("number_of_operations") >= prune_frequency: prune_database() def prune_database(): - """Delete tokens that have expired from the database. - """ + """Delete tokens that have expired from the database""" now = datetime.now() expired = BlacklistedToken.query.filter(BlacklistedToken.expires < now).all() for token in expired: diff --git a/walkoff/serverdb/user.py b/walkoff/serverdb/user.py index 7a9d70d17..62dd92314 100644 --- a/walkoff/serverdb/user.py +++ b/walkoff/serverdb/user.py @@ -35,7 +35,7 @@ def __init__(self, name, password, roles=None): Args: name (str): The username for the User. password (str): The password for the User. - roles (list[int]): List of Role ids for the User. Defaults to None. + roles (list[int], optional): List of Role ids for the User. Defaults to None. """ self.username = name self._password = pbkdf2_sha512.hash(password) @@ -46,6 +46,9 @@ def __init__(self, name, password, roles=None): @hybrid_property def password(self): """Returns the password for the user. + + Returns: + (str): The password """ return self._password @@ -65,7 +68,7 @@ def verify_password(self, password_attempt): password_attempt(str): The input password. Returns: - True if the passwords match, False if not. + (bool): True if the passwords match, False if not. """ return pbkdf2_sha512.verify(password_attempt, self._password) @@ -98,8 +101,7 @@ def login(self, ip_address): self.login_count += 1 def logout(self): - """Tracks login/logout information for the User upon logging out. - """ + """Tracks login/logout information for the User upon logging out""" if self.login_count > 0: self.login_count -= 1 else: @@ -113,7 +115,7 @@ def has_role(self, role): role (int): The ID of the Role. Returns: - True if the User has the Role, False otherwise. + (bool): True if the User has the Role, False otherwise. """ return role in [role.id for role in self.roles] @@ -125,7 +127,7 @@ def as_json(self, with_user_history=False): representation of the User. Defaults to False. Returns: - The dictionary representation of a User object. + (dict): The dictionary representation of a User object. """ out = {"id": self.id, "username": self.username, diff --git a/walkoff/sse.py b/walkoff/sse.py index 781a06075..675eecd32 100644 --- a/walkoff/sse.py +++ b/walkoff/sse.py @@ -1,10 +1,12 @@ +import collections +import json from functools import wraps + from flask import Response, Blueprint -import json -from walkoff.cache import unsubscribe_message -import collections from six import string_types, binary_type +from walkoff.cache import unsubscribe_message + class StreamableBlueprint(Blueprint): """Blueprint which has streams. @@ -14,6 +16,7 @@ class StreamableBlueprint(Blueprint): Attributes: stream (dict{str: Stream}): A lookup for the streams of this blueprint. The key is the channel of the stream """ + def __init__(self, *args, **kwargs): """ Kwargs: @@ -50,6 +53,7 @@ class SseEvent(object): data: The data related to this SSE """ + def __init__(self, event, data): self.event = event self.data = data @@ -100,6 +104,7 @@ class SseStream(object): cache (:obj:, optional): The cache to use for this SSE stream. Defaults to the `walkoff.cache.cache` used throughout Walkoff """ + def __init__(self, channel, cache=None): self.channel = channel self.cache = cache @@ -119,12 +124,14 @@ def push(self, event=''): Returns: (func): The decorated function """ + def decorator(func): @wraps(func) def wrapper(*args, **kwargs): response = func(*args, **kwargs) self._publish_response(response, event) return response + return wrapper return decorator @@ -228,6 +235,7 @@ class FilteredSseStream(SseStream): cache (:obj:, optional): The cache to use for this SSE stream. Defaults to the `walkoff.cache.cache` used throughout Walkoff """ + def __init__(self, channel, cache=None): super(FilteredSseStream, self).__init__(channel, cache) @@ -330,6 +338,7 @@ class InterfaceSseStream(SseStream): channel (str): The name of the channel cache (optional): The cache object used for this SSE stream """ + def __init__(self, interface, channel, cache=None): super(InterfaceSseStream, self).__init__(create_interface_channel_name(interface, channel), cache=cache) self.interface = interface @@ -346,6 +355,7 @@ class FilteredInterfaceSseStream(FilteredSseStream): channel (str): The name of the channel cache (optional): The cache object used for this SSE stream """ + def __init__(self, interface, channel, cache=None): super(FilteredInterfaceSseStream, self).__init__(create_interface_channel_name(interface, channel), cache=cache) self.interface = interface