Skip to content

Commit

Permalink
Client-side tls1.3 ticket/PSK support
Browse files Browse the repository at this point in the history
  • Loading branch information
ctz committed Dec 16, 2016
1 parent 057f368 commit 9638375
Show file tree
Hide file tree
Showing 6 changed files with 346 additions and 71 deletions.
241 changes: 186 additions & 55 deletions src/client_hs.rs
Expand Up @@ -13,7 +13,8 @@ use msgs::handshake::{ProtocolNameList, ConvertProtocolNameList};
use msgs::handshake::{CertificatePayloadTLS13, CertificateEntry};
use msgs::handshake::ServerKeyExchangePayload;
use msgs::handshake::DigitallySignedStruct;
use msgs::enums::ClientCertificateType;
use msgs::handshake::{PresharedKeyIdentity, PresharedKeyOffer};
use msgs::enums::{ClientCertificateType, PskKeyExchangeMode};
use msgs::codec::Codec;
use msgs::persist;
use msgs::ccs::ChangeCipherSpecPayload;
Expand All @@ -22,6 +23,7 @@ use session::{SessionSecrets, MessageCipherChange};
use key_schedule::{KeySchedule, SecretKind};
use cipher::MessageCipher;
use suites;
use hash_hs;
use verify;
use rand;
use error::TLSError;
Expand Down Expand Up @@ -80,17 +82,51 @@ fn randomise_sessionid_for_ticket(csv: &mut persist::ClientSessionValue) {
}
}

/// This implements the horrifying TLS1.3 hack where PSK binders have a
/// data dependency on the message they are contained within.
pub fn fill_in_psk_binder(sess: &mut ClientSessionImpl, hmp: &mut HandshakeMessagePayload) {
/* We need to know the hash function of the suite we're trying to resume into. */
let resuming = sess.handshake_data.resuming_session.as_ref().unwrap();
let suite_hash = sess.find_cipher_suite(&resuming.cipher_suite).unwrap().get_hash();

/* The binder is calculated over the clienthello, but doesn't include itself or its
* length, or the length of its container. */
let encoding = hmp.get_encoding();
let binder_len = suite_hash.output_len;
let binder_plaintext = &encoding[0 .. encoding.len() - binder_len - 2 - 1];
let handshake_hash = sess.handshake_data.transcript.get_hash_given(suite_hash,
binder_plaintext);

let mut empty_hash_ctx = hash_hs::HandshakeHash::new();
empty_hash_ctx.start_hash(suite_hash);
let empty_hash = empty_hash_ctx.get_current_hash();

/* Run a fake key_schedule to simulate what the server will do if it choses
* to resume. */
let mut key_schedule = KeySchedule::new(suite_hash);
key_schedule.input_secret(&resuming.master_secret.0);
let base_key = key_schedule.derive(SecretKind::ResumptionPSKBinderKey, &empty_hash);
let real_binder = key_schedule.sign_verify_data(&base_key, &handshake_hash);

match hmp.payload {
HandshakePayload::ClientHello(ref mut ch) => {
ch.set_psk_binder(real_binder);
},
_ => {}
};
}

