@@ -39,6 +39,7 @@ struct skcipher_ctx {
3939
4040 struct af_alg_completion completion ;
4141
42+ atomic_t inflight ;
4243 unsigned used ;
4344
4445 unsigned int len ;
@@ -49,9 +50,65 @@ struct skcipher_ctx {
4950 struct ablkcipher_request req ;
5051};
5152
53+ struct skcipher_async_rsgl {
54+ struct af_alg_sgl sgl ;
55+ struct list_head list ;
56+ };
57+
58+ struct skcipher_async_req {
59+ struct kiocb * iocb ;
60+ struct skcipher_async_rsgl first_sgl ;
61+ struct list_head list ;
62+ struct scatterlist * tsg ;
63+ char iv [];
64+ };
65+
66+ #define GET_SREQ (areq , ctx ) (struct skcipher_async_req *)((char *)areq + \
67+ crypto_ablkcipher_reqsize(crypto_ablkcipher_reqtfm(&ctx->req)))
68+
69+ #define GET_REQ_SIZE (ctx ) \
70+ crypto_ablkcipher_reqsize(crypto_ablkcipher_reqtfm(&ctx->req))
71+
72+ #define GET_IV_SIZE (ctx ) \
73+ crypto_ablkcipher_ivsize(crypto_ablkcipher_reqtfm(&ctx->req))
74+
5275#define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
5376 sizeof(struct scatterlist) - 1)
5477
78+ static void skcipher_free_async_sgls (struct skcipher_async_req * sreq )
79+ {
80+ struct skcipher_async_rsgl * rsgl , * tmp ;
81+ struct scatterlist * sgl ;
82+ struct scatterlist * sg ;
83+ int i , n ;
84+
85+ list_for_each_entry_safe (rsgl , tmp , & sreq -> list , list ) {
86+ af_alg_free_sg (& rsgl -> sgl );
87+ if (rsgl != & sreq -> first_sgl )
88+ kfree (rsgl );
89+ }
90+ sgl = sreq -> tsg ;
91+ n = sg_nents (sgl );
92+ for_each_sg (sgl , sg , n , i )
93+ put_page (sg_page (sg ));
94+
95+ kfree (sreq -> tsg );
96+ }
97+
98+ static void skcipher_async_cb (struct crypto_async_request * req , int err )
99+ {
100+ struct sock * sk = req -> data ;
101+ struct alg_sock * ask = alg_sk (sk );
102+ struct skcipher_ctx * ctx = ask -> private ;
103+ struct skcipher_async_req * sreq = GET_SREQ (req , ctx );
104+ struct kiocb * iocb = sreq -> iocb ;
105+
106+ atomic_dec (& ctx -> inflight );
107+ skcipher_free_async_sgls (sreq );
108+ kfree (req );
109+ aio_complete (iocb , err , err );
110+ }
111+
55112static inline int skcipher_sndbuf (struct sock * sk )
56113{
57114 struct alg_sock * ask = alg_sk (sk );
@@ -96,7 +153,7 @@ static int skcipher_alloc_sgl(struct sock *sk)
96153 return 0 ;
97154}
98155
99- static void skcipher_pull_sgl (struct sock * sk , int used )
156+ static void skcipher_pull_sgl (struct sock * sk , int used , int put )
100157{
101158 struct alg_sock * ask = alg_sk (sk );
102159 struct skcipher_ctx * ctx = ask -> private ;
@@ -123,8 +180,8 @@ static void skcipher_pull_sgl(struct sock *sk, int used)
123180
124181 if (sg [i ].length )
125182 return ;
126-
127- put_page (sg_page (sg + i ));
183+ if ( put )
184+ put_page (sg_page (sg + i ));
128185 sg_assign_page (sg + i , NULL );
129186 }
130187
@@ -143,7 +200,7 @@ static void skcipher_free_sgl(struct sock *sk)
143200 struct alg_sock * ask = alg_sk (sk );
144201 struct skcipher_ctx * ctx = ask -> private ;
145202
146- skcipher_pull_sgl (sk , ctx -> used );
203+ skcipher_pull_sgl (sk , ctx -> used , 1 );
147204}
148205
149206static int skcipher_wait_for_wmem (struct sock * sk , unsigned flags )
@@ -424,8 +481,149 @@ static ssize_t skcipher_sendpage(struct socket *sock, struct page *page,
424481 return err ?: size ;
425482}
426483
427- static int skcipher_recvmsg (struct socket * sock , struct msghdr * msg ,
428- size_t ignored , int flags )
484+ static int skcipher_all_sg_nents (struct skcipher_ctx * ctx )
485+ {
486+ struct skcipher_sg_list * sgl ;
487+ struct scatterlist * sg ;
488+ int nents = 0 ;
489+
490+ list_for_each_entry (sgl , & ctx -> tsgl , list ) {
491+ sg = sgl -> sg ;
492+
493+ while (!sg -> length )
494+ sg ++ ;
495+
496+ nents += sg_nents (sg );
497+ }
498+ return nents ;
499+ }
500+
501+ static int skcipher_recvmsg_async (struct socket * sock , struct msghdr * msg ,
502+ int flags )
503+ {
504+ struct sock * sk = sock -> sk ;
505+ struct alg_sock * ask = alg_sk (sk );
506+ struct skcipher_ctx * ctx = ask -> private ;
507+ struct skcipher_sg_list * sgl ;
508+ struct scatterlist * sg ;
509+ struct skcipher_async_req * sreq ;
510+ struct ablkcipher_request * req ;
511+ struct skcipher_async_rsgl * last_rsgl = NULL ;
512+ unsigned int len = 0 , tx_nents = skcipher_all_sg_nents (ctx );
513+ unsigned int reqlen = sizeof (struct skcipher_async_req ) +
514+ GET_REQ_SIZE (ctx ) + GET_IV_SIZE (ctx );
515+ int i = 0 ;
516+ int err = - ENOMEM ;
517+
518+ lock_sock (sk );
519+ req = kmalloc (reqlen , GFP_KERNEL );
520+ if (unlikely (!req ))
521+ goto unlock ;
522+
523+ sreq = GET_SREQ (req , ctx );
524+ sreq -> iocb = msg -> msg_iocb ;
525+ memset (& sreq -> first_sgl , '\0' , sizeof (struct skcipher_async_rsgl ));
526+ INIT_LIST_HEAD (& sreq -> list );
527+ sreq -> tsg = kcalloc (tx_nents , sizeof (* sg ), GFP_KERNEL );
528+ if (unlikely (!sreq -> tsg )) {
529+ kfree (req );
530+ goto unlock ;
531+ }
532+ sg_init_table (sreq -> tsg , tx_nents );
533+ memcpy (sreq -> iv , ctx -> iv , GET_IV_SIZE (ctx ));
534+ ablkcipher_request_set_tfm (req , crypto_ablkcipher_reqtfm (& ctx -> req ));
535+ ablkcipher_request_set_callback (req , CRYPTO_TFM_REQ_MAY_BACKLOG ,
536+ skcipher_async_cb , sk );
537+
538+ while (iov_iter_count (& msg -> msg_iter )) {
539+ struct skcipher_async_rsgl * rsgl ;
540+ unsigned long used ;
541+
542+ if (!ctx -> used ) {
543+ err = skcipher_wait_for_data (sk , flags );
544+ if (err )
545+ goto free ;
546+ }
547+ sgl = list_first_entry (& ctx -> tsgl ,
548+ struct skcipher_sg_list , list );
549+ sg = sgl -> sg ;
550+
551+ while (!sg -> length )
552+ sg ++ ;
553+
554+ used = min_t (unsigned long , ctx -> used ,
555+ iov_iter_count (& msg -> msg_iter ));
556+ used = min_t (unsigned long , used , sg -> length );
557+
558+ if (i == tx_nents ) {
559+ struct scatterlist * tmp ;
560+ int x ;
561+ /* Ran out of tx slots in async request
562+ * need to expand */
563+ tmp = kcalloc (tx_nents * 2 , sizeof (* tmp ),
564+ GFP_KERNEL );
565+ if (!tmp )
566+ goto free ;
567+
568+ sg_init_table (tmp , tx_nents * 2 );
569+ for (x = 0 ; x < tx_nents ; x ++ )
570+ sg_set_page (& tmp [x ], sg_page (& sreq -> tsg [x ]),
571+ sreq -> tsg [x ].length ,
572+ sreq -> tsg [x ].offset );
573+ kfree (sreq -> tsg );
574+ sreq -> tsg = tmp ;
575+ tx_nents *= 2 ;
576+ }
577+ /* Need to take over the tx sgl from ctx
578+ * to the asynch req - these sgls will be freed later */
579+ sg_set_page (sreq -> tsg + i ++ , sg_page (sg ), sg -> length ,
580+ sg -> offset );
581+
582+ if (list_empty (& sreq -> list )) {
583+ rsgl = & sreq -> first_sgl ;
584+ list_add_tail (& rsgl -> list , & sreq -> list );
585+ } else {
586+ rsgl = kzalloc (sizeof (* rsgl ), GFP_KERNEL );
587+ if (!rsgl ) {
588+ err = - ENOMEM ;
589+ goto free ;
590+ }
591+ list_add_tail (& rsgl -> list , & sreq -> list );
592+ }
593+
594+ used = af_alg_make_sg (& rsgl -> sgl , & msg -> msg_iter , used );
595+ err = used ;
596+ if (used < 0 )
597+ goto free ;
598+ if (last_rsgl )
599+ af_alg_link_sg (& last_rsgl -> sgl , & rsgl -> sgl );
600+
601+ last_rsgl = rsgl ;
602+ len += used ;
603+ skcipher_pull_sgl (sk , used , 0 );
604+ iov_iter_advance (& msg -> msg_iter , used );
605+ }
606+
607+ ablkcipher_request_set_crypt (req , sreq -> tsg , sreq -> first_sgl .sgl .sg ,
608+ len , sreq -> iv );
609+ err = ctx -> enc ? crypto_ablkcipher_encrypt (req ) :
610+ crypto_ablkcipher_decrypt (req );
611+ if (err == - EINPROGRESS ) {
612+ atomic_inc (& ctx -> inflight );
613+ err = - EIOCBQUEUED ;
614+ goto unlock ;
615+ }
616+ free :
617+ skcipher_free_async_sgls (sreq );
618+ kfree (req );
619+ unlock :
620+ skcipher_wmem_wakeup (sk );
621+ release_sock (sk );
622+ return err ;
623+ }
624+
625+ static int skcipher_recvmsg_sync (struct socket * sock , struct msghdr * msg ,
626+ int flags )
429627{
430628 struct sock * sk = sock -> sk ;
431629 struct alg_sock * ask = alg_sk (sk );
@@ -484,7 +682,7 @@ static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
484682 goto unlock ;
485683
486684 copied += used ;
487- skcipher_pull_sgl (sk , used );
685+ skcipher_pull_sgl (sk , used , 1 );
488686 iov_iter_advance (& msg -> msg_iter , used );
489687 }
490688
@@ -497,6 +695,13 @@ static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
497695 return copied ?: err ;
498696}
499697
698+ static int skcipher_recvmsg (struct socket * sock , struct msghdr * msg ,
699+ size_t ignored , int flags )
700+ {
701+ return (msg -> msg_iocb && !is_sync_kiocb (msg -> msg_iocb )) ?
702+ skcipher_recvmsg_async (sock , msg , flags ) :
703+ skcipher_recvmsg_sync (sock , msg , flags );
704+ }
500705
501706static unsigned int skcipher_poll (struct file * file , struct socket * sock ,
502707 poll_table * wait )
@@ -555,12 +760,25 @@ static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
555760 return crypto_ablkcipher_setkey (private , key , keylen );
556761}
557762
763+ static void skcipher_wait (struct sock * sk )
764+ {
765+ struct alg_sock * ask = alg_sk (sk );
766+ struct skcipher_ctx * ctx = ask -> private ;
767+ int ctr = 0 ;
768+
769+ while (atomic_read (& ctx -> inflight ) && ctr ++ < 100 )
770+ msleep (100 );
771+ }
772+
558773static void skcipher_sock_destruct (struct sock * sk )
559774{
560775 struct alg_sock * ask = alg_sk (sk );
561776 struct skcipher_ctx * ctx = ask -> private ;
562777 struct crypto_ablkcipher * tfm = crypto_ablkcipher_reqtfm (& ctx -> req );
563778
779+ if (atomic_read (& ctx -> inflight ))
780+ skcipher_wait (sk );
781+
564782 skcipher_free_sgl (sk );
565783 sock_kzfree_s (sk , ctx -> iv , crypto_ablkcipher_ivsize (tfm ));
566784 sock_kfree_s (sk , ctx , ctx -> len );
@@ -592,6 +810,7 @@ static int skcipher_accept_parent(void *private, struct sock *sk)
592810 ctx -> more = 0 ;
593811 ctx -> merge = 0 ;
594812 ctx -> enc = 0 ;
813+ atomic_set (& ctx -> inflight , 0 );
595814 af_alg_init_completion (& ctx -> completion );
596815
597816 ask -> private = ctx ;
0 commit comments