From 24bc68ef8e7617671a1f8d2bef4ebb60b726d57e Mon Sep 17 00:00:00 2001 From: annatisch Date: Wed, 25 May 2016 11:52:58 -0700 Subject: [PATCH] [Python] LRO and AAD updates and fixes (#1078) * Fixes for token refresh, and external token support * auth test updates * LRO refactor and PATCH fix * if cleanup * Separated refresh logic --- .../AcceptanceTests/lro_tests.py | 10 +- .../msrestazure/azure_active_directory.py | 83 +++++- .../msrestazure/azure_operation.py | 275 ++++++----------- .../Python/msrestazure/test/unittest_auth.py | 20 +- .../msrestazure/test/unittest_operation.py | 278 +++++++++++++++++- 5 files changed, 444 insertions(+), 222 deletions(-) diff --git a/AutoRest/Generators/Python/Azure.Python.Tests/AcceptanceTests/lro_tests.py b/AutoRest/Generators/Python/Azure.Python.Tests/AcceptanceTests/lro_tests.py index 39ef060fbe527..8d1e19912a256 100644 --- a/AutoRest/Generators/Python/Azure.Python.Tests/AcceptanceTests/lro_tests.py +++ b/AutoRest/Generators/Python/Azure.Python.Tests/AcceptanceTests/lro_tests.py @@ -169,13 +169,11 @@ def test_lro_happy_paths(self): process = self.client.lr_os.delete_provisioning202_accepted200_succeeded() self.assertEqual("Succeeded", process.result().provisioning_state) - # TODO: In C# this doesn't raise - self.assertRaisesWithMessage("Long running operation failed with status 'Canceled'", - self.client.lr_os.delete_provisioning202_deletingcanceled200().result) + result = self.client.lr_os.delete_provisioning202_deletingcanceled200().result() + self.assertEqual(result.provisioning_state, 'Canceled') - # TODO: In C# this doesn't raise - self.assertRaisesWithMessage("Long running operation failed with status 'Failed'", - self.client.lr_os.delete_provisioning202_deleting_failed200().result) + result = self.client.lr_os.delete_provisioning202_deleting_failed200().result() + self.assertEqual(result.provisioning_state, 'Failed') self.assertIsNone(self.client.lr_os.post202_no_retry204(product).result()) diff --git a/ClientRuntimes/Python/msrestazure/msrestazure/azure_active_directory.py b/ClientRuntimes/Python/msrestazure/msrestazure/azure_active_directory.py index 80cbbb7fb5e23..cda68088d8495 100644 --- a/ClientRuntimes/Python/msrestazure/msrestazure/azure_active_directory.py +++ b/ClientRuntimes/Python/msrestazure/msrestazure/azure_active_directory.py @@ -25,6 +25,7 @@ # -------------------------------------------------------------------------- import ast +import re import time try: from urlparse import urlparse, parse_qs @@ -92,7 +93,7 @@ def _https(uri, *extra): return _build_url(uri, extra, 'https') -class AADMixin(object): +class AADMixin(OAuthTokenAuthentication): """Mixin for Authentication object. Provides some AAD functionality: - State validation @@ -107,6 +108,7 @@ class AADMixin(object): _resource = 'https://management.core.windows.net/' _china_resource = "https://management.core.chinacloudapi.cn/" _keyring = "AzureAAD" + _case = re.compile('([a-z0-9])([A-Z])') def _configure(self, **kwargs): """Configure authentication endpoint. @@ -153,7 +155,17 @@ def _check_state(self, response): raise ValueError( "State received from server does not match that of request.") + def _convert_token(self, token): + """Convert token fields from camel case. + + :param dict token: An authentication token. + :rtype: dict + """ + return {self._case.sub(r'\1_\2', k).lower(): v + for k, v in token.items()} + def _parse_token(self): + # TODO: We could also check expires_on and use to update expires_in if self.token.get('expires_at'): countdown = float(self.token['expires_at']) - time.time() self.token['expires_in'] = countdown @@ -216,7 +228,70 @@ def clear_cached_token(self): raise_with_traceback(KeyError, "Unable to clear token.") -class UserPassCredentials(OAuthTokenAuthentication, AADMixin): +class AADRefreshMixin(object): + """ + Additional token refresh logic + """ + + def refresh_session(self): + """Return updated session if token has expired, attempts to + refresh using newly acquired token. + + :rtype: requests.Session. + """ + if self.token.get('refresh_token'): + try: + return self.signed_session() + except Expired: + pass + self.set_token() + return self.signed_session() + + +class AADTokenCredentials(AADMixin): + """ + Credentials objects for AAD token retrieved through external process + e.g. Python ADAL lib. + + Optional kwargs may include: + - china (bool): Configure auth for China-based service, + default is 'False'. + - tenant (str): Alternative tenant, default is 'common'. + - auth_uri (str): Alternative authentication endpoint. + - token_uri (str): Alternative token retrieval endpoint. + - resource (str): Alternative authentication resource, default + is 'https://management.core.windows.net/'. + - verify (bool): Verify secure connection, default is 'True'. + - keyring (str): Name of local token cache, default is 'AzureAAD'. + - cached (bool): If true, will not attempt to collect a token, + which can then be populated later from a cached token. + + :param dict token: Authentication token. + :param str client_id: Client ID, if not set, Xplat Client ID + will be used. + """ + + def __init__(self, token, client_id=None, **kwargs): + if not client_id: + # Default to Xplat Client ID. + client_id = '04b07795-8ddb-461a-bbee-02f9e1bf7b46' + super(AADTokenCredentials, self).__init__(client_id, None) + self._configure(**kwargs) + if not kwargs.get('cached'): + self.token = self._convert_token(token) + self.signed_session() + + @classmethod + def retrieve_session(cls, client_id=None): + """Create AADTokenCredentials from a cached token if it has not + yet expired. + """ + session = cls(None, None, client_id=client_id, cached=True) + session._retrieve_stored_token() + return session + + +class UserPassCredentials(AADRefreshMixin, AADMixin): """Credentials object for Headless Authentication, i.e. AAD authentication via username and password. @@ -298,7 +373,7 @@ def set_token(self): self.token = token -class ServicePrincipalCredentials(OAuthTokenAuthentication, AADMixin): +class ServicePrincipalCredentials(AADRefreshMixin, AADMixin): """Credentials object for Service Principle Authentication. Authenticates via a Client ID and Secret. @@ -361,7 +436,7 @@ def set_token(self): self.token = token -class InteractiveCredentials(OAuthTokenAuthentication, AADMixin): +class InteractiveCredentials(AADMixin): """Credentials object for Interactive/Web App Authentication. Requires that an AAD Client be configured with a redirect URL. diff --git a/ClientRuntimes/Python/msrestazure/msrestazure/azure_operation.py b/ClientRuntimes/Python/msrestazure/msrestazure/azure_operation.py index 0386b3a30b5cb..172858b700856 100644 --- a/ClientRuntimes/Python/msrestazure/msrestazure/azure_operation.py +++ b/ClientRuntimes/Python/msrestazure/msrestazure/azure_operation.py @@ -24,6 +24,7 @@ # # -------------------------------------------------------------------------- +import re import threading import time try: @@ -92,11 +93,20 @@ def __eq__(self, other): return self.__dict__ == other.__dict__ -class LongRunningOperationMixin(object): - """LongRunningOperation Mixin +class LongRunningOperation(object): + """LongRunningOperation Provides default logic for interpreting operation responses and status updates. """ + _convert = re.compile('([a-z0-9])([A-Z])') + + def __init__(self, response, outputs): + self.method = response.request.method + self.status = "" + self.resource = None + self.get_outputs = outputs + self.async_url = None + self.location_url = None def _validate(self, url): """Validate header url @@ -112,6 +122,19 @@ def _validate(self, url): raise ValueError("Invalid URL header") return url + def _check_status(self, response): + """Check response status code is valid for a Put or Patch + reqest. Must be 200, 202, or 204. + + :raises: BadStatus if invalid status. + """ + code = response.status_code + if code in [200, 202] or (code == 201 and self.method == 'PUT') or \ + (code == 204 and self.method in ['DELETE', 'POST']): + return + raise BadStatus( + "Invalid return status for {!r} operation".format(self.method)) + def _is_empty(self, response): """Check if response body contains meaningful content. @@ -125,7 +148,8 @@ def _is_empty(self, response): body = response.json() return not body except ValueError: - raise DeserializationError("Response json invalid") + raise DeserializationError( + "Error occurred in deserializing the response body.") def _deserialize(self, response): """Attempt to deserialize resource from response. @@ -136,17 +160,15 @@ def _deserialize(self, response): succeeded. """ self.resource = self.get_outputs(response) - - try: - if failed(self.resource.provisioning_state): - self.status = self.resource.provisioning_state + if self.method == 'PUT': + resource_status = self._get_resource_status() + if failed(resource_status): + self.status = resource_status raise OperationFailed("Operation failed or cancelled") - elif succeeded(self.resource.provisioning_state): + elif succeeded(resource_status): raise OperationFinished("Operation succeeded") - elif self.resource.provisioning_state: - self.status = self.resource.provisioning_state - except AttributeError: - pass + elif resource_status: + self.status = resource_status def _get_body_status(self, response): """Attempt to find status info in response body. @@ -157,7 +179,6 @@ def _get_body_status(self, response): """ if self._is_empty(response): return None - body = response.json() return body.get('status') @@ -183,21 +204,25 @@ def _object_from_response(self, response): :param requests.Response response: latest REST call response. """ body = response.json() - state = body.get('properties', body).get('provisioningState') - - if self.resource is None: + body = {self._convert.sub(r'\1_\2', k).lower(): v + for k, v in body.items()} + properties = body.get('properties') + if properties: + properties = {self._convert.sub(r'\1_\2', k).lower(): v + for k, v in properties.items()} + del body['properties'] + body.update(properties) + self.resource = SimpleResource(**body) + else: self.resource = SimpleResource(**body) - elif state: - if hasattr(self.resource, 'provisioning_state'): - self.resource.provisioning_state = state def _process_status(self, response): """Process response based on specific status code. :param requests.Response response: latest REST call response. """ - method = getattr(self, '_status_' + str(response.status_code)) - method(response) + process = getattr(self, '_status_' + str(response.status_code)) + process(response) def _status_200(self, response): """Process response with status code 200. @@ -206,11 +231,11 @@ def _status_200(self, response): """ status = self._get_body_status(response) self.status = status if status else 'Succeeded' - if not status: - try: - # Even if this fails, status '200' should be successful. - self._deserialize(response) - except CloudError: + try: + # Even if this fails, status '200' should be successful. + self._deserialize(response) + except CloudError: + if self.method in ['PUT', 'PATCH'] and not status: self._object_from_response(response) def _status_201(self, response): @@ -250,6 +275,9 @@ def is_done(self): :rtype: bool """ + if (self.async_url or not self.resource) and \ + self.method in ['PUT', 'PATCH']: + return False resouce_state = self._get_resource_status() try: return self.status.lower() == resouce_state.lower() @@ -263,71 +291,22 @@ def get_initial_status(self, response): :param requests.Response response: initial REST call response. """ self._check_status(response) - if response.status_code == 204: - self._status_204(response) - return - try: - self._deserialize(response) - if self.status: - return - except CloudError as err: - raise BadStatus(str(err)) - + if response.status_code in [200, 202, 204]: + self._process_status(response) status = self._get_body_status(response) if status: self.status = status - if response.status_code in [200, 202]: - self._process_status(response) - - def get_retry(self, response, *args): - """Retrieve the URL that will be polled for status. First looks for - 'azure-asyncoperation' header, if not found or invalid, check for - 'location' header. - - :param requests.Response response: latest REST call response. - """ try: - self.async_url = self._validate( - response.headers.get('azure-asyncoperation')) - - # Return if we have a url, in case location header raises error. - if self.async_url: - return - except ValueError: - pass # We can ignore as location header may still be valid. - self.location_url = self._validate(response.headers.get('location')) - - -class PostDeleteOperation(LongRunningOperationMixin): - """LongRunningOperation object for a POST or DELETE request. - - :param requests.Response response: initial REST call response. - :param callable outputs: Function to deserialize operation resource. - """ - - def __init__(self, response, outputs): - self.method = response.request.method - self.status = "" - self.resource = None - self.get_outputs = outputs - self.async_url = None - self.location_url = None - - def _check_status(self, response): - """Check response status code is valid for a Put or Patch - reqest. Must be 200, 202, or 204. - - :raises: BadStatus if invalid status. - """ - if response.status_code not in [200, 202, 204]: - raise BadStatus( - "Invalid return status for 'POST' or 'DELETE' call") + self._deserialize(response) + except CloudError: + pass def get_status_from_location(self, response): """Process the latest status update retrieved from a 'location' header. :param requests.Response response: latest REST call response. + :raises: BadResponse if response has no body and not status 202. """ self._check_status(response) self._process_status(response) @@ -340,11 +319,10 @@ def get_status_from_resource(self, response): :raises: BadResponse if status not 200 or 204. """ self._check_status(response) - if response.status_code in [200, 204]: - self._process_status(response) - else: - raise BadResponse('Location header is missing from ' - 'long running operation.') + if self._is_empty(response) and self.method in ['PUT', 'PATCH']: + raise BadResponse('The response from long running ' + 'operation does not contain a body.') + self._process_status(response) def get_status_from_async(self, response): """Process the latest status update retrieved from a @@ -367,18 +345,23 @@ def get_status_from_async(self, response): except CloudError: pass # Not all 'accept' statuses will deserialize. - def _object_from_response(self, response): - """For a POST of DELETE request, there's no need to attempt - resource deserialization. - """ - pass + def get_retry(self, response, *args): + """Retrieve the URL that will be polled for status. First looks for + 'azure-asyncoperation' header, if not found or invalid, check for + 'location' header. - def get_retry(self, response, first_call): - """Add addtional logic to super get_retry to accommodate POST - calls which must fail if no 'Location' or 'Async' headers are found - and status code is 202. + :param requests.Response response: latest REST call response. """ - super(PostDeleteOperation, self).get_retry(response) + try: + self.async_url = self._validate( + response.headers.get('azure-asyncoperation')) + + # Return if we have a url, in case location header raises error. + if self.async_url: + return + except ValueError: + pass # We can ignore as location header may still be valid. + self.location_url = self._validate(response.headers.get('location')) if not self.location_url and not self.async_url: code = response.status_code if code == 202 and self.method == 'POST': @@ -386,89 +369,6 @@ def get_retry(self, response, first_call): 'Location header is missing from long running operation.') -class PutPatchOperation(LongRunningOperationMixin): - """LongRunningOperation object for a PUT or PATCH request. - - :param requests.Response response: initial REST call response. - :param callable outputs: Function to deserialize operation resource. - """ - - def __init__(self, response, outputs): - self.status = "" - self.resource = None - self.get_outputs = outputs - self.async_url = None - self.location_url = None - - def _check_status(self, response): - """Check response status code is valid for a Put or Patch - reqest. Must be 200, 201, or 202. - - :raises: BadStatus if invalid status. - """ - if response.status_code not in [200, 201, 202]: - raise BadStatus("Invalid return status for 'PUT' or 'PATCH' call") - - def is_done(self): - """Check whether the operation can be considered complete. - For a PUT or PATCH function, result should include a deserialized - payload. - - :rtype: bool - """ - is_done = super(PutPatchOperation, self).is_done() - if not self.resource: - return False - return is_done - - def get_status_from_location(self, response): - """Process the latest status update retrieved from a 'location' - header. - - :param requests.Response response: latest REST call response. - :raises: BadResponse if response has no body and not status 202. - """ - self._check_status(response) - if response.status_code == 202: - self._status_202(response) - else: - if self._is_empty(response): - raise BadResponse('The response from long running ' - 'operation does not contain a body.') - self._process_status(response) - - def get_status_from_resource(self, response): - """Process the latest status update retrieved from the same URL as - the previous request. - - :param requests.Response response: latest REST call response. - :raises: BadResponse if response has no body. - """ - self._check_status(response) - if self._is_empty(response): - raise BadResponse('The response from long running operation ' - 'does not contain a body.') - - self._status_200(response) - - def get_status_from_async(self, response): - """Process the latest status update retrieved from a - 'azure-asyncoperation' header. - - :param requests.Response response: latest REST call response. - :raises: BadResponse if response has no body, or body does not - contain status. - """ - self._check_status(response) - if self._is_empty(response): - raise BadResponse('The response from long running operation ' - 'does not contain a body.') - - self.status = self._get_body_status(response) - if not self.status: - raise BadResponse("No status found in body") - - class AzureOperationPoller(object): """Initiates long running operation and polls status in separate thread. @@ -484,11 +384,6 @@ class AzureOperationPoller(object): argument, a completed LongRunningOperation (optional). """ - operations = {'PUT': PutPatchOperation, - 'PATCH': PutPatchOperation, - 'POST': PostDeleteOperation, - 'DELETE': PostDeleteOperation} - def __init__(self, send_cmd, output_cmd, update_cmd, timeout=30): self._timeout = timeout self._response = None @@ -512,15 +407,9 @@ def _start(self, send_cmd, update_cmd, output_cmd): """ try: self._response = send_cmd() - try: - op_type = self.operations[self._response.request.method] - self._operation = op_type(self._response, output_cmd) - self._operation.get_initial_status(self._response) - except KeyError: - error = "Request type {!r} is not a valid polling request" - raise TypeError(error.format(self._response.request.method)) - else: - self._poll(update_cmd) + self._operation = LongRunningOperation(self._response, output_cmd) + self._operation.get_initial_status(self._response) + self._poll(update_cmd) except BadStatus: self._operation.status = 'Failed' diff --git a/ClientRuntimes/Python/msrestazure/test/unittest_auth.py b/ClientRuntimes/Python/msrestazure/test/unittest_auth.py index e7fe73f67b2ec..363b4b835d5ee 100644 --- a/ClientRuntimes/Python/msrestazure/test/unittest_auth.py +++ b/ClientRuntimes/Python/msrestazure/test/unittest_auth.py @@ -92,7 +92,7 @@ def test_https(self): def test_check_state(self): - mix = AADMixin() + mix = AADMixin(None, None) mix.state = "abc" with self.assertRaises(ValueError): @@ -107,10 +107,22 @@ def test_check_state(self): mix._check_state("server?test&state=abcd&") mix._check_state("server?test&state=abc&") + def test_convert_token(self): + + mix = AADMixin(None, None) + token = {'access_token':'abc', 'expires_on':123, 'refresh_token':'asd'} + self.assertEqual(mix._convert_token(token), token) + + caps = {'accessToken':'abc', 'expiresOn':123, 'refreshToken':'asd'} + self.assertEqual(mix._convert_token(caps), token) + + caps = {'ACCessToken':'abc', 'Expires_On':123, 'REFRESH_TOKEN':'asd'} + self.assertEqual(mix._convert_token(caps), token) + @mock.patch('msrestazure.azure_active_directory.keyring') def test_store_token(self, mock_keyring): - mix = AADMixin() + mix = AADMixin(None, None) mix.cred_store = "store_name" mix.store_key = "client_id" mix._default_token_cache({'token_type':'1', 'access_token':'2'}) @@ -122,7 +134,7 @@ def test_store_token(self, mock_keyring): @mock.patch('msrestazure.azure_active_directory.keyring') def test_clear_token(self, mock_keyring): - mix = AADMixin() + mix = AADMixin(None, None) mix.cred_store = "store_name" mix.store_key = "client_id" mix.clear_cached_token() @@ -133,7 +145,7 @@ def test_clear_token(self, mock_keyring): @mock.patch('msrestazure.azure_active_directory.keyring') def test_credentials_get_stored_auth(self, mock_keyring): - mix = AADMixin() + mix = AADMixin(None, None) mix.cred_store = "store_name" mix.store_key = "client_id" mix.signed_session = mock.Mock() diff --git a/ClientRuntimes/Python/msrestazure/test/unittest_operation.py b/ClientRuntimes/Python/msrestazure/test/unittest_operation.py index f1a273785d681..01879ecf060e8 100644 --- a/ClientRuntimes/Python/msrestazure/test/unittest_operation.py +++ b/ClientRuntimes/Python/msrestazure/test/unittest_operation.py @@ -25,39 +25,287 @@ #-------------------------------------------------------------------------- import json +import re import unittest try: from unittest import mock except ImportError: import mock -from requests import Response +from requests import Request, Response from msrest import Deserializer -from msrest.exceptions import RequestException +from msrest.exceptions import RequestException, DeserializationError from msrestazure.azure_exceptions import CloudError from msrestazure.azure_operation import ( - PostDeleteOperation, - PutPatchOperation, - AzureOperationPoller) + LongRunningOperation, + AzureOperationPoller, + BadStatus, + SimpleResource) +class BadEndpointError(Exception): + pass + +TEST_NAME = 'foo' +RESPONSE_BODY = {'properties':{'provisioningState': 'InProgress'}} +ASYNC_BODY = json.dumps({ 'status': 'Succeeded' }) +ASYNC_URL = 'http://dummyurlFromAzureAsyncOPHeader_Return200' +LOCATION_BODY = json.dumps({ 'status': 'Succeeded', 'name': TEST_NAME }) +LOCATION_URL = 'http://dummyurlurlFromLocationHeader_Return200' +RESOURCE_BODY = json.dumps({ 'status': 'Succeeded', 'name': TEST_NAME }) +RESOURCE_URL = 'http://subscriptions/sub1/resourcegroups/g1/resourcetype1/resource1' +ERROR = 'http://dummyurl_ReturnError' +POLLING_STATUS = 200 class TestLongRunningOperation(unittest.TestCase): - def test_long_running_operation(self): - + convert = re.compile('([a-z0-9])([A-Z])') + + @staticmethod + def mock_send(method, status, headers, body=None): response = mock.create_autospec(Response) - response.status_code = 400 - response.reason = 'BadRequest' + response.request = mock.create_autospec(Request) + response.request.method = method + response.request.url = RESOURCE_URL + response.status_code = status + response.headers = headers + content = body if body else RESPONSE_BODY + response.content = json.dumps(content) + response.json = lambda: json.loads(response.content) + return lambda: response + + @staticmethod + def mock_update(url, headers=None): + response = mock.create_autospec(Response) + response.request = mock.create_autospec(Request) + response.request.method = 'GET' + + if url == ASYNC_URL: + response.request.url = url + response.status_code = POLLING_STATUS + response.content = ASYNC_BODY + response.randomFieldFromPollAsyncOpHeader = None + + elif url == LOCATION_URL: + response.request.url = url + response.status_code = POLLING_STATUS + response.content = LOCATION_BODY + response.randomFieldFromPollLocationHeader = None + + elif url == ERROR: + raise BadEndpointError("boom") - message = { - 'code': '500', - 'message': {'value': 'Bad Request\nRequest:34875\nTime:1999-12-31T23:59:59-23:59'}, - 'values': {'invalid_attribute':'data'} - } + elif url == RESOURCE_URL: + response.request.url = url + response.status_code = POLLING_STATUS + response.content = RESOURCE_BODY - response.content = json.dumps(message) + else: + raise Exception('URL does not match') response.json = lambda: json.loads(response.content) + return response + + @staticmethod + def mock_outputs(response): + body = response.json() + body = {TestLongRunningOperation.convert.sub(r'\1_\2', k).lower(): v + for k, v in body.items()} + properties = body.get('properties') + if properties: + properties = {TestLongRunningOperation.convert.sub(r'\1_\2', k).lower(): v + for k, v in properties.items()} + del body['properties'] + body.update(properties) + resource = SimpleResource(**body) + else: + resource = SimpleResource(**body) + return resource + + def test_long_running_put(self): + #TODO: Test custom header field + + # Test throw on non LRO related status code + response = TestLongRunningOperation.mock_send('PUT', 1000, {}) + op = LongRunningOperation(response(), lambda x:None) + with self.assertRaises(BadStatus): + op.get_initial_status(response()) + with self.assertRaises(CloudError): + AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0).result() + + # Test polling from azure-asyncoperation header + response = TestLongRunningOperation.mock_send( + 'PUT', 201, + {'azure-asyncoperation': ASYNC_URL}) + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0) + self.assertEqual(poll.result().name, TEST_NAME) + self.assertFalse(hasattr(poll._response, 'randomFieldFromPollAsyncOpHeader')) + + # Test polling location header + response = TestLongRunningOperation.mock_send( + 'PUT', 201, + {'location': LOCATION_URL}) + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0) + self.assertEqual(poll.result().name, TEST_NAME) + self.assertIsNone(poll._response.randomFieldFromPollLocationHeader) + + # Test fail to poll from azure-asyncoperation header + response = TestLongRunningOperation.mock_send( + 'PUT', 201, + {'azure-asyncoperation': ERROR}) + with self.assertRaises(BadEndpointError): + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0).result() + + # Test fail to poll from location header + response = TestLongRunningOperation.mock_send( + 'PUT', 201, + {'location': ERROR}) + with self.assertRaises(BadEndpointError): + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0).result() + + def test_long_running_patch(self): + + # Test polling from location header + response = TestLongRunningOperation.mock_send( + 'PATCH', 202, + {'location': LOCATION_URL}, + body={'properties':{'provisioningState': 'Succeeded'}}) + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0) + self.assertEqual(poll.result().name, TEST_NAME) + self.assertIsNone(poll._response.randomFieldFromPollLocationHeader) + + # Test polling from azure-asyncoperation header + response = TestLongRunningOperation.mock_send( + 'PATCH', 202, + {'azure-asyncoperation': ASYNC_URL}, + body={'properties':{'provisioningState': 'Succeeded'}}) + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0) + self.assertEqual(poll.result().name, TEST_NAME) + self.assertFalse(hasattr(poll._response, 'randomFieldFromPollAsyncOpHeader')) + + # Test fail to poll from azure-asyncoperation header + response = TestLongRunningOperation.mock_send( + 'PATCH', 202, + {'azure-asyncoperation': ERROR}) + with self.assertRaises(BadEndpointError): + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0).result() + + # Test fail to poll from location header + response = TestLongRunningOperation.mock_send( + 'PATCH', 202, + {'location': ERROR}) + with self.assertRaises(BadEndpointError): + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0).result() + + def test_long_running_post_delete(self): + + # Test throw on non LRO related status code + response = TestLongRunningOperation.mock_send('POST', 201, {}) + op = LongRunningOperation(response(), lambda x:None) + with self.assertRaises(BadStatus): + op.get_initial_status(response()) + with self.assertRaises(CloudError): + AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0).result() + + # Test polling from azure-asyncoperation header + response = TestLongRunningOperation.mock_send( + 'POST', 202, + {'azure-asyncoperation': ASYNC_URL}, + body={'properties':{'provisioningState': 'Succeeded'}}) + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0) + poll.wait() + #self.assertIsNone(poll.result()) + self.assertIsNone(poll._response.randomFieldFromPollAsyncOpHeader) + + # Test polling from location header + response = TestLongRunningOperation.mock_send( + 'POST', 202, + {'location': LOCATION_URL}, + body={'properties':{'provisioningState': 'Succeeded'}}) + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0) + self.assertEqual(poll.result().name, TEST_NAME) + self.assertIsNone(poll._response.randomFieldFromPollLocationHeader) + + # Test fail to poll from azure-asyncoperation header + response = TestLongRunningOperation.mock_send( + 'POST', 202, + {'azure-asyncoperation': ERROR}) + with self.assertRaises(BadEndpointError): + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0).result() + + # Test fail to poll from location header + response = TestLongRunningOperation.mock_send( + 'POST', 202, + {'location': ERROR}) + with self.assertRaises(BadEndpointError): + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0).result() + + def test_long_running_negative(self): + global LOCATION_BODY + global POLLING_STATUS + + # Test LRO PUT throws for invalid json + LOCATION_BODY = '{' + response = TestLongRunningOperation.mock_send( + 'POST', 202, + {'location': LOCATION_URL}) + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0) + with self.assertRaises(DeserializationError): + poll.wait() + + LOCATION_BODY = '{\'"}' + response = TestLongRunningOperation.mock_send( + 'POST', 202, + {'location': LOCATION_URL}) + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0) + with self.assertRaises(DeserializationError): + poll.wait() + + LOCATION_BODY = '{' + POLLING_STATUS = 203 + response = TestLongRunningOperation.mock_send( + 'POST', 202, + {'location': LOCATION_URL}) + poll = AzureOperationPoller(response, + TestLongRunningOperation.mock_outputs, + TestLongRunningOperation.mock_update, 0) + with self.assertRaises(CloudError): # TODO: Node.js raises on deserialization + poll.wait() + + LOCATION_BODY = json.dumps({ 'status': 'Succeeded', 'name': TEST_NAME }) + POLLING_STATUS = 200 + if __name__ == '__main__':