diff --git a/p2p/protocols/context.go b/p2p/protocols/context.go new file mode 100644 index 0000000000..41e103575e --- /dev/null +++ b/p2p/protocols/context.go @@ -0,0 +1,79 @@ +package protocols + +import ( + "bufio" + "bytes" + "context" + "io/ioutil" + + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethersphere/swarm/spancontext" + opentracing "github.com/opentracing/opentracing-go" +) + +// msgWithContext is used to propagate marshalled context alongside message payloads +type msgWithContext struct { + Context []byte + Msg []byte +} + +func encodeWithContext(ctx context.Context, msg interface{}) (interface{}, int, error) { + var b bytes.Buffer + writer := bufio.NewWriter(&b) + tracer := opentracing.GlobalTracer() + sctx := spancontext.FromContext(ctx) + if sctx != nil { + err := tracer.Inject( + sctx, + opentracing.Binary, + writer) + if err != nil { + return nil, 0, err + } + } + writer.Flush() + msgBytes, err := rlp.EncodeToBytes(msg) + if err != nil { + return nil, 0, err + } + + return &msgWithContext{ + Context: b.Bytes(), + Msg: msgBytes, + }, len(msgBytes), nil +} + +func decodeWithContext(msg p2p.Msg) (context.Context, []byte, error) { + var wmsg msgWithContext + err := msg.Decode(&wmsg) + if err != nil { + return nil, nil, err + } + + ctx := context.Background() + + if len(wmsg.Context) == 0 { + return ctx, wmsg.Msg, nil + } + + tracer := opentracing.GlobalTracer() + sctx, err := tracer.Extract(opentracing.Binary, bytes.NewReader(wmsg.Context)) + if err != nil { + return nil, nil, err + } + ctx = spancontext.WithContext(ctx, sctx) + return ctx, wmsg.Msg, nil +} + +func encodeWithoutContext(ctx context.Context, msg interface{}) (interface{}, int, error) { + return msg, 0, nil +} + +func decodeWithoutContext(msg p2p.Msg) (context.Context, []byte, error) { + b, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return nil, nil, err + } + return context.Background(), b, nil +} diff --git a/p2p/protocols/protocol.go b/p2p/protocols/protocol.go index d5b68e5061..4fc77fabb0 100644 --- a/p2p/protocols/protocol.go +++ b/p2p/protocols/protocol.go @@ -29,8 +29,6 @@ devp2p subprotocols by abstracting away code standardly shared by protocols. package protocols import ( - "bufio" - "bytes" "context" "fmt" "io" @@ -42,9 +40,7 @@ import ( "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethersphere/swarm/spancontext" "github.com/ethersphere/swarm/tracing" - opentracing "github.com/opentracing/opentracing-go" ) // error codes used by this protocol scheme @@ -115,13 +111,6 @@ func errorf(code int, format string, params ...interface{}) *Error { } } -// WrappedMsg is used to propagate marshalled context alongside message payloads -type WrappedMsg struct { - Context []byte - Size uint32 - Payload []byte -} - //For accounting, the design is to allow the Spec to describe which and how its messages are priced //To access this functionality, we provide a Hook interface which will call accounting methods //NOTE: there could be more such (horizontal) hooks in the future @@ -157,6 +146,10 @@ type Spec struct { initOnce sync.Once codes map[reflect.Type]uint64 types map[uint64]reflect.Type + + // if the protocol does not allow extending the p2p msg to propagate context + // even if context not disabled, context will propagate only tracing is enabled + DisableContext bool } func (s *Spec) init() { @@ -208,17 +201,27 @@ type Peer struct { *p2p.Peer // the p2p.Peer object representing the remote rw p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from spec *Spec + encode func(context.Context, interface{}) (interface{}, int, error) + decode func(p2p.Msg) (context.Context, []byte, error) } // NewPeer constructs a new peer // this constructor is called by the p2p.Protocol#Run function // the first two arguments are the arguments passed to p2p.Protocol.Run function // the third argument is the Spec describing the protocol -func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer { +func NewPeer(peer *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer { + encode := encodeWithContext + decode := decodeWithContext + if spec == nil || spec.DisableContext || !tracing.Enabled { + encode = encodeWithoutContext + decode = decodeWithoutContext + } return &Peer{ - Peer: p, - rw: rw, - spec: spec, + Peer: peer, + rw: rw, + spec: spec, + encode: encode, + decode: decode, } } @@ -234,7 +237,6 @@ func (p *Peer) Run(handler func(ctx context.Context, msg interface{}) error) err metrics.GetOrRegisterCounter("peer.handleincoming.error", nil).Inc(1) log.Error("peer.handleIncoming", "err", err) } - return err } } @@ -256,51 +258,32 @@ func (p *Peer) Send(ctx context.Context, msg interface{}) error { metrics.GetOrRegisterCounter("peer.send", nil).Inc(1) metrics.GetOrRegisterCounter(fmt.Sprintf("peer.send.%T", msg), nil).Inc(1) - var b bytes.Buffer - if tracing.Enabled { - writer := bufio.NewWriter(&b) - - tracer := opentracing.GlobalTracer() - - sctx := spancontext.FromContext(ctx) - - if sctx != nil { - err := tracer.Inject( - sctx, - opentracing.Binary, - writer) - if err != nil { - return err - } - } - - writer.Flush() + code, found := p.spec.GetCode(msg) + if !found { + return errorf(ErrInvalidMsgType, "%v", code) } - r, err := rlp.EncodeToBytes(msg) + wmsg, size, err := p.encode(ctx, msg) if err != nil { return err } - wmsg := WrappedMsg{ - Context: b.Bytes(), - Size: uint32(len(r)), - Payload: r, + // if size is not set by the wrapper, need to serialise + if size == 0 { + r, err := rlp.EncodeToBytes(msg) + if err != nil { + return err + } + size = len(r) } - - //if the accounting hook is set, call it + // if the accounting hook is set, call it if p.spec.Hook != nil { - err := p.spec.Hook.Send(p, wmsg.Size, msg) + err = p.spec.Hook.Send(p, uint32(size), msg) if err != nil { - p.Drop() return err } } - code, found := p.spec.GetCode(msg) - if !found { - return errorf(ErrInvalidMsgType, "%v", code) - } return p2p.Send(p.rw, code, wmsg) } @@ -324,44 +307,23 @@ func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{}) return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize) } - // unmarshal wrapped msg, which might contain context - var wmsg WrappedMsg - err = msg.Decode(&wmsg) - if err != nil { - log.Error(err.Error()) - return err - } - - ctx := context.Background() - - // if tracing is enabled and the context coming within the request is - // not empty, try to unmarshal it - if tracing.Enabled && len(wmsg.Context) > 0 { - var sctx opentracing.SpanContext - - tracer := opentracing.GlobalTracer() - sctx, err = tracer.Extract( - opentracing.Binary, - bytes.NewReader(wmsg.Context)) - if err != nil { - log.Error(err.Error()) - return err - } - - ctx = spancontext.WithContext(ctx, sctx) - } - val, ok := p.spec.NewMsg(msg.Code) if !ok { return errorf(ErrInvalidMsgCode, "%v", msg.Code) } - if err := rlp.DecodeBytes(wmsg.Payload, val); err != nil { + + ctx, msgBytes, err := p.decode(msg) + if err != nil { + return errorf(ErrDecode, "%v err=%v", msg.Code, err) + } + + if err := rlp.DecodeBytes(msgBytes, val); err != nil { return errorf(ErrDecode, "<= %v: %v", msg, err) } - //if the accounting hook is set, call it + // if the accounting hook is set, call it if p.spec.Hook != nil { - err := p.spec.Hook.Receive(p, wmsg.Size, val) + err := p.spec.Hook.Receive(p, uint32(len(msgBytes)), val) if err != nil { return err } diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go index f57714894b..4a640609d1 100644 --- a/p2p/protocols/protocol_test.go +++ b/p2p/protocols/protocol_test.go @@ -249,9 +249,7 @@ func TestProtocolHook(t *testing.T) { runFunc := func(p *p2p.Peer, rw p2p.MsgReadWriter) error { peer := NewPeer(p, rw, spec) ctx := context.TODO() - err := peer.Send(ctx, &dummyMsg{ - Content: "handshake"}) - + err := peer.Send(ctx, &dummyMsg{Content: "handshake"}) if err != nil { t.Fatal(err) } @@ -281,6 +279,7 @@ func TestProtocolHook(t *testing.T) { if err != nil { t.Fatal(err) } + testHook.mu.Lock() if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "handshake" { t.Fatal("Expected msg to be set, but it is not") @@ -291,8 +290,8 @@ func TestProtocolHook(t *testing.T) { if testHook.peer == nil { t.Fatal("Expected peer to be set, is nil") } - if peerId := testHook.peer.ID(); peerId != tester.Nodes[0].ID() && peerId != tester.Nodes[1].ID() { - t.Fatalf("Expected peer ID to be set correctly, but it is not (got %v, exp %v or %v", peerId, tester.Nodes[0].ID(), tester.Nodes[1].ID()) + if peerID := testHook.peer.ID(); peerID != tester.Nodes[0].ID() && peerID != tester.Nodes[1].ID() { + t.Fatalf("Expected peer ID to be set correctly, but it is not (got %v, exp %v or %v", peerID, tester.Nodes[0].ID(), tester.Nodes[1].ID()) } if testHook.size != 11 { //11 is the length of the encoded message t.Fatalf("Expected size to be %d, but it is %d ", 1, testHook.size) @@ -309,11 +308,10 @@ func TestProtocolHook(t *testing.T) { }, }) - <-testHook.waitC - if err != nil { t.Fatal(err) } + <-testHook.waitC testHook.mu.Lock() if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "response" { @@ -600,7 +598,11 @@ func (d *dummyRW) WriteMsg(msg p2p.Msg) error { } func (d *dummyRW) ReadMsg() (p2p.Msg, error) { - enc := bytes.NewReader(d.getDummyMsg()) + r, err := rlp.EncodeToBytes(d.msg) + if err != nil { + return p2p.Msg{}, err + } + enc := bytes.NewReader(r) return p2p.Msg{ Code: d.code, Size: d.size, @@ -608,16 +610,3 @@ func (d *dummyRW) ReadMsg() (p2p.Msg, error) { ReceivedAt: time.Now(), }, nil } - -func (d *dummyRW) getDummyMsg() []byte { - r, _ := rlp.EncodeToBytes(d.msg) - var b bytes.Buffer - wmsg := WrappedMsg{ - Context: b.Bytes(), - Size: uint32(len(r)), - Payload: r, - } - rr, _ := rlp.EncodeToBytes(wmsg) - d.size = uint32(len(rr)) - return rr -} diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go index e798240a56..333ec905cd 100644 --- a/p2p/testing/protocoltester.go +++ b/p2p/testing/protocoltester.go @@ -180,8 +180,7 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error { for { select { case trig := <-m.trigger: - wmsg := Wrap(trig.Msg) - m.err <- p2p.Send(rw, trig.Code, wmsg) + m.err <- p2p.Send(rw, trig.Code, trig.Msg) case exps := <-m.expect: m.err <- expectMsgs(rw, exps) case <-m.stop: @@ -221,7 +220,7 @@ func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error { } var found bool for i, exp := range exps { - if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(Wrap(exp.Msg))) { + if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) { if matched[i] { return fmt.Errorf("message #%d received two times", i) } @@ -236,7 +235,7 @@ func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error { if matched[i] { continue } - expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(Wrap(exp.Msg)))) + expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg))) } return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or ")) } @@ -268,17 +267,3 @@ func mustEncodeMsg(msg interface{}) []byte { } return contentEnc } - -type WrappedMsg struct { - Context []byte - Size uint32 - Payload []byte -} - -func Wrap(msg interface{}) interface{} { - data, _ := rlp.EncodeToBytes(msg) - return &WrappedMsg{ - Size: uint32(len(data)), - Payload: data, - } -}