Skip to content

Commit

Permalink
p2p/discover: fix update logic in handleAddNode (#29836)
Browse files Browse the repository at this point in the history
It seems the semantic differences between addFoundNode and addInboundNode were lost in
#29572. My understanding is addFoundNode is for a node you have not contacted directly
(and are unsure if is available) whereas addInboundNode is for adding nodes that have
contacted the local node and we can verify they are active.

handleAddNode seems to be the consolidation of those two methods, yet it bumps the node in
the bucket (updating it's IP addr) even if the node was not an inbound. This PR fixes
this. It wasn't originally caught in tests like TestTable_addSeenNode because the
manipulation of the node object actually modified the node value used by the test.

New logic is added to reject non-inbound updates unless the sequence number of the
(signed) ENR increases. Inbound updates, which are published by the updated node itself,
are always accepted. If an inbound update changes the endpoint, the node will be
revalidated on an expedited schedule.

Co-authored-by: Felix Lange <fjl@twurst.com>
  • Loading branch information
lightclient and fjl committed May 28, 2024
1 parent 171430c commit cc22e0c
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 59 deletions.
46 changes: 33 additions & 13 deletions p2p/discover/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,9 @@ func (tab *Table) handleAddNode(req addNodeOp) bool {
}

b := tab.bucket(req.node.ID())
if tab.bumpInBucket(b, req.node.Node) {
// Already in bucket, update record.
n, _ := tab.bumpInBucket(b, req.node.Node, req.isInbound)
if n != nil {
// Already in bucket.
return false
}
if len(b.entries) >= bucketSize {
Expand Down Expand Up @@ -605,26 +606,45 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *node {
return rep
}

// bumpInBucket updates the node record of n in the bucket.
func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node) bool {
// bumpInBucket updates a node record if it exists in the bucket.
// The second return value reports whether the node's endpoint (IP/port) was updated.
func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool) (n *node, endpointChanged bool) {
i := slices.IndexFunc(b.entries, func(elem *node) bool {
return elem.ID() == newRecord.ID()
})
if i == -1 {
return false
return nil, false // not in bucket
}
n = b.entries[i]

// For inbound updates (from the node itself) we accept any change, even if it sets
// back the sequence number. For found nodes (!isInbound), seq has to advance. Note
// this check also ensures found discv4 nodes (which always have seq=0) can't be
// updated.
if newRecord.Seq() <= n.Seq() && !isInbound {
return n, false
}

if !newRecord.IP().Equal(b.entries[i].IP()) {
// Endpoint has changed, ensure that the new IP fits into table limits.
tab.removeIP(b, b.entries[i].IP())
// Check endpoint update against IP limits.
ipchanged := newRecord.IPAddr() != n.IPAddr()
portchanged := newRecord.UDP() != n.UDP()
if ipchanged {
tab.removeIP(b, n.IP())
if !tab.addIP(b, newRecord.IP()) {
// It doesn't, put the previous one back.
tab.addIP(b, b.entries[i].IP())
return false
// It doesn't fit with the limit, put the previous record back.
tab.addIP(b, n.IP())
return n, false
}
}
b.entries[i].Node = newRecord
return true

// Apply update.
n.Node = newRecord
if ipchanged || portchanged {
// Ensure node is revalidated quickly for endpoint changes.
tab.revalidation.nodeEndpointChanged(tab, n)
return n, true
}
return n, false
}

func (tab *Table) handleTrackRequest(op trackRequestOp) {
Expand Down
25 changes: 13 additions & 12 deletions p2p/discover/table_reval.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (

const never = mclock.AbsTime(math.MaxInt64)

const slowRevalidationFactor = 3

// tableRevalidation implements the node revalidation process.
// It tracks all nodes contained in Table, and schedules sending PING to them.
type tableRevalidation struct {
Expand All @@ -48,7 +50,7 @@ func (tr *tableRevalidation) init(cfg *Config) {
tr.fast.interval = cfg.PingInterval
tr.fast.name = "fast"
tr.slow.nextTime = never
tr.slow.interval = cfg.PingInterval * 3
tr.slow.interval = cfg.PingInterval * slowRevalidationFactor
tr.slow.name = "slow"
}

Expand All @@ -65,6 +67,12 @@ func (tr *tableRevalidation) nodeRemoved(n *node) {
n.revalList.remove(n)
}

// nodeEndpointChanged is called when a change in IP or port is detected.
func (tr *tableRevalidation) nodeEndpointChanged(tab *Table, n *node) {
n.isValidatedLive = false
tr.moveToList(&tr.fast, n, tab.cfg.Clock.Now(), &tab.rand)
}

// run performs node revalidation.
// It returns the next time it should be invoked, which is used in the Table main loop
// to schedule a timer. However, run can be called at any time.
Expand Down Expand Up @@ -146,11 +154,11 @@ func (tr *tableRevalidation) handleResponse(tab *Table, resp revalidationRespons
defer tab.mutex.Unlock()

if !resp.didRespond {
// Revalidation failed.
n.livenessChecks /= 3
if n.livenessChecks <= 0 {
tab.deleteInBucket(b, n.ID())
} else {
tab.log.Debug("Node revalidation failed", "b", b.index, "id", n.ID(), "checks", n.livenessChecks, "q", n.revalList.name)
tr.moveToList(&tr.fast, n, now, &tab.rand)
}
return
Expand All @@ -159,22 +167,15 @@ func (tr *tableRevalidation) handleResponse(tab *Table, resp revalidationRespons
// The node responded.
n.livenessChecks++
n.isValidatedLive = true
tab.log.Debug("Node revalidated", "b", b.index, "id", n.ID(), "checks", n.livenessChecks, "q", n.revalList.name)
var endpointChanged bool
if resp.newRecord != nil {
endpointChanged = tab.bumpInBucket(b, resp.newRecord)
if endpointChanged {
// If the node changed its advertised endpoint, the updated ENR is not served
// until it has been revalidated.
n.isValidatedLive = false
}
_, endpointChanged = tab.bumpInBucket(b, resp.newRecord, false)
}
tab.log.Debug("Revalidated node", "b", b.index, "id", n.ID(), "checks", n.livenessChecks, "q", n.revalList)

// Move node over to slow queue after first validation.
// Node moves to slow list if it passed and hasn't changed.
if !endpointChanged {
tr.moveToList(&tr.slow, n, now, &tab.rand)
} else {
tr.moveToList(&tr.fast, n, now, &tab.rand)
}
}

Expand Down
53 changes: 51 additions & 2 deletions p2p/discover/table_reval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ import (
"time"

"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
)

// This test checks that revalidation can handle a node disappearing while
// a request is active.
func TestRevalidationNodeRemoved(t *testing.T) {
func TestRevalidation_nodeRemoved(t *testing.T) {
var (
clock mclock.Simulated
transport = newPingRecorder()
Expand All @@ -35,7 +37,7 @@ func TestRevalidationNodeRemoved(t *testing.T) {
)
defer db.Close()

// Fill a bucket.
// Add a node to the table.
node := nodeAtDistance(tab.self().ID(), 255, net.IP{77, 88, 99, 1})
tab.handleAddNode(addNodeOp{node: node})

Expand Down Expand Up @@ -68,3 +70,50 @@ func TestRevalidationNodeRemoved(t *testing.T) {
t.Fatal("removed node contained in revalidation list")
}
}

// This test checks that nodes with an updated endpoint remain in the fast revalidation list.
func TestRevalidation_endpointUpdate(t *testing.T) {
var (
clock mclock.Simulated
transport = newPingRecorder()
tab, db = newInactiveTestTable(transport, Config{Clock: &clock})
tr = &tab.revalidation
)
defer db.Close()

// Add node to table.
node := nodeAtDistance(tab.self().ID(), 255, net.IP{77, 88, 99, 1})
tab.handleAddNode(addNodeOp{node: node})

// Update the record in transport, including endpoint update.
record := node.Record()
record.Set(enr.IP{100, 100, 100, 100})
record.Set(enr.UDP(9999))
nodev2 := enode.SignNull(record, node.ID())
transport.updateRecord(nodev2)

// Start a revalidation request. Schedule once to get the next start time,
// then advance the clock to that point and schedule again to start.
next := tr.run(tab, clock.Now())
clock.Run(time.Duration(next + 1))
tr.run(tab, clock.Now())
if len(tr.activeReq) != 1 {
t.Fatal("revalidation request did not start:", tr.activeReq)
}

// Now finish the revalidation request.
var resp revalidationResponse
select {
case resp = <-tab.revalResponseCh:
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for revalidation")
}
tr.handleResponse(tab, resp)

if !tr.fast.contains(node.ID()) {
t.Fatal("node not contained in fast revalidation list")
}
if node.isValidatedLive {
t.Fatal("node is marked live after endpoint change")
}
}
122 changes: 90 additions & 32 deletions p2p/discover/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func waitForRevalidationPing(t *testing.T, transport *pingRecorder, tab *Table,
simclock := tab.cfg.Clock.(*mclock.Simulated)
maxAttempts := tab.len() * 8
for i := 0; i < maxAttempts; i++ {
simclock.Run(tab.cfg.PingInterval)
simclock.Run(tab.cfg.PingInterval * slowRevalidationFactor)
p := transport.waitPing(2 * time.Second)
if p == nil {
t.Fatal("Table did not send revalidation ping")
Expand Down Expand Up @@ -275,7 +275,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(t)
}

func TestTable_addVerifiedNode(t *testing.T) {
func TestTable_addInboundNode(t *testing.T) {
tab, db := newTestTable(newPingRecorder(), Config{})
<-tab.initDone
defer db.Close()
Expand All @@ -286,29 +286,26 @@ func TestTable_addVerifiedNode(t *testing.T) {
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addFoundNode(n1)
tab.addFoundNode(n2)
bucket := tab.bucket(n1.ID())
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node})

// Verify bucket content:
bcontent := []*node{n1, n2}
if !reflect.DeepEqual(unwrapNodes(bucket.entries), unwrapNodes(bcontent)) {
t.Fatalf("wrong bucket content: %v", bucket.entries)
}

// Add a changed version of n2.
// Add a changed version of n2. The bucket should be updated.
newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99})
newn2 := wrapNode(enode.SignNull(newrec, n2.ID()))
tab.addInboundNode(newn2)

// Check that bucket is updated correctly.
newBcontent := []*node{n1, newn2}
if !reflect.DeepEqual(unwrapNodes(bucket.entries), unwrapNodes(newBcontent)) {
t.Fatalf("wrong bucket content after update: %v", bucket.entries)
}
checkIPLimitInvariant(t, tab)
n2v2 := enode.SignNull(newrec, n2.ID())
tab.addInboundNode(wrapNode(n2v2))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2})

// Try updating n2 without sequence number change. The update is accepted
// because it's inbound.
newrec = n2.Record()
newrec.Set(enr.IP{100, 100, 100, 100})
newrec.SetSeq(n2.Seq())
n2v3 := enode.SignNull(newrec, n2.ID())
tab.addInboundNode(wrapNode(n2v3))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v3})
}

func TestTable_addSeenNode(t *testing.T) {
func TestTable_addFoundNode(t *testing.T) {
tab, db := newTestTable(newPingRecorder(), Config{})
<-tab.initDone
defer db.Close()
Expand All @@ -319,23 +316,84 @@ func TestTable_addSeenNode(t *testing.T) {
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addFoundNode(n1)
tab.addFoundNode(n2)
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node})

// Verify bucket content:
bcontent := []*node{n1, n2}
if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) {
t.Fatalf("wrong bucket content: %v", tab.bucket(n1.ID()).entries)
}

// Add a changed version of n2.
// Add a changed version of n2. The bucket should be updated.
newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99})
newn2 := wrapNode(enode.SignNull(newrec, n2.ID()))
tab.addFoundNode(newn2)
n2v2 := enode.SignNull(newrec, n2.ID())
tab.addFoundNode(wrapNode(n2v2))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2})

// Try updating n2 without a sequence number change.
// The update should not be accepted.
newrec = n2.Record()
newrec.Set(enr.IP{100, 100, 100, 100})
newrec.SetSeq(n2.Seq())
n2v3 := enode.SignNull(newrec, n2.ID())
tab.addFoundNode(wrapNode(n2v3))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2})
}

