diff --git a/include/net/mptcp.h b/include/net/mptcp.h index db188a2af6fd2..75bfb49b5104c 100644 --- a/include/net/mptcp.h +++ b/include/net/mptcp.h @@ -344,6 +344,10 @@ struct mptcp_cb { u32 orig_window_clamp; }; +struct mptcp_hmacsha1_pool { + struct hash_desc hmacsha1_desc; +}; + #define MPTCP_SUB_CAPABLE 0 #define MPTCP_SUB_LEN_CAPABLE_SYN 12 #define MPTCP_SUB_LEN_CAPABLE_SYN_ALIGN 12 @@ -826,8 +830,9 @@ void mptcp_select_initial_window(int __space, __u32 mss, __u32 *rcv_wnd, const struct sock *sk); unsigned int mptcp_current_mss(struct sock *meta_sk); int mptcp_select_size(const struct sock *meta_sk, bool sg); +void mptcp_alloc_hmacsha1_pool(void); void mptcp_key_sha1(u64 key, u32 *token, u64 *idsn); -void mptcp_hmac_sha1(u8 *key_1, u8 *key_2, u32 *hash_out, int arg_num, ...); +void mptcp_hmac_sha1(u8 *key_1, u8 *key_2, u8 *hash_out, int arg_num, ...); void mptcp_clean_rtx_infinite(const struct sk_buff *skb, struct sock *sk); void mptcp_fin(struct sock *meta_sk); void mptcp_retransmit_timer(struct sock *meta_sk); @@ -906,6 +911,11 @@ void mptcp_get_default_scheduler(char *name); int mptcp_set_default_scheduler(const char *name); extern struct mptcp_sched_ops mptcp_sched_default; +static inline void mptcp_put_hmacsha1_pool(void) +{ + local_bh_enable(); +} + /* Initializes function-pointers and MPTCP-flags */ static inline void mptcp_init_tcp_sock(struct sock *sk) { diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index b4e957203356d..cbea62861ba10 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -443,6 +443,7 @@ void tcp_init_sock(struct sock *sk) /* Initialize MPTCP-specific stuff and function-pointers */ mptcp_init_tcp_sock(sk); + mptcp_alloc_hmacsha1_pool(); local_bh_disable(); sock_update_memcg(sk); diff --git a/net/mptcp/mptcp_ctrl.c b/net/mptcp/mptcp_ctrl.c index cfd9cae2364e3..194200cef4e36 100644 --- a/net/mptcp/mptcp_ctrl.c +++ b/net/mptcp/mptcp_ctrl.c @@ -76,6 +76,10 @@ bool mptcp_init_failed __read_mostly; struct static_key mptcp_static_key = STATIC_KEY_INIT_FALSE; EXPORT_SYMBOL(mptcp_static_key); +bool mptcp_hmacsha1_pool_populated; +static DEFINE_PER_CPU(struct mptcp_hmacsha1_pool, mptcp_hmacsha1_pool); +static DEFINE_MUTEX(mptcp_hmacsha1_mutex); + static int proc_mptcp_path_manager(struct ctl_table *ctl, int write, void __user *buffer, size_t *lenp, loff_t *ppos) @@ -380,6 +384,50 @@ static void mptcp_set_key_sk(const struct sock *sk) &tp->mptcp_loc_token, NULL); } +static void __mptcp_alloc_hmacsha1_pool(void) +{ + int cpu; + + for_each_possible_cpu(cpu) { + if (!per_cpu(mptcp_hmacsha1_pool, cpu).hmacsha1_desc.tfm) { + struct crypto_hash *hash; + + hash = crypto_alloc_hash("hmac(sha1)", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR_OR_NULL(hash)) + return; + per_cpu(mptcp_hmacsha1_pool, cpu).hmacsha1_desc.tfm = hash; + } + } + /* before setting mptcp_hmacsha1_pool_populated, we must commit all + * writes to memory. See smp_rmb() in mptcp_get_hmacsha1_pool() + */ + smp_wmb(); + mptcp_hmacsha1_pool_populated = true; +} + +void mptcp_alloc_hmacsha1_pool(void) +{ + if (unlikely(!mptcp_hmacsha1_pool_populated)) { + mutex_lock(&mptcp_hmacsha1_mutex); + if (!mptcp_hmacsha1_pool_populated) + __mptcp_alloc_hmacsha1_pool(); + mutex_unlock(&mptcp_hmacsha1_mutex); + } +} +EXPORT_SYMBOL(mptcp_alloc_hmacsha1_pool); + +struct mptcp_hmacsha1_pool *mptcp_get_hmacsha1_pool(void) +{ + local_bh_disable(); + if (mptcp_hmacsha1_pool_populated) { + /* coupled with smp_wmb() in __mptcp_alloc_hmacsha1_pool() */ + smp_rmb(); + return this_cpu_ptr(&mptcp_hmacsha1_pool); + } + local_bh_enable(); + return NULL; +} + #ifdef HAVE_JUMP_LABEL /* We are not allowed to call static_key_slow_dec() from irq context * If mptcp_enable/disable_static_key() is called from irq context, @@ -798,81 +846,49 @@ void mptcp_key_sha1(u64 key, u32 *token, u64 *idsn) *idsn = *((u64 *)&mptcp_hashed_key[3]); } -void mptcp_hmac_sha1(u8 *key_1, u8 *key_2, u32 *hash_out, int arg_num, ...) +void mptcp_hmac_sha1(u8 *key_1, u8 *key_2, u8 *hash_out, int arg_num, ...) { - u32 workspace[SHA_WORKSPACE_WORDS]; - u8 input[128]; /* 2 512-bit blocks */ + struct mptcp_hmacsha1_pool *sp; + struct scatterlist sg; + u8 key[16]; int i; - int index; int length; u8 *msg; va_list list; - memset(workspace, 0, sizeof(workspace)); - /* Initialize result placeholder */ - memset(hash_out, 0, sizeof(hash_out)); + sp = mptcp_get_hmacsha1_pool(); + if (!sp) + goto clear_hmac_noput; + sp->hmacsha1_desc.flags = 0; + + memcpy(&key[0], key_1, 8); + memcpy(&key[8], key_2, 8); - /* Generate key xored with ipad */ - memset(input, 0x36, 64); - for (i = 0; i < 8; i++) - input[i] ^= key_1[i]; - for (i = 0; i < 8; i++) - input[i + 8] ^= key_2[i]; + if (crypto_hash_setkey(sp->hmacsha1_desc.tfm, (u8 *)key, 16)) + goto clear_hmac; + if (crypto_hash_init(&sp->hmacsha1_desc)) + goto clear_hmac; va_start(list, arg_num); - index = 64; - for (i = 0; i < arg_num; i++) { + for (i = 0; i < (arg_num); i++) { length = va_arg(list, int); msg = va_arg(list, u8 *); - if (index + length > 125) { - /* The message is too long */ - return; - } - memcpy(&input[index], msg, length); - index += length; + sg_init_one(&sg, msg, length); + if (crypto_hash_update(&sp->hmacsha1_desc, &sg, length)) + goto clear_hmac; } va_end(list); - input[index] = 0x80; /* Padding: First bit after message = 1 */ - memset(&input[index + 1], 0, (126 - index)); - - /* Padding: Length of the message = 512 + message length (bits) */ - input[126] = 0x02; - input[127] = ((index - 64) * 8); /* Message length (bits) */ - - sha_init(hash_out); - sha_transform(hash_out, input, workspace); - memset(workspace, 0, sizeof(workspace)); - - sha_transform(hash_out, &input[64], workspace); - memset(workspace, 0, sizeof(workspace)); - - for (i = 0; i < 5; i++) - hash_out[i] = cpu_to_be32(hash_out[i]); - - /* Prepare second part of hmac */ - memset(input, 0x5C, 64); - for (i = 0; i < 8; i++) - input[i] ^= key_1[i]; - for (i = 0; i < 8; i++) - input[i + 8] ^= key_2[i]; - - memcpy(&input[64], hash_out, 20); - input[84] = 0x80; - memset(&input[85], 0, 41); - - /* Padding: Length of the message = 512 + 160 bits */ - input[126] = 0x02; - input[127] = 0xA0; - - sha_init(hash_out); - sha_transform(hash_out, input, workspace); - memset(workspace, 0, sizeof(workspace)); - - sha_transform(hash_out, &input[64], workspace); + if (crypto_hash_final(&sp->hmacsha1_desc, hash_out)) + goto clear_hmac; + mptcp_put_hmacsha1_pool(); + return; - for (i = 0; i < 5; i++) - hash_out[i] = cpu_to_be32(hash_out[i]); +clear_hmac: + mptcp_put_hmacsha1_pool(); +clear_hmac_noput: + memset((u8 *)hash_out, 0, 20); + return; } static void mptcp_mpcb_inherit_sockopts(struct sock *meta_sk, struct sock *master_sk) @@ -2045,7 +2061,7 @@ struct sock *mptcp_check_req_child(struct sock *meta_sk, struct sock *child, mptcp_hmac_sha1((u8 *)&mpcb->mptcp_rem_key, (u8 *)&mpcb->mptcp_loc_key, - (u32 *)hash_mac_check, 2, + (u8 *)hash_mac_check, 2, 4, (u8 *)&mtreq->mptcp_rem_nonce, 4, (u8 *)&mtreq->mptcp_loc_nonce); @@ -2278,7 +2294,7 @@ void mptcp_join_reqsk_init(struct mptcp_cb *mpcb, const struct request_sock *req mptcp_hmac_sha1((u8 *)&mpcb->mptcp_loc_key, (u8 *)&mpcb->mptcp_rem_key, - (u32 *)mptcp_hash_mac, 2, + (u8 *)mptcp_hash_mac, 2, 4, (u8 *)&mtreq->mptcp_loc_nonce, 4, (u8 *)&mtreq->mptcp_rem_nonce); mtreq->mptcp_hash_tmac = *(u64 *)mptcp_hash_mac; diff --git a/net/mptcp/mptcp_fullmesh.c b/net/mptcp/mptcp_fullmesh.c index 81a8d767d8d44..9d4a542e97a52 100644 --- a/net/mptcp/mptcp_fullmesh.c +++ b/net/mptcp/mptcp_fullmesh.c @@ -1570,7 +1570,7 @@ static void full_mesh_addr_signal(struct sock *sk, unsigned *size, *(u64 *)no_key = 0; mptcp_hmac_sha1((u8 *)&mpcb->mptcp_loc_key, (u8 *)no_key, - (u32 *)mptcp_hash_mac, 2, + (u8 *)mptcp_hash_mac, 2, 1, (u8 *)&mptcp_local->locaddr4[ind].loc4_id, 4, (u8 *)&opts->add_addr4.addr.s_addr); opts->add_addr4.trunc_mac = *(u64 *)mptcp_hash_mac; @@ -1611,7 +1611,7 @@ static void full_mesh_addr_signal(struct sock *sk, unsigned *size, *(u64 *)no_key = 0; mptcp_hmac_sha1((u8 *)&mpcb->mptcp_loc_key, (u8 *)no_key, - (u32 *)mptcp_hash_mac, 2, + (u8 *)mptcp_hash_mac, 2, 1, (u8 *)&mptcp_local->locaddr6[ind].loc6_id, 16, (u8 *)&opts->add_addr6.addr.s6_addr); opts->add_addr6.trunc_mac = *(u64 *)mptcp_hash_mac; diff --git a/net/mptcp/mptcp_input.c b/net/mptcp/mptcp_input.c index ef634dac9453a..fa3b4143bc867 100644 --- a/net/mptcp/mptcp_input.c +++ b/net/mptcp/mptcp_input.c @@ -1940,7 +1940,7 @@ static void mptcp_handle_add_addr(const unsigned char *ptr, struct sock *sk) } mptcp_hmac_sha1((u8 *)&mpcb->mptcp_rem_key, (u8 *)no_key, - (u32 *)hash_mac_check, msg_parts, + (u8 *)hash_mac_check, msg_parts, 1, (u8 *)&mpadd->addr_id, 4, (u8 *)&mpadd->u.v4.addr.s_addr, 2, (u8 *)&mpadd->u.v4.port); @@ -1975,7 +1975,7 @@ static void mptcp_handle_add_addr(const unsigned char *ptr, struct sock *sk) } mptcp_hmac_sha1((u8 *)&mpcb->mptcp_rem_key, (u8 *)no_key, - (u32 *)hash_mac_check, msg_parts, + (u8 *)hash_mac_check, msg_parts, 1, (u8 *)&mpadd->addr_id, 16, (u8 *)&mpadd->u.v6.addr.s6_addr, 2, (u8 *)&mpadd->u.v6.port); @@ -2292,7 +2292,7 @@ int mptcp_rcv_synsent_state_process(struct sock *sk, struct sock **skptr, mptcp_hmac_sha1((u8 *)&mpcb->mptcp_rem_key, (u8 *)&mpcb->mptcp_loc_key, - (u32 *)hash_mac_check, 2, + (u8 *)hash_mac_check, 2, 4, (u8 *)&tp->mptcp->rx_opt.mptcp_recv_nonce, 4, (u8 *)&tp->mptcp->mptcp_loc_nonce); if (memcmp(hash_mac_check, @@ -2310,7 +2310,7 @@ int mptcp_rcv_synsent_state_process(struct sock *sk, struct sock **skptr, mptcp_hmac_sha1((u8 *)&mpcb->mptcp_loc_key, (u8 *)&mpcb->mptcp_rem_key, - (u32 *)&tp->mptcp->sender_mac[0], 2, + (u8 *)&tp->mptcp->sender_mac[0], 2, 4, (u8 *)&tp->mptcp->mptcp_loc_nonce, 4, (u8 *)&tp->mptcp->rx_opt.mptcp_recv_nonce);