@@ -107,11 +107,17 @@ struct l2tp_net {
107107 /* Lock for write access to l2tp_tunnel_idr */
108108 spinlock_t l2tp_tunnel_idr_lock ;
109109 struct idr l2tp_tunnel_idr ;
110- struct hlist_head l2tp_session_hlist [L2TP_HASH_SIZE_2 ];
111- /* Lock for write access to l2tp_session_hlist */
112- spinlock_t l2tp_session_hlist_lock ;
110+ /* Lock for write access to l2tp_v3_session_idr/htable */
111+ spinlock_t l2tp_session_idr_lock ;
112+ struct idr l2tp_v3_session_idr ;
113+ struct hlist_head l2tp_v3_session_htable [16 ];
113114};
114115
116+ static inline unsigned long l2tp_v3_session_hashkey (struct sock * sk , u32 session_id )
117+ {
118+ return ((unsigned long )sk ) + session_id ;
119+ }
120+
115121#if IS_ENABLED (CONFIG_IPV6 )
116122static bool l2tp_sk_is_v6 (struct sock * sk )
117123{
@@ -125,17 +131,6 @@ static inline struct l2tp_net *l2tp_pernet(const struct net *net)
125131 return net_generic (net , l2tp_net_id );
126132}
127133
128- /* Session hash global list for L2TPv3.
129- * The session_id SHOULD be random according to RFC3931, but several
130- * L2TP implementations use incrementing session_ids. So we do a real
131- * hash on the session_id, rather than a simple bitmask.
132- */
133- static inline struct hlist_head *
134- l2tp_session_id_hash_2 (struct l2tp_net * pn , u32 session_id )
135- {
136- return & pn -> l2tp_session_hlist [hash_32 (session_id , L2TP_HASH_BITS_2 )];
137- }
138-
139134/* Session hash list.
140135 * The session_id SHOULD be random according to RFC2661, but several
141136 * L2TP implementations (Cisco and Microsoft) use incrementing
@@ -262,26 +257,40 @@ struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel,
262257}
263258EXPORT_SYMBOL_GPL (l2tp_tunnel_get_session );
264259
265- struct l2tp_session * l2tp_session_get (const struct net * net , u32 session_id )
260+ struct l2tp_session * l2tp_v3_session_get (const struct net * net , struct sock * sk , u32 session_id )
266261{
267- struct hlist_head * session_list ;
262+ const struct l2tp_net * pn = l2tp_pernet ( net ) ;
268263 struct l2tp_session * session ;
269264
270- session_list = l2tp_session_id_hash_2 (l2tp_pernet (net ), session_id );
271-
272265 rcu_read_lock_bh ();
273- hlist_for_each_entry_rcu (session , session_list , global_hlist )
274- if (session -> session_id == session_id ) {
275- l2tp_session_inc_refcount (session );
276- rcu_read_unlock_bh ();
266+ session = idr_find (& pn -> l2tp_v3_session_idr , session_id );
267+ if (session && !hash_hashed (& session -> hlist ) &&
268+ refcount_inc_not_zero (& session -> ref_count )) {
269+ rcu_read_unlock_bh ();
270+ return session ;
271+ }
277272
278- return session ;
273+ /* If we get here and session is non-NULL, the session_id
274+ * collides with one in another tunnel. If sk is non-NULL,
275+ * find the session matching sk.
276+ */
277+ if (session && sk ) {
278+ unsigned long key = l2tp_v3_session_hashkey (sk , session -> session_id );
279+
280+ hash_for_each_possible_rcu (pn -> l2tp_v3_session_htable , session ,
281+ hlist , key ) {
282+ if (session -> tunnel -> sock == sk &&
283+ refcount_inc_not_zero (& session -> ref_count )) {
284+ rcu_read_unlock_bh ();
285+ return session ;
286+ }
279287 }
288+ }
280289 rcu_read_unlock_bh ();
281290
282291 return NULL ;
283292}
284- EXPORT_SYMBOL_GPL (l2tp_session_get );
293+ EXPORT_SYMBOL_GPL (l2tp_v3_session_get );
285294
286295struct l2tp_session * l2tp_session_get_nth (struct l2tp_tunnel * tunnel , int nth )
287296{
@@ -313,12 +322,12 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
313322 const char * ifname )
314323{
315324 struct l2tp_net * pn = l2tp_pernet (net );
316- int hash ;
325+ unsigned long session_id , tmp ;
317326 struct l2tp_session * session ;
318327
319328 rcu_read_lock_bh ();
320- for ( hash = 0 ; hash < L2TP_HASH_SIZE_2 ; hash ++ ) {
321- hlist_for_each_entry_rcu (session , & pn -> l2tp_session_hlist [ hash ], global_hlist ) {
329+ idr_for_each_entry_ul ( & pn -> l2tp_v3_session_idr , session , tmp , session_id ) {
330+ if (session ) {
322331 if (!strcmp (session -> ifname , ifname )) {
323332 l2tp_session_inc_refcount (session );
324333 rcu_read_unlock_bh ();
@@ -334,13 +343,106 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
334343}
335344EXPORT_SYMBOL_GPL (l2tp_session_get_by_ifname );
336345
346+ static void l2tp_session_coll_list_add (struct l2tp_session_coll_list * clist ,
347+ struct l2tp_session * session )
348+ {
349+ l2tp_session_inc_refcount (session );
350+ WARN_ON_ONCE (session -> coll_list );
351+ session -> coll_list = clist ;
352+ spin_lock (& clist -> lock );
353+ list_add (& session -> clist , & clist -> list );
354+ spin_unlock (& clist -> lock );
355+ }
356+
357+ static int l2tp_session_collision_add (struct l2tp_net * pn ,
358+ struct l2tp_session * session1 ,
359+ struct l2tp_session * session2 )
360+ {
361+ struct l2tp_session_coll_list * clist ;
362+
363+ lockdep_assert_held (& pn -> l2tp_session_idr_lock );
364+
365+ if (!session2 )
366+ return - EEXIST ;
367+
368+ /* If existing session is in IP-encap tunnel, refuse new session */
369+ if (session2 -> tunnel -> encap == L2TP_ENCAPTYPE_IP )
370+ return - EEXIST ;
371+
372+ clist = session2 -> coll_list ;
373+ if (!clist ) {
374+ /* First collision. Allocate list to manage the collided sessions
375+ * and add the existing session to the list.
376+ */
377+ clist = kmalloc (sizeof (* clist ), GFP_ATOMIC );
378+ if (!clist )
379+ return - ENOMEM ;
380+
381+ spin_lock_init (& clist -> lock );
382+ INIT_LIST_HEAD (& clist -> list );
383+ refcount_set (& clist -> ref_count , 1 );
384+ l2tp_session_coll_list_add (clist , session2 );
385+ }
386+
387+ /* If existing session isn't already in the session hlist, add it. */
388+ if (!hash_hashed (& session2 -> hlist ))
389+ hash_add (pn -> l2tp_v3_session_htable , & session2 -> hlist ,
390+ session2 -> hlist_key );
391+
392+ /* Add new session to the hlist and collision list */
393+ hash_add (pn -> l2tp_v3_session_htable , & session1 -> hlist ,
394+ session1 -> hlist_key );
395+ refcount_inc (& clist -> ref_count );
396+ l2tp_session_coll_list_add (clist , session1 );
397+
398+ return 0 ;
399+ }
400+
401+ static void l2tp_session_collision_del (struct l2tp_net * pn ,
402+ struct l2tp_session * session )
403+ {
404+ struct l2tp_session_coll_list * clist = session -> coll_list ;
405+ unsigned long session_key = session -> session_id ;
406+ struct l2tp_session * session2 ;
407+
408+ lockdep_assert_held (& pn -> l2tp_session_idr_lock );
409+
410+ hash_del (& session -> hlist );
411+
412+ if (clist ) {
413+ /* Remove session from its collision list. If there
414+ * are other sessions with the same ID, replace this
415+ * session's IDR entry with that session, otherwise
416+ * remove the IDR entry. If this is the last session,
417+ * the collision list data is freed.
418+ */
419+ spin_lock (& clist -> lock );
420+ list_del_init (& session -> clist );
421+ session2 = list_first_entry_or_null (& clist -> list , struct l2tp_session , clist );
422+ if (session2 ) {
423+ void * old = idr_replace (& pn -> l2tp_v3_session_idr , session2 , session_key );
424+
425+ WARN_ON_ONCE (IS_ERR_VALUE (old ));
426+ } else {
427+ void * removed = idr_remove (& pn -> l2tp_v3_session_idr , session_key );
428+
429+ WARN_ON_ONCE (removed != session );
430+ }
431+ session -> coll_list = NULL ;
432+ spin_unlock (& clist -> lock );
433+ if (refcount_dec_and_test (& clist -> ref_count ))
434+ kfree (clist );
435+ l2tp_session_dec_refcount (session );
436+ }
437+ }
438+
337439int l2tp_session_register (struct l2tp_session * session ,
338440 struct l2tp_tunnel * tunnel )
339441{
442+ struct l2tp_net * pn = l2tp_pernet (tunnel -> l2tp_net );
340443 struct l2tp_session * session_walk ;
341- struct hlist_head * g_head ;
342444 struct hlist_head * head ;
343- struct l2tp_net * pn ;
445+ u32 session_key ;
344446 int err ;
345447
346448 head = l2tp_session_id_hash (tunnel , session -> session_id );
@@ -358,39 +460,45 @@ int l2tp_session_register(struct l2tp_session *session,
358460 }
359461
360462 if (tunnel -> version == L2TP_HDR_VER_3 ) {
361- pn = l2tp_pernet (tunnel -> l2tp_net );
362- g_head = l2tp_session_id_hash_2 (pn , session -> session_id );
363-
364- spin_lock_bh (& pn -> l2tp_session_hlist_lock );
365-
463+ session_key = session -> session_id ;
464+ spin_lock_bh (& pn -> l2tp_session_idr_lock );
465+ err = idr_alloc_u32 (& pn -> l2tp_v3_session_idr , NULL ,
466+ & session_key , session_key , GFP_ATOMIC );
366467 /* IP encap expects session IDs to be globally unique, while
367- * UDP encap doesn't.
468+ * UDP encap doesn't. This isn't per the RFC, which says that
469+ * sessions are identified only by the session ID, but is to
470+ * support existing userspace which depends on it.
368471 */
369- hlist_for_each_entry (session_walk , g_head , global_hlist )
370- if (session_walk -> session_id == session -> session_id &&
371- (session_walk -> tunnel -> encap == L2TP_ENCAPTYPE_IP ||
372- tunnel -> encap == L2TP_ENCAPTYPE_IP )) {
373- err = - EEXIST ;
374- goto err_tlock_pnlock ;
375- }
472+ if (err == - ENOSPC && tunnel -> encap == L2TP_ENCAPTYPE_UDP ) {
473+ struct l2tp_session * session2 ;
376474
377- l2tp_tunnel_inc_refcount (tunnel );
378- hlist_add_head_rcu (& session -> global_hlist , g_head );
379-
380- spin_unlock_bh (& pn -> l2tp_session_hlist_lock );
381- } else {
382- l2tp_tunnel_inc_refcount (tunnel );
475+ session2 = idr_find (& pn -> l2tp_v3_session_idr ,
476+ session_key );
477+ err = l2tp_session_collision_add (pn , session , session2 );
478+ }
479+ spin_unlock_bh (& pn -> l2tp_session_idr_lock );
480+ if (err == - ENOSPC )
481+ err = - EEXIST ;
383482 }
384483
484+ if (err )
485+ goto err_tlock ;
486+
487+ l2tp_tunnel_inc_refcount (tunnel );
488+
385489 hlist_add_head_rcu (& session -> hlist , head );
386490 spin_unlock_bh (& tunnel -> hlist_lock );
387491
492+ if (tunnel -> version == L2TP_HDR_VER_3 ) {
493+ spin_lock_bh (& pn -> l2tp_session_idr_lock );
494+ idr_replace (& pn -> l2tp_v3_session_idr , session , session_key );
495+ spin_unlock_bh (& pn -> l2tp_session_idr_lock );
496+ }
497+
388498 trace_register_session (session );
389499
390500 return 0 ;
391501
392- err_tlock_pnlock :
393- spin_unlock_bh (& pn -> l2tp_session_hlist_lock );
394502err_tlock :
395503 spin_unlock_bh (& tunnel -> hlist_lock );
396504
@@ -1218,13 +1326,19 @@ static void l2tp_session_unhash(struct l2tp_session *session)
12181326 hlist_del_init_rcu (& session -> hlist );
12191327 spin_unlock_bh (& tunnel -> hlist_lock );
12201328
1221- /* For L2TPv3 we have a per-net hash : remove from there, too */
1222- if (tunnel -> version != L2TP_HDR_VER_2 ) {
1329+ /* For L2TPv3 we have a per-net IDR : remove from there, too */
1330+ if (tunnel -> version == L2TP_HDR_VER_3 ) {
12231331 struct l2tp_net * pn = l2tp_pernet (tunnel -> l2tp_net );
1224-
1225- spin_lock_bh (& pn -> l2tp_session_hlist_lock );
1226- hlist_del_init_rcu (& session -> global_hlist );
1227- spin_unlock_bh (& pn -> l2tp_session_hlist_lock );
1332+ struct l2tp_session * removed = session ;
1333+
1334+ spin_lock_bh (& pn -> l2tp_session_idr_lock );
1335+ if (hash_hashed (& session -> hlist ))
1336+ l2tp_session_collision_del (pn , session );
1337+ else
1338+ removed = idr_remove (& pn -> l2tp_v3_session_idr ,
1339+ session -> session_id );
1340+ WARN_ON_ONCE (removed && removed != session );
1341+ spin_unlock_bh (& pn -> l2tp_session_idr_lock );
12281342 }
12291343
12301344 synchronize_rcu ();
@@ -1649,8 +1763,9 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
16491763
16501764 skb_queue_head_init (& session -> reorder_q );
16511765
1766+ session -> hlist_key = l2tp_v3_session_hashkey (tunnel -> sock , session -> session_id );
16521767 INIT_HLIST_NODE (& session -> hlist );
1653- INIT_HLIST_NODE (& session -> global_hlist );
1768+ INIT_LIST_HEAD (& session -> clist );
16541769
16551770 if (cfg ) {
16561771 session -> pwtype = cfg -> pw_type ;
@@ -1683,15 +1798,12 @@ EXPORT_SYMBOL_GPL(l2tp_session_create);
16831798static __net_init int l2tp_init_net (struct net * net )
16841799{
16851800 struct l2tp_net * pn = net_generic (net , l2tp_net_id );
1686- int hash ;
16871801
16881802 idr_init (& pn -> l2tp_tunnel_idr );
16891803 spin_lock_init (& pn -> l2tp_tunnel_idr_lock );
16901804
1691- for (hash = 0 ; hash < L2TP_HASH_SIZE_2 ; hash ++ )
1692- INIT_HLIST_HEAD (& pn -> l2tp_session_hlist [hash ]);
1693-
1694- spin_lock_init (& pn -> l2tp_session_hlist_lock );
1805+ idr_init (& pn -> l2tp_v3_session_idr );
1806+ spin_lock_init (& pn -> l2tp_session_idr_lock );
16951807
16961808 return 0 ;
16971809}
@@ -1701,7 +1813,6 @@ static __net_exit void l2tp_exit_net(struct net *net)
17011813 struct l2tp_net * pn = l2tp_pernet (net );
17021814 struct l2tp_tunnel * tunnel = NULL ;
17031815 unsigned long tunnel_id , tmp ;
1704- int hash ;
17051816
17061817 rcu_read_lock_bh ();
17071818 idr_for_each_entry_ul (& pn -> l2tp_tunnel_idr , tunnel , tmp , tunnel_id ) {
@@ -1714,8 +1825,7 @@ static __net_exit void l2tp_exit_net(struct net *net)
17141825 flush_workqueue (l2tp_wq );
17151826 rcu_barrier ();
17161827
1717- for (hash = 0 ; hash < L2TP_HASH_SIZE_2 ; hash ++ )
1718- WARN_ON_ONCE (!hlist_empty (& pn -> l2tp_session_hlist [hash ]));
1828+ idr_destroy (& pn -> l2tp_v3_session_idr );
17191829 idr_destroy (& pn -> l2tp_tunnel_idr );
17201830}
17211831
0 commit comments