Skip to content

Commit c274af2

Browse files
edumazetdavem330
authored andcommitted
inet: introduce inet->inet_flags
Various inet fields are currently racy. do_ip_setsockopt() and do_ip_getsockopt() are mostly holding the socket lock, but some (fast) paths do not. Use a new inet->inet_flags to hold atomic bits in the series. Remove inet->cmsg_flags, and use instead 9 bits from inet_flags. Signed-off-by: Eric Dumazet <edumazet@google.com> Acked-by: Soheil Hassas Yeganeh <soheil@google.com> Reviewed-by: Simon Horman <horms@kernel.org> Signed-off-by: David S. Miller <davem@davemloft.net>
1 parent 936db83 commit c274af2

File tree

8 files changed

+83
-71
lines changed

8 files changed

+83
-71
lines changed

include/net/inet_sock.h

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ struct rtable;
194194
* @inet_rcv_saddr - Bound local IPv4 addr
195195
* @inet_dport - Destination port
196196
* @inet_num - Local port
197+
* @inet_flags - various atomic flags
197198
* @inet_saddr - Sending source
198199
* @uc_ttl - Unicast TTL
199200
* @inet_sport - Source port
@@ -218,11 +219,11 @@ struct inet_sock {
218219
#define inet_dport sk.__sk_common.skc_dport
219220
#define inet_num sk.__sk_common.skc_num
220221

222+
unsigned long inet_flags;
221223
__be32 inet_saddr;
222224
__s16 uc_ttl;
223-
__u16 cmsg_flags;
224-
struct ip_options_rcu __rcu *inet_opt;
225225
__be16 inet_sport;
226+
struct ip_options_rcu __rcu *inet_opt;
226227
__u16 inet_id;
227228

228229
__u8 tos;
@@ -259,16 +260,48 @@ struct inet_sock {
259260
#define IPCORK_OPT 1 /* ip-options has been held in ipcork.opt */
260261
#define IPCORK_ALLFRAG 2 /* always fragment (for ipv6 for now) */
261262

263+
enum {
264+
INET_FLAGS_PKTINFO = 0,
265+
INET_FLAGS_TTL = 1,
266+
INET_FLAGS_TOS = 2,
267+
INET_FLAGS_RECVOPTS = 3,
268+
INET_FLAGS_RETOPTS = 4,
269+
INET_FLAGS_PASSSEC = 5,
270+
INET_FLAGS_ORIGDSTADDR = 6,
271+
INET_FLAGS_CHECKSUM = 7,
272+
INET_FLAGS_RECVFRAGSIZE = 8,
273+
};
274+
262275
/* cmsg flags for inet */
263-
#define IP_CMSG_PKTINFO BIT(0)
264-
#define IP_CMSG_TTL BIT(1)
265-
#define IP_CMSG_TOS BIT(2)
266-
#define IP_CMSG_RECVOPTS BIT(3)
267-
#define IP_CMSG_RETOPTS BIT(4)
268-
#define IP_CMSG_PASSSEC BIT(5)
269-
#define IP_CMSG_ORIGDSTADDR BIT(6)
270-
#define IP_CMSG_CHECKSUM BIT(7)
271-
#define IP_CMSG_RECVFRAGSIZE BIT(8)
276+
#define IP_CMSG_PKTINFO BIT(INET_FLAGS_PKTINFO)
277+
#define IP_CMSG_TTL BIT(INET_FLAGS_TTL)
278+
#define IP_CMSG_TOS BIT(INET_FLAGS_TOS)
279+
#define IP_CMSG_RECVOPTS BIT(INET_FLAGS_RECVOPTS)
280+
#define IP_CMSG_RETOPTS BIT(INET_FLAGS_RETOPTS)
281+
#define IP_CMSG_PASSSEC BIT(INET_FLAGS_PASSSEC)
282+
#define IP_CMSG_ORIGDSTADDR BIT(INET_FLAGS_ORIGDSTADDR)
283+
#define IP_CMSG_CHECKSUM BIT(INET_FLAGS_CHECKSUM)
284+
#define IP_CMSG_RECVFRAGSIZE BIT(INET_FLAGS_RECVFRAGSIZE)
285+
286+
#define IP_CMSG_ALL (IP_CMSG_PKTINFO | IP_CMSG_TTL | \
287+
IP_CMSG_TOS | IP_CMSG_RECVOPTS | \
288+
IP_CMSG_RETOPTS | IP_CMSG_PASSSEC | \
289+
IP_CMSG_ORIGDSTADDR | IP_CMSG_CHECKSUM | \
290+
IP_CMSG_RECVFRAGSIZE)
291+
292+
static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
293+
{
294+
return READ_ONCE(inet->inet_flags) & IP_CMSG_ALL;
295+
}
296+
297+
#define inet_test_bit(nr, sk) \
298+
test_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags)
299+
#define inet_set_bit(nr, sk) \
300+
set_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags)
301+
#define inet_clear_bit(nr, sk) \
302+
clear_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags)
303+
#define inet_assign_bit(nr, sk, val) \
304+
assign_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags, val)
272305

