Skip to content

Commit

Permalink
net: convert sock.sk_wmem_alloc from atomic_t to refcount_t
Browse files Browse the repository at this point in the history
refcount_t type and corresponding API should be
used instead of atomic_t when the variable is used as
a reference counter. This allows to avoid accidental
refcounter overflows that might lead to use-after-free
situations.

Signed-off-by: Elena Reshetova <elena.reshetova@intel.com>
Signed-off-by: Hans Liljestrand <ishkamiel@gmail.com>
Signed-off-by: Kees Cook <keescook@chromium.org>
Signed-off-by: David Windsor <dwindsor@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
ereshetova authored and davem330 committed Jul 1, 2017
1 parent 2638595 commit 14afee4
Show file tree
Hide file tree
Showing 37 changed files with 74 additions and 85 deletions.
12 changes: 1 addition & 11 deletions drivers/atm/fore200e.c
Original file line number Diff line number Diff line change
Expand Up @@ -924,12 +924,7 @@ fore200e_tx_irq(struct fore200e* fore200e)
else {
dev_kfree_skb_any(entry->skb);
}
#if 1
/* race fixed by the above incarnation mechanism, but... */
if (atomic_read(&sk_atm(vcc)->sk_wmem_alloc) < 0) {
atomic_set(&sk_atm(vcc)->sk_wmem_alloc, 0);
}
#endif

/* check error condition */
if (*entry->status & STATUS_ERROR)
atomic_inc(&vcc->stats->tx_err);
Expand Down Expand Up @@ -1130,13 +1125,9 @@ fore200e_push_rpd(struct fore200e* fore200e, struct atm_vcc* vcc, struct rpd* rp
return -ENOMEM;
}

ASSERT(atomic_read(&sk_atm(vcc)->sk_wmem_alloc) >= 0);

vcc->push(vcc, skb);
atomic_inc(&vcc->stats->rx);

ASSERT(atomic_read(&sk_atm(vcc)->sk_wmem_alloc) >= 0);

return 0;
}

Expand Down Expand Up @@ -1572,7 +1563,6 @@ fore200e_send(struct atm_vcc *vcc, struct sk_buff *skb)
unsigned long flags;

ASSERT(vcc);
ASSERT(atomic_read(&sk_atm(vcc)->sk_wmem_alloc) >= 0);
ASSERT(fore200e);
ASSERT(fore200e_vcc);

Expand Down
2 changes: 1 addition & 1 deletion drivers/atm/he.c
Original file line number Diff line number Diff line change
Expand Up @@ -2395,7 +2395,7 @@ he_close(struct atm_vcc *vcc)
* TBRQ, the host issues the close command to the adapter.
*/

