Skip to content

Commit

Permalink
Bugfix/is usage audit (#108)
Browse files Browse the repository at this point in the history
* fix alert_client list checks

* improve checking for list-like arguments

* add vars to the Py42Response getitem/setitem error messages

* update tests

* changelog
  • Loading branch information
timabrmsn committed May 13, 2020
1 parent 3e01765 commit e8b5f30
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 23 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
The intended audience of this file is for py42 consumers -- as such, changes that don't affect
how a consumer would use the library (e.g. adding unit tests, updating documentation, etc) are not captured here.

## Unreleased

### Changed

- The following methods that required either a single str or list of string argument can now also accept a tuple of strings:
- `py42._internal.clients.alerts.AlertClient.get_details`
- `py42._internal.clients.alerts.AlertClient.resolve`
- `py42._internal.clients.alerts.AlertClient.reopen`
- `py42._internal.clients.detection_list_user.DetectionListUserClient.add_risk_tags`
- `py42._internal.clients.detection_list_user.DetectionListUserClient.remove_risk_tags`

## 1.1.3 - 2020-05-12

### Changed
Expand Down
6 changes: 3 additions & 3 deletions src/py42/_internal/clients/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def search(self, query):
return self._session.post(uri, data=query)

def get_details(self, alert_ids, tenant_id=None):
if type(alert_ids) is not list:
if not isinstance(alert_ids, (list, tuple)):
alert_ids = [alert_ids]
tenant_id = tenant_id if tenant_id else self._user_context.get_current_tenant_id()
uri = self._uri_prefix.format(u"query-details")
Expand All @@ -27,7 +27,7 @@ def get_details(self, alert_ids, tenant_id=None):
return _convert_observation_json_strings_to_objects(results)

def resolve(self, alert_ids, tenant_id=None, reason=None):
if type(alert_ids) is not list:
if not isinstance(alert_ids, (list, tuple)):
alert_ids = [alert_ids]
tenant_id = tenant_id if tenant_id else self._user_context.get_current_tenant_id()
reason = reason if reason else u""
Expand All @@ -36,7 +36,7 @@ def resolve(self, alert_ids, tenant_id=None, reason=None):
return self._session.post(uri, data=json.dumps(data))

def reopen(self, alert_ids, tenant_id=None, reason=None):
if type(alert_ids) is not list:
if not isinstance(alert_ids, (list, tuple)):
alert_ids = [alert_ids]
tenant_id = tenant_id if tenant_id else self._user_context.get_current_tenant_id()
uri = self._uri_prefix.format(u"reopen-alert")
Expand Down
8 changes: 4 additions & 4 deletions src/py42/_internal/clients/detection_list_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ def add_risk_tags(self, user_id, tags):
Args:
user_id (str or int): The user_id whose tag(s) needs to be updated.
tags (str or list of str ): A single tag or multiple tags in a list to be added.
e.g u"tag1" or ["tag1", "tag2"], for python version 2.X, pass u"str" instead of "str"
e.g "tag1" or ["tag1", "tag2"]
Returns:
:class:`py42.response.Py42Response`
"""

if type(tags) is str:
if not isinstance(tags, (list, tuple)):
tags = [tags]

data = {
Expand All @@ -135,12 +135,12 @@ def remove_risk_tags(self, user_id, tags):
Args:
user_id (str or int): The user_id whose tag(s) needs to be removed.
tags (str or list of str ): A single tag or multiple tags in a list to be removed.
e.g u"tag1" or ["tag1", "tag2"], for python version 2.X, pass u"str" instead of "str"
e.g "tag1" or ["tag1", "tag2"].
Returns:
:class:`py42.response.Py42Response`
"""
if type(tags) is str:
if not isinstance(tags, (list, tuple)):
tags = [tags]

data = {
Expand Down
2 changes: 1 addition & 1 deletion src/py42/modules/securitydata.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _try_get_security_detection_event_client(self, plan_storage_info):
def _get_security_detection_events(
self, plan_storage_infos, cursor, include_files, event_types, min_timestamp, max_timestamp
):
if type(plan_storage_infos) is not list:
if not isinstance(plan_storage_infos, (list, tuple)):
plan_storage_infos = [plan_storage_infos]

# get the storage node client for each plan
Expand Down
8 changes: 6 additions & 2 deletions src/py42/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@ def __getitem__(self, key):
return self._data_root[key]
except TypeError as e:
data_root_type = type(self._data_root)
message = u"The Py42Response root is of type {}, but __getitem__ got a key of {}, which is incompatible."
message = u"The Py42Response root is of type {0}, but __getitem__ got a key of {1}, which is incompatible.".format(
data_root_type, key
)
raise Py42Error(message)

def __setitem__(self, key, value):
try:
self._data_root[key] = value
except TypeError as e:
data_root_type = type(self._data_root)
message = u"The Py42Response root is of type {}, but __setitem__ got a key of {} and value of {}, which is incompatible."
message = u"The Py42Response root is of type {0}, but __setitem__ got a key of {1} and value of {2}, which is incompatible.".format(
data_root_type, key, value
)
raise Py42Error(message)

def __iter__(self):
Expand Down
9 changes: 6 additions & 3 deletions tests/_internal/clients/test_detection_list_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,22 @@ def test_update_notes_posts_expected_data(self, mock_session, user_context, mock
and posted_data["notes"] == "Test"
)

def test_add_risk_tag_posts_expected_data(self, mock_session, user_context, mock_user_client):
@pytest.mark.parametrize("tags", ["test_tag", ("test_tag",), ["test_tag"]])
def test_add_risk_tag_posts_expected_data(
self, mock_session, user_context, mock_user_client, tags
):
detection_list_user_client = DetectionListUserClient(
mock_session, user_context, mock_user_client
)
detection_list_user_client.add_risk_tags("942897397520289999", u"Test")
detection_list_user_client.add_risk_tags("942897397520289999", tags)

posted_data = json.loads(mock_session.post.call_args[1]["data"])
assert mock_session.post.call_count == 1
assert mock_session.post.call_args[0][0] == "/svc/api/v2/user/addriskfactors"
assert (
posted_data["tenantId"] == user_context.get_current_tenant_id()
and posted_data["userId"] == "942897397520289999"
and posted_data["riskFactors"] == ["Test"]
and posted_data["riskFactors"] == ["test_tag"]
)

def test_remove_risk_tag_posts_expected_data(
Expand Down
15 changes: 9 additions & 6 deletions tests/clients/test_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,14 @@ def test_get_details_when_not_given_tenant_id_posts_expected_data(
and post_data["alertIds"][1] == "ALERT_ID_2"
)

@pytest.mark.parametrize("alert_id", ["ALERT_ID_1", ("ALERT_ID_1",), ["ALERT_ID_1"]])
def test_get_details_when_given_single_alert_id_posts_expected_data(
self, mock_session, user_context, successful_post, py42_response
self, mock_session, user_context, successful_post, py42_response, alert_id
):
py42_response.text = TEST_PARSEABLE_ALERT_DETAIL_RESPONSE
mock_session.post.return_value = py42_response
alert_client = AlertClient(mock_session, user_context)
alert_client.get_details("ALERT_ID_1")
alert_client.get_details(alert_id)
post_data = json.loads(mock_session.post.call_args[1]["data"])
assert (
post_data["tenantId"] == TENANT_ID_FROM_RESPONSE
Expand Down Expand Up @@ -189,11 +190,12 @@ def test_resolve_when_not_given_tenant_id_posts_expected_data(
and post_data["alertIds"][1] == "ALERT_ID_2"
)

@pytest.mark.parametrize("alert_id", ["ALERT_ID_1", ("ALERT_ID_1",), ["ALERT_ID_1"]])
def test_resolve_when_given_single_alert_id_posts_expected_data(
self, mock_session, user_context, successful_post
self, mock_session, user_context, successful_post, alert_id
):
alert_client = AlertClient(mock_session, user_context)
alert_client.resolve("ALERT_ID_1")
alert_client.resolve(alert_id)
post_data = json.loads(mock_session.post.call_args[1]["data"])
assert (
post_data["tenantId"] == TENANT_ID_FROM_RESPONSE
Expand Down Expand Up @@ -232,11 +234,12 @@ def test_reopen_when_not_given_tenant_id_posts_expected_data(
and post_data["alertIds"][1] == "ALERT_ID_2"
)

@pytest.mark.parametrize("alert_id", ["ALERT_ID_1", ("ALERT_ID_1",), ["ALERT_ID_1"]])
def test_reopen_when_given_single_alert_id_posts_expected_data(
self, mock_session, user_context, successful_post
self, mock_session, user_context, successful_post, alert_id
):
alert_client = AlertClient(mock_session, user_context)
alert_client.reopen("ALERT_ID_1")
alert_client.reopen(alert_id)
post_data = json.loads(mock_session.post.call_args[1]["data"])
assert (
post_data["tenantId"] == TENANT_ID_FROM_RESPONSE
Expand Down
19 changes: 15 additions & 4 deletions tests/modules/test_securitydata.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,21 @@ def test_get_all_user_security_events_when_multiple_plans_with_cursors_returned_
pass
assert mock_storage_security_client.get_plan_security_events.call_count == 4

@pytest.mark.parametrize(
"plan_storage_info",
[
PlanStorageInfo("111111111111111111", "41", "4"),
(PlanStorageInfo("111111111111111111", "41", "4"),),
[PlanStorageInfo("111111111111111111", "41", "4")],
],
)
def test_get_all_plan_security_events_calls_security_client_with_expected_params(
self, mocker, security_client, storage_client_factory, microservice_client_factory
self,
mocker,
security_client,
storage_client_factory,
microservice_client_factory,
plan_storage_info,
):
mock_storage_client = mocker.MagicMock(spec=StorageClient)
mock_storage_security_client = mocker.MagicMock(spec=StorageSecurityClient)
Expand All @@ -460,9 +473,7 @@ def test_get_all_plan_security_events_calls_security_client_with_expected_params
security_module = SecurityModule(
security_client, storage_client_factory, microservice_client_factory
)
for _, _ in security_module.get_all_plan_security_events(
PlanStorageInfo("111111111111111111", "41", "4")
):
for _, _ in security_module.get_all_plan_security_events(plan_storage_info):
pass
mock_storage_security_client.get_plan_security_events.assert_called_once_with(
"111111111111111111",
Expand Down

0 comments on commit e8b5f30

Please sign in to comment.