273306
static inline bool sk_is_inet(struct sock *sk)
274307
{

net/ipv4/ip_sockglue.c

Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,10 @@ static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb)
171171
void ip_cmsg_recv_offset(struct msghdr *msg, struct sock *sk,
172172
struct sk_buff *skb, int tlen, int offset)
173173
{
174-
struct inet_sock *inet = inet_sk(sk);
175-
unsigned int flags = inet->cmsg_flags;
174+
unsigned long flags = inet_cmsg_flags(inet_sk(sk));
175+
176+
if (!flags)
177+
return;
176178

177179
/* Ordered by supposed usage frequency */
178180
if (flags & IP_CMSG_PKTINFO) {
@@ -568,7 +570,7 @@ int ip_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
568570
if (ipv4_datagram_support_cmsg(sk, skb, serr->ee.ee_origin)) {
569571
sin->sin_family = AF_INET;
570572
sin->sin_addr.s_addr = ip_hdr(skb)->saddr;
571-
if (inet_sk(sk)->cmsg_flags)
573+
if (inet_cmsg_flags(inet_sk(sk)))
572574
ip_cmsg_recv(msg, skb);
573575
}
574576

@@ -635,7 +637,7 @@ EXPORT_SYMBOL(ip_sock_set_mtu_discover);
635637
void ip_sock_set_pktinfo(struct sock *sk)
636638
{
637639
lock_sock(sk);
638-
inet_sk(sk)->cmsg_flags |= IP_CMSG_PKTINFO;
640+
inet_set_bit(PKTINFO, sk);
639641
release_sock(sk);
640642
}
641643
EXPORT_SYMBOL(ip_sock_set_pktinfo);
@@ -990,67 +992,43 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
990992
break;
991993
}
992994
case IP_PKTINFO:
993-
if (val)
994-
inet->cmsg_flags |= IP_CMSG_PKTINFO;
995-
else
996-
inet->cmsg_flags &= ~IP_CMSG_PKTINFO;
995+
inet_assign_bit(PKTINFO, sk, val);
997996
break;
998997
case IP_RECVTTL:
999-
if (val)
1000-
inet->cmsg_flags |= IP_CMSG_TTL;
1001-
else
1002-
inet->cmsg_flags &= ~IP_CMSG_TTL;
998+
inet_assign_bit(TTL, sk, val);
1003999
break;
10041000
case IP_RECVTOS:
1005-
if (val)
1006-
inet->cmsg_flags |= IP_CMSG_TOS;
1007-
else
1008-
inet->cmsg_flags &= ~IP_CMSG_TOS;
1001+
inet_assign_bit(TOS, sk, val);
10091002
break;
10101003
case IP_RECVOPTS:
1011-
if (val)
1012-
inet->cmsg_flags |= IP_CMSG_RECVOPTS;
1013-
else
1014-
inet->cmsg_flags &= ~IP_CMSG_RECVOPTS;
1004+
inet_assign_bit(RECVOPTS, sk, val);
10151005
break;
10161006
case IP_RETOPTS:
1017-
if (val)
1018-
inet->cmsg_flags |= IP_CMSG_RETOPTS;
1019-
else
1020-
inet->cmsg_flags &= ~IP_CMSG_RETOPTS;
1007+
inet_assign_bit(RETOPTS, sk, val);
10211008
break;
10221009
case IP_PASSSEC:
1023-
if (val)
1024-
inet->cmsg_flags |= IP_CMSG_PASSSEC;
1025-
else
1026-
inet->cmsg_flags &= ~IP_CMSG_PASSSEC;
1010+
inet_assign_bit(PASSSEC, sk, val);
10271011
break;
10281012
case IP_RECVORIGDSTADDR:
1029-
if (val)
1030-
inet->cmsg_flags |= IP_CMSG_ORIGDSTADDR;
1031-
else
1032-
inet->cmsg_flags &= ~IP_CMSG_ORIGDSTADDR;
1013+
inet_assign_bit(ORIGDSTADDR, sk, val);
10331014
break;
10341015
case IP_CHECKSUM:
10351016
if (val) {
1036-
if (!(inet->cmsg_flags & IP_CMSG_CHECKSUM)) {
1017+
if (!(inet_test_bit(CHECKSUM, sk))) {
10371018
inet_inc_convert_csum(sk);
1038-
inet->cmsg_flags |= IP_CMSG_CHECKSUM;
1019+
inet_set_bit(CHECKSUM, sk);
10391020
}
10401021
} else {
1041-
if (inet->cmsg_flags & IP_CMSG_CHECKSUM) {
1022+
if (inet_test_bit(CHECKSUM, sk)) {
10421023
inet_dec_convert_csum(sk);
1043-
inet->cmsg_flags &= ~IP_CMSG_CHECKSUM;
1024+
inet_clear_bit(CHECKSUM, sk);
10441025
}
10451026
}
10461027
break;
10471028
case IP_RECVFRAGSIZE:
10481029
if (sk->sk_type != SOCK_RAW && sk->sk_type != SOCK_DGRAM)
10491030
goto e_inval;
1050-
if (val)
1051-
inet->cmsg_flags |= IP_CMSG_RECVFRAGSIZE;
1052-
else
1053-
inet->cmsg_flags &= ~IP_CMSG_RECVFRAGSIZE;
1031+
inet_assign_bit(RECVFRAGSIZE, sk, val);
10541032
break;
10551033
case IP_TOS: /* This sets both TOS and Precedence */
10561034
__ip_sock_set_tos(sk, val);
@@ -1415,7 +1393,7 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
14151393
void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb)
14161394
{
14171395
struct in_pktinfo *pktinfo = PKTINFO_SKB_CB(skb);
1418-
bool prepare = (inet_sk(sk)->cmsg_flags & IP_CMSG_PKTINFO) ||
1396+
bool prepare = inet_test_bit(PKTINFO, sk) ||
14191397
ipv6_sk_rxinfo(sk);
14201398

14211399
if (prepare && skb_rtable(skb)) {
@@ -1601,31 +1579,31 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
16011579
return 0;
16021580
}
16031581
case IP_PKTINFO:
1604-
val = (inet->cmsg_flags & IP_CMSG_PKTINFO) != 0;
1582+
val = inet_test_bit(PKTINFO, sk);
16051583
break;
16061584
case IP_RECVTTL:
1607-
val = (inet->cmsg_flags & IP_CMSG_TTL) != 0;
1585+
val = inet_test_bit(TTL, sk);
16081586
break;
16091587
case IP_RECVTOS:
1610-
val = (inet->cmsg_flags & IP_CMSG_TOS) != 0;
1588+
val = inet_test_bit(TOS, sk);
16111589
break;
16121590
case IP_RECVOPTS:
1613-
val = (inet->cmsg_flags & IP_CMSG_RECVOPTS) != 0;
1591+
val = inet_test_bit(RECVOPTS, sk);
16141592
break;
16151593
case IP_RETOPTS:
1616-
val = (inet->cmsg_flags & IP_CMSG_RETOPTS) != 0;
1594+
val = inet_test_bit(RETOPTS, sk);
16171595
break;
16181596
case IP_PASSSEC:
1619-
val = (inet->cmsg_flags & IP_CMSG_PASSSEC) != 0;
1597+
val = inet_test_bit(PASSSEC, sk);
16201598
break;
16211599
case IP_RECVORIGDSTADDR:
1622-
val = (inet->cmsg_flags & IP_CMSG_ORIGDSTADDR) != 0;
1600+
val = inet_test_bit(ORIGDSTADDR, sk);
16231601
break;
16241602
case IP_CHECKSUM:
1625-
val = (inet->cmsg_flags & IP_CMSG_CHECKSUM) != 0;
1603+
val = inet_test_bit(CHECKSUM, sk);
16261604
break;
16271605
case IP_RECVFRAGSIZE:
1628-
val = (inet->cmsg_flags & IP_CMSG_RECVFRAGSIZE) != 0;
1606+
val = inet_test_bit(RECVFRAGSIZE, sk);
16291607
break;
16301608
case IP_TOS:
16311609
val = inet->tos;
@@ -1737,19 +1715,19 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
17371715
msg.msg_controllen = len;
17381716
msg.msg_flags = in_compat_syscall() ? MSG_CMSG_COMPAT : 0;
17391717

1740-
if (inet->cmsg_flags & IP_CMSG_PKTINFO) {
1718+
if (inet_test_bit(PKTINFO, sk)) {
17411719
struct in_pktinfo info;
17421720

17431721
info.ipi_addr.s_addr = inet->inet_rcv_saddr;
17441722
info.ipi_spec_dst.s_addr = inet->inet_rcv_saddr;
17451723
info.ipi_ifindex = inet->mc_index;
17461724
put_cmsg(&msg, SOL_IP, IP_PKTINFO, sizeof(info), &info);
17471725
}
1748-
if (inet->cmsg_flags & IP_CMSG_TTL) {
1726+
if (inet_test_bit(TTL, sk)) {
17491727
int hlim = inet->mc_ttl;
17501728
put_cmsg(&msg, SOL_IP, IP_TTL, sizeof(hlim), &hlim);
17511729
}
1752-
if (inet->cmsg_flags & IP_CMSG_TOS) {
1730+
if (inet_test_bit(TOS, sk)) {
17531731
int tos = inet->rcv_tos;
17541732
put_cmsg(&msg, SOL_IP, IP_TOS, sizeof(tos), &tos);
17551733
}

net/ipv4/ping.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags,
894894
*addr_len = sizeof(*sin);
895895
}
896896

897-
if (isk->cmsg_flags)
897+
if (inet_cmsg_flags(isk))
898898
ip_cmsg_recv(msg, skb);
899899

900900
#if IS_ENABLED(CONFIG_IPV6)
@@ -921,7 +921,8 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags,
921921
if (skb->protocol == htons(ETH_P_IPV6) &&
922922
inet6_sk(sk)->rxopt.all)
923923
pingv6_ops.ip6_datagram_recv_specific_ctl(sk, msg, skb);
924-
else if (skb->protocol == htons(ETH_P_IP) && isk->cmsg_flags)
924+
else if (skb->protocol == htons(ETH_P_IP) &&
925+
inet_cmsg_flags(isk))
925926
ip_cmsg_recv(msg, skb);
926927
#endif
927928
} else {

net/ipv4/raw.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ static int raw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
767767
memset(&sin->sin_zero, 0, sizeof(sin->sin_zero));
768768
*addr_len = sizeof(*sin);
769769
}
770-
if (inet->cmsg_flags)
770+
if (inet_cmsg_flags(inet))
771771
ip_cmsg_recv(msg, skb);
772772
if (flags & MSG_TRUNC)
773773
copied = skb->len;

net/ipv4/udp.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1870,7 +1870,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags,
18701870
if (udp_sk(sk)->gro_enabled)
18711871
udp_cmsg_recv(msg, sk, skb);
18721872

1873-
if (inet->cmsg_flags)
1873+
if (inet_cmsg_flags(inet))
18741874
ip_cmsg_recv_offset(msg, sk, skb, sizeof(struct udphdr), off);
18751875

18761876
err = copied;

net/ipv6/datagram.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ int ipv6_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
524524
} else {
525525
ipv6_addr_set_v4mapped(ip_hdr(skb)->saddr,
526526
&sin->sin6_addr);
527-
if (inet_sk(sk)->cmsg_flags)
527+
if (inet_cmsg_flags(inet_sk(sk)))
528528
ip_cmsg_recv(msg, skb);
529529
}
530530
}

net/ipv6/udp.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
420420
ip6_datagram_recv_common_ctl(sk, msg, skb);
421421

422422
if (is_udp4) {
423-
if (inet->cmsg_flags)
423+
if (inet_cmsg_flags(inet))
424424
ip_cmsg_recv_offset(msg, sk, skb,
425425
sizeof(struct udphdr), off);
426426
} else {

net/l2tp/l2tp_ip.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ static int l2tp_ip_recvmsg(struct sock *sk, struct msghdr *msg,
552552
memset(&sin->sin_zero, 0, sizeof(sin->sin_zero));
553553
*addr_len = sizeof(*sin);
554554
}
555-
if (inet->cmsg_flags)
555+
if (inet_cmsg_flags(inet))
556556
ip_cmsg_recv(msg, skb);
557557
if (flags & MSG_TRUNC)
558558
copied = skb->len;

0 commit comments

Comments
 (0)