Skip to content

Commit

Permalink
Fix symbols resolver race condition (#2665)
Browse files Browse the repository at this point in the history
  • Loading branch information
kolesnikovae committed Nov 13, 2023
1 parent 50654a8 commit be7bc5d
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 103 deletions.
90 changes: 59 additions & 31 deletions pkg/phlaredb/symdb/block_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ func (r *Reader) partition(ctx context.Context, partition uint64) (*partition, e
return nil, ErrPartitionNotFound
}
if err := p.init(ctx); err != nil {
p.Release()
return nil, err
}
return p, nil
Expand All @@ -205,20 +204,22 @@ type partition struct {
strings parquetTableRange[string, *schemav1.StringPersister]
}

func (p *partition) init(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
func (p *partition) init(ctx context.Context) (err error) { return p.tx().fetch(ctx) }

func (p *partition) Release() { p.tx().release() }

func (p *partition) tx() *fetchTx {
tx := make(fetchTx, 0, len(p.stacktraceChunks)+4)
for _, c := range p.stacktraceChunks {
c := c
g.Go(func() error { return c.fetch(ctx) })
tx.append(c)
}
if p.reader.index.Header.Version > FormatV1 {
g.Go(func() error { return p.locations.fetch(ctx) })
g.Go(func() error { return p.mappings.fetch(ctx) })
g.Go(func() error { return p.functions.fetch(ctx) })
g.Go(func() error { return p.strings.fetch(ctx) })
tx.append(&p.locations)
tx.append(&p.mappings)
tx.append(&p.functions)
tx.append(&p.strings)
}
err := g.Wait()
return err
return &tx
}

func (p *partition) Symbols() *Symbols {
Expand All @@ -231,26 +232,6 @@ func (p *partition) Symbols() *Symbols {
}
}

func (p *partition) Release() {
var wg sync.WaitGroup
wg.Add(len(p.stacktraceChunks))
for _, c := range p.stacktraceChunks {
c := c
go func() {
c.release()
wg.Done()
}()
}
if p.reader.index.Header.Version > FormatV1 {
wg.Add(4)
go func() { p.locations.release(); wg.Done() }()
go func() { p.mappings.release(); wg.Done() }()
go func() { p.functions.release(); wg.Done() }()
go func() { p.strings.release(); wg.Done() }()
}
wg.Wait()
}

func (p *partition) WriteStats(s *PartitionStats) {
var nodes uint32
for _, c := range p.stacktraceChunks {
Expand Down Expand Up @@ -482,3 +463,50 @@ func (t *parquetTableRange[M, P]) release() {
t.s = nil
})
}

// fetchTx facilitates fetching multiple objects in a transactional manner:
// if one of the objects has failed, all the remaining ones are released.
type fetchTx []fetch

type fetch interface {
fetch(context.Context) error
release()
}

func (tx *fetchTx) append(x fetch) { *tx = append(*tx, x) }

func (tx *fetchTx) fetch(ctx context.Context) (err error) {
defer func() {
if err != nil {
tx.release()
}
}()
g, ctx := errgroup.WithContext(ctx)
for i, x := range *tx {
i := i
x := x
g.Go(func() error {
fErr := x.fetch(ctx)
if fErr != nil {
(*tx)[i] = nil
}
return fErr
})
}
return g.Wait()
}

func (tx *fetchTx) release() {
var wg sync.WaitGroup
wg.Add(len(*tx))
for _, x := range *tx {
x := x
go func() {
defer wg.Done()
if x != nil {
x.release()
}
}()
}
wg.Wait()
}
144 changes: 72 additions & 72 deletions pkg/phlaredb/symdb/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ func WithMaxConcurrent(n int) ResolverOption {
}

type lazyPartition struct {
id uint64
reader chan PartitionReader
samples map[uint32]int64
c chan *Symbols
err chan error
done chan struct{}
}
Expand Down Expand Up @@ -101,112 +102,111 @@ func (r *Resolver) Partition(partition uint64) map[uint32]int64 {
return p.samples
}
p = &lazyPartition{
id: partition,
samples: make(map[uint32]int64),
err: make(chan error),
done: make(chan struct{}),
c: make(chan *Symbols, 1),
reader: make(chan PartitionReader, 1),
}
r.p[partition] = p
r.m.Unlock()
r.g.Go(func() error {
pr, err := r.s.Partition(r.ctx, partition)
if err != nil {
r.span.LogFields(log.String("err", err.Error()))
select {
case <-r.ctx.Done():
return r.ctx.Err()
case p.err <- err:
return err
}
}
defer pr.Release()
p.c <- pr.Symbols()
return r.acquirePartition(p)
})
// r.g.Wait() is only called at Resolver.Release.
return p.samples
}

func (r *Resolver) acquirePartition(p *lazyPartition) error {
pr, err := r.s.Partition(r.ctx, p.id)
if err != nil {
r.span.LogFields(log.String("err", err.Error()))
select {
case <-r.ctx.Done():
return r.ctx.Err()
case <-p.done:
return nil
case p.err <- err:
// Signal the partition receiver
// about the failure, so it won't
// block and return early.
return err
}
})
return p.samples
}
// We've acquired the partition and must release it
// once resolution finished or canceled.
select {
case p.reader <- pr:
// We transferred ownership to the recipient,
// which is now responsible for releasing the
// partition.
<-p.done
case <-r.ctx.Done():
// We still own the partition and must release
// it on our own. It's guaranteed that p.c receiver
// has no access to the partition.
pr.Release()
return r.ctx.Err()
}
return nil
}

