Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error reponse of ms defender connector #747

Merged
merged 4 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions stix_shifter_modules/msatp/stix_transmission/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@ def __init__(self, connection, configuration):
"""Initialization.
:param connection: dict, connection dict
:param configuration: dict,config dict"""

try:
self.token = Connector.generate_token(connection, configuration)
configuration['auth']['access_token'] = self.token
self.api_client = APIClient(connection, configuration)

except Exception as ex:
self.init_error = ex
self.adal_response = Connector.generate_token(connection, configuration)
if self.adal_response['success']:
configuration['auth']['access_token'] = self.adal_response['access_token']
self.api_client = APIClient(connection, configuration)
else:
self.init_error = True

@staticmethod
def _handle_errors(response, return_obj):
Expand Down Expand Up @@ -51,7 +50,7 @@ def ping_connection(self):
"""Ping the endpoint."""
return_obj = dict()
if self.init_error:
raise self.init_error
return self.adal_response
response = self.api_client.ping_box()
response_code = response.code
if 200 <= response_code < 300:
Expand All @@ -76,7 +75,7 @@ def create_results_connection(self, query, offset, length):

try:
if self.init_error:
raise self.init_error
return self.adal_response
response = self.api_client.run_search(query, offset, length)
return_obj = self._handle_errors(response, return_obj)
response_json = json.loads(return_obj["data"])
Expand Down Expand Up @@ -120,6 +119,7 @@ def generate_token(connection, configuration):
"""To generate the Token
:param connection: dict, connection dict
:param configuration: dict,config dict"""
return_obj = dict()

