From ab0c13f796c5668b57051ad2685210e0ff5166be Mon Sep 17 00:00:00 2001 From: Preston Van Loon Date: Sun, 16 Feb 2020 10:18:48 -0700 Subject: [PATCH] Check attestation bitlist length in aggregation to prevent panic (#4876) * Check attestation bitlist length in aggregation to prevent panic * Add case for overlap too Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com> --- beacon-chain/core/helpers/attestation.go | 14 ++++++++- beacon-chain/core/helpers/attestation_test.go | 29 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/beacon-chain/core/helpers/attestation.go b/beacon-chain/core/helpers/attestation.go index a0677568c27..9e5448b400c 100644 --- a/beacon-chain/core/helpers/attestation.go +++ b/beacon-chain/core/helpers/attestation.go @@ -16,6 +16,10 @@ var ( // ErrAttestationAggregationBitsOverlap is returned when two attestations aggregation // bits overlap with each other. ErrAttestationAggregationBitsOverlap = errors.New("overlapping aggregation bits") + + // ErrAttestationAggregationBitsDifferentLen is returned when two attestation aggregation bits + // have different lengths. + ErrAttestationAggregationBitsDifferentLen = errors.New("different bitlist lengths") ) // AggregateAttestations such that the minimal number of attestations are returned. @@ -32,7 +36,7 @@ func AggregateAttestations(atts []*ethpb.Attestation) ([]*ethpb.Attestation, err } for j := i + 1; j < len(atts); j++ { b := atts[j] - if !a.AggregationBits.Overlaps(b.AggregationBits) { + if a.AggregationBits.Len() == b.AggregationBits.Len() && !a.AggregationBits.Overlaps(b.AggregationBits) { var err error a, err = AggregateAttestation(a, b) if err != nil { @@ -50,6 +54,11 @@ func AggregateAttestations(atts []*ethpb.Attestation) ([]*ethpb.Attestation, err for i, a := range atts { for j := i + 1; j < len(atts); j++ { b := atts[j] + + if a.AggregationBits.Len() != b.AggregationBits.Len() { + continue + } + if a.AggregationBits.Contains(b.AggregationBits) { // If b is fully contained in a, then b can be removed. atts = append(atts[:j], atts[j+1:]...) @@ -74,6 +83,9 @@ var signatureFromBytes = bls.SignatureFromBytes // AggregateAttestation aggregates attestations a1 and a2 together. func AggregateAttestation(a1 *ethpb.Attestation, a2 *ethpb.Attestation) (*ethpb.Attestation, error) { + if a1.AggregationBits.Len() != a2.AggregationBits.Len() { + return nil, ErrAttestationAggregationBitsDifferentLen + } if a1.AggregationBits.Overlaps(a2.AggregationBits) { return nil, ErrAttestationAggregationBitsOverlap } diff --git a/beacon-chain/core/helpers/attestation_test.go b/beacon-chain/core/helpers/attestation_test.go index 4bd3c5d1829..298d1b83aa6 100644 --- a/beacon-chain/core/helpers/attestation_test.go +++ b/beacon-chain/core/helpers/attestation_test.go @@ -62,6 +62,22 @@ func TestAggregateAttestation_OverlapFails(t *testing.T) { } } +func TestAggregateAttestation_DiffLengthFails(t *testing.T) { + tests := []struct { + a1 *ethpb.Attestation + a2 *ethpb.Attestation + }{ + {a1: ðpb.Attestation{AggregationBits: bitfield.Bitlist{0x0F}}, + a2: ðpb.Attestation{AggregationBits: bitfield.Bitlist{0x11}}}, + } + for _, tt := range tests { + _, err := helpers.AggregateAttestation(tt.a1, tt.a2) + if err != helpers.ErrAttestationAggregationBitsDifferentLen { + t.Error("Did not receive wanted error") + } + } +} + func bitlistWithAllBitsSet(length uint64) bitfield.Bitlist { b := bitfield.NewBitlist(length) for i := uint64(0); i < length; i++ { @@ -167,6 +183,19 @@ func TestAggregateAttestations(t *testing.T) { {0b00000011, 0b1}, }, }, + { + name: "attestations with different bitlist lengths", + inputs: []bitfield.Bitlist{ + {0b00000011, 0b10}, + {0b00000111, 0b100}, + {0b00000100, 0b1}, + }, + want: []bitfield.Bitlist{ + {0b00000011, 0b10}, + {0b00000111, 0b100}, + {0b00000100, 0b1}, + }, + }, } var makeAttestationsFromBitlists = func(bl []bitfield.Bitlist) []*ethpb.Attestation {