diff --git a/beacon-chain/operations/attestations/prepare_forkchoice.go b/beacon-chain/operations/attestations/prepare_forkchoice.go index bf54123bfd92..70043ccf29df 100644 --- a/beacon-chain/operations/attestations/prepare_forkchoice.go +++ b/beacon-chain/operations/attestations/prepare_forkchoice.go @@ -101,17 +101,23 @@ func (s *Service) seen(att *ethpb.Attestation) (bool, error) { if err != nil { return false, err } + incomingBits := att.AggregationBits savedBits, ok := s.forkChoiceProcessedRoots.Get(string(attRoot[:])) if ok { savedBitlist, ok := savedBits.(bitfield.Bitlist) if !ok { return false, errors.New("not a bit field") } - if savedBitlist.Len() == att.AggregationBits.Len() && savedBitlist.Overlaps(att.AggregationBits) { - return true, nil + if savedBitlist.Len() == incomingBits.Len() { + // Returns true if the node has seen all the bits in the new bit field of the incoming attestation. + if savedBitlist.Contains(incomingBits) { + return true, nil + } + // Update the bit fields by Or'ing them with the new ones. + incomingBits = incomingBits.Or(savedBitlist) } } - s.forkChoiceProcessedRoots.Set(string(attRoot[:]), att.AggregationBits, 1 /*cost*/) + s.forkChoiceProcessedRoots.Set(string(attRoot[:]), incomingBits, 1 /*cost*/) return false, nil } diff --git a/beacon-chain/operations/attestations/prepare_forkchoice_test.go b/beacon-chain/operations/attestations/prepare_forkchoice_test.go index 5623f3fa496b..a2c02f129a0e 100644 --- a/beacon-chain/operations/attestations/prepare_forkchoice_test.go +++ b/beacon-chain/operations/attestations/prepare_forkchoice_test.go @@ -220,7 +220,7 @@ func TestSeenAttestations_PresentInCache(t *testing.T) { t.Fatal(err) } - att1 := ðpb.Attestation{Data: ðpb.AttestationData{}, Signature: []byte{'A'}, AggregationBits: bitfield.Bitlist{0x03}} + att1 := ðpb.Attestation{Data: ðpb.AttestationData{}, Signature: []byte{'A'}, AggregationBits: bitfield.Bitlist{0x13} /* 0b00010011 */} got, err := s.seen(att1) if err != nil { t.Fatal(err) @@ -231,12 +231,86 @@ func TestSeenAttestations_PresentInCache(t *testing.T) { time.Sleep(100 * time.Millisecond) - att2 := ðpb.Attestation{Data: ðpb.AttestationData{}, Signature: []byte{'A'}, AggregationBits: bitfield.Bitlist{0x03}} + att2 := ðpb.Attestation{Data: ðpb.AttestationData{}, Signature: []byte{'A'}, AggregationBits: bitfield.Bitlist{0x17} /* 0b00010111 */} got, err = s.seen(att2) if err != nil { t.Fatal(err) } + if got { + t.Error("Wanted false, got true") + } + + time.Sleep(100 * time.Millisecond) + + att3 := ðpb.Attestation{Data: ðpb.AttestationData{}, Signature: []byte{'A'}, AggregationBits: bitfield.Bitlist{0x17} /* 0b00010111 */} + got, err = s.seen(att3) + if err != nil { + t.Fatal(err) + } if !got { t.Error("Wanted true, got false") } } + +func TestService_seen(t *testing.T) { + // Attestation are checked in order of this list. + tests := []struct { + att *ethpb.Attestation + want bool + }{ + { + att: ðpb.Attestation{ + AggregationBits: bitfield.Bitlist{0b11011}, + Data: ðpb.AttestationData{Slot: 1}, + }, + want: false, + }, + { + att: ðpb.Attestation{ + AggregationBits: bitfield.Bitlist{0b11011}, + Data: ðpb.AttestationData{Slot: 1}, + }, + want: true, // Exact same attestation should return true + }, + { + att: ðpb.Attestation{ + AggregationBits: bitfield.Bitlist{0b10101}, + Data: ðpb.AttestationData{Slot: 1}, + }, + want: false, // Haven't seen the bit at index 2 yet. + }, + { + att: ðpb.Attestation{ + AggregationBits: bitfield.Bitlist{0b11111}, + Data: ðpb.AttestationData{Slot: 1}, + }, + want: true, // We've full committee at this point. + }, + { + att: ðpb.Attestation{ + AggregationBits: bitfield.Bitlist{0b11111}, + Data: ðpb.AttestationData{Slot: 2}, + }, + want: false, // Different root is different bitlist. + }, + { + att: ðpb.Attestation{ + AggregationBits: bitfield.Bitlist{0b11111001}, + Data: ðpb.AttestationData{Slot: 1}, + }, + want: false, // Sanity test that an attestation of different lengths does not panic. + }, + } + + s, err := NewService(context.Background(), &Config{Pool: NewPool()}) + if err != nil { + t.Fatal(err) + } + + for i, tt := range tests { + if got, _ := s.seen(tt.att); got != tt.want { + t.Errorf("Test %d failed. Got=%v want=%v", i, got, tt.want) + } + time.Sleep(10) // Sleep briefly for cache to routine to buffer. + } +}