authority_url = ('https://login.windows.net/' +
configuration['auth']['tenant'])
Expand All @@ -129,17 +129,19 @@ def generate_token(connection, configuration):
context = adal.AuthenticationContext(
authority_url, validate_authority=configuration['auth']['tenant'] != 'adfs',
)
token = context.acquire_token_with_client_credentials(
response_dict = context.acquire_token_with_client_credentials(
resource,
configuration['auth']['clientId'],
configuration['auth']['clientSecret'])

token_value = token['accessToken']
return token_value

return_obj['success'] = True
return_obj['access_token'] = response_dict['accessToken']
except Exception as ex:
return_obj = dict()
if ex.error_response:
ErrorResponder.fill_error(return_obj, ex.error_response, ['reason'])
Connector.logger.error("Token generation Failed: " + return_obj)
raise ex
if ex.__class__.__name__ == 'AdalError':
response_dict = ex.error_response
ErrorResponder.fill_error(return_obj, response_dict, ['error_description'])
else:
ErrorResponder.fill_error(return_obj, message=str(ex))
Connector.logger.error("Token generation Failed: " + str(ex.error_response))

return return_obj
33 changes: 19 additions & 14 deletions stix_shifter_modules/msatp/tests/stix_transmission/test_msatp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@ def __init__(self, response_code, obj):
def read(self):
return bytearray(self.object, 'utf-8')

class AdalMockResponse:

@patch('stix_shifter_modules.msatp.stix_transmission.connector.Connector.generate_token')
@staticmethod
def acquire_token_with_client_credentials(resource, client_id, client_secret):
context_response = dict()
context_response['accessToken'] = 'abc12345'
return context_response

@patch('stix_shifter_modules.msatp.stix_transmission.connector.adal.AuthenticationContext')
@patch('stix_shifter_modules.msatp.stix_transmission.api_client.APIClient.__init__')
class TestMSATPConnection(unittest.TestCase):
def config(self):
Expand All @@ -35,7 +42,7 @@ def connection(self):

def test_is_async(self, mock_api_client, mock_generate_token):
mock_api_client.return_value = None
mock_generate_token.return_value = 'test'
mock_generate_token.return_value = AdalMockResponse
entry_point = EntryPoint(self.connection(), self.config())
check_async = entry_point.is_async()

Expand All @@ -45,25 +52,24 @@ def test_is_async(self, mock_api_client, mock_generate_token):
def test_ping_endpoint(self, mock_ping_response, mock_api_client, mock_generate_token):

mock_api_client.return_value = None
mock_generate_token.return_value = None
mock_generate_token.return_value = AdalMockResponse
mocked_return_value = '["mock", "placeholder"]'

mock_ping_response.return_value = MSATPMockResponse(200, mocked_return_value)
print(str(self.connection))
print(str(self.config))
transmission = stix_transmission.StixTransmission('msatp', self.connection(), self.config())
ping_response = transmission.ping()

assert ping_response is not None
assert ping_response['success']

'''
@patch('stix_shifter_modules.msatp.stix_transmission.api_client.APIClient.ping_box')
def test_ping_endpoint_exception(self, mock_ping_response, mock_api_client, mock_generate_token):
mock_api_client.return_value = None
mock_generate_token.return_value = None
mock_generate_token.return_value = AdalMockResponse
mocked_return_value = '["mock", "placeholder"]'
mock_ping_response.return_value = MSATPMockResponse(200, mocked_return_value)
mock_ping_response.return_value = MSATPMockResponse(400, mocked_return_value)
mock_ping_response.side_effect = Exception('exception')

transmission = stix_transmission.StixTransmission('msatp', self.connection(), self.config())
Expand All @@ -72,12 +78,11 @@ def test_ping_endpoint_exception(self, mock_ping_response, mock_api_client, mock
assert ping_response is not None
assert ping_response['success'] is False
assert ping_response['code'] == ErrorCode.TRANSMISSION_UNKNOWN.value
'''

def test_query_connection(self, mock_api_client, mock_generate_token):

mock_api_client.return_value = None
mock_generate_token.return_value = None
mock_generate_token.return_value = AdalMockResponse

query = "(find withsource = TableName in (DeviceNetworkEvents) where Timestamp >= datetime(" \
"2019-09-24T16:32:32.993821Z) and Timestamp < datetime(2019-09-24T16:37:32.993821Z) | order by " \
Expand All @@ -96,7 +101,7 @@ def test_results_file_response(self, mock_results_response, mock_api_client, moc


mock_api_client.return_value = None
mock_generate_token.return_value = None
mock_generate_token.return_value = AdalMockResponse
mocked_return_value = """{
"Results": [{
"TableName": "DeviceFileEvents",
Expand Down Expand Up @@ -130,7 +135,7 @@ def test_results_registry_response(self, mock_results_response, mock_api_client,


mock_api_client.return_value = None
mock_generate_token.return_value = None
mock_generate_token.return_value = AdalMockResponse
mocked_return_value = """{"Results": [{"TableName": "DeviceRegistryEvents","Timestamp": "2019-10-10T10:43:07.2363291Z","DeviceId":
"db40e68dd7358aa450081343587941ce96ca4777","DeviceName": "testmachine1","ActionType": "RegistryValueSet",
"RegistryKey": "HKEY_LOCAL_MACHINE\\\\SYSTEM\\\\ControlSet001\\\\Services\\\\WindowsAzureGuestAgent",
Expand Down Expand Up @@ -181,7 +186,7 @@ def test_results_response_exception(self, mock_results_response, mock_api_client
def test_query_flow(self, mock_results_response, mock_api_client, mock_generate_token):

mock_api_client.return_value = None
mock_generate_token.return_value = None
mock_generate_token.return_value = AdalMockResponse
results_mock = """{
"Results": [{
"TableName": "DeviceFileEvents",
Expand Down Expand Up @@ -225,7 +230,7 @@ def test_query_flow(self, mock_results_response, mock_api_client, mock_generate_

def test_delete_query(self, mock_api_client, mock_generate_token):
mock_api_client.return_value = None
mock_generate_token.return_value = None
mock_generate_token.return_value = AdalMockResponse

search_id = '(find withsource = TableName in (DeviceFileEvents) where Timestamp >= datetime(' \
'2019-09-01T08:43:10.003Z) and Timestamp < datetime(2019-10-01T10:43:10.003Z) | order by ' \
Expand All @@ -242,7 +247,7 @@ def test_status_query(self, mock_api_client, mock_generate_token):


mock_api_client.return_value = None
mock_generate_token.return_value = None
mock_generate_token.return_value = AdalMockResponse

search_id = '(find withsource = TableName in (DeviceFileEvents) where Timestamp >= datetime(' \
'2019-09-01T08:43:10.003Z) and Timestamp < datetime(2019-10-01T10:43:10.003Z) | order by ' \
Expand Down