Skip to content

Commit

Permalink
Barnes's review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ekr committed Dec 21, 2017
1 parent 83d07fc commit a9a7197
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 24 deletions.
14 changes: 6 additions & 8 deletions client-state-machine.go
Expand Up @@ -441,10 +441,8 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
&serverKeyShare,
})
if err != nil {
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error processing extensions [%v]", err)
return nil, nil, AlertDecodeError
}
logf(logTypeHandshake, "[ClientWaitSH] Error processing extensions [%v]", err)
return nil, nil, AlertDecodeError
}

if foundExts[ExtensionTypePreSharedKey] && (serverPSK.SelectedIdentity == 0) {
Expand Down Expand Up @@ -564,7 +562,7 @@ func (state ClientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
}

ee := EncryptedExtensionsBody{}
if err := SafeUnmarshal(&ee, hm.body); err != nil {
if err := safeUnmarshal(&ee, hm.body); err != nil {
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
Expand Down Expand Up @@ -742,7 +740,7 @@ func (state ClientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
}

cert := &CertificateBody{}
if err := SafeUnmarshal(cert, hm.body); err != nil {
if err := safeUnmarshal(cert, hm.body); err != nil {
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
Expand Down Expand Up @@ -799,7 +797,7 @@ func (state ClientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
}

certVerify := CertificateVerifyBody{}
if err := SafeUnmarshal(&certVerify, hm.body); err != nil {
if err := safeUnmarshal(&certVerify, hm.body); err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
Expand Down Expand Up @@ -879,7 +877,7 @@ func (state ClientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)

fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
if err := SafeUnmarshal(fin, hm.body); err != nil {
if err := safeUnmarshal(fin, hm.body); err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
Expand Down
4 changes: 2 additions & 2 deletions extensions.go
Expand Up @@ -86,7 +86,7 @@ func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, err
return nil, fmt.Errorf("Duplicate extension of type [%v]", dst.Type())
}

err := SafeUnmarshal(dst, ext.ExtensionData)
err := safeUnmarshal(dst, ext.ExtensionData)
if err != nil {
return nil, err
}
Expand All @@ -102,7 +102,7 @@ func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, err
func (el ExtensionList) Find(dst ExtensionBody) (bool, error) {
for _, ext := range el {
if ext.ExtensionType == dst.Type() {
err := SafeUnmarshal(dst, ext.ExtensionData)
err := safeUnmarshal(dst, ext.ExtensionData)
if err != nil {
return true, err
}
Expand Down
10 changes: 5 additions & 5 deletions handshake-layer.go
Expand Up @@ -99,7 +99,7 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
}

err := SafeUnmarshal(body, hm.body)
err := safeUnmarshal(body, hm.body)
return body, err
}

Expand Down Expand Up @@ -480,17 +480,17 @@ func decodeUint(in []byte, size int) (uint64, []byte) {
return val, in[size:]
}

type marshalledPdu interface {
type marshalledPDU interface {
Marshal() ([]byte, error)
Unmarshal(data []byte) (int, error)
}

func SafeUnmarshal(pdu marshalledPdu, data []byte) error {
l, err := pdu.Unmarshal(data)
func safeUnmarshal(pdu marshalledPDU, data []byte) error {
read, err := pdu.Unmarshal(data)
if err != nil {
return err
}
if len(data) != l {
if len(data) != read {
return fmt.Errorf("Invalid encoding: Extra data not consumed")
}
return nil
Expand Down
10 changes: 5 additions & 5 deletions handshake-messages_test.go
Expand Up @@ -705,13 +705,13 @@ func TestEndOfEarlyDataMarshalUnmarshal(t *testing.T) {
assertDeepEquals(t, eoed, endOfEarlyDataValidIn)
}

func TestSafeUnmarshal(t *testing.T) {
func TestsafeUnmarshal(t *testing.T) {
chValid := unhex(chValidHex)
tooLong := append(chValid, 0)
var ch ClientHelloBody

// Check that SafeUnmarshal works normally
err := SafeUnmarshal(&ch, chValid)
// Check that safeUnmarshal works normally
err := safeUnmarshal(&ch, chValid)
assertNotError(t, err, "Failed to unmarshal ClientHello")

// Test successful unmarshal
Expand All @@ -720,7 +720,7 @@ func TestSafeUnmarshal(t *testing.T) {
assertEquals(t, read, len(chValid))
assertDeepEquals(t, ch, chValidIn)

// Now test that SafeUnmarshal barfs
err = SafeUnmarshal(&ch, tooLong)
// Now test that safeUnmarshal barfs
err = safeUnmarshal(&ch, tooLong)
assertError(t, err, "Unmarshalled something too long")
}
8 changes: 4 additions & 4 deletions server-state-machine.go
Expand Up @@ -93,7 +93,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
}

ch := &ClientHelloBody{}
if err := SafeUnmarshal(ch, hm.body); err != nil {
if err := safeUnmarshal(ch, hm.body); err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
Expand Down Expand Up @@ -888,7 +888,7 @@ func (state ServerStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
}

cert := &CertificateBody{}
if err := SafeUnmarshal(cert, hm.body); err != nil {
if err := safeUnmarshal(cert, hm.body); err != nil {
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
return nil, nil, AlertDecodeError
}
Expand Down Expand Up @@ -964,7 +964,7 @@ func (state ServerStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
}

certVerify := &CertificateVerifyBody{}
if err := SafeUnmarshal(certVerify, hm.body); err != nil {
if err := safeUnmarshal(certVerify, hm.body); err != nil {
logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err)
return nil, nil, AlertDecodeError
}
Expand Down Expand Up @@ -1038,7 +1038,7 @@ func (state ServerStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
}

fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()}
if err := SafeUnmarshal(fin, hm.body); err != nil {
if err := safeUnmarshal(fin, hm.body); err != nil {
logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err)
return nil, nil, AlertDecodeError
}
Expand Down

0 comments on commit a9a7197

Please sign in to comment.