Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 83 additions & 84 deletions go/netcode/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,25 @@ func NewClient(config *Config) *Client {
c := &Client{config: config}
c.lastPacketRecvTime = time.Now().Add(-time.Second)
c.lastPacketSendTime = time.Now().Add(-time.Second)
c.state = StateDisconnected
c.setState(StateDisconnected)
c.shouldDisconnect = false
c.challengeData = make([]byte, CHALLENGE_TOKEN_BYTES)

c.replayProtection = NewReplayProtection()
c.connectToken = NewConnectToken()
c.context = &Context{}
c.packetQueue = NewPacketQueue()
c.packetQueue = NewPacketQueue(PACKET_QUEUE_SIZE)
return c
}

func (c *Client) State() ClientState {
func (c *Client) getState() ClientState {
return c.state
}

func (c *Client) setState(newState ClientState) {
c.state = newState
}

func (c *Client) Init() error {
c.startTime = time.Now()
return c.connectToken.Generate(c.config, c.sequence)
Expand All @@ -104,11 +108,9 @@ func (c *Client) Connect() error {
c.serverIndex = 0
c.serverAddress = &c.connectToken.ServerAddrs[0]

c.conn = NewNetcodeConn(c.serverAddress)
if err = c.conn.Create(); err != nil {
return err
}
if err != nil {
c.conn = NewNetcodeConn()
c.conn.SetRecvHandler(c.onPacketData)
if err = c.conn.Dial(c.serverAddress); err != nil {
return err
}

Expand All @@ -117,7 +119,36 @@ func (c *Client) Connect() error {

c.Reset()
c.setState(StateSendingConnectionRequest)
return err
return nil
}

func (c *Client) Close() error {
return c.conn.Close()
}

func (c *Client) Reset() {
c.lastPacketRecvTime = time.Now().Add(-time.Second)
c.lastPacketSendTime = time.Now().Add(-time.Second)
c.shouldDisconnect = false
c.shouldDisconnectState = StateDisconnected
c.challengeData = make([]byte, CHALLENGE_TOKEN_BYTES)
c.challengeSequence = 0
c.replayProtection.Reset()
}

func (c *Client) resetConnectionData(newState ClientState) {
c.sequence = 0
c.clientIndex = 0
c.maxClients = 0
c.startTime = time.Now()
c.serverIndex = 0
c.serverAddress = nil
c.connectToken = nil
c.context = nil
c.setState(newState)
c.Reset()
c.packetQueue.Clear()

}

func (c *Client) connectNextServer() bool {
Expand All @@ -136,13 +167,13 @@ func (c *Client) connectNextServer() bool {
func (c *Client) Update(t time.Time) {
log.Printf("Update\n")
c.time = t
c.Recv()

if err := c.Send(); err != nil {
if err := c.send(); err != nil {
log.Fatalf("error sending packet: %s\n", err)
}

if c.state > StateDisconnected && c.state < StateConnected {
state := c.getState()
if state > StateDisconnected && state < StateConnected {
expire := c.connectToken.ExpireTimestamp - c.connectToken.CreateTimestamp
if uint64(c.startTime.Unix())+expire <= uint64(c.time.Unix()) {
c.Disconnect(StateTokenExpired, false)
Expand All @@ -161,7 +192,7 @@ func (c *Client) Update(t time.Time) {

timeout := c.lastPacketRecvTime.Unix() + int64(c.connectToken.TimeoutSeconds)

switch c.state {
switch c.getState() {
case StateSendingConnectionRequest:
if timeout < c.time.Unix() {
if c.connectNextServer() {
Expand All @@ -184,11 +215,11 @@ func (c *Client) Update(t time.Time) {
}

func (c *Client) Disconnect(reason ClientState, sendDisconnect bool) error {
if c.state <= StateDisconnected {
if c.getState() <= StateDisconnected {
return nil
}

if sendDisconnect && c.state > StateDisconnected {
if sendDisconnect && c.getState() > StateDisconnected {
for i := 0; i < NUM_DISCONNECT_PACKETS; i += 1 {
packet := &DisconnectPacket{}
c.sendPacket(packet)
Expand All @@ -198,13 +229,22 @@ func (c *Client) Disconnect(reason ClientState, sendDisconnect bool) error {
return nil
}

func (c *Client) Send() error {
func (c *Client) SendData(payloadData []byte) error {
log.Printf("sending data\n")
if c.getState() != StateConnected {
return errors.New("client not connected, unable to send packet")
}
p := NewPayloadPacket(payloadData)
return c.sendPacket(p)
}

func (c *Client) send() error {
// check our send rate prior to bother sending
if c.lastPacketSendTime.Unix()+(1/PACKET_SEND_RATE) >= c.time.Unix() {
return nil
}

switch c.state {
switch c.getState() {
case StateSendingConnectionRequest:
p := &RequestPacket{}
p.VersionInfo = c.connectToken.VersionInfo
Expand All @@ -231,15 +271,6 @@ func (c *Client) Send() error {
return nil
}

func (c *Client) SendData(payloadData []byte) error {
log.Printf("sending data\n")
if c.state != StateConnected {
return errors.New("client not connected, unable to send packet")
}
p := NewPayloadPacket(payloadData)
return c.sendPacket(p)
}

func (c *Client) sendPacket(packet Packet) error {
buffer := NewBuffer(MAX_PACKET_BYTES)
packet_bytes, err := packet.Write(buffer, c.connectToken.ProtocolId, c.sequence, c.context.WritePacketKey)
Expand All @@ -261,12 +292,23 @@ func (c *Client) RecvData() []byte {
return p.PayloadData
}

func (c *Client) Recv() error {
// called asynchronously whenever a new packet of data arrives from the NetcodeConn.
func (c *Client) onPacketData(packetData []byte, from *net.UDPAddr) {
var err error
var size int
var packetData []byte
var sequence uint64

if c.serverAddress.String() != from.String() || c.serverAddress.Port != from.Port {
log.Printf("unknown address sent us data")
return
}

size = len(packetData)
if len(packetData) == 0 {
log.Printf("unable to read from socket, 0 bytes returned")
return
}

allowedPackets := make([]byte, ConnectionNumPackets)
allowedPackets[ConnectionDenied] = 1
allowedPackets[ConnectionChallenge] = 1
Expand All @@ -275,38 +317,28 @@ func (c *Client) Recv() error {
allowedPackets[ConnectionDisconnect] = 1

timestamp := uint64(time.Now().Unix())
log.Printf("read %d from socket\n", len(packetData))

for {
log.Printf("recv loop entered\n")
//packetData := make([]byte, MAX_PACKET_BYTES)
packetData = c.conn.Recv()
size = len(packetData)
if len(packetData) == 0 {
return errors.New("unable to read from socket, 0 bytes returned")
}
log.Printf("read %d from socket\n", len(packetData))

packet := NewPacket(packetData)
packetBuffer := NewBufferFromBytes(packetData)
log.Printf("calling packet %#v .Read\n", packetBuffer)
if err = packet.Read(packetBuffer, size, c.config.ProtocolId, timestamp, c.context.ReadPacketKey, nil, allowedPackets, c.replayProtection); err != nil {
return err
}

c.processPacket(packet, sequence)
packet := NewPacket(packetData)
packetBuffer := NewBufferFromBytes(packetData)
if err = packet.Read(packetBuffer, size, c.config.ProtocolId, timestamp, c.context.ReadPacketKey, nil, allowedPackets, c.replayProtection); err != nil {
log.Printf("error reading packet: %s\n", err)
}

c.processPacket(packet, sequence)
}

func (c *Client) processPacket(packet Packet, sequence uint64) {
log.Printf("processing packet of type: %s\n", packetTypeMap[packet.GetType()])
state := c.getState()
switch packet.GetType() {
case ConnectionDenied:
if c.state == StateSendingConnectionRequest || c.state == StateSendingConnectionResponse {
if state == StateSendingConnectionRequest || state == StateSendingConnectionResponse {
c.shouldDisconnect = true
c.shouldDisconnectState = StateConnectionDenied
}
case ConnectionChallenge:
if c.state != StateSendingConnectionRequest {
if state != StateSendingConnectionRequest {
return
}

Expand All @@ -324,20 +356,20 @@ func (c *Client) processPacket(packet Packet, sequence uint64) {
}

log.Printf("client received connection keep alive packet from server\n")
if c.state == StateSendingConnectionResponse {
if state == StateSendingConnectionResponse {
c.clientIndex = p.ClientIndex
c.maxClients = p.MaxClients
c.setState(StateConnected)
log.Printf("client connected to server\n")
}
case ConnectionPayload:
if c.state != StateConnected {
if state != StateConnected {
return
}
log.Printf("got payload packet.\n")
c.packetQueue.Push(packet)
case ConnectionDisconnect:
if c.state != StateConnected {
if state != StateConnected {
return
}
c.shouldDisconnect = true
Expand All @@ -348,36 +380,3 @@ func (c *Client) processPacket(packet Packet, sequence uint64) {
// always update last packet recv time for valid packets.
c.lastPacketRecvTime = c.time
}

func (c *Client) Close() error {
return c.conn.Close()
}

func (c *Client) setState(newState ClientState) {
c.state = newState
}

func (c *Client) Reset() {
c.lastPacketRecvTime = time.Now().Add(-time.Second)
c.lastPacketSendTime = time.Now().Add(-time.Second)
c.shouldDisconnect = false
c.shouldDisconnectState = StateDisconnected
c.challengeData = make([]byte, CHALLENGE_TOKEN_BYTES)
c.challengeSequence = 0
c.replayProtection.Reset()
}

func (c *Client) resetConnectionData(newState ClientState) {
c.sequence = 0
c.clientIndex = 0
c.maxClients = 0
c.startTime = time.Now()
c.serverIndex = 0
c.serverAddress = nil
c.connectToken = nil
c.context = nil
c.state = newState
c.Reset()
c.packetQueue.Clear()

}
10 changes: 7 additions & 3 deletions go/netcode/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ func TestClientCommunications(t *testing.T) {

packetData := make([]byte, 1200)
count := 0
// fake game loop
for {

if count == 10 {
t.Fatalf("error communicating with server")
}
timestamp := time.Now()
c.Update(timestamp)
fmt.Println("sending update")
if c.state == StateConnected {
if c.getState() == StateConnected {
c.SendData(packetData)
fmt.Println("sent data")
}
Expand All @@ -78,7 +81,8 @@ func TestClientCommunications(t *testing.T) {
if payload := c.RecvData(); payload == nil {
break
} else {
fmt.Printf("recv'd payload: %#v\n", payload)
fmt.Printf("recv'd payload: of %d bytes\n", len(payload))
return
}
}
time.Sleep(deltaTime)
Expand Down
5 changes: 4 additions & 1 deletion go/netcode/connect_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package netcode

import (
"errors"
"log"
"strings"
"time"
)
Expand Down Expand Up @@ -36,6 +37,7 @@ func NewConnectToken() *ConnectToken {
}

// Generates the token and private token data with the supplied config values and sequence id.
// This will also write and encrypt the private token
func (token *ConnectToken) Generate(config *Config, sequence uint64) error {
token.CreateTimestamp = uint64(time.Now().Unix())
token.ExpireTimestamp = token.CreateTimestamp + config.TokenExpiry
Expand Down Expand Up @@ -125,8 +127,9 @@ func ReadConnectToken(tokenBuffer []byte) (*ConnectToken, error) {
if token.Sequence, err = buffer.GetUint64(); err != nil {
return nil, errors.New("read connect data has bad sequence " + err.Error())
}
log.Printf("sequence: %x\n", token.Sequence)

if privateData, err = buffer.GetBytes(CONNECT_TOKEN_PRIVATE_BYTES + MAC_BYTES); err != nil {
if privateData, err = buffer.GetBytes(CONNECT_TOKEN_PRIVATE_BYTES); err != nil {
return nil, errors.New("read connect data has bad private data " + err.Error())
}

Expand Down
15 changes: 12 additions & 3 deletions go/netcode/connect_token_private_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ func testComparePrivateTokens(token1, token2 *ConnectTokenPrivate, t *testing.T)
token1Servers := token1.ServerAddrs
token2Servers := token2.ServerAddrs
for i := 0; i < len(token1.ServerAddrs); i += 1 {
if bytes.Compare([]byte(token1Servers[i].IP), []byte(token2Servers[i].IP)) != 0 {
t.Fatalf("server addresses did not match: expected %v got %v\n", token1Servers[i], token2Servers[i])
}
testCompareAddrs(token1Servers[i], token2Servers[i], t)
}

if bytes.Compare(token1.ClientKey, token2.ClientKey) != 0 {
Expand All @@ -90,3 +88,14 @@ func testComparePrivateTokens(token1, token2 *ConnectTokenPrivate, t *testing.T)
t.Fatalf("ServerKey do not match expected %v got %v", token1.ServerKey, token2.ServerKey)
}
}

func testCompareAddrs(addr1, addr2 net.UDPAddr, t *testing.T) {
if addr1.IP.String() != addr2.IP.String() {
t.Fatalf("ip addresses were not equal: %s and %s\n", addr1.IP.String(), addr2.IP.String())
}

if addr1.Port != addr2.Port {
t.Fatalf("server ports did not match: expected %s got %s\n", addr1.Port, addr2.Port)
}

}
Loading