Skip to content

Commit

Permalink
Check HostFilter for control connection
Browse files Browse the repository at this point in the history
Fixes #1608
  • Loading branch information
jameshartig committed Apr 7, 2022
1 parent 0eacd31 commit 56d43d5
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 73 deletions.
123 changes: 55 additions & 68 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,28 +167,6 @@ func shuffleHosts(hosts []*HostInfo) []*HostInfo {
return shuffled
}

func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
// shuffle endpoints so not all drivers will connect to the same initial
// node.
shuffled := shuffleHosts(endpoints)

cfg := *c.session.connCfg
cfg.disableCoalesce = true

var err error
for _, host := range shuffled {
var conn *Conn
conn, err = c.session.dial(c.session.ctx, host, &cfg, c)
if err == nil {
return conn, nil
}

c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
}

return nil, err
}

// this is going to be version dependant and a nightmare to maintain :(
var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)

Expand Down Expand Up @@ -249,14 +227,31 @@ func (c *controlConn) connect(hosts []*HostInfo) error {
return errors.New("control: no endpoints specified")
}

conn, err := c.shuffleDial(hosts)
if err != nil {
return fmt.Errorf("control: unable to connect to initial hosts: %v", err)
}
// shuffle endpoints so not all drivers will connect to the same initial
// node.
hosts = shuffleHosts(hosts)

cfg := *c.session.connCfg
cfg.disableCoalesce = true

if err := c.setupConn(conn); err != nil {
var conn *Conn
var err error
for _, host := range hosts {
conn, err = c.session.dial(c.session.ctx, host, &cfg, c)
if err != nil {
c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
continue
}
err = c.setupConn(conn)
if err == nil {
break
}
c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
conn.Close()
return fmt.Errorf("control: unable to setup connection: %v", err)
conn = nil
}
if conn == nil {
return fmt.Errorf("unable to connect to initial hosts: %v", err)
}

// we could fetch the initial ring here and update initial host data. So that
Expand All @@ -273,15 +268,17 @@ type connHost struct {
}

func (c *controlConn) setupConn(conn *Conn) error {
if err := c.registerEvents(conn); err != nil {
conn.Close()
// we need up-to-date host info for the filterHost call below
host, err := conn.localHostInfo(context.TODO())
if err != nil {
return err
}

// TODO(zariel): do we need to fetch host info everytime
// the control conn connects? Surely we have it cached?
host, err := conn.localHostInfo(context.TODO())
if err != nil {
if c.session.cfg.filterHost(host) {
return fmt.Errorf("host was filtered: %v", host.ConnectAddress())
}

if err := c.registerEvents(conn); err != nil {
return err
}

Expand Down Expand Up @@ -346,50 +343,40 @@ func (c *controlConn) reconnect(refreshring bool) {
return
}
defer atomic.StoreInt32(&c.reconnecting, 0)
// TODO: simplify this function, use session.ring to get hosts instead of the
// connection pool

var host *HostInfo
hosts := c.session.ring.allHosts()
hosts = shuffleHosts(hosts)

// keep the old behavior of connecting to the old host first by moving it to
// the front of the slice
ch := c.getConn()
if ch != nil {
host = ch.host
ch.conn.Close()
}

var newConn *Conn
if host != nil {
// try to connect to the old host
conn, err := c.session.connect(c.session.ctx, host, c)
if err != nil {
// host is dead
// TODO: this is replicated in a few places
if c.session.cfg.ConvictionPolicy.AddFailure(err, host) {
c.session.handleNodeDown(host.ConnectAddress(), host.Port())
for i := range hosts {
if hosts[i].Equal(ch.host) {
hosts[0], hosts[i] = hosts[i], hosts[0]
break
}
} else {
newConn = conn
}
ch.conn.Close()
}

// TODO: should have our own round-robin for hosts so that we can try each
// in succession and guarantee that we get a different host each time.
if newConn == nil {
host := c.session.ring.rrHost()
if host == nil {
c.connect(c.session.ring.endpoints)
return
}

var err error
newConn, err = c.session.connect(c.session.ctx, host, c)
var conn *Conn
var err error
for _, host := range hosts {
conn, err = c.session.connect(c.session.ctx, host, c)
if err != nil {
// TODO: add log handler for things like this
return
c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
continue
}
err = c.setupConn(conn)
if err == nil {
break
}
c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
conn.Close()
conn = nil
}

if err := c.setupConn(newConn); err != nil {
newConn.Close()
if conn == nil {
c.session.logger.Printf("gocql: control unable to register events: %v\n", err)
return
}
Expand Down
30 changes: 27 additions & 3 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ func TestRingDiscovery(t *testing.T) {
}
}

// TestHostFilter ensures that host filtering works even when we discover hosts
func TestHostFilter(t *testing.T) {
// TestHostFilterDiscovery ensures that host filtering works even when we discover hosts
func TestHostFilterDiscovery(t *testing.T) {
clusterHosts := getClusterHosts()
if len(clusterHosts) < 2 {
t.Skip("skipping because we don't have 2 or more hosts")
Expand All @@ -98,7 +98,31 @@ func TestHostFilter(t *testing.T) {
session := createSessionFromCluster(cluster, t)
defer session.Close()

assertEqual(t, "len(rr.hosts.get()) != 0", len(clusterHosts)-1, len(rr.hosts.get()))
assertEqual(t, "len(clusterHosts)-1 != len(rr.hosts.get())", len(clusterHosts)-1, len(rr.hosts.get()))
}

// TestHostFilterInitial ensures that host filtering works for the initial
// connection including the control connection
func TestHostFilterInitial(t *testing.T) {
clusterHosts := getClusterHosts()
if len(clusterHosts) < 2 {
t.Skip("skipping because we don't have 2 or more hosts")
}
cluster := createCluster()
rr := RoundRobinHostPolicy().(*roundRobinHostPolicy)
cluster.PoolConfig.HostSelectionPolicy = rr
// we'll filter out the second host
filtered := clusterHosts[1]
cluster.HostFilter = HostFilterFunc(func(host *HostInfo) bool {
if host.ConnectAddress().String() == filtered {
return false
}
return true
})
session := createSessionFromCluster(cluster, t)
defer session.Close()

assertEqual(t, "len(clusterHosts)-1 != len(rr.hosts.get())", len(clusterHosts)-1, len(rr.hosts.get()))
}

func TestWriteFailure(t *testing.T) {
Expand Down
2 changes: 0 additions & 2 deletions ring.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ type ring struct {
}

func (r *ring) rrHost() *HostInfo {
// TODO: should we filter hosts that get used here? These hosts will be used
// for the control connection, should we also provide an iterator?
r.mu.RLock()
defer r.mu.RUnlock()
if len(r.hostList) == 0 {
Expand Down

0 comments on commit 56d43d5

Please sign in to comment.