func (r *Resolver) Tree() (*model.Tree, error) {
span, ctx := opentracing.StartSpanFromContext(r.ctx, "Resolver.Tree")
defer span.Finish()

g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.c)

var tm sync.Mutex
var lock sync.Mutex
tree := new(model.Tree)

for _, p := range r.p {
p := p
g.Go(func() error {
defer close(p.done)
select {
case <-ctx.Done():
return ctx.Err()
case err := <-p.err:
return err
case symbols := <-p.c:
samples := schemav1.NewSamplesFromMap(p.samples)
rt, err := symbols.Tree(ctx, samples)
if err != nil {
return err
}
tm.Lock()
tree.Merge(rt)
tm.Unlock()
}
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}

return tree, nil
err := r.withSymbols(ctx, func(symbols *Symbols, samples schemav1.Samples) error {
resolved, err := symbols.Tree(ctx, samples)
if err != nil {
return err
}
lock.Lock()
tree.Merge(resolved)
lock.Unlock()
return nil
})
return tree, err
}

func (r *Resolver) Profile() (*profile.Profile, error) {
span, ctx := opentracing.StartSpanFromContext(r.ctx, "Resolver.Profile")
defer span.Finish()
var lock sync.Mutex
profiles := make([]*profile.Profile, 0, len(r.p))
err := r.withSymbols(ctx, func(symbols *Symbols, samples schemav1.Samples) error {
resolved, err := symbols.Profile(ctx, samples)
if err != nil {
return err
}
lock.Lock()
profiles = append(profiles, resolved)
lock.Unlock()
return nil
})
if err != nil {
return nil, err
}
return profile.Merge(profiles)
}

func (r *Resolver) withSymbols(ctx context.Context, fn func(*Symbols, schemav1.Samples) error) error {
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.c)

var rm sync.Mutex
profiles := make([]*profile.Profile, 0, len(r.p))

for _, p := range r.p {
p := p
g.Go(func() error {
defer close(p.done)
select {
case <-ctx.Done():
return ctx.Err()
case err := <-p.err:
return err
case symbols := <-p.c:
samples := schemav1.NewSamplesFromMap(p.samples)
rp, err := symbols.Profile(ctx, samples)
if err != nil {
return err
}
rm.Lock()
profiles = append(profiles, rp)
rm.Unlock()
case <-ctx.Done():
return ctx.Err()
case pr := <-p.reader:
defer pr.Release()
return fn(pr.Symbols(), schemav1.NewSamplesFromMap(p.samples))
}
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}

return profile.Merge(profiles)
return g.Wait()
}

func (r *Symbols) Tree(ctx context.Context, samples schemav1.Samples) (*model.Tree, error) {
Expand Down
68 changes: 68 additions & 0 deletions pkg/phlaredb/symdb/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package symdb
import (
"context"
"io"
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -130,6 +132,40 @@ func Test_Resolver_Error_Propagation(t *testing.T) {
r.Release()
}

func Test_Resolver_Cancellation(t *testing.T) {
s := newBlockSuite(t, [][]string{{"testdata/profile.pb.gz"}})
defer s.teardown()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

const (
workers = 10
iterations = 10
depth = 5
)

var wg sync.WaitGroup
wg.Add(workers)

for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for j := 0; j < iterations; j++ {
for d := 0; d < depth; d++ {
func() {
r := NewResolver(contextCancelAfter(ctx, int64(d)), s.reader)
defer r.Release()
r.AddSamples(0, s.indexed[0][0].Samples)
_, _ = r.Tree()
}()
}
}
}()
}

wg.Wait()
}

type mockSymbolsReader struct{ mock.Mock }

func (m *mockSymbolsReader) Partition(ctx context.Context, partition uint64) (PartitionReader, error) {
Expand All @@ -142,3 +178,35 @@ func (m *mockSymbolsReader) Load(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}

type fakeContext struct {
context.Context
once sync.Once
ch chan struct{}
c atomic.Int64
n int64
}

func contextCancelAfter(ctx context.Context, n int64) context.Context {
return &fakeContext{
ch: make(chan struct{}),
Context: ctx,
n: n,
}
}

func (f *fakeContext) Done() <-chan struct{} {
if f.c.Add(1) > f.n {
f.once.Do(func() {
close(f.ch)
})
}
return f.ch
}

func (f *fakeContext) Err() error {
if f.c.Load() > f.n {
return context.Canceled
}
return f.Context.Err()
}
Loading

0 comments on commit be7bc5d

Please sign in to comment.