@@ -361,25 +361,15 @@ static void smc_destruct(struct sock *sk)
361361 return ;
362362}
363363
364- static struct sock * smc_sock_alloc (struct net * net , struct socket * sock ,
365- int protocol )
364+ void smc_sk_init (struct net * net , struct sock * sk , int protocol )
366365{
367- struct smc_sock * smc ;
368- struct proto * prot ;
369- struct sock * sk ;
370-
371- prot = (protocol == SMCPROTO_SMC6 ) ? & smc_proto6 : & smc_proto ;
372- sk = sk_alloc (net , PF_SMC , GFP_KERNEL , prot , 0 );
373- if (!sk )
374- return NULL ;
366+ struct smc_sock * smc = smc_sk (sk );
375367
376- sock_init_data (sock , sk ); /* sets sk_refcnt to 1 */
377368 sk -> sk_state = SMC_INIT ;
378369 sk -> sk_destruct = smc_destruct ;
379370 sk -> sk_protocol = protocol ;
380371 WRITE_ONCE (sk -> sk_sndbuf , 2 * READ_ONCE (net -> smc .sysctl_wmem ));
381372 WRITE_ONCE (sk -> sk_rcvbuf , 2 * READ_ONCE (net -> smc .sysctl_rmem ));
382- smc = smc_sk (sk );
383373 INIT_WORK (& smc -> tcp_listen_work , smc_tcp_listen_work );
384374 INIT_WORK (& smc -> connect_work , smc_connect_work );
385375 INIT_DELAYED_WORK (& smc -> conn .tx_work , smc_tx_work );
@@ -389,6 +379,24 @@ static struct sock *smc_sock_alloc(struct net *net, struct socket *sock,
389379 sk -> sk_prot -> hash (sk );
390380 mutex_init (& smc -> clcsock_release_lock );
391381 smc_init_saved_callbacks (smc );
382+ smc -> limit_smc_hs = net -> smc .limit_smc_hs ;
383+ smc -> use_fallback = false; /* assume rdma capability first */
384+ smc -> fallback_rsn = 0 ;
385+ }
386+
387+ static struct sock * smc_sock_alloc (struct net * net , struct socket * sock ,
388+ int protocol )
389+ {
390+ struct proto * prot ;
391+ struct sock * sk ;
392+
393+ prot = (protocol == SMCPROTO_SMC6 ) ? & smc_proto6 : & smc_proto ;
394+ sk = sk_alloc (net , PF_SMC , GFP_KERNEL , prot , 0 );
395+ if (!sk )
396+ return NULL ;
397+
398+ sock_init_data (sock , sk ); /* sets sk_refcnt to 1 */
399+ smc_sk_init (net , sk , protocol );
392400
393401 return sk ;
394402}
@@ -3303,6 +3311,31 @@ static const struct proto_ops smc_sock_ops = {
33033311 .splice_read = smc_splice_read ,
33043312};
33053313
3314+ int smc_create_clcsk (struct net * net , struct sock * sk , int family )
3315+ {
3316+ struct smc_sock * smc = smc_sk (sk );
3317+ int rc ;
3318+
3319+ rc = sock_create_kern (net , family , SOCK_STREAM , IPPROTO_TCP ,
3320+ & smc -> clcsock );
3321+ if (rc ) {
3322+ sk_common_release (sk );
3323+ return rc ;
3324+ }
3325+
3326+ /* smc_clcsock_release() does not wait smc->clcsock->sk's
3327+ * destruction; its sk_state might not be TCP_CLOSE after
3328+ * smc->sk is close()d, and TCP timers can be fired later,
3329+ * which need net ref.
3330+ */
3331+ sk = smc -> clcsock -> sk ;
3332+ __netns_tracker_free (net , & sk -> ns_tracker , false);
3333+ sk -> sk_net_refcnt = 1 ;
3334+ get_net_track (net , & sk -> ns_tracker , GFP_KERNEL );
3335+ sock_inuse_add (net , 1 );
3336+ return 0 ;
3337+ }
3338+
33063339static int __smc_create (struct net * net , struct socket * sock , int protocol ,
33073340 int kern , struct socket * clcsock )
33083341{
@@ -3328,35 +3361,12 @@ static int __smc_create(struct net *net, struct socket *sock, int protocol,
33283361
33293362 /* create internal TCP socket for CLC handshake and fallback */
33303363 smc = smc_sk (sk );
3331- smc -> use_fallback = false; /* assume rdma capability first */
3332- smc -> fallback_rsn = 0 ;
3333-
3334- /* default behavior from limit_smc_hs in every net namespace */
3335- smc -> limit_smc_hs = net -> smc .limit_smc_hs ;
33363364
33373365 rc = 0 ;
3338- if (!clcsock ) {
3339- rc = sock_create_kern (net , family , SOCK_STREAM , IPPROTO_TCP ,
3340- & smc -> clcsock );
3341- if (rc ) {
3342- sk_common_release (sk );
3343- goto out ;
3344- }
3345-
3346- /* smc_clcsock_release() does not wait smc->clcsock->sk's
3347- * destruction; its sk_state might not be TCP_CLOSE after
3348- * smc->sk is close()d, and TCP timers can be fired later,
3349- * which need net ref.
3350- */
3351- sk = smc -> clcsock -> sk ;
3352- __netns_tracker_free (net , & sk -> ns_tracker , false);
3353- sk -> sk_net_refcnt = 1 ;
3354- get_net_track (net , & sk -> ns_tracker , GFP_KERNEL );
3355- sock_inuse_add (net , 1 );
3356- } else {
3366+ if (clcsock )
33573367 smc -> clcsock = clcsock ;
3358- }
3359-
3368+ else
3369+ rc = smc_create_clcsk ( net , sk , family );
33603370out :
33613371 return rc ;
33623372}
0 commit comments