Skip to content

Commit

Permalink
Isolate /add endpoint behind a flag
Browse files Browse the repository at this point in the history
Add a flag `allow_add` to the ``blaze.server.Server`` constructor to
determine whether or not the `/add` route is available -- it defaults to
`False`.

Currently, if this is not set, it doesn't define the URL mapping at all,
so any requests for it will result in a 404 Not Found error.
  • Loading branch information
Rami Chowdhury committed Apr 19, 2016
1 parent 030160b commit 1b462d4
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 51 deletions.
19 changes: 16 additions & 3 deletions blaze/server/server.py
Expand Up @@ -47,6 +47,7 @@ class RC(object):

BAD_REQUEST = 400
UNAUTHORIZED = 401
NOT_FOUND = 404
FORBIDDEN = 403
CONFLICT = 409
UNPROCESSABLE_ENTITY = 422
Expand Down Expand Up @@ -109,6 +110,13 @@ def _register_api(app, options, first_registration=False):
profiler_output,
profile_by_default)

# Allowing users to dynamically add datasets to the Blaze server can be
# dangerous, so we only expose the method if specifically requested
allow_add = _get_option('allow_add', options, False)
if allow_add:
app.add_url_rule('/add', 'addserver', addserver,
methods=['POST', 'HEAD', 'OPTIONS'])

# Call the original register function.
Blueprint.register(api, app, options, first_registration)

Expand Down Expand Up @@ -254,6 +262,10 @@ class Server(object):
Run the profiler on any computation that does not explicitly set
"profile": false.
This requires `allow_profiler=True`.
allow_add : bool, optional
Expose an `/add` endpoint to allow datasets to be dynamically added to
the server. Since this increases the risk of security holes, it defaults
to `False`.
Examples
--------
Expand All @@ -274,7 +286,8 @@ def __init__(self,
authorization=None,
allow_profiler=False,
profiler_output=None,
profile_by_default=False):
profile_by_default=False,
allow_add=False):
if isinstance(data, collections.Mapping):
data = valmap(lambda v: v.data if isinstance(v, _Data) else v,
data)
Expand All @@ -289,7 +302,8 @@ def __init__(self,
authorization=authorization,
allow_profiler=allow_profiler,
profiler_output=profiler_output,
profile_by_default=profile_by_default)
profile_by_default=profile_by_default,
allow_add=allow_add)
self.data = data

def run(self, port=DEFAULT_PORT, retry=False, **kwargs):
Expand Down Expand Up @@ -578,7 +592,6 @@ def log_time(start=time()):
return serial.dumps(response)


@api.route('/add', methods=['POST', 'HEAD', 'OPTIONS'])
@cross_origin(origins='*', methods=['POST', 'HEAD', 'OPTIONS'])
@authorization
@check_request
Expand Down
130 changes: 82 additions & 48 deletions blaze/server/tests/test_server.py
Expand Up @@ -80,6 +80,13 @@ def server():
return s


@pytest.fixture(scope='module')
def add_server():
s = Server(tdata, all_formats, allow_add=True)
s.app.testing = True
return s


@pytest.yield_fixture(params=[None, tdata])
def temp_server(request):
"""For when we want to mutate the server"""
Expand All @@ -89,17 +96,32 @@ def temp_server(request):
with s.app.test_client() as c:
yield c

@pytest.yield_fixture(params=[None, tdata])
def temp_add_server(request):
"""For when we want to mutate the server, and also add datasets to it."""
data = request.param
s = Server(copy(data), formats=all_formats, allow_add=True)
s.app.testing = True
with s.app.test_client() as c:
yield c


@pytest.yield_fixture
def test(server):
with server.app.test_client() as c:
yield c


@pytest.yield_fixture
def test_add(add_server):
with add_server.app.test_client() as c:
yield c


@pytest.yield_fixture
def iris_server():
iris = CSV(example('iris.csv'))
s = Server(iris, all_formats)
s = Server(iris, all_formats, allow_add=True)
s.app.testing = True
with s.app.test_client() as c:
yield c
Expand Down Expand Up @@ -413,8 +435,8 @@ def test_cors_datashape(test):
assert 'POST' not in res.headers['Allow']


