Skip to content
This repository has been archived by the owner on Oct 17, 2018. It is now read-only.

Commit

Permalink
Add enabled flag for base iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerome Froelich committed Apr 20, 2017
1 parent 772843d commit 7fddb75
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 10 deletions.
25 changes: 24 additions & 1 deletion protocol/msgpack/aggregated_iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,24 @@ func TestAggregatedIteratorDecodeMetricMoreFieldsThanExpected(t *testing.T) {
validateAggregatedDecodeResults(t, it, []metricWithPolicy{input}, io.EOF)
}

func TestAggregatedIteratorPolicyDecompressedNotEnabledError(t *testing.T) {
input := metricWithPolicy{
metric: testMetric,
policy: testPolicy,
}

// Use an encoder which compresses testPolicy
enc := testAggregatedEncoder(t, testBaseEncoderOptions).(*aggregatedEncoder)

testAggregatedEncode(t, enc, input.metric.(aggregated.Metric), input.policy)
require.NoError(t, enc.err())

// Use an iterator which does not have decompression enabled
it := testAggregatedIterator(t, enc.Encoder().Buffer(), nil)

validateAggregatedDecodeResults(t, it, nil, errPolicyDecompressionNotEnabled)
}

func TestAggregatedIteratorUnrecognizedCompressedPolicyError(t *testing.T) {
input := metricWithPolicy{
metric: testMetric,
Expand All @@ -180,7 +198,12 @@ func TestAggregatedIteratorUnrecognizedCompressedPolicyError(t *testing.T) {
require.NoError(t, enc.err())

// Use an iterator which does not have testPolicy in it's decompressor
it := testAggregatedIterator(t, enc.Encoder().Buffer(), nil)
baseItOpts := baseIteratorOptions{
enabled: true,
decompressor: policy.NewNoopDecompressor(),
}

it := testAggregatedIterator(t, enc.Encoder().Buffer(), baseItOpts)

id, _ := testBaseEncoderOptions.PolicyCompressor().ID(input.policy)
validateAggregatedDecodeResults(t, it, nil, fmt.Errorf("unrecognized compression policy id: %v", id))
Expand Down
1 change: 1 addition & 0 deletions protocol/msgpack/aggregated_roundtrip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ var (
compressor: testCompressor,
}
testBaseIteratorOptions = baseIteratorOptions{
enabled: true,
decompressor: testDecompressor,
}
)
Expand Down
11 changes: 8 additions & 3 deletions protocol/msgpack/base_encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,14 @@ func newBaseEncoder(encoder BufferedEncoder, opts BaseEncoderOptions) encoderBas
opts = NewBaseEncoderOptions()
}

compressor := opts.PolicyCompressor()
if compressor == nil {
compressor = policy.NewNoopCompressor()
}

enc := &baseEncoder{
bufEncoder: encoder,
policyCompressor: opts.PolicyCompressor(),
policyCompressor: compressor,
policyCompressionEnabled: opts.PolicyCompressionEnabled(),
}

Expand Down Expand Up @@ -131,7 +136,7 @@ func (enc *baseEncoder) encodeResolution(resolution policy.Resolution) {
return
}
// Otherwise encode the entire resolution object.
// TODO(xichen): validate the resolution before putting it on the wire
// TODO(xichen): validate the resolution before putting it on the wire.
enc.encodeNumObjectFields(numFieldsForType(unknownResolutionType))
enc.encodeObjectType(unknownResolutionType)
enc.encodeVarintFn(int64(resolution.Window))
Expand All @@ -150,7 +155,7 @@ func (enc *baseEncoder) encodeRetention(retention policy.Retention) {
return
}
// Otherwise encode the entire retention object.
// TODO(xichen): validate the retention before putting it on the wire
// TODO(xichen): validate the retention before putting it on the wire.
enc.encodeNumObjectFields(numFieldsForType(unknownRetentionType))
enc.encodeObjectType(unknownRetentionType)
enc.encodeVarintFn(int64(retention))
Expand Down
27 changes: 21 additions & 6 deletions protocol/msgpack/base_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package msgpack

import (
"bytes"
"errors"
"fmt"
"io"
"time"
Expand All @@ -34,23 +35,32 @@ import (
)

var (
emptyReader *bytes.Buffer
emptyReader *bytes.Buffer
errPolicyDecompressionNotEnabled = errors.New("policy decompression is not enabled but recieved compressed policy id")
)

// baseIterator is the base iterator that provides common decoding APIs.
type baseIterator struct {
decoder *msgpack.Decoder // internal decoder that does the actual decoding
policyDecompressor policy.Decompressor // decompressor used to decompress policies
decodeErr error // error encountered during decoding
decoder *msgpack.Decoder // internal decoder that does the actual decoding
policyDecompressor policy.Decompressor // decompressor used to decompress policies
policyDecompressionEnabled bool // flag indicating whether policy decompression is enabled
decodeErr error // error encountered during decoding
}

func newBaseIterator(reader io.Reader, opts BaseIteratorOptions) iteratorBase {
if opts == nil {
opts = NewBaseIteratorOptions()
}

decompressor := opts.PolicyDecompressor()
if decompressor == nil {
decompressor = policy.NewNoopDecompressor()
}

return &baseIterator{
decoder: msgpack.NewDecoder(reader),
policyDecompressor: opts.PolicyDecompressor(),
decoder: msgpack.NewDecoder(reader),
policyDecompressor: decompressor,
policyDecompressionEnabled: opts.PolicyDecompressionEnabled(),
}
}

Expand Down Expand Up @@ -85,6 +95,11 @@ func (it *baseIterator) decodePolicy() policy.Policy {
it.skip(numActualFields - numExpectedFields)
return p
case compressedPolicyType:
if !it.policyDecompressionEnabled {
it.decodeErr = errPolicyDecompressionNotEnabled
return policy.EmptyPolicy
}

id := it.decodeVarint()
if it.decodeErr != nil {
return policy.EmptyPolicy
Expand Down
11 changes: 11 additions & 0 deletions protocol/msgpack/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func (o baseEncoderOptions) PolicyCompressor() policy.Compressor {
}

type baseIteratorOptions struct {
enabled bool
decompressor policy.Decompressor
}

Expand All @@ -79,6 +80,16 @@ func NewBaseIteratorOptions() BaseIteratorOptions {
}
}

func (o baseIteratorOptions) SetPolicyDecompressionEnabled(enabled bool) BaseIteratorOptions {
opts := o
opts.enabled = enabled
return opts
}

func (o baseIteratorOptions) PolicyDecompressionEnabled() bool {
return o.enabled
}

func (o baseIteratorOptions) SetPolicyDecompressor(decompressor policy.Decompressor) BaseIteratorOptions {
opts := o
opts.decompressor = decompressor
Expand Down
8 changes: 8 additions & 0 deletions protocol/msgpack/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ type iteratorBase interface {

// BaseIteratorOptions provide options for base iterators
type BaseIteratorOptions interface {
// SetPolicyDecompressionEnabled determines whether the iterator will attempt
// to decode policies that have been compressed
SetPolicyDecompressionEnabled(enabled bool) BaseIteratorOptions

// PolicyDecompressionEnabled returns whether the iterator will attempt to
// decode policies that have been compressed
PolicyDecompressionEnabled() bool

// SetPolicyDecompressor sets the policy Decompressor that will be used to
// decompress policies
SetPolicyDecompressor(decompressor policy.Decompressor) BaseIteratorOptions
Expand Down
2 changes: 2 additions & 0 deletions protocol/msgpack/unaggregated_roundtrip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ var (
compressor: testCompressorDefaultPolicies,
}
testDefaultPoliciesDecompressionOptions = baseIteratorOptions{
enabled: true,
decompressor: testDecompressorDefaultPolicies,
}

Expand All @@ -148,6 +149,7 @@ var (
compressor: testCompressorCustomPolicies,
}
testCustomPoliciesDecompressionOptions = baseIteratorOptions{
enabled: true,
decompressor: testDecompressorCustomPolicies,
}
)
Expand Down

0 comments on commit 7fddb75

Please sign in to comment.