Skip to content

Commit

Permalink
bugfix: add cloud account type to run_assesment (#12)
Browse files Browse the repository at this point in the history
Co-authored-by: Sam Fourgeaud <samuel.fourgeaud@onixnet.com>
  • Loading branch information
samforger and sam-fourgeaud-onix committed May 10, 2023
1 parent 2b9909c commit 96633c6
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions dome9/dome9.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class Dome9(object):
def __init__(self, key=None, secret=None, endpoint='https://api.dome9.com', apiVersion='v2'):
self.key = None
self.secret = None
self.headers = {'Content-Type': 'application/json', 'Accept': 'application/json'}
self.headers = {'Content-Type': 'application/json',
'Accept': 'application/json'}
self.endpoint = endpoint + '/{}/'.format(apiVersion)
self._load_credentials(key, secret)

Expand All @@ -48,15 +49,20 @@ def _request(self, method, route, payload=None): # noqa: C901 (lint ignore)

try:
if method == 'get':
res = requests.get(url=url, params=_payload, headers=self.headers, auth=(self.key, self.secret))
res = requests.get(
url=url, params=_payload, headers=self.headers, auth=(self.key, self.secret))
elif method == 'post':
res = requests.post(url=url, data=_payload, headers=self.headers, auth=(self.key, self.secret))
res = requests.post(
url=url, data=_payload, headers=self.headers, auth=(self.key, self.secret))
elif method == 'patch':
res = requests.patch(url=url, json=_payload, headers=self.headers, auth=(self.key, self.secret))
res = requests.patch(
url=url, json=_payload, headers=self.headers, auth=(self.key, self.secret))
elif method == 'put':
res = requests.put(url=url, data=_payload, headers=self.headers, auth=(self.key, self.secret))
res = requests.put(
url=url, data=_payload, headers=self.headers, auth=(self.key, self.secret))
elif method == 'delete':
res = requests.delete(url=url, params=_payload, headers=self.headers, auth=(self.key, self.secret))
res = requests.delete(
url=url, params=_payload, headers=self.headers, auth=(self.key, self.secret))
return bool(res.status_code == 204)

except requests.ConnectionError as ex:
Expand All @@ -68,9 +74,11 @@ def _request(self, method, route, payload=None): # noqa: C901 (lint ignore)
if res.content:
jsonObject = res.json()
except Exception as ex:
err = {'code': res.status_code, 'message': getattr(ex, 'message', ''), 'content': res.content}
err = {'code': res.status_code, 'message': getattr(
ex, 'message', ''), 'content': res.content}
else:
err = {'code': res.status_code, 'message': res.reason, 'content': res.content}
err = {'code': res.status_code,
'message': res.reason, 'content': res.content}

if err:
raise Exception(err)
Expand Down Expand Up @@ -222,15 +230,17 @@ def list_protected_assets(self, textSearch="", filters=[], pageSize=1000):
.. literalinclude:: schemas/ProtectedAsset.json
"""
results = {}
pagination = {"pageSize": pageSize, "filter": {"fields": filters, 'freeTextPhrase': textSearch}}
pagination = {"pageSize": pageSize, "filter": {
"fields": filters, 'freeTextPhrase': textSearch}}
rsp = self._post(route='protected-asset/search', payload=pagination)
results = rsp

self.list_protected_assets()

while rsp['searchAfter']:
pagination['searchAfter'] = rsp['searchAfter']
rsp = self._post(route='protected-asset/search', payload=pagination)
rsp = self._post(route='protected-asset/search',
payload=pagination)
results['assets'].extend(rsp['assets'])

return results
Expand Down Expand Up @@ -430,12 +440,13 @@ def delete_exclusion(self, exclusionId):
# ------------------ Assessments ------------------
# --------------------------------------------------

def run_assessment(self, rulesetId, cloudAccountId, region=None):
def run_assessment(self, rulesetId, cloudAccountId, cloudAccountType, region=None):
"""Run compliance assessments on Cloud Accounts, and get the results
Args:
rulesetId (str): Id of the Compliance Policy Ruleset to run
cloudAccountId (str): Id of the Cloud Account
cloudAccountType (str): Type of the Cloud Account (Google, Aws, Azure, Kubernetes, ...)
region (str, optional): Set a specific region. Defaults to None.
Returns:
Expand All @@ -446,7 +457,8 @@ def run_assessment(self, rulesetId, cloudAccountId, region=None):
"""
bundle = {
'id': rulesetId,
'CloudAccountId': cloudAccountId
'CloudAccountId': cloudAccountId,
'cloudAccountType': cloudAccountType
}
if region:
bundle['region'] = region
Expand Down

0 comments on commit 96633c6

Please sign in to comment.