// Check that bucket content is unchanged.
if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) {
t.Fatalf("wrong bucket content after update: %v", tab.bucket(n1.ID()).entries)
// This test checks that discv4 nodes can update their own endpoint via PING.
func TestTable_addInboundNodeUpdateV4Accept(t *testing.T) {
tab, db := newTestTable(newPingRecorder(), Config{})
<-tab.initDone
defer db.Close()
defer tab.close()

// Add a v4 node.
key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3")
n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000)
tab.addInboundNode(wrapNode(n1))
checkBucketContent(t, tab, []*enode.Node{n1})

// Add an updated version with changed IP.
// The update will be accepted because it is inbound.
n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000)
tab.addInboundNode(wrapNode(n1v2))
checkBucketContent(t, tab, []*enode.Node{n1v2})
}

// This test checks that discv4 node entries will NOT be updated when a
// changed record is found.
func TestTable_addFoundNodeV4UpdateReject(t *testing.T) {
tab, db := newTestTable(newPingRecorder(), Config{})
<-tab.initDone
defer db.Close()
defer tab.close()

// Add a v4 node.
key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3")
n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000)
tab.addFoundNode(wrapNode(n1))
checkBucketContent(t, tab, []*enode.Node{n1})

// Add an updated version with changed IP.
// The update won't be accepted because it isn't inbound.
n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000)
tab.addFoundNode(wrapNode(n1v2))
checkBucketContent(t, tab, []*enode.Node{n1})
}

func checkBucketContent(t *testing.T, tab *Table, nodes []*enode.Node) {
t.Helper()

b := tab.bucket(nodes[0].ID())
if reflect.DeepEqual(unwrapNodes(b.entries), nodes) {
return
}
t.Log("wrong bucket content. have nodes:")
for _, n := range b.entries {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP())
}
t.Log("want nodes:")
for _, n := range nodes {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP())
}
t.FailNow()

// Also check IP limits.
checkIPLimitInvariant(t, tab)
}

Expand Down

0 comments on commit cc22e0c

Please sign in to comment.