Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions src/DotNetCore.CAP.AmazonSQS/AmazonPolicyExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Amazon.Auth.AccessControlPolicy;
using Amazon.Auth.AccessControlPolicy.ActionIdentifiers;

namespace DotNetCore.CAP.AmazonSQS
{
public static class AmazonPolicyExtensions
{
/// <summary>
/// Check to see if the policy for the queue has already given permission to the topic.
/// </summary>
/// <param name="policy"></param>
/// <param name="topicArn"></param>
/// <param name="sqsQueueArn"></param>
/// <returns></returns>
public static bool HasSqsPermission(this Policy policy, string topicArn, string sqsQueueArn)
{
foreach (var statement in policy.Statements)
{
var containsResource = statement.Resources.Any(r => r.Id.Equals(sqsQueueArn));

if (!containsResource)
{
continue;
}

foreach (var condition in statement.Conditions)
{
if ((string.Equals(condition.Type, ConditionFactory.StringComparisonType.StringLike.ToString(), StringComparison.OrdinalIgnoreCase) ||
string.Equals(condition.Type, ConditionFactory.StringComparisonType.StringEquals.ToString(), StringComparison.OrdinalIgnoreCase) ||
string.Equals(condition.Type, ConditionFactory.ArnComparisonType.ArnEquals.ToString(), StringComparison.OrdinalIgnoreCase) ||
string.Equals(condition.Type, ConditionFactory.ArnComparisonType.ArnLike.ToString(), StringComparison.OrdinalIgnoreCase)) &&
string.Equals(condition.ConditionKey, ConditionFactory.SOURCE_ARN_CONDITION_KEY, StringComparison.OrdinalIgnoreCase) &&
condition.Values.Contains(topicArn))
{
return true;
}
}
}

return false;
}

/// <summary>
/// Add statement to the SQS policy that gives the SNS topics access to send a message to the queue.
/// </summary>
/// <code>
/// {
/// "Version": "2012-10-17",
/// "Statement": [
/// {
/// "Effect": "Allow",
/// "Principal": {
/// "AWS": "*"
/// },
/// "Action": "sqs:SendMessage",
/// "Resource": "arn:aws:sqs:us-east-1:MyQueue",
/// "Condition": {
/// "ArnLike": {
/// "aws:SourceArn": [
/// "arn:aws:sns:us-east-1:FirstTopic",
/// "arn:aws:sns:us-east-1:SecondTopic"
/// ]
/// }
/// }
/// }]
/// }
/// </code>
/// <param name="policy"></param>
/// <param name="topicArns"></param>
/// <param name="sqsQueueArn"></param>
public static void AddSqsPermissions(this Policy policy, IEnumerable<string> topicArns, string sqsQueueArn)
{
var statement = new Statement(Statement.StatementEffect.Allow);
statement.Actions.Add(SQSActionIdentifiers.SendMessage);
statement.Resources.Add(new Resource(sqsQueueArn));
statement.Principals.Add(new Principal("*"));
foreach (var topicArn in topicArns)
{
statement.Conditions.Add(ConditionFactory.NewSourceArnCondition(topicArn));
}

policy.Statements.Add(statement);
}

/// <summary>
/// Compact SQS access policy
/// </summary>
/// <para>
/// Transforms policies with multiple similar statements:
/// <code>
/// {
/// "Version": "2012-10-17",
/// "Statement": [
/// {
/// "Effect": "Allow",
/// "Principal": {
/// "AWS": "*"
/// },
/// "Action": "sqs:SendMessage",
/// "Resource": "arn:aws:sqs:us-east-1:MyQueue",
/// "Condition": {
/// "ArnLike": {
/// "aws:SourceArn": "arn:aws:sns:us-east-1:FirstTopic"
/// }
/// }
/// },
/// {
/// "Effect": "Allow",
/// "Principal": {
/// "AWS": "*"
/// },
/// "Action": "sqs:SendMessage",
/// "Resource": "arn:aws:sqs:us-east-1:MyQueue",
/// "Condition": {
/// "ArnLike": {
/// "aws:SourceArn": "arn:aws:sns:us-east-1:SecondTopic"
/// }
/// }
/// }]
/// }
/// </code>
/// into compacted single statement:
/// <code>
/// {
/// "Version": "2012-10-17",
/// "Statement": [
/// {
/// "Effect": "Allow",
/// "Principal": {
/// "AWS": "*"
/// },
/// "Action": "sqs:SendMessage",
/// "Resource": "arn:aws:sqs:us-east-1:MyQueue",
/// "Condition": {
/// "ArnLike": {
/// "aws:SourceArn": [
/// "arn:aws:sns:us-east-1:FirstTopic",
/// "arn:aws:sns:us-east-1:SecondTopic"
/// ]
/// }
/// }
/// }]
/// }
/// </code>
/// </para>
/// <param name="policy"></param>
/// <param name="sqsQueueArn"></param>
public static void CompactSqsPermissions(this Policy policy, string sqsQueueArn)
{
var statementsToCompact = policy.Statements
.Where(s => s.Effect == Statement.StatementEffect.Allow)
.Where(s => s.Actions.All(a => string.Equals(a.ActionName, SQSActionIdentifiers.SendMessage.ActionName, StringComparison.OrdinalIgnoreCase)))
.Where(s => s.Resources.All(r => string.Equals(r.Id, sqsQueueArn, StringComparison.OrdinalIgnoreCase)))
.Where(s => s.Principals.All(r => string.Equals(r.Id, "*", StringComparison.OrdinalIgnoreCase)))
.ToList();

if (statementsToCompact.Count < 2)
{
return;
}

var topicArns = new HashSet<string>();
foreach (var statement in statementsToCompact)
{
policy.Statements.Remove(statement);
foreach (var topicArn in statement.Conditions.SelectMany(c => c.Values))
{
topicArns.Add(topicArn);
}
}

policy.AddSqsPermissions(topicArns, sqsQueueArn);
}
}
}
58 changes: 50 additions & 8 deletions src/DotNetCore.CAP.AmazonSQS/AmazonSQSConsumerClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Amazon.Auth.AccessControlPolicy;
using Amazon.SimpleNotificationService;
using Amazon.SimpleNotificationService.Model;
using Amazon.SQS;
Expand Down Expand Up @@ -42,28 +43,41 @@ public AmazonSQSConsumerClient(string groupId, IOptions<AmazonSQSOptions> option

public BrokerAddress BrokerAddress => new BrokerAddress("AmazonSQS", _queueUrl);

public void Subscribe(IEnumerable<string> topics)
public ICollection<string> FetchTopics(IEnumerable<string> topicNames)
{
if (topics == null)
if (topicNames == null)
{
throw new ArgumentNullException(nameof(topics));
throw new ArgumentNullException(nameof(topicNames));
}

Connect(initSNS: true, initSQS: false);

var topicArns = new List<string>();
foreach (var topic in topics)
foreach (var topic in topicNames)
{
var createTopicRequest = new CreateTopicRequest(topic.NormalizeForAws());

var createTopicResponse = _snsClient.CreateTopicAsync(createTopicRequest).GetAwaiter().GetResult();

topicArns.Add(createTopicResponse.TopicArn);
}

GenerateSqsAccessPolicyAsync(topicArns)
.GetAwaiter().GetResult();

Connect(initSNS: false, initSQS: true);
return topicArns;
}

_snsClient.SubscribeQueueToTopicsAsync(topicArns, _sqsClient, _queueUrl)
public void Subscribe(IEnumerable<string> topics)
{
if (topics == null)
{
throw new ArgumentNullException(nameof(topics));
}

Connect();

_snsClient.SubscribeQueueToTopicsAsync(topics.ToList(), _sqsClient, _queueUrl)
.GetAwaiter().GetResult();
}

Expand Down Expand Up @@ -207,6 +221,34 @@ private Task MessageNotInflightLog(string exceptionMessage)
return Task.CompletedTask;
}

