Skip to content
24 changes: 14 additions & 10 deletions drivers/infiniband/core/iwcm.c
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,7 @@ EXPORT_SYMBOL(iw_cm_disconnect);
/*
* CM_ID <-- DESTROYING
*
* Clean up all resources associated with the connection and release
* the initial reference taken by iw_create_cm_id.
* Clean up all resources associated with the connection.
*/
static void destroy_cm_id(struct iw_cm_id *cm_id)
{
Expand Down Expand Up @@ -439,19 +438,22 @@ static void destroy_cm_id(struct iw_cm_id *cm_id)
iwpm_remove_mapinfo(&cm_id->local_addr, &cm_id->m_local_addr);
iwpm_remove_mapping(&cm_id->local_addr, RDMA_NL_IWCM);
}

(void)iwcm_deref_id(cm_id_priv);
}

/*
* This function is only called by the application thread and cannot
* be called by the event thread. The function will wait for all
* references to be released on the cm_id and then kfree the cm_id
* object.
* Destroy cm_id. If the cm_id still has other references, wait for all
* references to be released on the cm_id and then release the initial
* reference taken by iw_create_cm_id.
*/
void iw_destroy_cm_id(struct iw_cm_id *cm_id)
{
struct iwcm_id_private *cm_id_priv;

cm_id_priv = container_of(cm_id, struct iwcm_id_private, id);
destroy_cm_id(cm_id);
if (refcount_read(&cm_id_priv->refcount) > 1)
flush_workqueue(iwcm_wq);
iwcm_deref_id(cm_id_priv);
}
EXPORT_SYMBOL(iw_destroy_cm_id);

Expand Down Expand Up @@ -1034,8 +1036,10 @@ static void cm_work_handler(struct work_struct *_work)

if (!test_bit(IWCM_F_DROP_EVENTS, &cm_id_priv->flags)) {
ret = process_event(cm_id_priv, &levent);
if (ret)
if (ret) {
destroy_cm_id(&cm_id_priv->id);
WARN_ON_ONCE(iwcm_deref_id(cm_id_priv));
}
} else
pr_debug("dropping event %d\n", levent.event);
if (iwcm_deref_id(cm_id_priv))
Expand Down Expand Up @@ -1188,7 +1192,7 @@ static int __init iw_cm_init(void)
if (ret)
return ret;

iwcm_wq = alloc_ordered_workqueue("iw_cm_wq", 0);
iwcm_wq = alloc_ordered_workqueue("iw_cm_wq", WQ_MEM_RECLAIM);
if (!iwcm_wq)
goto err_alloc;

Expand Down
21 changes: 15 additions & 6 deletions include/net/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,12 @@ static inline void sk_wmem_queued_add(struct sock *sk, int val)
WRITE_ONCE(sk->sk_wmem_queued, sk->sk_wmem_queued + val);
}

static inline void sk_forward_alloc_add(struct sock *sk, int val)
{
/* Paired with lockless reads of sk->sk_forward_alloc */
WRITE_ONCE(sk->sk_forward_alloc, sk->sk_forward_alloc + val);
}

void sk_stream_write_space(struct sock *sk);

