Skip to content

Commit

Permalink
Enable Session-Id headers for slaves
Browse files Browse the repository at this point in the history
The master already returns a session id in response headers, and the
API will raise a 412 if the Session-Id request header specifies a
differing session id.

This change moves that logic one level up so that slaves will also
start generating and returning a slave-specific session id in their
responses. The master now also tracks the session id for each slave.

A subsequent change will add verification when the master sends
requests to the slave to ensure that the master is talking to the
correct slave instance.
  • Loading branch information
josephharrington committed Apr 10, 2016
1 parent 6b6ee21 commit 2a724f5
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 29 deletions.
5 changes: 3 additions & 2 deletions app/master/cluster_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,13 @@ def get_slave(self, slave_id=None, slave_url=None):

raise ItemNotFoundError('Requested slave ({}) does not exist.'.format(slave_id))

def connect_slave(self, slave_url, num_executors):
def connect_slave(self, slave_url, num_executors, slave_session_id=None):
"""
Connect a slave to this master.
:type slave_url: str
:type num_executors: int
:type slave_session_id: str | None
:return: The response with the slave id of the slave.
:rtype: dict[str, str]
"""
Expand All @@ -138,7 +139,7 @@ def connect_slave(self, slave_url, num_executors):
self._logger.info('Failed to find build {} that was running on {}', old_slave.current_build_id,
slave_url)

