Skip to content

Commit

Permalink
net: Update bhash2 when socket's rcv saddr changes
Browse files Browse the repository at this point in the history
Commit d5a42de ("net: Add a second bind table hashed by port and
address") added a second bind table, bhash2, that hashes by a socket's port
and rcv address.

However, there are two cases where the socket's rcv saddr can change
after it has been binded:

1) The case where there is a bind() call on "::" (IPADDR_ANY) and then
a connect() call. The kernel will assign the socket an address when it
handles the connect()

2) In inet_sk_reselect_saddr(), which is called when rerouting fails
when rebuilding the sk header (invoked by inet_sk_rebuild_header)

In these two cases, we need to update the bhash2 table by removing the
entry for the old address, and adding a new entry reflecting the updated
address.

Reported-by: syzbot+015d756bbd1f8b5c8f09@syzkaller.appspotmail.com
Fixes: d5a42de ("net: Add a second bind table hashed by port and address")
Signed-off-by: Joanne Koong <joannelkoong@gmail.com>
Reviewed-by: Eric Dumazet <edumzet@google.com>
  • Loading branch information
joannekoong authored and intel-lab-lkp committed Jun 1, 2022
1 parent 7e062cd commit d4e9d3a
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 18 deletions.
6 changes: 4 additions & 2 deletions include/net/inet_hashtables.h
Expand Up @@ -448,11 +448,13 @@ static inline void sk_rcv_saddr_set(struct sock *sk, __be32 addr)
}

int __inet_hash_connect(struct inet_timewait_death_row *death_row,
struct sock *sk, u64 port_offset,
struct sock *sk, u64 port_offset, bool prev_inaddr_any,
int (*check_established)(struct inet_timewait_death_row *,
struct sock *, __u16,
struct inet_timewait_sock **));

int inet_hash_connect(struct inet_timewait_death_row *death_row,
struct sock *sk);
struct sock *sk, bool prev_inaddr_any);

int inet_bhash2_update_saddr(struct sock *sk);
#endif /* _INET_HASHTABLES_H */
2 changes: 1 addition & 1 deletion include/net/ipv6.h
Expand Up @@ -1187,7 +1187,7 @@ int inet6_compat_ioctl(struct socket *sock, unsigned int cmd,
unsigned long arg);

int inet6_hash_connect(struct inet_timewait_death_row *death_row,
struct sock *sk);
struct sock *sk, bool prev_inaddr_any);
int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size);
int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
int flags);
Expand Down
10 changes: 7 additions & 3 deletions net/dccp/ipv4.c
Expand Up @@ -47,12 +47,13 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
struct inet_sock *inet = inet_sk(sk);
struct dccp_sock *dp = dccp_sk(sk);
struct ip_options_rcu *inet_opt;
__be16 orig_sport, orig_dport;
bool prev_inaddr_any = false;
__be32 daddr, nexthop;
struct flowi4 *fl4;
struct rtable *rt;
int err;
struct ip_options_rcu *inet_opt;

dp->dccps_role = DCCP_ROLE_CLIENT;

Expand Down Expand Up @@ -89,8 +90,11 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
if (inet_opt == NULL || !inet_opt->opt.srr)
daddr = fl4->daddr;

if (inet->inet_saddr == 0)
if (inet->inet_saddr == 0) {
inet->inet_saddr = fl4->saddr;
prev_inaddr_any = true;
}

sk_rcv_saddr_set(sk, inet->inet_saddr);
inet->inet_dport = usin->sin_port;
sk_daddr_set(sk, daddr);
Expand All @@ -105,7 +109,7 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
* complete initialization after this.
*/
dccp_set_state(sk, DCCP_REQUESTING);
err = inet_hash_connect(&dccp_death_row, sk);
err = inet_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
if (err != 0)
goto failure;

