Skip to content

Commit

Permalink
Further improve ship tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DerAndereAndi committed Jan 4, 2024
1 parent 5de4005 commit 194d3c8
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 57 deletions.
9 changes: 3 additions & 6 deletions ship/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,10 @@ func (c *ShipConnection) ApprovePendingHandshake() {

// HELLO_OK
c.stopHandshakeTimer()
c.setState(SmeHelloStateReadyInit, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateReadyInit)

// TODO: check if we need to do some validations before moving on to the next state
c.setState(SmeHelloStateOk, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateOk)
}

// invoked when pairing for a pending request is denied
Expand All @@ -129,8 +127,7 @@ func (c *ShipConnection) AbortPendingHandshake() {
// TODO: Move this into hs_hello.go and add tests

c.stopHandshakeTimer()
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
}

// report removing a connection
Expand Down
12 changes: 8 additions & 4 deletions ship/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,10 @@ func (c *ShipConnection) handleState(timeout bool, message []byte) {
c.handshakeProtocol_smeProtHStateClientListenChoice(message)

case SmeProtHStateClientOk:
c.setState(SmePinStateCheckInit, nil)
c.handleState(false, nil)
c.setAndHandleState(SmePinStateCheckInit)

case SmeProtHStateServerOk:
c.setState(SmePinStateCheckInit, nil)
c.handleState(false, nil)
c.setAndHandleState(SmePinStateCheckInit)

// smePinState

Expand All @@ -192,6 +190,12 @@ func (c *ShipConnection) handleState(timeout bool, message []byte) {
}
}

// set a state and trigger handling it
func (c *ShipConnection) setAndHandleState(state ShipMessageExchangeState) {
c.setState(state, nil)
c.handleState(false, nil)
}

