diff --git a/net/mptcp/pm_userspace.c b/net/mptcp/pm_userspace.c index a7e82d39e6aa..cffd61ea4489 100644 --- a/net/mptcp/pm_userspace.c +++ b/net/mptcp/pm_userspace.c @@ -64,6 +64,7 @@ static int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk, 1); list_add_tail_rcu(&e->list, &msk->pm.userspace_pm_local_addr_list); msk->pm.local_addr_used++; + refcount_set(&e->refcnt, 1); ret = e->addr.id; append_err: @@ -96,12 +97,11 @@ static int mptcp_userspace_pm_delete_local_addr(struct mptcp_sock *msk, entry = mptcp_userspace_pm_get_entry(msk, &addr->addr); if (entry) { - /* TODO: a refcount is needed because the entry can - * be used multiple times (e.g. fullmesh mode). - */ - list_del_rcu(&entry->list); - kfree(entry); - msk->pm.local_addr_used--; + if (!refcount_dec_not_one(&entry->refcnt)) { + list_del_rcu(&entry->list); + kfree(entry); + msk->pm.local_addr_used--; + } return 0; } @@ -213,6 +213,11 @@ int mptcp_pm_nl_announce_doit(struct sk_buff *skb, struct genl_info *info) spin_lock_bh(&msk->pm.lock); if (mptcp_pm_alloc_anno_list(msk, &addr_val.addr)) { + struct mptcp_pm_addr_entry *entry; + + entry = mptcp_userspace_pm_get_entry(msk, &addr_val.addr); + if (entry && !refcount_inc_not_zero(&entry->refcnt)) + pr_debug("userspace pm uninitialized entry"); msk->pm.add_addr_signaled++; mptcp_pm_announce_addr(msk, &addr_val.addr, false); mptcp_pm_nl_addr_send_ack(msk); @@ -318,8 +323,10 @@ int mptcp_pm_nl_remove_doit(struct sk_buff *skb, struct genl_info *info) mptcp_pm_remove_addrs(msk, &free_list); - list_del_rcu(&match->list); - kfree(match); + if (!refcount_dec_not_one(&match->refcnt)) { + list_del_rcu(&match->list); + kfree(match); + } release_sock(sk); @@ -403,10 +410,16 @@ int mptcp_pm_nl_subflow_create_doit(struct sk_buff *skb, struct genl_info *info) release_sock(sk); spin_lock_bh(&msk->pm.lock); - if (err) + if (err) { mptcp_userspace_pm_delete_local_addr(msk, &local); - else + } else { + struct mptcp_pm_addr_entry *entry; + + entry = mptcp_userspace_pm_get_entry(msk, &local.addr); + if (entry && !refcount_inc_not_zero(&entry->refcnt)) + pr_debug("userspace pm uninitialized entry"); msk->pm.subflows++; + } spin_unlock_bh(&msk->pm.lock); create_err: diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index 53ab6bfcdbfa..2006ab22992e 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -8,6 +8,7 @@ #define __MPTCP_PROTOCOL_H #include +#include #include #include #include @@ -242,6 +243,7 @@ struct mptcp_pm_addr_entry { u8 flags; int ifindex; struct socket *lsk; + refcount_t refcnt; }; struct mptcp_data_frag {