diff --git a/conn.go b/conn.go index 374d95199..a18effee8 100644 --- a/conn.go +++ b/conn.go @@ -1435,7 +1435,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { } for _, row := range rows { - host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}) + host, err := c.session.hostInfoFromMap(row, c.host.ConnectAddress(), c.session.cfg.Port) if err != nil { goto cont } @@ -1495,7 +1495,7 @@ func (c *Conn) localHostInfo(ctx context.Context) (*HostInfo, error) { port := c.conn.RemoteAddr().(*net.TCPAddr).Port // TODO(zariel): avoid doing this here - host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.connectAddress, port: port}) + host, err := c.session.hostInfoFromMap(row, c.host.connectAddress, port) if err != nil { return nil, err } diff --git a/host_source.go b/host_source.go index 811b356f2..c4a54cea6 100644 --- a/host_source.go +++ b/host_source.go @@ -412,10 +412,11 @@ func checkSystemSchema(control *controlConn) (bool, error) { // Given a map that represents a row from either system.local or system.peers // return as much information as we can in *HostInfo -func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) { +func (s *Session) hostInfoFromMap(row map[string]interface{}, connAddr net.IP, connPort int) (*HostInfo, error) { const assertErrorMsg = "Assertion failed for %s" var ok bool + host := HostInfo{connectAddress: connAddr, port: connPort} // Default to our connected port if the cluster doesn't have port information for key, value := range row { switch key { @@ -516,7 +517,7 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (* host.connectAddress = ip host.port = port - return host, nil + return &host, nil } // Ask the control node for host info on all it's known peers @@ -539,7 +540,7 @@ func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) { for _, row := range rows { // extract all available info about the peer - host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port}) + host, err := r.session.hostInfoFromMap(row, nil, r.session.cfg.Port) if err != nil { return nil, err } else if !isValidPeer(host) { @@ -601,7 +602,7 @@ func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) { } for _, row := range rows { - h, err := r.session.hostInfoFromMap(row, &HostInfo{port: port}) + h, err := r.session.hostInfoFromMap(row, nil, port) if err != nil { return nil, err } diff --git a/policies.go b/policies.go index 1446decac..c3751789b 100644 --- a/policies.go +++ b/policies.go @@ -604,7 +604,6 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { } if fallbackIter == nil { - // fallback fallbackIter = t.fallback.Pick(qry) } @@ -770,8 +769,8 @@ type dcAwareRR struct { } // DCAwareRoundRobinPolicy is a host selection policies which will prioritize and -// return hosts which are in the local datacentre before returning hosts in all -// other datercentres +// return hosts which are in the local datacenter before returning hosts in all +// other datacenters func DCAwareRoundRobinPolicy(localDC string) HostSelectionPolicy { return &dcAwareRR{local: localDC} } @@ -850,6 +849,61 @@ func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost { return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get()) } +type rackAwareRRHostPolicy struct { + datacenter string + rack string + rackHosts cowHostList + datacenterHosts cowHostList + otherHosts cowHostList + lastUsedHostIdx uint64 +} + +// RackAwareRRHostPolicy is a host selection policy that prioritizes hosts from +// the local rack over the hosts from the local datacenter over all other hosts. +func RackAwareRRHostPolicy(datacenter, rack string) HostSelectionPolicy { + return &rackAwareRRHostPolicy{datacenter: datacenter, rack: rack} +} + +func (d *rackAwareRRHostPolicy) Init(*Session) {} +func (d *rackAwareRRHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {} +func (d *rackAwareRRHostPolicy) SetPartitioner(p string) {} + +func (d *rackAwareRRHostPolicy) IsLocal(host *HostInfo) bool { + return host.DataCenter() == d.datacenter && host.Rack() == d.rack +} + +func (d *rackAwareRRHostPolicy) AddHost(host *HostInfo) { + if host.DataCenter() == d.datacenter { + if host.Rack() == d.rack { + d.rackHosts.add(host) + return + } + d.datacenterHosts.add(host) + return + } + d.otherHosts.add(host) +} + +func (d *rackAwareRRHostPolicy) RemoveHost(host *HostInfo) { + if host.DataCenter() == d.datacenter { + if host.Rack() == d.rack { + d.rackHosts.remove(host.ConnectAddress()) + return + } + d.datacenterHosts.remove(host.ConnectAddress()) + return + } + d.otherHosts.remove(host.ConnectAddress()) +} + +func (d *rackAwareRRHostPolicy) HostUp(host *HostInfo) { d.AddHost(host) } +func (d *rackAwareRRHostPolicy) HostDown(host *HostInfo) { d.RemoveHost(host) } + +func (d *rackAwareRRHostPolicy) Pick(q ExecutableQuery) NextHost { + nextStartOffset := atomic.AddUint64(&d.lastUsedHostIdx, 1) + return roundRobbin(int(nextStartOffset), d.rackHosts.get(), d.datacenterHosts.get(), d.otherHosts.get()) +} + // ReadyPolicy defines a policy for when a HostSelectionPolicy can be used. After // each host connects during session initialization, the Ready method will be // called. If you only need a single Host to be up you can wrap a @@ -969,6 +1023,8 @@ type SpeculativeExecutionPolicy interface { type NonSpeculativeExecution struct{} +var nonSpeculativeExecution NonSpeculativeExecution + func (sp NonSpeculativeExecution) Attempts() int { return 0 } // No additional attempts func (sp NonSpeculativeExecution) Delay() time.Duration { return 1 } // The delay. Must be positive to be used in a ticker. diff --git a/policies_test.go b/policies_test.go index 826b2cd9d..645ccad80 100644 --- a/policies_test.go +++ b/policies_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/hailocab/go-hostpool" + "github.com/stretchr/testify/require" ) // Tests of the round-robin host selection policy implementation @@ -619,3 +620,134 @@ func TestHostPolicy_TokenAware_NetworkStrategy(t *testing.T) { iterCheck(t, iter, "6") iterCheck(t, iter, "8") } + +func TestHostPolicy_RackAwareRR_IsLocal(t *testing.T) { + p := RackAwareRRHostPolicy("dc2", "rack2") + + require.False(t, p.IsLocal(&HostInfo{dataCenter: "dc1", rack: "rack1"})) + require.False(t, p.IsLocal(&HostInfo{dataCenter: "dc1", rack: "rack2"})) + require.False(t, p.IsLocal(&HostInfo{dataCenter: "dc2", rack: "rack1"})) + require.True(t, p.IsLocal(&HostInfo{dataCenter: "dc2", rack: "rack2"})) +} + +// Hosts from the same rack are preferred. +func TestHostPolicy_RackAwareRR_SameRack(t *testing.T) { + p := RackAwareRRHostPolicy("dc2", "rack2") + + hosts := []*HostInfo{ + {hostId: "1", connectAddress: net.ParseIP("10.0.0.1"), dataCenter: "dc1", rack: "rack1"}, + {hostId: "2", connectAddress: net.ParseIP("10.0.0.2"), dataCenter: "dc1", rack: "rack2"}, + {hostId: "3", connectAddress: net.ParseIP("10.0.0.3"), dataCenter: "dc1", rack: "rack2"}, + {hostId: "4", connectAddress: net.ParseIP("10.0.0.4"), dataCenter: "dc2", rack: "rack1"}, + {hostId: "5", connectAddress: net.ParseIP("10.0.0.5"), dataCenter: "dc2", rack: "rack2"}, + {hostId: "6", connectAddress: net.ParseIP("10.0.0.6"), dataCenter: "dc2", rack: "rack2"}, + {hostId: "7", connectAddress: net.ParseIP("10.0.0.7"), dataCenter: "dc3", rack: "rack1"}, + {hostId: "8", connectAddress: net.ParseIP("10.0.0.8"), dataCenter: "dc3", rack: "rack2"}, + {hostId: "9", connectAddress: net.ParseIP("10.0.0.9"), dataCenter: "dc3", rack: "rack2"}, + } + for _, host := range hosts { + p.AddHost(host) + } + + got := make(map[string]bool, len(hosts)) + order := make([]*HostInfo, 0, len(hosts)) + + // Make sure the same host is not returned twice. + it := p.Pick(nil) + for h := it(); h != nil; h = it() { + id := h.Info().hostId + require.NotNilf(t, got[id], "got duplicate host %s", id) + got[id] = true + order = append(order, h.Info()) + } + // Make sure all available hosts has been offered. + require.Len(t, order, len(hosts)) + + require.Equal(t, "dc2", order[0].DataCenter()) + require.Equal(t, "dc2", order[1].DataCenter()) + require.Equal(t, "dc2", order[2].DataCenter()) + require.NotEqual(t, "dc2", order[3].DataCenter()) + + require.Equal(t, "rack2", order[0].Rack()) + require.Equal(t, "rack2", order[1].Rack()) + require.NotEqual(t, "rack2", order[2].Rack()) +} + +// If there are no hosts in the same rack then hosts from the same datacenter +// are preferred. +func TestHostPolicy_RackAwareRR_SameDatacenter(t *testing.T) { + p := RackAwareRRHostPolicy("dc2", "rack2") + + hosts := []*HostInfo{ + {hostId: "1", connectAddress: net.ParseIP("10.0.0.1"), dataCenter: "dc1", rack: "rack1"}, + {hostId: "2", connectAddress: net.ParseIP("10.0.0.2"), dataCenter: "dc1", rack: "rack2"}, + {hostId: "3", connectAddress: net.ParseIP("10.0.0.3"), dataCenter: "dc1", rack: "rack2"}, + {hostId: "4", connectAddress: net.ParseIP("10.0.0.4"), dataCenter: "dc2", rack: "rack1"}, + {hostId: "5", connectAddress: net.ParseIP("10.0.0.5"), dataCenter: "dc2", rack: "rack4"}, + {hostId: "6", connectAddress: net.ParseIP("10.0.0.6"), dataCenter: "dc2", rack: "rack4"}, + {hostId: "7", connectAddress: net.ParseIP("10.0.0.7"), dataCenter: "dc3", rack: "rack1"}, + {hostId: "8", connectAddress: net.ParseIP("10.0.0.8"), dataCenter: "dc3", rack: "rack2"}, + {hostId: "9", connectAddress: net.ParseIP("10.0.0.9"), dataCenter: "dc3", rack: "rack2"}, + } + for _, host := range hosts { + p.AddHost(host) + } + + got := make(map[string]bool, len(hosts)) + order := make([]*HostInfo, 0, len(hosts)) + + // Make sure the same host is not returned twice. + it := p.Pick(nil) + for h := it(); h != nil; h = it() { + id := h.Info().hostId + require.NotNilf(t, got[id], "got duplicate host %s", id) + got[id] = true + order = append(order, h.Info()) + } + // Make sure all available hosts has been offered. + require.Len(t, order, len(hosts)) + + require.Equal(t, "dc2", order[0].DataCenter()) + require.Equal(t, "dc2", order[1].DataCenter()) + require.Equal(t, "dc2", order[2].DataCenter()) + require.NotEqual(t, "dc2", order[3].DataCenter()) + + require.NotEqual(t, "rack2", order[0].Rack()) +} + +// If there are no hosts from the same datacenter, then the exact order of +// other hosts does not matter. +func TestHostPolicy_RackAwareRR_Remote(t *testing.T) { + p := RackAwareRRHostPolicy("dc2", "rack2") + + hosts := []*HostInfo{ + {hostId: "1", connectAddress: net.ParseIP("10.0.0.1"), dataCenter: "dc1", rack: "rack1"}, + {hostId: "2", connectAddress: net.ParseIP("10.0.0.2"), dataCenter: "dc1", rack: "rack2"}, + {hostId: "3", connectAddress: net.ParseIP("10.0.0.3"), dataCenter: "dc1", rack: "rack2"}, + {hostId: "4", connectAddress: net.ParseIP("10.0.0.4"), dataCenter: "dc4", rack: "rack1"}, + {hostId: "5", connectAddress: net.ParseIP("10.0.0.5"), dataCenter: "dc4", rack: "rack2"}, + {hostId: "6", connectAddress: net.ParseIP("10.0.0.6"), dataCenter: "dc4", rack: "rack2"}, + {hostId: "7", connectAddress: net.ParseIP("10.0.0.7"), dataCenter: "dc3", rack: "rack1"}, + {hostId: "8", connectAddress: net.ParseIP("10.0.0.8"), dataCenter: "dc3", rack: "rack2"}, + {hostId: "9", connectAddress: net.ParseIP("10.0.0.9"), dataCenter: "dc3", rack: "rack2"}, + } + for _, host := range hosts { + p.AddHost(host) + } + + got := make(map[string]bool, len(hosts)) + order := make([]*HostInfo, 0, len(hosts)) + + // Make sure the same host is not returned twice. + it := p.Pick(nil) + for h := it(); h != nil; h = it() { + id := h.Info().hostId + require.NotNilf(t, got[id], "got duplicate host %s", id) + got[id] = true + order = append(order, h.Info()) + } + // Make sure all available hosts has been offered. + require.Len(t, order, len(hosts)) + + require.NotEqual(t, "dc2", order[0].DataCenter()) +} diff --git a/session.go b/session.go index 8165abd1f..1080f5696 100644 --- a/session.go +++ b/session.go @@ -875,7 +875,7 @@ type Query struct { disableAutoPage bool - // getKeyspace is field so that it can be overriden in tests + // getKeyspace is field so that it can be overridden in tests getKeyspace func() string // used by control conn queries to prevent triggering a write to systems @@ -896,10 +896,10 @@ func (q *Query) defaultsFromSession() { q.serialCons = s.cfg.SerialConsistency q.defaultTimestamp = s.cfg.DefaultTimestamp q.idempotent = s.cfg.DefaultIdempotence - q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} - - q.spec = &NonSpeculativeExecution{} s.mu.RUnlock() + + q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} + q.spec = &nonSpeculativeExecution } // Statement returns the statement that was used to generate this query.