// SHIP handshake is approved, now set the new state and the SPINE read handler
func (c *ShipConnection) approveHandshake() {
// Report to SPINE local device about this remote device connection
Expand Down
63 changes: 63 additions & 0 deletions ship/hs_access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ func (s *AccessSuite) Test_Request() {
shutdownTest(sut)
}

func (s *AccessSuite) Test_Request_Invalid() {
sut, _ := initTest(ShipRoleClient)

sut.setState(SmeAccessMethodsRequest, nil)

accessMsg := model.MessageProtocolHandshake{}
msg, err := sut.shipMessage(model.MsgTypeControl, accessMsg)
assert.Nil(s.T(), err)
assert.NotNil(s.T(), msg)

sut.handleState(false, msg)

assert.Equal(s.T(), false, sut.handshakeTimerRunning)
assert.Equal(s.T(), SmeStateError, sut.getState())

shutdownTest(sut)
}

func (s *AccessSuite) Test_Methods_Ok() {
sut, data := initTest(ShipRoleClient)

Expand All @@ -74,6 +92,27 @@ func (s *AccessSuite) Test_Methods_Ok() {
shutdownTest(sut)
}

func (s *AccessSuite) Test_Methods_NoID() {
sut, data := initTest(ShipRoleClient)

sut.setState(SmeAccessMethodsRequest, nil)

accessMsg := model.AccessMethods{
AccessMethods: model.AccessMethodsType{},
}
msg, err := sut.shipMessage(model.MsgTypeControl, accessMsg)
assert.Nil(s.T(), err)
assert.NotNil(s.T(), msg)

sut.handleState(false, msg)

assert.Equal(s.T(), false, sut.handshakeTimerRunning)
assert.Equal(s.T(), SmeStateError, sut.getState())
assert.Nil(s.T(), data.lastMessage())

shutdownTest(sut)
}

func (s *AccessSuite) Test_Methods_WrongShipID() {
sut, data := initTest(ShipRoleClient)

Expand All @@ -96,3 +135,27 @@ func (s *AccessSuite) Test_Methods_WrongShipID() {

shutdownTest(sut)
}

func (s *AccessSuite) Test_Methods_NoShipID() {
sut, _ := initTest(ShipRoleClient)

sut.remoteShipID = ""

sut.setState(SmeAccessMethodsRequest, nil)

accessMsg := model.AccessMethods{
AccessMethods: model.AccessMethodsType{
Id: util.Ptr(""),
},
}
msg, err := sut.shipMessage(model.MsgTypeControl, accessMsg)
assert.Nil(s.T(), err)
assert.NotNil(s.T(), msg)

sut.handleState(false, msg)

assert.Equal(s.T(), false, sut.handshakeTimerRunning)
assert.Equal(s.T(), SmeStateComplete, sut.getState())

shutdownTest(sut)
}
45 changes: 15 additions & 30 deletions ship/hs_hello.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ import (
// SME_HELLO_STATE_READY_INIT
func (c *ShipConnection) handshakeHello_Init() {
if err := c.handshakeHelloSend(model.ConnectionHelloPhaseTypeReady, tHelloInit, false); err != nil {
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
return
}

Expand All @@ -30,8 +29,7 @@ func (c *ShipConnection) handshakeHello_ReadyListen(timeout bool, message []byte

var helloReturnMsg model.ConnectionHello
if err := c.processShipJsonMessage(message, &helloReturnMsg); err != nil {
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
return
}

Expand Down Expand Up @@ -65,25 +63,22 @@ func (c *ShipConnection) handshakeHello_ReadyListen(timeout bool, message []byte
// TODO: what to do if this is false?

case model.ConnectionHelloPhaseTypeAborted:
c.setState(SmeHelloStateRemoteAbortDone, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateRemoteAbortDone)

return

default:
// don't accept any other responses
logging.Log.Errorf("Unexpected connection hello phase: %s", hello.Phase)
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
return
}

c.handleState(false, nil)
}

func (c *ShipConnection) handshakeHello_ReadyTimeout() {
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
}

// SME_HELLO_ABORT
Expand All @@ -95,8 +90,7 @@ func (c *ShipConnection) handshakeHello_Abort() {
return
}

c.setState(SmeHelloStateAbortDone, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbortDone)
}

// SME_HELLO_PENDING_INIT
Expand All @@ -109,8 +103,7 @@ func (c *ShipConnection) handshakeHello_PendingInit() {
c.setState(SmeHelloStatePendingListen, nil)

if !c.serviceDataProvider.AllowWaitingForTrust(c.remoteShipID) {
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
}
}

Expand All @@ -130,8 +123,7 @@ func (c *ShipConnection) handshakeHello_PendingListen(timeout bool, message []by

var helloReturnMsg model.ConnectionHello
if err := c.processShipJsonMessage(message, &helloReturnMsg); err != nil {
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
return
}

Expand All @@ -140,8 +132,7 @@ func (c *ShipConnection) handshakeHello_PendingListen(timeout bool, message []by
switch hello.Phase {
case model.ConnectionHelloPhaseTypeReady:
if hello.Waiting == nil {
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
return
}

Expand All @@ -163,8 +154,7 @@ func (c *ShipConnection) handshakeHello_PendingListen(timeout bool, message []by
if newDuration < tHelloProlongMin {
// I interpret 13.4.4.1.3 Page 64 Line 1550-1553 as this resulting in a timeout state
// TODO: verify this
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
}

case model.ConnectionHelloPhaseTypePending:
Expand All @@ -188,8 +178,7 @@ func (c *ShipConnection) handshakeHello_PendingListen(timeout bool, message []by
if newDuration < tHelloProlongMin {
// I interpret 13.4.4.1.3 Page 64 Line 1557-1560 as this resulting in a timeout state
// TODO: verify this
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
}

return
Expand All @@ -204,19 +193,16 @@ func (c *ShipConnection) handshakeHello_PendingListen(timeout bool, message []by
return
}

c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)

case model.ConnectionHelloPhaseTypeAborted:
c.setState(SmeHelloStateRemoteAbortDone, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateRemoteAbortDone)
return

default:
// don't accept any other responses
logging.Log.Errorf("Unexpected connection hello phase: %s", hello.Phase)
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
return
}

Expand All @@ -235,8 +221,7 @@ func (c *ShipConnection) handshakeHello_PendingProlongationRequest() {

func (c *ShipConnection) handshakeHello_PendingTimeout() {
if c.getHandshakeTimerType() != timeoutTimerTypeSendProlongationRequest {
c.setState(SmeHelloStateAbort, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloStateAbort)
return
}

Expand Down
6 changes: 2 additions & 4 deletions ship/hs_init.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ func (c *ShipConnection) handshakeInit_cmiStateServerWait(message []byte) {
return
}

c.setState(SmeHelloState, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloState)
}

// CMI_STATE_CLIENT_WAIT
Expand All @@ -51,8 +50,7 @@ func (c *ShipConnection) handshakeInit_cmiStateClientWait(message []byte) {
return
}

c.setState(SmeHelloState, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeHelloState)
}

// CMI_STATE_SERVER_EVALUATE
Expand Down
3 changes: 1 addition & 2 deletions ship/hs_pin.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ func (c *ShipConnection) handshakePin_smePinStateCheckListen(message []byte) {

switch connectionPinState.ConnectionPinState.PinState {
case model.PinStateTypeNone:
c.setState(SmePinStateCheckOk, nil)
c.handleState(false, nil)
c.setAndHandleState(SmePinStateCheckOk)
case model.PinStateTypeRequired:
c.endHandshakeWithError(errors.New("Got pin state: required (unsupported)"))
case model.PinStateTypeOptional:
Expand Down
6 changes: 2 additions & 4 deletions ship/hs_prot.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ func (c *ShipConnection) handshakeProtocol_smeProtHStateServerListenConfirm(mess

c.stopHandshakeTimer()

c.setState(SmeProtHStateServerOk, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeProtHStateServerOk)
}

func (c *ShipConnection) handshakeProtocol_smeProtHStateClientInit() {
Expand Down Expand Up @@ -158,8 +157,7 @@ func (c *ShipConnection) handshakeProtocol_smeProtHStateClientListenChoice(messa
return
}

c.setState(SmeProtHStateClientOk, nil)
c.handleState(false, nil)
c.setAndHandleState(SmeProtHStateClientOk)
}

func (c *ShipConnection) abortProtocolHandshake(err model.MessageProtocolHandshakeErrorErrorType) {
Expand Down
16 changes: 9 additions & 7 deletions ship/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (w *websocketConnection) writeShipPump() {
_ = w.conn.SetWriteDeadline(time.Now().Add(writeWait))
w.muxConWrite.Unlock()
if !ok {
logging.Log.Debug(w.remoteSki, "Ship write channel closed")
logging.Log.Debug(w.remoteSki, "ship write channel closed")
// The write channel has been closed
_ = w.writeMessage(websocket.CloseMessage, []byte{})
return
Expand All @@ -117,9 +117,7 @@ func (w *websocketConnection) writeShipPump() {
return
}

logging.Log.Debug(w.remoteSki, "error writing to websocket: ", err)
w.setConnClosedError(err)
w.dataProcessing.ReportConnectionError(err)
w.closeWithError(err, "error writing to websocket: ")
return
}

Expand Down Expand Up @@ -148,13 +146,17 @@ func (w *websocketConnection) handlePing() {
_ = w.conn.SetWriteDeadline(time.Now().Add(writeWait))
w.muxConWrite.Unlock()
if err := w.writeMessage(websocket.PingMessage, nil); err != nil {
logging.Log.Debug(w.remoteSki, "error writing to websocket: ", err)
w.setConnClosedError(err)
w.dataProcessing.ReportConnectionError(err)
w.closeWithError(err, "error writing to websocket: ")
return
}
}

func (w *websocketConnection) closeWithError(err error, reason string) {
logging.Log.Debug(w.remoteSki, reason, err)
w.setConnClosedError(err)
w.dataProcessing.ReportConnectionError(err)
}

// readShipPump checks for messages from the websocket connection
func (w *websocketConnection) readShipPump() {
_ = w.conn.SetReadDeadline(time.Now().Add(pongWait))
Expand Down
Loading

0 comments on commit 194d3c8

Please sign in to comment.