@@ -143,11 +143,13 @@ static bool lookup_subflow_by_daddr(const struct list_head *list,
143143 return false;
144144}
145145
146- static struct mptcp_pm_addr_entry *
146+ static bool
147147select_local_address (const struct pm_nl_pernet * pernet ,
148- const struct mptcp_sock * msk )
148+ const struct mptcp_sock * msk ,
149+ struct mptcp_pm_addr_entry * new_entry )
149150{
150- struct mptcp_pm_addr_entry * entry , * ret = NULL ;
151+ struct mptcp_pm_addr_entry * entry ;
152+ bool found = false;
151153
152154 msk_owned_by_me (msk );
153155
@@ -159,17 +161,21 @@ select_local_address(const struct pm_nl_pernet *pernet,
159161 if (!test_bit (entry -> addr .id , msk -> pm .id_avail_bitmap ))
160162 continue ;
161163
162- ret = entry ;
164+ * new_entry = * entry ;
165+ found = true;
163166 break ;
164167 }
165168 rcu_read_unlock ();
166- return ret ;
169+
170+ return found ;
167171}
168172
169- static struct mptcp_pm_addr_entry *
170- select_signal_address (struct pm_nl_pernet * pernet , const struct mptcp_sock * msk )
173+ static bool
174+ select_signal_address (struct pm_nl_pernet * pernet , const struct mptcp_sock * msk ,
175+ struct mptcp_pm_addr_entry * new_entry )
171176{
172- struct mptcp_pm_addr_entry * entry , * ret = NULL ;
177+ struct mptcp_pm_addr_entry * entry ;
178+ bool found = false;
173179
174180 rcu_read_lock ();
175181 /* do not keep any additional per socket state, just signal
@@ -184,11 +190,13 @@ select_signal_address(struct pm_nl_pernet *pernet, const struct mptcp_sock *msk)
184190 if (!(entry -> flags & MPTCP_PM_ADDR_FLAG_SIGNAL ))
185191 continue ;
186192
187- ret = entry ;
193+ * new_entry = * entry ;
194+ found = true;
188195 break ;
189196 }
190197 rcu_read_unlock ();
191- return ret ;
198+
199+ return found ;
192200}
193201
194202unsigned int mptcp_pm_get_add_addr_signal_max (const struct mptcp_sock * msk )
@@ -512,9 +520,10 @@ __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info)
512520
513521static void mptcp_pm_create_subflow_or_signal_addr (struct mptcp_sock * msk )
514522{
515- struct mptcp_pm_addr_entry * local , * signal_and_subflow = NULL ;
516523 struct sock * sk = (struct sock * )msk ;
524+ struct mptcp_pm_addr_entry local ;
517525 unsigned int add_addr_signal_max ;
526+ bool signal_and_subflow = false;
518527 unsigned int local_addr_max ;
519528 struct pm_nl_pernet * pernet ;
520529 unsigned int subflows_max ;
@@ -565,23 +574,22 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
565574 if (msk -> pm .addr_signal & BIT (MPTCP_ADD_ADDR_SIGNAL ))
566575 return ;
567576
568- local = select_signal_address (pernet , msk );
569- if (!local )
577+ if (!select_signal_address (pernet , msk , & local ))
570578 goto subflow ;
571579
572580 /* If the alloc fails, we are on memory pressure, not worth
573581 * continuing, and trying to create subflows.
574582 */
575- if (!mptcp_pm_alloc_anno_list (msk , & local -> addr ))
583+ if (!mptcp_pm_alloc_anno_list (msk , & local . addr ))
576584 return ;
577585
578- __clear_bit (local -> addr .id , msk -> pm .id_avail_bitmap );
586+ __clear_bit (local . addr .id , msk -> pm .id_avail_bitmap );
579587 msk -> pm .add_addr_signaled ++ ;
580- mptcp_pm_announce_addr (msk , & local -> addr , false);
588+ mptcp_pm_announce_addr (msk , & local . addr , false);
581589 mptcp_pm_nl_addr_send_ack (msk );
582590
583- if (local -> flags & MPTCP_PM_ADDR_FLAG_SUBFLOW )
584- signal_and_subflow = local ;
591+ if (local . flags & MPTCP_PM_ADDR_FLAG_SUBFLOW )
592+ signal_and_subflow = true ;
585593 }
586594
587595subflow :
@@ -592,26 +600,22 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
592600 bool fullmesh ;
593601 int i , nr ;
594602
595- if (signal_and_subflow ) {
596- local = signal_and_subflow ;
597- signal_and_subflow = NULL ;
598- } else {
599- local = select_local_address (pernet , msk );
600- if (!local )
601- break ;
602- }
603+ if (signal_and_subflow )
604+ signal_and_subflow = false;
605+ else if (!select_local_address (pernet , msk , & local ))
606+ break ;
603607
604- fullmesh = !!(local -> flags & MPTCP_PM_ADDR_FLAG_FULLMESH );
608+ fullmesh = !!(local . flags & MPTCP_PM_ADDR_FLAG_FULLMESH );
605609
606610 msk -> pm .local_addr_used ++ ;
607- __clear_bit (local -> addr .id , msk -> pm .id_avail_bitmap );
608- nr = fill_remote_addresses_vec (msk , & local -> addr , fullmesh , addrs );
611+ __clear_bit (local . addr .id , msk -> pm .id_avail_bitmap );
612+ nr = fill_remote_addresses_vec (msk , & local . addr , fullmesh , addrs );
609613 if (nr == 0 )
610614 continue ;
611615
612616 spin_unlock_bh (& msk -> pm .lock );
613617 for (i = 0 ; i < nr ; i ++ )
614- __mptcp_subflow_connect (sk , & local -> addr , & addrs [i ]);
618+ __mptcp_subflow_connect (sk , & local . addr , & addrs [i ]);
615619 spin_lock_bh (& msk -> pm .lock );
616620 }
617621 mptcp_pm_nl_check_work_pending (msk );
@@ -636,13 +640,16 @@ static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk,
636640{
637641 struct sock * sk = (struct sock * )msk ;
638642 struct mptcp_pm_addr_entry * entry ;
643+ struct mptcp_addr_info mpc_addr ;
639644 struct pm_nl_pernet * pernet ;
640645 unsigned int subflows_max ;
641646 int i = 0 ;
642647
643648 pernet = pm_nl_get_pernet_from_msk (msk );
644649 subflows_max = mptcp_pm_get_subflows_max (msk );
645650
651+ mptcp_local_address ((struct sock_common * )msk , & mpc_addr );
652+
646653 rcu_read_lock ();
647654 list_for_each_entry_rcu (entry , & pernet -> local_addr_list , list ) {
648655 if (!(entry -> flags & MPTCP_PM_ADDR_FLAG_FULLMESH ))
@@ -653,7 +660,13 @@ static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk,
653660
654661 if (msk -> pm .subflows < subflows_max ) {
655662 msk -> pm .subflows ++ ;
656- addrs [i ++ ] = entry -> addr ;
663+ addrs [i ] = entry -> addr ;
664+
665+ /* Special case for ID0: set the correct ID */
666+ if (mptcp_addresses_equal (& entry -> addr , & mpc_addr , entry -> addr .port ))
667+ addrs [i ].id = 0 ;
668+
669+ i ++ ;
657670 }
658671 }
659672 rcu_read_unlock ();
@@ -829,25 +842,27 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
829842 mptcp_close_ssk (sk , ssk , subflow );
830843 spin_lock_bh (& msk -> pm .lock );
831844
832- removed = true ;
845+ removed |= subflow -> request_join ;
833846 if (rm_type == MPTCP_MIB_RMSUBFLOW )
834847 __MPTCP_INC_STATS (sock_net (sk ), rm_type );
835848 }
836- if (rm_type == MPTCP_MIB_RMSUBFLOW )
837- __set_bit (rm_id ? rm_id : msk -> mpc_endpoint_id , msk -> pm .id_avail_bitmap );
838- else if (rm_type == MPTCP_MIB_RMADDR )
849+
850+ if (rm_type == MPTCP_MIB_RMADDR )
839851 __MPTCP_INC_STATS (sock_net (sk ), rm_type );
852+
840853 if (!removed )
841854 continue ;
842855
843856 if (!mptcp_pm_is_kernel (msk ))
844857 continue ;
845858
846- if (rm_type == MPTCP_MIB_RMADDR ) {
847- msk -> pm .add_addr_accepted -- ;
848- WRITE_ONCE (msk -> pm .accept_addr , true);
849- } else if (rm_type == MPTCP_MIB_RMSUBFLOW ) {
850- msk -> pm .local_addr_used -- ;
859+ if (rm_type == MPTCP_MIB_RMADDR && rm_id &&
860+ !WARN_ON_ONCE (msk -> pm .add_addr_accepted == 0 )) {
861+ /* Note: if the subflow has been closed before, this
862+ * add_addr_accepted counter will not be decremented.
863+ */
864+ if (-- msk -> pm .add_addr_accepted < mptcp_pm_get_add_addr_accept_max (msk ))
865+ WRITE_ONCE (msk -> pm .accept_addr , true);
851866 }
852867 }
853868}
@@ -857,8 +872,8 @@ static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
857872 mptcp_pm_nl_rm_addr_or_subflow (msk , & msk -> pm .rm_list_rx , MPTCP_MIB_RMADDR );
858873}
859874
860- void mptcp_pm_nl_rm_subflow_received (struct mptcp_sock * msk ,
861- const struct mptcp_rm_list * rm_list )
875+ static void mptcp_pm_nl_rm_subflow_received (struct mptcp_sock * msk ,
876+ const struct mptcp_rm_list * rm_list )
862877{
863878 mptcp_pm_nl_rm_addr_or_subflow (msk , rm_list , MPTCP_MIB_RMSUBFLOW );
864879}
@@ -1393,6 +1408,10 @@ int mptcp_pm_nl_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, unsigned int
13931408 struct sock * sk = (struct sock * )msk ;
13941409 struct net * net = sock_net (sk );
13951410
1411+ /* No entries with ID 0 */
1412+ if (id == 0 )
1413+ return 0 ;
1414+
13961415 rcu_read_lock ();
13971416 entry = __lookup_addr_by_id (pm_nl_get_pernet (net ), id );
13981417 if (entry ) {
@@ -1431,13 +1450,24 @@ static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
14311450 ret = remove_anno_list_by_saddr (msk , addr );
14321451 if (ret || force ) {
14331452 spin_lock_bh (& msk -> pm .lock );
1434- msk -> pm .add_addr_signaled -= ret ;
1453+ if (ret ) {
1454+ __set_bit (addr -> id , msk -> pm .id_avail_bitmap );
1455+ msk -> pm .add_addr_signaled -- ;
1456+ }
14351457 mptcp_pm_remove_addr (msk , & list );
14361458 spin_unlock_bh (& msk -> pm .lock );
14371459 }
14381460 return ret ;
14391461}
14401462
1463+ static void __mark_subflow_endp_available (struct mptcp_sock * msk , u8 id )
1464+ {
1465+ /* If it was marked as used, and not ID 0, decrement local_addr_used */
1466+ if (!__test_and_set_bit (id ? : msk -> mpc_endpoint_id , msk -> pm .id_avail_bitmap ) &&
1467+ id && !WARN_ON_ONCE (msk -> pm .local_addr_used == 0 ))
1468+ msk -> pm .local_addr_used -- ;
1469+ }
1470+
14411471static int mptcp_nl_remove_subflow_and_signal_addr (struct net * net ,
14421472 const struct mptcp_pm_addr_entry * entry )
14431473{
@@ -1466,8 +1496,19 @@ static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
14661496 remove_subflow = lookup_subflow_by_saddr (& msk -> conn_list , addr );
14671497 mptcp_pm_remove_anno_addr (msk , addr , remove_subflow &&
14681498 !(entry -> flags & MPTCP_PM_ADDR_FLAG_IMPLICIT ));
1469- if (remove_subflow )
1470- mptcp_pm_remove_subflow (msk , & list );
1499+
1500+ if (remove_subflow ) {
1501+ spin_lock_bh (& msk -> pm .lock );
1502+ mptcp_pm_nl_rm_subflow_received (msk , & list );
1503+ spin_unlock_bh (& msk -> pm .lock );
1504+ }
1505+
1506+ if (entry -> flags & MPTCP_PM_ADDR_FLAG_SUBFLOW ) {
1507+ spin_lock_bh (& msk -> pm .lock );
1508+ __mark_subflow_endp_available (msk , list .ids [0 ]);
1509+ spin_unlock_bh (& msk -> pm .lock );
1510+ }
1511+
14711512 release_sock (sk );
14721513
14731514next :
@@ -1502,6 +1543,7 @@ static int mptcp_nl_remove_id_zero_address(struct net *net,
15021543 spin_lock_bh (& msk -> pm .lock );
15031544 mptcp_pm_remove_addr (msk , & list );
15041545 mptcp_pm_nl_rm_subflow_received (msk , & list );
1546+ __mark_subflow_endp_available (msk , 0 );
15051547 spin_unlock_bh (& msk -> pm .lock );
15061548 release_sock (sk );
15071549
@@ -1605,14 +1647,17 @@ static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk,
16051647 alist .ids [alist .nr ++ ] = entry -> addr .id ;
16061648 }
16071649
1650+ spin_lock_bh (& msk -> pm .lock );
16081651 if (alist .nr ) {
1609- spin_lock_bh (& msk -> pm .lock );
16101652 msk -> pm .add_addr_signaled -= alist .nr ;
16111653 mptcp_pm_remove_addr (msk , & alist );
1612- spin_unlock_bh (& msk -> pm .lock );
16131654 }
16141655 if (slist .nr )
1615- mptcp_pm_remove_subflow (msk , & slist );
1656+ mptcp_pm_nl_rm_subflow_received (msk , & slist );
1657+ /* Reset counters: maybe some subflows have been removed before */
1658+ bitmap_fill (msk -> pm .id_avail_bitmap , MPTCP_PM_MAX_ADDR_ID + 1 );
1659+ msk -> pm .local_addr_used = 0 ;
1660+ spin_unlock_bh (& msk -> pm .lock );
16161661}
16171662
16181663static void mptcp_nl_remove_addrs_list (struct net * net ,
@@ -1900,6 +1945,7 @@ static void mptcp_pm_nl_fullmesh(struct mptcp_sock *msk,
19001945
19011946 spin_lock_bh (& msk -> pm .lock );
19021947 mptcp_pm_nl_rm_subflow_received (msk , & list );
1948+ __mark_subflow_endp_available (msk , list .ids [0 ]);
19031949 mptcp_pm_create_subflow_or_signal_addr (msk );
19041950 spin_unlock_bh (& msk -> pm .lock );
19051951}
0 commit comments