/* OOB backlog add */
Expand Down Expand Up @@ -1273,7 +1279,9 @@ struct proto {
unsigned int inuse_idx;
#endif

#if IS_ENABLED(CONFIG_MPTCP)
int (*forward_alloc_get)(const struct sock *sk);
#endif

bool (*stream_memory_free)(const struct sock *sk, int wake);
bool (*sock_is_readable)(struct sock *sk);
Expand Down Expand Up @@ -1363,10 +1371,11 @@ INDIRECT_CALLABLE_DECLARE(bool tcp_stream_memory_free(const struct sock *sk, int

static inline int sk_forward_alloc_get(const struct sock *sk)
{
if (!sk->sk_prot->forward_alloc_get)
return sk->sk_forward_alloc;

return sk->sk_prot->forward_alloc_get(sk);
#if IS_ENABLED(CONFIG_MPTCP)
if (sk->sk_prot->forward_alloc_get)
return sk->sk_prot->forward_alloc_get(sk);
#endif
return READ_ONCE(sk->sk_forward_alloc);
}

static inline bool __sk_stream_memory_free(const struct sock *sk, int wake)
Expand Down Expand Up @@ -1665,7 +1674,7 @@ static inline void sk_mem_charge(struct sock *sk, int size)
{
if (!sk_has_account(sk))
return;
sk->sk_forward_alloc -= size;
sk_forward_alloc_add(sk, -size);
}

/* the following macros control memory reclaiming in mptcp_rmem_uncharge()
Expand All @@ -1677,7 +1686,7 @@ static inline void sk_mem_uncharge(struct sock *sk, int size)
{
if (!sk_has_account(sk))
return;
sk->sk_forward_alloc += size;
sk_forward_alloc_add(sk, size);
sk_mem_reclaim(sk);
}

Expand Down
6 changes: 2 additions & 4 deletions net/bpf/test_run.c
Original file line number Diff line number Diff line change
Expand Up @@ -727,12 +727,10 @@ static void *bpf_test_init(const union bpf_attr *kattr, u32 user_size,
void __user *data_in = u64_to_user_ptr(kattr->test.data_in);
void *data;

if (size < ETH_HLEN || size > PAGE_SIZE - headroom - tailroom)
if (user_size < ETH_HLEN || user_size > PAGE_SIZE - headroom - tailroom)
return ERR_PTR(-EINVAL);

if (user_size > size)
return ERR_PTR(-EMSGSIZE);

size = SKB_DATA_ALIGN(size);
data = kzalloc(size + headroom + tailroom, GFP_USER);
if (!data)
return ERR_PTR(-ENOMEM);
Expand Down
8 changes: 4 additions & 4 deletions net/core/sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
mem_cgroup_uncharge_skmem(sk->sk_memcg, pages);
return -ENOMEM;
}
sk->sk_forward_alloc += pages << PAGE_SHIFT;
sk_forward_alloc_add(sk, pages << PAGE_SHIFT);

sk->sk_reserved_mem += pages << PAGE_SHIFT;

Expand Down Expand Up @@ -2974,10 +2974,10 @@ int __sk_mem_schedule(struct sock *sk, int size, int kind)
{
int ret, amt = sk_mem_pages(size);

sk->sk_forward_alloc += amt << PAGE_SHIFT;
sk_forward_alloc_add(sk, amt << PAGE_SHIFT);
ret = __sk_mem_raise_allocated(sk, size, amt, kind);
if (!ret)
sk->sk_forward_alloc -= amt << PAGE_SHIFT;
sk_forward_alloc_add(sk, -(amt << PAGE_SHIFT));
return ret;
}
EXPORT_SYMBOL(__sk_mem_schedule);
Expand Down Expand Up @@ -3010,7 +3010,7 @@ EXPORT_SYMBOL(__sk_mem_reduce_allocated);
void __sk_mem_reclaim(struct sock *sk, int amount)
{
amount >>= PAGE_SHIFT;
sk->sk_forward_alloc -= amount << PAGE_SHIFT;
sk_forward_alloc_add(sk, -(amount << PAGE_SHIFT));
__sk_mem_reduce_allocated(sk, amount);
}
EXPORT_SYMBOL(__sk_mem_reclaim);
Expand Down
2 changes: 1 addition & 1 deletion net/ipv4/tcp_output.c
Original file line number Diff line number Diff line change
Expand Up @@ -3386,7 +3386,7 @@ void sk_forced_mem_schedule(struct sock *sk, int size)
if (delta <= 0)
return;
amt = sk_mem_pages(delta);
sk->sk_forward_alloc += amt << PAGE_SHIFT;
sk_forward_alloc_add(sk, amt << PAGE_SHIFT);
sk_memory_allocated_add(sk, amt);

if (mem_cgroup_sockets_enabled && sk->sk_memcg)
Expand Down
18 changes: 8 additions & 10 deletions net/ipv4/udp.c
Original file line number Diff line number Diff line change
Expand Up @@ -1436,12 +1436,12 @@ static bool udp_skb_has_head_state(struct sk_buff *skb)
}

/* fully reclaim rmem/fwd memory allocated for skb */
static void udp_rmem_release(struct sock *sk, int size, int partial,
bool rx_queue_lock_held)
static void udp_rmem_release(struct sock *sk, unsigned int size,
int partial, bool rx_queue_lock_held)
{
struct udp_sock *up = udp_sk(sk);
struct sk_buff_head *sk_queue;
int amt;
unsigned int amt;

if (likely(partial)) {
up->forward_deficit += size;
Expand All @@ -1461,10 +1461,8 @@ static void udp_rmem_release(struct sock *sk, int size, int partial,
if (!rx_queue_lock_held)
spin_lock(&sk_queue->lock);


sk->sk_forward_alloc += size;
amt = (sk->sk_forward_alloc - partial) & ~(PAGE_SIZE - 1);
sk->sk_forward_alloc -= amt;
amt = (size + sk->sk_forward_alloc - partial) & ~(PAGE_SIZE - 1);
sk_forward_alloc_add(sk, size - amt);

if (amt)
__sk_mem_reduce_allocated(sk, amt >> PAGE_SHIFT);
Expand Down Expand Up @@ -1570,7 +1568,7 @@ int __udp_enqueue_schedule_skb(struct sock *sk, struct sk_buff *skb)
sk->sk_forward_alloc += delta;
}

sk->sk_forward_alloc -= size;
sk_forward_alloc_add(sk, -size);

/* no need to setup a destructor, we will explicitly release the
* forward allocated memory on dequeue
Expand Down Expand Up @@ -1648,7 +1646,7 @@ EXPORT_SYMBOL_GPL(skb_consume_udp);

static struct sk_buff *__first_packet_length(struct sock *sk,
struct sk_buff_head *rcvq,
int *total)
unsigned int *total)
{
struct sk_buff *skb;

Expand Down Expand Up @@ -1681,8 +1679,8 @@ static int first_packet_length(struct sock *sk)
{
struct sk_buff_head *rcvq = &udp_sk(sk)->reader_queue;
struct sk_buff_head *sk_queue = &sk->sk_receive_queue;
unsigned int total = 0;
struct sk_buff *skb;
int total = 0;
int res;

spin_lock_bh(&rcvq->lock);
Expand Down
6 changes: 3 additions & 3 deletions net/mptcp/protocol.c
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,7 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
}

/* data successfully copied into the write queue */
sk->sk_forward_alloc -= total_ts;
sk_forward_alloc_add(sk, -total_ts);
copied += psize;
dfrag->data_len += psize;
frag_truesize += psize;
Expand Down Expand Up @@ -3198,7 +3198,7 @@ void mptcp_destroy_common(struct mptcp_sock *msk, unsigned int flags)
/* move all the rx fwd alloc into the sk_mem_reclaim_final in
* inet_sock_destruct() will dispose it
*/
sk->sk_forward_alloc += msk->rmem_fwd_alloc;
sk_forward_alloc_add(sk, msk->rmem_fwd_alloc);
msk->rmem_fwd_alloc = 0;
mptcp_token_destroy(msk);
mptcp_pm_free_anno_list(msk);
Expand Down Expand Up @@ -3479,7 +3479,7 @@ static void mptcp_shutdown(struct sock *sk, int how)

static int mptcp_forward_alloc_get(const struct sock *sk)
{
return sk->sk_forward_alloc + mptcp_sk(sk)->rmem_fwd_alloc;
return READ_ONCE(sk->sk_forward_alloc) + mptcp_sk(sk)->rmem_fwd_alloc;
}

static int mptcp_ioctl_outq(const struct mptcp_sock *msk, u64 v)
Expand Down