def test_cors_add(test):
res = test.options('/add')
def test_cors_add(test_add):
res = test_add.options('/add')
assert res.status_code == RC.OK
assert 'HEAD' in res.headers['Allow']
assert 'POST' in res.headers['Allow']
Expand Down Expand Up @@ -514,18 +536,29 @@ def test_isin(test, serial):


@pytest.mark.parametrize('serial', all_formats)
def test_add_data_to_server(temp_server, serial):
# add data
def test_add_default_not_allowed(temp_server, serial):
iris_path = example('iris.csv')
blob = serial.dumps({'iris': iris_path})
response1 = temp_server.post('/add',
headers=mimetype(serial),
data=blob)
assert 'NOT FOUND' in response1.status
assert response1.status_code == RC.NOT_FOUND


@pytest.mark.parametrize('serial', all_formats)
def test_add_data_to_server(temp_add_server, serial):
# add data
iris_path = example('iris.csv')
blob = serial.dumps({'iris': iris_path})
response1 = temp_add_server.post('/add',
headers=mimetype(serial),
data=blob)
assert 'CREATED' in response1.status
assert response1.status_code == RC.CREATED

# check for expected server datashape
response2 = temp_server.get('/datashape')
response2 = temp_add_server.get('/datashape')
expected2 = discover({'iris': data(iris_path)})
response_dshape = datashape.dshape(response2.data.decode('utf-8'))
assert_dshape_equal(response_dshape.measure.dict['iris'],
Expand All @@ -535,9 +568,9 @@ def test_add_data_to_server(temp_server, serial):
t = data({'iris': data(iris_path)})
expr = t.iris.petal_length.sum()

response3 = temp_server.post('/compute',
data=serial.dumps({'expr': to_tree(expr)}),
headers=mimetype(serial))
response3 = temp_add_server.post('/compute',
data=serial.dumps({'expr': to_tree(expr)}),
headers=mimetype(serial))

result3 = serial.data_loads(serial.loads(response3.data)['data'])
expected3 = compute(expr, {'iris': data(iris_path)})
Expand All @@ -556,104 +589,105 @@ def test_cant_add_data_to_server(iris_server, serial):


@pytest.mark.parametrize('serial', all_formats)
def test_add_data_twice_error(temp_server, serial):
def test_add_data_twice_error(temp_add_server, serial):
# add iris
iris_path = example('iris.csv')
payload = serial.dumps({'iris': iris_path})
temp_server.post('/add',
headers=mimetype(serial),
data=payload)
temp_add_server.post('/add',
headers=mimetype(serial),
data=payload)

# Try to add to existing 'iris'
resp = temp_server.post('/add',
headers=mimetype(serial),
data=payload)
resp = temp_add_server.post('/add',
headers=mimetype(serial),
data=payload)
assert resp.status_code == RC.CONFLICT

# Verify the server still serves the original 'iris'.
ds = datashape.dshape(temp_server.get('/datashape').data.decode('utf-8'))
response_ds = temp_add_server.get('/datashape').data.decode('utf-8')
ds = datashape.dshape(response_ds)
t = symbol('t', ds)
query = {'expr': to_tree(t.iris)}
resp = temp_server.post('/compute',
data=serial.dumps(query),
headers=mimetype(serial))
resp = temp_add_server.post('/compute',
data=serial.dumps(query),
headers=mimetype(serial))
assert resp.status_code == RC.OK


@pytest.mark.parametrize('serial', all_formats)
def test_add_two_data_sets_at_once_error(temp_server, serial):
def test_add_two_data_sets_at_once_error(temp_add_server, serial):
# Try to add two things at once
payload = serial.dumps({'foo': 'iris.csv',
'bar': 'iris.csv'})
resp = temp_server.post('/add',
headers=mimetype(serial),
data=payload)
resp = temp_add_server.post('/add',
headers=mimetype(serial),
data=payload)
assert resp.status_code == RC.UNPROCESSABLE_ENTITY


@pytest.mark.parametrize('serial', all_formats)
def test_add_bunk_data_error(temp_server, serial):
def test_add_bunk_data_error(temp_add_server, serial):
# Try to add bunk data
payload = serial.dumps({'foo': None})
resp = temp_server.post('/add',
headers=mimetype(serial),
data=payload)
resp = temp_add_server.post('/add',
headers=mimetype(serial),
data=payload)
assert resp.status_code == RC.UNPROCESSABLE_ENTITY


@pytest.mark.parametrize('serial', all_formats)
def test_bad_add_payload(temp_server, serial):
def test_bad_add_payload(temp_add_server, serial):
# try adding more data to server
blob = serial.dumps('This is not a mutable mapping.')
response1 = temp_server.post('/add',
headers=mimetype(serial),
data=blob)
response1 = temp_add_server.post('/add',
headers=mimetype(serial),
data=blob)
assert response1.status_code == RC.UNPROCESSABLE_ENTITY


@pytest.mark.parametrize('serial', all_formats)
def test_add_expanded_payload(temp_server, serial):
def test_add_expanded_payload(temp_add_server, serial):
# Ensure that the expanded payload format is accepted by the server
iris_path = example('iris.csv')
blob = serial.dumps({'iris': {'source': iris_path,
'kwargs': {'delimiter': ','}}})
response1 = temp_server.post('/add',
headers=mimetype(serial),
data=blob)
response1 = temp_add_server.post('/add',
headers=mimetype(serial),
data=blob)
assert 'CREATED' in response1.status
assert response1.status_code == RC.CREATED


@pytest.mark.parametrize('serial', all_formats)
def test_add_expanded_payload_with_imports(temp_server, serial):
def test_add_expanded_payload_with_imports(temp_add_server, serial):
# Ensure that the expanded payload format is accepted by the server
iris_path = example('iris.csv')
blob = serial.dumps({'iris': {'source': iris_path,
'kwargs': {'delimiter': ','},
'imports': ['csv']}})
response1 = temp_server.post('/add',
headers=mimetype(serial),
data=blob)
response1 = temp_add_server.post('/add',
headers=mimetype(serial),
data=blob)
assert 'CREATED' in response1.status
assert response1.status_code == RC.CREATED


@pytest.mark.parametrize('serial', all_formats)
def test_add_expanded_payload_has_effect(temp_server, serial):
def test_add_expanded_payload_has_effect(temp_add_server, serial):
# Ensure that the expanded payload format actually passes the arguments
# through to the resource constructor
iris_path = example('iris-latin1.tsv')
csv_kwargs = {'delimiter': '\t', 'encoding': 'iso-8859-1'}
blob = serial.dumps({'iris': {'source': iris_path,
'kwargs': csv_kwargs}})
response1 = temp_server.post('/add',
headers=mimetype(serial),
data=blob)
response1 = temp_add_server.post('/add',
headers=mimetype(serial),
data=blob)
assert 'CREATED' in response1.status
assert response1.status_code == RC.CREATED

# check for expected server datashape
response2 = temp_server.get('/datashape')
response2 = temp_add_server.get('/datashape')
expected2 = discover({'iris': data(iris_path, **csv_kwargs)})
response_dshape = datashape.dshape(response2.data.decode('utf-8'))
assert_dshape_equal(response_dshape.measure.dict['iris'],
Expand All @@ -663,9 +697,9 @@ def test_add_expanded_payload_has_effect(temp_server, serial):
t = data({'iris': data(iris_path, **csv_kwargs)})
expr = t.iris.petal_length.sum()

response3 = temp_server.post('/compute',
data=serial.dumps({'expr': to_tree(expr)}),
headers=mimetype(serial))
response3 = temp_add_server.post('/compute',
data=serial.dumps({'expr': to_tree(expr)}),
headers=mimetype(serial))

result3 = serial.data_loads(serial.loads(response3.data)['data'])
expected3 = compute(expr, {'iris': data(iris_path, **csv_kwargs)})
Expand Down

0 comments on commit 1b462d4

Please sign in to comment.