Skip to content
This repository has been archived by the owner on Aug 2, 2021. It is now read-only.

p2p/protocols, p2p/testing; conditional propagation of context #1648

Merged
merged 2 commits into from
Aug 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions p2p/protocols/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package protocols

import (
"bufio"
"bytes"
"context"
"io/ioutil"

"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethersphere/swarm/spancontext"
opentracing "github.com/opentracing/opentracing-go"
)

// msgWithContext is used to propagate marshalled context alongside message payloads
type msgWithContext struct {
Context []byte
Msg []byte
}

func encodeWithContext(ctx context.Context, msg interface{}) (interface{}, int, error) {
var b bytes.Buffer
writer := bufio.NewWriter(&b)
tracer := opentracing.GlobalTracer()
sctx := spancontext.FromContext(ctx)
if sctx != nil {
err := tracer.Inject(
sctx,
opentracing.Binary,
writer)
if err != nil {
return nil, 0, err
}
}
writer.Flush()
msgBytes, err := rlp.EncodeToBytes(msg)
if err != nil {
return nil, 0, err
}

return &msgWithContext{
Context: b.Bytes(),
Msg: msgBytes,
}, len(msgBytes), nil
}

func decodeWithContext(msg p2p.Msg) (context.Context, []byte, error) {
var wmsg msgWithContext
err := msg.Decode(&wmsg)
if err != nil {
return nil, nil, err
}

ctx := context.Background()

if len(wmsg.Context) == 0 {
return ctx, wmsg.Msg, nil
}

tracer := opentracing.GlobalTracer()
sctx, err := tracer.Extract(opentracing.Binary, bytes.NewReader(wmsg.Context))
if err != nil {
return nil, nil, err
}
ctx = spancontext.WithContext(ctx, sctx)
return ctx, wmsg.Msg, nil
}

func encodeWithoutContext(ctx context.Context, msg interface{}) (interface{}, int, error) {
return msg, 0, nil
}

func decodeWithoutContext(msg p2p.Msg) (context.Context, []byte, error) {
b, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return nil, nil, err
}
return context.Background(), b, nil
}
118 changes: 40 additions & 78 deletions p2p/protocols/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ devp2p subprotocols by abstracting away code standardly shared by protocols.
package protocols

import (
"bufio"
"bytes"
"context"
"fmt"
"io"
Expand All @@ -42,9 +40,7 @@ import (
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethersphere/swarm/spancontext"
"github.com/ethersphere/swarm/tracing"
opentracing "github.com/opentracing/opentracing-go"
)

// error codes used by this protocol scheme
Expand Down Expand Up @@ -115,13 +111,6 @@ func errorf(code int, format string, params ...interface{}) *Error {
}
}

// WrappedMsg is used to propagate marshalled context alongside message payloads
type WrappedMsg struct {
Context []byte
Size uint32
Payload []byte
}

//For accounting, the design is to allow the Spec to describe which and how its messages are priced
//To access this functionality, we provide a Hook interface which will call accounting methods
//NOTE: there could be more such (horizontal) hooks in the future
Expand Down Expand Up @@ -157,6 +146,10 @@ type Spec struct {
initOnce sync.Once
codes map[reflect.Type]uint64
types map[uint64]reflect.Type

// if the protocol does not allow extending the p2p msg to propagate context
// even if context not disabled, context will propagate only tracing is enabled
DisableContext bool
}

func (s *Spec) init() {
Expand Down Expand Up @@ -208,17 +201,27 @@ type Peer struct {
*p2p.Peer // the p2p.Peer object representing the remote
rw p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from
spec *Spec
encode func(context.Context, interface{}) (interface{}, int, error)
decode func(p2p.Msg) (context.Context, []byte, error)
}

// NewPeer constructs a new peer
// this constructor is called by the p2p.Protocol#Run function
// the first two arguments are the arguments passed to p2p.Protocol.Run function
// the third argument is the Spec describing the protocol
func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer {
func NewPeer(peer *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer {
encode := encodeWithContext
decode := decodeWithContext
if spec == nil || spec.DisableContext || !tracing.Enabled {
encode = encodeWithoutContext
decode = decodeWithoutContext
}
return &Peer{
Peer: p,
rw: rw,
spec: spec,
Peer: peer,
rw: rw,
spec: spec,
encode: encode,
decode: decode,
}
}

Expand All @@ -234,7 +237,6 @@ func (p *Peer) Run(handler func(ctx context.Context, msg interface{}) error) err
metrics.GetOrRegisterCounter("peer.handleincoming.error", nil).Inc(1)
log.Error("peer.handleIncoming", "err", err)
}

return err
}
}
Expand All @@ -256,51 +258,32 @@ func (p *Peer) Send(ctx context.Context, msg interface{}) error {
metrics.GetOrRegisterCounter("peer.send", nil).Inc(1)
metrics.GetOrRegisterCounter(fmt.Sprintf("peer.send.%T", msg), nil).Inc(1)

var b bytes.Buffer
if tracing.Enabled {
writer := bufio.NewWriter(&b)

tracer := opentracing.GlobalTracer()

sctx := spancontext.FromContext(ctx)

if sctx != nil {
err := tracer.Inject(
sctx,
opentracing.Binary,
writer)
if err != nil {
return err
}
}

writer.Flush()
code, found := p.spec.GetCode(msg)
if !found {
return errorf(ErrInvalidMsgType, "%v", code)
}

r, err := rlp.EncodeToBytes(msg)
wmsg, size, err := p.encode(ctx, msg)
if err != nil {
return err
}

wmsg := WrappedMsg{
Context: b.Bytes(),
Size: uint32(len(r)),
Payload: r,
// if size is not set by the wrapper, need to serialise
if size == 0 {
r, err := rlp.EncodeToBytes(msg)
if err != nil {
return err
}
size = len(r)
}

//if the accounting hook is set, call it
// if the accounting hook is set, call it
if p.spec.Hook != nil {
err := p.spec.Hook.Send(p, wmsg.Size, msg)
err = p.spec.Hook.Send(p, uint32(size), msg)
if err != nil {
p.Drop()
return err
}
}

code, found := p.spec.GetCode(msg)
if !found {
return errorf(ErrInvalidMsgType, "%v", code)
}
return p2p.Send(p.rw, code, wmsg)
}

Expand All @@ -324,44 +307,23 @@ func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{})
return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize)
}