pub fn emit_client_hello(sess: &mut ClientSessionImpl) {
/* Do we have a SessionID or ticket cached for this host? */
sess.handshake_data.resuming_session = find_session(sess);
let (session_id, ticket) = if sess.handshake_data.resuming_session.is_some() {
let (session_id, ticket, resume_version) = if sess.handshake_data.resuming_session.is_some() {
let mut resuming = sess.handshake_data.resuming_session.as_mut().unwrap();
randomise_sessionid_for_ticket(resuming);
info!("Resuming session");
(resuming.session_id.clone(), resuming.ticket.0.clone())
(resuming.session_id.clone(), resuming.ticket.0.clone(), resuming.version)
} else {
info!("Not resuming any session");
(SessionID::empty(), Vec::new())
(SessionID::empty(), Vec::new(), ProtocolVersion::Unknown(0))
};

let support_tls12 = sess.config.versions.contains(&ProtocolVersion::TLSv1_2);
Expand Down Expand Up @@ -129,42 +165,75 @@ pub fn emit_client_hello(sess: &mut ClientSessionImpl) {
exts.push(ClientExtension::KeyShare(key_shares));
}

if sess.config.enable_tickets {
if support_tls13 && sess.config.enable_tickets {
let psk_modes = vec![ PskKeyExchangeMode::KE, PskKeyExchangeMode::DHE_KE ];
exts.push(ClientExtension::PresharedKeyModes(psk_modes));
}

if !sess.config.alpn_protocols.is_empty() {
exts.push(ClientExtension::Protocols(ProtocolNameList::from_strings(&sess.config.alpn_protocols)));
}

let fill_in_binder = if sess.config.enable_tickets && resume_version == ProtocolVersion::TLSv1_2 {
/* If we have a ticket, include it. Otherwise, request one. */
if ticket.is_empty() {
exts.push(ClientExtension::SessionTicketRequest);
} else {
exts.push(ClientExtension::SessionTicketOffer(Payload::new(ticket)));
}
}

if !sess.config.alpn_protocols.is_empty() {
exts.push(ClientExtension::Protocols(ProtocolNameList::from_strings(&sess.config.alpn_protocols)));
}
false
} else if support_tls13 && sess.config.enable_tickets
&& resume_version == ProtocolVersion::TLSv1_3 && !ticket.is_empty() {
/* Finally, and only for TLS1.3 with a ticket resumption, include a binder
* for our ticket. This must go last.
*
* Include an empty binder. It gets filled in below because it depends on
* the message it's contained in (!!!). */
let (obfuscated_ticket_age, suite) = {
let resuming = sess.handshake_data.resuming_session
.as_ref()
.unwrap();
(resuming.get_obfuscated_ticket_age(), resuming.cipher_suite)
};

let binder_len = sess.find_cipher_suite(&suite).unwrap().get_hash().output_len;
let binder = vec![0u8; binder_len];

let psk_identity = PresharedKeyIdentity::new(ticket, obfuscated_ticket_age);
let psk_ext = PresharedKeyOffer::new(psk_identity, binder);
exts.push(ClientExtension::PresharedKey(psk_ext));
true
} else {
false
};

/* Note what extensions we sent. */
sess.handshake_data.sent_extensions = exts.iter()
.map(|ext| ext.get_type())
.collect();

let mut chp = HandshakeMessagePayload {
typ: HandshakeType::ClientHello,
payload: HandshakePayload::ClientHello(
ClientHelloPayload {
client_version: ProtocolVersion::TLSv1_2,
random: Random::from_slice(&sess.handshake_data.randoms.client),
session_id: session_id,
cipher_suites: sess.get_cipher_suites(),
compression_methods: vec![Compression::Null],
extensions: exts
}
)
};

if fill_in_binder {
fill_in_psk_binder(sess, &mut chp);
}

let ch = Message {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::Handshake(
HandshakeMessagePayload {
typ: HandshakeType::ClientHello,
payload: HandshakePayload::ClientHello(
ClientHelloPayload {
client_version: ProtocolVersion::TLSv1_2,
random: Random::from_slice(&sess.handshake_data.randoms.client),
session_id: session_id,
cipher_suites: sess.get_cipher_suites(),
compression_methods: vec![Compression::Null],
extensions: exts
}
)
}
)
payload: MessagePayload::Handshake(chp)
};

debug!("Sending ClientHello {:#?}", ch);
Expand Down Expand Up @@ -204,25 +273,56 @@ fn find_key_share(sess: &mut ClientSessionImpl, group: NamedGroup) -> Result<sui

fn start_handshake_traffic(sess: &mut ClientSessionImpl, server_hello: &ServerHelloPayload)
-> Result<(), TLSError> {
let their_key_share = try!(
server_hello.get_key_share()
.ok_or_else(|| {
sess.common.send_fatal_alert(AlertDescription::MissingExtension);
TLSError::PeerMisbehavedError("missing key share".to_string())
})
);

let our_key_share = try!(find_key_share(sess, their_key_share.group));
let shared = try!(
our_key_share.complete(&their_key_share.payload.0)
.ok_or_else(|| TLSError::PeerMisbehavedError("key exchange failed".to_string()))
);

let suite = sess.common.get_suite();
let hash = suite.get_hash();
let mut key_schedule = KeySchedule::new(hash);
key_schedule.input_empty(); /* TODO: insert PSK here */
key_schedule.input_secret(&shared.premaster_secret);

// PSK_KE means allowing a missing server key_share
// here, but critically only if resuming from something.
let mut skip_key_share = false;

if let Some(selected_psk) = server_hello.get_psk_index() {
if let Some(ref resuming) = sess.handshake_data.resuming_session {
if suite.suite != resuming.cipher_suite {
return Err(TLSError::PeerMisbehavedError("server resuming wrong suite".to_string()));
}

if selected_psk != 0 {
return Err(TLSError::PeerMisbehavedError("server selected invalid psk".to_string()));
}

info!("Resuming using PSK");
key_schedule.input_secret(&resuming.master_secret.0);
skip_key_share = server_hello.get_key_share().is_none();
} else {
return Err(TLSError::PeerMisbehavedError("server selected unoffered psk".to_string()));
}
} else {
info!("Not resuming");
key_schedule.input_empty();
sess.handshake_data.resuming_session.take();
}

if skip_key_share {
info!("Server didn't contribute DH share");
key_schedule.input_empty();
} else {
let their_key_share = try!(
server_hello.get_key_share()
.ok_or_else(|| {
sess.common.send_fatal_alert(AlertDescription::MissingExtension);
TLSError::PeerMisbehavedError("missing key share".to_string())
})
);

let our_key_share = try!(find_key_share(sess, their_key_share.group));
let shared = try!(
our_key_share.complete(&their_key_share.payload.0)
.ok_or_else(|| TLSError::PeerMisbehavedError("key exchange failed".to_string()))
);

key_schedule.input_secret(&shared.premaster_secret);
}

let handshake_hash = sess.handshake_data.transcript.get_current_hash();
let write_key = key_schedule.derive(SecretKind::ClientHandshakeTrafficSecret, &handshake_hash);
Expand Down Expand Up @@ -366,7 +466,11 @@ fn handle_encrypted_extensions(sess: &mut ClientSessionImpl, m: Message) -> Resu

try!(process_alpn_protocol(sess, exts.get_alpn_protocol()));

Ok(ConnState::ExpectCertificateOrCertReq)
if sess.handshake_data.resuming_session.is_some() {
Ok(ConnState::ExpectFinished)
} else {
Ok(ConnState::ExpectCertificateOrCertReq)
}
}

pub static EXPECT_ENCRYPTED_EXTENSIONS: Handler = Handler {
Expand Down Expand Up @@ -844,7 +948,9 @@ fn save_session(sess: &mut ClientSessionImpl) {

let scs = sess.common.get_suite();
let master_secret = sess.secrets.as_ref().unwrap().get_master_secret();
let value = persist::ClientSessionValue::new(&scs.suite,
let version = sess.get_protocol_version().unwrap();
let value = persist::ClientSessionValue::new(version,
scs.suite,
&sess.handshake_data.session_id,
ticket,
master_secret);
Expand All @@ -860,14 +966,6 @@ fn save_session(sess: &mut ClientSessionImpl) {
}
}

fn handle_finished(sess: &mut ClientSessionImpl, m: Message) -> Result<ConnState, TLSError> {
if sess.common.is_tls13 {
handle_finished_tls13(sess, m)
} else {
handle_finished_tls12(sess, m)
}
}

fn emit_certificate_tls13(sess: &mut ClientSessionImpl) {
let context = sess.handshake_data.client_auth_context.take()
.unwrap_or_else(|| Vec::new());
Expand All @@ -893,7 +991,6 @@ fn emit_certificate_tls13(sess: &mut ClientSessionImpl) {
}
)
};

sess.handshake_data.transcript.add_message(&m);
sess.common.send_msg(m, true);
}
Expand Down Expand Up @@ -937,7 +1034,7 @@ fn emit_certverify_tls13(sess: &mut ClientSessionImpl) -> Result<(), TLSError> {
fn emit_finished_tls13(sess: &mut ClientSessionImpl) {
let handshake_hash = sess.handshake_data.transcript.get_current_hash();
let verify_data = sess.common.get_key_schedule()
.sign_verify_data(SecretKind::ClientHandshakeTrafficSecret, &handshake_hash);
.sign_finish(SecretKind::ClientHandshakeTrafficSecret, &handshake_hash);
let verify_data_payload = Payload::new(verify_data);

let m = Message {
Expand All @@ -960,7 +1057,7 @@ fn handle_finished_tls13(sess: &mut ClientSessionImpl, m: Message) -> Result<Con

let handshake_hash = sess.handshake_data.transcript.get_current_hash();
let expect_verify_data = sess.common.get_key_schedule()
.sign_verify_data(SecretKind::ServerHandshakeTrafficSecret, &handshake_hash);
.sign_finish(SecretKind::ServerHandshakeTrafficSecret, &handshake_hash);

use ring;
try!(
Expand Down Expand Up @@ -1022,6 +1119,14 @@ fn handle_finished_tls12(sess: &mut ClientSessionImpl, m: Message) -> Result<Con
Ok(ConnState::TrafficTLS12)
}

fn handle_finished(sess: &mut ClientSessionImpl, m: Message) -> Result<ConnState, TLSError> {
if sess.common.is_tls13 {
handle_finished_tls13(sess, m)
} else {
handle_finished_tls12(sess, m)
}
}

fn handle_finished_resume(sess: &mut ClientSessionImpl, m: Message) -> Result<ConnState, TLSError> {
let next_state = try!(handle_finished(sess, m));

Expand Down Expand Up @@ -1067,14 +1172,40 @@ fn handle_traffic_tls13(sess: &mut ClientSessionImpl, m: Message) -> Result<Conn
if m.is_content_type(ContentType::ApplicationData) {
try!(handle_traffic(sess, m));
} else if m.is_handshake_type(HandshakeType::NewSessionTicket) {
info!("Ignoring TLS1.3 NewSessionTicket message {:?}", m);
try!(handle_new_ticket_tls13(sess, m));
} else if m.is_handshake_type(HandshakeType::KeyUpdate) {
try!(handle_key_update(sess, m));
}

Ok(ConnState::TrafficTLS13)
}

fn handle_new_ticket_tls13(sess: &mut ClientSessionImpl, m: Message) -> Result<(), TLSError> {
let nst = extract_handshake!(m, HandshakePayload::NewSessionTicketTLS13).unwrap();
let handshake_hash = sess.handshake_data.transcript.get_current_hash();
let secret = sess.common.get_key_schedule().derive(SecretKind::ResumptionMasterSecret,
&handshake_hash);
let value = persist::ClientSessionValue::new(ProtocolVersion::TLSv1_3,
sess.common.get_suite().suite,
&SessionID::empty(),
nst.ticket.0.clone(),
secret);
let value_buf = value.get_encoding();

let key = persist::ClientSessionKey::for_dns_name(&sess.handshake_data.dns_name);
let key_buf = key.get_encoding();

let mut persist = sess.config.session_persistence.lock().unwrap();
let worked = persist.put(key_buf, value_buf);

if worked {
info!("Ticket saved");
} else {
info!("Ticket not saved");
}
Ok(())
}

fn handle_key_update(sess: &mut ClientSessionImpl, m: Message) -> Result<(), TLSError> {
let kur = extract_handshake!(m, HandshakePayload::KeyUpdate).unwrap();
sess.common.process_key_update(kur, SecretKind::ServerApplicationTrafficSecret)
Expand Down
20 changes: 18 additions & 2 deletions src/hash_hs.rs
Expand Up @@ -83,11 +83,27 @@ impl HandshakeHash {
self
}

/// Get the hash value if we were to hash `extra` too,
/// using hash function `hash`.
pub fn get_hash_given(&self,
hash: &'static digest::Algorithm,
extra: &[u8]) -> Vec<u8> {
debug_assert!(self.ctx.is_none());

let mut ctx = digest::Context::new(hash);
ctx.update(&self.buffer);
ctx.update(extra);
let hash = ctx.finish();
let mut ret = Vec::new();
ret.extend_from_slice(hash.as_ref());
ret
}

/// Get the current hash value.
pub fn get_current_hash(&self) -> Vec<u8> {
let h = self.ctx.as_ref().unwrap().clone().finish();
let hash = self.ctx.as_ref().unwrap().clone().finish();
let mut ret = Vec::new();
ret.extend_from_slice(h.as_ref());
ret.extend_from_slice(hash.as_ref());
ret
}

Expand Down

0 comments on commit 9638375

Please sign in to comment.