while (((tx_inuse = atomic_read(&sk_atm(vcc)->sk_wmem_alloc)) > 1) &&
while (((tx_inuse = refcount_read(&sk_atm(vcc)->sk_wmem_alloc)) > 1) &&
(retry < MAX_RETRY)) {
msleep(sleep);
if (sleep < 250)
Expand Down
4 changes: 2 additions & 2 deletions drivers/atm/idt77252.c
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ push_on_scq(struct idt77252_dev *card, struct vc_map *vc, struct sk_buff *skb)
struct sock *sk = sk_atm(vcc);

vc->estimator->cells += (skb->len + 47) / 48;
if (atomic_read(&sk->sk_wmem_alloc) >
if (refcount_read(&sk->sk_wmem_alloc) >
(sk->sk_sndbuf >> 1)) {
u32 cps = vc->estimator->maxcps;

Expand Down Expand Up @@ -2009,7 +2009,7 @@ idt77252_send_oam(struct atm_vcc *vcc, void *cell, int flags)
atomic_inc(&vcc->stats->tx_err);
return -ENOMEM;
}
atomic_add(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc);
refcount_add(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc);

skb_put_data(skb, cell, 52);

Expand Down
2 changes: 1 addition & 1 deletion include/linux/atmdev.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ static inline void atm_return(struct atm_vcc *vcc,int truesize)

static inline int atm_may_send(struct atm_vcc *vcc,unsigned int size)
{
return (size + atomic_read(&sk_atm(vcc)->sk_wmem_alloc)) <
return (size + refcount_read(&sk_atm(vcc)->sk_wmem_alloc)) <
sk_atm(vcc)->sk_sndbuf;
}

Expand Down
8 changes: 4 additions & 4 deletions include/net/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ struct sock {

/* ===== cache line for TX ===== */
int sk_wmem_queued;
atomic_t sk_wmem_alloc;
refcount_t sk_wmem_alloc;
unsigned long sk_tsq_flags;
struct sk_buff *sk_send_head;
struct sk_buff_head sk_write_queue;
Expand Down Expand Up @@ -1911,7 +1911,7 @@ static inline int skb_copy_to_page_nocache(struct sock *sk, struct iov_iter *fro
*/
static inline int sk_wmem_alloc_get(const struct sock *sk)
{
return atomic_read(&sk->sk_wmem_alloc) - 1;
return refcount_read(&sk->sk_wmem_alloc) - 1;
}

/**
Expand Down Expand Up @@ -2055,7 +2055,7 @@ static inline unsigned long sock_wspace(struct sock *sk)
int amt = 0;

if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {
amt = sk->sk_sndbuf - atomic_read(&sk->sk_wmem_alloc);
amt = sk->sk_sndbuf - refcount_read(&sk->sk_wmem_alloc);
if (amt < 0)
amt = 0;
}
Expand Down Expand Up @@ -2136,7 +2136,7 @@ bool sk_page_frag_refill(struct sock *sk, struct page_frag *pfrag);
*/
static inline bool sock_writeable(const struct sock *sk)
{
return atomic_read(&sk->sk_wmem_alloc) < (sk->sk_sndbuf >> 1);
return refcount_read(&sk->sk_wmem_alloc) < (sk->sk_sndbuf >> 1);
}

static inline gfp_t gfp_any(void)
Expand Down
2 changes: 1 addition & 1 deletion net/atm/br2684.c
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ static int br2684_xmit_vcc(struct sk_buff *skb, struct net_device *dev,

ATM_SKB(skb)->vcc = atmvcc = brvcc->atmvcc;
pr_debug("atm_skb(%p)->vcc(%p)->dev(%p)\n", skb, atmvcc, atmvcc->dev);
atomic_add(skb->truesize, &sk_atm(atmvcc)->sk_wmem_alloc);
refcount_add(skb->truesize, &sk_atm(atmvcc)->sk_wmem_alloc);
ATM_SKB(skb)->atm_options = atmvcc->atm_options;
dev->stats.tx_packets++;
dev->stats.tx_bytes += skb->len;
Expand Down
2 changes: 1 addition & 1 deletion net/atm/clip.c
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ static netdev_tx_t clip_start_xmit(struct sk_buff *skb,
memcpy(here, llc_oui, sizeof(llc_oui));
((__be16 *) here)[3] = skb->protocol;
}
atomic_add(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc);
refcount_add(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc);
ATM_SKB(skb)->atm_options = vcc->atm_options;
entry->vccs->last_use = jiffies;
pr_debug("atm_skb(%p)->vcc(%p)->dev(%p)\n", skb, vcc, vcc->dev);
Expand Down
10 changes: 5 additions & 5 deletions net/atm/common.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ static void vcc_sock_destruct(struct sock *sk)
printk(KERN_DEBUG "%s: rmem leakage (%d bytes) detected.\n",
__func__, atomic_read(&sk->sk_rmem_alloc));

if (atomic_read(&sk->sk_wmem_alloc))
if (refcount_read(&sk->sk_wmem_alloc))
printk(KERN_DEBUG "%s: wmem leakage (%d bytes) detected.\n",
__func__, atomic_read(&sk->sk_wmem_alloc));
__func__, refcount_read(&sk->sk_wmem_alloc));
}

static void vcc_def_wakeup(struct sock *sk)
Expand All @@ -101,7 +101,7 @@ static inline int vcc_writable(struct sock *sk)
struct atm_vcc *vcc = atm_sk(sk);

return (vcc->qos.txtp.max_sdu +
atomic_read(&sk->sk_wmem_alloc)) <= sk->sk_sndbuf;
refcount_read(&sk->sk_wmem_alloc)) <= sk->sk_sndbuf;
}

static void vcc_write_space(struct sock *sk)
Expand Down Expand Up @@ -156,7 +156,7 @@ int vcc_create(struct net *net, struct socket *sock, int protocol, int family, i
memset(&vcc->local, 0, sizeof(struct sockaddr_atmsvc));
memset(&vcc->remote, 0, sizeof(struct sockaddr_atmsvc));
vcc->qos.txtp.max_sdu = 1 << 16; /* for meta VCs */
atomic_set(&sk->sk_wmem_alloc, 1);
refcount_set(&sk->sk_wmem_alloc, 1);
atomic_set(&sk->sk_rmem_alloc, 0);
vcc->push = NULL;
vcc->pop = NULL;
Expand Down Expand Up @@ -630,7 +630,7 @@ int vcc_sendmsg(struct socket *sock, struct msghdr *m, size_t size)
goto out;
}
pr_debug("%d += %d\n", sk_wmem_alloc_get(sk), skb->truesize);
atomic_add(skb->truesize, &sk->sk_wmem_alloc);
refcount_add(skb->truesize, &sk->sk_wmem_alloc);

skb->dev = NULL; /* for paths shared with net_device interfaces */
ATM_SKB(skb)->atm_options = vcc->atm_options;
Expand Down
4 changes: 2 additions & 2 deletions net/atm/lec.c
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ lec_send(struct atm_vcc *vcc, struct sk_buff *skb)
ATM_SKB(skb)->vcc = vcc;
ATM_SKB(skb)->atm_options = vcc->atm_options;

atomic_add(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc);
refcount_add(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc);
if (vcc->send(vcc, skb) < 0) {
dev->stats.tx_dropped++;
return;
Expand Down Expand Up @@ -345,7 +345,7 @@ static int lec_atm_send(struct atm_vcc *vcc, struct sk_buff *skb)
int i;
char *tmp; /* FIXME */

atomic_sub(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc);
WARN_ON(refcount_sub_and_test(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc));
mesg = (struct atmlec_msg *)skb->data;
tmp = skb->data;
tmp += sizeof(struct atmlec_msg);
Expand Down
4 changes: 2 additions & 2 deletions net/atm/mpc.c
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ static int send_via_shortcut(struct sk_buff *skb, struct mpoa_client *mpc)
sizeof(struct llc_snap_hdr));
}

atomic_add(skb->truesize, &sk_atm(entry->shortcut)->sk_wmem_alloc);
refcount_add(skb->truesize, &sk_atm(entry->shortcut)->sk_wmem_alloc);
ATM_SKB(skb)->atm_options = entry->shortcut->atm_options;
entry->shortcut->send(entry->shortcut, skb);
entry->packets_fwded++;
Expand Down Expand Up @@ -911,7 +911,7 @@ static int msg_from_mpoad(struct atm_vcc *vcc, struct sk_buff *skb)

struct mpoa_client *mpc = find_mpc_by_vcc(vcc);
struct k_message *mesg = (struct k_message *)skb->data;
atomic_sub(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc);
WARN_ON(refcount_sub_and_test(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc));

if (mpc == NULL) {
pr_info("no mpc found\n");
Expand Down
2 changes: 1 addition & 1 deletion net/atm/pppoatm.c
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ static int pppoatm_send(struct ppp_channel *chan, struct sk_buff *skb)
return 1;
}

atomic_add(skb->truesize, &sk_atm(ATM_SKB(skb)->vcc)->sk_wmem_alloc);
refcount_add(skb->truesize, &sk_atm(ATM_SKB(skb)->vcc)->sk_wmem_alloc);
ATM_SKB(skb)->atm_options = ATM_SKB(skb)->vcc->atm_options;
pr_debug("atm_skb(%p)->vcc(%p)->dev(%p)\n",
skb, ATM_SKB(skb)->vcc, ATM_SKB(skb)->vcc->dev);
Expand Down
2 changes: 1 addition & 1 deletion net/atm/raw.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static void atm_pop_raw(struct atm_vcc *vcc, struct sk_buff *skb)

pr_debug("(%d) %d -= %d\n",
vcc->vci, sk_wmem_alloc_get(sk), skb->truesize);
atomic_sub(skb->truesize, &sk->sk_wmem_alloc);
WARN_ON(refcount_sub_and_test(skb->truesize, &sk->sk_wmem_alloc));
dev_kfree_skb_any(skb);
sk->sk_write_space(sk);
}
Expand Down
2 changes: 1 addition & 1 deletion net/atm/signaling.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ static int sigd_send(struct atm_vcc *vcc, struct sk_buff *skb)
struct sock *sk;

