Skip to content

Commit

Permalink
protocol: move magic exchange to version payload
Browse files Browse the repository at this point in the history
closes #889
  • Loading branch information
AnnaShaleva committed May 21, 2020
1 parent 1317666 commit e2d7560
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 19 deletions.
8 changes: 1 addition & 7 deletions pkg/network/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package network
import (
"fmt"

"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/consensus"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
Expand All @@ -19,8 +18,6 @@ const (

// Message is the complete message send between nodes.
type Message struct {
// NetMode of the node that sends this message.
Magic config.NetMode

// Command is utf8 code, of which the length is 12 bytes,
// the extra part is filled with 0.
Expand Down Expand Up @@ -61,7 +58,7 @@ const (
)

// NewMessage returns a new message with the given payload.
func NewMessage(magic config.NetMode, cmd CommandType, p payload.Payload) *Message {
func NewMessage(cmd CommandType, p payload.Payload) *Message {
var (
size uint32
)
Expand All @@ -77,7 +74,6 @@ func NewMessage(magic config.NetMode, cmd CommandType, p payload.Payload) *Messa
}

return &Message{
Magic: magic,
Command: cmdToByteArray(cmd),
Length: size,
Payload: p,
Expand Down Expand Up @@ -133,7 +129,6 @@ func (m *Message) CommandType() CommandType {

// Decode decodes a Message from the given reader.
func (m *Message) Decode(br *io.BinReader) error {
m.Magic = config.NetMode(br.ReadU32LE())
br.ReadBytes(m.Command[:])
m.Length = br.ReadU32LE()
if br.Err != nil {
Expand Down Expand Up @@ -191,7 +186,6 @@ func (m *Message) decodePayload(br *io.BinReader) error {

// Encode encodes a Message to any given BinWriter.
func (m *Message) Encode(br *io.BinWriter) error {
br.WriteU32LE(uint32(m.Magic))
br.WriteBytes(m.Command[:])
br.WriteU32LE(m.Length)
if m.Payload != nil {
Expand Down
9 changes: 8 additions & 1 deletion pkg/network/payload/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package payload
import (
"time"

"github.com/nspcc-dev/neo-go/pkg/config"

"github.com/nspcc-dev/neo-go/pkg/io"
)

Expand All @@ -21,6 +23,8 @@ const (

// Version payload.
type Version struct {
// NetMode of the node
Magic config.NetMode
// currently the version of the protocol is 0
Version uint32
// currently 1
Expand All @@ -40,8 +44,9 @@ type Version struct {
}

// NewVersion returns a pointer to a Version payload.
func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version {
func NewVersion(magic config.NetMode, id uint32, p uint16, ua string, h uint32, r bool) *Version {
return &Version{
Magic: magic,
Version: 0,
Services: nodePeerService,
Timestamp: uint32(time.Now().UTC().Unix()),
Expand All @@ -55,6 +60,7 @@ func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version {

// DecodeBinary implements Serializable interface.
func (p *Version) DecodeBinary(br *io.BinReader) {
p.Magic = config.NetMode(br.ReadU32LE())
p.Version = br.ReadU32LE()
p.Services = br.ReadU64LE()
p.Timestamp = br.ReadU32LE()
Expand All @@ -67,6 +73,7 @@ func (p *Version) DecodeBinary(br *io.BinReader) {

// EncodeBinary implements Serializable interface.
func (p *Version) EncodeBinary(br *io.BinWriter) {
br.WriteU32LE(uint32(p.Magic))
br.WriteU32LE(p.Version)
br.WriteU64LE(p.Services)
br.WriteU32LE(p.Timestamp)
Expand Down
5 changes: 4 additions & 1 deletion pkg/network/payload/version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@ package payload
import (
"testing"

"github.com/nspcc-dev/neo-go/pkg/config"

"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/stretchr/testify/assert"
)

func TestVersionEncodeDecode(t *testing.T) {
var magic config.NetMode = 56753
var port uint16 = 3000
var id uint32 = 13337
useragent := "/NEO:0.0.1/"
var height uint32 = 100500
var relay = true

version := NewVersion(id, port, useragent, height, relay)
version := NewVersion(magic, id, port, useragent, height, relay)
versionDecoded := &Version{}
testserdes.EncodeDecodeBinary(t, version, versionDecoded)

Expand Down
14 changes: 7 additions & 7 deletions pkg/network/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Lo
// MkMsg creates a new message based on the server configured network and given
// parameters.
func (s *Server) MkMsg(cmd CommandType, p payload.Payload) *Message {
return NewMessage(s.Net, cmd, p)
return NewMessage(cmd, p)
}

// ID returns the servers ID.
Expand Down Expand Up @@ -354,6 +354,7 @@ func (s *Server) HandshakedPeersCount() int {
// getVersionMsg returns current version message.
func (s *Server) getVersionMsg() *Message {
payload := payload.NewVersion(
s.Net,
s.id,
s.Port,
s.UserAgent,
Expand Down Expand Up @@ -406,6 +407,11 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
if s.id == version.Nonce {
return errIdenticalID
}
// Make sure both server and peer are operating on
// the same network.
if s.Net != version.Magic {
return errInvalidNetwork
}
peerAddr := p.PeerAddr().String()
s.discovery.RegisterConnectedAddr(peerAddr)
s.lock.RLock()
Expand Down Expand Up @@ -673,12 +679,6 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
zap.Stringer("addr", peer.RemoteAddr()),
zap.String("type", string(msg.CommandType())))

// Make sure both server and peer are operating on
// the same network.
if msg.Magic != s.Net {
return errInvalidNetwork
}

if peer.Handshaked() {
if inv, ok := msg.Payload.(*payload.Inventory); ok {
if !inv.Type.Valid() || len(inv.Hashes) == 0 {
Expand Down
14 changes: 11 additions & 3 deletions pkg/network/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) {
p.messageHandler = func(t *testing.T, msg *Message) {
assert.Equal(t, CMDVerack, msg.CommandType())
}
version := payload.NewVersion(1337, 3000, "/NEO-GO/", 0, true)
version := payload.NewVersion(0, 1337, 3000, "/NEO-GO/", 0, true)

require.NoError(t, s.handleVersionCmd(p, version))
}
Expand All @@ -59,6 +59,7 @@ func TestServerNotSendsVerack(t *testing.T) {
p2 = newLocalPeer(t, s)
)
s.id = 1
s.Net = 56753
finished := make(chan struct{})
go func() {
s.run()
Expand All @@ -76,13 +77,20 @@ func TestServerNotSendsVerack(t *testing.T) {
s.register <- p

// identical id's
version := payload.NewVersion(1, 3000, "/NEO-GO/", 0, true)
version := payload.NewVersion(56753, 1, 3000, "/NEO-GO/", 0, true)
err := s.handleVersionCmd(p, version)
assert.NotNil(t, err)
assert.Equal(t, errIdenticalID, err)

// Different IDs, make handshake pass.
// Different IDs, but also different magics
version.Nonce = 2
version.Magic = 56752
err = s.handleVersionCmd(p, version)
assert.NotNil(t, err)
assert.Equal(t, errInvalidNetwork, err)

// Different IDs and same network, make handshake pass.
version.Magic = 56753
require.NoError(t, s.handleVersionCmd(p, version))
require.NoError(t, p.HandleVersionAck())
require.Equal(t, true, p.Handshaked())
Expand Down

0 comments on commit e2d7560

Please sign in to comment.