Expand Down
4 changes: 3 additions & 1 deletion net/dccp/ipv6.c
Expand Up @@ -824,6 +824,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
struct ipv6_pinfo *np = inet6_sk(sk);
struct dccp_sock *dp = dccp_sk(sk);
struct in6_addr *saddr = NULL, *final_p, final;
bool prev_inaddr_any = false;
struct ipv6_txoptions *opt;
struct flowi6 fl6;
struct dst_entry *dst;
Expand Down Expand Up @@ -936,6 +937,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
if (saddr == NULL) {
saddr = &fl6.saddr;
sk->sk_v6_rcv_saddr = *saddr;
prev_inaddr_any = true;
}

/* set the source address */
Expand All @@ -951,7 +953,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
inet->inet_dport = usin->sin6_port;

dccp_set_state(sk, DCCP_REQUESTING);
err = inet6_hash_connect(&dccp_death_row, sk);
err = inet6_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
if (err)
goto late_failure;

Expand Down
7 changes: 6 additions & 1 deletion net/ipv4/af_inet.c
Expand Up @@ -1221,10 +1221,11 @@ static int inet_sk_reselect_saddr(struct sock *sk)
struct inet_sock *inet = inet_sk(sk);
__be32 old_saddr = inet->inet_saddr;
__be32 daddr = inet->inet_daddr;
struct ip_options_rcu *inet_opt;
struct flowi4 *fl4;
struct rtable *rt;
__be32 new_saddr;
struct ip_options_rcu *inet_opt;
int err;

inet_opt = rcu_dereference_protected(inet->inet_opt,
lockdep_sock_is_held(sk));
Expand Down Expand Up @@ -1253,6 +1254,10 @@ static int inet_sk_reselect_saddr(struct sock *sk)

inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;

err = inet_bhash2_update_saddr(sk);
if (err)
return err;

/*
* XXX The only one ugly spot where we need to
* XXX really change the sockets identity after
Expand Down
69 changes: 64 additions & 5 deletions net/ipv4/inet_hashtables.c
Expand Up @@ -826,6 +826,54 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
return bhash2;
}

/* the lock for the socket's corresponding bhash entry must be held */
int __inet_bhash2_update_saddr(struct sock *sk, struct inet_hashinfo *hinfo,
struct net *net, int port, int l3mdev)
{
struct inet_bind2_hashbucket *head2;
struct inet_bind2_bucket *tb2;

tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
&head2);
if (!tb2) {
tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
net, head2, port, l3mdev, sk);
if (!tb2)
return -ENOMEM;
}

/* Remove the socket's old entry from bhash2 */
__sk_del_bind2_node(sk);

sk_add_bind2_node(sk, &tb2->owners);
inet_csk(sk)->icsk_bind2_hash = tb2;

return 0;
}

/* This should be called if/when a socket's rcv saddr changes after it has
* been binded.
*/
int inet_bhash2_update_saddr(struct sock *sk)
{
struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
int l3mdev = inet_sk_bound_l3mdev(sk);
struct inet_bind_hashbucket *head;
int port = inet_sk(sk)->inet_num;
struct net *net = sock_net(sk);
int err;

head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)];

spin_lock_bh(&head->lock);

err = __inet_bhash2_update_saddr(sk, hinfo, net, port, l3mdev);

spin_unlock_bh(&head->lock);

return err;
}

/* RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm
* Note that we use 32bit integers (vs RFC 'short integers')
* because 2^16 is not a multiple of num_ephemeral and this
Expand All @@ -840,7 +888,7 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
static u32 *table_perturb;

int __inet_hash_connect(struct inet_timewait_death_row *death_row,
struct sock *sk, u64 port_offset,
struct sock *sk, u64 port_offset, bool prev_inaddr_any,
int (*check_established)(struct inet_timewait_death_row *,
struct sock *, __u16, struct inet_timewait_sock **))
{
Expand All @@ -858,11 +906,24 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
int l3mdev;
u32 index;

l3mdev = inet_sk_bound_l3mdev(sk);

if (port) {
head = &hinfo->bhash[inet_bhashfn(net, port,
hinfo->bhash_size)];
tb = inet_csk(sk)->icsk_bind_hash;

spin_lock_bh(&head->lock);

if (prev_inaddr_any) {
ret = __inet_bhash2_update_saddr(sk, hinfo, net, port,
l3mdev);
if (ret) {
spin_unlock_bh(&head->lock);
return ret;
}
}

if (sk_head(&tb->owners) == sk && !sk->sk_bind_node.next) {
inet_ehash_nolisten(sk, NULL, NULL);
spin_unlock_bh(&head->lock);
Expand All @@ -875,8 +936,6 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
return ret;
}

l3mdev = inet_sk_bound_l3mdev(sk);

inet_get_local_port_range(net, &low, &high);
high++; /* [32768, 60999] -> [32768, 61000[ */
remaining = high - low;
Expand Down Expand Up @@ -987,13 +1046,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
* Bind a port for a connect operation and hash it.
*/
int inet_hash_connect(struct inet_timewait_death_row *death_row,
struct sock *sk)
struct sock *sk, bool prev_inaddr_any)
{
u64 port_offset = 0;

if (!inet_sk(sk)->inet_num)
port_offset = inet_sk_port_offset(sk);
return __inet_hash_connect(death_row, sk, port_offset,
return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
__inet_check_established);
}
EXPORT_SYMBOL_GPL(inet_hash_connect);
Expand Down
8 changes: 6 additions & 2 deletions net/ipv4/tcp_ipv4.c
Expand Up @@ -203,6 +203,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
struct inet_sock *inet = inet_sk(sk);
struct tcp_sock *tp = tcp_sk(sk);
__be16 orig_sport, orig_dport;
bool prev_inaddr_any = false;
__be32 daddr, nexthop;
struct flowi4 *fl4;
struct rtable *rt;
Expand Down Expand Up @@ -246,8 +247,11 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
if (!inet_opt || !inet_opt->opt.srr)
daddr = fl4->daddr;

if (!inet->inet_saddr)
if (!inet->inet_saddr) {
inet->inet_saddr = fl4->saddr;
prev_inaddr_any = true;
}

sk_rcv_saddr_set(sk, inet->inet_saddr);

if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
Expand All @@ -273,7 +277,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
* complete initialization after this.
*/
tcp_set_state(sk, TCP_SYN_SENT);
err = inet_hash_connect(tcp_death_row, sk);
err = inet_hash_connect(tcp_death_row, sk, prev_inaddr_any);
if (err)
goto failure;

Expand Down
4 changes: 2 additions & 2 deletions net/ipv6/inet6_hashtables.c
Expand Up @@ -317,13 +317,13 @@ static u64 inet6_sk_port_offset(const struct sock *sk)
}

int inet6_hash_connect(struct inet_timewait_death_row *death_row,
struct sock *sk)
struct sock *sk, bool prev_inaddr_any)
{
u64 port_offset = 0;

if (!inet_sk(sk)->inet_num)
port_offset = inet6_sk_port_offset(sk);
return __inet_hash_connect(death_row, sk, port_offset,
return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
__inet6_check_established);
}
EXPORT_SYMBOL_GPL(inet6_hash_connect);
Expand Down
4 changes: 3 additions & 1 deletion net/ipv6/tcp_ipv6.c
Expand Up @@ -152,6 +152,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
struct ipv6_pinfo *np = tcp_inet6_sk(sk);
struct tcp_sock *tp = tcp_sk(sk);
struct in6_addr *saddr = NULL, *final_p, final;
bool prev_inaddr_any = false;
struct ipv6_txoptions *opt;
struct flowi6 fl6;
struct dst_entry *dst;
Expand Down Expand Up @@ -289,6 +290,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
if (!saddr) {
saddr = &fl6.saddr;
sk->sk_v6_rcv_saddr = *saddr;
prev_inaddr_any = true;
}

/* set the source address */
Expand All @@ -309,7 +311,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,

tcp_set_state(sk, TCP_SYN_SENT);
tcp_death_row = sock_net(sk)->ipv4.tcp_death_row;
err = inet6_hash_connect(tcp_death_row, sk);
err = inet6_hash_connect(tcp_death_row, sk, prev_inaddr_any);
if (err)
goto late_failure;

Expand Down

0 comments on commit d4e9d3a

Please sign in to comment.