diff --git a/af_ktls.c b/af_ktls.c index f5be10d..4ed6532 100644 --- a/af_ktls.c +++ b/af_ktls.c @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include "af_ktls.h" @@ -43,9 +45,9 @@ #define KTLS_NONCE_SIZE 8 #define KTLS_DATA_PAGES (KTLS_MAX_PAYLOAD_SIZE / PAGE_SIZE) -// +1 for header, +1 for tag +/* +1 for header, +1 for tag */ #define KTLS_VEC_SIZE (KTLS_DATA_PAGES + 2) -// +1 for aad, +1 for tag, +1 for chaining +/* +1 for aad, +1 for tag, +1 for chaining */ #define KTLS_SG_DATA_SIZE (KTLS_DATA_PAGES + 3) /* @@ -66,22 +68,24 @@ #define KTLS_TLS_1_2_MAJOR 0x03 #define KTLS_TLS_1_2_MINOR 0x03 -// nonce explicit offset in a record +/* nonce explicit offset in a record */ #define KTLS_TLS_NONCE_OFFSET KTLS_TLS_HEADER_SIZE /* * DTLS related stuff */ #define KTLS_DTLS_HEADER_SIZE 13 -#define KTLS_DTLS_PREPEND_SIZE (KTLS_DTLS_HEADER_SIZE + KTLS_NONCE_SIZE) -#define KTLS_DTLS_OVERHEAD (KTLS_DTLS_PREPEND_SIZE + KTLS_TAG_SIZE) +#define KTLS_DTLS_PREPEND_SIZE (KTLS_DTLS_HEADER_SIZE \ + + KTLS_NONCE_SIZE) +#define KTLS_DTLS_OVERHEAD (KTLS_DTLS_PREPEND_SIZE \ + + KTLS_TAG_SIZE) #define KTLS_DTLS_1_2_MAJOR 0xFE #define KTLS_DTLS_1_2_MINOR 0xFD -// we are handling epoch and seq num as one unit +/* we are handling epoch and seq num as one unit */ #define KTLS_DTLS_SEQ_NUM_OFFSET 3 -// nonce explicit offset in a record +/* nonce explicit offset in a record */ #define KTLS_DTLS_NONCE_OFFSET KTLS_DTLS_HEADER_SIZE /* @@ -130,16 +134,10 @@ (KTLS_TLS_NONCE_OFFSET) : \ (KTLS_DTLS_NONCE_OFFSET)) -/* - * Asynchrous receive handling - */ -#define TLS_CACHE_DISCARD(T) (T->recv_occupied = 0) -#define TLS_CACHE_SIZE(T) (T->recv_occupied) -#define TLS_CACHE_SET_SIZE(T, S) (T->recv_occupied = S) -//#define KTLS_DEBUG +/*#define KTLS_DEBUG */ -#if 1 // TODO: remove once ready to use +#if 1 /* TODO: remove once ready to use */ #ifdef KTLS_DEBUG # define xprintk(...) (do_xprintk(__VA_ARGS__)) # define print_hex(...) (do_print_hex(__VA_ARGS__)) @@ -150,22 +148,25 @@ #define UNUSED(X) ((void) X) - void do_xprintk(const char *fmt, ...) { + void do_xprintk(const char *fmt, ...) + { va_list va; + va_start(va, fmt); - printk("tls: "); + pr_debug("tls: "); vprintk(fmt, va); - printk("\n"); + pr_debug("\n"); va_end(va); } - void do_print_hex(const unsigned char * key, unsigned int keysize) { + void do_print_hex(const unsigned char *key, unsigned int keysize) + { int i = 0; - printk("kdls: hex: "); + pr_debug("kdls: hex: "); for (i = 0; i < keysize; i++) - printk("%02X", (unsigned char)key[i]); - printk("\n"); + pr_debug("%02X", (unsigned char)key[i]); + pr_debug("\n"); } #endif @@ -181,6 +182,13 @@ struct tls_key { size_t saltlen; }; +struct tls_rx_msg { + int full_len; + int accum_len; + int offset; + int early_eaten; +}; + struct tls_sock { /* struct sock must be the very first member */ struct sock sk; @@ -217,29 +225,19 @@ struct tls_sock { struct scatterlist sgaad_send[2]; struct scatterlist sgtag_send[2]; - /* - * Receiving context, rx_lock has to be acquired before socket lock to - * avoid deadlock - */ - struct mutex rx_lock; - struct scatterlist sg_rx_data[KTLS_SG_DATA_SIZE]; - struct kvec vec_recv[KTLS_VEC_SIZE]; - char header_recv[MAX(KTLS_TLS_PREPEND_SIZE, KTLS_DTLS_PREPEND_SIZE)]; + /* Receive */ + struct sk_buff *rx_skb_head; + struct sk_buff **rx_skb_nextp; + unsigned int rx_need_bytes; + struct scatterlist sgin[ALG_MAX_PAGES+1]; char aad_recv[KTLS_AAD_SPACE_SIZE]; - char tag_recv[KTLS_TAG_SIZE]; - struct page *pages_recv; - struct af_alg_sgl sgl_recv[UIO_MAXIOV]; - struct scatterlist sgaad_recv[2]; - struct scatterlist sgtag_recv[2]; + char header_recv[MAX(KTLS_TLS_PREPEND_SIZE, KTLS_DTLS_PREPEND_SIZE)]; /* - * Asynchronous work to cache one record + * Asynchronous work in case of low memory */ struct work_struct recv_work; void (*saved_sk_data_ready)(struct sock *sk); - struct scatterlist sg_rx_async_work[KTLS_SG_DATA_SIZE]; - struct page *pages_work; - size_t recv_occupied; /* * our cipher type and its crypto API representation (e.g. "gcm(aes)") @@ -278,19 +276,39 @@ struct tls_sock { struct page *page; } sendpage_ctx; - // TODO: remove once finished benchmarking + /* TODO: remove once finished benchmarking */ unsigned parallel_count_stat; }; + +static inline struct tls_rx_msg *tls_rx_msg(struct sk_buff *skb) +{ + return (struct tls_rx_msg *)((void *)skb->cb + + offsetof(struct qdisc_skb_cb, data)); +} static inline struct tls_sock *tls_sk(struct sock *sk) { return (struct tls_sock *)sk; } +static int tls_do_decryption(const struct tls_sock *tsk, + struct scatterlist *sgin, + struct scatterlist *sgout, + size_t data_len); + +static inline void tls_make_aad(struct tls_sock *tsk, + int recv, + char *buf, + size_t size, + char *nonce_explicit); + +static int tls_post_process(const struct tls_sock *tsk, struct sk_buff *skb); + static void increment_seqno(char *s) { u64 *seqno = (u64 *) s; u64 seq_h = be64_to_cpu(*seqno); + seq_h++; *seqno = cpu_to_be64(seq_h); } @@ -299,7 +317,8 @@ static void tls_free_sendpage_ctx(struct tls_sock *tsk) { size_t i; struct scatterlist *sg; - xprintk("--> %s", __FUNCTION__); + + xprintk("--> %s", __func__); sg = tsk->sendpage_ctx.sg; @@ -326,7 +345,7 @@ static void tls_update_senpage_ctx(struct tls_sock *tsk, size_t size) struct scatterlist *sg; struct scatterlist *sg_start; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); sg = tsk->sendpage_ctx.sg; @@ -340,7 +359,7 @@ static void tls_update_senpage_ctx(struct tls_sock *tsk, size_t size) walked_size = 0; sg_start = sg; put_count = 0; - while (put_count < tsk->sendpage_ctx.used && \ + while (put_count < tsk->sendpage_ctx.used && walked_size + sg_start->length <= size) { walked_size += sg_start->length; put_page(sg_page(sg_start)); @@ -348,7 +367,9 @@ static void tls_update_senpage_ctx(struct tls_sock *tsk, size_t size) put_count++; } - // adjust length and offset so we can send with right offset next time + /* adjust length and offset so we can send with right offset next + * time + */ sg_start->offset += (size - walked_size); sg_start->length -= (size - walked_size); tsk->sendpage_ctx.current_size -= size; @@ -359,30 +380,591 @@ static void tls_update_senpage_ctx(struct tls_sock *tsk, size_t size) * can use whole scatterlist next time */ memmove(sg, sg_start, - (KTLS_SG_DATA_SIZE - 1 - put_count)*sizeof(tsk->sendpage_ctx.sg[0])); + (KTLS_SG_DATA_SIZE - 1 - put_count)*sizeof( + tsk->sendpage_ctx.sg[0])); sg_mark_end(&sg[tsk->sendpage_ctx.used]); } -#include "dtls-window.c" +/*Must be called with socket callback locked */ +static void tls_unattach(struct tls_sock *tsk) +{ + tsk->socket->sk->sk_data_ready = tsk->saved_sk_data_ready; + tsk->socket->sk->sk_user_data = NULL; +} + +static void tls_err_abort(struct tls_sock *tsk) +{ + struct sock *sk; + + sk = (struct sock *)tsk; + tsk->rx_stopped = 1; + sk->sk_err = -EBADMSG; + sk->sk_error_report(sk); + tsk->saved_sk_data_ready(tsk->socket->sk); + tls_unattach(tsk); +} + +static int decrypt_skb(struct tls_sock *tsk, struct sk_buff *skb) +{ + int ret, nsg; + size_t prepend, overhead; + struct tls_rx_msg *rxm; + + prepend = IS_TLS(tsk) ? KTLS_TLS_PREPEND_SIZE : KTLS_DTLS_PREPEND_SIZE; + overhead = IS_TLS(tsk) ? KTLS_TLS_OVERHEAD : KTLS_DTLS_OVERHEAD; + rxm = tls_rx_msg(skb); + + + sg_init_table(tsk->sgin, ARRAY_SIZE(tsk->sgin)); + sg_set_buf(&tsk->sgin[0], tsk->aad_recv, sizeof(tsk->aad_recv)); + + /* + * TODO: So what exactly happens if skb_to_sgvec causes more + * than ALG_MAX_PAGES fragments? Consider allocating kernel + * pages + * tls_read_size already copied headers and aad. Therefore + * this simply needs to pass the encrypted data + message + */ + nsg = skb_to_sgvec(skb, &tsk->sgin[1], rxm->offset + + prepend, + rxm->full_len - prepend); + + /* + * The length of sg into decryption must not be over + * ALG_MAX_PAGES. The aad takes the first sg, so the + * payload must be less than ALG_MAX_PAGES - 1 + */ + if (nsg > ALG_MAX_PAGES - 1) { + ret = -EBADMSG; + goto decryption_fail; + } + + tls_make_aad(tsk, 1, tsk->aad_recv, + rxm->full_len - overhead, + tsk->iv_recv); + + /* + * Decrypt in place. + * After this function call, the decrypted data will be in + * rxm->offset. We must therefore account for the fact that + * the lengths of skbuff_in and skbuff_out are different + */ + + ret = tls_do_decryption(tsk, + tsk->sgin, + tsk->sgin, + rxm->full_len - overhead); + + if (ret < 0) + goto decryption_fail; + + + ret = tls_post_process(tsk, skb); + + if (ret < 0) + goto decryption_fail; + + return 0; +decryption_fail: + return ret; +} +/* + * Returns the length of the unencrypted message, plus overhead + * Note that this function also populates tsk->header which is later + * used for decryption + * TODO: Revisit + * In TLS we automatically bail if we see a non-TLS message. In DTLS + * we should determine if we got a corrupted message vs a control msg + * Right now if the TLS magic bit got corrupted it would incorrectly + * misinterpret it as a non-TLS message + * Returns 0 if more data is necessary to determine length + * Returns <0 if error occured + */ +static inline ssize_t tls_read_size(struct tls_sock *tsk, struct sk_buff *skb) +{ + int ret; + size_t data_len = 0; + size_t datagram_len; + size_t prepend; + char first_byte; + char *header; + struct tls_rx_msg *rxm; + + prepend = IS_TLS(tsk) ? KTLS_TLS_PREPEND_SIZE : KTLS_DTLS_PREPEND_SIZE; + header = tsk->header_recv; + xprintk("--> %s", __func__); + + rxm = tls_rx_msg(skb); + + ret = skb_copy_bits(skb, rxm->offset, &first_byte, 1); + if (ret < 0) + goto read_failure; + + /*Check the first byte to see if its a TLS record */ + if (first_byte != KTLS_RECORD_DATA) { + ret = -EBADMSG; + goto read_failure; + } + + /* + *We have a TLS record. Check that msglen is long enough to read + *the length of record. + *We must not check this before checking the first byte, since + *that will cause unencrypted + *messages shorter than KTLS_TLS_PREPEND_SIZE to not be read + */ + if (rxm->offset + prepend > skb->len) { + ret = 0; + goto read_failure; + } + + /* + *Copy header to use later in decryption. + *An optimization could be to zero-copy, but you'd + *have to be able to walk frag_lists. This function call + *takes care of that. + *Overhead is relatively small (13 bytes for TLS, 21 for DTLS) + */ + ret = skb_copy_bits(skb, rxm->offset, tsk->header_recv, + prepend); + + if (ret < 0) + goto read_failure; + + if (IS_TLS(tsk)) { + data_len = ((header[4] & 0xFF) | (header[3] << 8)); + data_len = data_len - KTLS_TAG_SIZE - KTLS_IV_SIZE; + datagram_len = data_len + KTLS_TLS_OVERHEAD; + } else { + data_len = ((header[12] & 0xFF) | (header[11] << 8)); + data_len = data_len - KTLS_TAG_SIZE - KTLS_IV_SIZE; + datagram_len = data_len + KTLS_DTLS_OVERHEAD; + } + + if (data_len > KTLS_MAX_PAYLOAD_SIZE) { + ret = -E2BIG; + goto read_failure; + } + return datagram_len; + +read_failure: + return ret; +} + +/* Lower socket lock held */ +/* Returns the number of bytes used */ +/* A lot of this code was copy/pasta from KCM code. Consider abstracting this */ +static int tls_tcp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb, + unsigned int orig_offset, size_t orig_len) +{ + struct tls_sock *tsk = (struct tls_sock *)desc->arg.data; + struct sock *sk = (struct sock *)tsk; + struct tls_rx_msg *rxm; + struct sk_buff *head, *skb; + size_t eaten = 0, uneaten; + ssize_t extra; + int err; + bool cloned_orig = false; + + head = tsk->rx_skb_head; + if (head) { + /* We seem to be in the middle of a message */ + + rxm = tls_rx_msg(head); + + if (unlikely(rxm->early_eaten)) { + /* Already some number of bytes on the receive sock + * data saved in rx_skb_head, just indicate they + * are consumed. + */ + eaten = orig_len <= rxm->early_eaten ? + orig_len : rxm->early_eaten; + rxm->early_eaten -= eaten; + + return eaten; + } + + if (unlikely(orig_offset)) { + /* Getting data with a non-zero offset when a message is + * in progress is not expected. If it does happen, we + * need to clone and pull since we can't deal with + * offsets in the skbs for a message expect in the head. + */ + orig_skb = skb_clone(orig_skb, GFP_ATOMIC); + if (!orig_skb) { + desc->error = -ENOMEM; + return 0; + } + if (!pskb_pull(orig_skb, orig_offset)) { + kfree_skb(orig_skb); + desc->error = -ENOMEM; + return 0; + } + cloned_orig = true; + orig_offset = 0; + } + + /*We are appending to head. Unshare the frag list */ + if (!tsk->rx_skb_nextp) { + err = skb_unclone(head, GFP_ATOMIC); + if (err) { + desc->error = err; + return 0; + } + + if (unlikely(skb_shinfo(head)->frag_list)) { + /* We can't append to an sk_buff that already + * has a frag_list. We create a new head, point + * the frag_list of that to the old head, and + * then are able to use the old head->next for + * appending to the message. + */ + if (WARN_ON(head->next)) { + desc->error = -EINVAL; + return 0; + } + + skb = alloc_skb(0, GFP_ATOMIC); + if (!skb) { + desc->error = -ENOMEM; + return 0; + } + skb->len = head->len; + skb->data_len = head->len; + skb->truesize = head->truesize; + *tls_rx_msg(skb) = *tls_rx_msg(head); + tsk->rx_skb_nextp = &head->next; + skb_shinfo(skb)->frag_list = head; + tsk->rx_skb_head = skb; + head = skb; + } else { + tsk->rx_skb_nextp = + &skb_shinfo(head)->frag_list; + } + } + } + + while (eaten < orig_len) { + int ret; + /* Always clone since we will consume something */ + skb = skb_clone(orig_skb, GFP_ATOMIC); + if (!skb) { + desc->error = -ENOMEM; + break; + } + + uneaten = orig_len - eaten; + + head = tsk->rx_skb_head; + /*head is null */ + if (!head) { + head = skb; + tsk->rx_skb_head = head; + /* Will set rx_skb_nextp on next packet if needed */ + tsk->rx_skb_nextp = NULL; + rxm = tls_rx_msg(head); + memset(rxm, 0, sizeof(*rxm)); + rxm->offset = orig_offset + eaten; + } else { + /*head not null */ + + rxm = tls_rx_msg(head); + *tsk->rx_skb_nextp = skb; + tsk->rx_skb_nextp = &skb->next; + head->data_len += skb->len; + head->len += skb->len; + head->truesize += skb->truesize; + } + if (!rxm->full_len) { + ssize_t len; + + len = tls_read_size(tsk, head); + + /*Is this a sane packet? */ + if (!len) { + /* Need more header to determine length */ + + rxm->accum_len += uneaten; + eaten += uneaten; + WARN_ON(eaten != orig_len); + break; + } else if (len < 0) { + /* Data does not appear to be a TLS record + * Make userspace handle it + */ + goto decryption_fail; + } else if (len <= (ssize_t)head->len - + skb->len - rxm->offset) { + /* Length must be into new skb (and also + * greater than zero) + */ + goto decryption_fail; + } + + rxm->full_len = len; + } + + extra = (ssize_t)(rxm->accum_len + uneaten) - rxm->full_len; + + if (extra < 0) { + /* Message not complete yet. */ + if (rxm->full_len - rxm->accum_len > + tcp_inq((struct sock *)tsk)) { + /* Don't have the whole messages in the socket + * buffer. Set tsk->rx_need_bytes to wait for + * the rest of the message. Also, set "early + * eaten" since we've already buffered the skb + * but don't consume yet per tcp_read_sock. + * If function returns 0, does not consume + */ + + /* Wait. Why doesn't this code path just set + * eaten? Then tcp_read_sock will eat and + * profit! + */ + if (!rxm->accum_len) { + /* Start RX timer for new message */ + /*kcm_start_rx_timer(tsk); */ + } + + tsk->rx_need_bytes = rxm->full_len - + rxm->accum_len; + rxm->accum_len += uneaten; + rxm->early_eaten = uneaten; + desc->count = 0; /* Stop reading socket */ + break; + } + rxm->accum_len += uneaten; + eaten += uneaten; + WARN_ON(eaten != orig_len); + break; + } + + /* Positive extra indicates ore bytes than needed for the + * message + */ + + WARN_ON(extra > uneaten); + + ret = decrypt_skb(tsk, head); + if (ret < 0) + goto decryption_fail; + + /* Hurray, we have a new message! */ + tsk->rx_skb_head = NULL; + eaten += (uneaten - extra); + + sock_queue_rcv_skb(sk, head); + } +done: + if (cloned_orig) + kfree_skb(orig_skb); + return eaten; +decryption_fail: + kfree_skb(skb); + desc->error = -EBADMSG; + tsk->rx_skb_head = NULL; + desc->count = 0; + goto done; +} + +static int tls_tcp_read_sock(struct tls_sock *tsk) +{ + read_descriptor_t desc; + + desc.arg.data = tsk; + desc.error = 0; + desc.count = 1; /* give more than one skb per call */ + + /* sk should be locked here, so okay to do tcp_read_sock */ + tcp_read_sock(tsk->socket->sk, &desc, tls_tcp_recv); + + return desc.error; +} + +static void do_tls_data_ready(struct tls_sock *tsk) +{ + int ret; + + if (tsk->rx_need_bytes) { + if (tcp_inq(tsk->socket->sk) >= tsk->rx_need_bytes) + tsk->rx_need_bytes = 0; + else + return; + } + + ret = tls_tcp_read_sock(tsk); + if (ret == -ENOMEM) /* No memory. Do it later */ + queue_work(tls_wq, &tsk->recv_work); + /*queue_delayed_work(tls_wq, &tsk->recv_work, 0); */ + + /* TLS couldn't handle this message. Pass it directly to userspace */ + else if (ret == -EBADMSG) + tls_err_abort(tsk); + +} + +/* Called with lower socket held */ static void tls_data_ready(struct sock *sk) { struct tls_sock *tsk; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); + + read_lock_bh(&sk->sk_callback_lock); + //TODO: forgot to lock tsk? + tsk = (struct tls_sock *)sk->sk_user_data; + if (unlikely(!tsk || tsk->rx_stopped)) + goto out; + + if (!KTLS_RECV_READY(tsk)) { + queue_work(tls_wq, &tsk->recv_work); + goto out; + } + + do_tls_data_ready(tsk); + +out: + read_unlock_bh(&sk->sk_callback_lock); +} + +#include "dtls-window.c" + +//Loop through the SKBs. Decrypt each one and, if valid, add it to recv queue +static int dtls_udp_read_sock(struct tls_sock *tsk) +{ + struct sk_buff *p, *next, *skb; + int ret = 0; + + skb_queue_walk_safe(&tsk->socket->sk->sk_receive_queue, p, next) { + + ssize_t len; + struct tls_rx_msg *rxm; + + rxm = tls_rx_msg(p); + memset(rxm, 0, sizeof(*rxm)); + + /* For UDP, set the offset such that the headers are ignored. + * Full_len is length of skb minus the headers + */ + rxm->full_len = p->len - sizeof(struct udphdr); + rxm->offset = sizeof(struct udphdr); + len = tls_read_size(tsk, p); + + /*Is this a sane packet? */ + WARN_ON(!len); + if (!len) + goto record_pop; + if (len < 0) { + if (len == -EBADMSG) { + /* Data does not appear to be a TLS record + * Make userspace handle it + * TODO: Of course, we are using DTLS therefore + * it may be that the headers were corrupted + * + */ + ret = -EBADMSG; + break; + } + /* Failed for some other reason. Drop the packet */ + goto record_pop; + } + if (dtls_window(tsk, tsk->header_recv + KTLS_DTLS_SEQ_NUM_OFFSET) < 0) + goto record_pop; + + skb = skb_clone(p, GFP_ATOMIC); + if (!skb) { + ret = -ENOMEM; + break; + } + ret = decrypt_skb(tsk, skb); + + if (ret < 0) + goto decryption_fail; + + sock_queue_rcv_skb((struct sock *)tsk, skb); + goto record_pop; +decryption_fail: + kfree_skb(skb); +record_pop: + skb_unlink(p, &tsk->socket->sk->sk_receive_queue); + kfree_skb(p); + } + return ret; + +} + +static void do_dtls_data_ready(struct tls_sock *tsk) +{ + int ret; + + ret = dtls_udp_read_sock(tsk); + if (ret == -ENOMEM) /* No memory. Do it later */ + queue_work(tls_wq, &tsk->recv_work); + /*queue_delayed_work(tls_wq, &tsk->recv_work, 0); */ + /* TLS couldn't handle this message. Pass it directly to userspace */ + else if (ret == -EBADMSG) + tls_err_abort(tsk); +} +/* Called with lower socket held */ +static void dtls_data_ready(struct sock *sk) +{ + struct tls_sock *tsk; + xprintk("--> %s", __func__); read_lock_bh(&sk->sk_callback_lock); tsk = (struct tls_sock *)sk->sk_user_data; - if (unlikely(!tsk || tsk->rx_stopped)) { + if (unlikely(!tsk || tsk->rx_stopped)) + goto out; + + if (!KTLS_RECV_READY(tsk)) { + queue_work(tls_wq, &tsk->recv_work); goto out; } - queue_work(tls_wq, &tsk->recv_work); - out: + do_dtls_data_ready(tsk); +out: read_unlock_bh(&sk->sk_callback_lock); } +static void do_tls_sock_rx_work(struct tls_sock *tsk) +{ + struct sock *sk = tsk->socket->sk; + + /* We need the read lock to synchronize with tls_sock_tcp_data_ready. + * We need the socket lock for calling tcp_read_sock. + */ + lock_sock(sk); + read_lock_bh(&sk->sk_callback_lock); + + if (unlikely(sk->sk_user_data != tsk)) + goto out; + + if (unlikely(tsk->rx_stopped)) + goto out; + + if (!KTLS_RECV_READY(tsk)) { + queue_work(tls_wq, &tsk->recv_work); + goto out; + } + + if (IS_TLS(tsk)) + do_tls_data_ready(tsk); + else + do_dtls_data_ready(tsk); + +out: + read_unlock_bh(&sk->sk_callback_lock); + release_sock(sk); +} + +static void tls_rx_work(struct work_struct *w) +{ + do_tls_sock_rx_work(container_of(w, struct tls_sock, recv_work)); +} + static int tls_set_iv(struct socket *sock, int recv, char __user *src, @@ -393,7 +975,7 @@ static int tls_set_iv(struct socket *sock, struct sock *sk; struct tls_sock *tsk; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); sk = sock->sk; tsk = tls_sk(sk); @@ -425,7 +1007,7 @@ static int tls_init_aead(struct tls_sock *tsk, int recv) char keyval[KTLS_KEY_SIZE + KTLS_SALT_SIZE]; size_t keyval_len; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); k = recv ? &tsk->key_recv : &tsk->key_send; aead = recv ? tsk->aead_recv : tsk->aead_send; @@ -452,6 +1034,7 @@ static int tls_init_aead(struct tls_sock *tsk, int recv) return ret ?: 0; } +/*TODO: No lock? */ static int tls_set_key(struct socket *sock, int recv, char __user *src, @@ -461,7 +1044,7 @@ static int tls_set_key(struct socket *sock, struct tls_sock *tsk; struct tls_key *k; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); tsk = tls_sk(sock->sk); @@ -474,8 +1057,7 @@ static int tls_set_key(struct socket *sock, k = recv ? &tsk->key_recv : &tsk->key_send; if (src_len > k->keylen) { - if (k->keylen) - kfree(k->key); + kfree(k->key); k->key = kmalloc(src_len, GFP_KERNEL); if (!k->key) return -ENOMEM; @@ -502,7 +1084,7 @@ static int tls_set_salt(struct socket *sock, struct tls_sock *tsk; struct tls_key *k; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); tsk = tls_sk(sock->sk); @@ -529,7 +1111,7 @@ static int tls_set_mtu(struct socket *sock, char __user *src, size_t src_len) size_t mtu; struct tls_sock *tsk; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); tsk = tls_sk(sock->sk); @@ -553,6 +1135,18 @@ static int tls_set_mtu(struct socket *sock, char __user *src, size_t src_len) return mtu; } +static void tls_do_unattach(struct socket *sock) +{ + struct tls_sock *tsk; + struct sock *sk; + + tsk = tls_sk(sock->sk); + sk = tsk->socket->sk; + + read_lock_bh(&sk->sk_callback_lock); + tls_err_abort(tsk); + read_unlock_bh(&sk->sk_callback_lock); +} static int tls_setsockopt(struct socket *sock, int level, int optname, char __user *optval, @@ -561,7 +1155,7 @@ static int tls_setsockopt(struct socket *sock, int ret; struct tls_sock *tsk; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); tsk = tls_sk(sock->sk); if (level != AF_KTLS) @@ -575,40 +1169,35 @@ static int tls_setsockopt(struct socket *sock, switch (optname) { - case KTLS_SET_IV_RECV: - ret = tls_set_iv(sock, 1, optval, optlen); - break; - case KTLS_SET_KEY_RECV: - ret = tls_set_key(sock, 1, optval, optlen); - break; - case KTLS_SET_SALT_RECV: - ret = tls_set_salt(sock, 1, optval, optlen); - break; - case KTLS_SET_IV_SEND: - ret = tls_set_iv(sock, 0, optval, optlen); - break; - case KTLS_SET_KEY_SEND: - ret = tls_set_key(sock, 0, optval, optlen); - break; - case KTLS_SET_SALT_SEND: - ret = tls_set_salt(sock, 0, optval, optlen); - break; - case KTLS_SET_MTU: - ret = tls_set_mtu(sock, optval, optlen); - break; - default: - break; + case KTLS_SET_IV_RECV: + ret = tls_set_iv(sock, 1, optval, optlen); + break; + case KTLS_SET_KEY_RECV: + ret = tls_set_key(sock, 1, optval, optlen); + break; + case KTLS_SET_SALT_RECV: + ret = tls_set_salt(sock, 1, optval, optlen); + break; + case KTLS_SET_IV_SEND: + ret = tls_set_iv(sock, 0, optval, optlen); + break; + case KTLS_SET_KEY_SEND: + ret = tls_set_key(sock, 0, optval, optlen); + break; + case KTLS_SET_SALT_SEND: + ret = tls_set_salt(sock, 0, optval, optlen); + break; + case KTLS_SET_MTU: + ret = tls_set_mtu(sock, optval, optlen); + break; + case KTLS_UNATTACH: + tls_do_unattach(sock); + ret = 0; + break; + default: + break; } - /* - * We need to discard cache every time there is a change on socket - * not to be in an invalid state - */ - TLS_CACHE_DISCARD(tsk); - /* - * The same applies to DTLS window - */ - DTLS_WINDOW_INIT(tsk->dtls_window); setsockopt_end: release_sock(sock->sk); return ret < 0 ? ret : 0; @@ -622,7 +1211,7 @@ static int tls_get_iv(const struct tls_sock *tsk, int ret; char *iv; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); if (dst_len < KTLS_IV_SIZE) return -ENOMEM; @@ -647,7 +1236,7 @@ static int tls_get_key(const struct tls_sock *tsk, int ret; const struct tls_key *k; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); k = recv ? &tsk->key_recv : &tsk->key_send; @@ -670,7 +1259,7 @@ static int tls_get_salt(const struct tls_sock *tsk, int ret; const struct tls_key *k; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); k = recv ? &tsk->key_recv : &tsk->key_send; @@ -696,7 +1285,7 @@ static int tls_getsockopt(struct socket *sock, size_t mtu; const struct tls_sock *tsk; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); tsk = tls_sk(sock->sk); @@ -772,7 +1361,7 @@ static inline void tls_make_prepend(struct tls_sock *tsk, { size_t pkt_len; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); pkt_len = plaintext_len + KTLS_IV_SIZE + KTLS_TAG_SIZE; @@ -804,9 +1393,9 @@ static inline void tls_make_aad(struct tls_sock *tsk, size_t size, char *nonce_explicit) { - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); - // has to be zero padded according to RFC5288 + /* has to be zero padded according to RFC5288 */ memset(buf, 0, KTLS_AAD_SPACE_SIZE); memcpy(buf, nonce_explicit, KTLS_NONCE_SIZE); @@ -818,26 +1407,6 @@ static inline void tls_make_aad(struct tls_sock *tsk, buf[12] = size & 0xFF; } -static inline void tls_pop_record(struct tls_sock *tsk, size_t data_len) -{ - int ret; - struct msghdr msg = {}; - - xprintk("--> %s", __FUNCTION__); - - if (IS_TCP(tsk->socket)) { - ret = kernel_recvmsg(tsk->socket, &msg, - tsk->vec_recv, KTLS_VEC_SIZE, - KTLS_RECORD_SIZE(tsk, data_len), MSG_TRUNC); - WARN_ON(ret != KTLS_RECORD_SIZE(tsk, data_len)); - } else { /* UDP */ - ret = kernel_recvmsg(tsk->socket, - &msg, tsk->vec_recv, KTLS_VEC_SIZE, - /*size*/0, /*flags*/0); - WARN_ON(ret != 0); - } -} - static int tls_do_encryption(struct tls_sock *tsk, struct scatterlist *sgin, struct scatterlist *sgout, @@ -849,7 +1418,7 @@ static int tls_do_encryption(struct tls_sock *tsk, struct aead_request *aead_req = (void *)aead_req_data; struct af_alg_completion completion; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); aead_request_set_tfm(aead_req, tsk->aead_send); aead_request_set_ad(aead_req, KTLS_PADDED_AAD_SIZE); @@ -860,13 +1429,15 @@ static int tls_do_encryption(struct tls_sock *tsk, &completion); } +/*TODO: Avoid kernel_sendmsg */ static int tls_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) { struct tls_sock *tsk; unsigned int i; unsigned int cnt = 0; int ret = 0; - xprintk("--> %s", __FUNCTION__); + + xprintk("--> %s", __func__); tsk = tls_sk(sock->sk); lock_sock(sock->sk); @@ -881,26 +1452,30 @@ static int tls_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) goto send_end; } - // TODO: handle flags, see issue #4 + /* TODO: handle flags, see issue #4 */ tls_make_aad(tsk, 0, tsk->aad_send, size, tsk->iv_send); while (iov_iter_count(&msg->msg_iter)) { size_t seglen = iov_iter_count(&msg->msg_iter); - int len = af_alg_make_sg(&tsk->sgl_send[cnt], &msg->msg_iter, seglen); + int len = af_alg_make_sg(&tsk->sgl_send[cnt], + &msg->msg_iter, seglen); if (len < 0) goto send_end; ret += len; if (cnt) - af_alg_link_sg(&tsk->sgl_send[cnt-1], &tsk->sgl_send[cnt]); + af_alg_link_sg(&tsk->sgl_send[cnt-1], + &tsk->sgl_send[cnt]); iov_iter_advance(&msg->msg_iter, len); cnt++; } sg_unmark_end(&tsk->sgaad_send[1]); sg_chain(tsk->sgaad_send, 2, tsk->sgl_send[0].sg); - sg_unmark_end(tsk->sgl_send[cnt-1].sg + tsk->sgl_send[cnt-1].npages - 1); - sg_chain(tsk->sgl_send[cnt-1].sg, tsk->sgl_send[cnt-1].npages + 1, tsk->sgtag_send); + sg_unmark_end(tsk->sgl_send[cnt-1].sg + + tsk->sgl_send[cnt-1].npages - 1); + sg_chain(tsk->sgl_send[cnt-1].sg, tsk->sgl_send[cnt-1].npages + 1, + tsk->sgtag_send); ret = tls_do_encryption(tsk, tsk->sgaad_send, tsk->sg_tx_data, size); if (ret < 0) @@ -935,13 +1510,13 @@ static int tls_do_decryption(const struct tls_sock *tsk, struct aead_request *aead_req = (void *)aead_req_data; struct af_alg_completion completion; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); aead_request_set_tfm(aead_req, tsk->aead_recv); aead_request_set_ad(aead_req, KTLS_PADDED_AAD_SIZE); aead_request_set_crypt(aead_req, sgin, sgout, data_len + KTLS_TAG_SIZE, - (u8*)tsk->header_recv + KTLS_NONCE_OFFSET(tsk)); + (u8 *)tsk->header_recv + KTLS_NONCE_OFFSET(tsk)); ret = af_alg_wait_for_completion( crypto_aead_decrypt(aead_req), @@ -950,416 +1525,276 @@ static int tls_do_decryption(const struct tls_sock *tsk, return ret; } -static int tls_post_process(const struct tls_sock *tsk, struct scatterlist *sgl) +static int tls_post_process(const struct tls_sock *tsk, struct sk_buff *skb) { - /* Placeholder for future extensibility via BPF or something similar. - * On several occasions we need to act based on the contents of the - * decrypted data (e.g., give control directly to userspace with a special - * error, etc). This will allow for protocols like Openconnect VPN to - * use this framework and handle specially the packets which are not data. + size_t prepend, overhead; + struct tls_rx_msg *rxm; + + prepend = IS_TLS(tsk) ? KTLS_TLS_PREPEND_SIZE : KTLS_DTLS_PREPEND_SIZE; + overhead = IS_TLS(tsk) ? KTLS_TLS_OVERHEAD : KTLS_DTLS_OVERHEAD; + rxm = tls_rx_msg(skb); + + /* The crypto API does the following transformation. + * Before: + * AAD(13) | DATA | TAG + * After: + * AAD(13) | DECRYPTED | TAG + * The AAD and TAG is left untouched. However we don't want that + * returned to the user. Therefore we fix the offsets and lengths */ - + rxm->offset += prepend; + rxm->full_len -= overhead; + increment_seqno(tsk->iv_recv); return 0; } -static inline ssize_t tls_peek_data(struct tls_sock *tsk, unsigned flags) -{ - int ret; - ssize_t peeked_size; - size_t data_len = 0; - size_t datagram_len; - struct msghdr msg = {}; - char *header; +static unsigned int tls_poll(struct file *file, struct socket *sock, + struct poll_table_struct *wait) { + + unsigned int ret; + struct tls_sock *tsk; + unsigned int mask; + struct sock *sk; - xprintk("--> %s", __FUNCTION__); + sk = sock->sk; + tsk = tls_sk(sock->sk); /* - * we need to peek first, so we know what will be received, we have to - * handle DTLS window here as well, since this is the only function that - * does actual recv + *Call POLL on the underlying socket, which will call sock_poll_wait + *on underlying socket. Used for POLLOUT and POLLHUP + *TODO: Should we be passing underlying socket file in? */ - do { - peeked_size = kernel_recvmsg(tsk->socket, &msg, - tsk->vec_recv, KTLS_VEC_SIZE, - KTLS_RECORD_SIZE(tsk, KTLS_MAX_PAYLOAD_SIZE), - MSG_PEEK | flags); - - if (peeked_size < 0) { - ret = peeked_size; - goto peek_failure; - } + ret = tsk->socket->ops->poll(tsk->socket->file, tsk->socket, wait); - header = tsk->header_recv; - // we handle only application data, let user space decide what - // to do otherwise - // - if (header[0] != KTLS_RECORD_DATA) { - ret = -EBADF; - goto peek_failure; - } - - if (IS_TLS(tsk)) { - data_len = ((header[4] & 0xFF) | (header[3] << 8)); - data_len = data_len - KTLS_TAG_SIZE - KTLS_IV_SIZE; - datagram_len = data_len + KTLS_TLS_OVERHEAD; - } else { - data_len = ((header[12] & 0xFF) | (header[11] << 8)); - data_len = data_len - KTLS_TAG_SIZE - KTLS_IV_SIZE; - datagram_len = data_len + KTLS_DTLS_OVERHEAD; - } + /* + * Clear POLLIN bits. Data available in the underlying socket is not + * necessarily ready to be read. The data could still be in the process + * of decryption, or it could be meant for original fd. + */ + ret &= ~(POLLIN | POLLRDNORM); - if (data_len > KTLS_MAX_PAYLOAD_SIZE) { - ret = -EMSGSIZE; - goto peek_failure; - } + /* + * Used for POLLIN + * Call generic POLL on TLS socket, which works for any sockets provided + * the socket receive queue is only ever holding data ready to receive. + * Data ready to be read are stored in KTLS's sk_receive_queue + */ + mask = datagram_poll(file, sock, wait); - if (IS_TLS(tsk)) { - if (datagram_len > peeked_size) { - ret = -EFAULT; // TODO: consider returning ENOMEM - goto peek_failure; - } - } else { - if (datagram_len != peeked_size) { - ret = -EFAULT; - goto peek_failure; - } - } - } while (IS_DTLS(tsk) && - !dtls_window(tsk, tsk->header_recv + KTLS_DTLS_SEQ_NUM_OFFSET)); + /* + * Clear POLLOUT and POLLHUPbits. Even if KTLS is ready to send, data + * won't be sent if the underlying socket is not ready. in addition, + * even if KTLS was initialized as a stream socket, it's not actually + * connected to anything, so we ignore its POLLHUP. + * Also, we don't support priority band writes in KTLS + */ + mask &= ~(POLLOUT | POLLWRNORM | POLLHUP); - return data_len; + ret |= mask; -peek_failure: + /* + * POLLERR should return if either socket is received error. + * We don't support high-priority data atm, so clear those bits + */ + ret &= ~(POLLWRBAND | POLLRDBAND); return ret; } -static void tls_rx_async_work(struct work_struct *w) +static struct sk_buff *tls_wait_data(struct tls_sock *tsk, int flags, + long timeo, int *err) { - int ret; - ssize_t data_len; + struct sk_buff *skb; struct sock *sk; - struct tls_sock *tsk = container_of(w, struct tls_sock, recv_work); - - sk = (struct sock*) tsk; - xprintk("--> %s", __FUNCTION__); + sk = (struct sock *)tsk; - if (!KTLS_RECV_READY(tsk)) - return; - - if (mutex_trylock(&tsk->rx_lock)) { - lock_sock(sk); - read_lock_bh(&sk->sk_callback_lock); - - if (!tsk->socket || tsk->rx_stopped) { - goto rx_work_end; + while (!(skb = skb_peek(&sk->sk_receive_queue))) { + /* Don't clear sk_err since recvmsg may not return + * it immediately. Instead, clear it after the next + * attach + */ + if (sk->sk_err) { + *err = sk->sk_err; + return NULL; } - // already occupied? - if (TLS_CACHE_SIZE(tsk) != 0) - goto rx_work_end; - - tsk->parallel_count_stat++; // TODO: remove - - data_len = tls_peek_data(tsk, MSG_DONTWAIT); - // nothing to process (-EAGAIN) or other error? let user space - // ask for it (do not cache errors) - if (data_len <= 0) - goto rx_work_end; - - tls_make_aad(tsk, 1, tsk->aad_recv, data_len, - tsk->iv_recv); - - ret = tls_do_decryption(tsk, tsk->sg_rx_data, - tsk->sg_rx_async_work, data_len); - if (ret < 0) - goto rx_work_end; - - TLS_CACHE_SET_SIZE(tsk, data_len); + if (sock_flag(sk, SOCK_DONE)) + return NULL; -rx_work_end: - read_unlock_bh(&sk->sk_callback_lock); - release_sock(sk); - mutex_unlock(&tsk->rx_lock); - } else { - // wake up rx queue - tsk->saved_sk_data_ready(tsk->socket->sk); - } -} + if ((flags & MSG_DONTWAIT) || !timeo) { + *err = -EAGAIN; + return NULL; + } -static const struct pipe_buf_operations tls_pipe_buf_ops = { - .can_merge = 0, - .confirm = generic_pipe_buf_confirm, - .release = generic_pipe_buf_release, - .steal = generic_pipe_buf_steal, - .get = generic_pipe_buf_get, -}; + sk_wait_data(sk, &timeo, NULL); -static void tls_spd_release(struct splice_pipe_desc *spd, unsigned int i) -{ - put_page(spd->pages[i]); -} - -static int tls_splice_read_alloc(struct splice_pipe_desc *spd, - size_t data_len) { - int ret; - size_t not_allocated, to_alloc; - size_t pages_needed, i, j; - - pages_needed = data_len / PAGE_SIZE; - if (pages_needed * PAGE_SIZE < data_len) - pages_needed++; - - not_allocated = data_len; - for (i = 0; i < pages_needed; i++) { - to_alloc = min_t(size_t, PAGE_SIZE, not_allocated); - spd->pages[i] = alloc_page(GFP_KERNEL); - if (!spd->pages[i]) { - for (j = 0; j < i; j++) - __free_page(spd->pages[j]); - ret = -ENOMEM; - goto splice_read_alloc_end; + /* Handle signals */ + if (signal_pending(current)) { + *err = sock_intr_errno(timeo); + return NULL; } - - spd->partial[i].len = to_alloc; - spd->partial[i].offset = 0; - spd->partial[i].private = 0; - not_allocated -= to_alloc; } - spd->nr_pages = pages_needed; - spd->nr_pages_max = pages_needed; - - ret = pages_needed; - -splice_read_alloc_end: - return ret; - + return skb; } -static ssize_t tls_splice_read(struct socket *sock, loff_t *ppos, - struct pipe_inode_info *pipe, - size_t size, unsigned int flags) +static int tls_recvmsg(struct socket *sock, + struct msghdr *msg, + size_t len, + int flags) { - ssize_t ret; - size_t copy; - size_t to_assign, assigned; - ssize_t data_len; - size_t i; - struct scatterlist sg[KTLS_DATA_PAGES + 1]; // +1 for chaining + ssize_t copied = 0; + int err = 0; + long timeo; struct tls_sock *tsk; - struct page *pages[KTLS_DATA_PAGES + 2]; // +1 for header, +1 for tag - struct partial_page partial[KTLS_DATA_PAGES + 2]; - struct splice_pipe_desc spd = { - .pages = pages, - .partial = partial, - .nr_pages = 0, // assigned bellow - .nr_pages_max = 0, // assigned bellow - .flags = flags, // TODO: handle, see issue #4 - .ops = &tls_pipe_buf_ops, - .spd_release = tls_spd_release, - }; - - xprintk("--> %s", __FUNCTION__); + struct tls_rx_msg *rxm; + int ret = 0; + struct sk_buff *skb; + + xprintk("--> %s", __func__); tsk = tls_sk(sock->sk); - mutex_lock(&tsk->rx_lock); lock_sock(sock->sk); if (!KTLS_RECV_READY(tsk)) { - ret = -EBADMSG; - goto splice_read_end; + err = -EBADMSG; + goto recv_end; } - if (TLS_CACHE_SIZE(tsk) > 0) { // we already received asynchronously - data_len = TLS_CACHE_SIZE(tsk); - - ret = tls_splice_read_alloc(&spd, data_len); - if (ret < 0) - goto splice_read_end; - - for (i = 0; data_len; i++) { - copy = min_t(size_t, - tsk->sg_rx_async_work[i + 1].length, - data_len); - memcpy(page_address(spd.pages[i]), - page_address(sg_page(tsk->sg_rx_async_work + i + 1)), - copy); - - spd.partial[i].len = copy; - spd.partial[i].offset = 0; - spd.partial[i].private = 0; - data_len -= copy; - } - data_len = TLS_CACHE_SIZE(tsk); - - ret = splice_to_pipe(pipe, &spd); - - if (ret > 0) - TLS_CACHE_DISCARD(tsk); - } else { - data_len = tls_peek_data(tsk, 0); - - if (data_len < 0) { - ret = data_len; - goto splice_read_end; - } - - if (data_len > size) { - ret = -EMSGSIZE; - goto splice_read_end; - } - - ret = tls_splice_read_alloc(&spd, data_len); - if (ret < 0) - goto splice_read_end; - - // assign to sg, so we can do decryption - sg_init_table(sg, ret + 1); - to_assign = data_len; - for (i = 0; to_assign; i ++) { - assigned = min_t(size_t, PAGE_SIZE, to_assign); - sg_set_page(sg + i, spd.pages[i], assigned, 0); - to_assign -= assigned; + timeo = sock_rcvtimeo(&tsk->sk, flags & MSG_DONTWAIT); + do { + int chunk; + /*TODO: Consider helping with decryption */ + skb = tls_wait_data(tsk, flags, timeo, &err); + if (!skb) + goto recv_end; + rxm = tls_rx_msg(skb); + chunk = min_t(unsigned int, rxm->full_len, len); + err = skb_copy_datagram_msg(skb, rxm->offset, msg, chunk); + if (err < 0) + goto recv_end; + copied += chunk; + len -= chunk; + if (likely(!(flags & MSG_PEEK))) { + if (copied < rxm->full_len) { + rxm->offset += copied; + rxm->full_len -= copied; + } else { + /* Finished with message */ + skb_unlink(skb, &((struct sock *)tsk) + ->sk_receive_queue); + kfree_skb(skb); + } } - sg_chain(tsk->sgaad_recv, 2, sg); - sg_unmark_end(&sg[ret - 1]); - sg_chain(sg, ret + 1, tsk->sgtag_recv); - - tls_make_aad(tsk, 1, tsk->aad_recv, data_len, - tsk->iv_recv); - - ret = tls_do_decryption(tsk, tsk->sg_rx_data, - tsk->sgaad_recv, data_len); - if (ret < 0) - goto splice_read_end; - - ret = splice_to_pipe(pipe, &spd); - } - if (ret > 0) { - increment_seqno(tsk->iv_recv); - tls_pop_record(tsk, data_len); - } - -splice_read_end: - // restore chaining for receiving - sg_chain(tsk->sgaad_recv, 2, tsk->sgl_recv[0].sg); + } while (len); - if (ret > 0) - queue_work(tls_wq, &tsk->recv_work); +recv_end: release_sock(sock->sk); - mutex_unlock(&tsk->rx_lock); - + ret = copied ? : err; return ret; } -static int tls_recvmsg(struct socket *sock, +static int dtls_recvmsg(struct socket *sock, struct msghdr *msg, - size_t size, + size_t len, int flags) { - int i; - size_t copy, copied; - ssize_t data_len; + ssize_t copied = 0; + int err; struct tls_sock *tsk; + struct tls_rx_msg *rxm; int ret = 0; - unsigned int cnt = 0; + struct sk_buff *skb; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); tsk = tls_sk(sock->sk); - mutex_lock(&tsk->rx_lock); lock_sock(sock->sk); if (!KTLS_RECV_READY(tsk)) { - ret = -EBADMSG; + err = -EBADMSG; goto recv_end; } - // TODO: handle flags, see issue #4 - - if (TLS_CACHE_SIZE(tsk) > 0) { - if (size < TLS_CACHE_SIZE(tsk)) { - ret = -ENOMEM; - goto recv_end; - } - - data_len = TLS_CACHE_SIZE(tsk); - for (i = 1; data_len; i++) { - copy = min_t(size_t, - tsk->sg_rx_async_work[i].length, - data_len); - copied = copy_page_to_iter(sg_page(tsk->sg_rx_async_work + i), - tsk->sg_rx_async_work[i].offset, - copy, - &msg->msg_iter); - if (copied < copy) { - ret = -EFAULT; - goto recv_end; - } + skb = skb_recv_datagram((struct sock *)tsk, flags & ~MSG_DONTWAIT, + flags & MSG_DONTWAIT, &err); + if (!skb) + goto recv_end; + rxm = tls_rx_msg(skb); + err = skb_copy_datagram_msg(skb, rxm->offset, msg, rxm->full_len); + if (err < 0) + goto recv_end; + copied = rxm->full_len; + if (copied > len) + msg->msg_flags |= MSG_TRUNC; + if (likely(!(flags & MSG_PEEK))) { + msg->msg_flags |= MSG_EOR; + skb_free_datagram((struct sock *)tsk, skb); + } +recv_end: - data_len -= copied; - } + release_sock(sock->sk); + ret = copied? : err; + return ret; +} +static ssize_t tls_sock_splice(struct sock *sk, + struct pipe_inode_info *pipe, + struct splice_pipe_desc *spd) +{ + int ret; - ret = TLS_CACHE_SIZE(tsk); - TLS_CACHE_DISCARD(tsk); - } else { - while (iov_iter_count(&msg->msg_iter)) { - size_t seglen = iov_iter_count(&msg->msg_iter); - int len = af_alg_make_sg(&tsk->sgl_recv[cnt], &msg->msg_iter, seglen); - if (len < 0) - goto recv_end; - ret += len; - if (cnt) - af_alg_link_sg(&tsk->sgl_recv[cnt-1], &tsk->sgl_recv[cnt]); - iov_iter_advance(&msg->msg_iter, len); - cnt++; - } - sg_unmark_end(&tsk->sgl_recv[cnt-1].sg[tsk->sgl_recv[cnt-1].npages - 1]); - sg_chain(tsk->sgl_recv[cnt-1].sg, tsk->sgl_recv[cnt-1].npages + 1, tsk->sgtag_recv); + release_sock(sk); + ret = splice_to_pipe(pipe, spd); + lock_sock(sk); - data_len = tls_peek_data(tsk, 0); + return ret; +} - if (data_len < 0) { - ret = data_len; - goto recv_end; - } +static ssize_t tls_splice_read(struct socket *sock, loff_t *ppos, + struct pipe_inode_info *pipe, + size_t len, unsigned int flags) +{ + ssize_t copied = 0; + long timeo; + struct tls_sock *tsk; + struct tls_rx_msg *rxm; + int ret = 0; + struct sk_buff *skb; + int chunk; + int err = 0; + struct sock *sk = sock->sk; - if (size < data_len) { - ret = -ENOMEM; - goto recv_end; - } + xprintk("--> %s", __func__); - tls_make_aad(tsk, 1, tsk->aad_recv, data_len, - tsk->iv_recv); + tsk = tls_sk(sk); + lock_sock(sk); - ret = tls_do_decryption(tsk, - tsk->sg_rx_data, - tsk->sgaad_recv, - data_len); - if (ret < 0) - goto recv_end; + if (!KTLS_RECV_READY(tsk)) { + err = -EBADMSG; + goto splice_read_end; + } - ret = tls_post_process(tsk, tsk->sgaad_recv); - if (ret < 0) - goto recv_end; + timeo = sock_rcvtimeo(&tsk->sk, flags & MSG_DONTWAIT); - ret = data_len; - } + skb = tls_wait_data(tsk, flags, timeo, &err); + if (!skb) + goto splice_read_end; - tls_pop_record(tsk, ret); - increment_seqno(tsk->iv_recv); + rxm = tls_rx_msg(skb); + chunk = min_t(unsigned int, rxm->full_len, len); + copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, + flags, tls_sock_splice); + if (ret < 0) + goto splice_read_end; -recv_end: - if (ret > 0) - queue_work(tls_wq, &tsk->recv_work); - for (i = 0; i < cnt; i++) - af_alg_free_sg(&tsk->sgl_recv[i]); - release_sock(sock->sk); - mutex_unlock(&tsk->rx_lock); + rxm->offset += copied; + rxm->full_len -= copied; +splice_read_end: + release_sock(sk); + ret = (copied)?copied:err; return ret; } @@ -1369,7 +1804,7 @@ static ssize_t tls_do_sendpage(struct tls_sock *tsk) size_t data_len; struct msghdr msg = {}; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); data_len = min_t(size_t, tsk->sendpage_ctx.current_size, @@ -1403,9 +1838,9 @@ static ssize_t tls_do_sendpage(struct tls_sock *tsk) tls_free_sendpage_ctx(tsk); do_sendmsg_end: - // restore, so we can use sendmsg() + /* restore, so we can use sendmsg() */ sg_chain(tsk->sgaad_send, 2, tsk->sgl_send[0].sg); - // remove chaining to sg tag + /* remove chaining to sg tag */ sg_mark_end(&tsk->sendpage_ctx.sg[tsk->sendpage_ctx.used]); return ret; @@ -1421,7 +1856,7 @@ static ssize_t tls_sendpage(struct socket *sock, struct page *page, struct tls_sock *tsk; struct scatterlist *sg; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); tsk = tls_sk(sock->sk); lock_sock(sock->sk); @@ -1436,7 +1871,7 @@ static ssize_t tls_sendpage(struct socket *sock, struct page *page, goto sendpage_end; } - // TODO: handle flags, see issue #4 + /* TODO: handle flags, see issue #4 */ sg = tsk->sendpage_ctx.sg; @@ -1475,13 +1910,14 @@ static ssize_t tls_sendpage(struct socket *sock, struct page *page, return ret < 0 ? ret : size; } +/*TODO: When binding, blow away all skbs from underlying socket */ static int tls_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) { int ret; struct tls_sock *tsk; struct sockaddr_ktls *sa_ktls; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); if (uaddr == NULL || sizeof(*sa_ktls) != addr_len) return -EBADMSG; @@ -1490,28 +1926,28 @@ static int tls_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) sa_ktls = (struct sockaddr_ktls *) uaddr; switch (sa_ktls->sa_cipher) { - case KTLS_CIPHER_AES_GCM_128: - tsk->cipher_type = KTLS_CIPHER_AES_GCM_128; - tsk->cipher_crypto = "rfc5288(gcm(aes))"; - break; - default: - return -ENOENT; + case KTLS_CIPHER_AES_GCM_128: + tsk->cipher_type = KTLS_CIPHER_AES_GCM_128; + tsk->cipher_crypto = "rfc5288(gcm(aes))"; + break; + default: + return -ENOENT; } - switch(sa_ktls->sa_version) { - case KTLS_VERSION_LATEST: - /* passthrough */ - case KTLS_VERSION_1_2: - if (IS_TLS(tsk)) { - tsk->version[0] = KTLS_TLS_1_2_MAJOR; - tsk->version[1] = KTLS_TLS_1_2_MINOR; - } else { - tsk->version[0] = KTLS_DTLS_1_2_MAJOR; - tsk->version[1] = KTLS_DTLS_1_2_MINOR; - } - break; - default: - return -ENOENT; + switch (sa_ktls->sa_version) { + case KTLS_VERSION_LATEST: + /* passthrough */ + case KTLS_VERSION_1_2: + if (IS_TLS(tsk)) { + tsk->version[0] = KTLS_TLS_1_2_MAJOR; + tsk->version[1] = KTLS_TLS_1_2_MINOR; + } else { + tsk->version[0] = KTLS_DTLS_1_2_MAJOR; + tsk->version[1] = KTLS_DTLS_1_2_MINOR; + } + break; + default: + return -ENOENT; } lock_sock(sock->sk); @@ -1534,29 +1970,47 @@ static int tls_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) } xprintk("--1"); - tsk->aead_recv = crypto_alloc_aead(tsk->cipher_crypto, - CRYPTO_ALG_INTERNAL, 0); - xprintk("--1"); - if (IS_ERR(tsk->aead_recv)) { - ret = PTR_ERR(tsk->aead_recv); - tsk->aead_recv = NULL; - goto bind_end; + if (!tsk->aead_recv) { + tsk->aead_recv = crypto_alloc_aead(tsk->cipher_crypto, + CRYPTO_ALG_INTERNAL, 0); + xprintk("--1"); + if (IS_ERR(tsk->aead_recv)) { + ret = PTR_ERR(tsk->aead_recv); + tsk->aead_recv = NULL; + goto bind_end; + } + } - tsk->aead_send = crypto_alloc_aead(tsk->cipher_crypto, - CRYPTO_ALG_INTERNAL, 0); - if (IS_ERR(tsk->aead_send)) { - ret = PTR_ERR(tsk->aead_send); - tsk->aead_send = NULL; - goto bind_end; + if (!tsk->aead_send) { + tsk->aead_send = crypto_alloc_aead(tsk->cipher_crypto, + CRYPTO_ALG_INTERNAL, 0); + if (IS_ERR(tsk->aead_send)) { + ret = PTR_ERR(tsk->aead_send); + tsk->aead_send = NULL; + goto bind_end; + } } + ((struct sock *)tsk)->sk_err = 0; + write_lock_bh(&tsk->socket->sk->sk_callback_lock); tsk->rx_stopped = 0; tsk->saved_sk_data_ready = tsk->socket->sk->sk_data_ready; - tsk->socket->sk->sk_data_ready = tls_data_ready; + if (IS_TLS(tsk)) + tsk->socket->sk->sk_data_ready = tls_data_ready; + else + tsk->socket->sk->sk_data_ready = dtls_data_ready; tsk->socket->sk->sk_user_data = tsk; write_unlock_bh(&tsk->socket->sk->sk_callback_lock); + + /* Check if any TLS packets have come in between the time the + * handshake was completed and bin() was called. If there were, + * the packets would have woken up TCP socket waiters, not + * KTLS. Therefore, pull the packets from TCP and wake up KTLS + * if necessary + */ + do_tls_sock_rx_work(tsk); release_sock(sock->sk); return 0; @@ -1571,7 +2025,7 @@ int tls_release(struct socket *sock) { struct tls_sock *tsk; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); tsk = tls_sk(sock->sk); @@ -1583,7 +2037,7 @@ int tls_release(struct socket *sock) return 0; } -static const struct proto_ops tls_proto_ops = { +static const struct proto_ops tls_stream_ops = { .family = PF_KTLS, .owner = THIS_MODULE, @@ -1594,7 +2048,7 @@ static const struct proto_ops tls_proto_ops = { .listen = sock_no_listen, .shutdown = sock_no_shutdown, .mmap = sock_no_mmap, - .poll = sock_no_poll, + .poll = tls_poll, .accept = sock_no_accept, .bind = tls_bind, @@ -1604,23 +2058,47 @@ static const struct proto_ops tls_proto_ops = { .recvmsg = tls_recvmsg, .sendpage = tls_sendpage, .release = tls_release, - .splice_read = tls_splice_read, + .splice_read = tls_splice_read, +}; + +static const struct proto_ops tls_dgram_ops = { + .family = PF_KTLS, + .owner = THIS_MODULE, + + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .getname = sock_no_getname, + .ioctl = sock_no_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .mmap = sock_no_mmap, + .poll = tls_poll, + .accept = sock_no_accept, + + .bind = tls_bind, + .setsockopt = tls_setsockopt, + .getsockopt = tls_getsockopt, + .sendmsg = tls_sendmsg, + .recvmsg = dtls_recvmsg, + .sendpage = tls_sendpage, + .release = tls_release, + .splice_read = tls_splice_read, }; static void tls_sock_destruct(struct sock *sk) { struct tls_sock *tsk; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); tsk = tls_sk(sk); - // TODO: remove - printk("tls: parallel executions: %u\n", tsk->parallel_count_stat); + /* TODO: remove */ + pr_debug("tls: parallel executions: %u\n", tsk->parallel_count_stat); cancel_work_sync(&tsk->recv_work); - // restore callback and abandon socket + /* restore callback and abandon socket */ if (tsk->socket) { write_lock_bh(&tsk->socket->sk->sk_callback_lock); @@ -1633,30 +2111,21 @@ static void tls_sock_destruct(struct sock *sk) tsk->socket = NULL; } - if (tsk->iv_send) - kfree(tsk->iv_send); + /* kfree(NULL) is safe */ + kfree(tsk->iv_send); - if (tsk->key_send.key) - kfree(tsk->key_send.key); + kfree(tsk->key_send.key); - if (tsk->iv_recv) - kfree(tsk->iv_recv); + kfree(tsk->iv_recv); - if (tsk->key_recv.key) - kfree(tsk->key_recv.key); + kfree(tsk->key_recv.key); - if (tsk->aead_send) - crypto_free_aead(tsk->aead_send); + crypto_free_aead(tsk->aead_send); - if (tsk->aead_recv) - crypto_free_aead(tsk->aead_recv); + crypto_free_aead(tsk->aead_recv); if (tsk->pages_send) __free_pages(tsk->pages_send, KTLS_DATA_PAGES); - if (tsk->pages_recv) - __free_pages(tsk->pages_recv, KTLS_DATA_PAGES); - if (tsk->pages_work) - __free_pages(tsk->pages_work, KTLS_DATA_PAGES); } static struct proto tls_proto = { @@ -1675,10 +2144,18 @@ static int tls_create(struct net *net, struct sock *sk; struct tls_sock *tsk; - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); - if (sock->type != SOCK_DGRAM && sock->type != SOCK_STREAM) + switch (sock->type) { + case SOCK_STREAM: + sock->ops = &tls_stream_ops; + break; + case SOCK_DGRAM: + sock->ops = &tls_dgram_ops; + break; + default: return -ESOCKTNOSUPPORT; + } if (protocol != 0) return -EPROTONOSUPPORT; @@ -1687,13 +2164,12 @@ static int tls_create(struct net *net, if (!sk) return -ENOMEM; - sock->ops = &tls_proto_ops; sock_init_data(sock, sk); sk->sk_family = PF_KTLS; sk->sk_destruct = tls_sock_destruct; - // initialize stored context + /* initialize stored context */ tsk = tls_sk(sk); tsk->iv_send = NULL; @@ -1712,14 +2188,10 @@ static int tls_create(struct net *net, */ tsk->mtu_payload = KTLS_MAX_PAYLOAD_SIZE; - DTLS_WINDOW_INIT(tsk->dtls_window); - sg_init_table(tsk->sendpage_ctx.sg, KTLS_SG_DATA_SIZE); sg_mark_end(&tsk->sendpage_ctx.sg[0]); - mutex_init(&tsk->rx_lock); - - tsk->pages_send = tsk->pages_recv = tsk->pages_work = NULL; + tsk->pages_send = NULL; ret = -ENOMEM; /* @@ -1733,7 +2205,7 @@ static int tls_create(struct net *net, if (!tsk->pages_send) goto create_error; for (i = 0; i < KTLS_DATA_PAGES; i++) - // the first is HEADER + /* the first is HEADER */ sg_set_page(tsk->sg_tx_data + i + 1, tsk->pages_send + i, PAGE_SIZE, 0); @@ -1741,12 +2213,14 @@ static int tls_create(struct net *net, tsk->tag_send, sizeof(tsk->tag_send)); sg_mark_end(tsk->sg_tx_data + KTLS_SG_DATA_SIZE - 1); - // msg for sending + /* msg for sending */ tsk->vec_send[0].iov_base = tsk->header_send; tsk->vec_send[0].iov_len = IS_TLS(tsk) ? KTLS_TLS_PREPEND_SIZE : KTLS_DTLS_PREPEND_SIZE; for (i = 1; i <= KTLS_DATA_PAGES + 1; i++) { - tsk->vec_send[i].iov_base = page_address(sg_page(tsk->sg_tx_data + i)) + tsk->sg_tx_data[i].offset; + tsk->vec_send[i].iov_base = page_address(sg_page + (tsk->sg_tx_data + i)) + tsk-> + sg_tx_data[i].offset; tsk->vec_send[i].iov_len = tsk->sg_tx_data[i].length; } @@ -1756,68 +2230,12 @@ static int tls_create(struct net *net, sg_init_table(tsk->sgtag_send, 2); sg_set_buf(&tsk->sgaad_send[0], tsk->aad_send, sizeof(tsk->aad_send)); - // chaining to tag is performed on actual data size when sending + /* chaining to tag is performed on actual data size when sending */ sg_set_buf(&tsk->sgtag_send[0], tsk->tag_send, sizeof(tsk->tag_send)); sg_unmark_end(&tsk->sgaad_send[1]); sg_chain(tsk->sgaad_send, 2, tsk->sgl_send[0].sg); - - /* - * Preallocation for receiving - * scatterlist: AAD | data | TAG - * (for crypto AAD, aad and TAG are untouched) - * vec: HEADER | data | TAG - * async vec: HEADER| data | TAG - * - * for the async vec HEADER and TAG are reused, but chaining after async - * operation has to be restored - */ - sg_init_table(tsk->sg_rx_data, KTLS_SG_DATA_SIZE); - sg_set_buf(&tsk->sg_rx_data[0], tsk->aad_recv, sizeof(tsk->aad_recv)); - tsk->pages_recv = alloc_pages(GFP_KERNEL, KTLS_DATA_PAGES); - if (!tsk->pages_recv) - goto create_error; - for (i = 0; i < KTLS_DATA_PAGES; i++) - // the first is HEADER - sg_set_page(tsk->sg_rx_data + i + 1, tsk->pages_recv + i, PAGE_SIZE, 0); - sg_set_buf(tsk->sg_rx_data + KTLS_SG_DATA_SIZE - 2, - tsk->tag_recv, sizeof(tsk->tag_recv)); - sg_mark_end(tsk->sg_rx_data + KTLS_SG_DATA_SIZE - 1); - - // msg for receiving - tsk->vec_recv[0].iov_base = tsk->header_recv; - tsk->vec_recv[0].iov_len = IS_TLS(tsk) ? - KTLS_TLS_PREPEND_SIZE : KTLS_DTLS_PREPEND_SIZE; - for (i = 1; i <= KTLS_DATA_PAGES + 1; i++) { - tsk->vec_recv[i].iov_base = page_address(sg_page(tsk->sg_rx_data + i)) + tsk->sg_rx_data[i].offset; - tsk->vec_recv[i].iov_len = tsk->sg_rx_data[i].length; - } - - for (i = 0; i < UIO_MAXIOV; i++) - memset(&tsk->sgl_recv[i], 0, sizeof(tsk->sgl_recv[i])); - sg_init_table(tsk->sgaad_recv, 2); - sg_init_table(tsk->sgtag_recv, 2); - - sg_set_buf(&tsk->sgaad_recv[0], tsk->aad_recv, sizeof(tsk->aad_recv)); - // chaining to tag is performed on actual data size when receiving - sg_set_buf(&tsk->sgtag_recv[0], tsk->tag_recv, sizeof(tsk->tag_recv)); - - sg_unmark_end(&tsk->sgaad_recv[1]); - sg_chain(tsk->sgaad_recv, 2, tsk->sgl_recv[0].sg); - - // preallocation for asynchronous worker, where decrypted data are stored - sg_init_table(tsk->sg_rx_async_work, KTLS_SG_DATA_SIZE); - sg_set_buf(&tsk->sg_rx_async_work[0], tsk->aad_recv, sizeof(tsk->aad_recv)); - tsk->pages_work = alloc_pages(GFP_KERNEL, KTLS_DATA_PAGES); - for (i = 0; i < KTLS_DATA_PAGES; i++) - // the first is HEADER - sg_set_page(tsk->sg_rx_async_work + i + 1, tsk->pages_work + i, PAGE_SIZE, 0); - sg_set_buf(tsk->sg_rx_async_work + KTLS_SG_DATA_SIZE - 2, - tsk->tag_recv, - sizeof(tsk->tag_recv)); - sg_mark_end(tsk->sg_rx_async_work + KTLS_SG_DATA_SIZE - 1); - - INIT_WORK(&tsk->recv_work, tls_rx_async_work); + INIT_WORK(&tsk->recv_work, tls_rx_work); return 0; @@ -1836,7 +2254,8 @@ static const struct net_proto_family tls_family = { static int __init tls_init(void) { int ret = -ENOMEM; - xprintk("--> %s", __FUNCTION__); + + xprintk("--> %s", __func__); tls_wq = create_workqueue("ktls"); if (!tls_wq) @@ -1861,7 +2280,7 @@ static int __init tls_init(void) static void __exit tls_exit(void) { - xprintk("--> %s", __FUNCTION__); + xprintk("--> %s", __func__); sock_unregister(PF_KTLS); proto_unregister(&tls_proto); destroy_workqueue(tls_wq); diff --git a/af_ktls.h b/af_ktls.h index 995f52a..25c9130 100644 --- a/af_ktls.h +++ b/af_ktls.h @@ -37,6 +37,7 @@ #define KTLS_SET_KEY_SEND 5 #define KTLS_SET_SALT_SEND 6 #define KTLS_SET_MTU 7 +#define KTLS_UNATTACH 8 /* * setsockopt() optnames