Skip to content

Commit

Permalink
Cleanup SNS exceptions. Closes #751.
Browse files Browse the repository at this point in the history
  • Loading branch information
spulec committed Mar 17, 2017
1 parent e7a3f34 commit c207963
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 3 deletions.
16 changes: 16 additions & 0 deletions moto/sns/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,19 @@ class SNSNotFoundError(RESTError):
def __init__(self, message):
super(SNSNotFoundError, self).__init__(
"NotFound", message)


class DuplicateSnsEndpointError(RESTError):
code = 400

def __init__(self, message):
super(DuplicateSnsEndpointError, self).__init__(
"DuplicateEndpoint", message)


class SnsEndpointDisabled(RESTError):
code = 400

def __init__(self, message):
super(SnsEndpointDisabled, self).__init__(
"EndpointDisabled", message)
13 changes: 12 additions & 1 deletion moto/sns/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.sqs import sqs_backends
from .exceptions import SNSNotFoundError
from .exceptions import (
SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled
)
from .utils import make_arn_for_topic, make_arn_for_subscription

DEFAULT_ACCOUNT_ID = 123456789012
Expand Down Expand Up @@ -136,6 +138,10 @@ def __fixup_attributes(self):
if 'Enabled' not in self.attributes:
self.attributes['Enabled'] = True

@property
def enabled(self):
return json.loads(self.attributes.get('Enabled', 'true').lower())

@property
def arn(self):
return "arn:aws:sns:{region}:123456789012:endpoint/{platform}/{name}/{id}".format(
Expand All @@ -146,6 +152,9 @@ def arn(self):
)

def publish(self, message):
if not self.enabled:
raise SnsEndpointDisabled("Endpoint %s disabled" % self.id)

# This is where we would actually send a message
message_id = six.text_type(uuid.uuid4())
self.messages[message_id] = message
Expand Down Expand Up @@ -251,6 +260,8 @@ def delete_platform_application(self, platform_arn):
self.applications.pop(platform_arn)

def create_platform_endpoint(self, region, application, custom_user_data, token, attributes):
if any(token == endpoint.token for endpoint in self.platform_endpoints.values()):
raise DuplicateSnsEndpointError("Duplicate endpoint token: %s" % token)
platform_endpoint = PlatformEndpoint(
region, application, custom_user_data, token, attributes)
self.platform_endpoints[platform_endpoint.arn] = platform_endpoint
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sns/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_publish_to_platform_endpoint():
token="some_unique_id",
custom_user_data="some user data",
attributes={
"Enabled": False,
"Enabled": True,
},
)

Expand Down
59 changes: 58 additions & 1 deletion tests/test_sns/test_application_boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,35 @@ def test_create_platform_endpoint():
"arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/")


@mock_sns
def test_create_duplicate_platform_endpoint():
conn = boto3.client('sns', region_name='us-east-1')
platform_application = conn.create_platform_application(
Name="my-application",
Platform="APNS",
Attributes={},
)
application_arn = platform_application['PlatformApplicationArn']

endpoint = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="some_unique_id",
CustomUserData="some user data",
Attributes={
"Enabled": 'false',
},
)

endpoint = conn.create_platform_endpoint.when.called_with(
PlatformApplicationArn=application_arn,
Token="some_unique_id",
CustomUserData="some user data",
Attributes={
"Enabled": 'false',
},
).should.throw(ClientError)


@mock_sns
def test_get_list_endpoints_by_platform_application():
conn = boto3.client('sns', region_name='us-east-1')
Expand Down Expand Up @@ -256,11 +285,39 @@ def test_publish_to_platform_endpoint():
Token="some_unique_id",
CustomUserData="some user data",
Attributes={
"Enabled": 'false',
"Enabled": 'true',
},
)

endpoint_arn = endpoint['EndpointArn']

conn.publish(Message="some message",
MessageStructure="json", TargetArn=endpoint_arn)


@mock_sns
def test_publish_to_disabled_platform_endpoint():
conn = boto3.client('sns', region_name='us-east-1')
platform_application = conn.create_platform_application(
Name="my-application",
Platform="APNS",
Attributes={},
)
application_arn = platform_application['PlatformApplicationArn']

endpoint = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="some_unique_id",
CustomUserData="some user data",
Attributes={
"Enabled": 'false',
},
)

endpoint_arn = endpoint['EndpointArn']

conn.publish.when.called_with(
Message="some message",
MessageStructure="json",
TargetArn=endpoint_arn,
).should.throw(ClientError)

0 comments on commit c207963

Please sign in to comment.