Skip to content

Commit

Permalink
Fix bug in references field of vuln templates API
Browse files Browse the repository at this point in the history
It didn't correctly handle values other than a comma-separated string
nor empty strings
A combination of bugs in the tests, the load_references method and
marshmallow 2.13.6 made the test altough the implementation was
incorrect.

See marshmallow-code/marshmallow#395 for more
info on the marshmallow bug
  • Loading branch information
cript0nauta committed Oct 27, 2017
1 parent fd1adf2 commit 23d1a65
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
12 changes: 10 additions & 2 deletions server/api/modules/vulnerability_template.py
Expand Up @@ -8,7 +8,7 @@
FilterSet,
operators,
)
from marshmallow import fields
from marshmallow import fields, ValidationError
from marshmallow.validate import (
OneOf,
)
Expand Down Expand Up @@ -44,7 +44,15 @@ def get_references(self, obj):
return ', '.join(map(lambda ref_tmpl: ref_tmpl.name, obj.reference_template_instances))

def load_references(self, value):
return value.split(',')
if not isinstance(value, (unicode, str)):
raise ValidationError('references must be a string')
if len(value) == 0:
# Required because "".split(",") == [""]
return []
references = value.split(',')
if any(len(ref) == 0 for ref in references):
raise ValidationError('Empty name detected in reference')
return references


class VulnerabilityTemplateFilterSet(FilterSet):
Expand Down
31 changes: 26 additions & 5 deletions test_cases/test_api_vulnerability_template.py
Expand Up @@ -60,16 +60,18 @@ def _create_post_data_vulnerability_template(self, references):

def test_create_new_vulnerability_template(self, session, test_client):
vuln_count_previous = session.query(VulnerabilityTemplate).count()
raw_data = self._create_post_data_vulnerability_template(references=[])
raw_data = self._create_post_data_vulnerability_template(references='')
res = test_client.post('/v2/vulnerability_template/', data=raw_data)
assert res.status_code == 201
assert res.json['_id'] == 6
assert vuln_count_previous + 1 == session.query(VulnerabilityTemplate).count()
vuln_template = VulnerabilityTemplate.query.get(6)
assert vuln_template.references == set()

def test_update_vulnerability_template(self, session, test_client):
template = self.factory.create()
session.commit()
raw_data = self._create_post_data_vulnerability_template(references=[])
raw_data = self._create_post_data_vulnerability_template(references='')
res = test_client.put('/v2/vulnerability_template/{0}/'.format(template.id), data=raw_data)
assert res.status_code == 200
updated_template = session.query(VulnerabilityTemplate).filter_by(id=template.id).first()
Expand All @@ -79,13 +81,32 @@ def test_update_vulnerability_template(self, session, test_client):
assert updated_template.description == raw_data['description']
assert updated_template.references == set([])

@pytest.mark.parametrize('references', [
',',
',,',
'a,',
['a', 'b', 'c'],
['a', 'b', ''],
[],
{"a": 1},
{}
])
def test_400_on_invalid_reference(self, session, test_client, references):
template = self.factory.create()
session.commit()
raw_data = self._create_post_data_vulnerability_template(
references=references)
res = test_client.put('/v2/vulnerability_template/{0}/'.format(
template.id), data=raw_data)
assert res.status_code == 400

def test_update_vulnerabiliy_template_change_refs(self, session, test_client):
template = self.factory.create()
for ref_name in set(['old1', 'old2']):
ref = ReferenceTemplateFactory.create(name=ref_name)
self.first_object.reference_template_instances.add(ref)
session.commit()
raw_data = self._create_post_data_vulnerability_template(references=['new_ref', 'another_ref'])
raw_data = self._create_post_data_vulnerability_template(references='new_ref,another_ref')
res = test_client.put('/v2/vulnerability_template/{0}/'.format(template.id), data=raw_data)
assert res.status_code == 200
updated_template = session.query(VulnerabilityTemplate).filter_by(id=template.id).first()
Expand All @@ -97,7 +118,7 @@ def test_update_vulnerabiliy_template_change_refs(self, session, test_client):

def test_create_new_vulnerability_template_with_references(self, session, test_client):
vuln_count_previous = session.query(VulnerabilityTemplate).count()
raw_data = self._create_post_data_vulnerability_template(references=['ref1', 'ref2'])
raw_data = self._create_post_data_vulnerability_template(references='ref1,ref2')
res = test_client.post('/v2/vulnerability_template/', data=raw_data)
assert res.status_code == 201
assert res.json['_id'] == 6
Expand All @@ -111,4 +132,4 @@ def test_delete_vuln_template(self, session, test_client):
res = test_client.delete('/v2/vulnerability_template/{0}/'.format(template.id))

assert res.status_code == 204
assert vuln_count_previous - 1 == session.query(VulnerabilityTemplate).count()
assert vuln_count_previous - 1 == session.query(VulnerabilityTemplate).count()

0 comments on commit 23d1a65

Please sign in to comment.