private async Task GenerateSqsAccessPolicyAsync(IEnumerable<string> topicArns)
{
Connect(initSNS: false, initSQS: true);

var queueAttributes = await _sqsClient.GetAttributesAsync(_queueUrl).ConfigureAwait(false);

var sqsQueueArn = queueAttributes["QueueArn"];

var policy = queueAttributes.TryGetValue("Policy", out var policyStr) && !string.IsNullOrEmpty(policyStr)
? Policy.FromJson(policyStr)
: new Policy();

var topicArnsToAllow = topicArns
.Where(a => !policy.HasSqsPermission(a, sqsQueueArn))
.ToList();

if (!topicArnsToAllow.Any())
{
return;
}

policy.AddSqsPermissions(topicArnsToAllow, sqsQueueArn);
policy.CompactSqsPermissions(sqsQueueArn);

var setAttributes = new Dictionary<string, string> { { "Policy", policy.ToJson() } };
await _sqsClient.SetAttributesAsync(_queueUrl, setAttributes).ConfigureAwait(false);
}

#endregion
}
}
11 changes: 9 additions & 2 deletions src/DotNetCore.CAP/Internal/IConsumerRegister.Default.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
// Licensed under the MIT License. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Runtime.Serialization;
using System.Threading;
using System.Threading.Tasks;
using DotNetCore.CAP.Diagnostics;
Expand Down Expand Up @@ -67,8 +67,15 @@ public void Start()

foreach (var matchGroup in groupingMatches)
{
ICollection<string> topics;
using (var client = _consumerClientFactory.Create(matchGroup.Key))
{
topics = client.FetchTopics(matchGroup.Value.Select(x => x.TopicName));
}

for (int i = 0; i < _options.ConsumerThreadCount; i++)
{
var topicIds = topics.Select(t => t);
Task.Factory.StartNew(() =>
{
try
Expand All @@ -79,7 +86,7 @@ public void Start()

RegisterMessageProcessor(client);

client.Subscribe(matchGroup.Value.Select(x => x.TopicName));
client.Subscribe(topicIds);

client.Listening(_pollingDelay, _cts.Token);
}
Expand Down
11 changes: 11 additions & 0 deletions src/DotNetCore.CAP/Transport/IConsumerClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using DotNetCore.CAP.Messages;
using JetBrains.Annotations;
Expand All @@ -17,6 +18,16 @@ public interface IConsumerClient : IDisposable
{
BrokerAddress BrokerAddress { get; }

/// <summary>
/// Create (if necessary) and get topic identifiers
/// </summary>
/// <param name="topicNames">Names of the requested topics</param>
/// <returns>Topic identifiers</returns>
ICollection<string> FetchTopics(IEnumerable<string> topicNames)
{
return topicNames.ToList();
}

/// <summary>
/// Subscribe to a set of topics to the message queue
/// </summary>
Expand Down