slave = Slave(slave_url, num_executors)
slave = Slave(slave_url, num_executors, slave_session_id)
self._all_slaves_by_url[slave_url] = slave
self._slave_allocator.add_idle_slave(slave)
self._logger.info('Slave on {} connected to master with {} executors. (id: {})',
Expand Down
5 changes: 4 additions & 1 deletion app/master/slave.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ class Slave(object):
API_VERSION = 'v1'
_slave_id_counter = Counter()

def __init__(self, slave_url, num_executors):
def __init__(self, slave_url, num_executors, slave_session_id=None):
"""
:type slave_url: str
:type num_executors: int
:type slave_session_id: str
"""
self.url = slave_url
self.num_executors = num_executors
Expand All @@ -27,12 +28,14 @@ def __init__(self, slave_url, num_executors):
self._is_alive = True
self._is_in_shutdown_mode = False
self._slave_api = UrlBuilder(slave_url, self.API_VERSION)
self._session_id = slave_session_id
self._logger = log.get_logger(__name__)

def api_representation(self):
return {
'url': self.url,
'id': self.id,
'session_id': self._session_id,
'num_executors': self.num_executors,
'num_executors_in_use': self.num_executors_in_use(),
'current_build_id': self.current_build_id,
Expand Down
3 changes: 3 additions & 0 deletions app/slave/cluster_slave.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from app.util.network import Network
from app.util.safe_thread import SafeThread
from app.util.secret import Secret
from app.util.session_id import SessionId
from app.util.single_use_coin import SingleUseCoin
from app.util.unhandled_exception_handler import UnhandledExceptionHandler
from app.util.url_builder import UrlBuilder
Expand Down Expand Up @@ -67,6 +68,7 @@ def api_representation(self):
'current_build_id': self._current_build_id,
'slave_id': self._slave_id,
'executors': executors_representation,
'session_id': SessionId.get(),
}

def get_status(self):
Expand Down Expand Up @@ -226,6 +228,7 @@ def connect_to_master(self, master_url=None):
data = {
'slave': '{}:{}'.format(self.host, self.port),
'num_executors': self._num_executors,
'session_id': SessionId.get()
}
response = self._network.post(connect_url, data=data)
self._slave_id = int(response.json().get('slave_id'))
Expand Down
8 changes: 5 additions & 3 deletions app/util/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from app.util import autoversioning, fs
from app.util.conf.configuration import Configuration
from app.util.session_id import SessionId


# This custom format string takes care of setting field widths to make logs more aligned and readable.
Expand Down Expand Up @@ -123,7 +124,7 @@ def configure_logging(log_level=None, log_file=None, simplified_console_logs=Fal
event_handler.log_application_summary()


def application_summary(logfile_count):
def application_summary(logfile_count): # todo: move this method to app_info.py
"""
Return a string summarizing general info about the application. This will be output at the start of every logfile.
Expand All @@ -133,8 +134,9 @@ def application_summary(logfile_count):
separator = '*' * 50
summary_lines = [
' ClusterRunner',
' * Version: {}'.format(autoversioning.get_version()),
' * PID: {}'.format(os.getpid()),
' * Version: {}'.format(autoversioning.get_version()),
' * PID: {}'.format(os.getpid()),
' * Session id: {}'.format(SessionId.get()),
]
if logfile_count > 1:
summary_lines.append(' * Logfile count: {}'.format(logfile_count))
Expand Down
15 changes: 15 additions & 0 deletions app/web_framework/cluster_base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from app.util.conf.configuration import Configuration
from app.util.exceptions import AuthenticationError, BadRequestError, ItemNotFoundError, ItemNotReadyError, PreconditionFailedError
from app.util.network import ENCODED_BODY
from app.util.session_id import SessionId


# pylint: disable=attribute-defined-outside-init
Expand Down Expand Up @@ -73,6 +74,8 @@ def prepare(self):
"""
Called at the beginning of a request before `get`/`post`/etc.
"""
self._check_expected_session_id()

# Decode an encoded body, if present. Otherwise fall back to decoding the raw request body. See the comments in
# the util.network.Network class for more information about why we're doing this.
try:
Expand All @@ -82,6 +85,17 @@ def prepare(self):
except ValueError as ex:
raise BadRequestError('Invalid JSON in request body.') from ex

def _check_expected_session_id(self):
"""
If the request has specified the session id, which is optional, and the session id does not match
the current instance's session id, then the requester is asking for a resource that has expired and
no longer exists.
"""
session_id = self.request.headers.get(SessionId.SESSION_HEADER_KEY)

if session_id is not None and session_id != SessionId.get():
raise PreconditionFailedError('Specified session id: {} has expired and is invalid.'.format(session_id))

def options(self, *args, **kwargs):
"""
Enable OPTIONS on all endpoints by default (preflight AJAX requests requires this).
Expand Down Expand Up @@ -124,6 +138,7 @@ def get_child_routes(self):

def set_default_headers(self):
self.set_header('Content-Type', 'application/json')
self.set_header(SessionId.SESSION_HEADER_KEY, SessionId.get())

request_origin = self.request.headers.get('Origin') # usually only set when making API request from a browser
if request_origin and self._is_request_origin_allowed(request_origin):
Expand Down
26 changes: 3 additions & 23 deletions app/web_framework/cluster_master_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from app.util import log
from app.util.conf.configuration import Configuration
from app.util.decorators import authenticated
from app.util.exceptions import ItemNotFoundError, PreconditionFailedError
from app.util.session_id import SessionId
from app.util.exceptions import ItemNotFoundError
from app.util.url_builder import UrlBuilder
from app.web_framework.cluster_application import ClusterApplication
from app.web_framework.cluster_base_handler import ClusterBaseAPIHandler, ClusterBaseHandler
Expand Down Expand Up @@ -74,26 +73,6 @@ def initialize(self, route_node=None, cluster_master=None):
self._cluster_master = cluster_master
super().initialize(route_node)

def prepare(self):
"""
If the request has specified the session id, which is optional, and the session id does not match
the current instance's session id, then the client is asking for a resource that has expired and
no longer exists.
"""
session_id = self.request.headers.get(SessionId.SESSION_HEADER_KEY)

if session_id is not None and session_id != SessionId.get():
raise PreconditionFailedError('Specified session id: {} has expired and is invalid.'.format(session_id))

super().prepare()

def set_default_headers(self):
"""
Inject the session id in the header.
"""
self.set_header(SessionId.SESSION_HEADER_KEY, SessionId.get())
super().set_default_headers()


class _RootHandler(_ClusterMasterBaseAPIHandler):
pass
Expand Down Expand Up @@ -296,7 +275,8 @@ class _SlavesHandler(_ClusterMasterBaseAPIHandler):
def post(self):
slave_url = self.decoded_body.get('slave')
num_executors = int(self.decoded_body.get('num_executors'))
response = self._cluster_master.connect_slave(slave_url, num_executors)
session_id = self.decoded_body.get('session_id')
response = self._cluster_master.connect_slave(slave_url, num_executors, session_id)
self._write_status(response, status_code=201)

def get(self):
Expand Down

0 comments on commit 2a724f5

Please sign in to comment.