Skip to content

Commit

Permalink
net/tcp: Disable TCP-MD5 static key on tcp_md5sig_info destruction
Browse files Browse the repository at this point in the history
To do that, separate two scenarios:
- where it's the first MD5 key on the system, which means that enabling
  of the static key may need to sleep;
- copying of an existing key from a listening socket to the request
  socket upon receiving a signed TCP segment, where static key was
  already enabled (when the key was added to the listening socket).

Now the life-time of the static branch for TCP-MD5 is until:
- last tcp_md5sig_info is destroyed
- last socket in time-wait state with MD5 key is closed.

Which means that after all sockets with TCP-MD5 keys are gone, the
system gets back the performance of disabled md5-key static branch.

Signed-off-by: Dmitry Safonov <dima@arista.com>
  • Loading branch information
0x7f454c46 authored and intel-lab-lkp committed Jul 26, 2022
1 parent b6cfe5c commit a4ee3ec
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 28 deletions.
10 changes: 7 additions & 3 deletions include/net/tcp.h
Expand Up @@ -1663,23 +1663,27 @@ int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
const struct sock *sk, const struct sk_buff *skb);
int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index, u8 flags,
const u8 *newkey, u8 newkeylen, gfp_t gfp);
const u8 *newkey, u8 newkeylen);
int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index,
struct tcp_md5sig_key *key);

int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index, u8 flags);
struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
const struct sock *addr_sk);

#ifdef CONFIG_TCP_MD5SIG
#include <linux/jump_label.h>
extern struct static_key_false tcp_md5_needed;
extern struct static_key_false_deferred tcp_md5_needed;
struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
const union tcp_md5_addr *addr,
int family);
static inline struct tcp_md5sig_key *
tcp_md5_do_lookup(const struct sock *sk, int l3index,
const union tcp_md5_addr *addr, int family)
{
if (!static_branch_unlikely(&tcp_md5_needed))
if (!static_branch_unlikely(&tcp_md5_needed.key))
return NULL;
return __tcp_md5_do_lookup(sk, l3index, addr, family);
}
Expand Down
5 changes: 1 addition & 4 deletions net/ipv4/tcp.c
Expand Up @@ -4404,11 +4404,8 @@ bool tcp_alloc_md5sig_pool(void)
if (unlikely(!tcp_md5sig_pool_populated)) {
mutex_lock(&tcp_md5sig_mutex);

if (!tcp_md5sig_pool_populated) {
if (!tcp_md5sig_pool_populated)
__tcp_alloc_md5sig_pool();
if (tcp_md5sig_pool_populated)
static_branch_inc(&tcp_md5_needed);
}

mutex_unlock(&tcp_md5sig_mutex);
}
Expand Down
45 changes: 35 additions & 10 deletions net/ipv4/tcp_ipv4.c
Expand Up @@ -1044,7 +1044,7 @@ static void tcp_v4_reqsk_destructor(struct request_sock *req)
* We need to maintain these in the sk structure.
*/

DEFINE_STATIC_KEY_FALSE(tcp_md5_needed);
DEFINE_STATIC_KEY_DEFERRED_FALSE(tcp_md5_needed, HZ);
EXPORT_SYMBOL(tcp_md5_needed);

static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new)
Expand Down Expand Up @@ -1171,9 +1171,9 @@ static int tcp_md5sig_info_add(struct sock *sk, gfp_t gfp)
}

/* This can be called on a newly created socket, from other files */
int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index, u8 flags,
const u8 *newkey, u8 newkeylen, gfp_t gfp)
int __tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index, u8 flags,
const u8 *newkey, u8 newkeylen, gfp_t gfp)
{
/* Add Key to the list */
struct tcp_md5sig_key *key;
Expand All @@ -1200,9 +1200,6 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
return 0;
}

if (tcp_md5sig_info_add(sk, gfp))
return -ENOMEM;

md5sig = rcu_dereference_protected(tp->md5sig_info,
lockdep_sock_is_held(sk));

Expand All @@ -1226,8 +1223,36 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
hlist_add_head_rcu(&key->node, &md5sig->head);
return 0;
}

int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index, u8 flags,
const u8 *newkey, u8 newkeylen)
{
if (tcp_md5sig_info_add(sk, GFP_KERNEL))
return -ENOMEM;

static_branch_inc(&tcp_md5_needed.key);

return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, flags,
newkey, newkeylen, GFP_KERNEL);
}
EXPORT_SYMBOL(tcp_md5_do_add);

int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index,
struct tcp_md5sig_key *key)
{
if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC)))
return -ENOMEM;

atomic_inc(&tcp_md5_needed.key.key.enabled);

