diff --git a/src/Connections/RoundRobinServerManager.cs b/src/Connections/RoundRobinServerManager.cs index e54e835..3cc033b 100644 --- a/src/Connections/RoundRobinServerManager.cs +++ b/src/Connections/RoundRobinServerManager.cs @@ -4,112 +4,117 @@ namespace FluentCassandra.Connections { - public class RoundRobinServerManager : IServerManager - { - private readonly object _lock = new object(); - - private List _servers; - private Queue _serverQueue; - private HashSet _blackListed; - - public RoundRobinServerManager(IConnectionBuilder builder) - { - _servers = new List(builder.Servers); - _serverQueue = new Queue(_servers); - _blackListed = new HashSet(); - } - - private bool IsBlackListed(Server server) - { - return _blackListed.Contains(server); - } - - #region IServerManager Members - - public bool HasNext - { - get { lock (_lock) { return (_serverQueue.Count - _blackListed.Count) > 0; } } - } - - public Server Next() - { - Server server; - - lock (_lock) - { - do - { - server = _serverQueue.Dequeue(); - - if (IsBlackListed(server)) - server = null; - else - _serverQueue.Enqueue(server); - } - while (_serverQueue.Count > 0 && server == null); - } - - return server; - } - - public void Add(Server server) - { - lock (_lock) - { - _servers.Add(server); - _serverQueue.Enqueue(server); - } - } - - public void ErrorOccurred(Server server, Exception exc = null) - { - Debug.WriteLineIf(exc != null, exc, "connection"); - BlackList(server); - } - - public void BlackList(Server server) - { - Debug.WriteLine(server + " has been blacklisted", "connection"); - lock (_lock) - { - _blackListed.Add(server); - } - } - - public void Remove(Server server) - { - lock (_lock) - { - _servers.Remove(server); - _serverQueue = new Queue(); - _blackListed.RemoveWhere(x => x == server); - - foreach (var s in _servers) - { - if (!_blackListed.Contains(s)) - _serverQueue.Enqueue(s); - } - } - } - - #endregion - - #region IEnumerable Members - - public IEnumerator GetEnumerator() - { - return _servers.GetEnumerator(); - } - - #endregion - - #region IEnumerable Members - - System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - - #endregion - } + public class RoundRobinServerManager : IServerManager + { + private readonly object _lock = new object(); + + private List _servers; + private Queue _serverQueue; + private HashSet _blackListed; + + public RoundRobinServerManager(IConnectionBuilder builder) + { + _servers = new List(builder.Servers); + _serverQueue = new Queue(_servers); + _blackListed = new HashSet(); + } + + private bool IsBlackListed(Server server) + { + return _blackListed.Contains(server); + } + + #region IServerManager Members + + public bool HasNext + { + get { lock (_lock) { return _serverQueue.Count > 0; } } + } + + public Server Next() + { + Server server = null; + + lock (_lock) + { + if (_serverQueue.Count > 0) + { + server = _serverQueue.Dequeue(); + _serverQueue.Enqueue(server); + } + } + + return server; + } + + public void Add(Server server) + { + lock (_lock) + { + _servers.Add(server); + _serverQueue.Enqueue(server); + } + } + + public void ErrorOccurred(Server server, Exception exc = null) + { + Debug.WriteLineIf(exc != null, exc, "connection"); + BlackList(server); + } + + public void BlackList(Server server) + { + Debug.WriteLine(server + " has been blacklisted", "connection"); + lock (_lock) + { + if (_blackListed.Add(server)) + { + _serverQueue.Clear(); + foreach (Server srv in _servers) + { + if (!IsBlackListed(srv)) + { + _serverQueue.Enqueue(srv); + } + } + } + } + } + + public void Remove(Server server) + { + lock (_lock) + { + _servers.Remove(server); + _serverQueue = new Queue(); + _blackListed.RemoveWhere(x => x == server); + + foreach (var s in _servers) + { + if (!_blackListed.Contains(s)) + _serverQueue.Enqueue(s); + } + } + } + + #endregion + + #region IEnumerable Members + + public IEnumerator GetEnumerator() + { + return _servers.GetEnumerator(); + } + + #endregion + + #region IEnumerable Members + + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + #endregion + } } \ No newline at end of file diff --git a/test/FluentCassandra.Tests/Connections/RoundRobinServerManagerTests.cs b/test/FluentCassandra.Tests/Connections/RoundRobinServerManagerTests.cs new file mode 100644 index 0000000..72c0bbb --- /dev/null +++ b/test/FluentCassandra.Tests/Connections/RoundRobinServerManagerTests.cs @@ -0,0 +1,82 @@ +using System; +using System.Collections.Generic; +using Xunit; + +namespace FluentCassandra.Connections.Tests +{ + public class RoundRobinServerManagerTests + { + [Fact] + public void CanBlackListAndCleanQueueTest() + { + RoundRobinServerManager target = new RoundRobinServerManager(new ConnectionBuilder("Server=unit-test-1,unit-test-2,unit-test-3")); + + Server srv = new Server("unit-test-4"); + target.Add(srv); + + bool gotServer4 = false; + + for (int i = 0; i < 4; i++) + { + Server server = target.Next(); + if (server.ToString().Equals(srv.ToString(), StringComparison.OrdinalIgnoreCase)) + { + gotServer4 = true; + break; + } + } + + Assert.True(gotServer4); + + target.BlackList(srv); + + gotServer4 = false; + for (int i = 0; i < 4; i++) + { + Server server = target.Next(); + if (server.Equals(srv)) + { + gotServer4 = true; + break; + } + } + + Assert.False(gotServer4); + } + + [Fact] + public void HasNextWithMoreThanHalfBlacklistedTest() + { + RoundRobinServerManager target = new RoundRobinServerManager(new ConnectionBuilder("Server=unit-test-1")); + + Server srv1 = null; + Server srv2 = new Server("unit-test-2"); + Server srv3 = new Server("unit-test-3"); + Server srv4 = new Server("unit-test-4"); + target.Add(srv2); + target.Add(srv3); + target.Add(srv4); + List servers = new List { new Server("unit-test-1"), srv2, srv3, srv4 }; + + for (int i = 0; i < 4; i++) + { + Server srv = target.Next(); + Assert.True(servers[i].ToString().Equals(srv.ToString(), StringComparison.OrdinalIgnoreCase)); + if(i == 0) + { + srv1 = srv; + } + } + + target.BlackList(srv2); + target.BlackList(srv3); + Assert.True(target.HasNext); + + target.BlackList(srv1); + Assert.True(target.HasNext); + + target.BlackList(srv4); + Assert.False(target.HasNext); + } + } +} diff --git a/test/FluentCassandra.Tests/FluentCassandra.Tests.csproj b/test/FluentCassandra.Tests/FluentCassandra.Tests.csproj index 41280e3..4756539 100644 --- a/test/FluentCassandra.Tests/FluentCassandra.Tests.csproj +++ b/test/FluentCassandra.Tests/FluentCassandra.Tests.csproj @@ -62,6 +62,7 @@ +