Skip to content

Commit

Permalink
sock: introduce sk_prot->update_proto()
Browse files Browse the repository at this point in the history
Currently sockmap calls into each protocol to update the struct
proto and replace it. This certainly won't work when the protocol
is implemented as a module, for example, AF_UNIX.

Introduce a new ops sk->sk_prot->update_proto(), so each protocol
can implement its own way to replace the struct proto.

Cc: John Fastabend <john.fastabend@gmail.com>
Cc: Daniel Borkmann <daniel@iogearbox.net>
Cc: Jakub Sitnicki <jakub@cloudflare.com>
Cc: Lorenz Bauer <lmb@cloudflare.com>
Signed-off-by: Cong Wang <cong.wang@bytedance.com>
  • Loading branch information
Cong Wang committed Mar 1, 2021
1 parent 4a99cd3 commit 55994db
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 45 deletions.
18 changes: 3 additions & 15 deletions include/linux/skmsg.h
Expand Up @@ -98,6 +98,7 @@ struct sk_psock {
void (*saved_close)(struct sock *sk, long timeout);
void (*saved_write_space)(struct sock *sk);
void (*saved_data_ready)(struct sock *sk);
int (*saved_update_proto)(struct sock *sk, bool restore);
struct proto *sk_proto;
struct sk_psock_work_state work_state;
struct work_struct work;
Expand Down Expand Up @@ -350,25 +351,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
}
}

static inline void sk_psock_update_proto(struct sock *sk,
struct sk_psock *psock,
struct proto *ops)
{
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, ops);
}

static inline void sk_psock_restore_proto(struct sock *sk,
struct sk_psock *psock)
{
sk->sk_prot->unhash = psock->saved_unhash;
if (inet_csk_has_ulp(sk)) {
tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
} else {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
}
if (psock->saved_update_proto)
psock->saved_update_proto(sk, true);
}

static inline void sk_psock_set_state(struct sk_psock *psock,
Expand Down
3 changes: 3 additions & 0 deletions include/net/sock.h
Expand Up @@ -1184,6 +1184,9 @@ struct proto {
void (*unhash)(struct sock *sk);
void (*rehash)(struct sock *sk);
int (*get_port)(struct sock *sk, unsigned short snum);
#ifdef CONFIG_BPF_SYSCALL
int (*update_proto)(struct sock *sk, bool restore);
#endif

/* Keeping track of sockets in use */
#ifdef CONFIG_PROC_FS
Expand Down
1 change: 1 addition & 0 deletions include/net/tcp.h
Expand Up @@ -2203,6 +2203,7 @@ struct sk_psock;

#ifdef CONFIG_BPF_SYSCALL
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
int tcp_bpf_update_proto(struct sock *sk, bool restore);
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
#endif /* CONFIG_BPF_SYSCALL */

Expand Down
1 change: 1 addition & 0 deletions include/net/udp.h
Expand Up @@ -518,6 +518,7 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk,
#ifdef CONFIG_BPF_SYSCALL
struct sk_psock;
struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
int udp_bpf_update_proto(struct sock *sk, bool restore);
#endif

#endif /* _UDP_H */
5 changes: 0 additions & 5 deletions net/core/skmsg.c
Expand Up @@ -563,11 +563,6 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)

write_lock_bh(&sk->sk_callback_lock);

if (inet_csk_has_ulp(sk)) {
psock = ERR_PTR(-EINVAL);
goto out;
}

if (sk->sk_user_data) {
psock = ERR_PTR(-EBUSY);
goto out;
Expand Down
24 changes: 4 additions & 20 deletions net/core/sock_map.c
Expand Up @@ -184,26 +184,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)

static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
{
struct proto *prot;

switch (sk->sk_type) {
case SOCK_STREAM:
prot = tcp_bpf_get_proto(sk, psock);
break;

case SOCK_DGRAM:
prot = udp_bpf_get_proto(sk, psock);
break;

default:
if (!sk->sk_prot->update_proto)
return -EINVAL;
}

if (IS_ERR(prot))
return PTR_ERR(prot);

sk_psock_update_proto(sk, psock, prot);
return 0;
psock->saved_update_proto = sk->sk_prot->update_proto;
return sk->sk_prot->update_proto(sk, false);
}

static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
Expand Down Expand Up @@ -570,7 +554,7 @@ static bool sock_map_redirect_allowed(const struct sock *sk)

static bool sock_map_sk_is_suitable(const struct sock *sk)
{
return sk_is_tcp(sk) || sk_is_udp(sk);
return !!sk->sk_prot->update_proto;
}

static bool sock_map_sk_state_allowed(const struct sock *sk)
Expand Down
23 changes: 20 additions & 3 deletions net/ipv4/tcp_bpf.c
Expand Up @@ -601,19 +601,36 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
}

struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
int tcp_bpf_update_proto(struct sock *sk, bool restore)
{
struct sk_psock *psock = sk_psock(sk);
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;

if (restore) {
if (inet_csk_has_ulp(sk)) {
tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
} else {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
}
return 0;
}

if (inet_csk_has_ulp(sk))
return -EINVAL;

if (sk->sk_family == AF_INET6) {
if (tcp_bpf_assert_proto_ops(psock->sk_proto))
return ERR_PTR(-EINVAL);
return -EINVAL;

tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
}

return &tcp_bpf_prots[family][config];
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
return 0;
}

/* If a child got cloned from a listening socket that had tcp_bpf
Expand Down
3 changes: 3 additions & 0 deletions net/ipv4/tcp_ipv4.c
Expand Up @@ -2806,6 +2806,9 @@ struct proto tcp_prot = {
.hash = inet_hash,
.unhash = inet_unhash,
.get_port = inet_csk_get_port,
#ifdef CONFIG_BPF_SYSCALL
.update_proto = tcp_bpf_update_proto,
#endif
.enter_memory_pressure = tcp_enter_memory_pressure,
.leave_memory_pressure = tcp_leave_memory_pressure,
.stream_memory_free = tcp_stream_memory_free,
Expand Down
3 changes: 3 additions & 0 deletions net/ipv4/udp.c
Expand Up @@ -2849,6 +2849,9 @@ struct proto udp_prot = {
.unhash = udp_lib_unhash,
.rehash = udp_v4_rehash,
.get_port = udp_v4_get_port,
#ifdef CONFIG_BPF_SYSCALL
.update_proto = udp_bpf_update_proto,
#endif
.memory_allocated = &udp_memory_allocated,
.sysctl_mem = sysctl_udp_mem,
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
Expand Down
14 changes: 12 additions & 2 deletions net/ipv4/udp_bpf.c
Expand Up @@ -41,12 +41,22 @@ static int __init udp_bpf_v4_build_proto(void)
}
core_initcall(udp_bpf_v4_build_proto);

struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
int udp_bpf_update_proto(struct sock *sk, bool restore)
{
int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
struct sk_psock *psock = sk_psock(sk);

if (restore) {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
return 0;
}

if (sk->sk_family == AF_INET6)
udp_bpf_check_v6_needs_rebuild(psock->sk_proto);

return &udp_bpf_prots[family];
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
return 0;
}
3 changes: 3 additions & 0 deletions net/ipv6/tcp_ipv6.c
Expand Up @@ -2134,6 +2134,9 @@ struct proto tcpv6_prot = {
.hash = inet6_hash,
.unhash = inet_unhash,
.get_port = inet_csk_get_port,
#ifdef CONFIG_BPF_SYSCALL
.update_proto = tcp_bpf_update_proto,
#endif
.enter_memory_pressure = tcp_enter_memory_pressure,
.leave_memory_pressure = tcp_leave_memory_pressure,
.stream_memory_free = tcp_stream_memory_free,
Expand Down
3 changes: 3 additions & 0 deletions net/ipv6/udp.c
Expand Up @@ -1713,6 +1713,9 @@ struct proto udpv6_prot = {
.unhash = udp_lib_unhash,
.rehash = udp_v6_rehash,
.get_port = udp_v6_get_port,
#ifdef CONFIG_BPF_SYSCALL
.update_proto = udp_bpf_update_proto,
#endif
.memory_allocated = &udp_memory_allocated,
.sysctl_mem = sysctl_udp_mem,
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
Expand Down

0 comments on commit 55994db

Please sign in to comment.