return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index,
key->flags, key->key, key->keylen,
sk_gfp_mask(sk, GFP_ATOMIC));
}
EXPORT_SYMBOL(tcp_md5_key_copy);

int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family,
u8 prefixlen, int l3index, u8 flags)
{
Expand Down Expand Up @@ -1314,7 +1339,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
return -EINVAL;

return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags,
cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
cmd.tcpm_key, cmd.tcpm_keylen);
}

static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp,
Expand Down Expand Up @@ -1571,8 +1596,7 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
* memory, then we end up not copying the key
* across. Shucks.
*/
tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags,
key->key, key->keylen, GFP_ATOMIC);
tcp_md5_key_copy(newsk, addr, AF_INET, 32, l3index, key);
sk_gso_disable(newsk);
}
#endif
Expand Down Expand Up @@ -2260,6 +2284,7 @@ void tcp_v4_destroy_sock(struct sock *sk)
tcp_clear_md5_list(sk);
kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu);
tp->md5sig_info = NULL;
static_branch_slow_dec_deferred(&tcp_md5_needed);
}
#endif

Expand Down
9 changes: 6 additions & 3 deletions net/ipv4/tcp_minisocks.c
Expand Up @@ -291,13 +291,14 @@ void tcp_time_wait(struct sock *sk, int state, int timeo)
*/
do {
tcptw->tw_md5_key = NULL;
if (static_branch_unlikely(&tcp_md5_needed)) {
if (static_branch_unlikely(&tcp_md5_needed.key)) {
struct tcp_md5sig_key *key;

key = tp->af_specific->md5_lookup(sk, sk);
if (key) {
tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC);
BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool());
atomic_inc(&tcp_md5_needed.key.key.enabled);
}
}
} while (0);
Expand Down Expand Up @@ -337,11 +338,13 @@ EXPORT_SYMBOL(tcp_time_wait);
void tcp_twsk_destructor(struct sock *sk)
{
#ifdef CONFIG_TCP_MD5SIG
if (static_branch_unlikely(&tcp_md5_needed)) {
if (static_branch_unlikely(&tcp_md5_needed.key)) {
struct tcp_timewait_sock *twsk = tcp_twsk(sk);

if (twsk->tw_md5_key)
if (twsk->tw_md5_key) {
kfree_rcu(twsk->tw_md5_key, rcu);
static_branch_slow_dec_deferred(&tcp_md5_needed);
}
}
#endif
}
Expand Down
4 changes: 2 additions & 2 deletions net/ipv4/tcp_output.c
Expand Up @@ -766,7 +766,7 @@ static unsigned int tcp_syn_options(struct sock *sk, struct sk_buff *skb,

*md5 = NULL;
#ifdef CONFIG_TCP_MD5SIG
if (static_branch_unlikely(&tcp_md5_needed) &&
if (static_branch_unlikely(&tcp_md5_needed.key) &&
rcu_access_pointer(tp->md5sig_info)) {
*md5 = tp->af_specific->md5_lookup(sk, sk);
if (*md5) {
Expand Down Expand Up @@ -922,7 +922,7 @@ static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb

*md5 = NULL;
#ifdef CONFIG_TCP_MD5SIG
if (static_branch_unlikely(&tcp_md5_needed) &&
if (static_branch_unlikely(&tcp_md5_needed.key) &&
rcu_access_pointer(tp->md5sig_info)) {
*md5 = tp->af_specific->md5_lookup(sk, sk);
if (*md5) {
Expand Down
10 changes: 4 additions & 6 deletions net/ipv6/tcp_ipv6.c
Expand Up @@ -658,12 +658,11 @@ static int tcp_v6_parse_md5_keys(struct sock *sk, int optname,
if (ipv6_addr_v4mapped(&sin6->sin6_addr))
return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
AF_INET, prefixlen, l3index, flags,
cmd.tcpm_key, cmd.tcpm_keylen,
GFP_KERNEL);
cmd.tcpm_key, cmd.tcpm_keylen);

return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
AF_INET6, prefixlen, l3index, flags,
cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
cmd.tcpm_key, cmd.tcpm_keylen);
}

static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp,
Expand Down Expand Up @@ -1359,9 +1358,8 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
* memory, then we end up not copying the key
* across. Shucks.
*/
tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
AF_INET6, 128, l3index, key->flags, key->key, key->keylen,
sk_gfp_mask(sk, GFP_ATOMIC));
tcp_md5_key_copy(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
AF_INET6, 128, l3index, key);
}
#endif

Expand Down

0 comments on commit a4ee3ec

Please sign in to comment.