msg = (struct atmsvc_msg *) skb->data;
atomic_sub(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc);
WARN_ON(refcount_sub_and_test(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc));
vcc = *(struct atm_vcc **) &msg->vcc;
pr_debug("%d (0x%lx)\n", (int)msg->type, (unsigned long)vcc);
sk = sk_atm(vcc);
Expand Down
2 changes: 1 addition & 1 deletion net/caif/caif_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ static const struct proto_ops caif_stream_ops = {
static void caif_sock_destructor(struct sock *sk)
{
struct caifsock *cf_sk = container_of(sk, struct caifsock, sk);
caif_assert(!atomic_read(&sk->sk_wmem_alloc));
caif_assert(!refcount_read(&sk->sk_wmem_alloc));
caif_assert(sk_unhashed(sk));
caif_assert(!sk->sk_socket);
if (!sock_flag(sk, SOCK_DEAD)) {
Expand Down
2 changes: 1 addition & 1 deletion net/core/datagram.c
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ int zerocopy_sg_from_iter(struct sk_buff *skb, struct iov_iter *from)
skb->data_len += copied;
skb->len += copied;
skb->truesize += truesize;
atomic_add(truesize, &skb->sk->sk_wmem_alloc);
refcount_add(truesize, &skb->sk->sk_wmem_alloc);
while (copied) {
int size = min_t(int, copied, PAGE_SIZE - start);
skb_fill_page_desc(skb, frag++, pages[n], start, size);
Expand Down
2 changes: 1 addition & 1 deletion net/core/skbuff.c
Original file line number Diff line number Diff line change
Expand Up @@ -3024,7 +3024,7 @@ int skb_append_datato_frags(struct sock *sk, struct sk_buff *skb,
get_page(pfrag->page);

skb->truesize += copy;
atomic_add(copy, &sk->sk_wmem_alloc);
refcount_add(copy, &sk->sk_wmem_alloc);
skb->len += copy;
skb->data_len += copy;
offset += copy;
Expand Down
26 changes: 13 additions & 13 deletions net/core/sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -1528,7 +1528,7 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority,
if (likely(sk->sk_net_refcnt))
get_net(net);
sock_net_set(sk, net);
atomic_set(&sk->sk_wmem_alloc, 1);
refcount_set(&sk->sk_wmem_alloc, 1);

mem_cgroup_sk_alloc(sk);
cgroup_sk_alloc(&sk->sk_cgrp_data);
Expand All @@ -1552,7 +1552,7 @@ static void __sk_destruct(struct rcu_head *head)
sk->sk_destruct(sk);

filter = rcu_dereference_check(sk->sk_filter,
atomic_read(&sk->sk_wmem_alloc) == 0);
refcount_read(&sk->sk_wmem_alloc) == 0);
if (filter) {
sk_filter_uncharge(sk, filter);
RCU_INIT_POINTER(sk->sk_filter, NULL);
Expand Down Expand Up @@ -1602,7 +1602,7 @@ void sk_free(struct sock *sk)
* some packets are still in some tx queue.
* If not null, sock_wfree() will call __sk_free(sk) later
*/
if (atomic_dec_and_test(&sk->sk_wmem_alloc))
if (refcount_dec_and_test(&sk->sk_wmem_alloc))
__sk_free(sk);
}
EXPORT_SYMBOL(sk_free);
Expand Down Expand Up @@ -1659,7 +1659,7 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
/*
* sk_wmem_alloc set to one (see sk_free() and sock_wfree())
*/
atomic_set(&newsk->sk_wmem_alloc, 1);
refcount_set(&newsk->sk_wmem_alloc, 1);
atomic_set(&newsk->sk_omem_alloc, 0);
sk_init_common(newsk);

Expand Down Expand Up @@ -1787,15 +1787,15 @@ void sock_wfree(struct sk_buff *skb)
* Keep a reference on sk_wmem_alloc, this will be released
* after sk_write_space() call
*/
atomic_sub(len - 1, &sk->sk_wmem_alloc);
WARN_ON(refcount_sub_and_test(len - 1, &sk->sk_wmem_alloc));
sk->sk_write_space(sk);
len = 1;
}
/*
* if sk_wmem_alloc reaches 0, we must finish what sk_free()
* could not do because of in-flight packets
*/
if (atomic_sub_and_test(len, &sk->sk_wmem_alloc))
if (refcount_sub_and_test(len, &sk->sk_wmem_alloc))
__sk_free(sk);
}
EXPORT_SYMBOL(sock_wfree);
Expand All @@ -1807,7 +1807,7 @@ void __sock_wfree(struct sk_buff *skb)
{
struct sock *sk = skb->sk;

if (atomic_sub_and_test(skb->truesize, &sk->sk_wmem_alloc))
if (refcount_sub_and_test(skb->truesize, &sk->sk_wmem_alloc))
__sk_free(sk);
}

Expand All @@ -1829,7 +1829,7 @@ void skb_set_owner_w(struct sk_buff *skb, struct sock *sk)
* is enough to guarantee sk_free() wont free this sock until
* all in-flight packets are completed
*/
atomic_add(skb->truesize, &sk->sk_wmem_alloc);
refcount_add(skb->truesize, &sk->sk_wmem_alloc);
}
EXPORT_SYMBOL(skb_set_owner_w);

Expand All @@ -1852,7 +1852,7 @@ void skb_orphan_partial(struct sk_buff *skb)
struct sock *sk = skb->sk;

if (atomic_inc_not_zero(&sk->sk_refcnt)) {
atomic_sub(skb->truesize, &sk->sk_wmem_alloc);
WARN_ON(refcount_sub_and_test(skb->truesize, &sk->sk_wmem_alloc));
skb->destructor = sock_efree;
}
} else {
Expand Down Expand Up @@ -1912,7 +1912,7 @@ EXPORT_SYMBOL(sock_i_ino);
struct sk_buff *sock_wmalloc(struct sock *sk, unsigned long size, int force,
gfp_t priority)
{
if (force || atomic_read(&sk->sk_wmem_alloc) < sk->sk_sndbuf) {
if (force || refcount_read(&sk->sk_wmem_alloc) < sk->sk_sndbuf) {
struct sk_buff *skb = alloc_skb(size, priority);
if (skb) {
skb_set_owner_w(skb, sk);
Expand Down Expand Up @@ -1987,7 +1987,7 @@ static long sock_wait_for_wmem(struct sock *sk, long timeo)
break;
set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
if (atomic_read(&sk->sk_wmem_alloc) < sk->sk_sndbuf)
if (refcount_read(&sk->sk_wmem_alloc) < sk->sk_sndbuf)
break;
if (sk->sk_shutdown & SEND_SHUTDOWN)
break;
Expand Down Expand Up @@ -2310,7 +2310,7 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
if (sk->sk_type == SOCK_STREAM) {
if (sk->sk_wmem_queued < prot->sysctl_wmem[0])
return 1;
} else if (atomic_read(&sk->sk_wmem_alloc) <
} else if (refcount_read(&sk->sk_wmem_alloc) <
prot->sysctl_wmem[0])
return 1;
}
Expand Down Expand Up @@ -2577,7 +2577,7 @@ static void sock_def_write_space(struct sock *sk)
/* Do not wake up a writer until he can make "significant"
* progress. --DaveM
*/
if ((atomic_read(&sk->sk_wmem_alloc) << 1) <= sk->sk_sndbuf) {
if ((refcount_read(&sk->sk_wmem_alloc) << 1) <= sk->sk_sndbuf) {
wq = rcu_dereference(sk->sk_wq);
if (skwq_has_sleeper(wq))
wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
Expand Down
2 changes: 1 addition & 1 deletion net/ipv4/af_inet.c
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void inet_sock_destruct(struct sock *sk)
}

WARN_ON(atomic_read(&sk->sk_rmem_alloc));
WARN_ON(atomic_read(&sk->sk_wmem_alloc));
WARN_ON(refcount_read(&sk->sk_wmem_alloc));
WARN_ON(sk->sk_wmem_queued);
WARN_ON(sk->sk_forward_alloc);

Expand Down
2 changes: 1 addition & 1 deletion net/ipv4/esp4.c
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ int esp_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *
skb->data_len += tailen;
skb->truesize += tailen;
if (sk)
atomic_add(tailen, &sk->sk_wmem_alloc);
refcount_add(tailen, &sk->sk_wmem_alloc);

goto out;
}
Expand Down
Loading

0 comments on commit 14afee4

Please sign in to comment.