Skip to content
Permalink
Browse files Browse the repository at this point in the history
TLS protocol: add handshake state validation
  • Loading branch information
spouliot authored and migueldeicaza committed Mar 6, 2015
1 parent 48992d4 commit 1509226
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 21 deletions.
Expand Up @@ -129,6 +129,7 @@ private HandshakeMessage createClientHandshakeMessage(HandshakeType type)
HandshakeType type, byte[] buffer)
{
ClientContext context = (ClientContext)this.context;
var last = context.LastHandshakeMsg;

switch (type)
{
Expand All @@ -148,30 +149,52 @@ private HandshakeMessage createClientHandshakeMessage(HandshakeType type)
return null;

case HandshakeType.ServerHello:
if (last != HandshakeType.HelloRequest)
break;
return new TlsServerHello(this.context, buffer);

// Optional
case HandshakeType.Certificate:
if (last != HandshakeType.ServerHello)
break;
return new TlsServerCertificate(this.context, buffer);

// Optional
case HandshakeType.ServerKeyExchange:
return new TlsServerKeyExchange(this.context, buffer);
// only for RSA_EXPORT
if (last == HandshakeType.Certificate && context.Current.Cipher.IsExportable)
return new TlsServerKeyExchange(this.context, buffer);
break;

// Optional
case HandshakeType.CertificateRequest:
return new TlsServerCertificateRequest(this.context, buffer);
if (last == HandshakeType.ServerKeyExchange || last == HandshakeType.Certificate)
return new TlsServerCertificateRequest(this.context, buffer);
break;

case HandshakeType.ServerHelloDone:
return new TlsServerHelloDone(this.context, buffer);
if (last == HandshakeType.CertificateRequest || last == HandshakeType.Certificate || last == HandshakeType.ServerHello)
return new TlsServerHelloDone(this.context, buffer);
break;

case HandshakeType.Finished:
return new TlsServerFinished(this.context, buffer);

// depends if a full (ServerHelloDone) or an abbreviated handshake (ServerHello) is being done
bool check = context.AbbreviatedHandshake ? (last == HandshakeType.ServerHello) : (last == HandshakeType.ServerHelloDone);
// ChangeCipherSpecDone is not an handshake message (it's a content type) but still needs to be happens before finished
if (check && context.ChangeCipherSpecDone) {
context.ChangeCipherSpecDone = false;
return new TlsServerFinished (this.context, buffer);
}
break;

default:
throw new TlsException(
AlertDescription.UnexpectedMessage,
String.Format(CultureInfo.CurrentUICulture,
"Unknown server handshake message received ({0})",
type.ToString()));
}
throw new TlsException (AlertDescription.HandshakeFailiure, String.Format ("Protocol error, unexpected protocol transition from {0} to {1}", last, type));
}

#endregion
Expand Down
2 changes: 2 additions & 0 deletions mcs/class/Mono.Security/Mono.Security.Protocol.Tls/Context.cs
Expand Up @@ -122,6 +122,8 @@ public bool ProtocolNegotiated
set { this.protocolNegotiated = value; }
}

public bool ChangeCipherSpecDone { get; set; }

public SecurityProtocolType SecurityProtocol
{
get
Expand Down
Expand Up @@ -88,6 +88,8 @@ protected virtual void ProcessChangeCipherSpec ()
} else {
ctx.StartSwitchingSecurityParameters (false);
}

ctx.ChangeCipherSpecDone = true;
}

public virtual HandshakeMessage GetMessage(HandshakeType type)
Expand Down Expand Up @@ -348,9 +350,6 @@ private void InternalReceiveRecordCallback(IAsyncResult asyncResult)
// Try to read the Record Content Type
int type = internalResult.InitialBuffer[0];

// Set last handshake message received to None
this.context.LastHandshakeMsg = HandshakeType.ClientHello;

ContentType contentType = (ContentType)type;
byte[] buffer = this.ReadRecordBuffer(type, record);
if (buffer == null)
Expand Down Expand Up @@ -458,9 +457,6 @@ public byte[] ReceiveRecord(Stream record)
// Try to read the Record Content Type
int type = recordTypeBuffer[0];

// Set last handshake message received to None
this.context.LastHandshakeMsg = HandshakeType.ClientHello;

ContentType contentType = (ContentType)type;
byte[] buffer = this.ReadRecordBuffer(type, record);
if (buffer == null)
Expand Down
Expand Up @@ -33,6 +33,8 @@ namespace Mono.Security.Protocol.Tls
{
internal class ServerRecordProtocol : RecordProtocol
{
TlsClientCertificate cert;

#region Constructors

public ServerRecordProtocol(
Expand Down Expand Up @@ -93,30 +95,45 @@ protected override void ProcessHandshakeMessage(TlsStream handMsg)
private HandshakeMessage createClientHandshakeMessage(
HandshakeType type, byte[] buffer)
{
var last = context.LastHandshakeMsg;
switch (type)
{
case HandshakeType.ClientHello:
return new TlsClientHello(this.context, buffer);

case HandshakeType.Certificate:
return new TlsClientCertificate(this.context, buffer);
if (last != HandshakeType.ClientHello)
break;
cert = new TlsClientCertificate(this.context, buffer);
return cert;

case HandshakeType.ClientKeyExchange:
return new TlsClientKeyExchange(this.context, buffer);
if (last == HandshakeType.ClientHello || last == HandshakeType.Certificate)
return new TlsClientKeyExchange(this.context, buffer);
break;

case HandshakeType.CertificateVerify:
return new TlsClientCertificateVerify(this.context, buffer);
if (last == HandshakeType.ClientKeyExchange && cert != null)
return new TlsClientCertificateVerify(this.context, buffer);
break;

case HandshakeType.Finished:
return new TlsClientFinished(this.context, buffer);

// Certificates are optional, but if provided, they should send a CertificateVerify
bool check = (cert == null) ? (last == HandshakeType.ClientKeyExchange) : (last == HandshakeType.CertificateVerify);
// ChangeCipherSpecDone is not an handshake message (it's a content type) but still needs to be happens before finished
if (check && context.ChangeCipherSpecDone) {
context.ChangeCipherSpecDone = false;
return new TlsClientFinished(this.context, buffer);
}
break;

default:
throw new TlsException(
AlertDescription.UnexpectedMessage,
String.Format(CultureInfo.CurrentUICulture,
"Unknown server handshake message received ({0})",
type.ToString()));
throw new TlsException(AlertDescription.UnexpectedMessage, String.Format(CultureInfo.CurrentUICulture,
"Unknown server handshake message received ({0})",
type.ToString()));
break;
}
throw new TlsException (AlertDescription.HandshakeFailiure, String.Format ("Protocol error, unexpected protocol transition from {0} to {1}", last, type));
}

private HandshakeMessage createServerHandshakeMessage(
Expand Down

0 comments on commit 1509226

Please sign in to comment.