Skip to content

Commit

Permalink
Optimize Copying of Fields (prysmaticlabs#4811)
Browse files Browse the repository at this point in the history
* add new changes

* memory pool

* add test

* final optimization

* preston's review
  • Loading branch information
nisdas committed Feb 10, 2020
1 parent 18fbdd5 commit 4f654d3
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 10 deletions.
9 changes: 9 additions & 0 deletions beacon-chain/blockchain/process_attestation_helpers.go
@@ -1,6 +1,7 @@
package blockchain

import (
"bytes"
"context"
"fmt"

Expand Down Expand Up @@ -28,6 +29,14 @@ func (s *Service) getAttPreState(ctx context.Context, c *ethpb.Checkpoint) (*sta
return cachedState, nil
}

headRoot, err := s.HeadRoot(ctx)
if err != nil {
return nil, errors.Wrapf(err, "could not get head root")
}
if bytes.Equal(headRoot, c.Root) {
return s.HeadState(ctx)
}

baseState, err := s.beaconDB.State(ctx, bytesutil.ToBytes32(c.Root))
if err != nil {
return nil, errors.Wrapf(err, "could not get pre state for slot %d", helpers.StartSlot(c.Epoch))
Expand Down
8 changes: 7 additions & 1 deletion beacon-chain/blockchain/process_block_helpers.go
Expand Up @@ -73,7 +73,13 @@ func (s *Service) verifyBlkPreState(ctx context.Context, b *ethpb.BeaconBlock) (
}
return preState.Copy(), nil
}

headRoot, err := s.HeadRoot(ctx)
if err != nil {
return nil, errors.Wrapf(err, "could not get head root")
}
if bytes.Equal(headRoot, b.ParentRoot) {
return s.HeadState(ctx)
}
preState, err := s.beaconDB.State(ctx, bytesutil.ToBytes32(b.ParentRoot))
if err != nil {
return nil, errors.Wrapf(err, "could not get pre state for slot %d", b.Slot)
Expand Down
1 change: 1 addition & 0 deletions beacon-chain/state/BUILD.bazel
Expand Up @@ -19,6 +19,7 @@ go_library(
"//proto/beacon/p2p/v1:go_default_library",
"//shared/bytesutil:go_default_library",
"//shared/hashutil:go_default_library",
"//shared/memorypool:go_default_library",
"//shared/params:go_default_library",
"//shared/stateutil:go_default_library",
"@com_github_gogo_protobuf//proto:go_default_library",
Expand Down
35 changes: 27 additions & 8 deletions beacon-chain/state/setters.go
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/prysmaticlabs/go-bitfield"
pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
"github.com/prysmaticlabs/prysm/shared/hashutil"
"github.com/prysmaticlabs/prysm/shared/memorypool"
)

type fieldIndex int
Expand Down Expand Up @@ -39,6 +40,9 @@ const (
previousJustifiedCheckpoint
currentJustifiedCheckpoint
finalizedCheckpoint
// validatorIdxMap is not part of the state, but is used so as to be able to keep
// track of references to it to allow for efficient copy on write.
validatorIdxMap
)

// SetGenesisTime for the beacon state.
Expand Down Expand Up @@ -308,14 +312,21 @@ func (b *BeaconState) UpdateValidatorAtIndex(idx uint64, val *ethpb.Validator) e
// SetValidatorIndexByPubkey updates the validator index mapping maintained internally to
// a given input 48-byte, public key.
func (b *BeaconState) SetValidatorIndexByPubkey(pubKey [48]byte, validatorIdx uint64) {
// Copy on write since this is a shared map.
m := b.validatorIndexMap()
idxMap := b.valIdxMap
b.lock.RLock()
if b.sharedFieldReferences[validatorIdxMap].refs > 1 {
// copy-on-write for idx map
idxMap = b.validatorIndexMap()
b.sharedFieldReferences[validatorIdxMap].refs--
b.sharedFieldReferences[validatorIdxMap] = &reference{refs: 1}
}
b.lock.RUnlock()

b.lock.Lock()
defer b.lock.Unlock()

m[pubKey] = validatorIdx
b.valIdxMap = m
idxMap[pubKey] = validatorIdx
b.valIdxMap = idxMap
}

// SetBalances for the beacon state. This PR updates the entire
Expand Down Expand Up @@ -381,7 +392,9 @@ func (b *BeaconState) UpdateRandaoMixesAtIndex(val []byte, idx uint64) error {
b.lock.RLock()
mixes := b.state.RandaoMixes
if refs := b.sharedFieldReferences[randaoMixes].refs; refs > 1 {
mixes = b.RandaoMixes()
newMixes := memorypool.GetDoubleByteSlice(len(mixes))
copy(newMixes, mixes)
mixes = newMixes
b.sharedFieldReferences[randaoMixes].refs--
b.sharedFieldReferences[randaoMixes] = &reference{refs: 1}
}
Expand Down Expand Up @@ -492,7 +505,9 @@ func (b *BeaconState) AppendCurrentEpochAttestations(val *pbp2p.PendingAttestati

atts := b.state.CurrentEpochAttestations
if b.sharedFieldReferences[currentEpochAttestations].refs > 1 {
atts = b.CurrentEpochAttestations()
copiedAtts := make([]*pbp2p.PendingAttestation, len(atts), len(atts)+1)
copy(copiedAtts, atts)
atts = copiedAtts
b.sharedFieldReferences[currentEpochAttestations].refs--
b.sharedFieldReferences[currentEpochAttestations] = &reference{refs: 1}
}
Expand All @@ -512,7 +527,9 @@ func (b *BeaconState) AppendPreviousEpochAttestations(val *pbp2p.PendingAttestat
b.lock.RLock()
atts := b.state.PreviousEpochAttestations
if b.sharedFieldReferences[previousEpochAttestations].refs > 1 {
atts = b.PreviousEpochAttestations()
copiedAtts := make([]*pbp2p.PendingAttestation, len(atts), len(atts)+1)
copy(copiedAtts, atts)
atts = copiedAtts
b.sharedFieldReferences[previousEpochAttestations].refs--
b.sharedFieldReferences[previousEpochAttestations] = &reference{refs: 1}
}
Expand All @@ -532,7 +549,9 @@ func (b *BeaconState) AppendValidator(val *ethpb.Validator) error {
b.lock.RLock()
vals := b.state.Validators
if b.sharedFieldReferences[validators].refs > 1 {
vals = b.Validators()
copiedVals := make([]*ethpb.Validator, len(b.state.Validators), len(b.state.Validators)+1)
copy(copiedVals, b.state.Validators)
vals = copiedVals
b.sharedFieldReferences[validators].refs--
b.sharedFieldReferences[validators] = &reference{refs: 1}
}
Expand Down
13 changes: 12 additions & 1 deletion beacon-chain/state/types.go
Expand Up @@ -12,6 +12,7 @@ import (
pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
"github.com/prysmaticlabs/prysm/shared/bytesutil"
"github.com/prysmaticlabs/prysm/shared/hashutil"
"github.com/prysmaticlabs/prysm/shared/memorypool"
"github.com/prysmaticlabs/prysm/shared/params"
"github.com/prysmaticlabs/prysm/shared/stateutil"
)
Expand Down Expand Up @@ -73,6 +74,7 @@ func InitializeFromProtoUnsafe(st *pbp2p.BeaconState) (*BeaconState, error) {
b.sharedFieldReferences[validators] = &reference{refs: 1}
b.sharedFieldReferences[balances] = &reference{refs: 1}
b.sharedFieldReferences[historicalRoots] = &reference{refs: 1}
b.sharedFieldReferences[validatorIdxMap] = &reference{refs: 1}

return b, nil
}
Expand Down Expand Up @@ -141,8 +143,11 @@ func (b *BeaconState) Copy() *BeaconState {

// Finalizer runs when dst is being destroyed in garbage collection.
runtime.SetFinalizer(dst, func(b *BeaconState) {
for _, v := range b.sharedFieldReferences {
for i, v := range b.sharedFieldReferences {
v.refs--
if i == randaoMixes && v.refs == 0 {
memorypool.PutDoubleByteSlice(b.state.RandaoMixes)
}
}
})

Expand All @@ -166,6 +171,12 @@ func (b *BeaconState) HashTreeRoot() ([32]byte, error) {
}

for field := range b.dirtyFields {
// do not compute root for field
// thats not part of the state.
if field == validatorIdxMap {
delete(b.dirtyFields, field)
continue
}
root, err := b.rootSelector(field)
if err != nil {
return [32]byte{}, err
Expand Down
14 changes: 14 additions & 0 deletions shared/memorypool/BUILD.bazel
@@ -0,0 +1,14 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "go_default_library",
srcs = ["memorypool.go"],
importpath = "github.com/prysmaticlabs/prysm/shared/memorypool",
visibility = ["//visibility:public"],
)

go_test(
name = "go_default_test",
srcs = ["memorypool_test.go"],
embed = [":go_default_library"],
)
27 changes: 27 additions & 0 deletions shared/memorypool/memorypool.go
@@ -0,0 +1,27 @@
package memorypool

import "sync"

// DoubleByteSlicePool represents the memory pool
// for 2d byte slices
var DoubleByteSlicePool = new(sync.Pool)

// GetDoubleByteSlice retrieves the 2d byte slice of
// the desired size from the memory pool.
func GetDoubleByteSlice(size int) [][]byte {
rawObj := DoubleByteSlicePool.Get()
if rawObj == nil {
return make([][]byte, size)
}
byteSlice := rawObj.([][]byte)
if len(byteSlice) >= size {
return byteSlice[:size]
}
return append(byteSlice, make([][]byte, size-len(byteSlice))...)
}

// PutDoubleByteSlice places the provided 2d byte slice
// in the memory pool
func PutDoubleByteSlice(data [][]byte) {
DoubleByteSlicePool.Put(data)
}
16 changes: 16 additions & 0 deletions shared/memorypool/memorypool_test.go
@@ -0,0 +1,16 @@
package memorypool

import (
"testing"
)

func TestRoundTripMemoryRetrieval(t *testing.T) {
byteSlice := make([][]byte, 1000)
PutDoubleByteSlice(byteSlice)
newSlice := GetDoubleByteSlice(1000)

if len(newSlice) != 1000 {
t.Errorf("Wanted same slice object, but got different object. "+
"Wanted slice with length %d but got length %d", 1000, len(newSlice))
}
}

0 comments on commit 4f654d3

Please sign in to comment.