Skip to content

Commit

Permalink
Add an optional authorization layer to RPCMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
stefano-maggiolo committed Jul 16, 2015
1 parent a157a27 commit 8475009
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 17 deletions.
17 changes: 14 additions & 3 deletions cms/io/web_rpc.py
Expand Up @@ -3,6 +3,7 @@

# Contest Management System - http://cms-dev.github.io/
# Copyright © 2013 Luca Wehrstedt <luca.wehrstedt@gmail.com>
# Copyright © 2015 Stefano Maggiolo <s.maggiolo@gmail.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
Expand Down Expand Up @@ -68,14 +69,18 @@ class RPCMiddleware(object):
string describing the error that occured (if any).
"""
def __init__(self, service):
def __init__(self, service, auth=None):
"""Create an HTTP-to-RPC proxy for the given service.
service (Service): the service this application is running for.
Will usually be the AdminWebServer.
auth (function|None): a function taking the environ of a request
and returning whether the request is allowed. If not present,
all requests are allowed.
"""
self._service = service
self._auth = auth
self._url_map = Map([Rule("/<service>/<int:shard>/<method>",
methods=["POST"], endpoint="rpc")],
encoding_errors="strict")
Expand Down Expand Up @@ -105,11 +110,17 @@ def wsgi_app(self, environ, start_response):

assert endpoint == "rpc"

response = Response()

if self._auth is not None and not self._auth(environ):
response.status_code = 403
response.mimetype = "plain/text"
response.data = "Request not allowed."
return response

request = Request(environ)
request.encoding_errors = "strict"

response = Response()

remote_service = ServiceCoord(args['service'], args['shard'])

if remote_service not in self._service.remote_services:
Expand Down
17 changes: 9 additions & 8 deletions cms/io/web_service.py
Expand Up @@ -61,12 +61,21 @@ def __init__(self, listen_port, handlers, parameters, shard=0,

static_files = parameters.pop('static_files', [])
rpc_enabled = parameters.pop('rpc_enabled', False)
rpc_auth = parameters.pop('rpc_auth', None)
auth_middleware = parameters.pop('auth_middleware', None)
is_proxy_used = parameters.pop('is_proxy_used', False)

self.wsgi_app = tornado.wsgi.WSGIApplication(handlers, **parameters)
self.wsgi_app.service = self

for entry in static_files:
self.wsgi_app = SharedDataMiddleware(
self.wsgi_app, {"/static": entry})

if rpc_enabled:
self.wsgi_app = DispatcherMiddleware(
self.wsgi_app, {"/rpc": RPCMiddleware(self, rpc_auth)})

# Remove any authentication header that a user may try to fake.
self.wsgi_app = HeaderRewriterFix(
self.wsgi_app,
Expand All @@ -75,14 +84,6 @@ def __init__(self, listen_port, handlers, parameters, shard=0,
if auth_middleware is not None:
self.wsgi_app = auth_middleware(self.wsgi_app)

for entry in static_files:
self.wsgi_app = SharedDataMiddleware(
self.wsgi_app, {"/static": entry})

if rpc_enabled:
self.wsgi_app = DispatcherMiddleware(
self.wsgi_app, {"/rpc": RPCMiddleware(self)})

# If is_proxy_used is set to True we'll use the content of the
# X-Forwarded-For HTTP header (if provided) to determine the
# client IP address, ignoring the one the request came from.
Expand Down
90 changes: 90 additions & 0 deletions cms/server/admin/rpc_authorization.py
@@ -0,0 +1,90 @@
#!/usr/bin/env python2
# -*- coding: utf-8 -*-

# Contest Management System - http://cms-dev.github.io/
# Copyright © 2015 Stefano Maggiolo <s.maggiolo@gmail.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Handle authorization for the RPC calls coming from AWS.
"""

from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals

from cms.db import Admin, SessionGen
from cms.io import WebService


AUTHENTICATED_USER_HEADER_IN_ENV = "HTTP_" + \
WebService.AUTHENTICATED_USER_HEADER.upper().replace('-', '_')


RPCS_ALLOWED_FOR_AUTHENTICATED = [
("ResourceService", "get_resources"),
("EvaluationService", "workers_status"),
("EvaluationService", "submissions_status"),
("EvaluationService", "queue_status"),
("LogService", "last_messages"),
]


RPCS_ALLOWED_FOR_MESSAGING = RPCS_ALLOWED_FOR_AUTHENTICATED + []


RPCS_ALLOWED_FOR_ALL = RPCS_ALLOWED_FOR_MESSAGING + [
("ResourceService", "kill_resources"),
("ResourceService", "toggle_autorestart"),
("EvaluationService", "enable_worker"),
("EvaluationService", "disable_worker"),
("EvaluationService", "invalidate_submission"),
]


def rpc_authorization_checker(environ):
"""Return whether to accept the request.
environ ({}): WSGI environ object with the request metadata.
return (bool): whether to accept the request or not.
"""
admin_id = int(environ.get(AUTHENTICATED_USER_HEADER_IN_ENV, None))
path_info = environ.get("PATH_INFO", "").strip("/").split("/")

if admin_id is None or len(path_info) < 3:
return False

service = path_info[-3]
# We don't check on shard = path_info[-2].
method = path_info[-1]

with SessionGen() as session:
# Load admin.
admin = session.query(Admin)\
.filter(Admin.id == admin_id)\
.first()
if admin is None:
return False

if admin.permission_all:
return (service, method) in RPCS_ALLOWED_FOR_ALL

elif admin.permission_messaging:
return (service, method) in RPCS_ALLOWED_FOR_MESSAGING

else:
return (service, method) in RPCS_ALLOWED_FOR_AUTHENTICATED
13 changes: 7 additions & 6 deletions cms/server/admin/server.py
Expand Up @@ -44,11 +44,15 @@
from cmscommon.datetime import make_timestamp

from .handlers import HANDLERS

from .rpc_authorization import rpc_authorization_checker

logger = logging.getLogger(__name__)


AUTHENTICATED_USER_HEADER_IN_ENV = "HTTP_" + \
WebService.AUTHENTICATED_USER_HEADER.upper().replace('-', '_')


class AdminWebServer(WebService):
"""Service that runs the web server serving the managers.
Expand All @@ -64,6 +68,7 @@ def __init__(self, shard):
"debug": config.tornado_debug,
"auth_middleware": AWSAuthMiddleware,
"rpc_enabled": True,
"rpc_auth": rpc_authorization_checker,
}
super(AdminWebServer, self).__init__(
config.admin_listen_port,
Expand Down Expand Up @@ -121,9 +126,6 @@ class AWSAuthMiddleware(object):
"""

AUTHENTICATED_USER_HEADER_IN_ENV = "HTTP_" + \
WebService.AUTHENTICATED_USER_HEADER.upper().replace('-', '_')

# Header that the underlying WSGI application can set to ask this
# middleware to create a new cookie, refresh an existing cookie,
# or delete it.
Expand All @@ -144,8 +146,7 @@ def __call__(self, environ, start_response):
request = Request(environ)
admin_id = self._authenticate(request)
if admin_id is not None:
environ[
AWSAuthMiddleware.AUTHENTICATED_USER_HEADER_IN_ENV] = admin_id
environ[AUTHENTICATED_USER_HEADER_IN_ENV] = admin_id
response = self._app(
environ,
self._build_start_response(environ, start_response, admin_id))
Expand Down

0 comments on commit 8475009

Please sign in to comment.