// unmarshal wrapped msg, which might contain context
var wmsg WrappedMsg
err = msg.Decode(&wmsg)
if err != nil {
log.Error(err.Error())
return err
}

ctx := context.Background()

// if tracing is enabled and the context coming within the request is
// not empty, try to unmarshal it
if tracing.Enabled && len(wmsg.Context) > 0 {
var sctx opentracing.SpanContext

tracer := opentracing.GlobalTracer()
sctx, err = tracer.Extract(
opentracing.Binary,
bytes.NewReader(wmsg.Context))
if err != nil {
log.Error(err.Error())
return err
}

ctx = spancontext.WithContext(ctx, sctx)
}

val, ok := p.spec.NewMsg(msg.Code)
if !ok {
return errorf(ErrInvalidMsgCode, "%v", msg.Code)
}
if err := rlp.DecodeBytes(wmsg.Payload, val); err != nil {

ctx, msgBytes, err := p.decode(msg)
if err != nil {
return errorf(ErrDecode, "%v err=%v", msg.Code, err)
}

if err := rlp.DecodeBytes(msgBytes, val); err != nil {
return errorf(ErrDecode, "<= %v: %v", msg, err)
}

//if the accounting hook is set, call it
// if the accounting hook is set, call it
if p.spec.Hook != nil {
err := p.spec.Hook.Receive(p, wmsg.Size, val)
err := p.spec.Hook.Receive(p, uint32(len(msgBytes)), val)
if err != nil {
return err
}
Expand Down
31 changes: 10 additions & 21 deletions p2p/protocols/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,7 @@ func TestProtocolHook(t *testing.T) {
runFunc := func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
peer := NewPeer(p, rw, spec)
ctx := context.TODO()
err := peer.Send(ctx, &dummyMsg{
Content: "handshake"})

err := peer.Send(ctx, &dummyMsg{Content: "handshake"})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -281,6 +279,7 @@ func TestProtocolHook(t *testing.T) {
if err != nil {
t.Fatal(err)
}

testHook.mu.Lock()
if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "handshake" {
t.Fatal("Expected msg to be set, but it is not")
Expand All @@ -291,8 +290,8 @@ func TestProtocolHook(t *testing.T) {
if testHook.peer == nil {
t.Fatal("Expected peer to be set, is nil")
}
if peerId := testHook.peer.ID(); peerId != tester.Nodes[0].ID() && peerId != tester.Nodes[1].ID() {
t.Fatalf("Expected peer ID to be set correctly, but it is not (got %v, exp %v or %v", peerId, tester.Nodes[0].ID(), tester.Nodes[1].ID())
if peerID := testHook.peer.ID(); peerID != tester.Nodes[0].ID() && peerID != tester.Nodes[1].ID() {
t.Fatalf("Expected peer ID to be set correctly, but it is not (got %v, exp %v or %v", peerID, tester.Nodes[0].ID(), tester.Nodes[1].ID())
}
if testHook.size != 11 { //11 is the length of the encoded message
t.Fatalf("Expected size to be %d, but it is %d ", 1, testHook.size)
Expand All @@ -309,11 +308,10 @@ func TestProtocolHook(t *testing.T) {
},
})

<-testHook.waitC

if err != nil {
t.Fatal(err)
}
<-testHook.waitC

testHook.mu.Lock()
if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "response" {
Expand Down Expand Up @@ -600,24 +598,15 @@ func (d *dummyRW) WriteMsg(msg p2p.Msg) error {
}

func (d *dummyRW) ReadMsg() (p2p.Msg, error) {
enc := bytes.NewReader(d.getDummyMsg())
r, err := rlp.EncodeToBytes(d.msg)
if err != nil {
return p2p.Msg{}, err
}
enc := bytes.NewReader(r)
return p2p.Msg{
Code: d.code,
Size: d.size,
Payload: enc,
ReceivedAt: time.Now(),
}, nil
}

func (d *dummyRW) getDummyMsg() []byte {
r, _ := rlp.EncodeToBytes(d.msg)
var b bytes.Buffer
wmsg := WrappedMsg{
Context: b.Bytes(),
Size: uint32(len(r)),
Payload: r,
}
rr, _ := rlp.EncodeToBytes(wmsg)
d.size = uint32(len(rr))
return rr
}
Loading