Skip to content

Commit

Permalink
minor refactor to reduce code copy
Browse files Browse the repository at this point in the history
  • Loading branch information
heiderje-vmware committed Dec 8, 2023
1 parent b408043 commit 56a1dc7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 22 deletions.
23 changes: 6 additions & 17 deletions src/cbc_sdk/platform/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"""Alert Models"""

MAX_RESULTS_LIMIT = 10000
REQUEST_IGNORED_KEYS = ["_doc_class", "_cb", "_count_valid", "_total_results", "_query_builder", "_sortcriteria"]


class Alert(PlatformModel):
Expand Down Expand Up @@ -990,14 +991,10 @@ def get_alert_search_query(self):
Note:
Does not preserve sort criterion
"""
ignored_keys = ["_doc_class", "_cb", "_count_valid", "_total_results", "_sortcriteria", "_query_builder"]
alert_search_query = self._cb.select(Alert)
for key, value in vars(alert_search_query).items():
if hasattr(self._request, key) and key not in ignored_keys:
if hasattr(self._request, key) and key not in REQUEST_IGNORED_KEYS:
setattr(alert_search_query, key, self._request.__getattribute__(key))
key = "_time_range"
if hasattr(self._request, key):
setattr(alert_search_query, key, self._request.__getattribute__(key))

alert_search_query.add_criteria(self._request._group_by.lower(), self.most_recent_alert["threat_id"])
return alert_search_query
Expand Down Expand Up @@ -1038,6 +1035,7 @@ def __init__(self, doc_class, cb):
self._query_builder = QueryBuilder()
self._criteria = {}
self._time_filters = {}
self._time_range = {}
self._exclusions = {}
self._time_exclusion_filters = {}
self._sortcriteria = {}
Expand Down Expand Up @@ -1120,7 +1118,6 @@ def set_time_range(self, *args, **kwargs):
else:
# everything before this is only for backwards compatibility, once v6 deprecates all the other
# checks can be removed
self._time_range = {}
self._time_range = time_filter
return self

Expand Down Expand Up @@ -1304,7 +1301,7 @@ def _build_request(self, from_row, max_rows, add_sort=True):
request["query"] = query

request["rows"] = self._batch_size
if hasattr(self, "_time_range"):
if self._time_range != {}:
request["time_range"] = self._time_range
if from_row > 1:
request["start"] = from_row
Expand Down Expand Up @@ -1601,14 +1598,10 @@ def set_group_by(self, field):
Note:
Does not preserve sort criterion
"""
ignored_keys = ["_doc_class", "_cb", "_count_valid", "_total_results", "_query_builder", "_sortcriteria"]
grouped_alert_search_query = self._cb.select(GroupedAlert)
for key, value in vars(grouped_alert_search_query).items():
if hasattr(self, key) and key not in ignored_keys:
if hasattr(self, key) and key not in REQUEST_IGNORED_KEYS:
setattr(grouped_alert_search_query, key, self.__getattribute__(key))
key = "_time_range"
if hasattr(self, key):
setattr(grouped_alert_search_query, key, self.__getattribute__(key))
grouped_alert_search_query.set_group_by(field)

return grouped_alert_search_query
Expand Down Expand Up @@ -1655,14 +1648,10 @@ def get_alert_search_query(self):
Note: Does not preserve sort criterion
"""
ignored_keys = ["_doc_class", "_cb", "_count_valid", "_total_results", "_query_builder", "_sortcriteria"]
alert_search_query = self._cb.select(Alert)
for key, value in vars(alert_search_query).items():
if hasattr(self, key) and key not in ignored_keys:
if hasattr(self, key) and key not in REQUEST_IGNORED_KEYS:
setattr(alert_search_query, key, self.__getattribute__(key))
key = "_time_range"
if hasattr(self, key):
setattr(alert_search_query, key, self.__getattribute__(key))

return alert_search_query

Expand Down
10 changes: 5 additions & 5 deletions src/tests/unit/platform/test_alertsv7_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
DeviceControlAlert,
GroupedAlert,
Process,
Job,
AlertSearchQuery,
GroupedAlertSearchQuery
Job
)
from cbc_sdk.rest_api import CBCloudAPI
from tests.unit.fixtures.CBCSDKMock import CBCSDKMock
Expand Down Expand Up @@ -2121,7 +2119,8 @@ def test_grouped_alert_search_query_to_alert_search_query(cbcsdk_mock):
delattr(alert_search_query, "_query_builder")
delattr(expected_alert_search_query, "_query_builder")

assert isinstance(alert_search_query, AlertSearchQuery)
assert alert_search_query.__module__ == "cbc_sdk.platform.alerts" and type(alert_search_query).__name__ == \
"AlertSearchQuery"
assert vars(alert_search_query) == vars(expected_alert_search_query)


Expand All @@ -2140,5 +2139,6 @@ def test_alert_search_query_to_grouped_alert_search_query(cbcsdk_mock):
delattr(grouped_alert_search_query, "_query_builder")
delattr(expected_grouped_alert_search_query, "_query_builder")

assert isinstance(grouped_alert_search_query, GroupedAlertSearchQuery)
assert grouped_alert_search_query.__module__ == "cbc_sdk.platform.alerts" and type(grouped_alert_search_query).\
__name__ == "GroupedAlertSearchQuery"
assert vars(grouped_alert_search_query) == vars(expected_grouped_alert_search_query)

0 comments on commit 56a1dc7

Please sign in to comment.