122122struct pppol2tp_session {
123123 int owner ; /* pid that opened the socket */
124124
125- struct sock * sock ; /* Pointer to the session
125+ struct mutex sk_lock ; /* Protects .sk */
126+ struct sock __rcu * sk ; /* Pointer to the session
126127 * PPPoX socket */
128+ struct sock * __sk ; /* Copy of .sk, for cleanup */
129+ struct rcu_head rcu ; /* For asynchronous release */
127130 struct sock * tunnel_sock ; /* Pointer to the tunnel UDP
128131 * socket */
129132 int flags ; /* accessed by PPPIOCGFLAGS.
@@ -138,6 +141,24 @@ static const struct ppp_channel_ops pppol2tp_chan_ops = {
138141
139142static const struct proto_ops pppol2tp_ops ;
140143
144+ /* Retrieves the pppol2tp socket associated to a session.
145+ * A reference is held on the returned socket, so this function must be paired
146+ * with sock_put().
147+ */
148+ static struct sock * pppol2tp_session_get_sock (struct l2tp_session * session )
149+ {
150+ struct pppol2tp_session * ps = l2tp_session_priv (session );
151+ struct sock * sk ;
152+
153+ rcu_read_lock ();
154+ sk = rcu_dereference (ps -> sk );
155+ if (sk )
156+ sock_hold (sk );
157+ rcu_read_unlock ();
158+
159+ return sk ;
160+ }
161+
141162/* Helpers to obtain tunnel/session contexts from sockets.
142163 */
143164static inline struct l2tp_session * pppol2tp_sock_to_session (struct sock * sk )
@@ -224,7 +245,8 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
224245 /* If the socket is bound, send it in to PPP's input queue. Otherwise
225246 * queue it on the session socket.
226247 */
227- sk = ps -> sock ;
248+ rcu_read_lock ();
249+ sk = rcu_dereference (ps -> sk );
228250 if (sk == NULL )
229251 goto no_sock ;
230252
@@ -247,30 +269,16 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
247269 kfree_skb (skb );
248270 }
249271 }
272+ rcu_read_unlock ();
250273
251274 return ;
252275
253276no_sock :
277+ rcu_read_unlock ();
254278 l2tp_info (session , L2TP_MSG_DATA , "%s: no socket\n" , session -> name );
255279 kfree_skb (skb );
256280}
257281
258- static void pppol2tp_session_sock_hold (struct l2tp_session * session )
259- {
260- struct pppol2tp_session * ps = l2tp_session_priv (session );
261-
262- if (ps -> sock )
263- sock_hold (ps -> sock );
264- }
265-
266- static void pppol2tp_session_sock_put (struct l2tp_session * session )
267- {
268- struct pppol2tp_session * ps = l2tp_session_priv (session );
269-
270- if (ps -> sock )
271- sock_put (ps -> sock );
272- }
273-
274282/************************************************************************
275283 * Transmit handling
276284 ***********************************************************************/
@@ -431,14 +439,16 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
431439 */
432440static void pppol2tp_session_close (struct l2tp_session * session )
433441{
434- struct pppol2tp_session * ps = l2tp_session_priv (session );
435- struct sock * sk = ps -> sock ;
436- struct socket * sock = sk -> sk_socket ;
442+ struct sock * sk ;
437443
438444 BUG_ON (session -> magic != L2TP_SESSION_MAGIC );
439445
440- if (sock )
441- inet_shutdown (sock , SEND_SHUTDOWN );
446+ sk = pppol2tp_session_get_sock (session );
447+ if (sk ) {
448+ if (sk -> sk_socket )
449+ inet_shutdown (sk -> sk_socket , SEND_SHUTDOWN );
450+ sock_put (sk );
451+ }
442452
443453 /* Don't let the session go away before our socket does */
444454 l2tp_session_inc_refcount (session );
@@ -461,6 +471,14 @@ static void pppol2tp_session_destruct(struct sock *sk)
461471 }
462472}
463473
474+ static void pppol2tp_put_sk (struct rcu_head * head )
475+ {
476+ struct pppol2tp_session * ps ;
477+
478+ ps = container_of (head , typeof (* ps ), rcu );
479+ sock_put (ps -> __sk );
480+ }
481+
464482/* Called when the PPPoX socket (session) is closed.
465483 */
466484static int pppol2tp_release (struct socket * sock )
@@ -486,11 +504,24 @@ static int pppol2tp_release(struct socket *sock)
486504
487505 session = pppol2tp_sock_to_session (sk );
488506
489- /* Purge any queued data */
490507 if (session != NULL ) {
508+ struct pppol2tp_session * ps ;
509+
491510 __l2tp_session_unhash (session );
492511 l2tp_session_queue_purge (session );
493- sock_put (sk );
512+
513+ ps = l2tp_session_priv (session );
514+ mutex_lock (& ps -> sk_lock );
515+ ps -> __sk = rcu_dereference_protected (ps -> sk ,
516+ lockdep_is_held (& ps -> sk_lock ));
517+ RCU_INIT_POINTER (ps -> sk , NULL );
518+ mutex_unlock (& ps -> sk_lock );
519+ call_rcu (& ps -> rcu , pppol2tp_put_sk );
520+
521+ /* Rely on the sock_put() call at the end of the function for
522+ * dropping the reference held by pppol2tp_sock_to_session().
523+ * The last reference will be dropped by pppol2tp_put_sk().
524+ */
494525 }
495526 release_sock (sk );
496527
@@ -557,12 +588,14 @@ static int pppol2tp_create(struct net *net, struct socket *sock, int kern)
557588static void pppol2tp_show (struct seq_file * m , void * arg )
558589{
559590 struct l2tp_session * session = arg ;
560- struct pppol2tp_session * ps = l2tp_session_priv (session );
591+ struct sock * sk ;
592+
593+ sk = pppol2tp_session_get_sock (session );
594+ if (sk ) {
595+ struct pppox_sock * po = pppox_sk (sk );
561596
562- if (ps ) {
563- struct pppox_sock * po = pppox_sk (ps -> sock );
564- if (po )
565- seq_printf (m , " interface %s\n" , ppp_dev_name (& po -> chan ));
597+ seq_printf (m , " interface %s\n" , ppp_dev_name (& po -> chan ));
598+ sock_put (sk );
566599 }
567600}
568601#endif
@@ -693,13 +726,17 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
693726 /* Using a pre-existing session is fine as long as it hasn't
694727 * been connected yet.
695728 */
696- if (ps -> sock ) {
729+ mutex_lock (& ps -> sk_lock );
730+ if (rcu_dereference_protected (ps -> sk ,
731+ lockdep_is_held (& ps -> sk_lock ))) {
732+ mutex_unlock (& ps -> sk_lock );
697733 error = - EEXIST ;
698734 goto end ;
699735 }
700736
701737 /* consistency checks */
702738 if (ps -> tunnel_sock != tunnel -> sock ) {
739+ mutex_unlock (& ps -> sk_lock );
703740 error = - EEXIST ;
704741 goto end ;
705742 }
@@ -716,19 +753,21 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
716753 goto end ;
717754 }
718755
756+ ps = l2tp_session_priv (session );
757+ mutex_init (& ps -> sk_lock );
719758 l2tp_session_inc_refcount (session );
759+
760+ mutex_lock (& ps -> sk_lock );
720761 error = l2tp_session_register (session , tunnel );
721762 if (error < 0 ) {
763+ mutex_unlock (& ps -> sk_lock );
722764 kfree (session );
723765 goto end ;
724766 }
725767 drop_refcnt = true;
726768 }
727769
728- /* Associate session with its PPPoL2TP socket */
729- ps = l2tp_session_priv (session );
730770 ps -> owner = current -> pid ;
731- ps -> sock = sk ;
732771 ps -> tunnel_sock = tunnel -> sock ;
733772
734773 session -> recv_skb = pppol2tp_recv ;
@@ -737,12 +776,6 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
737776 session -> show = pppol2tp_show ;
738777#endif
739778
740- /* We need to know each time a skb is dropped from the reorder
741- * queue.
742- */
743- session -> ref = pppol2tp_session_sock_hold ;
744- session -> deref = pppol2tp_session_sock_put ;
745-
746779 /* If PMTU discovery was enabled, use the MTU that was discovered */
747780 dst = sk_dst_get (tunnel -> sock );
748781 if (dst != NULL ) {
@@ -776,12 +809,17 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
776809 po -> chan .mtu = session -> mtu ;
777810
778811 error = ppp_register_net_channel (sock_net (sk ), & po -> chan );
779- if (error )
812+ if (error ) {
813+ mutex_unlock (& ps -> sk_lock );
780814 goto end ;
815+ }
781816
782817out_no_ppp :
783818 /* This is how we get the session context from the socket. */
784819 sk -> sk_user_data = session ;
820+ rcu_assign_pointer (ps -> sk , sk );
821+ mutex_unlock (& ps -> sk_lock );
822+
785823 sk -> sk_state = PPPOX_CONNECTED ;
786824 l2tp_info (session , L2TP_MSG_CONTROL , "%s: created\n" ,
787825 session -> name );
@@ -827,6 +865,7 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel,
827865 }
828866
829867 ps = l2tp_session_priv (session );
868+ mutex_init (& ps -> sk_lock );
830869 ps -> tunnel_sock = tunnel -> sock ;
831870
832871 error = l2tp_session_register (session , tunnel );
@@ -998,12 +1037,10 @@ static int pppol2tp_session_ioctl(struct l2tp_session *session,
9981037 "%s: pppol2tp_session_ioctl(cmd=%#x, arg=%#lx)\n" ,
9991038 session -> name , cmd , arg );
10001039
1001- sk = ps -> sock ;
1040+ sk = pppol2tp_session_get_sock ( session ) ;
10021041 if (!sk )
10031042 return - EBADR ;
10041043
1005- sock_hold (sk );
1006-
10071044 switch (cmd ) {
10081045 case SIOCGIFMTU :
10091046 err = - ENXIO ;
@@ -1279,7 +1316,6 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
12791316 int optname , int val )
12801317{
12811318 int err = 0 ;
1282- struct pppol2tp_session * ps = l2tp_session_priv (session );
12831319
12841320 switch (optname ) {
12851321 case PPPOL2TP_SO_RECVSEQ :
@@ -1300,8 +1336,8 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
13001336 }
13011337 session -> send_seq = !!val ;
13021338 {
1303- struct sock * ssk = ps -> sock ;
1304- struct pppox_sock * po = pppox_sk ( ssk );
1339+ struct pppox_sock * po = pppox_sk ( sk ) ;
1340+
13051341 po -> chan .hdrlen = val ? PPPOL2TP_L2TP_HDR_SIZE_SEQ :
13061342 PPPOL2TP_L2TP_HDR_SIZE_NOSEQ ;
13071343 }
@@ -1640,8 +1676,9 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
16401676{
16411677 struct l2tp_session * session = v ;
16421678 struct l2tp_tunnel * tunnel = session -> tunnel ;
1643- struct pppol2tp_session * ps = l2tp_session_priv (session );
1644- struct pppox_sock * po = pppox_sk (ps -> sock );
1679+ unsigned char state ;
1680+ char user_data_ok ;
1681+ struct sock * sk ;
16451682 u32 ip = 0 ;
16461683 u16 port = 0 ;
16471684
@@ -1651,16 +1688,23 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
16511688 port = ntohs (inet -> inet_sport );
16521689 }
16531690
1691+ sk = pppol2tp_session_get_sock (session );
1692+ if (sk ) {
1693+ state = sk -> sk_state ;
1694+ user_data_ok = (session == sk -> sk_user_data ) ? 'Y' : 'N' ;
1695+ } else {
1696+ state = 0 ;
1697+ user_data_ok = 'N' ;
1698+ }
1699+
16541700 seq_printf (m , " SESSION '%s' %08X/%d %04X/%04X -> "
16551701 "%04X/%04X %d %c\n" ,
16561702 session -> name , ip , port ,
16571703 tunnel -> tunnel_id ,
16581704 session -> session_id ,
16591705 tunnel -> peer_tunnel_id ,
16601706 session -> peer_session_id ,
1661- ps -> sock -> sk_state ,
1662- (session == ps -> sock -> sk_user_data ) ?
1663- 'Y' : 'N' );
1707+ state , user_data_ok );
16641708 seq_printf (m , " %d/%d/%c/%c/%s %08x %u\n" ,
16651709 session -> mtu , session -> mru ,
16661710 session -> recv_seq ? 'R' : '-' ,
@@ -1677,8 +1721,12 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
16771721 atomic_long_read (& session -> stats .rx_bytes ),
16781722 atomic_long_read (& session -> stats .rx_errors ));
16791723
1680- if (po )
1724+ if (sk ) {
1725+ struct pppox_sock * po = pppox_sk (sk );
1726+
16811727 seq_printf (m , " interface %s\n" , ppp_dev_name (& po -> chan ));
1728+ sock_put (sk );
1729+ }
16821730}
16831731
16841732static int pppol2tp_seq_show (struct seq_file * m , void * v )
0 commit comments