From b2d0351a9f6c082d94d2995e1a4ddf5e282007b7 Mon Sep 17 00:00:00 2001 From: wirepair Date: Sun, 2 Apr 2017 21:04:18 +0900 Subject: [PATCH 01/11] initial commit for go port --- go/buffer.go | 241 +++++++++++++++++++++++++++ go/buffer_test.go | 266 ++++++++++++++++++++++++++++++ go/challenge_token.go | 36 +++++ go/challenge_token_test.go | 9 ++ go/client.go | 38 +++++ go/config.go | 19 +++ go/connect_token.go | 207 ++++++++++++++++++++++++ go/connect_token_test.go | 61 +++++++ go/crypto.go | 37 +++++ go/encryption_manager.go | 1 + go/packet.go | 324 +++++++++++++++++++++++++++++++++++++ go/queue.go | 1 + go/server.go | 1 + go/simulator.go | 1 + go/sizes.go | 56 +++++++ go/socket.go | 73 +++++++++ 16 files changed, 1371 insertions(+) create mode 100644 go/buffer.go create mode 100644 go/buffer_test.go create mode 100644 go/challenge_token.go create mode 100644 go/challenge_token_test.go create mode 100644 go/client.go create mode 100644 go/config.go create mode 100644 go/connect_token.go create mode 100644 go/connect_token_test.go create mode 100644 go/crypto.go create mode 100644 go/encryption_manager.go create mode 100644 go/packet.go create mode 100644 go/queue.go create mode 100644 go/server.go create mode 100644 go/simulator.go create mode 100644 go/sizes.go create mode 100644 go/socket.go diff --git a/go/buffer.go b/go/buffer.go new file mode 100644 index 0000000..b1f5cec --- /dev/null +++ b/go/buffer.go @@ -0,0 +1,241 @@ +package netcode + +import ( + "math" + "io" +) + +type Buffer struct { + Buf []byte + Pos int +} + +func NewBuffer(size int) *Buffer { + b := &Buffer{} + b.Buf = make([]byte, size) + return b +} + +func NewBufferFromBytes(buf []byte) *Buffer { + b := &Buffer{} + b.Buf = buf + return b +} + +func (b *Buffer) Copy() *Buffer { + c := NewBufferFromBytes(b.Buf) + return c +} + +func (b *Buffer) Len() int { + return len(b.Buf) +} + +func (b *Buffer) Bytes() []byte { + return b.Buf +} + +// GetByte decodes a little-endian byte +func (b *Buffer) GetByte() (byte, error) { + return b.GetUint8() +} + +// GetBytes returns a byte slice possibly smaller than length if bytes are not available from the +// reader. +func (b *Buffer) GetBytes(length int) ([]byte, error) { + if len(b.Buf) < length { + return nil, io.EOF + } + value := b.Buf[b.Pos:b.Pos+length] + b.Pos += length + return value, nil +} + +// GetUint8 decodes a little-endian uint8 from the buffer +func (b *Buffer) GetUint8() (uint8, error) { + if b.Pos + SizeUint8 > len(b.Buf) { + return 0, io.EOF + } + buf := b.Buf[b.Pos:b.Pos+SizeUint8] + b.Pos++ + return uint8(buf[0]), nil +} + +// GetUint16 decodes a little-endian uint16 from the buffer +func (b *Buffer) GetUint16() (uint16, error) { + var n uint16 + buf, err := b.GetBytes(SizeUint16) + n |= uint16(buf[0]) + n |= uint16(buf[1]) << 8 + return n, err +} + +// GetUint32 decodes a little-endian uint32 from the buffer +func (b *Buffer) GetUint32() (uint32, error) { + var n uint32 + buf, err := b.GetBytes(SizeUint32) + n |= uint32(buf[0]) + n |= uint32(buf[1]) << 8 + n |= uint32(buf[2]) << 16 + n |= uint32(buf[3]) << 24 + return n, err +} + +// GetUint64 decodes a little-endian uint64 from the buffer +func (b *Buffer) GetUint64() (uint64, error) { + var n uint64 + buf, err := b.GetBytes(SizeUint64) + n |= uint64(buf[0]) + n |= uint64(buf[1]) << 8 + n |= uint64(buf[2]) << 16 + n |= uint64(buf[3]) << 24 + n |= uint64(buf[4]) << 32 + n |= uint64(buf[5]) << 40 + n |= uint64(buf[6]) << 48 + n |= uint64(buf[7]) << 56 + return n, err +} + +// GetInt8 decodes a little-endian int8 from the buffer +func (b *Buffer) GetInt8() (int8, error) { + if b.Pos + 1 > len(b.Buf) { + return 0, io.EOF + } + buf := b.Buf[b.Pos:b.Pos+SizeInt8] + return int8(buf[0]), nil +} + +// GetInt16 decodes a little-endian int16 from the buffer +func (b *Buffer) GetInt16() (int16, error) { + var n int16 + buf, err := b.GetBytes(SizeInt16) + n |= int16(buf[0]) + n |= int16(buf[1]) << 8 + return n, err +} + +// GetInt32 decodes a little-endian int32 from the buffer +func (b *Buffer) GetInt32() (int32, error) { + var n int32 + buf, err := b.GetBytes(SizeInt32) + n |= int32(buf[0]) + n |= int32(buf[1]) << 8 + n |= int32(buf[2]) << 16 + n |= int32(buf[3]) << 24 + return n, err +} + +// GetInt64 decodes a little-endian int64 from the buffer +func (b *Buffer) GetInt64() (int64, error) { + var n int64 + buf, err := b.GetBytes(SizeInt64) + n |= int64(buf[0]) + n |= int64(buf[1]) << 8 + n |= int64(buf[2]) << 16 + n |= int64(buf[3]) << 24 + n |= int64(buf[4]) << 32 + n |= int64(buf[5]) << 40 + n |= int64(buf[6]) << 48 + n |= int64(buf[7]) << 56 + return n, err +} + + +// WriteByte encodes a little-endian uint8 into the buffer. +func (b *Buffer) WriteByte(n byte) { + b.Buf[b.Pos] = uint8(n) + b.Pos++ +} + +// WriteBytes encodes a little-endian byte slice into the buffer +func (b *Buffer) WriteBytes(src []byte) { + for i := 0; i < len(src); i+=1 { + b.WriteByte(uint8(src[i])) + } +} + +// WriteBytes encodes a little-endian byte slice into the buffer +func (b *Buffer) WriteBytesN(src []byte, length int) { + for i := 0; i < length; i+=1 { + b.WriteByte(uint8(src[i])) + } +} + + +// WriteUint8 encodes a little-endian uint8 into the buffer. +func (b *Buffer) WriteUint8(n uint8) { + b.Buf[b.Pos] = byte(n) + b.Pos++ +} + +// WriteUint16 encodes a little-endian uint16 into the buffer. +func (b *Buffer) WriteUint16(n uint16) { + b.Buf[b.Pos] = byte(n) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 8) + b.Pos++ +} + +// WriteUint32 encodes a little-endian uint32 into the buffer. +func (b *Buffer) WriteUint32(n uint32) { + b.Buf[b.Pos] = byte(n) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 8) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 16) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 24) + b.Pos++ +} + +// WriteUint64 encodes a little-endian uint64 into the buffer. +func (b *Buffer) WriteUint64(n uint64) { + for i := uint(0); i < uint(SizeUint64); i++ { + b.Buf[b.Pos] = byte(n >> (i * 8)) + b.Pos++ + } +} + +// WriteInt8 encodes a little-endian int8 into the buffer. +func (b *Buffer) WriteInt8(n int8) { + b.Buf[b.Pos] = byte(n) + b.Pos++ +} + +// WriteInt16 encodes a little-endian int16 into the buffer. +func (b *Buffer) WriteInt16(n int16) { + b.Buf[b.Pos] = byte(n) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 8) + b.Pos++ +} + +// WriteInt32 encodes a little-endian int32 into the buffer. +func (b *Buffer) WriteInt32(n int32) { + b.Buf[b.Pos] = byte(n) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 8) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 16) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 24) + b.Pos++ +} + +// WriteInt64 encodes a little-endian int64 into the buffer. +func (b *Buffer) WriteInt64(n int64) { + for i := uint(0); i < uint(SizeInt64); i++ { + b.Buf[b.Pos] = byte(n >> (i * 8)) + b.Pos++ + } +} + +// WriteFloat32 encodes a little-endian float32 into the buffer. +func (b *Buffer) WriteFloat32(n float32) { + b.WriteUint32(math.Float32bits(n)) +} + +// WriteFloat64 encodes a little-endian float64 into the buffer. +func (b *Buffer) WriteFloat64(buf []byte, n float64) { + b.WriteUint64(math.Float64bits(n)) +} diff --git a/go/buffer_test.go b/go/buffer_test.go new file mode 100644 index 0000000..4cdeb29 --- /dev/null +++ b/go/buffer_test.go @@ -0,0 +1,266 @@ +package netcode + +import ( + "testing" +) + +func TestBuffer(t *testing.T) { + b := NewBuffer(10) + b.WriteByte('a') + b.WriteBytesN([]byte("bcdefghij"), 9) + + if string(b.Buf) != "abcdefghij" { + t.Fatalf("error should have written 'abcdefghij' got '%s'\n", string(b.Buf)) + } + +} + +func TestBuffer_Copy(t *testing.T) { + b := NewBuffer(10) + b.WriteByte('a') + b.WriteBytesN([]byte("bcdefghij"), 9) + + r := b.Copy() + if r.Len() != b.Len() { + t.Fatalf("expected copy length to be same got: %d and %d\n", r.Len(), b.Len()) + } + data, err := r.GetBytes(10) + if err != nil { + t.Fatalf("error reading bytes from copy: %s\n", err) + } + + t.Logf("%s\n", string(data)) + +} + +func TestBuffer_GetByte(t *testing.T) { + buf := make([]byte, 1) + buf[0] = 0xfe + b := NewBufferFromBytes(buf) + val, err := b.GetByte() + + if err != nil { + t.Fatal(err) + } + + if val != 0xfe { + t.Fatalf("expected 0xfe got: %x\n", val) + } +} + + +func TestBuffer_GetBytes(t *testing.T) { + buf := make([]byte, 2) + buf[0] = 'a' + buf[1] = 'b' + b := NewBufferFromBytes(buf) + + val, err := b.GetBytes(2) + + if err != nil { + t.Fatal(err) + } + + if string(val) != "ab" { + t.Fatalf("expected ab got: %s\n", val) + } + + b = NewBufferFromBytes(buf) + + val, err = b.GetBytes(3) + if err == nil { + t.Fatal("expected EOF") + } + +} + + +func TestBuffer_GetInt8(t *testing.T) { + writer := NewBuffer(SizeInt8) + writer.WriteInt8(0x0f) + reader := writer.Copy() + + val, err := reader.GetInt8() + + if err != nil { + t.Fatal(err) + } + + if val != 0xf { + t.Fatalf("expected 0xf got: %x\n", val) + } + + buf := make([]byte, SizeInt8) + buf[0] = 0xff + b := NewBufferFromBytes(buf) + val, err = b.GetInt8() + if err != nil { + t.Fatal(err) + } + + if val != -1 { + t.Fatalf("expected -1 got: %x\n", val) + } +} + +func TestBuffer_GetInt16(t *testing.T) { + writer := NewBuffer(SizeInt16) + writer.WriteInt16(0x0fff) + reader := writer.Copy() + val, err := reader.GetInt16() + + if err != nil { + t.Fatal(err) + } + + if val != 0x0fff { + t.Fatalf("expected 0x0fff got: %x\n", val) + } + + buf := make([]byte, SizeInt16) + buf[0] = 0xff + buf[1] = 0xff + b := NewBufferFromBytes(buf) + val, err = b.GetInt16() + if err != nil { + t.Fatal(err) + } + + if val != -1 { + t.Fatalf("expected -1 got: %x\n", val) + } +} + + +func TestBuffer_GetInt32(t *testing.T) { + writer := NewBuffer(SizeInt32) + writer.WriteInt32(0x0fffffff) + reader := writer.Copy() + + val, err := reader.GetInt32() + if err != nil { + t.Fatal(err) + } + + if val != 0x0fffffff { + t.Fatalf("expected 0x0fffffff got: %x\n", val) + } + + buf := make([]byte, SizeInt32) + buf[0] = 0xff + buf[1] = 0xff + buf[2] = 0xff + buf[3] = 0xff + b := NewBufferFromBytes(buf) + val, err = b.GetInt32() + if err != nil { + t.Fatal(err) + } + + if val != -1 { + t.Fatalf("expected -1 got: %x\n", val) + } +} + +func TestBuffer_GetInt64(t *testing.T) { + writer := NewBuffer(SizeInt64) + writer.WriteInt64(0xf3f3f3f3f3f3) + reader := writer.Copy() + + val, err := reader.GetInt64() + + if err != nil { + t.Fatal(err) + } + + if val != 0xf3f3f3f3f3f3 { + t.Fatalf("expected 0xf3f3f3f3f3f3 got: %x\n", val) + } +} + +func TestBuffer_GetUint8(t *testing.T) { + writer := NewBuffer(SizeUint8) + writer.WriteUint8(0xff) + reader := writer.Copy() + + val, err := reader.GetUint8() + + if err != nil { + t.Fatal(err) + } + + if val != 0xff { + t.Fatalf("expected 0xff got: %x\n", val) + } +} + +func TestBuffer_GetUint16(t *testing.T) { + writer := NewBuffer(SizeUint16) + writer.WriteUint16(0xffff) + reader := writer.Copy() + + val, err := reader.GetUint16() + if err != nil { + t.Fatal(err) + } + + if val != 0xffff { + t.Fatalf("expected 0xffff got: %x\n", val) + } +} + +func TestBuffer_GetUint32(t *testing.T) { + writer := NewBuffer(SizeUint32) + writer.WriteUint32(0xffffffff) + reader := writer.Copy() + + val, err := reader.GetUint32() + if err != nil { + t.Fatal(err) + } + + if val != 0xffffffff { + t.Fatalf("expected 0xffffffff got: %x\n", val) + } +} + +func TestBuffer_GetUint64(t *testing.T) { + writer := NewBuffer(SizeUint64) + writer.WriteUint64(0xffffffffffffffff) + reader := writer.Copy() + + val, err := reader.GetUint64() + + if err != nil { + t.Fatal(err) + } + + if val != 0xffffffffffffffff { + t.Fatalf("expected 0xffffffffffffffff got: %x\n", val) + } +} + +func TestBuffer_Len(t *testing.T) { + b := NewBuffer(10) + b.WriteByte('a') + b.WriteBytesN([]byte("bcdefghij"), 9) + + if b.Len() != 10 { + t.Fatalf("expected length of 10 got: %d\n", b.Len()) + } +} + +func TestBuffer_WriteBytes(t *testing.T) { + w := NewBuffer(10) + w.WriteBytes([]byte("0123456789")) + r := w.Copy() + val, err :=r.GetBytes(10) + if err != nil { + t.Fatal(err) + } + + if string(val) != "0123456789" { + t.Fatalf("expected 0123456789 got: %s %d\n",val, len(val)) + } + +} diff --git a/go/challenge_token.go b/go/challenge_token.go new file mode 100644 index 0000000..67dd485 --- /dev/null +++ b/go/challenge_token.go @@ -0,0 +1,36 @@ +package netcode + + +type Token interface { + Encrypt() + Decrypt() + Read(buffer []byte, length uint) + Write(buffer []byte, length uint) +} + +type ChallengeToken struct { + ClientId uint64 + UserData [USER_DATA_BYTES]byte +} + +func NewChallengeToken() *ChallengeToken { + token := &ChallengeToken{} + return token +} + +func (t *ChallengeToken) Encrypt(buffer []byte, length uint, sequence uint64, key []byte) error { + return nil +} + +func (t *ChallengeToken) Decrypt(buffer []byte, length uint, sequence uint64, key []byte) error { + return nil +} + +func (t *ChallengeToken) Read(buffer []byte, length uint) { + +} + +func (t *ChallengeToken) Write(buffer []byte, length uint) { + +} + diff --git a/go/challenge_token_test.go b/go/challenge_token_test.go new file mode 100644 index 0000000..96864fb --- /dev/null +++ b/go/challenge_token_test.go @@ -0,0 +1,9 @@ +package netcode + +import ( + "testing" +) + +func TestNewChallengeToken(t *testing.T) { + +} \ No newline at end of file diff --git a/go/client.go b/go/client.go new file mode 100644 index 0000000..80e9fd0 --- /dev/null +++ b/go/client.go @@ -0,0 +1,38 @@ +package netcode + +import ( + "crypto/rand" + "math/big" +) + +type Client struct { + Id uint64 + config *Config +} + +func NewClient(config *Config) *Client { + c := &Client{config: config} + + return c +} + +func (c *Client) Init(sequence uint64) error { + id, err := rand.Int(rand.Reader, big.NewInt(64)) + if err != nil { + return err + } + + c.Id = id.Uint64() + + token := NewConnectToken() + if err := token.Generate(c.config, sequence, c.Id); err != nil { + return err + } + + return nil +} + +func (c *Client) Connect() error { + return nil +} + diff --git a/go/config.go b/go/config.go new file mode 100644 index 0000000..13442fd --- /dev/null +++ b/go/config.go @@ -0,0 +1,19 @@ +package netcode + +import "net" + +type Config struct { + ServerAddrs []net.UDPAddr + TokenExpiry uint64 + ProtocolId uint64 + PrivateKey []byte +} + +func NewConfig(serverAddrs []net.UDPAddr, expiry, protocolId uint64, privateKey []byte) *Config { + c := &Config{} + c.ServerAddrs = serverAddrs + c.TokenExpiry = expiry + c.ProtocolId = protocolId + c.PrivateKey = privateKey + return c +} diff --git a/go/connect_token.go b/go/connect_token.go new file mode 100644 index 0000000..66e71a6 --- /dev/null +++ b/go/connect_token.go @@ -0,0 +1,207 @@ +package netcode + +import ( + "net" + "errors" + "strconv" + "time" + "log" +) + +const ( + ADDRESS_NONE = iota + ADDRESS_IPV4 + ADDRESS_IPV6 +) + + +// Token used for connecting +type ConnectToken struct { + ClientId uint64 // client identifier + ServerAddresses []net.UDPAddr // list of server addresses this client may connect to + ClientKey []byte // client to server key + ServerKey []byte // server to client key + UserData []byte // user data + ExpireTimestamp uint64 + TokenData *Buffer // connect token data +} + +// create a new empty token +func NewConnectToken() *ConnectToken { + token := &ConnectToken{} + return token +} + +// Generates the token with the supplied configuration values and clientId. +func (token *ConnectToken) Generate(config *Config, sequence, clientId uint64) error { + var err error + + token.ClientId = clientId + token.ServerAddresses = config.ServerAddrs + + if token.UserData, err = RandomBytes(USER_DATA_BYTES); err != nil { + return err + } + + if token.ClientKey, err = GenerateKey(); err != nil { + return err + } + + if token.ServerKey, err = GenerateKey(); err != nil { + return err + } + + if token.TokenData, err = WriteToken(token); err != nil { + return err + } + + creationTime := time.Now().Unix() + token.ExpireTimestamp = uint64(creationTime) + config.TokenExpiry + token.Encrypt(config.ProtocolId, sequence, config.PrivateKey) + + return nil +} + +// Encrypts the token.TokenData +func (token *ConnectToken) Encrypt(protocolId, sequence uint64, privateKey []byte) error { + additionalData, nonce := buildCryptData(protocolId, token.ExpireTimestamp, sequence) + + if err := EncryptAead(&token.TokenData.Buf, additionalData.Bytes(), nonce.Bytes(), privateKey); err != nil { + return err + } + log.Printf("after encrypt: %#v\n", token.TokenData.Bytes()) + return nil +} + +func (token *ConnectToken) Decrypt(protocolId, sequence uint64, privateKey []byte) error { + var err error + + additionalData, nonce := buildCryptData(protocolId, token.ExpireTimestamp, sequence) + + if token.TokenData.Buf, err = DecryptAead(token.TokenData.Bytes(), additionalData.Bytes(), nonce.Bytes(), privateKey); err != nil { + return err + } + return nil +} + +func buildCryptData(protocolId, expireTimestamp, sequence uint64) (*Buffer, *Buffer) { + additionalData := NewBuffer(VERSION_INFO_BYTES+8+8) + additionalData.WriteBytes([]byte(VERSION_INFO)) + additionalData.WriteUint64(protocolId) + additionalData.WriteUint64(expireTimestamp) + + nonce := NewBuffer(SizeUint64) + nonce.WriteUint64(sequence) + + return additionalData, nonce +} + +// Writes the token data to the TokenData buffer and returns to caller +func WriteToken(token *ConnectToken) (*Buffer, error) { + data := NewBuffer(CONNECT_TOKEN_PRIVATE_BYTES) + data.WriteUint64(token.ClientId) + data.WriteUint32(uint32(len(token.ServerAddresses))) + + for _, addr := range token.ServerAddresses { + host, port, err := net.SplitHostPort(addr.String()) + if err != nil { + return nil, errors.New("invalid port for host: " + addr.String()) + } + + parsed := net.ParseIP(host) + if parsed == nil { + return nil, errors.New("invalid ip address") + } + + if len(parsed) == 4 { + data.WriteUint8(uint8(ADDRESS_IPV4)) + + } else { + data.WriteUint8(uint8(ADDRESS_IPV6)) + } + + for i := 0; i < len(parsed); i +=1 { + data.WriteUint8(parsed[i]) + } + + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil, err + } + data.WriteUint16(uint16(p)) + } + data.WriteBytesN(token.ClientKey, KEY_BYTES) + data.WriteBytesN(token.ServerKey, KEY_BYTES) + data.WriteBytesN(token.UserData, USER_DATA_BYTES) + return data, nil +} + +// Takes in a slice of bytes and generates a new ConnectToken. +func ReadToken(tokenBuffer []byte) (*ConnectToken, error) { + var err error + var servers uint32 + var ipBytes []byte + + token := NewConnectToken() + buffer := NewBufferFromBytes(tokenBuffer) + + if token.ClientId, err = buffer.GetUint64(); err != nil { + return nil, err + } + + servers, err = buffer.GetUint32() + if err != nil { + return nil, err + } + + if servers <= 0 { + return nil, errors.New("empty servers") + } + + if servers > MAX_SERVERS_PER_CONNECT { + return nil, errors.New("too many servers") + } + + var i uint32 + token.ServerAddresses = make([]net.UDPAddr, servers) + + for i = 0; i < servers; i+=1 { + serverType, err := buffer.GetUint8() + if err != nil { + return nil, err + } + + if serverType == ADDRESS_IPV4 { + ipBytes, err = buffer.GetBytes(4) + } else if serverType == ADDRESS_IPV6 { + ipBytes, err = buffer.GetBytes(16) + } else { + return nil, errors.New("unknown ip address") + } + + if err != nil { + return nil, err + } + + ip := net.IP(ipBytes) + port, err := buffer.GetUint16() + if err != nil { + return nil, errors.New("invalid port") + } + token.ServerAddresses[i] = net.UDPAddr{IP: ip, Port: int(port)} + } + + if token.ClientKey, err = buffer.GetBytes(KEY_BYTES); err != nil { + return nil, errors.New("error reading client to server key") + } + + if token.ServerKey, err = buffer.GetBytes(KEY_BYTES); err != nil { + return nil, errors.New("error reading server to client key") + } + + if token.UserData, err = buffer.GetBytes(USER_DATA_BYTES); err != nil { + return nil, errors.New("error reading user data") + } + + return token, nil +} \ No newline at end of file diff --git a/go/connect_token_test.go b/go/connect_token_test.go new file mode 100644 index 0000000..1bb91f6 --- /dev/null +++ b/go/connect_token_test.go @@ -0,0 +1,61 @@ +package netcode + +import ( + "testing" + "net" + "bytes" +) + +const ( + TEST_PROTOCOL_ID = 0x1122334455667788 + TEST_CONNECT_TOKEN_EXPIRY = 30 + TEST_SERVER_PORT = 40000 + TEST_CLIENT_ID = 0x1 + TEST_SEQUENCE_START = 1000 +) + +var TEST_PRIVATE_KEY = []byte{0x60, 0x6a, 0xbe, 0x6e, 0xc9, 0x19, 0x10, 0xea, + 0x9a, 0x65, 0x62, 0xf6, 0x6f, 0x2b, 0x30, 0xe4, + 0x43, 0x71, 0xd6, 0x2c, 0xd1, 0x99, 0x27, 0x26, + 0x6b, 0x3c, 0x60, 0xf4, 0xb7, 0x15, 0xab, 0xa1 } + +func TestNewConnectToken(t *testing.T) { + token := NewConnectToken() + server := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000} + servers := make([]net.UDPAddr, 1) + servers[0] = server + config := NewConfig(servers, TEST_CONNECT_TOKEN_EXPIRY, TEST_PROTOCOL_ID, TEST_PRIVATE_KEY) + err := token.Generate(config, TEST_SEQUENCE_START, TEST_CLIENT_ID) + if err != nil { + t.Fatalf("error generating token") + } + + err = token.Decrypt(config.ProtocolId, TEST_SEQUENCE_START, config.PrivateKey) + if err != nil { + t.Fatalf("error decrypting token: %s\n", err) + } + + token2, err := ReadToken(token.TokenData.Buf) + + if token.ClientId != token2.ClientId { + t.Fatalf("clientIds do not match expected %d got %d", token.ClientId, token2.ClientId) + } + + if len(token.ServerAddresses) != len(token2.ServerAddresses) { + t.Fatalf("time stamps do not match expected %d got %d", len(token.ServerAddresses), len(token2.ServerAddresses)) + } + + // TODO verify server addresses + + if bytes.Compare(token.ClientKey, token2.ClientKey) != 0 { + t.Fatalf("ClientKey do not match expected %v got %v", token.ClientKey, token2.ClientKey) + } + + if bytes.Compare(token.ServerKey, token2.ServerKey) != 0 { + t.Fatalf("ServerKey do not match expected %v got %v", token.ServerKey, token2.ServerKey) + } + + if bytes.Compare(token.UserData, token2.UserData) != 0 { + t.Fatalf("UserData do not match expected %v got %v", token.UserData, token2.UserData) + } +} \ No newline at end of file diff --git a/go/crypto.go b/go/crypto.go new file mode 100644 index 0000000..c37fd57 --- /dev/null +++ b/go/crypto.go @@ -0,0 +1,37 @@ +package netcode + +import ( + "crypto/rand" + "github.com/codahale/chacha20poly1305" + "log" +) + +func RandomBytes(bytes int) ([]byte, error) { + b := make([]byte, bytes) + _, err := rand.Read(b) + return b, err +} + +func GenerateKey() ([]byte, error) { + return RandomBytes(KEY_BYTES) +} + + +func EncryptAead(message *[]byte, additional []byte, nonce, key []byte) error { + aead, err := chacha20poly1305.New(key) + if err != nil { + return err + } + log.Printf("before seal: %#v\n", message) + *message = aead.Seal(nil, nonce, *message, additional) + return nil +} + +func DecryptAead(message []byte, additional []byte, nonce, key []byte) ([]byte, error) { + aead, err := chacha20poly1305.New(key) + + if err != nil { + return nil, err + } + return aead.Open(nil, nonce, message, additional) +} \ No newline at end of file diff --git a/go/encryption_manager.go b/go/encryption_manager.go new file mode 100644 index 0000000..99ad84a --- /dev/null +++ b/go/encryption_manager.go @@ -0,0 +1 @@ +package netcode diff --git a/go/packet.go b/go/packet.go new file mode 100644 index 0000000..17edb71 --- /dev/null +++ b/go/packet.go @@ -0,0 +1,324 @@ +package netcode + +import ( + "log" + "errors" +) + +type PacketType uint8 + +const MAX_CLIENTS = 60 +const CONNECT_TOKEN_PRIVATE_BYTES = 1024 +const CHALLENGE_TOKEN_BYTES = 300 +const VERSION_INFO_BYTES = 13 +const USER_DATA_BYTES = 256 +const MAX_PACKET_BYTES = 1220 +const MAX_PAYLOAD_BYTES = 1200 +const MAX_ADDRESS_STRING_LENGTH = 256 +const PACKET_QUEUE_SIZE = 256 +const REPLAY_PROTECTION_BUFFER_SIZE = 256 +const CLIENT_MAX_RECEIVE_PACKETS = 64 +const SERVER_MAX_RECEIVE_PACKETS = ( 64 * MAX_CLIENTS ) + +const KEY_BYTES = 32 +const MAC_BYTES = 16 +const NONCE_BYTES = 8 +const MAX_SERVERS_PER_CONNECT = 32 + +const VERSION_INFO = "NETCODE 1.00" +const PACKET_SEND_RATE = 10.0 +const TIMEOUT_SECONDS = 5.0 +const NUM_DISCONNECT_PACKETS = 10 + + +const ( + ConnectionRequest PacketType = iota + ConnectionDenied + ConnectionChallenge + ConnectionResponse + ConnectionKeepAlive + ConnectionPayload + ConnectionDisconnect + ConnectionNumPackets +) + +type Packet interface { + GetType() PacketType +} + +type RequestPacket struct { + Type PacketType + VersionInfo []byte + ProtocolId uint64 + ConnectTokenExpireTimestamp uint64 + ConnectTokenSequence uint64 + ConnectTokenData []byte +} + +func (p *RequestPacket) GetType() PacketType { + return ConnectionRequest +} + +type DeniedPacket struct { + Type PacketType +} + +func (p *DeniedPacket) GetType() PacketType { + return ConnectionDenied +} + +type ChallengePacket struct { + Type PacketType + ChallengeTokenSequence uint64 + ChallengeTokenData []byte +} + +func (p *ChallengePacket) GetType() PacketType { + return ConnectionChallenge +} + +type ResponsePacket struct { + Type PacketType + ChallengeTokenSequence uint64 + ChallengeTokenData []byte +} + +func (p *ResponsePacket) GetType() PacketType { + return ConnectionResponse +} + +type KeepAlivePacket struct { + Type PacketType + ClientIndex uint + MaxClients uint +} + +func (p *KeepAlivePacket) GetType() PacketType { + return ConnectionKeepAlive +} + + +type PayloadPacket struct { + Type PacketType + PayloadBytes uint32 + PayloadData []byte + // ... +} + +func (p *PayloadPacket) GetType() PacketType { + return ConnectionPayload +} + +func NewPayloadPacket(payload_bytes uint32) *PayloadPacket { + return &PayloadPacket{Type: ConnectionPayload, PayloadBytes: payload_bytes} +} + +type DisconnectPacket struct { + Type PacketType +} + +type Context struct { + WritePacketKey []byte + ReadPacketKey []byte +} + +type ReplayProtection struct { + MostRecentSequence uint64 + ReceivedPacket []uint64 +} + +func (r *ReplayProtection) Reset() { + r.MostRecentSequence = 0 + //MemsetUint64(r.ReceivedPacket, 0xFF) +} + +func (r *ReplayProtection) AlreadyReceived(sequence uint64) int { + if (sequence & 1 << 63) == 0 { + return 0 + } + + if sequence + REPLAY_PROTECTION_BUFFER_SIZE <= r.MostRecentSequence { + return 1 + } + + if sequence > r.MostRecentSequence { + r.MostRecentSequence = sequence + } + + index := ( sequence % REPLAY_PROTECTION_BUFFER_SIZE) + + if r.ReceivedPacket[index] == 0xFFFFFFFFFFFFFFFF { + r.ReceivedPacket[index] = sequence + return 0 + } + + if r.ReceivedPacket[index] >= sequence { + return 1 + } + + r.ReceivedPacket[index] = sequence + return 0 +} + +func SequenceNumberBytesRequired(sequence uint64) int { + var mask uint64 + mask = 0xFF00000000000000 + i := 0 + for ; i < 7; i+=1 { + if (sequence & mask == 0) { + break + } + mask >>= 8 + } + return 8 - i +} + +func WritePacket(packet Packet, buffer *Buffer, buffer_length uint, sequence uint64, write_packet_key []byte, protocol_id uint64) (int, error) { + var p Packet + var start *Buffer + + packetType := packet.GetType() + + if packetType == ConnectionRequest { + + p, ok := packet.(*RequestPacket) + if !ok { + return -1, nil + } + start = NewBufferFromBytes(buffer.Bytes()) + buffer.WriteUint8(uint8(ConnectionRequest)) + buffer.WriteBytesN(p.VersionInfo, VERSION_INFO_BYTES) + buffer.WriteUint64(p.ProtocolId) + buffer.WriteUint64(p.ConnectTokenExpireTimestamp) + buffer.WriteUint64(p.ConnectTokenSequence) + buffer.WriteBytesN(p.ConnectTokenData, CONNECT_TOKEN_PRIVATE_BYTES) + return buffer.Len() - start.Len(), nil + } + + // *** encrypted packets *** + + // write the prefix byte (this is a combination of the packet type and number of sequence bytes) + start = NewBufferFromBytes(buffer.Bytes()) + sequence_bytes := SequenceNumberBytesRequired(sequence) + + prefix_byte := uint8(p.GetType()) | uint8(sequence_bytes << 4) + buffer.WriteUint8(prefix_byte) + + sequence_temp := sequence + + for i := 0; i < sequence_bytes; i+=1 { + buffer.WriteUint8(uint8(sequence_temp & 0xFF)) + sequence_temp >>= 8 + } + + //encrypted_start := NewBufferFromBytes(buffer.Buf.Bytes()) + + switch (p.GetType()) { + case ConnectionDenied: + // ... + case ConnectionChallenge: + p, ok := packet.(*ChallengePacket) + if !ok { + return -1, nil + } + buffer.WriteUint64(p.ChallengeTokenSequence) + buffer.WriteBytesN(p.ChallengeTokenData, CHALLENGE_TOKEN_BYTES) + case ConnectionResponse: + p, ok := packet.(*ResponsePacket) + if !ok { + return -1, nil + } + buffer.WriteUint64(p.ChallengeTokenSequence) + buffer.WriteBytesN(p.ChallengeTokenData, CHALLENGE_TOKEN_BYTES) + case ConnectionKeepAlive: + p, ok := packet.(*KeepAlivePacket) + if !ok { + return -1, nil + } + buffer.WriteUint32(uint32(p.ClientIndex)) + buffer.WriteUint32(uint32(p.MaxClients)) + case ConnectionPayload: + p, ok := packet.(*PayloadPacket) + if !ok { + return -1, nil + } + buffer.WriteBytesN([]byte(p.PayloadData), int(p.PayloadBytes)) + case ConnectionDisconnect: + // ... + } + //encrypted_finish := buffer + + + // encrypt the per-packet packet written with the prefix byte, protocol id and version as the associated data. this must match to decrypt. + additional_data := NewBuffer(VERSION_INFO_BYTES+8+1) + additional_data.WriteBytesN([]byte(VERSION_INFO), VERSION_INFO_BYTES) + + nonce := NewBuffer(8) + + nonce.WriteUint64(sequence) + + //err := EncryptAead(encrypted_start, len(encrypted_finish) - len(encrypted_start), additional_data, len(additional_data), nonce, write_packet_key) + //if err != nil { + // return -1, err + //} + + // buffer += MAC_BYTES ??? + + return buffer.Len() - start.Len(), nil +} + +func ReadPacket(buffer *Buffer, buffer_length int, sequence uint64, read_packet_key []byte, protocol_id uint64, current_timestamp uint64, private_key []byte, allowed_packets []byte, replay_protection *ReplayProtection) (Packet, error) { + var packet Packet + sequence = 0 + if buffer_length < 1 { + return nil, errors.New("invalid buffer length") + } + + //start := NewBufferFromBytes(buffer.Buf.Bytes()) + + prefix_byte, err := buffer.GetUint8() + if err != nil { + return nil, errors.New("invalid buffer length") + } + + if PacketType(prefix_byte) == ConnectionRequest { + if allowed_packets[0] != 0 { + return nil, errors.New("ignored connection request packet. packet type is not allowed\n") + } + + if buffer_length != 1 + VERSION_INFO_BYTES + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES { + return nil, errors.New("ignored connection request packet. bad packet length\n") + } + + if private_key == nil { + return nil, errors.New("ignored connection request packet. no private key\n") + } + + version_info, err := buffer.GetBytes(VERSION_INFO_BYTES) + if err != nil { + return nil, errors.New("ignored connection request packet. bad version info\n") + } + + if string(version_info) != VERSION_INFO { + return nil, errors.New("ignored connection request packet. bad version info\n") + } + + id, err := buffer.GetUint64() + if err != nil || id != protocol_id { + return nil, errors.New("ignored connection request packet. wrong protocol id\n") + } + + expire, err := buffer.GetUint64() + if err != nil || expire <= current_timestamp { + return nil, errors.New("ignored connection request packet. connect token expired\n") + } + + token_sequence, err := buffer.GetUint64() + if err != nil { + return nil, err + } + log.Print(token_sequence) + return packet, nil + } + return packet, nil +} \ No newline at end of file diff --git a/go/queue.go b/go/queue.go new file mode 100644 index 0000000..99ad84a --- /dev/null +++ b/go/queue.go @@ -0,0 +1 @@ +package netcode diff --git a/go/server.go b/go/server.go new file mode 100644 index 0000000..99ad84a --- /dev/null +++ b/go/server.go @@ -0,0 +1 @@ +package netcode diff --git a/go/simulator.go b/go/simulator.go new file mode 100644 index 0000000..99ad84a --- /dev/null +++ b/go/simulator.go @@ -0,0 +1 @@ +package netcode diff --git a/go/sizes.go b/go/sizes.go new file mode 100644 index 0000000..ad7a684 --- /dev/null +++ b/go/sizes.go @@ -0,0 +1,56 @@ +package netcode + +// Taken from https://raw.githubusercontent.com/google/flatbuffers/master/go/sizes.go +import ( + "unsafe" +) + +const ( + // See http://golang.org/ref/spec#Numeric_types + + // SizeUint8 is the byte size of a uint8. + SizeUint8 = 1 + // SizeUint16 is the byte size of a uint16. + SizeUint16 = 2 + // SizeUint32 is the byte size of a uint32. + SizeUint32 = 4 + // SizeUint64 is the byte size of a uint64. + SizeUint64 = 8 + + // SizeInt8 is the byte size of a int8. + SizeInt8 = 1 + // SizeInt16 is the byte size of a int16. + SizeInt16 = 2 + // SizeInt32 is the byte size of a int32. + SizeInt32 = 4 + // SizeInt64 is the byte size of a int64. + SizeInt64 = 8 + + // SizeFloat32 is the byte size of a float32. + SizeFloat32 = 4 + // SizeFloat64 is the byte size of a float64. + SizeFloat64 = 8 + + // SizeByte is the byte size of a byte. + // The `byte` type is aliased (by Go definition) to uint8. + SizeByte = 1 + + // SizeBool is the byte size of a bool. + // The `bool` type is aliased (by flatbuffers convention) to uint8. + SizeBool = 1 + + // SizeSOffsetT is the byte size of an SOffsetT. + // The `SOffsetT` type is aliased (by flatbuffers convention) to int32. + SizeSOffsetT = 4 + // SizeUOffsetT is the byte size of an UOffsetT. + // The `UOffsetT` type is aliased (by flatbuffers convention) to uint32. + SizeUOffsetT = 4 + // SizeVOffsetT is the byte size of an VOffsetT. + // The `VOffsetT` type is aliased (by flatbuffers convention) to uint16. + SizeVOffsetT = 2 +) + +// byteSliceToString converts a []byte to string without a heap allocation. +func byteSliceToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} \ No newline at end of file diff --git a/go/socket.go b/go/socket.go new file mode 100644 index 0000000..029f679 --- /dev/null +++ b/go/socket.go @@ -0,0 +1,73 @@ +package netcode + +import ( + "net" +) + +const ( + SOCKET_ERROR_NONE = iota + SOCKET_ERROR_CREATE_FAILED + SOCKET_ERROR_SET_NON_BLOCKING_FAILED + SOCKET_ERROR_SOCKOPT_IPV6_ONLY_FAILED + SOCKET_ERROR_SOCKOPT_RCVBUF_FAILED + SOCKET_ERROR_SOCKOPT_SNDBUF_FAILED + SOCKET_ERROR_BIND_IPV4_FAILED + SOCKET_ERROR_BIND_IPV6_FAILED + SOCKET_ERROR_GET_SOCKNAME_IPV4_FAILED + SOCKET_ERROR_GET_SOCKNAME_IPV6_FAILED +) + +type Socket struct { + Address *net.UDPAddr + Conn *net.UDPConn +} + +func NewSocket() *Socket { + s := &Socket{} + return s +} + +func (s *Socket) Create(address *net.UDPAddr, sendsize, recvsize int) error { + conn, err := net.ListenUDP(address.Network(), address) + if err != nil { + return err + } + + if err := conn.SetReadBuffer(recvsize); err != nil { + return err + } + + if err := conn.SetWriteBuffer(sendsize); err != nil { + return err + } + + s.Conn = conn + return nil +} + +func (s *Socket) Send(destination *net.UDPAddr, data []byte) error { + if s.Conn == nil { + return nil + } + + length, err := s.Conn.WriteTo(data, destination) + if err != nil { + return err + } + + if length != len(data) { + // error writing all data + return nil + } + return nil +} + +func (s *Socket) Recv(source *net.Addr, data []byte, maxsize uint) error { + return nil +} + + +func (s *Socket) Destroy() { + s.Conn.Close() +} + From 21f865b915817eb83ded29b61d2a9e2e280bd1a2 Mon Sep 17 00:00:00 2001 From: wirepair Date: Sun, 2 Apr 2017 21:11:02 +0900 Subject: [PATCH 02/11] add some comments --- go/connect_token.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/go/connect_token.go b/go/connect_token.go index 66e71a6..d51a84a 100644 --- a/go/connect_token.go +++ b/go/connect_token.go @@ -73,6 +73,7 @@ func (token *ConnectToken) Encrypt(protocolId, sequence uint64, privateKey []byt return nil } +// Decrypts the tokendata and assigns it back to the backing buffer func (token *ConnectToken) Decrypt(protocolId, sequence uint64, privateKey []byte) error { var err error @@ -84,6 +85,7 @@ func (token *ConnectToken) Decrypt(protocolId, sequence uint64, privateKey []byt return nil } +// builds the additional data and nonce necessary for encryption and decryption. func buildCryptData(protocolId, expireTimestamp, sequence uint64) (*Buffer, *Buffer) { additionalData := NewBuffer(VERSION_INFO_BYTES+8+8) additionalData.WriteBytes([]byte(VERSION_INFO)) @@ -162,10 +164,9 @@ func ReadToken(tokenBuffer []byte) (*ConnectToken, error) { return nil, errors.New("too many servers") } - var i uint32 token.ServerAddresses = make([]net.UDPAddr, servers) - for i = 0; i < servers; i+=1 { + for i := 0; i < int(servers); i+=1 { serverType, err := buffer.GetUint8() if err != nil { return nil, err From d5db653b36c24defdc8f95eaffa22533497d8d30 Mon Sep 17 00:00:00 2001 From: wirepair Date: Sun, 2 Apr 2017 21:15:38 +0900 Subject: [PATCH 03/11] fix buffer check for errors actually check errors and return before attempting to access buffer values. --- go/buffer.go | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/go/buffer.go b/go/buffer.go index b1f5cec..f39418e 100644 --- a/go/buffer.go +++ b/go/buffer.go @@ -65,26 +65,35 @@ func (b *Buffer) GetUint8() (uint8, error) { func (b *Buffer) GetUint16() (uint16, error) { var n uint16 buf, err := b.GetBytes(SizeUint16) + if err != nil { + return 0, nil + } n |= uint16(buf[0]) n |= uint16(buf[1]) << 8 - return n, err + return n, nil } // GetUint32 decodes a little-endian uint32 from the buffer func (b *Buffer) GetUint32() (uint32, error) { var n uint32 buf, err := b.GetBytes(SizeUint32) + if err != nil { + return 0, nil + } n |= uint32(buf[0]) n |= uint32(buf[1]) << 8 n |= uint32(buf[2]) << 16 n |= uint32(buf[3]) << 24 - return n, err + return n, nil } // GetUint64 decodes a little-endian uint64 from the buffer func (b *Buffer) GetUint64() (uint64, error) { var n uint64 buf, err := b.GetBytes(SizeUint64) + if err != nil { + return 0, nil + } n |= uint64(buf[0]) n |= uint64(buf[1]) << 8 n |= uint64(buf[2]) << 16 @@ -93,7 +102,7 @@ func (b *Buffer) GetUint64() (uint64, error) { n |= uint64(buf[5]) << 40 n |= uint64(buf[6]) << 48 n |= uint64(buf[7]) << 56 - return n, err + return n, nil } // GetInt8 decodes a little-endian int8 from the buffer @@ -109,26 +118,35 @@ func (b *Buffer) GetInt8() (int8, error) { func (b *Buffer) GetInt16() (int16, error) { var n int16 buf, err := b.GetBytes(SizeInt16) + if err != nil { + return 0, nil + } n |= int16(buf[0]) n |= int16(buf[1]) << 8 - return n, err + return n, nil } // GetInt32 decodes a little-endian int32 from the buffer func (b *Buffer) GetInt32() (int32, error) { var n int32 buf, err := b.GetBytes(SizeInt32) + if err != nil { + return 0, nil + } n |= int32(buf[0]) n |= int32(buf[1]) << 8 n |= int32(buf[2]) << 16 n |= int32(buf[3]) << 24 - return n, err + return n, nil } // GetInt64 decodes a little-endian int64 from the buffer func (b *Buffer) GetInt64() (int64, error) { var n int64 buf, err := b.GetBytes(SizeInt64) + if err != nil { + return 0, nil + } n |= int64(buf[0]) n |= int64(buf[1]) << 8 n |= int64(buf[2]) << 16 @@ -137,7 +155,7 @@ func (b *Buffer) GetInt64() (int64, error) { n |= int64(buf[5]) << 40 n |= int64(buf[6]) << 48 n |= int64(buf[7]) << 56 - return n, err + return n, nil } From a19a68668cb69109c4b3e9ae31f9d5b709fe8429 Mon Sep 17 00:00:00 2001 From: wirepair Date: Mon, 3 Apr 2017 20:53:30 +0900 Subject: [PATCH 04/11] implement challenge token and replay protection --- go/buffer.go | 18 +++++++-- go/challenge_token.go | 67 +++++++++++++++++++++++--------- go/challenge_token_test.go | 39 +++++++++++++++++++ go/client.go | 1 - go/crypto.go | 7 +++- go/packet.go | 75 ++++++++++-------------------------- go/packet_test.go | 9 +++++ go/queue.go | 43 +++++++++++++++++++++ go/replay_protection.go | 56 +++++++++++++++++++++++++++ go/replay_protection_test.go | 66 +++++++++++++++++++++++++++++++ 10 files changed, 303 insertions(+), 78 deletions(-) create mode 100644 go/packet_test.go create mode 100644 go/replay_protection.go create mode 100644 go/replay_protection_test.go diff --git a/go/buffer.go b/go/buffer.go index f39418e..c43864d 100644 --- a/go/buffer.go +++ b/go/buffer.go @@ -5,36 +5,48 @@ import ( "io" ) +// Buffer is a helper struct for serializing and deserializing as the caller +// does not need to externally manage where in the buffer they are currently reading +// or writing to. type Buffer struct { - Buf []byte - Pos int + Buf []byte // the backing byte slice + Pos int // current position in read/write } - +// Creates a new Buffer with a backing byte slice of the provided size func NewBuffer(size int) *Buffer { b := &Buffer{} b.Buf = make([]byte, size) return b } +// Creates a new buffer from a byte slice func NewBufferFromBytes(buf []byte) *Buffer { b := &Buffer{} b.Buf = buf return b } +// Returns a copy of Buffer func (b *Buffer) Copy() *Buffer { c := NewBufferFromBytes(b.Buf) return c } +// Gets the length of the backing byte slice func (b *Buffer) Len() int { return len(b.Buf) } +// Returns the backing byte slice func (b *Buffer) Bytes() []byte { return b.Buf } +// Resets the position back to beginning of buffer +func (b *Buffer) Reset() { + b.Pos = 0 +} + // GetByte decodes a little-endian byte func (b *Buffer) GetByte() (byte, error) { return b.GetUint8() diff --git a/go/challenge_token.go b/go/challenge_token.go index 67dd485..835bc85 100644 --- a/go/challenge_token.go +++ b/go/challenge_token.go @@ -1,36 +1,67 @@ package netcode - -type Token interface { - Encrypt() - Decrypt() - Read(buffer []byte, length uint) - Write(buffer []byte, length uint) -} - +// Challenge tokens are used in certain packet types type ChallengeToken struct { - ClientId uint64 - UserData [USER_DATA_BYTES]byte + ClientId uint64 // the clientId associated with this token + UserData *Buffer // the userdata payload + TokenData *Buffer // the serialized payload container } -func NewChallengeToken() *ChallengeToken { +// Creates a new empty challenge token with only the clientId set +func NewChallengeToken(clientId uint64) *ChallengeToken { token := &ChallengeToken{} + token.ClientId = clientId + token.UserData = NewBuffer(USER_DATA_BYTES) + token.TokenData = NewBuffer(CHALLENGE_TOKEN_BYTES) return token } -func (t *ChallengeToken) Encrypt(buffer []byte, length uint, sequence uint64, key []byte) error { - return nil +// Encrypts the TokenData buffer with the sequence nonce and provided key +func (t *ChallengeToken) Encrypt(sequence uint64, key []byte) error { + nonce := NewBuffer(SizeUint64) + nonce.WriteUint64(sequence) + + return EncryptAead(&t.TokenData.Buf, nil, nonce.Bytes(), key) } -func (t *ChallengeToken) Decrypt(buffer []byte, length uint, sequence uint64, key []byte) error { - return nil +// Decrypts the TokenData buffer with the sequence nonce and provided key, updating the +// internal TokenData buffer +func (t *ChallengeToken) Decrypt(sequence uint64, key []byte) error { + var err error + nonce := NewBuffer(SizeUint64) + nonce.WriteUint64(sequence) + t.TokenData.Buf, err = DecryptAead(t.TokenData.Buf, nil, nonce.Bytes(), key) + return err } -func (t *ChallengeToken) Read(buffer []byte, length uint) { +// Serializes the client id and userData, also sets the UserData buffer. +func (t *ChallengeToken) Write(userData []byte) { + t.UserData.WriteBytes(userData) + t.TokenData.WriteUint64(t.ClientId) + t.TokenData.WriteBytes(userData) } -func (t *ChallengeToken) Write(buffer []byte, length uint) { +// Generates a new ChallengeToken from the provided buffer byte slice. Only sets the ClientId +// and UserData buffer, does not update the TokenData buffer. +func ReadChallengeToken(buffer []byte) (*ChallengeToken, error) { + var err error + var clientId uint64 + var userData []byte + tokenBuffer := NewBufferFromBytes(buffer) -} + clientId, err = tokenBuffer.GetUint64() + if err != nil { + return nil, err + } + token := NewChallengeToken(clientId) + + userData, err = tokenBuffer.GetBytes(USER_DATA_BYTES) + if err != nil { + return nil, err + } + token.UserData.WriteBytes(userData) + token.UserData.Reset() + return token, nil +} diff --git a/go/challenge_token_test.go b/go/challenge_token_test.go index 96864fb..0295cd6 100644 --- a/go/challenge_token_test.go +++ b/go/challenge_token_test.go @@ -2,8 +2,47 @@ package netcode import ( "testing" + "bytes" ) func TestNewChallengeToken(t *testing.T) { + var err error + var userData []byte + + token := NewChallengeToken(TEST_CLIENT_ID) + if userData, err = RandomBytes(USER_DATA_BYTES); err != nil { + t.Fatalf("error generating random data\n") + } + token.Write(userData) + + var sequence uint64 + sequence = 1000 + key, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key\n") + } + + if err := token.Encrypt(sequence, key); err != nil { + t.Fatalf("error encrypting challenge token: %s\n", err) + } + + if err := token.Decrypt(sequence, key); err != nil { + t.Fatalf("error decrypting challenge token: %s\n", err) + } + + newToken, err := ReadChallengeToken(token.TokenData.Buf) + if err != nil { + t.Fatalf("error reading token data %s\n", err) + } + + if newToken.ClientId != token.ClientId { + t.Fatalf("token client id did not match, expected %d got %d\n", token.ClientId, newToken.ClientId) + } + + if bytes.Compare(newToken.UserData.Buf, token.UserData.Buf) != 0 { + t.Fatalf("user data did not match expected\n %#v\ngot\n%#v!", token.UserData.Buf, newToken.UserData.Buf) + } + + } \ No newline at end of file diff --git a/go/client.go b/go/client.go index 80e9fd0..68a10b5 100644 --- a/go/client.go +++ b/go/client.go @@ -12,7 +12,6 @@ type Client struct { func NewClient(config *Config) *Client { c := &Client{config: config} - return c } diff --git a/go/crypto.go b/go/crypto.go index c37fd57..16efe42 100644 --- a/go/crypto.go +++ b/go/crypto.go @@ -6,17 +6,19 @@ import ( "log" ) +// Generates random bytes func RandomBytes(bytes int) ([]byte, error) { b := make([]byte, bytes) _, err := rand.Read(b) return b, err } +// Generates a random key of KEY_BYTES func GenerateKey() ([]byte, error) { return RandomBytes(KEY_BYTES) } - +// Encrypts the message in place with the nonce and key and optional additional buffer func EncryptAead(message *[]byte, additional []byte, nonce, key []byte) error { aead, err := chacha20poly1305.New(key) if err != nil { @@ -24,9 +26,12 @@ func EncryptAead(message *[]byte, additional []byte, nonce, key []byte) error { } log.Printf("before seal: %#v\n", message) *message = aead.Seal(nil, nonce, *message, additional) + log.Printf("after seal: %#v\n", message) return nil } +// Encrypts the message with the nonce and key and optional additional buffer returning a copy +// byte slice func DecryptAead(message []byte, additional []byte, nonce, key []byte) ([]byte, error) { aead, err := chacha20poly1305.New(key) diff --git a/go/packet.go b/go/packet.go index 17edb71..7ed74ef 100644 --- a/go/packet.go +++ b/go/packet.go @@ -15,7 +15,6 @@ const USER_DATA_BYTES = 256 const MAX_PACKET_BYTES = 1220 const MAX_PAYLOAD_BYTES = 1200 const MAX_ADDRESS_STRING_LENGTH = 256 -const PACKET_QUEUE_SIZE = 256 const REPLAY_PROTECTION_BUFFER_SIZE = 256 const CLIENT_MAX_RECEIVE_PACKETS = 64 const SERVER_MAX_RECEIVE_PACKETS = ( 64 * MAX_CLIENTS ) @@ -109,8 +108,11 @@ func (p *PayloadPacket) GetType() PacketType { return ConnectionPayload } -func NewPayloadPacket(payload_bytes uint32) *PayloadPacket { - return &PayloadPacket{Type: ConnectionPayload, PayloadBytes: payload_bytes} +func NewPayloadPacket(payloadBytes uint32) *PayloadPacket { + packet := &PayloadPacket{Type: ConnectionPayload} + packet.PayloadBytes = payloadBytes + packet.PayloadData = make([]byte, payloadBytes) + return packet } type DisconnectPacket struct { @@ -122,57 +124,6 @@ type Context struct { ReadPacketKey []byte } -type ReplayProtection struct { - MostRecentSequence uint64 - ReceivedPacket []uint64 -} - -func (r *ReplayProtection) Reset() { - r.MostRecentSequence = 0 - //MemsetUint64(r.ReceivedPacket, 0xFF) -} - -func (r *ReplayProtection) AlreadyReceived(sequence uint64) int { - if (sequence & 1 << 63) == 0 { - return 0 - } - - if sequence + REPLAY_PROTECTION_BUFFER_SIZE <= r.MostRecentSequence { - return 1 - } - - if sequence > r.MostRecentSequence { - r.MostRecentSequence = sequence - } - - index := ( sequence % REPLAY_PROTECTION_BUFFER_SIZE) - - if r.ReceivedPacket[index] == 0xFFFFFFFFFFFFFFFF { - r.ReceivedPacket[index] = sequence - return 0 - } - - if r.ReceivedPacket[index] >= sequence { - return 1 - } - - r.ReceivedPacket[index] = sequence - return 0 -} - -func SequenceNumberBytesRequired(sequence uint64) int { - var mask uint64 - mask = 0xFF00000000000000 - i := 0 - for ; i < 7; i+=1 { - if (sequence & mask == 0) { - break - } - mask >>= 8 - } - return 8 - i -} - func WritePacket(packet Packet, buffer *Buffer, buffer_length uint, sequence uint64, write_packet_key []byte, protocol_id uint64) (int, error) { var p Packet var start *Buffer @@ -199,7 +150,7 @@ func WritePacket(packet Packet, buffer *Buffer, buffer_length uint, sequence uin // write the prefix byte (this is a combination of the packet type and number of sequence bytes) start = NewBufferFromBytes(buffer.Bytes()) - sequence_bytes := SequenceNumberBytesRequired(sequence) + sequence_bytes := sequenceNumberBytesRequired(sequence) prefix_byte := uint8(p.GetType()) | uint8(sequence_bytes << 4) buffer.WriteUint8(prefix_byte) @@ -321,4 +272,18 @@ func ReadPacket(buffer *Buffer, buffer_length int, sequence uint64, read_packet_ return packet, nil } return packet, nil +} + + +func sequenceNumberBytesRequired(sequence uint64) int { + var mask uint64 + mask = 0xFF00000000000000 + i := 0 + for ; i < 7; i+=1 { + if (sequence & mask == 0) { + break + } + mask >>= 8 + } + return 8 - i } \ No newline at end of file diff --git a/go/packet_test.go b/go/packet_test.go new file mode 100644 index 0000000..13e5b95 --- /dev/null +++ b/go/packet_test.go @@ -0,0 +1,9 @@ +package netcode + +import ( + "testing" +) + +func TestReadPacket(t *testing.T) { + +} \ No newline at end of file diff --git a/go/queue.go b/go/queue.go index 99ad84a..6dfc01a 100644 --- a/go/queue.go +++ b/go/queue.go @@ -1 +1,44 @@ package netcode + +const PACKET_QUEUE_SIZE = 256 + +type Queue struct { + NumPackets int + StartIndex int + Packets []Packet +} + +func NewQueue() *Queue { + q := &Queue{} + q.Packets = make([]Packet, PACKET_QUEUE_SIZE) + return q +} + +func (q *Queue) Clear() { + q.NumPackets = 0 + q.StartIndex = 0 + q.Packets = make([]Packet, PACKET_QUEUE_SIZE) +} + +func (q *Queue) Push(packet Packet) int { + if q.NumPackets == PACKET_QUEUE_SIZE { + return 0 + } + + index := (q.StartIndex + q.NumPackets) % PACKET_QUEUE_SIZE + q.Packets[index] = packet + q.NumPackets++ + return 1 +} + +func (q *Queue) Pop() Packet { + if q.NumPackets == 0 { + return nil + } + + packet := q.Packets[q.StartIndex] + q.StartIndex = ( q.StartIndex + 1 ) % PACKET_QUEUE_SIZE + q.NumPackets-- + return packet +} + diff --git a/go/replay_protection.go b/go/replay_protection.go new file mode 100644 index 0000000..f07fc2f --- /dev/null +++ b/go/replay_protection.go @@ -0,0 +1,56 @@ +package netcode + +// Our type to hold replay protection of packet sequences +type ReplayProtection struct { + MostRecentSequence uint64 // last sequence recv'd + ReceivedPacket []uint64 // slice of REPLAY_PROTECTION_BUFFER_SIZE worth of packet sequences +} + +// Initializes a new ReplayProtection with the ReceivedPacket buffer elements all set to 0xFFFFFFFFFFFFFFFF +func NewReplayProtection() *ReplayProtection { + r := &ReplayProtection{} + r.ReceivedPacket = make([]uint64, REPLAY_PROTECTION_BUFFER_SIZE) + r.Reset() + return r +} + +// Clears out the most recent sequence and resets the entire packet buffer to 0xFFFFFFFFFFFFFFFF +func (r *ReplayProtection) Reset() { + r.MostRecentSequence = 0 + clearPacketBuffer(r.ReceivedPacket) +} + +// Tests that the sequence has not already been recv'd, adding it to the buffer if it's new. +func (r *ReplayProtection) AlreadyReceived(sequence uint64) int { + if sequence & (uint64(1 << 63)) != 0 { + return 0 + } + + if sequence + REPLAY_PROTECTION_BUFFER_SIZE <= r.MostRecentSequence { + return 1 + } + + if sequence > r.MostRecentSequence { + r.MostRecentSequence = sequence + } + + index := sequence % REPLAY_PROTECTION_BUFFER_SIZE + + if r.ReceivedPacket[index] == 0xFFFFFFFFFFFFFFFF { + r.ReceivedPacket[index] = sequence + return 0 + } + + if r.ReceivedPacket[index] >= sequence { + return 1 + } + + r.ReceivedPacket[index] = sequence + return 0 +} + +func clearPacketBuffer(packets []uint64) { + for i := 0; i < len(packets); i+=1 { + packets[i] = 0xFFFFFFFFFFFFFFFF + } +} \ No newline at end of file diff --git a/go/replay_protection_test.go b/go/replay_protection_test.go new file mode 100644 index 0000000..dca619f --- /dev/null +++ b/go/replay_protection_test.go @@ -0,0 +1,66 @@ +package netcode + +import ( + "testing" +) + +func TestReplayProtection_Reset(t *testing.T) { + r := NewReplayProtection() + for _, p := range r.ReceivedPacket { + if p != 0xFFFFFFFFFFFFFFFF { + t.Fatalf("packet was not reset") + } + } +} + +func TestReplayProtection(t *testing.T) { + r := NewReplayProtection() + for i := 0; i < 2; i+=1 { + r.Reset() + if r.MostRecentSequence != 0 { + t.Fatalf("sequence was not 0") + } + + // sequence numbers with high bit set should be ignored + sequence := uint64(1 << 63) + if r.AlreadyReceived(sequence) != 0 { + t.Fatalf("sequence numbers with high bit set should be ignored") + } + + if r.MostRecentSequence != 0 { + t.Fatalf("sequence was not 0 after high-bit check got: 0x%x\n", r.MostRecentSequence) + } + + // the first time we receive packets, they should not be already received + maxSequence := uint64(REPLAY_PROTECTION_BUFFER_SIZE * 4) + for sequence = 0; sequence < maxSequence; sequence+=1 { + if r.AlreadyReceived(sequence) != 0 { + t.Fatalf("the first time we receive packets, they should not be already received") + } + } + + // old packets outside buffer should be considered already received + if r.AlreadyReceived(0) != 1 { + t.Fatalf("old packets outside buffer should be considered already received") + } + + // packets received a second time should be flagged already received + for sequence = maxSequence - 10; sequence < maxSequence; sequence+=1 { + if r.AlreadyReceived(sequence) != 1 { + t.Fatalf("packets received a second time should be flagged already received") + } + } + + // jumping ahead to a much higher sequence should be considered not already received + if r.AlreadyReceived(maxSequence+REPLAY_PROTECTION_BUFFER_SIZE) != 0 { + t.Fatalf("jumping ahead to a much higher sequence should be considered not already received") + } + + // old packets should be considered already received + for sequence = 0; sequence < maxSequence; sequence+=1 { + if r.AlreadyReceived(sequence) != 1 { + t.Fatalf("old packets should be considered already received") + } + } + } +} \ No newline at end of file From b81bbc009a033457d85c088a7308afd54c176a4c Mon Sep 17 00:00:00 2001 From: wirepair Date: Wed, 5 Apr 2017 08:08:42 +0900 Subject: [PATCH 05/11] implement first round of packet and fix to connect token --- go/client.go | 4 +- go/connect_token.go | 118 ++++++++++----- go/connect_token_test.go | 45 ++++-- go/crypto.go | 11 +- go/packet.go | 315 +++++++++++++++++++++++++++++++-------- go/packet_test.go | 89 ++++++++++- 6 files changed, 456 insertions(+), 126 deletions(-) diff --git a/go/client.go b/go/client.go index 68a10b5..31e4e7e 100644 --- a/go/client.go +++ b/go/client.go @@ -3,6 +3,7 @@ package netcode import ( "crypto/rand" "math/big" + "time" ) type Client struct { @@ -22,9 +23,10 @@ func (c *Client) Init(sequence uint64) error { } c.Id = id.Uint64() + currentTimestamp := uint64(time.Now().Unix()) token := NewConnectToken() - if err := token.Generate(c.config, sequence, c.Id); err != nil { + if err := token.Generate(c.config, c.Id, currentTimestamp, sequence); err != nil { return err } diff --git a/go/connect_token.go b/go/connect_token.go index d51a84a..897c6cf 100644 --- a/go/connect_token.go +++ b/go/connect_token.go @@ -4,7 +4,6 @@ import ( "net" "errors" "strconv" - "time" "log" ) @@ -17,13 +16,22 @@ const ( // Token used for connecting type ConnectToken struct { - ClientId uint64 // client identifier - ServerAddresses []net.UDPAddr // list of server addresses this client may connect to + VersionInfo []byte + ProtocolId uint64 + CreateTimestamp uint64 + ExpireTimestamp uint64 + Sequence uint64 + PrivateData *ConnectTokenPrivate + TimeoutSeconds int +} + +type ConnectTokenPrivate struct { + ClientId uint64 + ServerAddrs []net.UDPAddr // list of server addresses this client may connect to ClientKey []byte // client to server key ServerKey []byte // server to client key UserData []byte // user data - ExpireTimestamp uint64 - TokenData *Buffer // connect token data + TokenData *Buffer // used to store the serialized buffer } // create a new empty token @@ -32,31 +40,51 @@ func NewConnectToken() *ConnectToken { return token } -// Generates the token with the supplied configuration values and clientId. -func (token *ConnectToken) Generate(config *Config, sequence, clientId uint64) error { +func (token *ConnectToken) ServerKey() []byte { + return token.PrivateData.ServerKey +} + +func (token *ConnectToken) ClientKey() []byte { + return token.PrivateData.ClientKey +} + +// list of server addresses this client may connect to +func (token *ConnectToken) ServerAddresses() []net.UDPAddr { + return token.PrivateData.ServerAddrs +} + +func (token *ConnectToken) ClientId() uint64 { + return token.PrivateData.ClientId +} + +// Generates the token with the supplied configuration values +func (token *ConnectToken) Generate(config *Config, clientId, currentTimestamp, sequence uint64) error { var err error - token.ClientId = clientId - token.ServerAddresses = config.ServerAddrs + privateData := &ConnectTokenPrivate{} + token.PrivateData = privateData - if token.UserData, err = RandomBytes(USER_DATA_BYTES); err != nil { + privateData.ClientId = clientId + privateData.ServerAddrs = config.ServerAddrs + + if privateData.UserData, err = RandomBytes(USER_DATA_BYTES); err != nil { return err } - if token.ClientKey, err = GenerateKey(); err != nil { + if privateData.ClientKey, err = GenerateKey(); err != nil { return err } - if token.ServerKey, err = GenerateKey(); err != nil { + if privateData.ServerKey, err = GenerateKey(); err != nil { return err } - if token.TokenData, err = WriteToken(token); err != nil { + if privateData.TokenData, err = WriteConnectToken(token); err != nil { return err } - creationTime := time.Now().Unix() - token.ExpireTimestamp = uint64(creationTime) + config.TokenExpiry + token.CreateTimestamp = currentTimestamp + token.ExpireTimestamp = token.CreateTimestamp + config.TokenExpiry token.Encrypt(config.ProtocolId, sequence, config.PrivateKey) return nil @@ -65,11 +93,10 @@ func (token *ConnectToken) Generate(config *Config, sequence, clientId uint64) e // Encrypts the token.TokenData func (token *ConnectToken) Encrypt(protocolId, sequence uint64, privateKey []byte) error { additionalData, nonce := buildCryptData(protocolId, token.ExpireTimestamp, sequence) - - if err := EncryptAead(&token.TokenData.Buf, additionalData.Bytes(), nonce.Bytes(), privateKey); err != nil { + if err := EncryptAead(&token.PrivateData.TokenData.Buf, additionalData.Bytes(), nonce.Bytes(), privateKey); err != nil { return err } - log.Printf("after encrypt: %#v\n", token.TokenData.Bytes()) + log.Printf("after encrypt: %#v\n", token.PrivateData.TokenData) return nil } @@ -78,8 +105,7 @@ func (token *ConnectToken) Decrypt(protocolId, sequence uint64, privateKey []byt var err error additionalData, nonce := buildCryptData(protocolId, token.ExpireTimestamp, sequence) - - if token.TokenData.Buf, err = DecryptAead(token.TokenData.Bytes(), additionalData.Bytes(), nonce.Bytes(), privateKey); err != nil { + if token.PrivateData.TokenData.Buf, err = DecryptAead(token.PrivateData.TokenData.Bytes(), additionalData.Bytes(), nonce.Bytes(), privateKey); err != nil { return err } return nil @@ -89,6 +115,7 @@ func (token *ConnectToken) Decrypt(protocolId, sequence uint64, privateKey []byt func buildCryptData(protocolId, expireTimestamp, sequence uint64) (*Buffer, *Buffer) { additionalData := NewBuffer(VERSION_INFO_BYTES+8+8) additionalData.WriteBytes([]byte(VERSION_INFO)) + log.Printf("buildCryptData %x %x\n", protocolId, expireTimestamp) additionalData.WriteUint64(protocolId) additionalData.WriteUint64(expireTimestamp) @@ -99,12 +126,12 @@ func buildCryptData(protocolId, expireTimestamp, sequence uint64) (*Buffer, *Buf } // Writes the token data to the TokenData buffer and returns to caller -func WriteToken(token *ConnectToken) (*Buffer, error) { +func WriteConnectToken(token *ConnectToken) (*Buffer, error) { data := NewBuffer(CONNECT_TOKEN_PRIVATE_BYTES) - data.WriteUint64(token.ClientId) - data.WriteUint32(uint32(len(token.ServerAddresses))) + data.WriteUint64(token.PrivateData.ClientId) + data.WriteUint32(uint32(len(token.PrivateData.ServerAddrs))) - for _, addr := range token.ServerAddresses { + for _, addr := range token.ServerAddresses() { host, port, err := net.SplitHostPort(addr.String()) if err != nil { return nil, errors.New("invalid port for host: " + addr.String()) @@ -132,26 +159,34 @@ func WriteToken(token *ConnectToken) (*Buffer, error) { } data.WriteUint16(uint16(p)) } - data.WriteBytesN(token.ClientKey, KEY_BYTES) - data.WriteBytesN(token.ServerKey, KEY_BYTES) - data.WriteBytesN(token.UserData, USER_DATA_BYTES) + data.WriteBytesN(token.PrivateData.ClientKey, KEY_BYTES) + data.WriteBytesN(token.PrivateData.ServerKey, KEY_BYTES) + data.WriteBytesN(token.PrivateData.UserData, USER_DATA_BYTES) return data, nil } -// Takes in a slice of bytes and generates a new ConnectToken. -func ReadToken(tokenBuffer []byte) (*ConnectToken, error) { +// Takes in a slice of bytes and generates a new ConnectToken after decryption. +func ReadConnectToken(tokenBuffer []byte, protocolId, expireTimestamp, sequence uint64, privateKey []byte) (*ConnectToken, error) { var err error var servers uint32 var ipBytes []byte token := NewConnectToken() - buffer := NewBufferFromBytes(tokenBuffer) + token.PrivateData = &ConnectTokenPrivate{} + token.ExpireTimestamp = expireTimestamp + + token.PrivateData.TokenData = NewBufferFromBytes(tokenBuffer) + if err := token.Decrypt(protocolId, sequence, privateKey); err != nil { + return nil, errors.New("error decrypting connection token: " + err.Error()) + } - if token.ClientId, err = buffer.GetUint64(); err != nil { + if token.PrivateData.ClientId, err = token.PrivateData.TokenData.GetUint64(); err != nil { return nil, err } - servers, err = buffer.GetUint32() + log.Printf("clientid: %x\n", token.PrivateData.ClientId) + + servers, err = token.PrivateData.TokenData.GetUint32() if err != nil { return nil, err } @@ -161,21 +196,22 @@ func ReadToken(tokenBuffer []byte) (*ConnectToken, error) { } if servers > MAX_SERVERS_PER_CONNECT { + log.Printf("got %d expected %d\n", servers, MAX_SERVERS_PER_CONNECT) return nil, errors.New("too many servers") } - token.ServerAddresses = make([]net.UDPAddr, servers) + token.PrivateData.ServerAddrs = make([]net.UDPAddr, servers) for i := 0; i < int(servers); i+=1 { - serverType, err := buffer.GetUint8() + serverType, err := token.PrivateData.TokenData.GetUint8() if err != nil { return nil, err } if serverType == ADDRESS_IPV4 { - ipBytes, err = buffer.GetBytes(4) + ipBytes, err = token.PrivateData.TokenData.GetBytes(4) } else if serverType == ADDRESS_IPV6 { - ipBytes, err = buffer.GetBytes(16) + ipBytes, err = token.PrivateData.TokenData.GetBytes(16) } else { return nil, errors.New("unknown ip address") } @@ -185,22 +221,22 @@ func ReadToken(tokenBuffer []byte) (*ConnectToken, error) { } ip := net.IP(ipBytes) - port, err := buffer.GetUint16() + port, err := token.PrivateData.TokenData.GetUint16() if err != nil { return nil, errors.New("invalid port") } - token.ServerAddresses[i] = net.UDPAddr{IP: ip, Port: int(port)} + token.PrivateData.ServerAddrs[i] = net.UDPAddr{IP: ip, Port: int(port)} } - if token.ClientKey, err = buffer.GetBytes(KEY_BYTES); err != nil { + if token.PrivateData.ClientKey, err = token.PrivateData.TokenData.GetBytes(KEY_BYTES); err != nil { return nil, errors.New("error reading client to server key") } - if token.ServerKey, err = buffer.GetBytes(KEY_BYTES); err != nil { + if token.PrivateData.ServerKey, err = token.PrivateData.TokenData.GetBytes(KEY_BYTES); err != nil { return nil, errors.New("error reading server to client key") } - if token.UserData, err = buffer.GetBytes(USER_DATA_BYTES); err != nil { + if token.PrivateData.UserData, err = token.PrivateData.TokenData.GetBytes(USER_DATA_BYTES); err != nil { return nil, errors.New("error reading user data") } diff --git a/go/connect_token_test.go b/go/connect_token_test.go index 1bb91f6..5964742 100644 --- a/go/connect_token_test.go +++ b/go/connect_token_test.go @@ -3,7 +3,8 @@ package netcode import ( "testing" "net" - "bytes" + //"bytes" + "time" ) const ( @@ -24,38 +25,52 @@ func TestNewConnectToken(t *testing.T) { server := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000} servers := make([]net.UDPAddr, 1) servers[0] = server + config := NewConfig(servers, TEST_CONNECT_TOKEN_EXPIRY, TEST_PROTOCOL_ID, TEST_PRIVATE_KEY) - err := token.Generate(config, TEST_SEQUENCE_START, TEST_CLIENT_ID) + currentTimestamp := uint64(time.Now().Unix()) + + err := token.Generate(config, TEST_CLIENT_ID, currentTimestamp, TEST_SEQUENCE_START) if err != nil { - t.Fatalf("error generating token") + t.Fatalf("error generating and encrypting token") } - err = token.Decrypt(config.ProtocolId, TEST_SEQUENCE_START, config.PrivateKey) + _, err = ReadConnectToken(token.PrivateData.TokenData.Bytes(), config.ProtocolId, currentTimestamp+config.TokenExpiry, TEST_SEQUENCE_START, config.PrivateKey) if err != nil { - t.Fatalf("error decrypting token: %s\n", err) + t.Fatalf("error reading connect token %s", err) } - token2, err := ReadToken(token.TokenData.Buf) + //err = token.Decrypt(config.ProtocolId, TEST_SEQUENCE_START, config.PrivateKey) + //if err != nil { + // t.Fatalf("error decrypting: %s\n", err) + //} + + //err = token.Decrypt(config.ProtocolId, TEST_SEQUENCE_START, config.PrivateKey) + //if err != nil { + // t.Fatalf("error decrypting token: %s\n", err) + //} + - if token.ClientId != token2.ClientId { + /* + if token.ClientId() != token2.ClientId() { t.Fatalf("clientIds do not match expected %d got %d", token.ClientId, token2.ClientId) } - if len(token.ServerAddresses) != len(token2.ServerAddresses) { - t.Fatalf("time stamps do not match expected %d got %d", len(token.ServerAddresses), len(token2.ServerAddresses)) + if len(token.ServerAddresses()) != len(token2.ServerAddresses()) { + t.Fatalf("time stamps do not match expected %d got %d", len(token.ServerAddresses()), len(token2.ServerAddresses())) } // TODO verify server addresses - if bytes.Compare(token.ClientKey, token2.ClientKey) != 0 { - t.Fatalf("ClientKey do not match expected %v got %v", token.ClientKey, token2.ClientKey) + if bytes.Compare(token.ClientKey(), token2.ClientKey()) != 0 { + t.Fatalf("ClientKey do not match expected %v got %v", token.ClientKey(), token2.ClientKey()) } - if bytes.Compare(token.ServerKey, token2.ServerKey) != 0 { - t.Fatalf("ServerKey do not match expected %v got %v", token.ServerKey, token2.ServerKey) + if bytes.Compare(token.ServerKey(), token2.ServerKey()) != 0 { + t.Fatalf("ServerKey do not match expected %v got %v", token.ServerKey(), token2.ServerKey()) } - if bytes.Compare(token.UserData, token2.UserData) != 0 { - t.Fatalf("UserData do not match expected %v got %v", token.UserData, token2.UserData) + if bytes.Compare(token.PrivateData.UserData, token2.PrivateData.UserData) != 0 { + t.Fatalf("UserData do not match expected %v got %v", token.PrivateData.UserData, token2.PrivateData.UserData) } + */ } \ No newline at end of file diff --git a/go/crypto.go b/go/crypto.go index 16efe42..a28f00b 100644 --- a/go/crypto.go +++ b/go/crypto.go @@ -19,22 +19,21 @@ func GenerateKey() ([]byte, error) { } // Encrypts the message in place with the nonce and key and optional additional buffer -func EncryptAead(message *[]byte, additional []byte, nonce, key []byte) error { +func EncryptAead(message *[]byte, additional, nonce, key []byte) error { aead, err := chacha20poly1305.New(key) if err != nil { return err } - log.Printf("before seal: %#v\n", message) + log.Printf("before encrypt len: %d\n", len(*message)) *message = aead.Seal(nil, nonce, *message, additional) - log.Printf("after seal: %#v\n", message) + return nil } -// Encrypts the message with the nonce and key and optional additional buffer returning a copy +// Decrypts the message with the nonce and key and optional additional buffer returning a copy // byte slice -func DecryptAead(message []byte, additional []byte, nonce, key []byte) ([]byte, error) { +func DecryptAead(message []byte, additional, nonce, key []byte) ([]byte, error) { aead, err := chacha20poly1305.New(key) - if err != nil { return nil, err } diff --git a/go/packet.go b/go/packet.go index 7ed74ef..923bb25 100644 --- a/go/packet.go +++ b/go/packet.go @@ -1,8 +1,9 @@ package netcode import ( - "log" "errors" + "strconv" + "log" ) type PacketType uint8 @@ -24,7 +25,7 @@ const MAC_BYTES = 16 const NONCE_BYTES = 8 const MAX_SERVERS_PER_CONNECT = 32 -const VERSION_INFO = "NETCODE 1.00" +const VERSION_INFO = "NETCODE 1.00\x00" const PACKET_SEND_RATE = 10.0 const TIMEOUT_SECONDS = 5.0 const NUM_DISCONNECT_PACKETS = 10 @@ -41,6 +42,17 @@ const ( ConnectionNumPackets ) +var packetTypeMap = map[PacketType]string { + ConnectionRequest: "CONNECTION_REQUEST", + ConnectionDenied: "CONNECTION_DENIED", + ConnectionChallenge: "CONNECTION_CHALLENGE", + ConnectionResponse: "CONNECTION_RESPONSE", + ConnectionKeepAlive: "CONNECTION_KEEPALIVE", + ConnectionPayload: "CONNECTION_PAYLOAD", + ConnectionDisconnect: "CONNECTION_DISCONNECT", + ConnectionNumPackets: "CONNECTION_NUMPACKETS", +} + type Packet interface { GetType() PacketType } @@ -51,6 +63,7 @@ type RequestPacket struct { ProtocolId uint64 ConnectTokenExpireTimestamp uint64 ConnectTokenSequence uint64 + Token *ConnectToken ConnectTokenData []byte } @@ -59,7 +72,6 @@ func (p *RequestPacket) GetType() PacketType { } type DeniedPacket struct { - Type PacketType } func (p *DeniedPacket) GetType() PacketType { @@ -88,8 +100,8 @@ func (p *ResponsePacket) GetType() PacketType { type KeepAlivePacket struct { Type PacketType - ClientIndex uint - MaxClients uint + ClientIndex uint32 + MaxClients uint32 } func (p *KeepAlivePacket) GetType() PacketType { @@ -119,52 +131,59 @@ type DisconnectPacket struct { Type PacketType } +func (p *DisconnectPacket) GetType() PacketType { + return ConnectionDisconnect +} + type Context struct { WritePacketKey []byte ReadPacketKey []byte } -func WritePacket(packet Packet, buffer *Buffer, buffer_length uint, sequence uint64, write_packet_key []byte, protocol_id uint64) (int, error) { +func WritePacket(packet Packet, buffer *Buffer, sequence uint64, writePacketKey []byte, protocolId uint64) (int, error) { var p Packet - var start *Buffer packetType := packet.GetType() if packetType == ConnectionRequest { - + // connection request packet: first byte is zero p, ok := packet.(*RequestPacket) if !ok { - return -1, nil + return -1, errors.New("invalid packet type, expecting request packet") } - start = NewBufferFromBytes(buffer.Bytes()) buffer.WriteUint8(uint8(ConnectionRequest)) - buffer.WriteBytesN(p.VersionInfo, VERSION_INFO_BYTES) + buffer.WriteBytes(p.VersionInfo) buffer.WriteUint64(p.ProtocolId) buffer.WriteUint64(p.ConnectTokenExpireTimestamp) buffer.WriteUint64(p.ConnectTokenSequence) buffer.WriteBytesN(p.ConnectTokenData, CONNECT_TOKEN_PRIVATE_BYTES) - return buffer.Len() - start.Len(), nil + if buffer.Pos != 1 + 13 + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES { + return -1, errors.New("invalid buffer size") + } + return buffer.Pos, nil } // *** encrypted packets *** // write the prefix byte (this is a combination of the packet type and number of sequence bytes) - start = NewBufferFromBytes(buffer.Bytes()) - sequence_bytes := sequenceNumberBytesRequired(sequence) + sequenceBytes := sequenceNumberBytesRequired(sequence) + if (sequenceBytes < 1 || sequenceBytes > 8) { + return -1, errors.New("invalid sequence bytes, must be between [1-8]") + } - prefix_byte := uint8(p.GetType()) | uint8(sequence_bytes << 4) - buffer.WriteUint8(prefix_byte) + prefixByte := uint8(p.GetType()) | uint8(sequenceBytes << 4) + buffer.WriteUint8(prefixByte) - sequence_temp := sequence + sequenceTemp := sequence - for i := 0; i < sequence_bytes; i+=1 { - buffer.WriteUint8(uint8(sequence_temp & 0xFF)) - sequence_temp >>= 8 + for i := 0; i < sequenceBytes; i+=1 { + buffer.WriteUint8(uint8(sequenceTemp & 0xFF)) + sequenceTemp >>= 8 } - //encrypted_start := NewBufferFromBytes(buffer.Buf.Bytes()) - - switch (p.GetType()) { + encryptedStart := buffer.Pos + // write packet data according to type. this data will be encrypted. + switch p.GetType() { case ConnectionDenied: // ... case ConnectionChallenge: @@ -197,80 +216,252 @@ func WritePacket(packet Packet, buffer *Buffer, buffer_length uint, sequence uin case ConnectionDisconnect: // ... } - //encrypted_finish := buffer - + encryptedFinish := buffer.Pos // encrypt the per-packet packet written with the prefix byte, protocol id and version as the associated data. this must match to decrypt. - additional_data := NewBuffer(VERSION_INFO_BYTES+8+1) - additional_data.WriteBytesN([]byte(VERSION_INFO), VERSION_INFO_BYTES) + additionalData := NewBuffer(VERSION_INFO_BYTES+8+1) + additionalData.WriteBytesN([]byte(VERSION_INFO), VERSION_INFO_BYTES) + additionalData.WriteUint64(protocolId) + additionalData.WriteUint8(prefixByte) nonce := NewBuffer(8) - nonce.WriteUint64(sequence) - //err := EncryptAead(encrypted_start, len(encrypted_finish) - len(encrypted_start), additional_data, len(additional_data), nonce, write_packet_key) - //if err != nil { - // return -1, err - //} - - // buffer += MAC_BYTES ??? - - return buffer.Len() - start.Len(), nil + err := EncryptAead(&buffer.Buf[encryptedStart:encryptedFinish], additionalData.Bytes(), nonce.Bytes(), writePacketKey) + if err != nil { + return -1, err + } + return buffer.Pos + MAC_BYTES, nil } -func ReadPacket(buffer *Buffer, buffer_length int, sequence uint64, read_packet_key []byte, protocol_id uint64, current_timestamp uint64, private_key []byte, allowed_packets []byte, replay_protection *ReplayProtection) (Packet, error) { - var packet Packet - sequence = 0 - if buffer_length < 1 { +func ReadPacket(packetData []byte, packetLen int, sequence uint64, readPacketKey []byte, protocolId uint64, currentTimestamp uint64, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) (Packet, error) { + + if packetLen < 1 { return nil, errors.New("invalid buffer length") } - //start := NewBufferFromBytes(buffer.Buf.Bytes()) + packetBuffer := NewBufferFromBytes(packetData) - prefix_byte, err := buffer.GetUint8() + prefixByte, err := packetBuffer.GetUint8() if err != nil { return nil, errors.New("invalid buffer length") } - if PacketType(prefix_byte) == ConnectionRequest { - if allowed_packets[0] != 0 { - return nil, errors.New("ignored connection request packet. packet type is not allowed\n") + if PacketType(prefixByte) == ConnectionRequest { + return readRequestPacket(packetBuffer, packetLen, protocolId, currentTimestamp, allowedPackets, privateKey) + } + // *** encrypted packets *** + + if readPacketKey == nil { + return nil, errors.New("empty packet key") + } + + if packetLen < 1 + 1 + MAC_BYTES { + return nil, errors.New("ignored encrypted packet. packet is too small to be valid") + } + + packetType := prefixByte & 0xF + + if PacketType(packetType) >= ConnectionNumPackets { + return nil, errors.New("ignored encrypted packet. packet type " + packetTypeMap[PacketType(packetType)] + " is invalid") + } + + if allowedPackets[packetType] == 0 { + return nil, errors.New("ignored encrypted packet. packet type " + packetTypeMap[PacketType(packetType)] + " is invalid") + } + + sequenceBytes := prefixByte >> 4 + if sequenceBytes < 1 || sequenceBytes > 8 { + return nil, errors.New("ignored encrypted packet. sequence bytes is out of range [1,8]") + } + + if packetLen < 1 + int(sequenceBytes) + MAC_BYTES { + return nil, errors.New("ignored encrypted packet. buffer is too small for sequence bytes + encryption mac") + } + + var i uint8 + // read variable length sequence number [1,8] + for i = 0; i < sequenceBytes; i+=1 { + val, err := packetBuffer.GetUint8() + if err != nil { + return nil, err } + sequence |= uint64((val) << ( 8 * i )) + } - if buffer_length != 1 + VERSION_INFO_BYTES + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES { - return nil, errors.New("ignored connection request packet. bad packet length\n") + // replay protection (optional) + if replayProtection != nil && PacketType(packetType) >= ConnectionKeepAlive { + if replayProtection.AlreadyReceived(sequence) == 1 { + v := strconv.FormatUint(sequence, 10) + return nil, errors.New("ignored connection payload packet. sequence " + v + " already received (replay protection)") } + } + + // decrypt the per-packet type data + additionalData := NewBuffer(VERSION_INFO_BYTES+8+1) + additionalData.WriteBytes([]byte(VERSION_INFO)) + additionalData.WriteUint64(protocolId) + additionalData.WriteUint8(prefixByte) + + nonce := NewBuffer(SizeUint64) + nonce.WriteUint64(sequence) + + encryptedSize := packetLen - packetBuffer.Pos + if encryptedSize < MAC_BYTES { + return nil, errors.New("ignored encrypted packet. encrypted payload is too small") + } + + encryptedBuff, err := packetBuffer.GetBytes(encryptedSize) + if err != nil { + return nil, errors.New("ignored encrypted packet. encrypted payload is too small") + } + + decryptedBuff, err := DecryptAead(encryptedBuff, additionalData.Bytes(), nonce.Bytes(), readPacketKey) + if err != nil { + return nil, errors.New("ignored encrypted packet. failed to decrypt: " + err.Error()) + } + + decryptedSize := encryptedSize - MAC_BYTES + + // process the per-packet type data that was just decrypted + return processPacket(PacketType(packetType), decryptedBuff, decryptedSize) +} - if private_key == nil { - return nil, errors.New("ignored connection request packet. no private key\n") +func processPacket(packetType PacketType, decrypted []byte, decryptedSize int) (Packet, error) { + var err error + decryptedBuff := NewBufferFromBytes(decrypted) + + switch (packetType) { + case ConnectionDenied: + if decryptedSize != 0 { + return nil, errors.New("ignored connection denied packet. decrypted packet data is wrong size") + } + return &DeniedPacket{}, nil + case ConnectionChallenge: + if decryptedSize != 8 + CHALLENGE_TOKEN_BYTES { + return nil, errors.New("ignored connection challenge packet. decrypted packet data is wrong size") } - version_info, err := buffer.GetBytes(VERSION_INFO_BYTES) + packet := &ChallengePacket{} + packet.ChallengeTokenSequence, err = decryptedBuff.GetUint64() if err != nil { - return nil, errors.New("ignored connection request packet. bad version info\n") + return nil, errors.New("error reading challenge token sequence") } - if string(version_info) != VERSION_INFO { - return nil, errors.New("ignored connection request packet. bad version info\n") + packet.ChallengeTokenData, err = decryptedBuff.GetBytes(CHALLENGE_TOKEN_BYTES) + if err != nil { + return nil, errors.New("error reading challenge token data") + } + return packet, nil + case ConnectionResponse: + if decryptedSize != 8 + CHALLENGE_TOKEN_BYTES { + return nil, errors.New("ignored connection response packet. decrypted packet data is wrong size") } - id, err := buffer.GetUint64() - if err != nil || id != protocol_id { - return nil, errors.New("ignored connection request packet. wrong protocol id\n") + packet := &ResponsePacket{} + packet.ChallengeTokenSequence, err = decryptedBuff.GetUint64() + if err != nil { + return nil, errors.New("error reading response token sequence") } - expire, err := buffer.GetUint64() - if err != nil || expire <= current_timestamp { - return nil, errors.New("ignored connection request packet. connect token expired\n") + packet.ChallengeTokenData, err = decryptedBuff.GetBytes(CHALLENGE_TOKEN_BYTES) + if err != nil { + return nil, errors.New("error reading response token data") + } + return packet, nil + case ConnectionKeepAlive: + if decryptedSize != 8 { + return nil, errors.New("ignored connection keep alive packet. decrypted packet data is wrong size") + } + packet := &KeepAlivePacket{} + packet.ClientIndex, err = decryptedBuff.GetUint32() + if err != nil { + return nil, errors.New("error reading keepalive client index") } - token_sequence, err := buffer.GetUint64() + packet.MaxClients, err = decryptedBuff.GetUint32() if err != nil { - return nil, err + return nil, errors.New("error reading keepalive max clients") } - log.Print(token_sequence) return packet, nil + case ConnectionPayload: + if decryptedSize < 1 { + return nil, errors.New("ignored connection payload packet. payload is too small") + } + + if decryptedSize > MAX_PAYLOAD_BYTES { + return nil, errors.New("ignored connection payload packet. payload is too large") + } + + packet := NewPayloadPacket(uint32(decryptedSize)) + copy(packet.PayloadData, decryptedBuff.Bytes()) + return packet, nil + case ConnectionDisconnect: + if decryptedSize != 0 { + return nil, errors.New("ignored connection disconnect packet. decrypted packet data is wrong size") + } + packet := &DisconnectPacket{} + return packet, nil + } + + return nil, errors.New("unknown packet type") +} + +// Reads the RequestPacket type returning the packet after deserializing +func readRequestPacket(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, allowedPackets []byte, privateKey []byte) (Packet, error) { + var err error + packet := &RequestPacket{} + + if allowedPackets[0] == 0 { + return nil, errors.New("ignored connection request packet. packet type is not allowed") + } + + if packetLen != 1 + VERSION_INFO_BYTES + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES { + log.Printf("packetLen: %d, expected: %d\n", packetLen, 1 + VERSION_INFO_BYTES + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES) + return nil, errors.New("ignored connection request packet. bad packet length") + } + + if privateKey == nil { + return nil, errors.New("ignored connection request packet. no private key\n") + } + + packet.VersionInfo, err = packetBuffer.GetBytes(VERSION_INFO_BYTES) + if err != nil { + return nil, errors.New("ignored connection request packet. bad version info\n") + } + + if string(packet.VersionInfo) != VERSION_INFO { + return nil, errors.New("ignored connection request packet. bad version info\n") } + + packet.ProtocolId, err = packetBuffer.GetUint64() + if err != nil || packet.ProtocolId != protocolId { + return nil, errors.New("ignored connection request packet. wrong protocol id\n") + } + + packet.ConnectTokenExpireTimestamp, err = packetBuffer.GetUint64() + if err != nil || packet.ConnectTokenExpireTimestamp <= currentTimestamp { + return nil, errors.New("ignored connection request packet. connect token expired\n") + } + + packet.ConnectTokenSequence, err = packetBuffer.GetUint64() + if err != nil { + return nil, err + } + + var tokenBuffer []byte + tokenBuffer, err = packetBuffer.GetBytes(CONNECT_TOKEN_PRIVATE_BYTES) + if err != nil { + return nil, err + } + log.Printf("len tokenBuffer: %d, pos: %d\n", len(tokenBuffer), packetBuffer.Pos) + log.Printf("tokenBuffer: %x %#v\n", packet.ConnectTokenExpireTimestamp, tokenBuffer) + packet.Token, err = ReadConnectToken(tokenBuffer, packet.ProtocolId, packet.ConnectTokenExpireTimestamp, packet.ConnectTokenSequence, privateKey) + if err != nil { + return nil, err + } + + packet.ConnectTokenData = packet.Token.PrivateData.TokenData.Buf return packet, nil } diff --git a/go/packet_test.go b/go/packet_test.go index 13e5b95..b2ec051 100644 --- a/go/packet_test.go +++ b/go/packet_test.go @@ -2,8 +2,95 @@ package netcode import ( "testing" + "net" + "time" + "bytes" ) func TestReadPacket(t *testing.T) { -} \ No newline at end of file +} + +func TestConnectionRequestPacket(t *testing.T) { + addr := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: TEST_SERVER_PORT} + serverAddrs := make([]net.UDPAddr, 1) + serverAddrs[0] = addr + config := NewConfig(serverAddrs, TEST_CONNECT_TOKEN_EXPIRY, TEST_PROTOCOL_ID, TEST_PRIVATE_KEY) + + connectToken := NewConnectToken() + currentTimestamp := uint64(time.Now().Unix()) + + if err := connectToken.Generate(config, TEST_CLIENT_ID, currentTimestamp, TEST_SEQUENCE_START); err != nil { + t.Fatalf("error generating connect token: %s\n", err) + } + + // setup a connection request packet wrapping the encrypted connect token + inputPacket := &RequestPacket{} + inputPacket.VersionInfo = []byte(VERSION_INFO) + inputPacket.ProtocolId = TEST_PROTOCOL_ID + inputPacket.ConnectTokenExpireTimestamp = uint64(time.Now().Unix() + 30) + inputPacket.ConnectTokenSequence = TEST_SEQUENCE_START + inputPacket.ConnectTokenData = connectToken.PrivateData.TokenData.Bytes() + + // write the connection request packet to a buffer + buffer := NewBuffer(2048) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + // read the connection request packet back in from the buffer (the connect token data is decrypted as part of the read packet validation) + var sequence uint64 + sequence = TEST_SEQUENCE_START + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i+=1 { + allowedPackets[i] = 1 + } + + buffer.Reset() + outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), TEST_PRIVATE_KEY, allowedPackets, nil) + if err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + if outputPacket.GetType() != ConnectionRequest { + t.Fatal("packet output was not a connection request") + } + + output, ok := outputPacket.(*RequestPacket) + if !ok { + t.Fatalf("error casting to connection request packet") + } + + if bytes.Compare(inputPacket.VersionInfo, output.VersionInfo) != 0 { + t.Fatalf("version info did not match") + } + + if inputPacket.ProtocolId != output.ProtocolId { + t.Fatalf("ProtocolId did not match") + } + + if inputPacket.ConnectTokenExpireTimestamp != output.ConnectTokenExpireTimestamp { + t.Fatalf("ConnectTokenExpireTimestamp did not match") + } + + if inputPacket.ConnectTokenSequence != output.ConnectTokenSequence { + t.Fatalf("ConnectTokenSequence did not match") + } + + if bytes.Compare(inputPacket.Token.PrivateData.TokenData.Buf, output.Token.PrivateData.TokenData.Buf) != 0 { + t.Fatalf("TokenData did not match") + } + +} From f41d3733f0f9d39c710381b0993ff3a606a9b2db Mon Sep 17 00:00:00 2001 From: wirepair Date: Wed, 5 Apr 2017 15:11:52 +0900 Subject: [PATCH 06/11] finish impl & test for RequestPacket --- go/challenge_token.go | 23 ++-- go/challenge_token_test.go | 12 +- go/connect_token.go | 226 ++++++++++++++++++++----------------- go/connect_token_test.go | 62 ++++++---- go/crypto.go | 6 +- go/packet.go | 115 ++++++++++--------- go/packet_test.go | 79 +++++++++---- go/simulator.go | 11 ++ 8 files changed, 308 insertions(+), 226 deletions(-) diff --git a/go/challenge_token.go b/go/challenge_token.go index 835bc85..38e0ce0 100644 --- a/go/challenge_token.go +++ b/go/challenge_token.go @@ -12,38 +12,35 @@ func NewChallengeToken(clientId uint64) *ChallengeToken { token := &ChallengeToken{} token.ClientId = clientId token.UserData = NewBuffer(USER_DATA_BYTES) - token.TokenData = NewBuffer(CHALLENGE_TOKEN_BYTES) return token } // Encrypts the TokenData buffer with the sequence nonce and provided key -func (t *ChallengeToken) Encrypt(sequence uint64, key []byte) error { +func EncryptChallengeToken(tokenBuffer *[]byte, sequence uint64, key []byte) error { nonce := NewBuffer(SizeUint64) nonce.WriteUint64(sequence) - - return EncryptAead(&t.TokenData.Buf, nil, nonce.Bytes(), key) + return EncryptAead(tokenBuffer, nil, nonce.Bytes(), key) } // Decrypts the TokenData buffer with the sequence nonce and provided key, updating the // internal TokenData buffer -func (t *ChallengeToken) Decrypt(sequence uint64, key []byte) error { - var err error +func DecryptChallengeToken(tokenBuffer []byte, sequence uint64, key []byte) ([]byte, error) { nonce := NewBuffer(SizeUint64) nonce.WriteUint64(sequence) - t.TokenData.Buf, err = DecryptAead(t.TokenData.Buf, nil, nonce.Bytes(), key) - return err + return DecryptAead(tokenBuffer, nil, nonce.Bytes(), key) } // Serializes the client id and userData, also sets the UserData buffer. -func (t *ChallengeToken) Write(userData []byte) { +func (t *ChallengeToken) Write(userData []byte) []byte { + tokenData := NewBuffer(CHALLENGE_TOKEN_BYTES) t.UserData.WriteBytes(userData) - - t.TokenData.WriteUint64(t.ClientId) - t.TokenData.WriteBytes(userData) + tokenData.WriteUint64(t.ClientId) + tokenData.WriteBytes(userData) + return tokenData.Buf } // Generates a new ChallengeToken from the provided buffer byte slice. Only sets the ClientId -// and UserData buffer, does not update the TokenData buffer. +// and UserData buffer. func ReadChallengeToken(buffer []byte) (*ChallengeToken, error) { var err error var clientId uint64 diff --git a/go/challenge_token_test.go b/go/challenge_token_test.go index 0295cd6..34081e1 100644 --- a/go/challenge_token_test.go +++ b/go/challenge_token_test.go @@ -8,12 +8,13 @@ import ( func TestNewChallengeToken(t *testing.T) { var err error var userData []byte + var decryptedBuffer []byte token := NewChallengeToken(TEST_CLIENT_ID) if userData, err = RandomBytes(USER_DATA_BYTES); err != nil { t.Fatalf("error generating random data\n") } - token.Write(userData) + tokenBuffer := token.Write(userData) var sequence uint64 sequence = 1000 @@ -22,15 +23,15 @@ func TestNewChallengeToken(t *testing.T) { t.Fatalf("error generating key\n") } - if err := token.Encrypt(sequence, key); err != nil { + if err := EncryptChallengeToken(&tokenBuffer, sequence, key); err != nil { t.Fatalf("error encrypting challenge token: %s\n", err) } - if err := token.Decrypt(sequence, key); err != nil { + if decryptedBuffer, err = DecryptChallengeToken(tokenBuffer, sequence, key); err != nil { t.Fatalf("error decrypting challenge token: %s\n", err) } - newToken, err := ReadChallengeToken(token.TokenData.Buf) + newToken, err := ReadChallengeToken(decryptedBuffer) if err != nil { t.Fatalf("error reading token data %s\n", err) } @@ -42,7 +43,4 @@ func TestNewChallengeToken(t *testing.T) { if bytes.Compare(newToken.UserData.Buf, token.UserData.Buf) != 0 { t.Fatalf("user data did not match expected\n %#v\ngot\n%#v!", token.UserData.Buf, newToken.UserData.Buf) } - - - } \ No newline at end of file diff --git a/go/connect_token.go b/go/connect_token.go index 897c6cf..5b52d54 100644 --- a/go/connect_token.go +++ b/go/connect_token.go @@ -25,15 +25,6 @@ type ConnectToken struct { TimeoutSeconds int } -type ConnectTokenPrivate struct { - ClientId uint64 - ServerAddrs []net.UDPAddr // list of server addresses this client may connect to - ClientKey []byte // client to server key - ServerKey []byte // server to client key - UserData []byte // user data - TokenData *Buffer // used to store the serialized buffer -} - // create a new empty token func NewConnectToken() *ConnectToken { token := &ConnectToken{} @@ -57,76 +48,127 @@ func (token *ConnectToken) ClientId() uint64 { return token.PrivateData.ClientId } -// Generates the token with the supplied configuration values -func (token *ConnectToken) Generate(config *Config, clientId, currentTimestamp, sequence uint64) error { - var err error - privateData := &ConnectTokenPrivate{} - token.PrivateData = privateData +type ConnectTokenPrivate struct { + ClientId uint64 + ServerAddrs []net.UDPAddr // list of server addresses this client may connect to + ClientKey []byte // client to server key + ServerKey []byte // server to client key + UserData []byte // used to store user data + TokenData *Buffer // used to store the serialized buffer +} - privateData.ClientId = clientId - privateData.ServerAddrs = config.ServerAddrs +func NewConnectTokenPrivate() *ConnectTokenPrivate { + p := &ConnectTokenPrivate{} + p.TokenData = NewBuffer(CONNECT_TOKEN_PRIVATE_BYTES) + return p +} - if privateData.UserData, err = RandomBytes(USER_DATA_BYTES); err != nil { +func (p *ConnectTokenPrivate) Read() error { + var err error + + + if p.ClientId, err = p.TokenData.GetUint64(); err != nil { return err } - if privateData.ClientKey, err = GenerateKey(); err != nil { + if err := p.readServerData(); err != nil { return err } - if privateData.ServerKey, err = GenerateKey(); err != nil { - return err + if p.ClientKey, err = p.TokenData.GetBytes(KEY_BYTES); err != nil { + return errors.New("error reading client to server key") } - if privateData.TokenData, err = WriteConnectToken(token); err != nil { - return err + if p.ServerKey, err = p.TokenData.GetBytes(KEY_BYTES); err != nil { + return errors.New("error reading server to client key") } - token.CreateTimestamp = currentTimestamp - token.ExpireTimestamp = token.CreateTimestamp + config.TokenExpiry - token.Encrypt(config.ProtocolId, sequence, config.PrivateKey) + if p.UserData, err = p.TokenData.GetBytes(USER_DATA_BYTES); err != nil { + return errors.New("error reading user data") + } return nil } -// Encrypts the token.TokenData -func (token *ConnectToken) Encrypt(protocolId, sequence uint64, privateKey []byte) error { - additionalData, nonce := buildCryptData(protocolId, token.ExpireTimestamp, sequence) - if err := EncryptAead(&token.PrivateData.TokenData.Buf, additionalData.Bytes(), nonce.Bytes(), privateKey); err != nil { +func (p *ConnectTokenPrivate) readServerData() error { + var err error + var servers uint32 + var ipBytes []byte + + servers, err = p.TokenData.GetUint32() + if err != nil { return err } - log.Printf("after encrypt: %#v\n", token.PrivateData.TokenData) + + if servers <= 0 { + return errors.New("empty servers") + } + + if servers > MAX_SERVERS_PER_CONNECT { + log.Printf("got %d expected %d\n", servers, MAX_SERVERS_PER_CONNECT) + return errors.New("too many servers") + } + + p.ServerAddrs = make([]net.UDPAddr, servers) + + for i := 0; i < int(servers); i+=1 { + serverType, err := p.TokenData.GetUint8() + if err != nil { + return err + } + + if serverType == ADDRESS_IPV4 { + ipBytes, err = p.TokenData.GetBytes(4) + } else if serverType == ADDRESS_IPV6 { + ipBytes, err = p.TokenData.GetBytes(16) + } else { + return errors.New("unknown ip address") + } + + if err != nil { + return err + } + + ip := net.IP(ipBytes) + port, err := p.TokenData.GetUint16() + if err != nil { + return errors.New("invalid port") + } + p.ServerAddrs[i] = net.UDPAddr{IP: ip, Port: int(port)} + } return nil } -// Decrypts the tokendata and assigns it back to the backing buffer -func (token *ConnectToken) Decrypt(protocolId, sequence uint64, privateKey []byte) error { +// Generates the token with the supplied configuration values +func (token *ConnectToken) Generate(config *Config, clientId, currentTimestamp, sequence uint64) error { var err error - additionalData, nonce := buildCryptData(protocolId, token.ExpireTimestamp, sequence) - if token.PrivateData.TokenData.Buf, err = DecryptAead(token.PrivateData.TokenData.Bytes(), additionalData.Bytes(), nonce.Bytes(), privateKey); err != nil { + privateData := &ConnectTokenPrivate{} + token.PrivateData = privateData + + privateData.ClientId = clientId + privateData.ServerAddrs = config.ServerAddrs + + if privateData.UserData, err = RandomBytes(USER_DATA_BYTES); err != nil { return err } - return nil -} -// builds the additional data and nonce necessary for encryption and decryption. -func buildCryptData(protocolId, expireTimestamp, sequence uint64) (*Buffer, *Buffer) { - additionalData := NewBuffer(VERSION_INFO_BYTES+8+8) - additionalData.WriteBytes([]byte(VERSION_INFO)) - log.Printf("buildCryptData %x %x\n", protocolId, expireTimestamp) - additionalData.WriteUint64(protocolId) - additionalData.WriteUint64(expireTimestamp) + if privateData.ClientKey, err = GenerateKey(); err != nil { + return err + } - nonce := NewBuffer(SizeUint64) - nonce.WriteUint64(sequence) + if privateData.ServerKey, err = GenerateKey(); err != nil { + return err + } - return additionalData, nonce + token.CreateTimestamp = currentTimestamp + token.ExpireTimestamp = token.CreateTimestamp + config.TokenExpiry + return nil } -// Writes the token data to the TokenData buffer and returns to caller -func WriteConnectToken(token *ConnectToken) (*Buffer, error) { +// Writes the token data to a byte slice and returns to caller +func (token *ConnectToken) Write() ([]byte, error) { data := NewBuffer(CONNECT_TOKEN_PRIVATE_BYTES) data.WriteUint64(token.PrivateData.ClientId) data.WriteUint32(uint32(len(token.PrivateData.ServerAddrs))) @@ -162,83 +204,55 @@ func WriteConnectToken(token *ConnectToken) (*Buffer, error) { data.WriteBytesN(token.PrivateData.ClientKey, KEY_BYTES) data.WriteBytesN(token.PrivateData.ServerKey, KEY_BYTES) data.WriteBytesN(token.PrivateData.UserData, USER_DATA_BYTES) - return data, nil + return data.Buf, nil } // Takes in a slice of bytes and generates a new ConnectToken after decryption. func ReadConnectToken(tokenBuffer []byte, protocolId, expireTimestamp, sequence uint64, privateKey []byte) (*ConnectToken, error) { var err error - var servers uint32 - var ipBytes []byte + var privateData []byte token := NewConnectToken() - token.PrivateData = &ConnectTokenPrivate{} token.ExpireTimestamp = expireTimestamp - token.PrivateData.TokenData = NewBufferFromBytes(tokenBuffer) - if err := token.Decrypt(protocolId, sequence, privateKey); err != nil { + if privateData, err = DecryptConnectTokenPrivate(tokenBuffer, protocolId, expireTimestamp, sequence, privateKey); err != nil { return nil, errors.New("error decrypting connection token: " + err.Error()) } - if token.PrivateData.ClientId, err = token.PrivateData.TokenData.GetUint64(); err != nil { + private := NewConnectTokenPrivate() + private.TokenData = NewBufferFromBytes(privateData) + if err = private.Read(); err != nil { return nil, err } + token.PrivateData = private + return token, nil +} - log.Printf("clientid: %x\n", token.PrivateData.ClientId) - - servers, err = token.PrivateData.TokenData.GetUint32() - if err != nil { - return nil, err - } - - if servers <= 0 { - return nil, errors.New("empty servers") - } - - if servers > MAX_SERVERS_PER_CONNECT { - log.Printf("got %d expected %d\n", servers, MAX_SERVERS_PER_CONNECT) - return nil, errors.New("too many servers") - } - - token.PrivateData.ServerAddrs = make([]net.UDPAddr, servers) - - for i := 0; i < int(servers); i+=1 { - serverType, err := token.PrivateData.TokenData.GetUint8() - if err != nil { - return nil, err - } - - if serverType == ADDRESS_IPV4 { - ipBytes, err = token.PrivateData.TokenData.GetBytes(4) - } else if serverType == ADDRESS_IPV6 { - ipBytes, err = token.PrivateData.TokenData.GetBytes(16) - } else { - return nil, errors.New("unknown ip address") - } - - if err != nil { - return nil, err - } +// Encrypts the supplied buffer for the token private parts +func EncryptConnectTokenPrivate(privateData *[]byte, protocolId, expireTimestamp, sequence uint64, privateKey []byte) error { + additionalData, nonce := buildCryptData(protocolId, expireTimestamp, sequence) - ip := net.IP(ipBytes) - port, err := token.PrivateData.TokenData.GetUint16() - if err != nil { - return nil, errors.New("invalid port") - } - token.PrivateData.ServerAddrs[i] = net.UDPAddr{IP: ip, Port: int(port)} + if err := EncryptAead(privateData, additionalData, nonce, privateKey); err != nil { + return err } + return nil +} - if token.PrivateData.ClientKey, err = token.PrivateData.TokenData.GetBytes(KEY_BYTES); err != nil { - return nil, errors.New("error reading client to server key") - } +// Decrypts the supplied privateData buffer and generates a new ConnectTokenPrivate instance +func DecryptConnectTokenPrivate(privateData []byte, protocolId, expireTimestamp, sequence uint64, privateKey []byte) ([]byte, error) { + additionalData, nonce := buildCryptData(protocolId, expireTimestamp, sequence) + return DecryptAead(privateData, additionalData, nonce, privateKey) +} - if token.PrivateData.ServerKey, err = token.PrivateData.TokenData.GetBytes(KEY_BYTES); err != nil { - return nil, errors.New("error reading server to client key") - } +// builds the additional data and nonce necessary for encryption and decryption. +func buildCryptData(protocolId, expireTimestamp, sequence uint64) ([]byte, []byte) { + additionalData := NewBuffer(VERSION_INFO_BYTES+8+8) + additionalData.WriteBytes([]byte(VERSION_INFO)) + additionalData.WriteUint64(protocolId) + additionalData.WriteUint64(expireTimestamp) - if token.PrivateData.UserData, err = token.PrivateData.TokenData.GetBytes(USER_DATA_BYTES); err != nil { - return nil, errors.New("error reading user data") - } + nonce := NewBuffer(SizeUint64) + nonce.WriteUint64(sequence) - return token, nil + return additionalData.Buf, nonce.Buf } \ No newline at end of file diff --git a/go/connect_token_test.go b/go/connect_token_test.go index 5964742..1d8f23e 100644 --- a/go/connect_token_test.go +++ b/go/connect_token_test.go @@ -3,7 +3,7 @@ package netcode import ( "testing" "net" - //"bytes" + "bytes" "time" ) @@ -21,7 +21,7 @@ var TEST_PRIVATE_KEY = []byte{0x60, 0x6a, 0xbe, 0x6e, 0xc9, 0x19, 0x10, 0xea, 0x6b, 0x3c, 0x60, 0xf4, 0xb7, 0x15, 0xab, 0xa1 } func TestNewConnectToken(t *testing.T) { - token := NewConnectToken() + token1 := NewConnectToken() server := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000} servers := make([]net.UDPAddr, 1) servers[0] = server @@ -29,48 +29,60 @@ func TestNewConnectToken(t *testing.T) { config := NewConfig(servers, TEST_CONNECT_TOKEN_EXPIRY, TEST_PROTOCOL_ID, TEST_PRIVATE_KEY) currentTimestamp := uint64(time.Now().Unix()) - err := token.Generate(config, TEST_CLIENT_ID, currentTimestamp, TEST_SEQUENCE_START) + err := token1.Generate(config, TEST_CLIENT_ID, currentTimestamp, TEST_SEQUENCE_START) if err != nil { t.Fatalf("error generating and encrypting token") } - _, err = ReadConnectToken(token.PrivateData.TokenData.Bytes(), config.ProtocolId, currentTimestamp+config.TokenExpiry, TEST_SEQUENCE_START, config.PrivateKey) + private, err := token1.Write() + if err != nil { + t.Fatalf("error writing token private data") + } + + EncryptConnectTokenPrivate(&private, TEST_PROTOCOL_ID, uint64(currentTimestamp + config.TokenExpiry), TEST_SEQUENCE_START, config.PrivateKey) + + token2, err := ReadConnectToken(private, config.ProtocolId, currentTimestamp+config.TokenExpiry, TEST_SEQUENCE_START, config.PrivateKey) if err != nil { t.Fatalf("error reading connect token %s", err) } - //err = token.Decrypt(config.ProtocolId, TEST_SEQUENCE_START, config.PrivateKey) - //if err != nil { - // t.Fatalf("error decrypting: %s\n", err) - //} + compareTokens(token1, token2, t) - //err = token.Decrypt(config.ProtocolId, TEST_SEQUENCE_START, config.PrivateKey) - //if err != nil { - // t.Fatalf("error decrypting token: %s\n", err) - //} + private2, err := token2.Write() + if err != nil { + t.Fatalf("error writing token2 buffer") + } + EncryptConnectTokenPrivate(&private2, TEST_PROTOCOL_ID, uint64(currentTimestamp + config.TokenExpiry), TEST_SEQUENCE_START, config.PrivateKey) - /* - if token.ClientId() != token2.ClientId() { - t.Fatalf("clientIds do not match expected %d got %d", token.ClientId, token2.ClientId) + if bytes.Compare(private, private2) != 0 { + t.Fatalf("encrypted private bits didn't match %v and %v\n", private, private2) } +} - if len(token.ServerAddresses()) != len(token2.ServerAddresses()) { - t.Fatalf("time stamps do not match expected %d got %d", len(token.ServerAddresses()), len(token2.ServerAddresses())) +func compareTokens(token1, token2 *ConnectToken, t *testing.T) { + if token1.ClientId() != token2.ClientId() { + t.Fatalf("clientIds do not match expected %d got %d", token1.ClientId, token2.ClientId) } - // TODO verify server addresses + if len(token1.ServerAddresses()) != len(token2.ServerAddresses()) { + t.Fatalf("time stamps do not match expected %d got %d", len(token1.ServerAddresses()), len(token2.ServerAddresses())) + } - if bytes.Compare(token.ClientKey(), token2.ClientKey()) != 0 { - t.Fatalf("ClientKey do not match expected %v got %v", token.ClientKey(), token2.ClientKey()) + token1Servers := token1.ServerAddresses() + token2Servers := token2.ServerAddresses() + for i := 0; i < len(token1.ServerAddresses()); i+=1 { + if bytes.Compare([]byte(token1Servers[i].IP), []byte(token2Servers[i].IP)) != 0 { + t.Fatalf("server addresses did not match: expected %v got %v\n", token1Servers[i], token2Servers[i]) + } } - if bytes.Compare(token.ServerKey(), token2.ServerKey()) != 0 { - t.Fatalf("ServerKey do not match expected %v got %v", token.ServerKey(), token2.ServerKey()) + if bytes.Compare(token1.ClientKey(), token2.ClientKey()) != 0 { + t.Fatalf("ClientKey do not match expected %v got %v", token1.ClientKey(), token2.ClientKey()) } - if bytes.Compare(token.PrivateData.UserData, token2.PrivateData.UserData) != 0 { - t.Fatalf("UserData do not match expected %v got %v", token.PrivateData.UserData, token2.PrivateData.UserData) + if bytes.Compare(token1.ServerKey(), token2.ServerKey()) != 0 { + t.Fatalf("ServerKey do not match expected %v got %v", token1.ServerKey(), token2.ServerKey()) } - */ + } \ No newline at end of file diff --git a/go/crypto.go b/go/crypto.go index a28f00b..827ac43 100644 --- a/go/crypto.go +++ b/go/crypto.go @@ -3,6 +3,8 @@ package netcode import ( "crypto/rand" "github.com/codahale/chacha20poly1305" + //"log" + "crypto/sha1" "log" ) @@ -24,9 +26,8 @@ func EncryptAead(message *[]byte, additional, nonce, key []byte) error { if err != nil { return err } - log.Printf("before encrypt len: %d\n", len(*message)) *message = aead.Seal(nil, nonce, *message, additional) - + log.Printf("AFTER ENCRYPT: %x %x %x %x\n", sha1.Sum(*message), sha1.Sum(additional), sha1.Sum(nonce), sha1.Sum(key)) return nil } @@ -37,5 +38,6 @@ func DecryptAead(message []byte, additional, nonce, key []byte) ([]byte, error) if err != nil { return nil, err } + log.Printf("BEFORE DECRYPT: %x %x %x %x\n", sha1.Sum(message), sha1.Sum(additional), sha1.Sum(nonce), sha1.Sum(key)) return aead.Open(nil, nonce, message, additional) } \ No newline at end of file diff --git a/go/packet.go b/go/packet.go index 923bb25..ebfd9e5 100644 --- a/go/packet.go +++ b/go/packet.go @@ -4,10 +4,9 @@ import ( "errors" "strconv" "log" + "crypto/sha1" ) -type PacketType uint8 - const MAX_CLIENTS = 60 const CONNECT_TOKEN_PRIVATE_BYTES = 1024 const CHALLENGE_TOKEN_BYTES = 300 @@ -26,11 +25,10 @@ const NONCE_BYTES = 8 const MAX_SERVERS_PER_CONNECT = 32 const VERSION_INFO = "NETCODE 1.00\x00" -const PACKET_SEND_RATE = 10.0 -const TIMEOUT_SECONDS = 5.0 -const NUM_DISCONNECT_PACKETS = 10 +type PacketType uint8 + const ( ConnectionRequest PacketType = iota ConnectionDenied @@ -39,8 +37,10 @@ const ( ConnectionKeepAlive ConnectionPayload ConnectionDisconnect - ConnectionNumPackets + ) +// not a packet type, but value is last packetType+1 +const ConnectionNumPackets = ConnectionDisconnect+1 var packetTypeMap = map[PacketType]string { ConnectionRequest: "CONNECTION_REQUEST", @@ -50,7 +50,6 @@ var packetTypeMap = map[PacketType]string { ConnectionKeepAlive: "CONNECTION_KEEPALIVE", ConnectionPayload: "CONNECTION_PAYLOAD", ConnectionDisconnect: "CONNECTION_DISCONNECT", - ConnectionNumPackets: "CONNECTION_NUMPACKETS", } type Packet interface { @@ -58,13 +57,12 @@ type Packet interface { } type RequestPacket struct { - Type PacketType VersionInfo []byte ProtocolId uint64 ConnectTokenExpireTimestamp uint64 ConnectTokenSequence uint64 Token *ConnectToken - ConnectTokenData []byte + ConnectTokenData []byte // the encrypted Token after Write -> Encrypt } func (p *RequestPacket) GetType() PacketType { @@ -79,7 +77,6 @@ func (p *DeniedPacket) GetType() PacketType { } type ChallengePacket struct { - Type PacketType ChallengeTokenSequence uint64 ChallengeTokenData []byte } @@ -89,7 +86,6 @@ func (p *ChallengePacket) GetType() PacketType { } type ResponsePacket struct { - Type PacketType ChallengeTokenSequence uint64 ChallengeTokenData []byte } @@ -99,7 +95,6 @@ func (p *ResponsePacket) GetType() PacketType { } type KeepAlivePacket struct { - Type PacketType ClientIndex uint32 MaxClients uint32 } @@ -110,10 +105,8 @@ func (p *KeepAlivePacket) GetType() PacketType { type PayloadPacket struct { - Type PacketType PayloadBytes uint32 PayloadData []byte - // ... } func (p *PayloadPacket) GetType() PacketType { @@ -121,30 +114,22 @@ func (p *PayloadPacket) GetType() PacketType { } func NewPayloadPacket(payloadBytes uint32) *PayloadPacket { - packet := &PayloadPacket{Type: ConnectionPayload} + packet := &PayloadPacket{} packet.PayloadBytes = payloadBytes packet.PayloadData = make([]byte, payloadBytes) return packet } type DisconnectPacket struct { - Type PacketType } func (p *DisconnectPacket) GetType() PacketType { return ConnectionDisconnect } -type Context struct { - WritePacketKey []byte - ReadPacketKey []byte -} - func WritePacket(packet Packet, buffer *Buffer, sequence uint64, writePacketKey []byte, protocolId uint64) (int, error) { - var p Packet - packetType := packet.GetType() - + // TODO: this should be moved to writePacketData provided packet prefix can be safely ignored/added if packetType == ConnectionRequest { // connection request packet: first byte is zero p, ok := packet.(*RequestPacket) @@ -156,19 +141,41 @@ func WritePacket(packet Packet, buffer *Buffer, sequence uint64, writePacketKey buffer.WriteUint64(p.ProtocolId) buffer.WriteUint64(p.ConnectTokenExpireTimestamp) buffer.WriteUint64(p.ConnectTokenSequence) - buffer.WriteBytesN(p.ConnectTokenData, CONNECT_TOKEN_PRIVATE_BYTES) - if buffer.Pos != 1 + 13 + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES { - return -1, errors.New("invalid buffer size") + buffer.WriteBytes(p.ConnectTokenData) // write the encrypted connection token private data + if buffer.Pos != 1 + 13 + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES { + return -1, errors.New("invalid buffer size written") } return buffer.Pos, nil } - + panic("WritePacket should not get here") // *** encrypted packets *** + prefixByte, err := writePacketPrefix(packet, buffer, sequence) + if err != nil { + return -1, err + } + + encryptedStart := buffer.Pos + if err := writePacketData(packet, buffer); err != nil { + return -1, err + } + encryptedFinish := buffer.Pos - // write the prefix byte (this is a combination of the packet type and number of sequence bytes) + additionalData, nonce := packetCryptData(prefixByte, protocolId, sequence) + + // slice up the buffer for the bits we will encrypt + encryptedBuffer := buffer.Buf[encryptedStart:encryptedFinish] + if err := EncryptAead(&encryptedBuffer, additionalData, nonce, writePacketKey); err != nil { + return -1, err + } + + return buffer.Pos + MAC_BYTES, nil +} + +// write the prefix byte (this is a combination of the packet type and number of sequence bytes) +func writePacketPrefix(p Packet, buffer *Buffer, sequence uint64) (uint8, error) { sequenceBytes := sequenceNumberBytesRequired(sequence) if (sequenceBytes < 1 || sequenceBytes > 8) { - return -1, errors.New("invalid sequence bytes, must be between [1-8]") + return 0, errors.New("invalid sequence bytes, must be between [1-8]") } prefixByte := uint8(p.GetType()) | uint8(sequenceBytes << 4) @@ -180,58 +187,57 @@ func WritePacket(packet Packet, buffer *Buffer, sequence uint64, writePacketKey buffer.WriteUint8(uint8(sequenceTemp & 0xFF)) sequenceTemp >>= 8 } + return prefixByte, nil +} - encryptedStart := buffer.Pos - // write packet data according to type. this data will be encrypted. - switch p.GetType() { +// write packet data according to type. this data will be encrypted. +func writePacketData(packet Packet, buffer *Buffer) error { + switch packet.GetType() { case ConnectionDenied: // ... case ConnectionChallenge: p, ok := packet.(*ChallengePacket) if !ok { - return -1, nil + return errors.New("invalid packet type") } buffer.WriteUint64(p.ChallengeTokenSequence) buffer.WriteBytesN(p.ChallengeTokenData, CHALLENGE_TOKEN_BYTES) case ConnectionResponse: p, ok := packet.(*ResponsePacket) if !ok { - return -1, nil + return errors.New("invalid packet type") } buffer.WriteUint64(p.ChallengeTokenSequence) buffer.WriteBytesN(p.ChallengeTokenData, CHALLENGE_TOKEN_BYTES) case ConnectionKeepAlive: p, ok := packet.(*KeepAlivePacket) if !ok { - return -1, nil + return errors.New("invalid packet type") } buffer.WriteUint32(uint32(p.ClientIndex)) buffer.WriteUint32(uint32(p.MaxClients)) case ConnectionPayload: p, ok := packet.(*PayloadPacket) if !ok { - return -1, nil + return errors.New("invalid packet type") } buffer.WriteBytesN([]byte(p.PayloadData), int(p.PayloadBytes)) case ConnectionDisconnect: // ... } - encryptedFinish := buffer.Pos + return nil +} - // encrypt the per-packet packet written with the prefix byte, protocol id and version as the associated data. this must match to decrypt. +// used for encrypting the per-packet packet written with the prefix byte, protocol id and version as the associated data. this must match to decrypt. +func packetCryptData(prefixByte uint8, protocolId, sequence uint64) ([]byte, []byte) { additionalData := NewBuffer(VERSION_INFO_BYTES+8+1) additionalData.WriteBytesN([]byte(VERSION_INFO), VERSION_INFO_BYTES) additionalData.WriteUint64(protocolId) additionalData.WriteUint8(prefixByte) - nonce := NewBuffer(8) + nonce := NewBuffer(SizeUint64) nonce.WriteUint64(sequence) - - err := EncryptAead(&buffer.Buf[encryptedStart:encryptedFinish], additionalData.Bytes(), nonce.Bytes(), writePacketKey) - if err != nil { - return -1, err - } - return buffer.Pos + MAC_BYTES, nil + return additionalData.Buf, nonce.Buf } func ReadPacket(packetData []byte, packetLen int, sequence uint64, readPacketKey []byte, protocolId uint64, currentTimestamp uint64, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) (Packet, error) { @@ -250,6 +256,7 @@ func ReadPacket(packetData []byte, packetLen int, sequence uint64, readPacketKey if PacketType(prefixByte) == ConnectionRequest { return readRequestPacket(packetBuffer, packetLen, protocolId, currentTimestamp, allowedPackets, privateKey) } + panic("ReadPacket should not get here") // *** encrypted packets *** if readPacketKey == nil { @@ -315,7 +322,7 @@ func ReadPacket(packetData []byte, packetLen int, sequence uint64, readPacketKey if err != nil { return nil, errors.New("ignored encrypted packet. encrypted payload is too small") } - + log.Printf("encryptedBuff: %#v\n", encryptedBuff) decryptedBuff, err := DecryptAead(encryptedBuff, additionalData.Bytes(), nonce.Bytes(), readPacketKey) if err != nil { return nil, errors.New("ignored encrypted packet. failed to decrypt: " + err.Error()) @@ -416,7 +423,7 @@ func readRequestPacket(packetBuffer *Buffer, packetLen int, protocolId, currentT return nil, errors.New("ignored connection request packet. packet type is not allowed") } - if packetLen != 1 + VERSION_INFO_BYTES + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES { + if packetLen != 1 + VERSION_INFO_BYTES + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES { log.Printf("packetLen: %d, expected: %d\n", packetLen, 1 + VERSION_INFO_BYTES + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES) return nil, errors.New("ignored connection request packet. bad packet length") } @@ -448,20 +455,24 @@ func readRequestPacket(packetBuffer *Buffer, packetLen int, protocolId, currentT if err != nil { return nil, err } + log.Printf("expireTime %d sequence %d\n", packet.ConnectTokenExpireTimestamp, packet.ConnectTokenSequence) + + if packetBuffer.Pos != 1 + VERSION_INFO_BYTES + 8 + 8 + 8 { + return nil, errors.New(" invalid length of packet buffer read") + } var tokenBuffer []byte - tokenBuffer, err = packetBuffer.GetBytes(CONNECT_TOKEN_PRIVATE_BYTES) + tokenBuffer, err = packetBuffer.GetBytes(CONNECT_TOKEN_PRIVATE_BYTES+MAC_BYTES) if err != nil { return nil, err } - log.Printf("len tokenBuffer: %d, pos: %d\n", len(tokenBuffer), packetBuffer.Pos) - log.Printf("tokenBuffer: %x %#v\n", packet.ConnectTokenExpireTimestamp, tokenBuffer) + log.Printf("len of tokenBuffer: %d hash: %x\n", len(tokenBuffer), sha1.Sum(tokenBuffer)) + packet.Token, err = ReadConnectToken(tokenBuffer, packet.ProtocolId, packet.ConnectTokenExpireTimestamp, packet.ConnectTokenSequence, privateKey) if err != nil { return nil, err } - packet.ConnectTokenData = packet.Token.PrivateData.TokenData.Buf return packet, nil } diff --git a/go/packet_test.go b/go/packet_test.go index b2ec051..9c83951 100644 --- a/go/packet_test.go +++ b/go/packet_test.go @@ -5,6 +5,7 @@ import ( "net" "time" "bytes" + "crypto/sha1" ) func TestReadPacket(t *testing.T) { @@ -12,27 +13,14 @@ func TestReadPacket(t *testing.T) { } func TestConnectionRequestPacket(t *testing.T) { - addr := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: TEST_SERVER_PORT} - serverAddrs := make([]net.UDPAddr, 1) - serverAddrs[0] = addr - config := NewConfig(serverAddrs, TEST_CONNECT_TOKEN_EXPIRY, TEST_PROTOCOL_ID, TEST_PRIVATE_KEY) - - connectToken := NewConnectToken() - currentTimestamp := uint64(time.Now().Unix()) - - if err := connectToken.Generate(config, TEST_CLIENT_ID, currentTimestamp, TEST_SEQUENCE_START); err != nil { - t.Fatalf("error generating connect token: %s\n", err) + connectTokenKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating connect token key: %s\n", err) } - - // setup a connection request packet wrapping the encrypted connect token - inputPacket := &RequestPacket{} - inputPacket.VersionInfo = []byte(VERSION_INFO) - inputPacket.ProtocolId = TEST_PROTOCOL_ID - inputPacket.ConnectTokenExpireTimestamp = uint64(time.Now().Unix() + 30) - inputPacket.ConnectTokenSequence = TEST_SEQUENCE_START - inputPacket.ConnectTokenData = connectToken.PrivateData.TokenData.Bytes() - + inputPacket, decryptedToken := testBuildRequestPacket(connectTokenKey, t) + t.Logf("decrypted len: %#d\n", len(decryptedToken)) // write the connection request packet to a buffer + buffer := NewBuffer(2048) packetKey, err := GenerateKey() @@ -49,6 +37,7 @@ func TestConnectionRequestPacket(t *testing.T) { t.Fatalf("did not write any bytes for this packet") } + // read the connection request packet back in from the buffer (the connect token data is decrypted as part of the read packet validation) var sequence uint64 sequence = TEST_SEQUENCE_START @@ -59,9 +48,11 @@ func TestConnectionRequestPacket(t *testing.T) { } buffer.Reset() - outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), TEST_PRIVATE_KEY, allowedPackets, nil) + //t.Logf("before read: %#v %d\n", buffer.Buf[:bytesWritten], len(buffer.Buf[:bytesWritten])) + outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), connectTokenKey, allowedPackets, nil) if err != nil { t.Fatalf("error reading packet: %s\n", err) + } if outputPacket.GetType() != ConnectionRequest { @@ -89,8 +80,54 @@ func TestConnectionRequestPacket(t *testing.T) { t.Fatalf("ConnectTokenSequence did not match") } - if bytes.Compare(inputPacket.Token.PrivateData.TokenData.Buf, output.Token.PrivateData.TokenData.Buf) != 0 { + if bytes.Compare(decryptedToken, output.Token.PrivateData.TokenData.Buf) != 0 { t.Fatalf("TokenData did not match") } } + +func testBuildRequestPacket(connectTokenKey []byte, t *testing.T) (*RequestPacket, []byte) { + addr := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: TEST_SERVER_PORT} + serverAddrs := make([]net.UDPAddr, 1) + serverAddrs[0] = addr + config := NewConfig(serverAddrs, TEST_CONNECT_TOKEN_EXPIRY, TEST_PROTOCOL_ID, TEST_PRIVATE_KEY) + + connectToken := NewConnectToken() + currentTimestamp := uint64(time.Now().Unix()) + expireTimestamp := uint64(time.Now().Unix()) + config.TokenExpiry + + if err := connectToken.Generate(config, TEST_CLIENT_ID, currentTimestamp, TEST_SEQUENCE_START); err != nil { + t.Fatalf("error generating connect token: %s\n", err) + } + + privateData, err := connectToken.Write() + if err != nil { + t.Fatalf("error writing private data: %s\n", err) + } + + if err := EncryptConnectTokenPrivate(&privateData, TEST_PROTOCOL_ID, expireTimestamp, TEST_SEQUENCE_START, connectTokenKey); err != nil { + t.Fatalf("error encrypting connect token private %s\n", err) + } + + t.Logf("after encrypt test: %x\n", sha1.Sum(privateData)) + + decryptedToken, err := DecryptConnectTokenPrivate(privateData, TEST_PROTOCOL_ID, expireTimestamp, TEST_SEQUENCE_START, connectTokenKey) + if err != nil { + t.Fatalf("error decrypting connect token: %s", err) + } + + t.Logf("build request private data len: %d\n", len(privateData)) + _, err = ReadConnectToken(privateData, TEST_PROTOCOL_ID, expireTimestamp, TEST_SEQUENCE_START, connectTokenKey) + if err != nil { + t.Fatalf("error reading connect token: %s\n", err) + } + // setup a connection request packet wrapping the encrypted connect token + inputPacket := &RequestPacket{} + inputPacket.VersionInfo = []byte(VERSION_INFO) + inputPacket.ProtocolId = TEST_PROTOCOL_ID + inputPacket.ConnectTokenExpireTimestamp = expireTimestamp + inputPacket.ConnectTokenSequence = TEST_SEQUENCE_START + inputPacket.Token = connectToken + inputPacket.ConnectTokenData = privateData + return inputPacket, decryptedToken +} \ No newline at end of file diff --git a/go/simulator.go b/go/simulator.go index 99ad84a..79411da 100644 --- a/go/simulator.go +++ b/go/simulator.go @@ -1 +1,12 @@ package netcode + + +const PACKET_SEND_RATE = 10.0 +const TIMEOUT_SECONDS = 5.0 +const NUM_DISCONNECT_PACKETS = 10 + + +type Context struct { + WritePacketKey []byte + ReadPacketKey []byte +} From 68d387279f0bdbf246e4dd54d63a30395baa8ce8 Mon Sep 17 00:00:00 2001 From: wirepair Date: Wed, 5 Apr 2017 22:46:31 +0900 Subject: [PATCH 07/11] implement packet and packet tests also start untangling connect token and connect token private --- go/connect_token.go | 177 ++++++------------ go/connect_token_private.go | 110 +++++++++++ go/connect_token_test.go | 42 +++++ go/crypto.go | 5 - go/packet.go | 48 +++-- go/packet_test.go | 354 +++++++++++++++++++++++++++++++++++- 6 files changed, 574 insertions(+), 162 deletions(-) create mode 100644 go/connect_token_private.go diff --git a/go/connect_token.go b/go/connect_token.go index 5b52d54..d79efcd 100644 --- a/go/connect_token.go +++ b/go/connect_token.go @@ -5,6 +5,7 @@ import ( "errors" "strconv" "log" + "go/token" ) const ( @@ -13,6 +14,7 @@ const ( ADDRESS_IPV6 ) +const CONNECT_TOKEN_BYTES = 2048 // Token used for connecting type ConnectToken struct { @@ -48,96 +50,24 @@ func (token *ConnectToken) ClientId() uint64 { return token.PrivateData.ClientId } - -type ConnectTokenPrivate struct { - ClientId uint64 - ServerAddrs []net.UDPAddr // list of server addresses this client may connect to - ClientKey []byte // client to server key - ServerKey []byte // server to client key - UserData []byte // used to store user data - TokenData *Buffer // used to store the serialized buffer -} - -func NewConnectTokenPrivate() *ConnectTokenPrivate { - p := &ConnectTokenPrivate{} - p.TokenData = NewBuffer(CONNECT_TOKEN_PRIVATE_BYTES) - return p -} - -func (p *ConnectTokenPrivate) Read() error { - var err error - - - if p.ClientId, err = p.TokenData.GetUint64(); err != nil { - return err - } - - if err := p.readServerData(); err != nil { - return err - } - - if p.ClientKey, err = p.TokenData.GetBytes(KEY_BYTES); err != nil { - return errors.New("error reading client to server key") - } - - if p.ServerKey, err = p.TokenData.GetBytes(KEY_BYTES); err != nil { - return errors.New("error reading server to client key") - } - - if p.UserData, err = p.TokenData.GetBytes(USER_DATA_BYTES); err != nil { - return errors.New("error reading user data") - } - - return nil -} - -func (p *ConnectTokenPrivate) readServerData() error { - var err error - var servers uint32 - var ipBytes []byte - - servers, err = p.TokenData.GetUint32() +func (token *ConnectToken) Write() ([]byte, error) { + buffer := NewBuffer(CONNECT_TOKEN_BYTES) + buffer.WriteBytes([]byte(VERSION_INFO)) + buffer.WriteUint64(token.ProtocolId) + buffer.WriteUint64(token.CreateTimestamp) + buffer.WriteUint64(token.ExpireTimestamp) + buffer.WriteUint64(token.Sequence) + + privateData, err := token.PrivateData.Write() if err != nil { - return err - } - - if servers <= 0 { - return errors.New("empty servers") - } - - if servers > MAX_SERVERS_PER_CONNECT { - log.Printf("got %d expected %d\n", servers, MAX_SERVERS_PER_CONNECT) - return errors.New("too many servers") + return nil, err } + buffer.WriteBytes(privateData) - p.ServerAddrs = make([]net.UDPAddr, servers) - - for i := 0; i < int(servers); i+=1 { - serverType, err := p.TokenData.GetUint8() - if err != nil { - return err - } - - if serverType == ADDRESS_IPV4 { - ipBytes, err = p.TokenData.GetBytes(4) - } else if serverType == ADDRESS_IPV6 { - ipBytes, err = p.TokenData.GetBytes(16) - } else { - return errors.New("unknown ip address") - } - - if err != nil { - return err - } - - ip := net.IP(ipBytes) - port, err := p.TokenData.GetUint16() - if err != nil { - return errors.New("invalid port") - } - p.ServerAddrs[i] = net.UDPAddr{IP: ip, Port: int(port)} + if err := writeServerData(buffer, token.PrivateData.ServerAddrs, token.PrivateData.ClientKey, token.PrivateData.ServerKey, token.PrivateData.UserData); err != nil { + return nil, err } - return nil + return buffer.Buf, nil } // Generates the token with the supplied configuration values @@ -167,45 +97,7 @@ func (token *ConnectToken) Generate(config *Config, clientId, currentTimestamp, return nil } -// Writes the token data to a byte slice and returns to caller -func (token *ConnectToken) Write() ([]byte, error) { - data := NewBuffer(CONNECT_TOKEN_PRIVATE_BYTES) - data.WriteUint64(token.PrivateData.ClientId) - data.WriteUint32(uint32(len(token.PrivateData.ServerAddrs))) - - for _, addr := range token.ServerAddresses() { - host, port, err := net.SplitHostPort(addr.String()) - if err != nil { - return nil, errors.New("invalid port for host: " + addr.String()) - } - - parsed := net.ParseIP(host) - if parsed == nil { - return nil, errors.New("invalid ip address") - } - - if len(parsed) == 4 { - data.WriteUint8(uint8(ADDRESS_IPV4)) - - } else { - data.WriteUint8(uint8(ADDRESS_IPV6)) - } - - for i := 0; i < len(parsed); i +=1 { - data.WriteUint8(parsed[i]) - } - p, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return nil, err - } - data.WriteUint16(uint16(p)) - } - data.WriteBytesN(token.PrivateData.ClientKey, KEY_BYTES) - data.WriteBytesN(token.PrivateData.ServerKey, KEY_BYTES) - data.WriteBytesN(token.PrivateData.UserData, USER_DATA_BYTES) - return data.Buf, nil -} // Takes in a slice of bytes and generates a new ConnectToken after decryption. func ReadConnectToken(tokenBuffer []byte, protocolId, expireTimestamp, sequence uint64, privateKey []byte) (*ConnectToken, error) { @@ -255,4 +147,41 @@ func buildCryptData(protocolId, expireTimestamp, sequence uint64) ([]byte, []byt nonce.WriteUint64(sequence) return additionalData.Buf, nonce.Buf +} + +func writeServerData(buffer *Buffer, serverAddrs []net.UDPAddr, clientKey, serverKey, userData []byte) error { + buffer.WriteUint32(uint32(len(serverAddrs))) + + for _, addr := range serverAddrs { + host, port, err := net.SplitHostPort(addr.String()) + if err != nil { + return errors.New("invalid port for host: " + addr.String()) + } + + parsed := net.ParseIP(host) + if parsed == nil { + return errors.New("invalid ip address") + } + + if len(parsed) == 4 { + buffer.WriteUint8(uint8(ADDRESS_IPV4)) + + } else { + buffer.WriteUint8(uint8(ADDRESS_IPV6)) + } + + for i := 0; i < len(parsed); i +=1 { + buffer.WriteUint8(parsed[i]) + } + + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return err + } + buffer.WriteUint16(uint16(p)) + } + buffer.WriteBytesN(clientKey, KEY_BYTES) + buffer.WriteBytesN(serverKey, KEY_BYTES) + buffer.WriteBytesN(userData, USER_DATA_BYTES) + return nil } \ No newline at end of file diff --git a/go/connect_token_private.go b/go/connect_token_private.go new file mode 100644 index 0000000..5ab8941 --- /dev/null +++ b/go/connect_token_private.go @@ -0,0 +1,110 @@ +package netcode + +import ( + "net" + "errors" + "strconv" + "log" +) + +type ConnectTokenPrivate struct { + ClientId uint64 + ServerAddrs []net.UDPAddr // list of server addresses this client may connect to + ClientKey []byte // client to server key + ServerKey []byte // server to client key + UserData []byte // used to store user data + TokenData *Buffer // used to store the serialized buffer +} + +func NewConnectTokenPrivate() *ConnectTokenPrivate { + p := &ConnectTokenPrivate{} + p.TokenData = NewBuffer(CONNECT_TOKEN_PRIVATE_BYTES) + return p +} + +// Reads the token properties from the internal TokenData buffer. +func (p *ConnectTokenPrivate) Read() error { + var err error + + if p.ClientId, err = p.TokenData.GetUint64(); err != nil { + return err + } + + if err := p.readServerData(); err != nil { + return err + } + + if p.ClientKey, err = p.TokenData.GetBytes(KEY_BYTES); err != nil { + return errors.New("error reading client to server key") + } + + if p.ServerKey, err = p.TokenData.GetBytes(KEY_BYTES); err != nil { + return errors.New("error reading server to client key") + } + + if p.UserData, err = p.TokenData.GetBytes(USER_DATA_BYTES); err != nil { + return errors.New("error reading user data") + } + + return nil +} + +func (p *ConnectTokenPrivate) readServerData() error { + var err error + var servers uint32 + var ipBytes []byte + + servers, err = p.TokenData.GetUint32() + if err != nil { + return err + } + + if servers <= 0 { + return errors.New("empty servers") + } + + if servers > MAX_SERVERS_PER_CONNECT { + log.Printf("got %d expected %d\n", servers, MAX_SERVERS_PER_CONNECT) + return errors.New("too many servers") + } + + p.ServerAddrs = make([]net.UDPAddr, servers) + + for i := 0; i < int(servers); i+=1 { + serverType, err := p.TokenData.GetUint8() + if err != nil { + return err + } + + if serverType == ADDRESS_IPV4 { + ipBytes, err = p.TokenData.GetBytes(4) + } else if serverType == ADDRESS_IPV6 { + ipBytes, err = p.TokenData.GetBytes(16) + } else { + return errors.New("unknown ip address") + } + + if err != nil { + return err + } + + ip := net.IP(ipBytes) + port, err := p.TokenData.GetUint16() + if err != nil { + return errors.New("invalid port") + } + p.ServerAddrs[i] = net.UDPAddr{IP: ip, Port: int(port)} + } + return nil +} + +// Writes the token data to a byte slice and returns to caller +func (token *ConnectTokenPrivate) Write() ([]byte, error) { + data := NewBuffer(CONNECT_TOKEN_PRIVATE_BYTES) + data.WriteUint64(token.ClientId) + + if err := writeServerData(data, token.ServerAddrs, token.ClientKey, token.ServerKey, token.UserData); err != nil { + return nil, err + } + return data.Buf, nil +} \ No newline at end of file diff --git a/go/connect_token_test.go b/go/connect_token_test.go index 1d8f23e..5add773 100644 --- a/go/connect_token_test.go +++ b/go/connect_token_test.go @@ -5,6 +5,7 @@ import ( "net" "bytes" "time" + "go/token" ) const ( @@ -58,6 +59,47 @@ func TestNewConnectToken(t *testing.T) { if bytes.Compare(private, private2) != 0 { t.Fatalf("encrypted private bits didn't match %v and %v\n", private, private2) } +} + +func TestConnectTokenPublic(t *testing.T) { + token1 := NewConnectToken() + server := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000} + servers := make([]net.UDPAddr, 1) + servers[0] = server + + key, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key %s\n", key) + } + + config := NewConfig(servers, TEST_CONNECT_TOKEN_EXPIRY, TEST_PROTOCOL_ID, key) + currentTimestamp := uint64(time.Now().Unix()) + + err = token1.Generate(config, TEST_CLIENT_ID, currentTimestamp, TEST_SEQUENCE_START) + if err != nil { + t.Fatalf("error generating and encrypting token") + } + + private, err := token1.Write() + if err != nil { + t.Fatalf("error writing token private data") + } + + // write it to a buffer + EncryptConnectTokenPrivate(&private, TEST_PROTOCOL_ID, uint64(currentTimestamp + config.TokenExpiry), TEST_SEQUENCE_START, config.PrivateKey) + + // set misc public token properties + token1.TimeoutSeconds = int(TIMEOUT_SECONDS) + + tokenData, err := token1.Write() + if err != nil { + t.Fatalf("error writing token: %s\n", err) + } + + + + + } func compareTokens(token1, token2 *ConnectToken, t *testing.T) { diff --git a/go/crypto.go b/go/crypto.go index 827ac43..43b210d 100644 --- a/go/crypto.go +++ b/go/crypto.go @@ -3,9 +3,6 @@ package netcode import ( "crypto/rand" "github.com/codahale/chacha20poly1305" - //"log" - "crypto/sha1" - "log" ) // Generates random bytes @@ -27,7 +24,6 @@ func EncryptAead(message *[]byte, additional, nonce, key []byte) error { return err } *message = aead.Seal(nil, nonce, *message, additional) - log.Printf("AFTER ENCRYPT: %x %x %x %x\n", sha1.Sum(*message), sha1.Sum(additional), sha1.Sum(nonce), sha1.Sum(key)) return nil } @@ -38,6 +34,5 @@ func DecryptAead(message []byte, additional, nonce, key []byte) ([]byte, error) if err != nil { return nil, err } - log.Printf("BEFORE DECRYPT: %x %x %x %x\n", sha1.Sum(message), sha1.Sum(additional), sha1.Sum(nonce), sha1.Sum(key)) return aead.Open(nil, nonce, message, additional) } \ No newline at end of file diff --git a/go/packet.go b/go/packet.go index ebfd9e5..8bf65f1 100644 --- a/go/packet.go +++ b/go/packet.go @@ -4,7 +4,6 @@ import ( "errors" "strconv" "log" - "crypto/sha1" ) const MAX_CLIENTS = 60 @@ -142,12 +141,12 @@ func WritePacket(packet Packet, buffer *Buffer, sequence uint64, writePacketKey buffer.WriteUint64(p.ConnectTokenExpireTimestamp) buffer.WriteUint64(p.ConnectTokenSequence) buffer.WriteBytes(p.ConnectTokenData) // write the encrypted connection token private data - if buffer.Pos != 1 + 13 + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES { + if buffer.Pos != 1 + 13 + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES { return -1, errors.New("invalid buffer size written") } return buffer.Pos, nil } - panic("WritePacket should not get here") + // *** encrypted packets *** prefixByte, err := writePacketPrefix(packet, buffer, sequence) if err != nil { @@ -161,20 +160,24 @@ func WritePacket(packet Packet, buffer *Buffer, sequence uint64, writePacketKey encryptedFinish := buffer.Pos additionalData, nonce := packetCryptData(prefixByte, protocolId, sequence) - + log.Printf("data to encrypt size: %d = %d - %d\n", encryptedFinish-encryptedStart, encryptedFinish, encryptedStart) // slice up the buffer for the bits we will encrypt encryptedBuffer := buffer.Buf[encryptedStart:encryptedFinish] if err := EncryptAead(&encryptedBuffer, additionalData, nonce, writePacketKey); err != nil { return -1, err } - return buffer.Pos + MAC_BYTES, nil + // hack to reset Pos to write in the encrypted buffer to avoid allocations/append() calls + buffer.Pos = encryptedStart + buffer.WriteBytes(encryptedBuffer) + log.Printf("buffer written so far plus mac: %d\n", buffer.Pos) + return buffer.Pos, nil // in c, we do Pos + MAC_BYTES but the WriteBytes will update Pos to include it } // write the prefix byte (this is a combination of the packet type and number of sequence bytes) func writePacketPrefix(p Packet, buffer *Buffer, sequence uint64) (uint8, error) { sequenceBytes := sequenceNumberBytesRequired(sequence) - if (sequenceBytes < 1 || sequenceBytes > 8) { + if sequenceBytes < 1 || sequenceBytes > 8 { return 0, errors.New("invalid sequence bytes, must be between [1-8]") } @@ -183,7 +186,8 @@ func writePacketPrefix(p Packet, buffer *Buffer, sequence uint64) (uint8, error) sequenceTemp := sequence - for i := 0; i < sequenceBytes; i+=1 { + var i uint8 + for ; i < sequenceBytes; i+=1 { buffer.WriteUint8(uint8(sequenceTemp & 0xFF)) sequenceTemp >>= 8 } @@ -221,7 +225,9 @@ func writePacketData(packet Packet, buffer *Buffer) error { if !ok { return errors.New("invalid packet type") } + log.Printf("writing %d payload bytes pre: %d\n", p.PayloadBytes, buffer.Pos) buffer.WriteBytesN([]byte(p.PayloadData), int(p.PayloadBytes)) + log.Printf("writing %d payload bytes post: %d\n", p.PayloadBytes, buffer.Pos) case ConnectionDisconnect: // ... } @@ -241,7 +247,6 @@ func packetCryptData(prefixByte uint8, protocolId, sequence uint64) ([]byte, []b } func ReadPacket(packetData []byte, packetLen int, sequence uint64, readPacketKey []byte, protocolId uint64, currentTimestamp uint64, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) (Packet, error) { - if packetLen < 1 { return nil, errors.New("invalid buffer length") } @@ -256,9 +261,8 @@ func ReadPacket(packetData []byte, packetLen int, sequence uint64, readPacketKey if PacketType(prefixByte) == ConnectionRequest { return readRequestPacket(packetBuffer, packetLen, protocolId, currentTimestamp, allowedPackets, privateKey) } - panic("ReadPacket should not get here") - // *** encrypted packets *** + // *** encrypted packets *** if readPacketKey == nil { return nil, errors.New("empty packet key") } @@ -305,13 +309,7 @@ func ReadPacket(packetData []byte, packetLen int, sequence uint64, readPacketKey } // decrypt the per-packet type data - additionalData := NewBuffer(VERSION_INFO_BYTES+8+1) - additionalData.WriteBytes([]byte(VERSION_INFO)) - additionalData.WriteUint64(protocolId) - additionalData.WriteUint8(prefixByte) - - nonce := NewBuffer(SizeUint64) - nonce.WriteUint64(sequence) + additionalData, nonce := packetCryptData(prefixByte, protocolId, sequence) encryptedSize := packetLen - packetBuffer.Pos if encryptedSize < MAC_BYTES { @@ -322,8 +320,8 @@ func ReadPacket(packetData []byte, packetLen int, sequence uint64, readPacketKey if err != nil { return nil, errors.New("ignored encrypted packet. encrypted payload is too small") } - log.Printf("encryptedBuff: %#v\n", encryptedBuff) - decryptedBuff, err := DecryptAead(encryptedBuff, additionalData.Bytes(), nonce.Bytes(), readPacketKey) + + decryptedBuff, err := DecryptAead(encryptedBuff, additionalData, nonce, readPacketKey) if err != nil { return nil, errors.New("ignored encrypted packet. failed to decrypt: " + err.Error()) } @@ -334,11 +332,12 @@ func ReadPacket(packetData []byte, packetLen int, sequence uint64, readPacketKey return processPacket(PacketType(packetType), decryptedBuff, decryptedSize) } +// Processes the packet after decryption has occurred. func processPacket(packetType PacketType, decrypted []byte, decryptedSize int) (Packet, error) { var err error decryptedBuff := NewBufferFromBytes(decrypted) - switch (packetType) { + switch packetType { case ConnectionDenied: if decryptedSize != 0 { return nil, errors.New("ignored connection denied packet. decrypted packet data is wrong size") @@ -424,7 +423,6 @@ func readRequestPacket(packetBuffer *Buffer, packetLen int, protocolId, currentT } if packetLen != 1 + VERSION_INFO_BYTES + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES { - log.Printf("packetLen: %d, expected: %d\n", packetLen, 1 + VERSION_INFO_BYTES + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES) return nil, errors.New("ignored connection request packet. bad packet length") } @@ -455,7 +453,6 @@ func readRequestPacket(packetBuffer *Buffer, packetLen int, protocolId, currentT if err != nil { return nil, err } - log.Printf("expireTime %d sequence %d\n", packet.ConnectTokenExpireTimestamp, packet.ConnectTokenSequence) if packetBuffer.Pos != 1 + VERSION_INFO_BYTES + 8 + 8 + 8 { return nil, errors.New(" invalid length of packet buffer read") @@ -466,7 +463,6 @@ func readRequestPacket(packetBuffer *Buffer, packetLen int, protocolId, currentT if err != nil { return nil, err } - log.Printf("len of tokenBuffer: %d hash: %x\n", len(tokenBuffer), sha1.Sum(tokenBuffer)) packet.Token, err = ReadConnectToken(tokenBuffer, packet.ProtocolId, packet.ConnectTokenExpireTimestamp, packet.ConnectTokenSequence, privateKey) if err != nil { @@ -477,12 +473,12 @@ func readRequestPacket(packetBuffer *Buffer, packetLen int, protocolId, currentT } -func sequenceNumberBytesRequired(sequence uint64) int { +func sequenceNumberBytesRequired(sequence uint64) uint8 { var mask uint64 mask = 0xFF00000000000000 - i := 0 + var i uint8 for ; i < 7; i+=1 { - if (sequence & mask == 0) { + if (sequence & mask != 0) { break } mask >>= 8 diff --git a/go/packet_test.go b/go/packet_test.go index 9c83951..e37b67c 100644 --- a/go/packet_test.go +++ b/go/packet_test.go @@ -5,20 +5,66 @@ import ( "net" "time" "bytes" - "crypto/sha1" ) func TestReadPacket(t *testing.T) { } +func TestSequence(t *testing.T) { + seq := sequenceNumberBytesRequired(0) + if seq != 1 { + t.Fatal("expected 0, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x11) + if seq != 1 { + t.Fatal("expected 1, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x1122) + if seq != 2 { + t.Fatal("expected 2, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x112233) + if seq != 3 { + t.Fatal("expected 3, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x11223344) + if seq != 4 { + t.Fatal("expected 4, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x1122334455) + if seq != 5 { + t.Fatal("expected 5, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x112233445566) + if seq != 6 { + t.Fatal("expected 6, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x11223344556677) + if seq != 7 { + t.Fatal("expected 7, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x1122334455667788) + if seq != 8 { + t.Fatal("expected 8, got: ", seq) + } + +} + func TestConnectionRequestPacket(t *testing.T) { connectTokenKey, err := GenerateKey() if err != nil { t.Fatalf("error generating connect token key: %s\n", err) } inputPacket, decryptedToken := testBuildRequestPacket(connectTokenKey, t) - t.Logf("decrypted len: %#d\n", len(decryptedToken)) // write the connection request packet to a buffer buffer := NewBuffer(2048) @@ -37,7 +83,6 @@ func TestConnectionRequestPacket(t *testing.T) { t.Fatalf("did not write any bytes for this packet") } - // read the connection request packet back in from the buffer (the connect token data is decrypted as part of the read packet validation) var sequence uint64 sequence = TEST_SEQUENCE_START @@ -48,7 +93,6 @@ func TestConnectionRequestPacket(t *testing.T) { } buffer.Reset() - //t.Logf("before read: %#v %d\n", buffer.Buf[:bytesWritten], len(buffer.Buf[:bytesWritten])) outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), connectTokenKey, allowedPackets, nil) if err != nil { t.Fatalf("error reading packet: %s\n", err) @@ -83,7 +127,305 @@ func TestConnectionRequestPacket(t *testing.T) { if bytes.Compare(decryptedToken, output.Token.PrivateData.TokenData.Buf) != 0 { t.Fatalf("TokenData did not match") } +} + +func TestConnectionDeniedPacket(t *testing.T) { + // setup a connection denied packet + inputPacket := &DeniedPacket{} + + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + var sequence uint64 + sequence = TEST_SEQUENCE_START + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i+=1 { + allowedPackets[i] = 1 + } + + buffer.Reset() + outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) + if err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + if outputPacket.GetType() != ConnectionDenied { + t.Fatalf("did not get a denied packet after read") + } +} + +func TestConnectionChallengePacket(t *testing.T) { + var err error + + // setup a connection challenge packet + inputPacket := &ChallengePacket{} + inputPacket.ChallengeTokenSequence = 0 + inputPacket.ChallengeTokenData, err = RandomBytes(CHALLENGE_TOKEN_BYTES) + if err != nil { + t.Fatalf("error generating random bytes") + } + + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + var sequence uint64 + sequence = TEST_SEQUENCE_START + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i+=1 { + allowedPackets[i] = 1 + } + + buffer.Reset() + outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) + if err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + challenge, ok := outputPacket.(*ChallengePacket) + if !ok { + t.Fatalf("did not get a challenge packet after read") + } + + if inputPacket.ChallengeTokenSequence != challenge.ChallengeTokenSequence { + t.Fatalf("input and output sequence differed, expected %d got %d\n", inputPacket.ChallengeTokenSequence, challenge.ChallengeTokenSequence) + } + + if bytes.Compare(inputPacket.ChallengeTokenData, challenge.ChallengeTokenData) != 0 { + t.Fatalf("challenge token data was not equal\n") + } +} + +func TestConnectionResponsePacket(t *testing.T) { + var err error + + // setup a connection challenge packet + inputPacket := &ResponsePacket{} + inputPacket.ChallengeTokenSequence = 0 + inputPacket.ChallengeTokenData, err = RandomBytes(CHALLENGE_TOKEN_BYTES) + if err != nil { + t.Fatalf("error generating random bytes") + } + + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + var sequence uint64 + sequence = TEST_SEQUENCE_START + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i+=1 { + allowedPackets[i] = 1 + } + + buffer.Reset() + outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) + if err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + response, ok := outputPacket.(*ResponsePacket) + if !ok { + t.Fatalf("did not get a response packet after read") + } + + if inputPacket.ChallengeTokenSequence != response.ChallengeTokenSequence { + t.Fatalf("input and output sequence differed, expected %d got %d\n", inputPacket.ChallengeTokenSequence, response.ChallengeTokenSequence) + } + + if bytes.Compare(inputPacket.ChallengeTokenData, response.ChallengeTokenData) != 0 { + t.Fatalf("challenge token data was not equal\n") + } +} + + +func TestConnectionKeepAlivePacket(t *testing.T) { + var err error + + // setup a connection challenge packet + inputPacket := &KeepAlivePacket{} + inputPacket.ClientIndex = 10 + inputPacket.MaxClients = 16 + + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + var sequence uint64 + sequence = TEST_SEQUENCE_START + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i+=1 { + allowedPackets[i] = 1 + } + + buffer.Reset() + outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) + if err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + keepalive, ok := outputPacket.(*KeepAlivePacket) + if !ok { + t.Fatalf("did not get a response packet after read") + } + + if inputPacket.ClientIndex != keepalive.ClientIndex { + t.Fatalf("input and output index differed, expected %d got %d\n", inputPacket.ClientIndex, keepalive.ClientIndex) + } + + if inputPacket.MaxClients != keepalive.MaxClients { + t.Fatalf("input and output maxclients differed, expected %d got %d\n", inputPacket.MaxClients, keepalive.MaxClients) + } +} + +func TestConnectionPayloadPacket(t *testing.T) { + var err error + + // setup a connection challenge packet + inputPacket := NewPayloadPacket(MAX_PAYLOAD_BYTES) + inputPacket.PayloadData, err = RandomBytes(MAX_PAYLOAD_BYTES) + if err != nil { + t.Fatalf("error generating random payload data: %s\n", err) + } + + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + var sequence uint64 + sequence = TEST_SEQUENCE_START + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i+=1 { + allowedPackets[i] = 1 + } + buffer.Reset() + outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) + if err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + payload, ok := outputPacket.(*PayloadPacket) + if !ok { + t.Fatalf("did not get a payload packet after read") + } + + if inputPacket.PayloadBytes != payload.PayloadBytes { + t.Fatalf("input and output index differed, expected %d got %d\n", inputPacket.PayloadBytes, payload.PayloadBytes) + } + + if bytes.Compare(inputPacket.PayloadData, payload.PayloadData) != 0 { + t.Fatalf("input and output payload differed, expected %v got %v\n", inputPacket.PayloadData, payload.PayloadData) + } +} + +func TestDisconnectPacket(t *testing.T) { + inputPacket := &DisconnectPacket{} + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + var sequence uint64 + sequence = TEST_SEQUENCE_START + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i+=1 { + allowedPackets[i] = 1 + } + + buffer.Reset() + outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) + if err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + _, ok := outputPacket.(*DisconnectPacket) + if !ok { + t.Fatalf("did not get a disconnect packet after read") + } } func testBuildRequestPacket(connectTokenKey []byte, t *testing.T) (*RequestPacket, []byte) { @@ -109,18 +451,16 @@ func testBuildRequestPacket(connectTokenKey []byte, t *testing.T) (*RequestPacke t.Fatalf("error encrypting connect token private %s\n", err) } - t.Logf("after encrypt test: %x\n", sha1.Sum(privateData)) - decryptedToken, err := DecryptConnectTokenPrivate(privateData, TEST_PROTOCOL_ID, expireTimestamp, TEST_SEQUENCE_START, connectTokenKey) if err != nil { t.Fatalf("error decrypting connect token: %s", err) } - t.Logf("build request private data len: %d\n", len(privateData)) _, err = ReadConnectToken(privateData, TEST_PROTOCOL_ID, expireTimestamp, TEST_SEQUENCE_START, connectTokenKey) if err != nil { t.Fatalf("error reading connect token: %s\n", err) } + // setup a connection request packet wrapping the encrypted connect token inputPacket := &RequestPacket{} inputPacket.VersionInfo = []byte(VERSION_INFO) From cbe92145252af52e89296e39a2354b0013e639ce Mon Sep 17 00:00:00 2001 From: wirepair Date: Thu, 6 Apr 2017 16:44:10 +0900 Subject: [PATCH 08/11] make packets a bit more object-y --- go/client.go | 7 +- go/config.go | 14 +- go/connect_token.go | 203 ++++------ go/connect_token_private.go | 128 +++--- go/connect_token_private_test.go | 92 +++++ go/connect_token_shared.go | 126 ++++++ go/connect_token_test.go | 141 +++---- go/packet.go | 655 +++++++++++++++++-------------- go/packet_test.go | 204 ++++------ 9 files changed, 862 insertions(+), 708 deletions(-) create mode 100644 go/connect_token_private_test.go create mode 100644 go/connect_token_shared.go diff --git a/go/client.go b/go/client.go index 31e4e7e..3a00d16 100644 --- a/go/client.go +++ b/go/client.go @@ -3,11 +3,10 @@ package netcode import ( "crypto/rand" "math/big" - "time" ) type Client struct { - Id uint64 + Id uint64 config *Config } @@ -23,10 +22,9 @@ func (c *Client) Init(sequence uint64) error { } c.Id = id.Uint64() - currentTimestamp := uint64(time.Now().Unix()) token := NewConnectToken() - if err := token.Generate(c.config, c.Id, currentTimestamp, sequence); err != nil { + if err := token.Generate(c.config, sequence); err != nil { return err } @@ -36,4 +34,3 @@ func (c *Client) Init(sequence uint64) error { func (c *Client) Connect() error { return nil } - diff --git a/go/config.go b/go/config.go index 13442fd..18db612 100644 --- a/go/config.go +++ b/go/config.go @@ -3,17 +3,21 @@ package netcode import "net" type Config struct { - ServerAddrs []net.UDPAddr - TokenExpiry uint64 - ProtocolId uint64 - PrivateKey []byte + ClientId uint64 + ServerAddrs []net.UDPAddr + TokenExpiry uint64 + TimeoutSeconds uint32 + ProtocolId uint64 + PrivateKey []byte } -func NewConfig(serverAddrs []net.UDPAddr, expiry, protocolId uint64, privateKey []byte) *Config { +func NewConfig(serverAddrs []net.UDPAddr, timeoutSeconds uint32, expiry, clientId, protocolId uint64, privateKey []byte) *Config { c := &Config{} + c.ClientId = clientId c.ServerAddrs = serverAddrs c.TokenExpiry = expiry c.ProtocolId = protocolId c.PrivateKey = privateKey + c.TimeoutSeconds = timeoutSeconds return c } diff --git a/go/connect_token.go b/go/connect_token.go index d79efcd..c2dd9f9 100644 --- a/go/connect_token.go +++ b/go/connect_token.go @@ -1,11 +1,9 @@ package netcode import ( - "net" "errors" - "strconv" - "log" - "go/token" + "strings" + "time" ) const ( @@ -18,170 +16,129 @@ const CONNECT_TOKEN_BYTES = 2048 // Token used for connecting type ConnectToken struct { - VersionInfo []byte - ProtocolId uint64 + sharedTokenData + VersionInfo []byte + ProtocolId uint64 CreateTimestamp uint64 ExpireTimestamp uint64 - Sequence uint64 - PrivateData *ConnectTokenPrivate - TimeoutSeconds int + Sequence uint64 + PrivateData *ConnectTokenPrivate + TimeoutSeconds uint32 } -// create a new empty token +// Create a new empty token and empty private token func NewConnectToken() *ConnectToken { token := &ConnectToken{} + token.PrivateData = NewConnectTokenPrivate() return token } -func (token *ConnectToken) ServerKey() []byte { - return token.PrivateData.ServerKey -} +// Generates the token and private token data with the supplied config values and sequence id. +func (token *ConnectToken) Generate(config *Config, sequence uint64) error { + token.CreateTimestamp = uint64(time.Now().Unix()) + token.ExpireTimestamp = token.CreateTimestamp + config.TokenExpiry + token.VersionInfo = []byte(VERSION_INFO) + token.ProtocolId = config.ProtocolId + token.TimeoutSeconds = config.TimeoutSeconds + token.Sequence = sequence -func (token *ConnectToken) ClientKey() []byte { - return token.PrivateData.ClientKey -} + userData, err := RandomBytes(USER_DATA_BYTES) + if err != nil { + return err + } -// list of server addresses this client may connect to -func (token *ConnectToken) ServerAddresses() []net.UDPAddr { - return token.PrivateData.ServerAddrs -} + if err = token.PrivateData.Generate(config, userData); err != nil { + return err + } + + // copy directly from the private token since we don't want to generate 2 different keys + token.ClientKey = token.PrivateData.ClientKey + token.ServerKey = token.PrivateData.ServerKey + token.ServerAddrs = token.PrivateData.ServerAddrs + + if _, err = token.PrivateData.Write(); err != nil { + return err + } + + if err = token.PrivateData.Encrypt(token.ProtocolId, token.ExpireTimestamp, sequence, config.PrivateKey); err != nil { + return err + } -func (token *ConnectToken) ClientId() uint64 { - return token.PrivateData.ClientId + return nil } +// Writes the ConnectToken and previously encrypted ConnectTokenPrivate data to a byte slice func (token *ConnectToken) Write() ([]byte, error) { buffer := NewBuffer(CONNECT_TOKEN_BYTES) - buffer.WriteBytes([]byte(VERSION_INFO)) + buffer.WriteBytes(token.VersionInfo) buffer.WriteUint64(token.ProtocolId) buffer.WriteUint64(token.CreateTimestamp) buffer.WriteUint64(token.ExpireTimestamp) buffer.WriteUint64(token.Sequence) - privateData, err := token.PrivateData.Write() - if err != nil { - return nil, err - } - buffer.WriteBytes(privateData) + // assumes private token has already been encrypted + buffer.WriteBytes(token.PrivateData.Buffer()) - if err := writeServerData(buffer, token.PrivateData.ServerAddrs, token.PrivateData.ClientKey, token.PrivateData.ServerKey, token.PrivateData.UserData); err != nil { + if err := token.WriteShared(buffer); err != nil { return nil, err } + + buffer.WriteUint32(token.TimeoutSeconds) return buffer.Buf, nil } -// Generates the token with the supplied configuration values -func (token *ConnectToken) Generate(config *Config, clientId, currentTimestamp, sequence uint64) error { +// Takes in a slice of decrypted connect token bytes and generates a new ConnectToken. +// Note that the ConnectTokenPrivate is still encrypted at this point. +func ReadConnectToken(tokenBuffer []byte) (*ConnectToken, error) { var err error + var privateData []byte - privateData := &ConnectTokenPrivate{} - token.PrivateData = privateData - - privateData.ClientId = clientId - privateData.ServerAddrs = config.ServerAddrs + buffer := NewBufferFromBytes(tokenBuffer) + token := NewConnectToken() - if privateData.UserData, err = RandomBytes(USER_DATA_BYTES); err != nil { - return err + if token.VersionInfo, err = buffer.GetBytes(VERSION_INFO_BYTES); err != nil { + return nil, errors.New("read connect token data has bad version info " + err.Error()) } - if privateData.ClientKey, err = GenerateKey(); err != nil { - return err + if strings.Compare(VERSION_INFO, string(token.VersionInfo)) != 0 { + return nil, errors.New("read connect token data has bad version info: " + string(token.VersionInfo)) } - if privateData.ServerKey, err = GenerateKey(); err != nil { - return err + if token.ProtocolId, err = buffer.GetUint64(); err != nil { + return nil, errors.New("read connect token data has bad protocol id " + err.Error()) } - token.CreateTimestamp = currentTimestamp - token.ExpireTimestamp = token.CreateTimestamp + config.TokenExpiry - return nil -} - - - -// Takes in a slice of bytes and generates a new ConnectToken after decryption. -func ReadConnectToken(tokenBuffer []byte, protocolId, expireTimestamp, sequence uint64, privateKey []byte) (*ConnectToken, error) { - var err error - var privateData []byte - - token := NewConnectToken() - token.ExpireTimestamp = expireTimestamp + if token.CreateTimestamp, err = buffer.GetUint64(); err != nil { + return nil, errors.New("read connect token data has bad create timestamp " + err.Error()) + } - if privateData, err = DecryptConnectTokenPrivate(tokenBuffer, protocolId, expireTimestamp, sequence, privateKey); err != nil { - return nil, errors.New("error decrypting connection token: " + err.Error()) + if token.ExpireTimestamp, err = buffer.GetUint64(); err != nil { + return nil, errors.New("read connect token data has bad expire timestamp " + err.Error()) } - private := NewConnectTokenPrivate() - private.TokenData = NewBufferFromBytes(privateData) - if err = private.Read(); err != nil { - return nil, err + if token.CreateTimestamp > token.ExpireTimestamp { + return nil, errors.New("expire timestamp is > create timestamp") } - token.PrivateData = private - return token, nil -} -// Encrypts the supplied buffer for the token private parts -func EncryptConnectTokenPrivate(privateData *[]byte, protocolId, expireTimestamp, sequence uint64, privateKey []byte) error { - additionalData, nonce := buildCryptData(protocolId, expireTimestamp, sequence) + if token.Sequence, err = buffer.GetUint64(); err != nil { + return nil, errors.New("read connect data has bad sequence " + err.Error()) + } - if err := EncryptAead(privateData, additionalData, nonce, privateKey); err != nil { - return err + if privateData, err = buffer.GetBytes(CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES); err != nil { + return nil, errors.New("read connect data has bad private data " + err.Error()) } - return nil -} -// Decrypts the supplied privateData buffer and generates a new ConnectTokenPrivate instance -func DecryptConnectTokenPrivate(privateData []byte, protocolId, expireTimestamp, sequence uint64, privateKey []byte) ([]byte, error) { - additionalData, nonce := buildCryptData(protocolId, expireTimestamp, sequence) - return DecryptAead(privateData, additionalData, nonce, privateKey) -} + // it is still encrypted at this point. + token.PrivateData.TokenData = NewBufferFromBytes(privateData) -// builds the additional data and nonce necessary for encryption and decryption. -func buildCryptData(protocolId, expireTimestamp, sequence uint64) ([]byte, []byte) { - additionalData := NewBuffer(VERSION_INFO_BYTES+8+8) - additionalData.WriteBytes([]byte(VERSION_INFO)) - additionalData.WriteUint64(protocolId) - additionalData.WriteUint64(expireTimestamp) + // reads servers, client and server key + if err = token.ReadShared(buffer); err != nil { + return nil, err + } - nonce := NewBuffer(SizeUint64) - nonce.WriteUint64(sequence) + if token.TimeoutSeconds, err = buffer.GetUint32(); err != nil { + return nil, err + } - return additionalData.Buf, nonce.Buf + return token, nil } - -func writeServerData(buffer *Buffer, serverAddrs []net.UDPAddr, clientKey, serverKey, userData []byte) error { - buffer.WriteUint32(uint32(len(serverAddrs))) - - for _, addr := range serverAddrs { - host, port, err := net.SplitHostPort(addr.String()) - if err != nil { - return errors.New("invalid port for host: " + addr.String()) - } - - parsed := net.ParseIP(host) - if parsed == nil { - return errors.New("invalid ip address") - } - - if len(parsed) == 4 { - buffer.WriteUint8(uint8(ADDRESS_IPV4)) - - } else { - buffer.WriteUint8(uint8(ADDRESS_IPV6)) - } - - for i := 0; i < len(parsed); i +=1 { - buffer.WriteUint8(parsed[i]) - } - - p, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return err - } - buffer.WriteUint16(uint16(p)) - } - buffer.WriteBytesN(clientKey, KEY_BYTES) - buffer.WriteBytesN(serverKey, KEY_BYTES) - buffer.WriteBytesN(userData, USER_DATA_BYTES) - return nil -} \ No newline at end of file diff --git a/go/connect_token_private.go b/go/connect_token_private.go index 5ab8941..b153960 100644 --- a/go/connect_token_private.go +++ b/go/connect_token_private.go @@ -1,27 +1,45 @@ package netcode import ( - "net" "errors" - "strconv" - "log" ) +// The private parts of a connect token type ConnectTokenPrivate struct { - ClientId uint64 - ServerAddrs []net.UDPAddr // list of server addresses this client may connect to - ClientKey []byte // client to server key - ServerKey []byte // server to client key - UserData []byte // used to store user data - TokenData *Buffer // used to store the serialized buffer + sharedTokenData // holds the server addresses, client <-> server keys + ClientId uint64 // id for this token + UserData []byte // used to store user data + TokenData *Buffer // used to store the serialized/encrypted buffer } +// Create a new connect token private with an empty TokenData buffer func NewConnectTokenPrivate() *ConnectTokenPrivate { p := &ConnectTokenPrivate{} p.TokenData = NewBuffer(CONNECT_TOKEN_PRIVATE_BYTES) return p } +// Create a new connect token private with an pre-set, encrypted buffer +// Caller is expected to call Decrypt() and Read() to set the instances properties +func NewConnectTokenPrivateEncrypted(buffer []byte) *ConnectTokenPrivate { + p := &ConnectTokenPrivate{} + p.TokenData = NewBufferFromBytes(buffer) + return p +} + +// Helper to return the internal []byte of the private data +func (p *ConnectTokenPrivate) Buffer() []byte { + return p.TokenData.Buf +} + +// Reads the configuration values to set various properties of this private token data +// and requires a supplied userData slice. +func (p *ConnectTokenPrivate) Generate(config *Config, userData []byte) error { + p.ClientId = config.ClientId + p.UserData = userData + return p.GenerateShared(config) +} + // Reads the token properties from the internal TokenData buffer. func (p *ConnectTokenPrivate) Read() error { var err error @@ -30,18 +48,10 @@ func (p *ConnectTokenPrivate) Read() error { return err } - if err := p.readServerData(); err != nil { + if err = p.ReadShared(p.TokenData); err != nil { return err } - if p.ClientKey, err = p.TokenData.GetBytes(KEY_BYTES); err != nil { - return errors.New("error reading client to server key") - } - - if p.ServerKey, err = p.TokenData.GetBytes(KEY_BYTES); err != nil { - return errors.New("error reading server to client key") - } - if p.UserData, err = p.TokenData.GetBytes(USER_DATA_BYTES); err != nil { return errors.New("error reading user data") } @@ -49,62 +59,50 @@ func (p *ConnectTokenPrivate) Read() error { return nil } -func (p *ConnectTokenPrivate) readServerData() error { - var err error - var servers uint32 - var ipBytes []byte +// Writes the token data to our TokenData buffer and alternatively returns the buffer to caller. +func (p *ConnectTokenPrivate) Write() ([]byte, error) { + p.TokenData.WriteUint64(p.ClientId) - servers, err = p.TokenData.GetUint32() - if err != nil { - return err + if err := p.WriteShared(p.TokenData); err != nil { + return nil, err } - if servers <= 0 { - return errors.New("empty servers") - } + p.TokenData.WriteBytesN(p.UserData, USER_DATA_BYTES) + return p.TokenData.Buf, nil +} - if servers > MAX_SERVERS_PER_CONNECT { - log.Printf("got %d expected %d\n", servers, MAX_SERVERS_PER_CONNECT) - return errors.New("too many servers") - } +// Encrypts, in place, the TokenData buffer, assumes Write() has already been called. +func (token *ConnectTokenPrivate) Encrypt(protocolId, expireTimestamp, sequence uint64, privateKey []byte) error { + additionalData, nonce := buildTokenCryptData(protocolId, expireTimestamp, sequence) - p.ServerAddrs = make([]net.UDPAddr, servers) - - for i := 0; i < int(servers); i+=1 { - serverType, err := p.TokenData.GetUint8() - if err != nil { - return err - } - - if serverType == ADDRESS_IPV4 { - ipBytes, err = p.TokenData.GetBytes(4) - } else if serverType == ADDRESS_IPV6 { - ipBytes, err = p.TokenData.GetBytes(16) - } else { - return errors.New("unknown ip address") - } - - if err != nil { - return err - } - - ip := net.IP(ipBytes) - port, err := p.TokenData.GetUint16() - if err != nil { - return errors.New("invalid port") - } - p.ServerAddrs[i] = net.UDPAddr{IP: ip, Port: int(port)} + if err := EncryptAead(&token.TokenData.Buf, additionalData, nonce, privateKey); err != nil { + return err } return nil } -// Writes the token data to a byte slice and returns to caller -func (token *ConnectTokenPrivate) Write() ([]byte, error) { - data := NewBuffer(CONNECT_TOKEN_PRIVATE_BYTES) - data.WriteUint64(token.ClientId) +// Decrypts the internal TokenData buffer, assumes that TokenData has been populated with the encrypted data +// (most likely via NewConnectTokenPrivateEncrypted(...)). Optionally returns the decrypted buffer to caller. +func (token *ConnectTokenPrivate) Decrypt(protocolId, expireTimestamp, sequence uint64, privateKey []byte) ([]byte, error) { + var err error + additionalData, nonce := buildTokenCryptData(protocolId, expireTimestamp, sequence) - if err := writeServerData(data, token.ServerAddrs, token.ClientKey, token.ServerKey, token.UserData); err != nil { + if token.TokenData.Buf, err = DecryptAead(token.TokenData.Buf, additionalData, nonce, privateKey); err != nil { return nil, err } - return data.Buf, nil -} \ No newline at end of file + token.TokenData.Reset() // reset for reads + return token.TokenData.Buf, nil +} + +// Builds the additional data and nonce necessary for encryption and decryption. +func buildTokenCryptData(protocolId, expireTimestamp, sequence uint64) ([]byte, []byte) { + additionalData := NewBuffer(VERSION_INFO_BYTES + 8 + 8) + additionalData.WriteBytes([]byte(VERSION_INFO)) + additionalData.WriteUint64(protocolId) + additionalData.WriteUint64(expireTimestamp) + + nonce := NewBuffer(SizeUint64) + nonce.WriteUint64(sequence) + + return additionalData.Buf, nonce.Buf +} diff --git a/go/connect_token_private_test.go b/go/connect_token_private_test.go new file mode 100644 index 0000000..1ad62ff --- /dev/null +++ b/go/connect_token_private_test.go @@ -0,0 +1,92 @@ +package netcode + +import ( + "bytes" + "net" + "testing" + "time" +) + +func TestConnectTokenPrivate(t *testing.T) { + token1 := NewConnectTokenPrivate() + server := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000} + servers := make([]net.UDPAddr, 1) + servers[0] = server + + config := NewConfig(servers, TEST_TIMEOUT_SECONDS, TEST_CONNECT_TOKEN_EXPIRY, TEST_CLIENT_ID, TEST_PROTOCOL_ID, TEST_PRIVATE_KEY) + currentTimestamp := uint64(time.Now().Unix()) + expireTimestamp := uint64(currentTimestamp + config.TokenExpiry) + + userData, err := RandomBytes(USER_DATA_BYTES) + if err != nil { + t.Fatalf("error generating random bytes: %s\n", err) + } + + if err := token1.Generate(config, userData); err != nil { + t.Fatalf("error generating and encrypting token") + } + + if _, err := token1.Write(); err != nil { + t.Fatalf("error writing token private data") + } + + if err := token1.Encrypt(config.ProtocolId, expireTimestamp, TEST_SEQUENCE_START, config.PrivateKey); err != nil { + t.Fatalf("error encrypting token: %s\n", err) + } + + token2 := NewConnectTokenPrivate() + token2.TokenData = NewBufferFromBytes(token1.Buffer()) + + if _, err := token2.Decrypt(config.ProtocolId, expireTimestamp, TEST_SEQUENCE_START, config.PrivateKey); err != nil { + t.Fatalf("error decrypting token: %s", err) + } + + if err := token2.Read(); err != nil { + t.Fatalf("error reading token: %s\n", err) + } + + testComparePrivateTokens(token1, token2, t) + + token2.TokenData.Reset() + if _, err = token2.Write(); err != nil { + t.Fatalf("error writing token2 buffer") + } + + if err := token2.Encrypt(config.ProtocolId, expireTimestamp, TEST_SEQUENCE_START, config.PrivateKey); err != nil { + t.Fatalf("error encrypting second token: %s\n", err) + } + + if len(token1.Buffer()) != len(token2.Buffer()) { + t.Fatalf("encrypted buffer lengths did not match %d and %d\n", len(token1.Buffer()), len(token2.Buffer())) + } + + if bytes.Compare(token1.Buffer(), token2.Buffer()) != 0 { + t.Fatalf("encrypted private bits didn't match\n%#v\n and\n%#v\n", token1.Buffer(), token2.Buffer()) + } +} + +func testComparePrivateTokens(token1, token2 *ConnectTokenPrivate, t *testing.T) { + if token1.ClientId != token2.ClientId { + t.Fatalf("clientIds do not match expected %d got %d", token1.ClientId, token2.ClientId) + } + + if len(token1.ServerAddrs) != len(token2.ServerAddrs) { + t.Fatalf("time stamps do not match expected %d got %d", len(token1.ServerAddrs), len(token2.ServerAddrs)) + } + + token1Servers := token1.ServerAddrs + token2Servers := token2.ServerAddrs + for i := 0; i < len(token1.ServerAddrs); i += 1 { + if bytes.Compare([]byte(token1Servers[i].IP), []byte(token2Servers[i].IP)) != 0 { + t.Fatalf("server addresses did not match: expected %v got %v\n", token1Servers[i], token2Servers[i]) + } + } + + if bytes.Compare(token1.ClientKey, token2.ClientKey) != 0 { + t.Fatalf("ClientKey do not match expected %v got %v", token1.ClientKey, token2.ClientKey) + } + + if bytes.Compare(token1.ServerKey, token2.ServerKey) != 0 { + t.Fatalf("ServerKey do not match expected %v got %v", token1.ServerKey, token2.ServerKey) + } +} diff --git a/go/connect_token_shared.go b/go/connect_token_shared.go new file mode 100644 index 0000000..306bcfb --- /dev/null +++ b/go/connect_token_shared.go @@ -0,0 +1,126 @@ +package netcode + +import ( + "errors" + "net" + "strconv" +) + +// This struct contains data that is shared in both public and private parts of the +// connect token. +type sharedTokenData struct { + ServerAddrs []net.UDPAddr // list of server addresses this client may connect to + ClientKey []byte // client to server key + ServerKey []byte // server to client key +} + +// Reads and validates the servers and client <-> server keys. +func (shared *sharedTokenData) ReadShared(buffer *Buffer) error { + var err error + var servers uint32 + var ipBytes []byte + + servers, err = buffer.GetUint32() + if err != nil { + return err + } + + if servers <= 0 { + return errors.New("empty servers") + } + + if servers > MAX_SERVERS_PER_CONNECT { + return errors.New("too many servers") + } + + shared.ServerAddrs = make([]net.UDPAddr, servers) + + for i := 0; i < int(servers); i += 1 { + serverType, err := buffer.GetUint8() + if err != nil { + return err + } + + if serverType == ADDRESS_IPV4 { + ipBytes, err = buffer.GetBytes(4) + } else if serverType == ADDRESS_IPV6 { + ipBytes, err = buffer.GetBytes(16) + } else { + return errors.New("unknown ip address") + } + + if err != nil { + return err + } + + ip := net.IP(ipBytes) + port, err := buffer.GetUint16() + if err != nil { + return errors.New("invalid port") + } + shared.ServerAddrs[i] = net.UDPAddr{IP: ip, Port: int(port)} + } + + if shared.ClientKey, err = buffer.GetBytes(KEY_BYTES); err != nil { + return err + } + + if shared.ServerKey, err = buffer.GetBytes(KEY_BYTES); err != nil { + return err + } + + return nil +} + +// Writes the servers and client <-> server keys to the supplied buffer +func (shared *sharedTokenData) WriteShared(buffer *Buffer) error { + buffer.WriteUint32(uint32(len(shared.ServerAddrs))) + + for _, addr := range shared.ServerAddrs { + host, port, err := net.SplitHostPort(addr.String()) + if err != nil { + return errors.New("invalid port for host: " + addr.String()) + } + + parsed := net.ParseIP(host) + if parsed == nil { + return errors.New("invalid ip address") + } + + if len(parsed) == 4 { + buffer.WriteUint8(uint8(ADDRESS_IPV4)) + + } else { + buffer.WriteUint8(uint8(ADDRESS_IPV6)) + } + + for i := 0; i < len(parsed); i += 1 { + buffer.WriteUint8(parsed[i]) + } + + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return err + } + buffer.WriteUint16(uint16(p)) + } + buffer.WriteBytesN(shared.ClientKey, KEY_BYTES) + buffer.WriteBytesN(shared.ServerKey, KEY_BYTES) + return nil +} + +// Generates the shared data, should only really be called by ConnectTokenPrivate +// since the same data will be copied/referenced by ConnectToken +func (shared *sharedTokenData) GenerateShared(config *Config) error { + var err error + + shared.ServerAddrs = config.ServerAddrs + if shared.ClientKey, err = GenerateKey(); err != nil { + return err + } + + if shared.ServerKey, err = GenerateKey(); err != nil { + return err + } + return nil +} diff --git a/go/connect_token_test.go b/go/connect_token_test.go index 5add773..6650836 100644 --- a/go/connect_token_test.go +++ b/go/connect_token_test.go @@ -1,130 +1,119 @@ package netcode import ( - "testing" - "net" "bytes" - "time" - "go/token" + "net" + "testing" ) const ( - TEST_PROTOCOL_ID = 0x1122334455667788 + TEST_PROTOCOL_ID = 0x1122334455667788 TEST_CONNECT_TOKEN_EXPIRY = 30 - TEST_SERVER_PORT = 40000 - TEST_CLIENT_ID = 0x1 - TEST_SEQUENCE_START = 1000 + TEST_SERVER_PORT = 40000 + TEST_CLIENT_ID = 0x1 + TEST_SEQUENCE_START = 1000 + TEST_TIMEOUT_SECONDS = 1 ) var TEST_PRIVATE_KEY = []byte{0x60, 0x6a, 0xbe, 0x6e, 0xc9, 0x19, 0x10, 0xea, - 0x9a, 0x65, 0x62, 0xf6, 0x6f, 0x2b, 0x30, 0xe4, - 0x43, 0x71, 0xd6, 0x2c, 0xd1, 0x99, 0x27, 0x26, - 0x6b, 0x3c, 0x60, 0xf4, 0xb7, 0x15, 0xab, 0xa1 } + 0x9a, 0x65, 0x62, 0xf6, 0x6f, 0x2b, 0x30, 0xe4, + 0x43, 0x71, 0xd6, 0x2c, 0xd1, 0x99, 0x27, 0x26, + 0x6b, 0x3c, 0x60, 0xf4, 0xb7, 0x15, 0xab, 0xa1} + +func TestConnectToken(t *testing.T) { + var err error + var tokenBuffer []byte + var key []byte -func TestNewConnectToken(t *testing.T) { - token1 := NewConnectToken() + inToken := NewConnectToken() server := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000} servers := make([]net.UDPAddr, 1) servers[0] = server - config := NewConfig(servers, TEST_CONNECT_TOKEN_EXPIRY, TEST_PROTOCOL_ID, TEST_PRIVATE_KEY) - currentTimestamp := uint64(time.Now().Unix()) - - err := token1.Generate(config, TEST_CLIENT_ID, currentTimestamp, TEST_SEQUENCE_START) - if err != nil { - t.Fatalf("error generating and encrypting token") - } - - private, err := token1.Write() - if err != nil { - t.Fatalf("error writing token private data") + if key, err = GenerateKey(); err != nil { + t.Fatalf("error generating key %s\n", key) } - EncryptConnectTokenPrivate(&private, TEST_PROTOCOL_ID, uint64(currentTimestamp + config.TokenExpiry), TEST_SEQUENCE_START, config.PrivateKey) + config := NewConfig(servers, TEST_TIMEOUT_SECONDS, TEST_CONNECT_TOKEN_EXPIRY, TEST_CLIENT_ID, TEST_PROTOCOL_ID, key) - token2, err := ReadConnectToken(private, config.ProtocolId, currentTimestamp+config.TokenExpiry, TEST_SEQUENCE_START, config.PrivateKey) + // generate will write & encrypt the ConnectTokenPrivate + err = inToken.Generate(config, TEST_SEQUENCE_START) if err != nil { - t.Fatalf("error reading connect token %s", err) + t.Fatalf("error generating") } - compareTokens(token1, token2, t) + // Writes the entire ConnectToken (including Private) + if tokenBuffer, err = inToken.Write(); err != nil { + t.Fatalf("error writing token: %s\n", err) + } - private2, err := token2.Write() + outToken, err := ReadConnectToken(tokenBuffer) if err != nil { - t.Fatalf("error writing token2 buffer") + t.Fatalf("error re-reading back token buffer: %s\n", err) } - EncryptConnectTokenPrivate(&private2, TEST_PROTOCOL_ID, uint64(currentTimestamp + config.TokenExpiry), TEST_SEQUENCE_START, config.PrivateKey) - - if bytes.Compare(private, private2) != 0 { - t.Fatalf("encrypted private bits didn't match %v and %v\n", private, private2) + if string(inToken.VersionInfo) != string(outToken.VersionInfo) { + t.Fatalf("version info did not match expected: %s got: %s\n", inToken.VersionInfo, outToken.VersionInfo) } -} - -func TestConnectTokenPublic(t *testing.T) { - token1 := NewConnectToken() - server := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000} - servers := make([]net.UDPAddr, 1) - servers[0] = server - key, err := GenerateKey() - if err != nil { - t.Fatalf("error generating key %s\n", key) + if inToken.ProtocolId != outToken.ProtocolId { + t.Fatalf("ProtocolId did not match expected: %s got: %s\n", inToken.ProtocolId, outToken.ProtocolId) } - config := NewConfig(servers, TEST_CONNECT_TOKEN_EXPIRY, TEST_PROTOCOL_ID, key) - currentTimestamp := uint64(time.Now().Unix()) - - err = token1.Generate(config, TEST_CLIENT_ID, currentTimestamp, TEST_SEQUENCE_START) - if err != nil { - t.Fatalf("error generating and encrypting token") + if inToken.CreateTimestamp != outToken.CreateTimestamp { + t.Fatalf("CreateTimestamp did not match expected: %s got: %s\n", inToken.CreateTimestamp, outToken.CreateTimestamp) } - private, err := token1.Write() - if err != nil { - t.Fatalf("error writing token private data") + if inToken.ExpireTimestamp != outToken.ExpireTimestamp { + t.Fatalf("ExpireTimestamp did not match expected: %s got: %s\n", inToken.ExpireTimestamp, outToken.ExpireTimestamp) } - // write it to a buffer - EncryptConnectTokenPrivate(&private, TEST_PROTOCOL_ID, uint64(currentTimestamp + config.TokenExpiry), TEST_SEQUENCE_START, config.PrivateKey) + if inToken.Sequence != outToken.Sequence { + t.Fatalf("Sequence did not match expected: %s got: %s\n", inToken.Sequence, outToken.Sequence) + } - // set misc public token properties - token1.TimeoutSeconds = int(TIMEOUT_SECONDS) + testCompareTokens(inToken, outToken, t) - tokenData, err := token1.Write() - if err != nil { - t.Fatalf("error writing token: %s\n", err) + if bytes.Compare(inToken.PrivateData.Buffer(), outToken.PrivateData.Buffer()) != 0 { + t.Fatalf("encrypted private data of tokens did not match\n%#v\n%#v", inToken.PrivateData.Buffer(), outToken.PrivateData.Buffer()) } + // need to decrypt the private tokens before we can compare + if _, err := outToken.PrivateData.Decrypt(config.ProtocolId, outToken.ExpireTimestamp, outToken.Sequence, key); err != nil { + t.Fatalf("error decrypting private out token data: %s\n", err) + } + if _, err := inToken.PrivateData.Decrypt(config.ProtocolId, inToken.ExpireTimestamp, inToken.Sequence, key); err != nil { + t.Fatalf("error decrypting private in token data: %s\n", err) + } + // and re-read to set the properties in outToken private + if err := outToken.PrivateData.Read(); err != nil { + t.Fatalf("error reading private data %s", err) + } + testComparePrivateTokens(inToken.PrivateData, outToken.PrivateData, t) } -func compareTokens(token1, token2 *ConnectToken, t *testing.T) { - if token1.ClientId() != token2.ClientId() { - t.Fatalf("clientIds do not match expected %d got %d", token1.ClientId, token2.ClientId) +func testCompareTokens(token1, token2 *ConnectToken, t *testing.T) { + if len(token1.ServerAddrs) != len(token2.ServerAddrs) { + t.Fatalf("time stamps do not match expected %d got %d", len(token1.ServerAddrs), len(token2.ServerAddrs)) } - if len(token1.ServerAddresses()) != len(token2.ServerAddresses()) { - t.Fatalf("time stamps do not match expected %d got %d", len(token1.ServerAddresses()), len(token2.ServerAddresses())) - } - - token1Servers := token1.ServerAddresses() - token2Servers := token2.ServerAddresses() - for i := 0; i < len(token1.ServerAddresses()); i+=1 { + token1Servers := token1.ServerAddrs + token2Servers := token2.ServerAddrs + for i := 0; i < len(token1.ServerAddrs); i += 1 { if bytes.Compare([]byte(token1Servers[i].IP), []byte(token2Servers[i].IP)) != 0 { t.Fatalf("server addresses did not match: expected %v got %v\n", token1Servers[i], token2Servers[i]) } } - if bytes.Compare(token1.ClientKey(), token2.ClientKey()) != 0 { - t.Fatalf("ClientKey do not match expected %v got %v", token1.ClientKey(), token2.ClientKey()) + if bytes.Compare(token1.ClientKey, token2.ClientKey) != 0 { + t.Fatalf("ClientKey do not match expected %v got %v", token1.ClientKey, token2.ClientKey) } - if bytes.Compare(token1.ServerKey(), token2.ServerKey()) != 0 { - t.Fatalf("ServerKey do not match expected %v got %v", token1.ServerKey(), token2.ServerKey()) + if bytes.Compare(token1.ServerKey, token2.ServerKey) != 0 { + t.Fatalf("ServerKey do not match expected %v got %v", token1.ServerKey, token2.ServerKey) } - -} \ No newline at end of file +} diff --git a/go/packet.go b/go/packet.go index 8bf65f1..084c815 100644 --- a/go/packet.go +++ b/go/packet.go @@ -2,8 +2,8 @@ package netcode import ( "errors" - "strconv" "log" + "strconv" ) const MAX_CLIENTS = 60 @@ -16,7 +16,7 @@ const MAX_PAYLOAD_BYTES = 1200 const MAX_ADDRESS_STRING_LENGTH = 256 const REPLAY_PROTECTION_BUFFER_SIZE = 256 const CLIENT_MAX_RECEIVE_PACKETS = 64 -const SERVER_MAX_RECEIVE_PACKETS = ( 64 * MAX_CLIENTS ) +const SERVER_MAX_RECEIVE_PACKETS = (64 * MAX_CLIENTS) const KEY_BYTES = 32 const MAC_BYTES = 16 @@ -25,7 +25,6 @@ const MAX_SERVERS_PER_CONNECT = 32 const VERSION_INFO = "NETCODE 1.00\x00" - type PacketType uint8 const ( @@ -36,32 +35,115 @@ const ( ConnectionKeepAlive ConnectionPayload ConnectionDisconnect - ) + // not a packet type, but value is last packetType+1 -const ConnectionNumPackets = ConnectionDisconnect+1 - -var packetTypeMap = map[PacketType]string { - ConnectionRequest: "CONNECTION_REQUEST", - ConnectionDenied: "CONNECTION_DENIED", - ConnectionChallenge: "CONNECTION_CHALLENGE", - ConnectionResponse: "CONNECTION_RESPONSE", - ConnectionKeepAlive: "CONNECTION_KEEPALIVE", - ConnectionPayload: "CONNECTION_PAYLOAD", +const ConnectionNumPackets = ConnectionDisconnect + 1 + +var packetTypeMap = map[PacketType]string{ + ConnectionRequest: "CONNECTION_REQUEST", + ConnectionDenied: "CONNECTION_DENIED", + ConnectionChallenge: "CONNECTION_CHALLENGE", + ConnectionResponse: "CONNECTION_RESPONSE", + ConnectionKeepAlive: "CONNECTION_KEEPALIVE", + ConnectionPayload: "CONNECTION_PAYLOAD", ConnectionDisconnect: "CONNECTION_DISCONNECT", } type Packet interface { GetType() PacketType + Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) + Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error } type RequestPacket struct { - VersionInfo []byte - ProtocolId uint64 + VersionInfo []byte + ProtocolId uint64 ConnectTokenExpireTimestamp uint64 - ConnectTokenSequence uint64 - Token *ConnectToken - ConnectTokenData []byte // the encrypted Token after Write -> Encrypt + ConnectTokenSequence uint64 + Token *ConnectTokenPrivate + ConnectTokenData []byte // the encrypted Token after Write -> Encrypt +} + +// Writes the RequestPacket data to a supplied buffer and returns the length of bytes written to it. +func (p *RequestPacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + buffer.WriteUint8(uint8(ConnectionRequest)) + buffer.WriteBytes(p.VersionInfo) + buffer.WriteUint64(p.ProtocolId) + buffer.WriteUint64(p.ConnectTokenExpireTimestamp) + buffer.WriteUint64(p.ConnectTokenSequence) + buffer.WriteBytes(p.ConnectTokenData) // write the encrypted connection token private data + if buffer.Pos != 1+13+8+8+8+CONNECT_TOKEN_PRIVATE_BYTES+MAC_BYTES { + return -1, errors.New("invalid buffer size written") + } + return buffer.Pos, nil +} + +// Reads a request packet and decrypts the connect token private data +func (p *RequestPacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + var err error + var packetType uint8 + + if packetType, err = packetBuffer.GetUint8(); err != nil || PacketType(packetType) != ConnectionRequest { + return errors.New("invalid packet type") + } + + if allowedPackets[0] == 0 { + return errors.New("ignored connection request packet. packet type is not allowed") + } + + if packetLen != 1+VERSION_INFO_BYTES+8+8+8+CONNECT_TOKEN_PRIVATE_BYTES+MAC_BYTES { + return errors.New("ignored connection request packet. bad packet length") + } + + if privateKey == nil { + return errors.New("ignored connection request packet. no private key\n") + } + + p.VersionInfo, err = packetBuffer.GetBytes(VERSION_INFO_BYTES) + if err != nil { + return errors.New("ignored connection request packet. bad version info invalid bytes returned\n") + } + + if string(p.VersionInfo) != VERSION_INFO { + return errors.New("ignored connection request packet. bad version info did not match\n") + } + + p.ProtocolId, err = packetBuffer.GetUint64() + if err != nil || p.ProtocolId != protocolId { + return errors.New("ignored connection request packet. wrong protocol id\n") + } + + p.ConnectTokenExpireTimestamp, err = packetBuffer.GetUint64() + if err != nil || p.ConnectTokenExpireTimestamp <= currentTimestamp { + return errors.New("ignored connection request packet. connect token expired\n") + } + + p.ConnectTokenSequence, err = packetBuffer.GetUint64() + if err != nil { + return err + } + + if packetBuffer.Pos != 1+VERSION_INFO_BYTES+8+8+8 { + return errors.New("invalid length of packet buffer read") + } + + var tokenBuffer []byte + tokenBuffer, err = packetBuffer.GetBytes(CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES) + if err != nil { + return err + } + + p.Token = NewConnectTokenPrivateEncrypted(tokenBuffer) + if _, err := p.Token.Decrypt(p.ProtocolId, p.ConnectTokenExpireTimestamp, p.ConnectTokenSequence, privateKey); err != nil { + return errors.New("error decrypting connect token private data: " + err.Error()) + } + + if err := p.Token.Read(); err != nil { + return errors.New("error reading decrypted connect token private data: " + err.Error()) + } + + return nil } func (p *RequestPacket) GetType() PacketType { @@ -71,13 +153,71 @@ func (p *RequestPacket) GetType() PacketType { type DeniedPacket struct { } +func (p *DeniedPacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) + if err != nil { + return -1, err + } + + // denied packets are empty + return encryptPacket(buffer, buffer.Pos, buffer.Pos, prefixByte, protocolId, sequence, writePacketKey) +} + +func (p *DeniedPacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) + if err != nil { + return err + } + + if decryptedBuf.Len() != 0 { + return errors.New("ignored connection denied packet. decrypted packet data is wrong size") + } + return nil +} + func (p *DeniedPacket) GetType() PacketType { return ConnectionDenied } type ChallengePacket struct { ChallengeTokenSequence uint64 - ChallengeTokenData []byte + ChallengeTokenData []byte +} + +func (p *ChallengePacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) + if err != nil { + return -1, err + } + + encryptedStart := buffer.Pos + buffer.WriteUint64(p.ChallengeTokenSequence) + buffer.WriteBytesN(p.ChallengeTokenData, CHALLENGE_TOKEN_BYTES) + encryptedFinish := buffer.Pos + return encryptPacket(buffer, encryptedStart, encryptedFinish, prefixByte, protocolId, sequence, writePacketKey) +} + +func (p *ChallengePacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) + if err != nil { + return err + } + + if decryptedBuf.Len() != 8+CHALLENGE_TOKEN_BYTES { + return errors.New("ignored connection challenge packet. decrypted packet data is wrong size") + } + + p.ChallengeTokenSequence, err = decryptedBuf.GetUint64() + if err != nil { + return errors.New("error reading challenge token sequence") + } + + p.ChallengeTokenData, err = decryptedBuf.GetBytes(CHALLENGE_TOKEN_BYTES) + if err != nil { + return errors.New("error reading challenge token data") + } + + return nil } func (p *ChallengePacket) GetType() PacketType { @@ -86,230 +226,189 @@ func (p *ChallengePacket) GetType() PacketType { type ResponsePacket struct { ChallengeTokenSequence uint64 - ChallengeTokenData []byte + ChallengeTokenData []byte } -func (p *ResponsePacket) GetType() PacketType { - return ConnectionResponse -} +func (p *ResponsePacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) + if err != nil { + return -1, err + } -type KeepAlivePacket struct { - ClientIndex uint32 - MaxClients uint32 + encryptedStart := buffer.Pos + buffer.WriteUint64(p.ChallengeTokenSequence) + buffer.WriteBytesN(p.ChallengeTokenData, CHALLENGE_TOKEN_BYTES) + encryptedFinish := buffer.Pos + return encryptPacket(buffer, encryptedStart, encryptedFinish, prefixByte, protocolId, sequence, writePacketKey) } -func (p *KeepAlivePacket) GetType() PacketType { - return ConnectionKeepAlive -} +func (p *ResponsePacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) + if err != nil { + return err + } + if decryptedBuf.Len() != 8+CHALLENGE_TOKEN_BYTES { + return errors.New("ignored connection challenge packet. decrypted packet data is wrong size") + } -type PayloadPacket struct { - PayloadBytes uint32 - PayloadData []byte -} + p.ChallengeTokenSequence, err = decryptedBuf.GetUint64() + if err != nil { + return errors.New("error reading challenge token sequence") + } -func (p *PayloadPacket) GetType() PacketType { - return ConnectionPayload -} + p.ChallengeTokenData, err = decryptedBuf.GetBytes(CHALLENGE_TOKEN_BYTES) + if err != nil { + return errors.New("error reading challenge token data") + } -func NewPayloadPacket(payloadBytes uint32) *PayloadPacket { - packet := &PayloadPacket{} - packet.PayloadBytes = payloadBytes - packet.PayloadData = make([]byte, payloadBytes) - return packet + return nil } -type DisconnectPacket struct { +func (p *ResponsePacket) GetType() PacketType { + return ConnectionResponse } -func (p *DisconnectPacket) GetType() PacketType { - return ConnectionDisconnect +type KeepAlivePacket struct { + ClientIndex uint32 + MaxClients uint32 } -func WritePacket(packet Packet, buffer *Buffer, sequence uint64, writePacketKey []byte, protocolId uint64) (int, error) { - packetType := packet.GetType() - // TODO: this should be moved to writePacketData provided packet prefix can be safely ignored/added - if packetType == ConnectionRequest { - // connection request packet: first byte is zero - p, ok := packet.(*RequestPacket) - if !ok { - return -1, errors.New("invalid packet type, expecting request packet") - } - buffer.WriteUint8(uint8(ConnectionRequest)) - buffer.WriteBytes(p.VersionInfo) - buffer.WriteUint64(p.ProtocolId) - buffer.WriteUint64(p.ConnectTokenExpireTimestamp) - buffer.WriteUint64(p.ConnectTokenSequence) - buffer.WriteBytes(p.ConnectTokenData) // write the encrypted connection token private data - if buffer.Pos != 1 + 13 + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES { - return -1, errors.New("invalid buffer size written") - } - return buffer.Pos, nil - } - - // *** encrypted packets *** - prefixByte, err := writePacketPrefix(packet, buffer, sequence) +func (p *KeepAlivePacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) if err != nil { return -1, err } encryptedStart := buffer.Pos - if err := writePacketData(packet, buffer); err != nil { - return -1, err - } + buffer.WriteUint32(uint32(p.ClientIndex)) + buffer.WriteUint32(uint32(p.MaxClients)) encryptedFinish := buffer.Pos + return encryptPacket(buffer, encryptedStart, encryptedFinish, prefixByte, protocolId, sequence, writePacketKey) +} - additionalData, nonce := packetCryptData(prefixByte, protocolId, sequence) - log.Printf("data to encrypt size: %d = %d - %d\n", encryptedFinish-encryptedStart, encryptedFinish, encryptedStart) - // slice up the buffer for the bits we will encrypt - encryptedBuffer := buffer.Buf[encryptedStart:encryptedFinish] - if err := EncryptAead(&encryptedBuffer, additionalData, nonce, writePacketKey); err != nil { - return -1, err +func (p *KeepAlivePacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) + if err != nil { + return err } - // hack to reset Pos to write in the encrypted buffer to avoid allocations/append() calls - buffer.Pos = encryptedStart - buffer.WriteBytes(encryptedBuffer) - log.Printf("buffer written so far plus mac: %d\n", buffer.Pos) - return buffer.Pos, nil // in c, we do Pos + MAC_BYTES but the WriteBytes will update Pos to include it -} + if decryptedBuf.Len() != 8 { + return errors.New("ignored connection keep alive packet. decrypted packet data is wrong size") + } -// write the prefix byte (this is a combination of the packet type and number of sequence bytes) -func writePacketPrefix(p Packet, buffer *Buffer, sequence uint64) (uint8, error) { - sequenceBytes := sequenceNumberBytesRequired(sequence) - if sequenceBytes < 1 || sequenceBytes > 8 { - return 0, errors.New("invalid sequence bytes, must be between [1-8]") + p.ClientIndex, err = decryptedBuf.GetUint32() + if err != nil { + return errors.New("error reading keepalive client index") } - prefixByte := uint8(p.GetType()) | uint8(sequenceBytes << 4) - buffer.WriteUint8(prefixByte) + p.MaxClients, err = decryptedBuf.GetUint32() + if err != nil { + return errors.New("error reading keepalive max clients") + } - sequenceTemp := sequence + return nil +} - var i uint8 - for ; i < sequenceBytes; i+=1 { - buffer.WriteUint8(uint8(sequenceTemp & 0xFF)) - sequenceTemp >>= 8 - } - return prefixByte, nil +func (p *KeepAlivePacket) GetType() PacketType { + return ConnectionKeepAlive } -// write packet data according to type. this data will be encrypted. -func writePacketData(packet Packet, buffer *Buffer) error { - switch packet.GetType() { - case ConnectionDenied: - // ... - case ConnectionChallenge: - p, ok := packet.(*ChallengePacket) - if !ok { - return errors.New("invalid packet type") - } - buffer.WriteUint64(p.ChallengeTokenSequence) - buffer.WriteBytesN(p.ChallengeTokenData, CHALLENGE_TOKEN_BYTES) - case ConnectionResponse: - p, ok := packet.(*ResponsePacket) - if !ok { - return errors.New("invalid packet type") - } - buffer.WriteUint64(p.ChallengeTokenSequence) - buffer.WriteBytesN(p.ChallengeTokenData, CHALLENGE_TOKEN_BYTES) - case ConnectionKeepAlive: - p, ok := packet.(*KeepAlivePacket) - if !ok { - return errors.New("invalid packet type") - } - buffer.WriteUint32(uint32(p.ClientIndex)) - buffer.WriteUint32(uint32(p.MaxClients)) - case ConnectionPayload: - p, ok := packet.(*PayloadPacket) - if !ok { - return errors.New("invalid packet type") - } - log.Printf("writing %d payload bytes pre: %d\n", p.PayloadBytes, buffer.Pos) - buffer.WriteBytesN([]byte(p.PayloadData), int(p.PayloadBytes)) - log.Printf("writing %d payload bytes post: %d\n", p.PayloadBytes, buffer.Pos) - case ConnectionDisconnect: - // ... - } - return nil +type PayloadPacket struct { + PayloadBytes uint32 + PayloadData []byte } -// used for encrypting the per-packet packet written with the prefix byte, protocol id and version as the associated data. this must match to decrypt. -func packetCryptData(prefixByte uint8, protocolId, sequence uint64) ([]byte, []byte) { - additionalData := NewBuffer(VERSION_INFO_BYTES+8+1) - additionalData.WriteBytesN([]byte(VERSION_INFO), VERSION_INFO_BYTES) - additionalData.WriteUint64(protocolId) - additionalData.WriteUint8(prefixByte) +func (p *PayloadPacket) GetType() PacketType { + return ConnectionPayload +} - nonce := NewBuffer(SizeUint64) - nonce.WriteUint64(sequence) - return additionalData.Buf, nonce.Buf +func NewPayloadPacket(payloadData []byte) *PayloadPacket { + packet := &PayloadPacket{} + packet.PayloadBytes = uint32(len(payloadData)) + packet.PayloadData = payloadData + return packet } -func ReadPacket(packetData []byte, packetLen int, sequence uint64, readPacketKey []byte, protocolId uint64, currentTimestamp uint64, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) (Packet, error) { - if packetLen < 1 { - return nil, errors.New("invalid buffer length") +func (p *PayloadPacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) + if err != nil { + return -1, err } - packetBuffer := NewBufferFromBytes(packetData) + encryptedStart := buffer.Pos + buffer.WriteBytesN([]byte(p.PayloadData), int(p.PayloadBytes)) + encryptedFinish := buffer.Pos + return encryptPacket(buffer, encryptedStart, encryptedFinish, prefixByte, protocolId, sequence, writePacketKey) +} - prefixByte, err := packetBuffer.GetUint8() +func (p *PayloadPacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) if err != nil { - return nil, errors.New("invalid buffer length") + return err } - if PacketType(prefixByte) == ConnectionRequest { - return readRequestPacket(packetBuffer, packetLen, protocolId, currentTimestamp, allowedPackets, privateKey) + decryptedSize := uint32(decryptedBuf.Len()) + if decryptedSize < 1 { + return errors.New("ignored connection payload packet. payload is too small") } - // *** encrypted packets *** - if readPacketKey == nil { - return nil, errors.New("empty packet key") + if decryptedSize > MAX_PAYLOAD_BYTES { + return errors.New("ignored connection payload packet. payload is too large") } - if packetLen < 1 + 1 + MAC_BYTES { - return nil, errors.New("ignored encrypted packet. packet is too small to be valid") - } + p.PayloadBytes = decryptedSize + p.PayloadData = decryptedBuf.Bytes() + return nil +} - packetType := prefixByte & 0xF +type DisconnectPacket struct { +} - if PacketType(packetType) >= ConnectionNumPackets { - return nil, errors.New("ignored encrypted packet. packet type " + packetTypeMap[PacketType(packetType)] + " is invalid") +func (p *DisconnectPacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) + if err != nil { + return -1, err } - if allowedPackets[packetType] == 0 { - return nil, errors.New("ignored encrypted packet. packet type " + packetTypeMap[PacketType(packetType)] + " is invalid") + // denied packets are empty + return encryptPacket(buffer, buffer.Pos, buffer.Pos, prefixByte, protocolId, sequence, writePacketKey) +} + +func (p *DisconnectPacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) + if err != nil { + return err } - sequenceBytes := prefixByte >> 4 - if sequenceBytes < 1 || sequenceBytes > 8 { - return nil, errors.New("ignored encrypted packet. sequence bytes is out of range [1,8]") + if decryptedBuf.Len() != 0 { + return errors.New("ignored connection denied packet. decrypted packet data is wrong size") } + return nil +} - if packetLen < 1 + int(sequenceBytes) + MAC_BYTES { - return nil, errors.New("ignored encrypted packet. buffer is too small for sequence bytes + encryption mac") +func (p *DisconnectPacket) GetType() PacketType { + return ConnectionDisconnect +} + +func decryptPacket(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) (*Buffer, error) { + var packetSequence uint64 + + prefixByte, err := packetBuffer.GetUint8() + if err != nil { + return nil, errors.New("invalid buffer length") } - var i uint8 - // read variable length sequence number [1,8] - for i = 0; i < sequenceBytes; i+=1 { - val, err := packetBuffer.GetUint8() - if err != nil { - return nil, err - } - sequence |= uint64((val) << ( 8 * i )) + if packetSequence, err = readSequence(packetBuffer, packetLen, prefixByte); err != nil { + return nil, err } - // replay protection (optional) - if replayProtection != nil && PacketType(packetType) >= ConnectionKeepAlive { - if replayProtection.AlreadyReceived(sequence) == 1 { - v := strconv.FormatUint(sequence, 10) - return nil, errors.New("ignored connection payload packet. sequence " + v + " already received (replay protection)") - } + if err := validateSequence(packetLen, prefixByte, packetSequence, readPacketKey, allowedPackets, replayProtection); err != nil { + return nil, err } // decrypt the per-packet type data - additionalData, nonce := packetCryptData(prefixByte, protocolId, sequence) + additionalData, nonce := packetCryptData(prefixByte, protocolId, packetSequence) encryptedSize := packetLen - packetBuffer.Pos if encryptedSize < MAC_BYTES { @@ -326,162 +425,120 @@ func ReadPacket(packetData []byte, packetLen int, sequence uint64, readPacketKey return nil, errors.New("ignored encrypted packet. failed to decrypt: " + err.Error()) } - decryptedSize := encryptedSize - MAC_BYTES - - // process the per-packet type data that was just decrypted - return processPacket(PacketType(packetType), decryptedBuff, decryptedSize) + return NewBufferFromBytes(decryptedBuff), nil } -// Processes the packet after decryption has occurred. -func processPacket(packetType PacketType, decrypted []byte, decryptedSize int) (Packet, error) { - var err error - decryptedBuff := NewBufferFromBytes(decrypted) - - switch packetType { - case ConnectionDenied: - if decryptedSize != 0 { - return nil, errors.New("ignored connection denied packet. decrypted packet data is wrong size") - } - return &DeniedPacket{}, nil - case ConnectionChallenge: - if decryptedSize != 8 + CHALLENGE_TOKEN_BYTES { - return nil, errors.New("ignored connection challenge packet. decrypted packet data is wrong size") - } - - packet := &ChallengePacket{} - packet.ChallengeTokenSequence, err = decryptedBuff.GetUint64() - if err != nil { - return nil, errors.New("error reading challenge token sequence") - } - - packet.ChallengeTokenData, err = decryptedBuff.GetBytes(CHALLENGE_TOKEN_BYTES) - if err != nil { - return nil, errors.New("error reading challenge token data") - } - return packet, nil - case ConnectionResponse: - if decryptedSize != 8 + CHALLENGE_TOKEN_BYTES { - return nil, errors.New("ignored connection response packet. decrypted packet data is wrong size") - } +func readSequence(packetBuffer *Buffer, packetLen int, prefixByte uint8) (uint64, error) { + var sequence uint64 - packet := &ResponsePacket{} - packet.ChallengeTokenSequence, err = decryptedBuff.GetUint64() - if err != nil { - return nil, errors.New("error reading response token sequence") - } + sequenceBytes := prefixByte >> 4 + if sequenceBytes < 1 || sequenceBytes > 8 { + return 0, errors.New("ignored encrypted packet. sequence bytes is out of range [1,8]") + } - packet.ChallengeTokenData, err = decryptedBuff.GetBytes(CHALLENGE_TOKEN_BYTES) - if err != nil { - return nil, errors.New("error reading response token data") - } - return packet, nil - case ConnectionKeepAlive: - if decryptedSize != 8 { - return nil, errors.New("ignored connection keep alive packet. decrypted packet data is wrong size") - } - packet := &KeepAlivePacket{} - packet.ClientIndex, err = decryptedBuff.GetUint32() - if err != nil { - return nil, errors.New("error reading keepalive client index") - } + if packetLen < 1+int(sequenceBytes)+MAC_BYTES { + return 0, errors.New("ignored encrypted packet. buffer is too small for sequence bytes + encryption mac") + } - packet.MaxClients, err = decryptedBuff.GetUint32() + var i uint8 + // read variable length sequence number [1,8] + for i = 0; i < sequenceBytes; i += 1 { + val, err := packetBuffer.GetUint8() if err != nil { - return nil, errors.New("error reading keepalive max clients") - } - return packet, nil - case ConnectionPayload: - if decryptedSize < 1 { - return nil, errors.New("ignored connection payload packet. payload is too small") - } - - if decryptedSize > MAX_PAYLOAD_BYTES { - return nil, errors.New("ignored connection payload packet. payload is too large") + return 0, err } - - packet := NewPayloadPacket(uint32(decryptedSize)) - copy(packet.PayloadData, decryptedBuff.Bytes()) - return packet, nil - case ConnectionDisconnect: - if decryptedSize != 0 { - return nil, errors.New("ignored connection disconnect packet. decrypted packet data is wrong size") - } - packet := &DisconnectPacket{} - return packet, nil + sequence |= (uint64(val) << (8 * i)) } - - return nil, errors.New("unknown packet type") + return sequence, nil } -// Reads the RequestPacket type returning the packet after deserializing -func readRequestPacket(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, allowedPackets []byte, privateKey []byte) (Packet, error) { - var err error - packet := &RequestPacket{} +// Validates the data prior to encrypted blob is valid before we bother attempting to decrypt. +func validateSequence(packetLen int, prefixByte uint8, sequence uint64, readPacketKey, allowedPackets []byte, replayProtection *ReplayProtection) error { - if allowedPackets[0] == 0 { - return nil, errors.New("ignored connection request packet. packet type is not allowed") + if readPacketKey == nil { + return errors.New("empty packet key") } - if packetLen != 1 + VERSION_INFO_BYTES + 8 + 8 + 8 + CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES { - return nil, errors.New("ignored connection request packet. bad packet length") + if packetLen < 1+1+MAC_BYTES { + return errors.New("ignored encrypted packet. packet is too small to be valid") } - if privateKey == nil { - return nil, errors.New("ignored connection request packet. no private key\n") + packetType := prefixByte & 0xF + if PacketType(packetType) >= ConnectionNumPackets { + return errors.New("ignored encrypted packet. packet type " + packetTypeMap[PacketType(packetType)] + " is invalid") } - packet.VersionInfo, err = packetBuffer.GetBytes(VERSION_INFO_BYTES) - if err != nil { - return nil, errors.New("ignored connection request packet. bad version info\n") + if allowedPackets[packetType] == 0 { + return errors.New("ignored encrypted packet. packet type " + packetTypeMap[PacketType(packetType)] + " is invalid") } - if string(packet.VersionInfo) != VERSION_INFO { - return nil, errors.New("ignored connection request packet. bad version info\n") + // replay protection (optional) + if replayProtection != nil && PacketType(packetType) >= ConnectionKeepAlive { + if replayProtection.AlreadyReceived(sequence) == 1 { + v := strconv.FormatUint(sequence, 10) + return errors.New("ignored connection payload packet. sequence " + v + " already received (replay protection)") + } } + return nil +} - packet.ProtocolId, err = packetBuffer.GetUint64() - if err != nil || packet.ProtocolId != protocolId { - return nil, errors.New("ignored connection request packet. wrong protocol id\n") +// write the prefix byte (this is a combination of the packet type and number of sequence bytes) +func writePacketPrefix(p Packet, buffer *Buffer, sequence uint64) (uint8, error) { + sequenceBytes := sequenceNumberBytesRequired(sequence) + if sequenceBytes < 1 || sequenceBytes > 8 { + return 0, errors.New("invalid sequence bytes, must be between [1-8]") } - packet.ConnectTokenExpireTimestamp, err = packetBuffer.GetUint64() - if err != nil || packet.ConnectTokenExpireTimestamp <= currentTimestamp { - return nil, errors.New("ignored connection request packet. connect token expired\n") - } + prefixByte := uint8(p.GetType()) | uint8(sequenceBytes<<4) + buffer.WriteUint8(prefixByte) - packet.ConnectTokenSequence, err = packetBuffer.GetUint64() - if err != nil { - return nil, err - } + sequenceTemp := sequence - if packetBuffer.Pos != 1 + VERSION_INFO_BYTES + 8 + 8 + 8 { - return nil, errors.New(" invalid length of packet buffer read") + var i uint8 + for ; i < sequenceBytes; i += 1 { + log.Printf("sequenceTemp %d: %x\n", i, uint8(sequenceTemp&0xFF)) + buffer.WriteUint8(uint8(sequenceTemp & 0xFF)) + sequenceTemp >>= 8 } + return prefixByte, nil +} - var tokenBuffer []byte - tokenBuffer, err = packetBuffer.GetBytes(CONNECT_TOKEN_PRIVATE_BYTES+MAC_BYTES) - if err != nil { - return nil, err - } +func encryptPacket(buffer *Buffer, encryptedStart, encryptedFinish int, prefixByte uint8, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + // slice up the buffer for the bits we will encrypt + encryptedBuffer := buffer.Buf[encryptedStart:encryptedFinish] - packet.Token, err = ReadConnectToken(tokenBuffer, packet.ProtocolId, packet.ConnectTokenExpireTimestamp, packet.ConnectTokenSequence, privateKey) - if err != nil { - return nil, err + additionalData, nonce := packetCryptData(prefixByte, protocolId, sequence) + if err := EncryptAead(&encryptedBuffer, additionalData, nonce, writePacketKey); err != nil { + return -1, err } - return packet, nil + buffer.Pos = encryptedStart // reset position to start of where the encrypted data goes + buffer.WriteBytes(encryptedBuffer) + return buffer.Pos, nil // in c, we do Pos + MAC_BYTES but the WriteBytes will update buffer.Pos to include it } +// used for encrypting the per-packet packet written with the prefix byte, protocol id and version as the associated data. this must match to decrypt. +func packetCryptData(prefixByte uint8, protocolId, sequence uint64) ([]byte, []byte) { + additionalData := NewBuffer(VERSION_INFO_BYTES + 8 + 1) + additionalData.WriteBytesN([]byte(VERSION_INFO), VERSION_INFO_BYTES) + additionalData.WriteUint64(protocolId) + additionalData.WriteUint8(prefixByte) + nonce := NewBuffer(SizeUint64) + nonce.WriteUint64(sequence) + return additionalData.Buf, nonce.Buf +} + +// Depending on size of sequence number, we need to reserve N bytes func sequenceNumberBytesRequired(sequence uint64) uint8 { var mask uint64 mask = 0xFF00000000000000 var i uint8 - for ; i < 7; i+=1 { - if (sequence & mask != 0) { + for ; i < 7; i += 1 { + if sequence&mask != 0 { break } mask >>= 8 } return 8 - i -} \ No newline at end of file +} diff --git a/go/packet_test.go b/go/packet_test.go index e37b67c..f339e12 100644 --- a/go/packet_test.go +++ b/go/packet_test.go @@ -1,16 +1,12 @@ package netcode import ( - "testing" + "bytes" "net" + "testing" "time" - "bytes" ) -func TestReadPacket(t *testing.T) { - -} - func TestSequence(t *testing.T) { seq := sequenceNumberBytesRequired(0) if seq != 1 { @@ -65,8 +61,8 @@ func TestConnectionRequestPacket(t *testing.T) { t.Fatalf("error generating connect token key: %s\n", err) } inputPacket, decryptedToken := testBuildRequestPacket(connectTokenKey, t) - // write the connection request packet to a buffer + // write the connection request packet to a buffer buffer := NewBuffer(2048) packetKey, err := GenerateKey() @@ -74,7 +70,7 @@ func TestConnectionRequestPacket(t *testing.T) { t.Fatalf("error generating key") } - bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) if err != nil { t.Fatalf("error writing packet: %s\n", err) } @@ -83,48 +79,35 @@ func TestConnectionRequestPacket(t *testing.T) { t.Fatalf("did not write any bytes for this packet") } - // read the connection request packet back in from the buffer (the connect token data is decrypted as part of the read packet validation) - var sequence uint64 - sequence = TEST_SEQUENCE_START - allowedPackets := make([]byte, ConnectionNumPackets) - for i := 0; i < len(allowedPackets); i+=1 { + for i := 0; i < len(allowedPackets); i += 1 { allowedPackets[i] = 1 } + outputPacket := &RequestPacket{} + buffer.Reset() - outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), connectTokenKey, allowedPackets, nil) - if err != nil { + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, connectTokenKey, allowedPackets, nil); err != nil { t.Fatalf("error reading packet: %s\n", err) - } - if outputPacket.GetType() != ConnectionRequest { - t.Fatal("packet output was not a connection request") - } - - output, ok := outputPacket.(*RequestPacket) - if !ok { - t.Fatalf("error casting to connection request packet") - } - - if bytes.Compare(inputPacket.VersionInfo, output.VersionInfo) != 0 { + if bytes.Compare(inputPacket.VersionInfo, outputPacket.VersionInfo) != 0 { t.Fatalf("version info did not match") } - if inputPacket.ProtocolId != output.ProtocolId { + if inputPacket.ProtocolId != outputPacket.ProtocolId { t.Fatalf("ProtocolId did not match") } - if inputPacket.ConnectTokenExpireTimestamp != output.ConnectTokenExpireTimestamp { + if inputPacket.ConnectTokenExpireTimestamp != outputPacket.ConnectTokenExpireTimestamp { t.Fatalf("ConnectTokenExpireTimestamp did not match") } - if inputPacket.ConnectTokenSequence != output.ConnectTokenSequence { + if inputPacket.ConnectTokenSequence != outputPacket.ConnectTokenSequence { t.Fatalf("ConnectTokenSequence did not match") } - if bytes.Compare(decryptedToken, output.Token.PrivateData.TokenData.Buf) != 0 { + if bytes.Compare(decryptedToken, outputPacket.Token.TokenData.Buf) != 0 { t.Fatalf("TokenData did not match") } } @@ -141,7 +124,7 @@ func TestConnectionDeniedPacket(t *testing.T) { } // write the packet to a buffer - bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) if err != nil { t.Fatalf("error writing packet: %s\n", err) } @@ -150,19 +133,17 @@ func TestConnectionDeniedPacket(t *testing.T) { t.Fatalf("did not write any bytes for this packet") } - var sequence uint64 - sequence = TEST_SEQUENCE_START - allowedPackets := make([]byte, ConnectionNumPackets) - for i := 0; i < len(allowedPackets); i+=1 { + for i := 0; i < len(allowedPackets); i += 1 { allowedPackets[i] = 1 } + outputPacket := &DeniedPacket{} buffer.Reset() - outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) - if err != nil { + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { t.Fatalf("error reading packet: %s\n", err) } + if outputPacket.GetType() != ConnectionDenied { t.Fatalf("did not get a denied packet after read") } @@ -187,7 +168,7 @@ func TestConnectionChallengePacket(t *testing.T) { } // write the packet to a buffer - bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) if err != nil { t.Fatalf("error writing packet: %s\n", err) } @@ -196,30 +177,22 @@ func TestConnectionChallengePacket(t *testing.T) { t.Fatalf("did not write any bytes for this packet") } - var sequence uint64 - sequence = TEST_SEQUENCE_START - allowedPackets := make([]byte, ConnectionNumPackets) - for i := 0; i < len(allowedPackets); i+=1 { + for i := 0; i < len(allowedPackets); i += 1 { allowedPackets[i] = 1 } + outputPacket := &ChallengePacket{} buffer.Reset() - outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) - if err != nil { + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { t.Fatalf("error reading packet: %s\n", err) } - challenge, ok := outputPacket.(*ChallengePacket) - if !ok { - t.Fatalf("did not get a challenge packet after read") - } - - if inputPacket.ChallengeTokenSequence != challenge.ChallengeTokenSequence { - t.Fatalf("input and output sequence differed, expected %d got %d\n", inputPacket.ChallengeTokenSequence, challenge.ChallengeTokenSequence) + if inputPacket.ChallengeTokenSequence != outputPacket.ChallengeTokenSequence { + t.Fatalf("input and output sequence differed, expected %d got %d\n", inputPacket.ChallengeTokenSequence, outputPacket.ChallengeTokenSequence) } - if bytes.Compare(inputPacket.ChallengeTokenData, challenge.ChallengeTokenData) != 0 { + if bytes.Compare(inputPacket.ChallengeTokenData, outputPacket.ChallengeTokenData) != 0 { t.Fatalf("challenge token data was not equal\n") } } @@ -227,7 +200,7 @@ func TestConnectionChallengePacket(t *testing.T) { func TestConnectionResponsePacket(t *testing.T) { var err error - // setup a connection challenge packet + // setup a connection response packet inputPacket := &ResponsePacket{} inputPacket.ChallengeTokenSequence = 0 inputPacket.ChallengeTokenData, err = RandomBytes(CHALLENGE_TOKEN_BYTES) @@ -243,7 +216,7 @@ func TestConnectionResponsePacket(t *testing.T) { } // write the packet to a buffer - bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) if err != nil { t.Fatalf("error writing packet: %s\n", err) } @@ -252,35 +225,26 @@ func TestConnectionResponsePacket(t *testing.T) { t.Fatalf("did not write any bytes for this packet") } - var sequence uint64 - sequence = TEST_SEQUENCE_START - allowedPackets := make([]byte, ConnectionNumPackets) - for i := 0; i < len(allowedPackets); i+=1 { + for i := 0; i < len(allowedPackets); i += 1 { allowedPackets[i] = 1 } + outputPacket := &ResponsePacket{} buffer.Reset() - outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) - if err != nil { + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { t.Fatalf("error reading packet: %s\n", err) } - response, ok := outputPacket.(*ResponsePacket) - if !ok { - t.Fatalf("did not get a response packet after read") - } - - if inputPacket.ChallengeTokenSequence != response.ChallengeTokenSequence { - t.Fatalf("input and output sequence differed, expected %d got %d\n", inputPacket.ChallengeTokenSequence, response.ChallengeTokenSequence) + if inputPacket.ChallengeTokenSequence != outputPacket.ChallengeTokenSequence { + t.Fatalf("input and output sequence differed, expected %d got %d\n", inputPacket.ChallengeTokenSequence, outputPacket.ChallengeTokenSequence) } - if bytes.Compare(inputPacket.ChallengeTokenData, response.ChallengeTokenData) != 0 { - t.Fatalf("challenge token data was not equal\n") + if bytes.Compare(inputPacket.ChallengeTokenData, outputPacket.ChallengeTokenData) != 0 { + t.Fatalf("response challenge token data was not equal\n") } } - func TestConnectionKeepAlivePacket(t *testing.T) { var err error @@ -297,7 +261,7 @@ func TestConnectionKeepAlivePacket(t *testing.T) { } // write the packet to a buffer - bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) if err != nil { t.Fatalf("error writing packet: %s\n", err) } @@ -306,44 +270,35 @@ func TestConnectionKeepAlivePacket(t *testing.T) { t.Fatalf("did not write any bytes for this packet") } - var sequence uint64 - sequence = TEST_SEQUENCE_START - allowedPackets := make([]byte, ConnectionNumPackets) - for i := 0; i < len(allowedPackets); i+=1 { + for i := 0; i < len(allowedPackets); i += 1 { allowedPackets[i] = 1 } + outputPacket := &KeepAlivePacket{} buffer.Reset() - outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) - if err != nil { + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { t.Fatalf("error reading packet: %s\n", err) } - keepalive, ok := outputPacket.(*KeepAlivePacket) - if !ok { - t.Fatalf("did not get a response packet after read") + if inputPacket.ClientIndex != outputPacket.ClientIndex { + t.Fatalf("input and output index differed, expected %d got %d\n", inputPacket.ClientIndex, outputPacket.ClientIndex) } - if inputPacket.ClientIndex != keepalive.ClientIndex { - t.Fatalf("input and output index differed, expected %d got %d\n", inputPacket.ClientIndex, keepalive.ClientIndex) - } - - if inputPacket.MaxClients != keepalive.MaxClients { - t.Fatalf("input and output maxclients differed, expected %d got %d\n", inputPacket.MaxClients, keepalive.MaxClients) + if inputPacket.MaxClients != outputPacket.MaxClients { + t.Fatalf("input and output maxclients differed, expected %d got %d\n", inputPacket.MaxClients, outputPacket.MaxClients) } } func TestConnectionPayloadPacket(t *testing.T) { var err error - - // setup a connection challenge packet - inputPacket := NewPayloadPacket(MAX_PAYLOAD_BYTES) - inputPacket.PayloadData, err = RandomBytes(MAX_PAYLOAD_BYTES) + payloadData, err := RandomBytes(MAX_PAYLOAD_BYTES) if err != nil { t.Fatalf("error generating random payload data: %s\n", err) } + inputPacket := NewPayloadPacket(payloadData) + buffer := NewBuffer(MAX_PACKET_BYTES) packetKey, err := GenerateKey() @@ -352,7 +307,7 @@ func TestConnectionPayloadPacket(t *testing.T) { } // write the packet to a buffer - bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) if err != nil { t.Fatalf("error writing packet: %s\n", err) } @@ -361,31 +316,24 @@ func TestConnectionPayloadPacket(t *testing.T) { t.Fatalf("did not write any bytes for this packet") } - var sequence uint64 - sequence = TEST_SEQUENCE_START - allowedPackets := make([]byte, ConnectionNumPackets) - for i := 0; i < len(allowedPackets); i+=1 { + for i := 0; i < len(allowedPackets); i += 1 { allowedPackets[i] = 1 } buffer.Reset() - outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) - if err != nil { - t.Fatalf("error reading packet: %s\n", err) - } + outputPacket := &PayloadPacket{} - payload, ok := outputPacket.(*PayloadPacket) - if !ok { - t.Fatalf("did not get a payload packet after read") + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { + t.Fatalf("error reading packet: %s\n", err) } - if inputPacket.PayloadBytes != payload.PayloadBytes { - t.Fatalf("input and output index differed, expected %d got %d\n", inputPacket.PayloadBytes, payload.PayloadBytes) + if inputPacket.PayloadBytes != outputPacket.PayloadBytes { + t.Fatalf("input and output index differed, expected %d got %d\n", inputPacket.PayloadBytes, outputPacket.PayloadBytes) } - if bytes.Compare(inputPacket.PayloadData, payload.PayloadData) != 0 { - t.Fatalf("input and output payload differed, expected %v got %v\n", inputPacket.PayloadData, payload.PayloadData) + if bytes.Compare(inputPacket.PayloadData, outputPacket.PayloadData) != 0 { + t.Fatalf("input and output payload differed, expected %v got %v\n", inputPacket.PayloadData, outputPacket.PayloadData) } } @@ -399,7 +347,7 @@ func TestDisconnectPacket(t *testing.T) { } // write the packet to a buffer - bytesWritten, err := WritePacket(inputPacket, buffer, TEST_SEQUENCE_START, packetKey, TEST_PROTOCOL_ID) + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) if err != nil { t.Fatalf("error writing packet: %s\n", err) } @@ -408,66 +356,52 @@ func TestDisconnectPacket(t *testing.T) { t.Fatalf("did not write any bytes for this packet") } - var sequence uint64 - sequence = TEST_SEQUENCE_START - allowedPackets := make([]byte, ConnectionNumPackets) - for i := 0; i < len(allowedPackets); i+=1 { + for i := 0; i < len(allowedPackets); i += 1 { allowedPackets[i] = 1 } buffer.Reset() - outputPacket, err := ReadPacket(buffer.Buf, bytesWritten, sequence, packetKey, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), nil, allowedPackets, nil) - if err != nil { + outputPacket := &DisconnectPacket{} + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { t.Fatalf("error reading packet: %s\n", err) } - _, ok := outputPacket.(*DisconnectPacket) - if !ok { - t.Fatalf("did not get a disconnect packet after read") - } } func testBuildRequestPacket(connectTokenKey []byte, t *testing.T) (*RequestPacket, []byte) { addr := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: TEST_SERVER_PORT} serverAddrs := make([]net.UDPAddr, 1) serverAddrs[0] = addr - config := NewConfig(serverAddrs, TEST_CONNECT_TOKEN_EXPIRY, TEST_PROTOCOL_ID, TEST_PRIVATE_KEY) + config := NewConfig(serverAddrs, TEST_TIMEOUT_SECONDS, TEST_CONNECT_TOKEN_EXPIRY, TEST_CLIENT_ID, TEST_PROTOCOL_ID, connectTokenKey) connectToken := NewConnectToken() - currentTimestamp := uint64(time.Now().Unix()) - expireTimestamp := uint64(time.Now().Unix()) + config.TokenExpiry - if err := connectToken.Generate(config, TEST_CLIENT_ID, currentTimestamp, TEST_SEQUENCE_START); err != nil { + if err := connectToken.Generate(config, TEST_SEQUENCE_START); err != nil { t.Fatalf("error generating connect token: %s\n", err) } - privateData, err := connectToken.Write() + _, err := connectToken.Write() if err != nil { t.Fatalf("error writing private data: %s\n", err) } - if err := EncryptConnectTokenPrivate(&privateData, TEST_PROTOCOL_ID, expireTimestamp, TEST_SEQUENCE_START, connectTokenKey); err != nil { - t.Fatalf("error encrypting connect token private %s\n", err) - } - - decryptedToken, err := DecryptConnectTokenPrivate(privateData, TEST_PROTOCOL_ID, expireTimestamp, TEST_SEQUENCE_START, connectTokenKey) + decryptedToken, err := connectToken.PrivateData.Decrypt(TEST_PROTOCOL_ID, connectToken.ExpireTimestamp, TEST_SEQUENCE_START, connectTokenKey) if err != nil { t.Fatalf("error decrypting connect token: %s", err) } - - _, err = ReadConnectToken(privateData, TEST_PROTOCOL_ID, expireTimestamp, TEST_SEQUENCE_START, connectTokenKey) - if err != nil { - t.Fatalf("error reading connect token: %s\n", err) + // need to re-encrypt the private data + if err := connectToken.PrivateData.Encrypt(TEST_PROTOCOL_ID, connectToken.ExpireTimestamp, TEST_SEQUENCE_START, connectTokenKey); err != nil { + t.Fatalf("error re-encrypting connect private token: %s\n", err) } // setup a connection request packet wrapping the encrypted connect token inputPacket := &RequestPacket{} inputPacket.VersionInfo = []byte(VERSION_INFO) inputPacket.ProtocolId = TEST_PROTOCOL_ID - inputPacket.ConnectTokenExpireTimestamp = expireTimestamp + inputPacket.ConnectTokenExpireTimestamp = connectToken.ExpireTimestamp inputPacket.ConnectTokenSequence = TEST_SEQUENCE_START - inputPacket.Token = connectToken - inputPacket.ConnectTokenData = privateData + inputPacket.Token = connectToken.PrivateData + inputPacket.ConnectTokenData = connectToken.PrivateData.Buffer() return inputPacket, decryptedToken -} \ No newline at end of file +} From 50d187428b76104e0594fa133bf114ef4d3aa6ae Mon Sep 17 00:00:00 2001 From: wirepair Date: Fri, 7 Apr 2017 08:47:12 +0900 Subject: [PATCH 09/11] move files to netcode directory for naming purposes --- go/netcode/README.md | 21 +++++++++++++++++++ go/{ => netcode}/buffer.go | 0 go/{ => netcode}/buffer_test.go | 0 go/{ => netcode}/challenge_token.go | 0 go/{ => netcode}/challenge_token_test.go | 0 go/{ => netcode}/client.go | 0 go/{ => netcode}/config.go | 0 go/{ => netcode}/connect_token.go | 0 go/{ => netcode}/connect_token_private.go | 0 .../connect_token_private_test.go | 0 go/{ => netcode}/connect_token_shared.go | 0 go/{ => netcode}/connect_token_test.go | 0 go/{ => netcode}/crypto.go | 0 go/{ => netcode}/encryption_manager.go | 0 go/{ => netcode}/packet.go | 0 go/{ => netcode}/packet_test.go | 0 go/{ => netcode}/queue.go | 0 go/{ => netcode}/replay_protection.go | 0 go/{ => netcode}/replay_protection_test.go | 0 go/{ => netcode}/server.go | 0 go/{ => netcode}/simulator.go | 0 go/{ => netcode}/sizes.go | 0 go/{ => netcode}/socket.go | 0 23 files changed, 21 insertions(+) create mode 100644 go/netcode/README.md rename go/{ => netcode}/buffer.go (100%) rename go/{ => netcode}/buffer_test.go (100%) rename go/{ => netcode}/challenge_token.go (100%) rename go/{ => netcode}/challenge_token_test.go (100%) rename go/{ => netcode}/client.go (100%) rename go/{ => netcode}/config.go (100%) rename go/{ => netcode}/connect_token.go (100%) rename go/{ => netcode}/connect_token_private.go (100%) rename go/{ => netcode}/connect_token_private_test.go (100%) rename go/{ => netcode}/connect_token_shared.go (100%) rename go/{ => netcode}/connect_token_test.go (100%) rename go/{ => netcode}/crypto.go (100%) rename go/{ => netcode}/encryption_manager.go (100%) rename go/{ => netcode}/packet.go (100%) rename go/{ => netcode}/packet_test.go (100%) rename go/{ => netcode}/queue.go (100%) rename go/{ => netcode}/replay_protection.go (100%) rename go/{ => netcode}/replay_protection_test.go (100%) rename go/{ => netcode}/server.go (100%) rename go/{ => netcode}/simulator.go (100%) rename go/{ => netcode}/sizes.go (100%) rename go/{ => netcode}/socket.go (100%) diff --git a/go/netcode/README.md b/go/netcode/README.md new file mode 100644 index 0000000..8efc2e3 --- /dev/null +++ b/go/netcode/README.md @@ -0,0 +1,21 @@ +Draft Implementation of netcode.io for Go +========================================= + +This is the main repository for the Go implementation of [netcode.io](https://netcode.io). This repository and the API are highly violatile until the client and server implementations have been completed. + +## Dependencies +codehale's implementation of [chacha20poly130](https://github.com/codahale/chacha20poly1305). While I would have liked to use [https://godoc.org/golang.org/x/crypto/chacha20poly1305](https://godoc.org/golang.org/x/crypto/chacha20poly1305) it only implements the IETF version with nonce size of 12 bytes. + +## Documentation +[godocs](https://godoc.org/github.com/networkprotocol/netcode.io/go/netcode/) (todo, this may not work due to godocs probably expecting the package to be the top level of the repository). + +## TODO +- Implement Client +- Implement Server + +## Completed +- Implemented packet and token portion of protocol + + +## Author +[Isaac Dawson](https://github.com/wirepair) \ No newline at end of file diff --git a/go/buffer.go b/go/netcode/buffer.go similarity index 100% rename from go/buffer.go rename to go/netcode/buffer.go diff --git a/go/buffer_test.go b/go/netcode/buffer_test.go similarity index 100% rename from go/buffer_test.go rename to go/netcode/buffer_test.go diff --git a/go/challenge_token.go b/go/netcode/challenge_token.go similarity index 100% rename from go/challenge_token.go rename to go/netcode/challenge_token.go diff --git a/go/challenge_token_test.go b/go/netcode/challenge_token_test.go similarity index 100% rename from go/challenge_token_test.go rename to go/netcode/challenge_token_test.go diff --git a/go/client.go b/go/netcode/client.go similarity index 100% rename from go/client.go rename to go/netcode/client.go diff --git a/go/config.go b/go/netcode/config.go similarity index 100% rename from go/config.go rename to go/netcode/config.go diff --git a/go/connect_token.go b/go/netcode/connect_token.go similarity index 100% rename from go/connect_token.go rename to go/netcode/connect_token.go diff --git a/go/connect_token_private.go b/go/netcode/connect_token_private.go similarity index 100% rename from go/connect_token_private.go rename to go/netcode/connect_token_private.go diff --git a/go/connect_token_private_test.go b/go/netcode/connect_token_private_test.go similarity index 100% rename from go/connect_token_private_test.go rename to go/netcode/connect_token_private_test.go diff --git a/go/connect_token_shared.go b/go/netcode/connect_token_shared.go similarity index 100% rename from go/connect_token_shared.go rename to go/netcode/connect_token_shared.go diff --git a/go/connect_token_test.go b/go/netcode/connect_token_test.go similarity index 100% rename from go/connect_token_test.go rename to go/netcode/connect_token_test.go diff --git a/go/crypto.go b/go/netcode/crypto.go similarity index 100% rename from go/crypto.go rename to go/netcode/crypto.go diff --git a/go/encryption_manager.go b/go/netcode/encryption_manager.go similarity index 100% rename from go/encryption_manager.go rename to go/netcode/encryption_manager.go diff --git a/go/packet.go b/go/netcode/packet.go similarity index 100% rename from go/packet.go rename to go/netcode/packet.go diff --git a/go/packet_test.go b/go/netcode/packet_test.go similarity index 100% rename from go/packet_test.go rename to go/netcode/packet_test.go diff --git a/go/queue.go b/go/netcode/queue.go similarity index 100% rename from go/queue.go rename to go/netcode/queue.go diff --git a/go/replay_protection.go b/go/netcode/replay_protection.go similarity index 100% rename from go/replay_protection.go rename to go/netcode/replay_protection.go diff --git a/go/replay_protection_test.go b/go/netcode/replay_protection_test.go similarity index 100% rename from go/replay_protection_test.go rename to go/netcode/replay_protection_test.go diff --git a/go/server.go b/go/netcode/server.go similarity index 100% rename from go/server.go rename to go/netcode/server.go diff --git a/go/simulator.go b/go/netcode/simulator.go similarity index 100% rename from go/simulator.go rename to go/netcode/simulator.go diff --git a/go/sizes.go b/go/netcode/sizes.go similarity index 100% rename from go/sizes.go rename to go/netcode/sizes.go diff --git a/go/socket.go b/go/netcode/socket.go similarity index 100% rename from go/socket.go rename to go/netcode/socket.go From 24373382171f48df08dc242d812701272ed0ab0c Mon Sep 17 00:00:00 2001 From: wirepair Date: Fri, 7 Apr 2017 09:02:29 +0900 Subject: [PATCH 10/11] format code --- go/netcode/buffer.go | 25 +++++++++--------- go/netcode/buffer_test.go | 9 ++----- go/netcode/challenge_token.go | 4 +-- go/netcode/config.go | 14 +++++----- go/netcode/connect_token.go | 18 +++++++------ go/netcode/connect_token_private.go | 4 +-- go/netcode/connect_token_shared.go | 2 +- go/netcode/packet.go | 40 +++++++++++++++++++---------- go/netcode/replay_protection.go | 14 +++++----- go/netcode/sizes.go | 2 +- go/netcode/socket.go | 4 +-- 11 files changed, 72 insertions(+), 64 deletions(-) diff --git a/go/netcode/buffer.go b/go/netcode/buffer.go index c43864d..7fb2504 100644 --- a/go/netcode/buffer.go +++ b/go/netcode/buffer.go @@ -1,8 +1,8 @@ package netcode import ( - "math" "io" + "math" ) // Buffer is a helper struct for serializing and deserializing as the caller @@ -10,8 +10,9 @@ import ( // or writing to. type Buffer struct { Buf []byte // the backing byte slice - Pos int // current position in read/write + Pos int // current position in read/write } + // Creates a new Buffer with a backing byte slice of the provided size func NewBuffer(size int) *Buffer { b := &Buffer{} @@ -58,17 +59,17 @@ func (b *Buffer) GetBytes(length int) ([]byte, error) { if len(b.Buf) < length { return nil, io.EOF } - value := b.Buf[b.Pos:b.Pos+length] + value := b.Buf[b.Pos : b.Pos+length] b.Pos += length return value, nil } // GetUint8 decodes a little-endian uint8 from the buffer func (b *Buffer) GetUint8() (uint8, error) { - if b.Pos + SizeUint8 > len(b.Buf) { + if b.Pos+SizeUint8 > len(b.Buf) { return 0, io.EOF } - buf := b.Buf[b.Pos:b.Pos+SizeUint8] + buf := b.Buf[b.Pos : b.Pos+SizeUint8] b.Pos++ return uint8(buf[0]), nil } @@ -118,11 +119,11 @@ func (b *Buffer) GetUint64() (uint64, error) { } // GetInt8 decodes a little-endian int8 from the buffer -func (b *Buffer) GetInt8() (int8, error) { - if b.Pos + 1 > len(b.Buf) { +func (b *Buffer) GetInt8() (int8, error) { + if b.Pos+1 > len(b.Buf) { return 0, io.EOF } - buf := b.Buf[b.Pos:b.Pos+SizeInt8] + buf := b.Buf[b.Pos : b.Pos+SizeInt8] return int8(buf[0]), nil } @@ -170,7 +171,6 @@ func (b *Buffer) GetInt64() (int64, error) { return n, nil } - // WriteByte encodes a little-endian uint8 into the buffer. func (b *Buffer) WriteByte(n byte) { b.Buf[b.Pos] = uint8(n) @@ -179,19 +179,18 @@ func (b *Buffer) WriteByte(n byte) { // WriteBytes encodes a little-endian byte slice into the buffer func (b *Buffer) WriteBytes(src []byte) { - for i := 0; i < len(src); i+=1 { + for i := 0; i < len(src); i += 1 { b.WriteByte(uint8(src[i])) } } // WriteBytes encodes a little-endian byte slice into the buffer func (b *Buffer) WriteBytesN(src []byte, length int) { - for i := 0; i < length; i+=1 { + for i := 0; i < length; i += 1 { b.WriteByte(uint8(src[i])) } } - // WriteUint8 encodes a little-endian uint8 into the buffer. func (b *Buffer) WriteUint8(n uint8) { b.Buf[b.Pos] = byte(n) @@ -266,6 +265,6 @@ func (b *Buffer) WriteFloat32(n float32) { } // WriteFloat64 encodes a little-endian float64 into the buffer. -func (b *Buffer) WriteFloat64(buf []byte, n float64) { +func (b *Buffer) WriteFloat64(buf []byte, n float64) { b.WriteUint64(math.Float64bits(n)) } diff --git a/go/netcode/buffer_test.go b/go/netcode/buffer_test.go index 4cdeb29..4f315c0 100644 --- a/go/netcode/buffer_test.go +++ b/go/netcode/buffer_test.go @@ -29,8 +29,6 @@ func TestBuffer_Copy(t *testing.T) { t.Fatalf("error reading bytes from copy: %s\n", err) } - t.Logf("%s\n", string(data)) - } func TestBuffer_GetByte(t *testing.T) { @@ -48,7 +46,6 @@ func TestBuffer_GetByte(t *testing.T) { } } - func TestBuffer_GetBytes(t *testing.T) { buf := make([]byte, 2) buf[0] = 'a' @@ -74,7 +71,6 @@ func TestBuffer_GetBytes(t *testing.T) { } - func TestBuffer_GetInt8(t *testing.T) { writer := NewBuffer(SizeInt8) writer.WriteInt8(0x0f) @@ -131,7 +127,6 @@ func TestBuffer_GetInt16(t *testing.T) { } } - func TestBuffer_GetInt32(t *testing.T) { writer := NewBuffer(SizeInt32) writer.WriteInt32(0x0fffffff) @@ -254,13 +249,13 @@ func TestBuffer_WriteBytes(t *testing.T) { w := NewBuffer(10) w.WriteBytes([]byte("0123456789")) r := w.Copy() - val, err :=r.GetBytes(10) + val, err := r.GetBytes(10) if err != nil { t.Fatal(err) } if string(val) != "0123456789" { - t.Fatalf("expected 0123456789 got: %s %d\n",val, len(val)) + t.Fatalf("expected 0123456789 got: %s %d\n", val, len(val)) } } diff --git a/go/netcode/challenge_token.go b/go/netcode/challenge_token.go index 38e0ce0..6dc5c63 100644 --- a/go/netcode/challenge_token.go +++ b/go/netcode/challenge_token.go @@ -2,8 +2,8 @@ package netcode // Challenge tokens are used in certain packet types type ChallengeToken struct { - ClientId uint64 // the clientId associated with this token - UserData *Buffer // the userdata payload + ClientId uint64 // the clientId associated with this token + UserData *Buffer // the userdata payload TokenData *Buffer // the serialized payload container } diff --git a/go/netcode/config.go b/go/netcode/config.go index 18db612..2e6ff68 100644 --- a/go/netcode/config.go +++ b/go/netcode/config.go @@ -2,15 +2,17 @@ package netcode import "net" +// A configuration container for various properties that are passed to packets type Config struct { - ClientId uint64 - ServerAddrs []net.UDPAddr - TokenExpiry uint64 - TimeoutSeconds uint32 - ProtocolId uint64 - PrivateKey []byte + ClientId uint64 // client id used in packet generation + ServerAddrs []net.UDPAddr // list of server addresses + TokenExpiry uint64 // when the token expires, current time + this value. + TimeoutSeconds uint32 // timeout in seconds for connect token + ProtocolId uint64 // the protocol id used between server <-> client + PrivateKey []byte // the private key used for encryption } +// Creates a new config holder for ease of passing around to packet generation and client/servers func NewConfig(serverAddrs []net.UDPAddr, timeoutSeconds uint32, expiry, clientId, protocolId uint64, privateKey []byte) *Config { c := &Config{} c.ClientId = clientId diff --git a/go/netcode/connect_token.go b/go/netcode/connect_token.go index c2dd9f9..be489ea 100644 --- a/go/netcode/connect_token.go +++ b/go/netcode/connect_token.go @@ -6,24 +6,26 @@ import ( "time" ) +// ip types used in serialization of server addresses const ( ADDRESS_NONE = iota ADDRESS_IPV4 ADDRESS_IPV6 ) +// number of bytes for connect tokens const CONNECT_TOKEN_BYTES = 2048 // Token used for connecting type ConnectToken struct { - sharedTokenData - VersionInfo []byte - ProtocolId uint64 - CreateTimestamp uint64 - ExpireTimestamp uint64 - Sequence uint64 - PrivateData *ConnectTokenPrivate - TimeoutSeconds uint32 + sharedTokenData // a shared container holding the server addresses, client and server keys + VersionInfo []byte // the version information for client <-> server communications + ProtocolId uint64 // protocol id for communications + CreateTimestamp uint64 // when this token was created + ExpireTimestamp uint64 // when this token expires + Sequence uint64 // the sequence id + PrivateData *ConnectTokenPrivate // reference to the private parts of this connect token + TimeoutSeconds uint32 // timeout of connect token in seconds } // Create a new empty token and empty private token diff --git a/go/netcode/connect_token_private.go b/go/netcode/connect_token_private.go index b153960..e2feb11 100644 --- a/go/netcode/connect_token_private.go +++ b/go/netcode/connect_token_private.go @@ -74,7 +74,6 @@ func (p *ConnectTokenPrivate) Write() ([]byte, error) { // Encrypts, in place, the TokenData buffer, assumes Write() has already been called. func (token *ConnectTokenPrivate) Encrypt(protocolId, expireTimestamp, sequence uint64, privateKey []byte) error { additionalData, nonce := buildTokenCryptData(protocolId, expireTimestamp, sequence) - if err := EncryptAead(&token.TokenData.Buf, additionalData, nonce, privateKey); err != nil { return err } @@ -85,8 +84,8 @@ func (token *ConnectTokenPrivate) Encrypt(protocolId, expireTimestamp, sequence // (most likely via NewConnectTokenPrivateEncrypted(...)). Optionally returns the decrypted buffer to caller. func (token *ConnectTokenPrivate) Decrypt(protocolId, expireTimestamp, sequence uint64, privateKey []byte) ([]byte, error) { var err error - additionalData, nonce := buildTokenCryptData(protocolId, expireTimestamp, sequence) + additionalData, nonce := buildTokenCryptData(protocolId, expireTimestamp, sequence) if token.TokenData.Buf, err = DecryptAead(token.TokenData.Buf, additionalData, nonce, privateKey); err != nil { return nil, err } @@ -103,6 +102,5 @@ func buildTokenCryptData(protocolId, expireTimestamp, sequence uint64) ([]byte, nonce := NewBuffer(SizeUint64) nonce.WriteUint64(sequence) - return additionalData.Buf, nonce.Buf } diff --git a/go/netcode/connect_token_shared.go b/go/netcode/connect_token_shared.go index 306bcfb..65c8949 100644 --- a/go/netcode/connect_token_shared.go +++ b/go/netcode/connect_token_shared.go @@ -14,7 +14,7 @@ type sharedTokenData struct { ServerKey []byte // server to client key } -// Reads and validates the servers and client <-> server keys. +// Reads and validates the servers, client <-> server keys. func (shared *sharedTokenData) ReadShared(buffer *Buffer) error { var err error var servers uint32 diff --git a/go/netcode/packet.go b/go/netcode/packet.go index 084c815..572fe90 100644 --- a/go/netcode/packet.go +++ b/go/netcode/packet.go @@ -25,6 +25,7 @@ const MAX_SERVERS_PER_CONNECT = 32 const VERSION_INFO = "NETCODE 1.00\x00" +// Used for determining the type of packet, part of the serialization protocol type PacketType uint8 const ( @@ -37,9 +38,7 @@ const ( ConnectionDisconnect ) -// not a packet type, but value is last packetType+1 -const ConnectionNumPackets = ConnectionDisconnect + 1 - +// reference map of packet -> string values var packetTypeMap = map[PacketType]string{ ConnectionRequest: "CONNECTION_REQUEST", ConnectionDenied: "CONNECTION_DENIED", @@ -50,19 +49,24 @@ var packetTypeMap = map[PacketType]string{ ConnectionDisconnect: "CONNECTION_DISCONNECT", } +// not a packet type, but value is last packetType+1 +const ConnectionNumPackets = ConnectionDisconnect + 1 + +// Packet interface supporting reading and writing. type Packet interface { - GetType() PacketType - Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) - Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error + GetType() PacketType // returns the packet type + Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) // writes the packet data to the supplied buffer. + Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error // reads in data from the supplied buffer to set the packet properties } +// The connection request packet type RequestPacket struct { - VersionInfo []byte - ProtocolId uint64 - ConnectTokenExpireTimestamp uint64 - ConnectTokenSequence uint64 - Token *ConnectTokenPrivate - ConnectTokenData []byte // the encrypted Token after Write -> Encrypt + VersionInfo []byte // version information of communications + ProtocolId uint64 // protocol id used in communications + ConnectTokenExpireTimestamp uint64 // when the connect token expires + ConnectTokenSequence uint64 // the sequence id of this token + Token *ConnectTokenPrivate // reference to the private parts of this packet + ConnectTokenData []byte // the encrypted Token after Write -> Encrypt } // Writes the RequestPacket data to a supplied buffer and returns the length of bytes written to it. @@ -150,6 +154,7 @@ func (p *RequestPacket) GetType() PacketType { return ConnectionRequest } +// Denied packet type, contains no information type DeniedPacket struct { } @@ -179,6 +184,7 @@ func (p *DeniedPacket) GetType() PacketType { return ConnectionDenied } +// Challenge packet containing token data and the sequence id used type ChallengePacket struct { ChallengeTokenSequence uint64 ChallengeTokenData []byte @@ -224,6 +230,7 @@ func (p *ChallengePacket) GetType() PacketType { return ConnectionChallenge } +// Response packet, containing the token data and sequence id type ResponsePacket struct { ChallengeTokenSequence uint64 ChallengeTokenData []byte @@ -269,6 +276,7 @@ func (p *ResponsePacket) GetType() PacketType { return ConnectionResponse } +// used for heart beats type KeepAlivePacket struct { ClientIndex uint32 MaxClients uint32 @@ -314,6 +322,7 @@ func (p *KeepAlivePacket) GetType() PacketType { return ConnectionKeepAlive } +// Contains user supplied payload data between server <-> client type PayloadPacket struct { PayloadBytes uint32 PayloadData []byte @@ -323,6 +332,7 @@ func (p *PayloadPacket) GetType() PacketType { return ConnectionPayload } +// Helper function to create a new payload packet with the supplied buffer func NewPayloadPacket(payloadData []byte) *PayloadPacket { packet := &PayloadPacket{} packet.PayloadBytes = uint32(len(payloadData)) @@ -362,6 +372,7 @@ func (p *PayloadPacket) Read(packetBuffer *Buffer, packetLen int, protocolId, cu return nil } +// Signals to server/client to disconnect, contains no data. type DisconnectPacket struct { } @@ -391,6 +402,7 @@ func (p *DisconnectPacket) GetType() PacketType { return ConnectionDisconnect } +// Decrypts the packet after reading in the prefix byte and sequence id. Used for all PacketTypes except RequestPacket. Returns a buffer containing the decrypted data func decryptPacket(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) (*Buffer, error) { var packetSequence uint64 @@ -428,6 +440,7 @@ func decryptPacket(packetBuffer *Buffer, packetLen int, protocolId, currentTimes return NewBufferFromBytes(decryptedBuff), nil } +// Reads and verifies the sequence id func readSequence(packetBuffer *Buffer, packetLen int, prefixByte uint8) (uint64, error) { var sequence uint64 @@ -452,7 +465,7 @@ func readSequence(packetBuffer *Buffer, packetLen int, prefixByte uint8) (uint64 return sequence, nil } -// Validates the data prior to encrypted blob is valid before we bother attempting to decrypt. +// Validates the data prior to the encrypted segment before we bother attempting to decrypt. func validateSequence(packetLen int, prefixByte uint8, sequence uint64, readPacketKey, allowedPackets []byte, replayProtection *ReplayProtection) error { if readPacketKey == nil { @@ -503,6 +516,7 @@ func writePacketPrefix(p Packet, buffer *Buffer, sequence uint64) (uint8, error) return prefixByte, nil } +// Encrypts the packet data of the supplied buffer between encryptedStart and encrypedFinish. func encryptPacket(buffer *Buffer, encryptedStart, encryptedFinish int, prefixByte uint8, protocolId, sequence uint64, writePacketKey []byte) (int, error) { // slice up the buffer for the bits we will encrypt encryptedBuffer := buffer.Buf[encryptedStart:encryptedFinish] diff --git a/go/netcode/replay_protection.go b/go/netcode/replay_protection.go index f07fc2f..1487ead 100644 --- a/go/netcode/replay_protection.go +++ b/go/netcode/replay_protection.go @@ -2,8 +2,8 @@ package netcode // Our type to hold replay protection of packet sequences type ReplayProtection struct { - MostRecentSequence uint64 // last sequence recv'd - ReceivedPacket []uint64 // slice of REPLAY_PROTECTION_BUFFER_SIZE worth of packet sequences + MostRecentSequence uint64 // last sequence recv'd + ReceivedPacket []uint64 // slice of REPLAY_PROTECTION_BUFFER_SIZE worth of packet sequences } // Initializes a new ReplayProtection with the ReceivedPacket buffer elements all set to 0xFFFFFFFFFFFFFFFF @@ -22,15 +22,15 @@ func (r *ReplayProtection) Reset() { // Tests that the sequence has not already been recv'd, adding it to the buffer if it's new. func (r *ReplayProtection) AlreadyReceived(sequence uint64) int { - if sequence & (uint64(1 << 63)) != 0 { + if sequence&(uint64(1<<63)) != 0 { return 0 } - if sequence + REPLAY_PROTECTION_BUFFER_SIZE <= r.MostRecentSequence { + if sequence+REPLAY_PROTECTION_BUFFER_SIZE <= r.MostRecentSequence { return 1 } - if sequence > r.MostRecentSequence { + if sequence > r.MostRecentSequence { r.MostRecentSequence = sequence } @@ -50,7 +50,7 @@ func (r *ReplayProtection) AlreadyReceived(sequence uint64) int { } func clearPacketBuffer(packets []uint64) { - for i := 0; i < len(packets); i+=1 { + for i := 0; i < len(packets); i += 1 { packets[i] = 0xFFFFFFFFFFFFFFFF } -} \ No newline at end of file +} diff --git a/go/netcode/sizes.go b/go/netcode/sizes.go index ad7a684..7340bd2 100644 --- a/go/netcode/sizes.go +++ b/go/netcode/sizes.go @@ -53,4 +53,4 @@ const ( // byteSliceToString converts a []byte to string without a heap allocation. func byteSliceToString(b []byte) string { return *(*string)(unsafe.Pointer(&b)) -} \ No newline at end of file +} diff --git a/go/netcode/socket.go b/go/netcode/socket.go index 029f679..5e1c6ec 100644 --- a/go/netcode/socket.go +++ b/go/netcode/socket.go @@ -19,7 +19,7 @@ const ( type Socket struct { Address *net.UDPAddr - Conn *net.UDPConn + Conn *net.UDPConn } func NewSocket() *Socket { @@ -66,8 +66,6 @@ func (s *Socket) Recv(source *net.Addr, data []byte, maxsize uint) error { return nil } - func (s *Socket) Destroy() { s.Conn.Close() } - From 434477c82d7436efc675f1913c7198ac4a98a337 Mon Sep 17 00:00:00 2001 From: wirepair Date: Fri, 7 Apr 2017 09:07:03 +0900 Subject: [PATCH 11/11] fix unit test and remove debug logging --- go/netcode/buffer_test.go | 5 +++++ go/netcode/packet.go | 2 -- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/go/netcode/buffer_test.go b/go/netcode/buffer_test.go index 4f315c0..8423559 100644 --- a/go/netcode/buffer_test.go +++ b/go/netcode/buffer_test.go @@ -24,11 +24,16 @@ func TestBuffer_Copy(t *testing.T) { if r.Len() != b.Len() { t.Fatalf("expected copy length to be same got: %d and %d\n", r.Len(), b.Len()) } + data, err := r.GetBytes(10) if err != nil { t.Fatalf("error reading bytes from copy: %s\n", err) } + if string(data) != "abcdefghij" { + t.Fatalf("error expeced: %s got %d\n", "abcdefghij", string(data)) + } + } func TestBuffer_GetByte(t *testing.T) { diff --git a/go/netcode/packet.go b/go/netcode/packet.go index 572fe90..7cef99e 100644 --- a/go/netcode/packet.go +++ b/go/netcode/packet.go @@ -2,7 +2,6 @@ package netcode import ( "errors" - "log" "strconv" ) @@ -509,7 +508,6 @@ func writePacketPrefix(p Packet, buffer *Buffer, sequence uint64) (uint8, error) var i uint8 for ; i < sequenceBytes; i += 1 { - log.Printf("sequenceTemp %d: %x\n", i, uint8(sequenceTemp&0xFF)) buffer.WriteUint8(uint8(sequenceTemp & 0xFF)) sequenceTemp >>= 8 }