diff --git a/go/netcode/README.md b/go/netcode/README.md new file mode 100644 index 0000000..8efc2e3 --- /dev/null +++ b/go/netcode/README.md @@ -0,0 +1,21 @@ +Draft Implementation of netcode.io for Go +========================================= + +This is the main repository for the Go implementation of [netcode.io](https://netcode.io). This repository and the API are highly violatile until the client and server implementations have been completed. + +## Dependencies +codehale's implementation of [chacha20poly130](https://github.com/codahale/chacha20poly1305). While I would have liked to use [https://godoc.org/golang.org/x/crypto/chacha20poly1305](https://godoc.org/golang.org/x/crypto/chacha20poly1305) it only implements the IETF version with nonce size of 12 bytes. + +## Documentation +[godocs](https://godoc.org/github.com/networkprotocol/netcode.io/go/netcode/) (todo, this may not work due to godocs probably expecting the package to be the top level of the repository). + +## TODO +- Implement Client +- Implement Server + +## Completed +- Implemented packet and token portion of protocol + + +## Author +[Isaac Dawson](https://github.com/wirepair) \ No newline at end of file diff --git a/go/netcode/buffer.go b/go/netcode/buffer.go new file mode 100644 index 0000000..7fb2504 --- /dev/null +++ b/go/netcode/buffer.go @@ -0,0 +1,270 @@ +package netcode + +import ( + "io" + "math" +) + +// Buffer is a helper struct for serializing and deserializing as the caller +// does not need to externally manage where in the buffer they are currently reading +// or writing to. +type Buffer struct { + Buf []byte // the backing byte slice + Pos int // current position in read/write +} + +// Creates a new Buffer with a backing byte slice of the provided size +func NewBuffer(size int) *Buffer { + b := &Buffer{} + b.Buf = make([]byte, size) + return b +} + +// Creates a new buffer from a byte slice +func NewBufferFromBytes(buf []byte) *Buffer { + b := &Buffer{} + b.Buf = buf + return b +} + +// Returns a copy of Buffer +func (b *Buffer) Copy() *Buffer { + c := NewBufferFromBytes(b.Buf) + return c +} + +// Gets the length of the backing byte slice +func (b *Buffer) Len() int { + return len(b.Buf) +} + +// Returns the backing byte slice +func (b *Buffer) Bytes() []byte { + return b.Buf +} + +// Resets the position back to beginning of buffer +func (b *Buffer) Reset() { + b.Pos = 0 +} + +// GetByte decodes a little-endian byte +func (b *Buffer) GetByte() (byte, error) { + return b.GetUint8() +} + +// GetBytes returns a byte slice possibly smaller than length if bytes are not available from the +// reader. +func (b *Buffer) GetBytes(length int) ([]byte, error) { + if len(b.Buf) < length { + return nil, io.EOF + } + value := b.Buf[b.Pos : b.Pos+length] + b.Pos += length + return value, nil +} + +// GetUint8 decodes a little-endian uint8 from the buffer +func (b *Buffer) GetUint8() (uint8, error) { + if b.Pos+SizeUint8 > len(b.Buf) { + return 0, io.EOF + } + buf := b.Buf[b.Pos : b.Pos+SizeUint8] + b.Pos++ + return uint8(buf[0]), nil +} + +// GetUint16 decodes a little-endian uint16 from the buffer +func (b *Buffer) GetUint16() (uint16, error) { + var n uint16 + buf, err := b.GetBytes(SizeUint16) + if err != nil { + return 0, nil + } + n |= uint16(buf[0]) + n |= uint16(buf[1]) << 8 + return n, nil +} + +// GetUint32 decodes a little-endian uint32 from the buffer +func (b *Buffer) GetUint32() (uint32, error) { + var n uint32 + buf, err := b.GetBytes(SizeUint32) + if err != nil { + return 0, nil + } + n |= uint32(buf[0]) + n |= uint32(buf[1]) << 8 + n |= uint32(buf[2]) << 16 + n |= uint32(buf[3]) << 24 + return n, nil +} + +// GetUint64 decodes a little-endian uint64 from the buffer +func (b *Buffer) GetUint64() (uint64, error) { + var n uint64 + buf, err := b.GetBytes(SizeUint64) + if err != nil { + return 0, nil + } + n |= uint64(buf[0]) + n |= uint64(buf[1]) << 8 + n |= uint64(buf[2]) << 16 + n |= uint64(buf[3]) << 24 + n |= uint64(buf[4]) << 32 + n |= uint64(buf[5]) << 40 + n |= uint64(buf[6]) << 48 + n |= uint64(buf[7]) << 56 + return n, nil +} + +// GetInt8 decodes a little-endian int8 from the buffer +func (b *Buffer) GetInt8() (int8, error) { + if b.Pos+1 > len(b.Buf) { + return 0, io.EOF + } + buf := b.Buf[b.Pos : b.Pos+SizeInt8] + return int8(buf[0]), nil +} + +// GetInt16 decodes a little-endian int16 from the buffer +func (b *Buffer) GetInt16() (int16, error) { + var n int16 + buf, err := b.GetBytes(SizeInt16) + if err != nil { + return 0, nil + } + n |= int16(buf[0]) + n |= int16(buf[1]) << 8 + return n, nil +} + +// GetInt32 decodes a little-endian int32 from the buffer +func (b *Buffer) GetInt32() (int32, error) { + var n int32 + buf, err := b.GetBytes(SizeInt32) + if err != nil { + return 0, nil + } + n |= int32(buf[0]) + n |= int32(buf[1]) << 8 + n |= int32(buf[2]) << 16 + n |= int32(buf[3]) << 24 + return n, nil +} + +// GetInt64 decodes a little-endian int64 from the buffer +func (b *Buffer) GetInt64() (int64, error) { + var n int64 + buf, err := b.GetBytes(SizeInt64) + if err != nil { + return 0, nil + } + n |= int64(buf[0]) + n |= int64(buf[1]) << 8 + n |= int64(buf[2]) << 16 + n |= int64(buf[3]) << 24 + n |= int64(buf[4]) << 32 + n |= int64(buf[5]) << 40 + n |= int64(buf[6]) << 48 + n |= int64(buf[7]) << 56 + return n, nil +} + +// WriteByte encodes a little-endian uint8 into the buffer. +func (b *Buffer) WriteByte(n byte) { + b.Buf[b.Pos] = uint8(n) + b.Pos++ +} + +// WriteBytes encodes a little-endian byte slice into the buffer +func (b *Buffer) WriteBytes(src []byte) { + for i := 0; i < len(src); i += 1 { + b.WriteByte(uint8(src[i])) + } +} + +// WriteBytes encodes a little-endian byte slice into the buffer +func (b *Buffer) WriteBytesN(src []byte, length int) { + for i := 0; i < length; i += 1 { + b.WriteByte(uint8(src[i])) + } +} + +// WriteUint8 encodes a little-endian uint8 into the buffer. +func (b *Buffer) WriteUint8(n uint8) { + b.Buf[b.Pos] = byte(n) + b.Pos++ +} + +// WriteUint16 encodes a little-endian uint16 into the buffer. +func (b *Buffer) WriteUint16(n uint16) { + b.Buf[b.Pos] = byte(n) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 8) + b.Pos++ +} + +// WriteUint32 encodes a little-endian uint32 into the buffer. +func (b *Buffer) WriteUint32(n uint32) { + b.Buf[b.Pos] = byte(n) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 8) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 16) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 24) + b.Pos++ +} + +// WriteUint64 encodes a little-endian uint64 into the buffer. +func (b *Buffer) WriteUint64(n uint64) { + for i := uint(0); i < uint(SizeUint64); i++ { + b.Buf[b.Pos] = byte(n >> (i * 8)) + b.Pos++ + } +} + +// WriteInt8 encodes a little-endian int8 into the buffer. +func (b *Buffer) WriteInt8(n int8) { + b.Buf[b.Pos] = byte(n) + b.Pos++ +} + +// WriteInt16 encodes a little-endian int16 into the buffer. +func (b *Buffer) WriteInt16(n int16) { + b.Buf[b.Pos] = byte(n) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 8) + b.Pos++ +} + +// WriteInt32 encodes a little-endian int32 into the buffer. +func (b *Buffer) WriteInt32(n int32) { + b.Buf[b.Pos] = byte(n) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 8) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 16) + b.Pos++ + b.Buf[b.Pos] = byte(n >> 24) + b.Pos++ +} + +// WriteInt64 encodes a little-endian int64 into the buffer. +func (b *Buffer) WriteInt64(n int64) { + for i := uint(0); i < uint(SizeInt64); i++ { + b.Buf[b.Pos] = byte(n >> (i * 8)) + b.Pos++ + } +} + +// WriteFloat32 encodes a little-endian float32 into the buffer. +func (b *Buffer) WriteFloat32(n float32) { + b.WriteUint32(math.Float32bits(n)) +} + +// WriteFloat64 encodes a little-endian float64 into the buffer. +func (b *Buffer) WriteFloat64(buf []byte, n float64) { + b.WriteUint64(math.Float64bits(n)) +} diff --git a/go/netcode/buffer_test.go b/go/netcode/buffer_test.go new file mode 100644 index 0000000..8423559 --- /dev/null +++ b/go/netcode/buffer_test.go @@ -0,0 +1,266 @@ +package netcode + +import ( + "testing" +) + +func TestBuffer(t *testing.T) { + b := NewBuffer(10) + b.WriteByte('a') + b.WriteBytesN([]byte("bcdefghij"), 9) + + if string(b.Buf) != "abcdefghij" { + t.Fatalf("error should have written 'abcdefghij' got '%s'\n", string(b.Buf)) + } + +} + +func TestBuffer_Copy(t *testing.T) { + b := NewBuffer(10) + b.WriteByte('a') + b.WriteBytesN([]byte("bcdefghij"), 9) + + r := b.Copy() + if r.Len() != b.Len() { + t.Fatalf("expected copy length to be same got: %d and %d\n", r.Len(), b.Len()) + } + + data, err := r.GetBytes(10) + if err != nil { + t.Fatalf("error reading bytes from copy: %s\n", err) + } + + if string(data) != "abcdefghij" { + t.Fatalf("error expeced: %s got %d\n", "abcdefghij", string(data)) + } + +} + +func TestBuffer_GetByte(t *testing.T) { + buf := make([]byte, 1) + buf[0] = 0xfe + b := NewBufferFromBytes(buf) + val, err := b.GetByte() + + if err != nil { + t.Fatal(err) + } + + if val != 0xfe { + t.Fatalf("expected 0xfe got: %x\n", val) + } +} + +func TestBuffer_GetBytes(t *testing.T) { + buf := make([]byte, 2) + buf[0] = 'a' + buf[1] = 'b' + b := NewBufferFromBytes(buf) + + val, err := b.GetBytes(2) + + if err != nil { + t.Fatal(err) + } + + if string(val) != "ab" { + t.Fatalf("expected ab got: %s\n", val) + } + + b = NewBufferFromBytes(buf) + + val, err = b.GetBytes(3) + if err == nil { + t.Fatal("expected EOF") + } + +} + +func TestBuffer_GetInt8(t *testing.T) { + writer := NewBuffer(SizeInt8) + writer.WriteInt8(0x0f) + reader := writer.Copy() + + val, err := reader.GetInt8() + + if err != nil { + t.Fatal(err) + } + + if val != 0xf { + t.Fatalf("expected 0xf got: %x\n", val) + } + + buf := make([]byte, SizeInt8) + buf[0] = 0xff + b := NewBufferFromBytes(buf) + val, err = b.GetInt8() + if err != nil { + t.Fatal(err) + } + + if val != -1 { + t.Fatalf("expected -1 got: %x\n", val) + } +} + +func TestBuffer_GetInt16(t *testing.T) { + writer := NewBuffer(SizeInt16) + writer.WriteInt16(0x0fff) + reader := writer.Copy() + val, err := reader.GetInt16() + + if err != nil { + t.Fatal(err) + } + + if val != 0x0fff { + t.Fatalf("expected 0x0fff got: %x\n", val) + } + + buf := make([]byte, SizeInt16) + buf[0] = 0xff + buf[1] = 0xff + b := NewBufferFromBytes(buf) + val, err = b.GetInt16() + if err != nil { + t.Fatal(err) + } + + if val != -1 { + t.Fatalf("expected -1 got: %x\n", val) + } +} + +func TestBuffer_GetInt32(t *testing.T) { + writer := NewBuffer(SizeInt32) + writer.WriteInt32(0x0fffffff) + reader := writer.Copy() + + val, err := reader.GetInt32() + if err != nil { + t.Fatal(err) + } + + if val != 0x0fffffff { + t.Fatalf("expected 0x0fffffff got: %x\n", val) + } + + buf := make([]byte, SizeInt32) + buf[0] = 0xff + buf[1] = 0xff + buf[2] = 0xff + buf[3] = 0xff + b := NewBufferFromBytes(buf) + val, err = b.GetInt32() + if err != nil { + t.Fatal(err) + } + + if val != -1 { + t.Fatalf("expected -1 got: %x\n", val) + } +} + +func TestBuffer_GetInt64(t *testing.T) { + writer := NewBuffer(SizeInt64) + writer.WriteInt64(0xf3f3f3f3f3f3) + reader := writer.Copy() + + val, err := reader.GetInt64() + + if err != nil { + t.Fatal(err) + } + + if val != 0xf3f3f3f3f3f3 { + t.Fatalf("expected 0xf3f3f3f3f3f3 got: %x\n", val) + } +} + +func TestBuffer_GetUint8(t *testing.T) { + writer := NewBuffer(SizeUint8) + writer.WriteUint8(0xff) + reader := writer.Copy() + + val, err := reader.GetUint8() + + if err != nil { + t.Fatal(err) + } + + if val != 0xff { + t.Fatalf("expected 0xff got: %x\n", val) + } +} + +func TestBuffer_GetUint16(t *testing.T) { + writer := NewBuffer(SizeUint16) + writer.WriteUint16(0xffff) + reader := writer.Copy() + + val, err := reader.GetUint16() + if err != nil { + t.Fatal(err) + } + + if val != 0xffff { + t.Fatalf("expected 0xffff got: %x\n", val) + } +} + +func TestBuffer_GetUint32(t *testing.T) { + writer := NewBuffer(SizeUint32) + writer.WriteUint32(0xffffffff) + reader := writer.Copy() + + val, err := reader.GetUint32() + if err != nil { + t.Fatal(err) + } + + if val != 0xffffffff { + t.Fatalf("expected 0xffffffff got: %x\n", val) + } +} + +func TestBuffer_GetUint64(t *testing.T) { + writer := NewBuffer(SizeUint64) + writer.WriteUint64(0xffffffffffffffff) + reader := writer.Copy() + + val, err := reader.GetUint64() + + if err != nil { + t.Fatal(err) + } + + if val != 0xffffffffffffffff { + t.Fatalf("expected 0xffffffffffffffff got: %x\n", val) + } +} + +func TestBuffer_Len(t *testing.T) { + b := NewBuffer(10) + b.WriteByte('a') + b.WriteBytesN([]byte("bcdefghij"), 9) + + if b.Len() != 10 { + t.Fatalf("expected length of 10 got: %d\n", b.Len()) + } +} + +func TestBuffer_WriteBytes(t *testing.T) { + w := NewBuffer(10) + w.WriteBytes([]byte("0123456789")) + r := w.Copy() + val, err := r.GetBytes(10) + if err != nil { + t.Fatal(err) + } + + if string(val) != "0123456789" { + t.Fatalf("expected 0123456789 got: %s %d\n", val, len(val)) + } + +} diff --git a/go/netcode/challenge_token.go b/go/netcode/challenge_token.go new file mode 100644 index 0000000..6dc5c63 --- /dev/null +++ b/go/netcode/challenge_token.go @@ -0,0 +1,64 @@ +package netcode + +// Challenge tokens are used in certain packet types +type ChallengeToken struct { + ClientId uint64 // the clientId associated with this token + UserData *Buffer // the userdata payload + TokenData *Buffer // the serialized payload container +} + +// Creates a new empty challenge token with only the clientId set +func NewChallengeToken(clientId uint64) *ChallengeToken { + token := &ChallengeToken{} + token.ClientId = clientId + token.UserData = NewBuffer(USER_DATA_BYTES) + return token +} + +// Encrypts the TokenData buffer with the sequence nonce and provided key +func EncryptChallengeToken(tokenBuffer *[]byte, sequence uint64, key []byte) error { + nonce := NewBuffer(SizeUint64) + nonce.WriteUint64(sequence) + return EncryptAead(tokenBuffer, nil, nonce.Bytes(), key) +} + +// Decrypts the TokenData buffer with the sequence nonce and provided key, updating the +// internal TokenData buffer +func DecryptChallengeToken(tokenBuffer []byte, sequence uint64, key []byte) ([]byte, error) { + nonce := NewBuffer(SizeUint64) + nonce.WriteUint64(sequence) + return DecryptAead(tokenBuffer, nil, nonce.Bytes(), key) +} + +// Serializes the client id and userData, also sets the UserData buffer. +func (t *ChallengeToken) Write(userData []byte) []byte { + tokenData := NewBuffer(CHALLENGE_TOKEN_BYTES) + t.UserData.WriteBytes(userData) + tokenData.WriteUint64(t.ClientId) + tokenData.WriteBytes(userData) + return tokenData.Buf +} + +// Generates a new ChallengeToken from the provided buffer byte slice. Only sets the ClientId +// and UserData buffer. +func ReadChallengeToken(buffer []byte) (*ChallengeToken, error) { + var err error + var clientId uint64 + var userData []byte + tokenBuffer := NewBufferFromBytes(buffer) + + clientId, err = tokenBuffer.GetUint64() + if err != nil { + return nil, err + } + token := NewChallengeToken(clientId) + + userData, err = tokenBuffer.GetBytes(USER_DATA_BYTES) + if err != nil { + return nil, err + } + token.UserData.WriteBytes(userData) + token.UserData.Reset() + + return token, nil +} diff --git a/go/netcode/challenge_token_test.go b/go/netcode/challenge_token_test.go new file mode 100644 index 0000000..34081e1 --- /dev/null +++ b/go/netcode/challenge_token_test.go @@ -0,0 +1,46 @@ +package netcode + +import ( + "testing" + "bytes" +) + +func TestNewChallengeToken(t *testing.T) { + var err error + var userData []byte + var decryptedBuffer []byte + + token := NewChallengeToken(TEST_CLIENT_ID) + if userData, err = RandomBytes(USER_DATA_BYTES); err != nil { + t.Fatalf("error generating random data\n") + } + tokenBuffer := token.Write(userData) + + var sequence uint64 + sequence = 1000 + key, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key\n") + } + + if err := EncryptChallengeToken(&tokenBuffer, sequence, key); err != nil { + t.Fatalf("error encrypting challenge token: %s\n", err) + } + + if decryptedBuffer, err = DecryptChallengeToken(tokenBuffer, sequence, key); err != nil { + t.Fatalf("error decrypting challenge token: %s\n", err) + } + + newToken, err := ReadChallengeToken(decryptedBuffer) + if err != nil { + t.Fatalf("error reading token data %s\n", err) + } + + if newToken.ClientId != token.ClientId { + t.Fatalf("token client id did not match, expected %d got %d\n", token.ClientId, newToken.ClientId) + } + + if bytes.Compare(newToken.UserData.Buf, token.UserData.Buf) != 0 { + t.Fatalf("user data did not match expected\n %#v\ngot\n%#v!", token.UserData.Buf, newToken.UserData.Buf) + } +} \ No newline at end of file diff --git a/go/netcode/client.go b/go/netcode/client.go new file mode 100644 index 0000000..3a00d16 --- /dev/null +++ b/go/netcode/client.go @@ -0,0 +1,36 @@ +package netcode + +import ( + "crypto/rand" + "math/big" +) + +type Client struct { + Id uint64 + config *Config +} + +func NewClient(config *Config) *Client { + c := &Client{config: config} + return c +} + +func (c *Client) Init(sequence uint64) error { + id, err := rand.Int(rand.Reader, big.NewInt(64)) + if err != nil { + return err + } + + c.Id = id.Uint64() + + token := NewConnectToken() + if err := token.Generate(c.config, sequence); err != nil { + return err + } + + return nil +} + +func (c *Client) Connect() error { + return nil +} diff --git a/go/netcode/config.go b/go/netcode/config.go new file mode 100644 index 0000000..2e6ff68 --- /dev/null +++ b/go/netcode/config.go @@ -0,0 +1,25 @@ +package netcode + +import "net" + +// A configuration container for various properties that are passed to packets +type Config struct { + ClientId uint64 // client id used in packet generation + ServerAddrs []net.UDPAddr // list of server addresses + TokenExpiry uint64 // when the token expires, current time + this value. + TimeoutSeconds uint32 // timeout in seconds for connect token + ProtocolId uint64 // the protocol id used between server <-> client + PrivateKey []byte // the private key used for encryption +} + +// Creates a new config holder for ease of passing around to packet generation and client/servers +func NewConfig(serverAddrs []net.UDPAddr, timeoutSeconds uint32, expiry, clientId, protocolId uint64, privateKey []byte) *Config { + c := &Config{} + c.ClientId = clientId + c.ServerAddrs = serverAddrs + c.TokenExpiry = expiry + c.ProtocolId = protocolId + c.PrivateKey = privateKey + c.TimeoutSeconds = timeoutSeconds + return c +} diff --git a/go/netcode/connect_token.go b/go/netcode/connect_token.go new file mode 100644 index 0000000..be489ea --- /dev/null +++ b/go/netcode/connect_token.go @@ -0,0 +1,146 @@ +package netcode + +import ( + "errors" + "strings" + "time" +) + +// ip types used in serialization of server addresses +const ( + ADDRESS_NONE = iota + ADDRESS_IPV4 + ADDRESS_IPV6 +) + +// number of bytes for connect tokens +const CONNECT_TOKEN_BYTES = 2048 + +// Token used for connecting +type ConnectToken struct { + sharedTokenData // a shared container holding the server addresses, client and server keys + VersionInfo []byte // the version information for client <-> server communications + ProtocolId uint64 // protocol id for communications + CreateTimestamp uint64 // when this token was created + ExpireTimestamp uint64 // when this token expires + Sequence uint64 // the sequence id + PrivateData *ConnectTokenPrivate // reference to the private parts of this connect token + TimeoutSeconds uint32 // timeout of connect token in seconds +} + +// Create a new empty token and empty private token +func NewConnectToken() *ConnectToken { + token := &ConnectToken{} + token.PrivateData = NewConnectTokenPrivate() + return token +} + +// Generates the token and private token data with the supplied config values and sequence id. +func (token *ConnectToken) Generate(config *Config, sequence uint64) error { + token.CreateTimestamp = uint64(time.Now().Unix()) + token.ExpireTimestamp = token.CreateTimestamp + config.TokenExpiry + token.VersionInfo = []byte(VERSION_INFO) + token.ProtocolId = config.ProtocolId + token.TimeoutSeconds = config.TimeoutSeconds + token.Sequence = sequence + + userData, err := RandomBytes(USER_DATA_BYTES) + if err != nil { + return err + } + + if err = token.PrivateData.Generate(config, userData); err != nil { + return err + } + + // copy directly from the private token since we don't want to generate 2 different keys + token.ClientKey = token.PrivateData.ClientKey + token.ServerKey = token.PrivateData.ServerKey + token.ServerAddrs = token.PrivateData.ServerAddrs + + if _, err = token.PrivateData.Write(); err != nil { + return err + } + + if err = token.PrivateData.Encrypt(token.ProtocolId, token.ExpireTimestamp, sequence, config.PrivateKey); err != nil { + return err + } + + return nil +} + +// Writes the ConnectToken and previously encrypted ConnectTokenPrivate data to a byte slice +func (token *ConnectToken) Write() ([]byte, error) { + buffer := NewBuffer(CONNECT_TOKEN_BYTES) + buffer.WriteBytes(token.VersionInfo) + buffer.WriteUint64(token.ProtocolId) + buffer.WriteUint64(token.CreateTimestamp) + buffer.WriteUint64(token.ExpireTimestamp) + buffer.WriteUint64(token.Sequence) + + // assumes private token has already been encrypted + buffer.WriteBytes(token.PrivateData.Buffer()) + + if err := token.WriteShared(buffer); err != nil { + return nil, err + } + + buffer.WriteUint32(token.TimeoutSeconds) + return buffer.Buf, nil +} + +// Takes in a slice of decrypted connect token bytes and generates a new ConnectToken. +// Note that the ConnectTokenPrivate is still encrypted at this point. +func ReadConnectToken(tokenBuffer []byte) (*ConnectToken, error) { + var err error + var privateData []byte + + buffer := NewBufferFromBytes(tokenBuffer) + token := NewConnectToken() + + if token.VersionInfo, err = buffer.GetBytes(VERSION_INFO_BYTES); err != nil { + return nil, errors.New("read connect token data has bad version info " + err.Error()) + } + + if strings.Compare(VERSION_INFO, string(token.VersionInfo)) != 0 { + return nil, errors.New("read connect token data has bad version info: " + string(token.VersionInfo)) + } + + if token.ProtocolId, err = buffer.GetUint64(); err != nil { + return nil, errors.New("read connect token data has bad protocol id " + err.Error()) + } + + if token.CreateTimestamp, err = buffer.GetUint64(); err != nil { + return nil, errors.New("read connect token data has bad create timestamp " + err.Error()) + } + + if token.ExpireTimestamp, err = buffer.GetUint64(); err != nil { + return nil, errors.New("read connect token data has bad expire timestamp " + err.Error()) + } + + if token.CreateTimestamp > token.ExpireTimestamp { + return nil, errors.New("expire timestamp is > create timestamp") + } + + if token.Sequence, err = buffer.GetUint64(); err != nil { + return nil, errors.New("read connect data has bad sequence " + err.Error()) + } + + if privateData, err = buffer.GetBytes(CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES); err != nil { + return nil, errors.New("read connect data has bad private data " + err.Error()) + } + + // it is still encrypted at this point. + token.PrivateData.TokenData = NewBufferFromBytes(privateData) + + // reads servers, client and server key + if err = token.ReadShared(buffer); err != nil { + return nil, err + } + + if token.TimeoutSeconds, err = buffer.GetUint32(); err != nil { + return nil, err + } + + return token, nil +} diff --git a/go/netcode/connect_token_private.go b/go/netcode/connect_token_private.go new file mode 100644 index 0000000..e2feb11 --- /dev/null +++ b/go/netcode/connect_token_private.go @@ -0,0 +1,106 @@ +package netcode + +import ( + "errors" +) + +// The private parts of a connect token +type ConnectTokenPrivate struct { + sharedTokenData // holds the server addresses, client <-> server keys + ClientId uint64 // id for this token + UserData []byte // used to store user data + TokenData *Buffer // used to store the serialized/encrypted buffer +} + +// Create a new connect token private with an empty TokenData buffer +func NewConnectTokenPrivate() *ConnectTokenPrivate { + p := &ConnectTokenPrivate{} + p.TokenData = NewBuffer(CONNECT_TOKEN_PRIVATE_BYTES) + return p +} + +// Create a new connect token private with an pre-set, encrypted buffer +// Caller is expected to call Decrypt() and Read() to set the instances properties +func NewConnectTokenPrivateEncrypted(buffer []byte) *ConnectTokenPrivate { + p := &ConnectTokenPrivate{} + p.TokenData = NewBufferFromBytes(buffer) + return p +} + +// Helper to return the internal []byte of the private data +func (p *ConnectTokenPrivate) Buffer() []byte { + return p.TokenData.Buf +} + +// Reads the configuration values to set various properties of this private token data +// and requires a supplied userData slice. +func (p *ConnectTokenPrivate) Generate(config *Config, userData []byte) error { + p.ClientId = config.ClientId + p.UserData = userData + return p.GenerateShared(config) +} + +// Reads the token properties from the internal TokenData buffer. +func (p *ConnectTokenPrivate) Read() error { + var err error + + if p.ClientId, err = p.TokenData.GetUint64(); err != nil { + return err + } + + if err = p.ReadShared(p.TokenData); err != nil { + return err + } + + if p.UserData, err = p.TokenData.GetBytes(USER_DATA_BYTES); err != nil { + return errors.New("error reading user data") + } + + return nil +} + +// Writes the token data to our TokenData buffer and alternatively returns the buffer to caller. +func (p *ConnectTokenPrivate) Write() ([]byte, error) { + p.TokenData.WriteUint64(p.ClientId) + + if err := p.WriteShared(p.TokenData); err != nil { + return nil, err + } + + p.TokenData.WriteBytesN(p.UserData, USER_DATA_BYTES) + return p.TokenData.Buf, nil +} + +// Encrypts, in place, the TokenData buffer, assumes Write() has already been called. +func (token *ConnectTokenPrivate) Encrypt(protocolId, expireTimestamp, sequence uint64, privateKey []byte) error { + additionalData, nonce := buildTokenCryptData(protocolId, expireTimestamp, sequence) + if err := EncryptAead(&token.TokenData.Buf, additionalData, nonce, privateKey); err != nil { + return err + } + return nil +} + +// Decrypts the internal TokenData buffer, assumes that TokenData has been populated with the encrypted data +// (most likely via NewConnectTokenPrivateEncrypted(...)). Optionally returns the decrypted buffer to caller. +func (token *ConnectTokenPrivate) Decrypt(protocolId, expireTimestamp, sequence uint64, privateKey []byte) ([]byte, error) { + var err error + + additionalData, nonce := buildTokenCryptData(protocolId, expireTimestamp, sequence) + if token.TokenData.Buf, err = DecryptAead(token.TokenData.Buf, additionalData, nonce, privateKey); err != nil { + return nil, err + } + token.TokenData.Reset() // reset for reads + return token.TokenData.Buf, nil +} + +// Builds the additional data and nonce necessary for encryption and decryption. +func buildTokenCryptData(protocolId, expireTimestamp, sequence uint64) ([]byte, []byte) { + additionalData := NewBuffer(VERSION_INFO_BYTES + 8 + 8) + additionalData.WriteBytes([]byte(VERSION_INFO)) + additionalData.WriteUint64(protocolId) + additionalData.WriteUint64(expireTimestamp) + + nonce := NewBuffer(SizeUint64) + nonce.WriteUint64(sequence) + return additionalData.Buf, nonce.Buf +} diff --git a/go/netcode/connect_token_private_test.go b/go/netcode/connect_token_private_test.go new file mode 100644 index 0000000..1ad62ff --- /dev/null +++ b/go/netcode/connect_token_private_test.go @@ -0,0 +1,92 @@ +package netcode + +import ( + "bytes" + "net" + "testing" + "time" +) + +func TestConnectTokenPrivate(t *testing.T) { + token1 := NewConnectTokenPrivate() + server := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000} + servers := make([]net.UDPAddr, 1) + servers[0] = server + + config := NewConfig(servers, TEST_TIMEOUT_SECONDS, TEST_CONNECT_TOKEN_EXPIRY, TEST_CLIENT_ID, TEST_PROTOCOL_ID, TEST_PRIVATE_KEY) + currentTimestamp := uint64(time.Now().Unix()) + expireTimestamp := uint64(currentTimestamp + config.TokenExpiry) + + userData, err := RandomBytes(USER_DATA_BYTES) + if err != nil { + t.Fatalf("error generating random bytes: %s\n", err) + } + + if err := token1.Generate(config, userData); err != nil { + t.Fatalf("error generating and encrypting token") + } + + if _, err := token1.Write(); err != nil { + t.Fatalf("error writing token private data") + } + + if err := token1.Encrypt(config.ProtocolId, expireTimestamp, TEST_SEQUENCE_START, config.PrivateKey); err != nil { + t.Fatalf("error encrypting token: %s\n", err) + } + + token2 := NewConnectTokenPrivate() + token2.TokenData = NewBufferFromBytes(token1.Buffer()) + + if _, err := token2.Decrypt(config.ProtocolId, expireTimestamp, TEST_SEQUENCE_START, config.PrivateKey); err != nil { + t.Fatalf("error decrypting token: %s", err) + } + + if err := token2.Read(); err != nil { + t.Fatalf("error reading token: %s\n", err) + } + + testComparePrivateTokens(token1, token2, t) + + token2.TokenData.Reset() + if _, err = token2.Write(); err != nil { + t.Fatalf("error writing token2 buffer") + } + + if err := token2.Encrypt(config.ProtocolId, expireTimestamp, TEST_SEQUENCE_START, config.PrivateKey); err != nil { + t.Fatalf("error encrypting second token: %s\n", err) + } + + if len(token1.Buffer()) != len(token2.Buffer()) { + t.Fatalf("encrypted buffer lengths did not match %d and %d\n", len(token1.Buffer()), len(token2.Buffer())) + } + + if bytes.Compare(token1.Buffer(), token2.Buffer()) != 0 { + t.Fatalf("encrypted private bits didn't match\n%#v\n and\n%#v\n", token1.Buffer(), token2.Buffer()) + } +} + +func testComparePrivateTokens(token1, token2 *ConnectTokenPrivate, t *testing.T) { + if token1.ClientId != token2.ClientId { + t.Fatalf("clientIds do not match expected %d got %d", token1.ClientId, token2.ClientId) + } + + if len(token1.ServerAddrs) != len(token2.ServerAddrs) { + t.Fatalf("time stamps do not match expected %d got %d", len(token1.ServerAddrs), len(token2.ServerAddrs)) + } + + token1Servers := token1.ServerAddrs + token2Servers := token2.ServerAddrs + for i := 0; i < len(token1.ServerAddrs); i += 1 { + if bytes.Compare([]byte(token1Servers[i].IP), []byte(token2Servers[i].IP)) != 0 { + t.Fatalf("server addresses did not match: expected %v got %v\n", token1Servers[i], token2Servers[i]) + } + } + + if bytes.Compare(token1.ClientKey, token2.ClientKey) != 0 { + t.Fatalf("ClientKey do not match expected %v got %v", token1.ClientKey, token2.ClientKey) + } + + if bytes.Compare(token1.ServerKey, token2.ServerKey) != 0 { + t.Fatalf("ServerKey do not match expected %v got %v", token1.ServerKey, token2.ServerKey) + } +} diff --git a/go/netcode/connect_token_shared.go b/go/netcode/connect_token_shared.go new file mode 100644 index 0000000..65c8949 --- /dev/null +++ b/go/netcode/connect_token_shared.go @@ -0,0 +1,126 @@ +package netcode + +import ( + "errors" + "net" + "strconv" +) + +// This struct contains data that is shared in both public and private parts of the +// connect token. +type sharedTokenData struct { + ServerAddrs []net.UDPAddr // list of server addresses this client may connect to + ClientKey []byte // client to server key + ServerKey []byte // server to client key +} + +// Reads and validates the servers, client <-> server keys. +func (shared *sharedTokenData) ReadShared(buffer *Buffer) error { + var err error + var servers uint32 + var ipBytes []byte + + servers, err = buffer.GetUint32() + if err != nil { + return err + } + + if servers <= 0 { + return errors.New("empty servers") + } + + if servers > MAX_SERVERS_PER_CONNECT { + return errors.New("too many servers") + } + + shared.ServerAddrs = make([]net.UDPAddr, servers) + + for i := 0; i < int(servers); i += 1 { + serverType, err := buffer.GetUint8() + if err != nil { + return err + } + + if serverType == ADDRESS_IPV4 { + ipBytes, err = buffer.GetBytes(4) + } else if serverType == ADDRESS_IPV6 { + ipBytes, err = buffer.GetBytes(16) + } else { + return errors.New("unknown ip address") + } + + if err != nil { + return err + } + + ip := net.IP(ipBytes) + port, err := buffer.GetUint16() + if err != nil { + return errors.New("invalid port") + } + shared.ServerAddrs[i] = net.UDPAddr{IP: ip, Port: int(port)} + } + + if shared.ClientKey, err = buffer.GetBytes(KEY_BYTES); err != nil { + return err + } + + if shared.ServerKey, err = buffer.GetBytes(KEY_BYTES); err != nil { + return err + } + + return nil +} + +// Writes the servers and client <-> server keys to the supplied buffer +func (shared *sharedTokenData) WriteShared(buffer *Buffer) error { + buffer.WriteUint32(uint32(len(shared.ServerAddrs))) + + for _, addr := range shared.ServerAddrs { + host, port, err := net.SplitHostPort(addr.String()) + if err != nil { + return errors.New("invalid port for host: " + addr.String()) + } + + parsed := net.ParseIP(host) + if parsed == nil { + return errors.New("invalid ip address") + } + + if len(parsed) == 4 { + buffer.WriteUint8(uint8(ADDRESS_IPV4)) + + } else { + buffer.WriteUint8(uint8(ADDRESS_IPV6)) + } + + for i := 0; i < len(parsed); i += 1 { + buffer.WriteUint8(parsed[i]) + } + + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return err + } + buffer.WriteUint16(uint16(p)) + } + buffer.WriteBytesN(shared.ClientKey, KEY_BYTES) + buffer.WriteBytesN(shared.ServerKey, KEY_BYTES) + return nil +} + +// Generates the shared data, should only really be called by ConnectTokenPrivate +// since the same data will be copied/referenced by ConnectToken +func (shared *sharedTokenData) GenerateShared(config *Config) error { + var err error + + shared.ServerAddrs = config.ServerAddrs + if shared.ClientKey, err = GenerateKey(); err != nil { + return err + } + + if shared.ServerKey, err = GenerateKey(); err != nil { + return err + } + return nil +} diff --git a/go/netcode/connect_token_test.go b/go/netcode/connect_token_test.go new file mode 100644 index 0000000..6650836 --- /dev/null +++ b/go/netcode/connect_token_test.go @@ -0,0 +1,119 @@ +package netcode + +import ( + "bytes" + "net" + "testing" +) + +const ( + TEST_PROTOCOL_ID = 0x1122334455667788 + TEST_CONNECT_TOKEN_EXPIRY = 30 + TEST_SERVER_PORT = 40000 + TEST_CLIENT_ID = 0x1 + TEST_SEQUENCE_START = 1000 + TEST_TIMEOUT_SECONDS = 1 +) + +var TEST_PRIVATE_KEY = []byte{0x60, 0x6a, 0xbe, 0x6e, 0xc9, 0x19, 0x10, 0xea, + 0x9a, 0x65, 0x62, 0xf6, 0x6f, 0x2b, 0x30, 0xe4, + 0x43, 0x71, 0xd6, 0x2c, 0xd1, 0x99, 0x27, 0x26, + 0x6b, 0x3c, 0x60, 0xf4, 0xb7, 0x15, 0xab, 0xa1} + +func TestConnectToken(t *testing.T) { + var err error + var tokenBuffer []byte + var key []byte + + inToken := NewConnectToken() + server := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000} + servers := make([]net.UDPAddr, 1) + servers[0] = server + + if key, err = GenerateKey(); err != nil { + t.Fatalf("error generating key %s\n", key) + } + + config := NewConfig(servers, TEST_TIMEOUT_SECONDS, TEST_CONNECT_TOKEN_EXPIRY, TEST_CLIENT_ID, TEST_PROTOCOL_ID, key) + + // generate will write & encrypt the ConnectTokenPrivate + err = inToken.Generate(config, TEST_SEQUENCE_START) + if err != nil { + t.Fatalf("error generating") + } + + // Writes the entire ConnectToken (including Private) + if tokenBuffer, err = inToken.Write(); err != nil { + t.Fatalf("error writing token: %s\n", err) + } + + outToken, err := ReadConnectToken(tokenBuffer) + if err != nil { + t.Fatalf("error re-reading back token buffer: %s\n", err) + } + + if string(inToken.VersionInfo) != string(outToken.VersionInfo) { + t.Fatalf("version info did not match expected: %s got: %s\n", inToken.VersionInfo, outToken.VersionInfo) + } + + if inToken.ProtocolId != outToken.ProtocolId { + t.Fatalf("ProtocolId did not match expected: %s got: %s\n", inToken.ProtocolId, outToken.ProtocolId) + } + + if inToken.CreateTimestamp != outToken.CreateTimestamp { + t.Fatalf("CreateTimestamp did not match expected: %s got: %s\n", inToken.CreateTimestamp, outToken.CreateTimestamp) + } + + if inToken.ExpireTimestamp != outToken.ExpireTimestamp { + t.Fatalf("ExpireTimestamp did not match expected: %s got: %s\n", inToken.ExpireTimestamp, outToken.ExpireTimestamp) + } + + if inToken.Sequence != outToken.Sequence { + t.Fatalf("Sequence did not match expected: %s got: %s\n", inToken.Sequence, outToken.Sequence) + } + + testCompareTokens(inToken, outToken, t) + + if bytes.Compare(inToken.PrivateData.Buffer(), outToken.PrivateData.Buffer()) != 0 { + t.Fatalf("encrypted private data of tokens did not match\n%#v\n%#v", inToken.PrivateData.Buffer(), outToken.PrivateData.Buffer()) + } + + // need to decrypt the private tokens before we can compare + if _, err := outToken.PrivateData.Decrypt(config.ProtocolId, outToken.ExpireTimestamp, outToken.Sequence, key); err != nil { + t.Fatalf("error decrypting private out token data: %s\n", err) + } + + if _, err := inToken.PrivateData.Decrypt(config.ProtocolId, inToken.ExpireTimestamp, inToken.Sequence, key); err != nil { + t.Fatalf("error decrypting private in token data: %s\n", err) + } + + // and re-read to set the properties in outToken private + if err := outToken.PrivateData.Read(); err != nil { + t.Fatalf("error reading private data %s", err) + } + + testComparePrivateTokens(inToken.PrivateData, outToken.PrivateData, t) + +} + +func testCompareTokens(token1, token2 *ConnectToken, t *testing.T) { + if len(token1.ServerAddrs) != len(token2.ServerAddrs) { + t.Fatalf("time stamps do not match expected %d got %d", len(token1.ServerAddrs), len(token2.ServerAddrs)) + } + + token1Servers := token1.ServerAddrs + token2Servers := token2.ServerAddrs + for i := 0; i < len(token1.ServerAddrs); i += 1 { + if bytes.Compare([]byte(token1Servers[i].IP), []byte(token2Servers[i].IP)) != 0 { + t.Fatalf("server addresses did not match: expected %v got %v\n", token1Servers[i], token2Servers[i]) + } + } + + if bytes.Compare(token1.ClientKey, token2.ClientKey) != 0 { + t.Fatalf("ClientKey do not match expected %v got %v", token1.ClientKey, token2.ClientKey) + } + + if bytes.Compare(token1.ServerKey, token2.ServerKey) != 0 { + t.Fatalf("ServerKey do not match expected %v got %v", token1.ServerKey, token2.ServerKey) + } +} diff --git a/go/netcode/crypto.go b/go/netcode/crypto.go new file mode 100644 index 0000000..43b210d --- /dev/null +++ b/go/netcode/crypto.go @@ -0,0 +1,38 @@ +package netcode + +import ( + "crypto/rand" + "github.com/codahale/chacha20poly1305" +) + +// Generates random bytes +func RandomBytes(bytes int) ([]byte, error) { + b := make([]byte, bytes) + _, err := rand.Read(b) + return b, err +} + +// Generates a random key of KEY_BYTES +func GenerateKey() ([]byte, error) { + return RandomBytes(KEY_BYTES) +} + +// Encrypts the message in place with the nonce and key and optional additional buffer +func EncryptAead(message *[]byte, additional, nonce, key []byte) error { + aead, err := chacha20poly1305.New(key) + if err != nil { + return err + } + *message = aead.Seal(nil, nonce, *message, additional) + return nil +} + +// Decrypts the message with the nonce and key and optional additional buffer returning a copy +// byte slice +func DecryptAead(message []byte, additional, nonce, key []byte) ([]byte, error) { + aead, err := chacha20poly1305.New(key) + if err != nil { + return nil, err + } + return aead.Open(nil, nonce, message, additional) +} \ No newline at end of file diff --git a/go/netcode/encryption_manager.go b/go/netcode/encryption_manager.go new file mode 100644 index 0000000..99ad84a --- /dev/null +++ b/go/netcode/encryption_manager.go @@ -0,0 +1 @@ +package netcode diff --git a/go/netcode/packet.go b/go/netcode/packet.go new file mode 100644 index 0000000..7cef99e --- /dev/null +++ b/go/netcode/packet.go @@ -0,0 +1,556 @@ +package netcode + +import ( + "errors" + "strconv" +) + +const MAX_CLIENTS = 60 +const CONNECT_TOKEN_PRIVATE_BYTES = 1024 +const CHALLENGE_TOKEN_BYTES = 300 +const VERSION_INFO_BYTES = 13 +const USER_DATA_BYTES = 256 +const MAX_PACKET_BYTES = 1220 +const MAX_PAYLOAD_BYTES = 1200 +const MAX_ADDRESS_STRING_LENGTH = 256 +const REPLAY_PROTECTION_BUFFER_SIZE = 256 +const CLIENT_MAX_RECEIVE_PACKETS = 64 +const SERVER_MAX_RECEIVE_PACKETS = (64 * MAX_CLIENTS) + +const KEY_BYTES = 32 +const MAC_BYTES = 16 +const NONCE_BYTES = 8 +const MAX_SERVERS_PER_CONNECT = 32 + +const VERSION_INFO = "NETCODE 1.00\x00" + +// Used for determining the type of packet, part of the serialization protocol +type PacketType uint8 + +const ( + ConnectionRequest PacketType = iota + ConnectionDenied + ConnectionChallenge + ConnectionResponse + ConnectionKeepAlive + ConnectionPayload + ConnectionDisconnect +) + +// reference map of packet -> string values +var packetTypeMap = map[PacketType]string{ + ConnectionRequest: "CONNECTION_REQUEST", + ConnectionDenied: "CONNECTION_DENIED", + ConnectionChallenge: "CONNECTION_CHALLENGE", + ConnectionResponse: "CONNECTION_RESPONSE", + ConnectionKeepAlive: "CONNECTION_KEEPALIVE", + ConnectionPayload: "CONNECTION_PAYLOAD", + ConnectionDisconnect: "CONNECTION_DISCONNECT", +} + +// not a packet type, but value is last packetType+1 +const ConnectionNumPackets = ConnectionDisconnect + 1 + +// Packet interface supporting reading and writing. +type Packet interface { + GetType() PacketType // returns the packet type + Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) // writes the packet data to the supplied buffer. + Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error // reads in data from the supplied buffer to set the packet properties +} + +// The connection request packet +type RequestPacket struct { + VersionInfo []byte // version information of communications + ProtocolId uint64 // protocol id used in communications + ConnectTokenExpireTimestamp uint64 // when the connect token expires + ConnectTokenSequence uint64 // the sequence id of this token + Token *ConnectTokenPrivate // reference to the private parts of this packet + ConnectTokenData []byte // the encrypted Token after Write -> Encrypt +} + +// Writes the RequestPacket data to a supplied buffer and returns the length of bytes written to it. +func (p *RequestPacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + buffer.WriteUint8(uint8(ConnectionRequest)) + buffer.WriteBytes(p.VersionInfo) + buffer.WriteUint64(p.ProtocolId) + buffer.WriteUint64(p.ConnectTokenExpireTimestamp) + buffer.WriteUint64(p.ConnectTokenSequence) + buffer.WriteBytes(p.ConnectTokenData) // write the encrypted connection token private data + if buffer.Pos != 1+13+8+8+8+CONNECT_TOKEN_PRIVATE_BYTES+MAC_BYTES { + return -1, errors.New("invalid buffer size written") + } + return buffer.Pos, nil +} + +// Reads a request packet and decrypts the connect token private data +func (p *RequestPacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + var err error + var packetType uint8 + + if packetType, err = packetBuffer.GetUint8(); err != nil || PacketType(packetType) != ConnectionRequest { + return errors.New("invalid packet type") + } + + if allowedPackets[0] == 0 { + return errors.New("ignored connection request packet. packet type is not allowed") + } + + if packetLen != 1+VERSION_INFO_BYTES+8+8+8+CONNECT_TOKEN_PRIVATE_BYTES+MAC_BYTES { + return errors.New("ignored connection request packet. bad packet length") + } + + if privateKey == nil { + return errors.New("ignored connection request packet. no private key\n") + } + + p.VersionInfo, err = packetBuffer.GetBytes(VERSION_INFO_BYTES) + if err != nil { + return errors.New("ignored connection request packet. bad version info invalid bytes returned\n") + } + + if string(p.VersionInfo) != VERSION_INFO { + return errors.New("ignored connection request packet. bad version info did not match\n") + } + + p.ProtocolId, err = packetBuffer.GetUint64() + if err != nil || p.ProtocolId != protocolId { + return errors.New("ignored connection request packet. wrong protocol id\n") + } + + p.ConnectTokenExpireTimestamp, err = packetBuffer.GetUint64() + if err != nil || p.ConnectTokenExpireTimestamp <= currentTimestamp { + return errors.New("ignored connection request packet. connect token expired\n") + } + + p.ConnectTokenSequence, err = packetBuffer.GetUint64() + if err != nil { + return err + } + + if packetBuffer.Pos != 1+VERSION_INFO_BYTES+8+8+8 { + return errors.New("invalid length of packet buffer read") + } + + var tokenBuffer []byte + tokenBuffer, err = packetBuffer.GetBytes(CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES) + if err != nil { + return err + } + + p.Token = NewConnectTokenPrivateEncrypted(tokenBuffer) + if _, err := p.Token.Decrypt(p.ProtocolId, p.ConnectTokenExpireTimestamp, p.ConnectTokenSequence, privateKey); err != nil { + return errors.New("error decrypting connect token private data: " + err.Error()) + } + + if err := p.Token.Read(); err != nil { + return errors.New("error reading decrypted connect token private data: " + err.Error()) + } + + return nil +} + +func (p *RequestPacket) GetType() PacketType { + return ConnectionRequest +} + +// Denied packet type, contains no information +type DeniedPacket struct { +} + +func (p *DeniedPacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) + if err != nil { + return -1, err + } + + // denied packets are empty + return encryptPacket(buffer, buffer.Pos, buffer.Pos, prefixByte, protocolId, sequence, writePacketKey) +} + +func (p *DeniedPacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) + if err != nil { + return err + } + + if decryptedBuf.Len() != 0 { + return errors.New("ignored connection denied packet. decrypted packet data is wrong size") + } + return nil +} + +func (p *DeniedPacket) GetType() PacketType { + return ConnectionDenied +} + +// Challenge packet containing token data and the sequence id used +type ChallengePacket struct { + ChallengeTokenSequence uint64 + ChallengeTokenData []byte +} + +func (p *ChallengePacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) + if err != nil { + return -1, err + } + + encryptedStart := buffer.Pos + buffer.WriteUint64(p.ChallengeTokenSequence) + buffer.WriteBytesN(p.ChallengeTokenData, CHALLENGE_TOKEN_BYTES) + encryptedFinish := buffer.Pos + return encryptPacket(buffer, encryptedStart, encryptedFinish, prefixByte, protocolId, sequence, writePacketKey) +} + +func (p *ChallengePacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) + if err != nil { + return err + } + + if decryptedBuf.Len() != 8+CHALLENGE_TOKEN_BYTES { + return errors.New("ignored connection challenge packet. decrypted packet data is wrong size") + } + + p.ChallengeTokenSequence, err = decryptedBuf.GetUint64() + if err != nil { + return errors.New("error reading challenge token sequence") + } + + p.ChallengeTokenData, err = decryptedBuf.GetBytes(CHALLENGE_TOKEN_BYTES) + if err != nil { + return errors.New("error reading challenge token data") + } + + return nil +} + +func (p *ChallengePacket) GetType() PacketType { + return ConnectionChallenge +} + +// Response packet, containing the token data and sequence id +type ResponsePacket struct { + ChallengeTokenSequence uint64 + ChallengeTokenData []byte +} + +func (p *ResponsePacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) + if err != nil { + return -1, err + } + + encryptedStart := buffer.Pos + buffer.WriteUint64(p.ChallengeTokenSequence) + buffer.WriteBytesN(p.ChallengeTokenData, CHALLENGE_TOKEN_BYTES) + encryptedFinish := buffer.Pos + return encryptPacket(buffer, encryptedStart, encryptedFinish, prefixByte, protocolId, sequence, writePacketKey) +} + +func (p *ResponsePacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) + if err != nil { + return err + } + + if decryptedBuf.Len() != 8+CHALLENGE_TOKEN_BYTES { + return errors.New("ignored connection challenge packet. decrypted packet data is wrong size") + } + + p.ChallengeTokenSequence, err = decryptedBuf.GetUint64() + if err != nil { + return errors.New("error reading challenge token sequence") + } + + p.ChallengeTokenData, err = decryptedBuf.GetBytes(CHALLENGE_TOKEN_BYTES) + if err != nil { + return errors.New("error reading challenge token data") + } + + return nil +} + +func (p *ResponsePacket) GetType() PacketType { + return ConnectionResponse +} + +// used for heart beats +type KeepAlivePacket struct { + ClientIndex uint32 + MaxClients uint32 +} + +func (p *KeepAlivePacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) + if err != nil { + return -1, err + } + + encryptedStart := buffer.Pos + buffer.WriteUint32(uint32(p.ClientIndex)) + buffer.WriteUint32(uint32(p.MaxClients)) + encryptedFinish := buffer.Pos + return encryptPacket(buffer, encryptedStart, encryptedFinish, prefixByte, protocolId, sequence, writePacketKey) +} + +func (p *KeepAlivePacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) + if err != nil { + return err + } + + if decryptedBuf.Len() != 8 { + return errors.New("ignored connection keep alive packet. decrypted packet data is wrong size") + } + + p.ClientIndex, err = decryptedBuf.GetUint32() + if err != nil { + return errors.New("error reading keepalive client index") + } + + p.MaxClients, err = decryptedBuf.GetUint32() + if err != nil { + return errors.New("error reading keepalive max clients") + } + + return nil +} + +func (p *KeepAlivePacket) GetType() PacketType { + return ConnectionKeepAlive +} + +// Contains user supplied payload data between server <-> client +type PayloadPacket struct { + PayloadBytes uint32 + PayloadData []byte +} + +func (p *PayloadPacket) GetType() PacketType { + return ConnectionPayload +} + +// Helper function to create a new payload packet with the supplied buffer +func NewPayloadPacket(payloadData []byte) *PayloadPacket { + packet := &PayloadPacket{} + packet.PayloadBytes = uint32(len(payloadData)) + packet.PayloadData = payloadData + return packet +} + +func (p *PayloadPacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) + if err != nil { + return -1, err + } + + encryptedStart := buffer.Pos + buffer.WriteBytesN([]byte(p.PayloadData), int(p.PayloadBytes)) + encryptedFinish := buffer.Pos + return encryptPacket(buffer, encryptedStart, encryptedFinish, prefixByte, protocolId, sequence, writePacketKey) +} + +func (p *PayloadPacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) + if err != nil { + return err + } + + decryptedSize := uint32(decryptedBuf.Len()) + if decryptedSize < 1 { + return errors.New("ignored connection payload packet. payload is too small") + } + + if decryptedSize > MAX_PAYLOAD_BYTES { + return errors.New("ignored connection payload packet. payload is too large") + } + + p.PayloadBytes = decryptedSize + p.PayloadData = decryptedBuf.Bytes() + return nil +} + +// Signals to server/client to disconnect, contains no data. +type DisconnectPacket struct { +} + +func (p *DisconnectPacket) Write(buffer *Buffer, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + prefixByte, err := writePacketPrefix(p, buffer, sequence) + if err != nil { + return -1, err + } + + // denied packets are empty + return encryptPacket(buffer, buffer.Pos, buffer.Pos, prefixByte, protocolId, sequence, writePacketKey) +} + +func (p *DisconnectPacket) Read(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + decryptedBuf, err := decryptPacket(packetBuffer, packetLen, protocolId, currentTimestamp, readPacketKey, privateKey, allowedPackets, replayProtection) + if err != nil { + return err + } + + if decryptedBuf.Len() != 0 { + return errors.New("ignored connection denied packet. decrypted packet data is wrong size") + } + return nil +} + +func (p *DisconnectPacket) GetType() PacketType { + return ConnectionDisconnect +} + +// Decrypts the packet after reading in the prefix byte and sequence id. Used for all PacketTypes except RequestPacket. Returns a buffer containing the decrypted data +func decryptPacket(packetBuffer *Buffer, packetLen int, protocolId, currentTimestamp uint64, readPacketKey, privateKey, allowedPackets []byte, replayProtection *ReplayProtection) (*Buffer, error) { + var packetSequence uint64 + + prefixByte, err := packetBuffer.GetUint8() + if err != nil { + return nil, errors.New("invalid buffer length") + } + + if packetSequence, err = readSequence(packetBuffer, packetLen, prefixByte); err != nil { + return nil, err + } + + if err := validateSequence(packetLen, prefixByte, packetSequence, readPacketKey, allowedPackets, replayProtection); err != nil { + return nil, err + } + + // decrypt the per-packet type data + additionalData, nonce := packetCryptData(prefixByte, protocolId, packetSequence) + + encryptedSize := packetLen - packetBuffer.Pos + if encryptedSize < MAC_BYTES { + return nil, errors.New("ignored encrypted packet. encrypted payload is too small") + } + + encryptedBuff, err := packetBuffer.GetBytes(encryptedSize) + if err != nil { + return nil, errors.New("ignored encrypted packet. encrypted payload is too small") + } + + decryptedBuff, err := DecryptAead(encryptedBuff, additionalData, nonce, readPacketKey) + if err != nil { + return nil, errors.New("ignored encrypted packet. failed to decrypt: " + err.Error()) + } + + return NewBufferFromBytes(decryptedBuff), nil +} + +// Reads and verifies the sequence id +func readSequence(packetBuffer *Buffer, packetLen int, prefixByte uint8) (uint64, error) { + var sequence uint64 + + sequenceBytes := prefixByte >> 4 + if sequenceBytes < 1 || sequenceBytes > 8 { + return 0, errors.New("ignored encrypted packet. sequence bytes is out of range [1,8]") + } + + if packetLen < 1+int(sequenceBytes)+MAC_BYTES { + return 0, errors.New("ignored encrypted packet. buffer is too small for sequence bytes + encryption mac") + } + + var i uint8 + // read variable length sequence number [1,8] + for i = 0; i < sequenceBytes; i += 1 { + val, err := packetBuffer.GetUint8() + if err != nil { + return 0, err + } + sequence |= (uint64(val) << (8 * i)) + } + return sequence, nil +} + +// Validates the data prior to the encrypted segment before we bother attempting to decrypt. +func validateSequence(packetLen int, prefixByte uint8, sequence uint64, readPacketKey, allowedPackets []byte, replayProtection *ReplayProtection) error { + + if readPacketKey == nil { + return errors.New("empty packet key") + } + + if packetLen < 1+1+MAC_BYTES { + return errors.New("ignored encrypted packet. packet is too small to be valid") + } + + packetType := prefixByte & 0xF + if PacketType(packetType) >= ConnectionNumPackets { + return errors.New("ignored encrypted packet. packet type " + packetTypeMap[PacketType(packetType)] + " is invalid") + } + + if allowedPackets[packetType] == 0 { + return errors.New("ignored encrypted packet. packet type " + packetTypeMap[PacketType(packetType)] + " is invalid") + } + + // replay protection (optional) + if replayProtection != nil && PacketType(packetType) >= ConnectionKeepAlive { + if replayProtection.AlreadyReceived(sequence) == 1 { + v := strconv.FormatUint(sequence, 10) + return errors.New("ignored connection payload packet. sequence " + v + " already received (replay protection)") + } + } + return nil +} + +// write the prefix byte (this is a combination of the packet type and number of sequence bytes) +func writePacketPrefix(p Packet, buffer *Buffer, sequence uint64) (uint8, error) { + sequenceBytes := sequenceNumberBytesRequired(sequence) + if sequenceBytes < 1 || sequenceBytes > 8 { + return 0, errors.New("invalid sequence bytes, must be between [1-8]") + } + + prefixByte := uint8(p.GetType()) | uint8(sequenceBytes<<4) + buffer.WriteUint8(prefixByte) + + sequenceTemp := sequence + + var i uint8 + for ; i < sequenceBytes; i += 1 { + buffer.WriteUint8(uint8(sequenceTemp & 0xFF)) + sequenceTemp >>= 8 + } + return prefixByte, nil +} + +// Encrypts the packet data of the supplied buffer between encryptedStart and encrypedFinish. +func encryptPacket(buffer *Buffer, encryptedStart, encryptedFinish int, prefixByte uint8, protocolId, sequence uint64, writePacketKey []byte) (int, error) { + // slice up the buffer for the bits we will encrypt + encryptedBuffer := buffer.Buf[encryptedStart:encryptedFinish] + + additionalData, nonce := packetCryptData(prefixByte, protocolId, sequence) + if err := EncryptAead(&encryptedBuffer, additionalData, nonce, writePacketKey); err != nil { + return -1, err + } + + buffer.Pos = encryptedStart // reset position to start of where the encrypted data goes + buffer.WriteBytes(encryptedBuffer) + return buffer.Pos, nil // in c, we do Pos + MAC_BYTES but the WriteBytes will update buffer.Pos to include it +} + +// used for encrypting the per-packet packet written with the prefix byte, protocol id and version as the associated data. this must match to decrypt. +func packetCryptData(prefixByte uint8, protocolId, sequence uint64) ([]byte, []byte) { + additionalData := NewBuffer(VERSION_INFO_BYTES + 8 + 1) + additionalData.WriteBytesN([]byte(VERSION_INFO), VERSION_INFO_BYTES) + additionalData.WriteUint64(protocolId) + additionalData.WriteUint8(prefixByte) + + nonce := NewBuffer(SizeUint64) + nonce.WriteUint64(sequence) + return additionalData.Buf, nonce.Buf +} + +// Depending on size of sequence number, we need to reserve N bytes +func sequenceNumberBytesRequired(sequence uint64) uint8 { + var mask uint64 + mask = 0xFF00000000000000 + var i uint8 + for ; i < 7; i += 1 { + if sequence&mask != 0 { + break + } + mask >>= 8 + } + return 8 - i +} diff --git a/go/netcode/packet_test.go b/go/netcode/packet_test.go new file mode 100644 index 0000000..f339e12 --- /dev/null +++ b/go/netcode/packet_test.go @@ -0,0 +1,407 @@ +package netcode + +import ( + "bytes" + "net" + "testing" + "time" +) + +func TestSequence(t *testing.T) { + seq := sequenceNumberBytesRequired(0) + if seq != 1 { + t.Fatal("expected 0, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x11) + if seq != 1 { + t.Fatal("expected 1, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x1122) + if seq != 2 { + t.Fatal("expected 2, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x112233) + if seq != 3 { + t.Fatal("expected 3, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x11223344) + if seq != 4 { + t.Fatal("expected 4, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x1122334455) + if seq != 5 { + t.Fatal("expected 5, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x112233445566) + if seq != 6 { + t.Fatal("expected 6, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x11223344556677) + if seq != 7 { + t.Fatal("expected 7, got: ", seq) + } + + seq = sequenceNumberBytesRequired(0x1122334455667788) + if seq != 8 { + t.Fatal("expected 8, got: ", seq) + } + +} + +func TestConnectionRequestPacket(t *testing.T) { + connectTokenKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating connect token key: %s\n", err) + } + inputPacket, decryptedToken := testBuildRequestPacket(connectTokenKey, t) + + // write the connection request packet to a buffer + buffer := NewBuffer(2048) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i += 1 { + allowedPackets[i] = 1 + } + + outputPacket := &RequestPacket{} + + buffer.Reset() + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, connectTokenKey, allowedPackets, nil); err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + if bytes.Compare(inputPacket.VersionInfo, outputPacket.VersionInfo) != 0 { + t.Fatalf("version info did not match") + } + + if inputPacket.ProtocolId != outputPacket.ProtocolId { + t.Fatalf("ProtocolId did not match") + } + + if inputPacket.ConnectTokenExpireTimestamp != outputPacket.ConnectTokenExpireTimestamp { + t.Fatalf("ConnectTokenExpireTimestamp did not match") + } + + if inputPacket.ConnectTokenSequence != outputPacket.ConnectTokenSequence { + t.Fatalf("ConnectTokenSequence did not match") + } + + if bytes.Compare(decryptedToken, outputPacket.Token.TokenData.Buf) != 0 { + t.Fatalf("TokenData did not match") + } +} + +func TestConnectionDeniedPacket(t *testing.T) { + // setup a connection denied packet + inputPacket := &DeniedPacket{} + + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i += 1 { + allowedPackets[i] = 1 + } + + outputPacket := &DeniedPacket{} + buffer.Reset() + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + if outputPacket.GetType() != ConnectionDenied { + t.Fatalf("did not get a denied packet after read") + } +} + +func TestConnectionChallengePacket(t *testing.T) { + var err error + + // setup a connection challenge packet + inputPacket := &ChallengePacket{} + inputPacket.ChallengeTokenSequence = 0 + inputPacket.ChallengeTokenData, err = RandomBytes(CHALLENGE_TOKEN_BYTES) + if err != nil { + t.Fatalf("error generating random bytes") + } + + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i += 1 { + allowedPackets[i] = 1 + } + + outputPacket := &ChallengePacket{} + buffer.Reset() + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + if inputPacket.ChallengeTokenSequence != outputPacket.ChallengeTokenSequence { + t.Fatalf("input and output sequence differed, expected %d got %d\n", inputPacket.ChallengeTokenSequence, outputPacket.ChallengeTokenSequence) + } + + if bytes.Compare(inputPacket.ChallengeTokenData, outputPacket.ChallengeTokenData) != 0 { + t.Fatalf("challenge token data was not equal\n") + } +} + +func TestConnectionResponsePacket(t *testing.T) { + var err error + + // setup a connection response packet + inputPacket := &ResponsePacket{} + inputPacket.ChallengeTokenSequence = 0 + inputPacket.ChallengeTokenData, err = RandomBytes(CHALLENGE_TOKEN_BYTES) + if err != nil { + t.Fatalf("error generating random bytes") + } + + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i += 1 { + allowedPackets[i] = 1 + } + + outputPacket := &ResponsePacket{} + buffer.Reset() + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + if inputPacket.ChallengeTokenSequence != outputPacket.ChallengeTokenSequence { + t.Fatalf("input and output sequence differed, expected %d got %d\n", inputPacket.ChallengeTokenSequence, outputPacket.ChallengeTokenSequence) + } + + if bytes.Compare(inputPacket.ChallengeTokenData, outputPacket.ChallengeTokenData) != 0 { + t.Fatalf("response challenge token data was not equal\n") + } +} + +func TestConnectionKeepAlivePacket(t *testing.T) { + var err error + + // setup a connection challenge packet + inputPacket := &KeepAlivePacket{} + inputPacket.ClientIndex = 10 + inputPacket.MaxClients = 16 + + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i += 1 { + allowedPackets[i] = 1 + } + + outputPacket := &KeepAlivePacket{} + buffer.Reset() + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + if inputPacket.ClientIndex != outputPacket.ClientIndex { + t.Fatalf("input and output index differed, expected %d got %d\n", inputPacket.ClientIndex, outputPacket.ClientIndex) + } + + if inputPacket.MaxClients != outputPacket.MaxClients { + t.Fatalf("input and output maxclients differed, expected %d got %d\n", inputPacket.MaxClients, outputPacket.MaxClients) + } +} + +func TestConnectionPayloadPacket(t *testing.T) { + var err error + payloadData, err := RandomBytes(MAX_PAYLOAD_BYTES) + if err != nil { + t.Fatalf("error generating random payload data: %s\n", err) + } + + inputPacket := NewPayloadPacket(payloadData) + + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i += 1 { + allowedPackets[i] = 1 + } + + buffer.Reset() + outputPacket := &PayloadPacket{} + + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + + if inputPacket.PayloadBytes != outputPacket.PayloadBytes { + t.Fatalf("input and output index differed, expected %d got %d\n", inputPacket.PayloadBytes, outputPacket.PayloadBytes) + } + + if bytes.Compare(inputPacket.PayloadData, outputPacket.PayloadData) != 0 { + t.Fatalf("input and output payload differed, expected %v got %v\n", inputPacket.PayloadData, outputPacket.PayloadData) + } +} + +func TestDisconnectPacket(t *testing.T) { + inputPacket := &DisconnectPacket{} + buffer := NewBuffer(MAX_PACKET_BYTES) + + packetKey, err := GenerateKey() + if err != nil { + t.Fatalf("error generating key") + } + + // write the packet to a buffer + bytesWritten, err := inputPacket.Write(buffer, TEST_PROTOCOL_ID, TEST_SEQUENCE_START, packetKey) + if err != nil { + t.Fatalf("error writing packet: %s\n", err) + } + + if bytesWritten <= 0 { + t.Fatalf("did not write any bytes for this packet") + } + + allowedPackets := make([]byte, ConnectionNumPackets) + for i := 0; i < len(allowedPackets); i += 1 { + allowedPackets[i] = 1 + } + + buffer.Reset() + outputPacket := &DisconnectPacket{} + if err := outputPacket.Read(buffer, bytesWritten, TEST_PROTOCOL_ID, uint64(time.Now().Unix()), packetKey, nil, allowedPackets, nil); err != nil { + t.Fatalf("error reading packet: %s\n", err) + } + +} + +func testBuildRequestPacket(connectTokenKey []byte, t *testing.T) (*RequestPacket, []byte) { + addr := net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: TEST_SERVER_PORT} + serverAddrs := make([]net.UDPAddr, 1) + serverAddrs[0] = addr + config := NewConfig(serverAddrs, TEST_TIMEOUT_SECONDS, TEST_CONNECT_TOKEN_EXPIRY, TEST_CLIENT_ID, TEST_PROTOCOL_ID, connectTokenKey) + + connectToken := NewConnectToken() + + if err := connectToken.Generate(config, TEST_SEQUENCE_START); err != nil { + t.Fatalf("error generating connect token: %s\n", err) + } + + _, err := connectToken.Write() + if err != nil { + t.Fatalf("error writing private data: %s\n", err) + } + + decryptedToken, err := connectToken.PrivateData.Decrypt(TEST_PROTOCOL_ID, connectToken.ExpireTimestamp, TEST_SEQUENCE_START, connectTokenKey) + if err != nil { + t.Fatalf("error decrypting connect token: %s", err) + } + // need to re-encrypt the private data + if err := connectToken.PrivateData.Encrypt(TEST_PROTOCOL_ID, connectToken.ExpireTimestamp, TEST_SEQUENCE_START, connectTokenKey); err != nil { + t.Fatalf("error re-encrypting connect private token: %s\n", err) + } + + // setup a connection request packet wrapping the encrypted connect token + inputPacket := &RequestPacket{} + inputPacket.VersionInfo = []byte(VERSION_INFO) + inputPacket.ProtocolId = TEST_PROTOCOL_ID + inputPacket.ConnectTokenExpireTimestamp = connectToken.ExpireTimestamp + inputPacket.ConnectTokenSequence = TEST_SEQUENCE_START + inputPacket.Token = connectToken.PrivateData + inputPacket.ConnectTokenData = connectToken.PrivateData.Buffer() + return inputPacket, decryptedToken +} diff --git a/go/netcode/queue.go b/go/netcode/queue.go new file mode 100644 index 0000000..6dfc01a --- /dev/null +++ b/go/netcode/queue.go @@ -0,0 +1,44 @@ +package netcode + +const PACKET_QUEUE_SIZE = 256 + +type Queue struct { + NumPackets int + StartIndex int + Packets []Packet +} + +func NewQueue() *Queue { + q := &Queue{} + q.Packets = make([]Packet, PACKET_QUEUE_SIZE) + return q +} + +func (q *Queue) Clear() { + q.NumPackets = 0 + q.StartIndex = 0 + q.Packets = make([]Packet, PACKET_QUEUE_SIZE) +} + +func (q *Queue) Push(packet Packet) int { + if q.NumPackets == PACKET_QUEUE_SIZE { + return 0 + } + + index := (q.StartIndex + q.NumPackets) % PACKET_QUEUE_SIZE + q.Packets[index] = packet + q.NumPackets++ + return 1 +} + +func (q *Queue) Pop() Packet { + if q.NumPackets == 0 { + return nil + } + + packet := q.Packets[q.StartIndex] + q.StartIndex = ( q.StartIndex + 1 ) % PACKET_QUEUE_SIZE + q.NumPackets-- + return packet +} + diff --git a/go/netcode/replay_protection.go b/go/netcode/replay_protection.go new file mode 100644 index 0000000..1487ead --- /dev/null +++ b/go/netcode/replay_protection.go @@ -0,0 +1,56 @@ +package netcode + +// Our type to hold replay protection of packet sequences +type ReplayProtection struct { + MostRecentSequence uint64 // last sequence recv'd + ReceivedPacket []uint64 // slice of REPLAY_PROTECTION_BUFFER_SIZE worth of packet sequences +} + +// Initializes a new ReplayProtection with the ReceivedPacket buffer elements all set to 0xFFFFFFFFFFFFFFFF +func NewReplayProtection() *ReplayProtection { + r := &ReplayProtection{} + r.ReceivedPacket = make([]uint64, REPLAY_PROTECTION_BUFFER_SIZE) + r.Reset() + return r +} + +// Clears out the most recent sequence and resets the entire packet buffer to 0xFFFFFFFFFFFFFFFF +func (r *ReplayProtection) Reset() { + r.MostRecentSequence = 0 + clearPacketBuffer(r.ReceivedPacket) +} + +// Tests that the sequence has not already been recv'd, adding it to the buffer if it's new. +func (r *ReplayProtection) AlreadyReceived(sequence uint64) int { + if sequence&(uint64(1<<63)) != 0 { + return 0 + } + + if sequence+REPLAY_PROTECTION_BUFFER_SIZE <= r.MostRecentSequence { + return 1 + } + + if sequence > r.MostRecentSequence { + r.MostRecentSequence = sequence + } + + index := sequence % REPLAY_PROTECTION_BUFFER_SIZE + + if r.ReceivedPacket[index] == 0xFFFFFFFFFFFFFFFF { + r.ReceivedPacket[index] = sequence + return 0 + } + + if r.ReceivedPacket[index] >= sequence { + return 1 + } + + r.ReceivedPacket[index] = sequence + return 0 +} + +func clearPacketBuffer(packets []uint64) { + for i := 0; i < len(packets); i += 1 { + packets[i] = 0xFFFFFFFFFFFFFFFF + } +} diff --git a/go/netcode/replay_protection_test.go b/go/netcode/replay_protection_test.go new file mode 100644 index 0000000..dca619f --- /dev/null +++ b/go/netcode/replay_protection_test.go @@ -0,0 +1,66 @@ +package netcode + +import ( + "testing" +) + +func TestReplayProtection_Reset(t *testing.T) { + r := NewReplayProtection() + for _, p := range r.ReceivedPacket { + if p != 0xFFFFFFFFFFFFFFFF { + t.Fatalf("packet was not reset") + } + } +} + +func TestReplayProtection(t *testing.T) { + r := NewReplayProtection() + for i := 0; i < 2; i+=1 { + r.Reset() + if r.MostRecentSequence != 0 { + t.Fatalf("sequence was not 0") + } + + // sequence numbers with high bit set should be ignored + sequence := uint64(1 << 63) + if r.AlreadyReceived(sequence) != 0 { + t.Fatalf("sequence numbers with high bit set should be ignored") + } + + if r.MostRecentSequence != 0 { + t.Fatalf("sequence was not 0 after high-bit check got: 0x%x\n", r.MostRecentSequence) + } + + // the first time we receive packets, they should not be already received + maxSequence := uint64(REPLAY_PROTECTION_BUFFER_SIZE * 4) + for sequence = 0; sequence < maxSequence; sequence+=1 { + if r.AlreadyReceived(sequence) != 0 { + t.Fatalf("the first time we receive packets, they should not be already received") + } + } + + // old packets outside buffer should be considered already received + if r.AlreadyReceived(0) != 1 { + t.Fatalf("old packets outside buffer should be considered already received") + } + + // packets received a second time should be flagged already received + for sequence = maxSequence - 10; sequence < maxSequence; sequence+=1 { + if r.AlreadyReceived(sequence) != 1 { + t.Fatalf("packets received a second time should be flagged already received") + } + } + + // jumping ahead to a much higher sequence should be considered not already received + if r.AlreadyReceived(maxSequence+REPLAY_PROTECTION_BUFFER_SIZE) != 0 { + t.Fatalf("jumping ahead to a much higher sequence should be considered not already received") + } + + // old packets should be considered already received + for sequence = 0; sequence < maxSequence; sequence+=1 { + if r.AlreadyReceived(sequence) != 1 { + t.Fatalf("old packets should be considered already received") + } + } + } +} \ No newline at end of file diff --git a/go/netcode/server.go b/go/netcode/server.go new file mode 100644 index 0000000..99ad84a --- /dev/null +++ b/go/netcode/server.go @@ -0,0 +1 @@ +package netcode diff --git a/go/netcode/simulator.go b/go/netcode/simulator.go new file mode 100644 index 0000000..79411da --- /dev/null +++ b/go/netcode/simulator.go @@ -0,0 +1,12 @@ +package netcode + + +const PACKET_SEND_RATE = 10.0 +const TIMEOUT_SECONDS = 5.0 +const NUM_DISCONNECT_PACKETS = 10 + + +type Context struct { + WritePacketKey []byte + ReadPacketKey []byte +} diff --git a/go/netcode/sizes.go b/go/netcode/sizes.go new file mode 100644 index 0000000..7340bd2 --- /dev/null +++ b/go/netcode/sizes.go @@ -0,0 +1,56 @@ +package netcode + +// Taken from https://raw.githubusercontent.com/google/flatbuffers/master/go/sizes.go +import ( + "unsafe" +) + +const ( + // See http://golang.org/ref/spec#Numeric_types + + // SizeUint8 is the byte size of a uint8. + SizeUint8 = 1 + // SizeUint16 is the byte size of a uint16. + SizeUint16 = 2 + // SizeUint32 is the byte size of a uint32. + SizeUint32 = 4 + // SizeUint64 is the byte size of a uint64. + SizeUint64 = 8 + + // SizeInt8 is the byte size of a int8. + SizeInt8 = 1 + // SizeInt16 is the byte size of a int16. + SizeInt16 = 2 + // SizeInt32 is the byte size of a int32. + SizeInt32 = 4 + // SizeInt64 is the byte size of a int64. + SizeInt64 = 8 + + // SizeFloat32 is the byte size of a float32. + SizeFloat32 = 4 + // SizeFloat64 is the byte size of a float64. + SizeFloat64 = 8 + + // SizeByte is the byte size of a byte. + // The `byte` type is aliased (by Go definition) to uint8. + SizeByte = 1 + + // SizeBool is the byte size of a bool. + // The `bool` type is aliased (by flatbuffers convention) to uint8. + SizeBool = 1 + + // SizeSOffsetT is the byte size of an SOffsetT. + // The `SOffsetT` type is aliased (by flatbuffers convention) to int32. + SizeSOffsetT = 4 + // SizeUOffsetT is the byte size of an UOffsetT. + // The `UOffsetT` type is aliased (by flatbuffers convention) to uint32. + SizeUOffsetT = 4 + // SizeVOffsetT is the byte size of an VOffsetT. + // The `VOffsetT` type is aliased (by flatbuffers convention) to uint16. + SizeVOffsetT = 2 +) + +// byteSliceToString converts a []byte to string without a heap allocation. +func byteSliceToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} diff --git a/go/netcode/socket.go b/go/netcode/socket.go new file mode 100644 index 0000000..5e1c6ec --- /dev/null +++ b/go/netcode/socket.go @@ -0,0 +1,71 @@ +package netcode + +import ( + "net" +) + +const ( + SOCKET_ERROR_NONE = iota + SOCKET_ERROR_CREATE_FAILED + SOCKET_ERROR_SET_NON_BLOCKING_FAILED + SOCKET_ERROR_SOCKOPT_IPV6_ONLY_FAILED + SOCKET_ERROR_SOCKOPT_RCVBUF_FAILED + SOCKET_ERROR_SOCKOPT_SNDBUF_FAILED + SOCKET_ERROR_BIND_IPV4_FAILED + SOCKET_ERROR_BIND_IPV6_FAILED + SOCKET_ERROR_GET_SOCKNAME_IPV4_FAILED + SOCKET_ERROR_GET_SOCKNAME_IPV6_FAILED +) + +type Socket struct { + Address *net.UDPAddr + Conn *net.UDPConn +} + +func NewSocket() *Socket { + s := &Socket{} + return s +} + +func (s *Socket) Create(address *net.UDPAddr, sendsize, recvsize int) error { + conn, err := net.ListenUDP(address.Network(), address) + if err != nil { + return err + } + + if err := conn.SetReadBuffer(recvsize); err != nil { + return err + } + + if err := conn.SetWriteBuffer(sendsize); err != nil { + return err + } + + s.Conn = conn + return nil +} + +func (s *Socket) Send(destination *net.UDPAddr, data []byte) error { + if s.Conn == nil { + return nil + } + + length, err := s.Conn.WriteTo(data, destination) + if err != nil { + return err + } + + if length != len(data) { + // error writing all data + return nil + } + return nil +} + +func (s *Socket) Recv(source *net.Addr, data []byte, maxsize uint) error { + return nil +} + +func (s *Socket) Destroy() { + s.Conn.Close() +}