4141#include <linux/mm.h>
4242#include <net/strparser.h>
4343#include <net/tcp.h>
44+ #include <linux/ptr_ring.h>
45+ #include <net/inet_common.h>
4446
4547#define SOCK_CREATE_FLAG_MASK \
4648 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
@@ -82,6 +84,7 @@ struct smap_psock {
8284 int sg_size ;
8385 int eval ;
8486 struct sk_msg_buff * cork ;
87+ struct list_head ingress ;
8588
8689 struct strparser strp ;
8790 struct bpf_prog * bpf_tx_msg ;
@@ -103,6 +106,8 @@ struct smap_psock {
103106};
104107
105108static void smap_release_sock (struct smap_psock * psock , struct sock * sock );
109+ static int bpf_tcp_recvmsg (struct sock * sk , struct msghdr * msg , size_t len ,
110+ int nonblock , int flags , int * addr_len );
106111static int bpf_tcp_sendmsg (struct sock * sk , struct msghdr * msg , size_t size );
107112static int bpf_tcp_sendpage (struct sock * sk , struct page * page ,
108113 int offset , size_t size , int flags );
@@ -112,6 +117,21 @@ static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
112117 return rcu_dereference_sk_user_data (sk );
113118}
114119
120+ static bool bpf_tcp_stream_read (const struct sock * sk )
121+ {
122+ struct smap_psock * psock ;
123+ bool empty = true;
124+
125+ rcu_read_lock ();
126+ psock = smap_psock_sk (sk );
127+ if (unlikely (!psock ))
128+ goto out ;
129+ empty = list_empty (& psock -> ingress );
130+ out :
131+ rcu_read_unlock ();
132+ return !empty ;
133+ }
134+
115135static struct proto tcp_bpf_proto ;
116136static int bpf_tcp_init (struct sock * sk )
117137{
@@ -135,6 +155,8 @@ static int bpf_tcp_init(struct sock *sk)
135155 if (psock -> bpf_tx_msg ) {
136156 tcp_bpf_proto .sendmsg = bpf_tcp_sendmsg ;
137157 tcp_bpf_proto .sendpage = bpf_tcp_sendpage ;
158+ tcp_bpf_proto .recvmsg = bpf_tcp_recvmsg ;
159+ tcp_bpf_proto .stream_memory_read = bpf_tcp_stream_read ;
138160 }
139161
140162 sk -> sk_prot = & tcp_bpf_proto ;
@@ -170,6 +192,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
170192{
171193 void (* close_fun )(struct sock * sk , long timeout );
172194 struct smap_psock_map_entry * e , * tmp ;
195+ struct sk_msg_buff * md , * mtmp ;
173196 struct smap_psock * psock ;
174197 struct sock * osk ;
175198
@@ -188,6 +211,12 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
188211 close_fun = psock -> save_close ;
189212
190213 write_lock_bh (& sk -> sk_callback_lock );
214+ list_for_each_entry_safe (md , mtmp , & psock -> ingress , list ) {
215+ list_del (& md -> list );
216+ free_start_sg (psock -> sock , md );
217+ kfree (md );
218+ }
219+
191220 list_for_each_entry_safe (e , tmp , & psock -> maps , list ) {
192221 osk = cmpxchg (e -> entry , sk , NULL );
193222 if (osk == sk ) {
@@ -468,13 +497,80 @@ static unsigned int smap_do_tx_msg(struct sock *sk,
468497 return _rc ;
469498}
470499
500+ static int bpf_tcp_ingress (struct sock * sk , int apply_bytes ,
501+ struct smap_psock * psock ,
502+ struct sk_msg_buff * md , int flags )
503+ {
504+ bool apply = apply_bytes ;
505+ size_t size , copied = 0 ;
506+ struct sk_msg_buff * r ;
507+ int err = 0 , i ;
508+
509+ r = kzalloc (sizeof (struct sk_msg_buff ), __GFP_NOWARN | GFP_KERNEL );
510+ if (unlikely (!r ))
511+ return - ENOMEM ;
512+
513+ lock_sock (sk );
514+ r -> sg_start = md -> sg_start ;
515+ i = md -> sg_start ;
516+
517+ do {
518+ r -> sg_data [i ] = md -> sg_data [i ];
519+
520+ size = (apply && apply_bytes < md -> sg_data [i ].length ) ?
521+ apply_bytes : md -> sg_data [i ].length ;
522+
523+ if (!sk_wmem_schedule (sk , size )) {
524+ if (!copied )
525+ err = - ENOMEM ;
526+ break ;
527+ }
528+
529+ sk_mem_charge (sk , size );
530+ r -> sg_data [i ].length = size ;
531+ md -> sg_data [i ].length -= size ;
532+ md -> sg_data [i ].offset += size ;
533+ copied += size ;
534+
535+ if (md -> sg_data [i ].length ) {
536+ get_page (sg_page (& r -> sg_data [i ]));
537+ r -> sg_end = (i + 1 ) == MAX_SKB_FRAGS ? 0 : i + 1 ;
538+ } else {
539+ i ++ ;
540+ if (i == MAX_SKB_FRAGS )
541+ i = 0 ;
542+ r -> sg_end = i ;
543+ }
544+
545+ if (apply ) {
546+ apply_bytes -= size ;
547+ if (!apply_bytes )
548+ break ;
549+ }
550+ } while (i != md -> sg_end );
551+
552+ md -> sg_start = i ;
553+
554+ if (!err ) {
555+ list_add_tail (& r -> list , & psock -> ingress );
556+ sk -> sk_data_ready (sk );
557+ } else {
558+ free_start_sg (sk , r );
559+ kfree (r );
560+ }
561+
562+ release_sock (sk );
563+ return err ;
564+ }
565+
471566static int bpf_tcp_sendmsg_do_redirect (struct sock * sk , int send ,
472567 struct sk_msg_buff * md ,
473568 int flags )
474569{
475570 struct smap_psock * psock ;
476571 struct scatterlist * sg ;
477572 int i , err , free = 0 ;
573+ bool ingress = !!(md -> flags & BPF_F_INGRESS );
478574
479575 sg = md -> sg_data ;
480576
@@ -487,9 +583,14 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
487583 goto out_rcu ;
488584
489585 rcu_read_unlock ();
490- lock_sock (sk );
491- err = bpf_tcp_push (sk , send , md , flags , false);
492- release_sock (sk );
586+
587+ if (ingress ) {
588+ err = bpf_tcp_ingress (sk , send , psock , md , flags );
589+ } else {
590+ lock_sock (sk );
591+ err = bpf_tcp_push (sk , send , md , flags , false);
592+ release_sock (sk );
593+ }
493594 smap_release_sock (psock , sk );
494595 if (unlikely (err ))
495596 goto out ;
@@ -623,6 +724,89 @@ static int bpf_exec_tx_verdict(struct smap_psock *psock,
623724 return err ;
624725}
625726
727+ static int bpf_tcp_recvmsg (struct sock * sk , struct msghdr * msg , size_t len ,
728+ int nonblock , int flags , int * addr_len )
729+ {
730+ struct iov_iter * iter = & msg -> msg_iter ;
731+ struct smap_psock * psock ;
732+ int copied = 0 ;
733+
734+ if (unlikely (flags & MSG_ERRQUEUE ))
735+ return inet_recv_error (sk , msg , len , addr_len );
736+
737+ rcu_read_lock ();
738+ psock = smap_psock_sk (sk );
739+ if (unlikely (!psock ))
740+ goto out ;
741+
742+ if (unlikely (!refcount_inc_not_zero (& psock -> refcnt )))
743+ goto out ;
744+ rcu_read_unlock ();
745+
746+ if (!skb_queue_empty (& sk -> sk_receive_queue ))
747+ return tcp_recvmsg (sk , msg , len , nonblock , flags , addr_len );
748+
749+ lock_sock (sk );
750+ while (copied != len ) {
751+ struct scatterlist * sg ;
752+ struct sk_msg_buff * md ;
753+ int i ;
754+
755+ md = list_first_entry_or_null (& psock -> ingress ,
756+ struct sk_msg_buff , list );
757+ if (unlikely (!md ))
758+ break ;
759+ i = md -> sg_start ;
760+ do {
761+ struct page * page ;
762+ int n , copy ;
763+
764+ sg = & md -> sg_data [i ];
765+ copy = sg -> length ;
766+ page = sg_page (sg );
767+
768+ if (copied + copy > len )
769+ copy = len - copied ;
770+
771+ n = copy_page_to_iter (page , sg -> offset , copy , iter );
772+ if (n != copy ) {
773+ md -> sg_start = i ;
774+ release_sock (sk );
775+ smap_release_sock (psock , sk );
776+ return - EFAULT ;
777+ }
778+
779+ copied += copy ;
780+ sg -> offset += copy ;
781+ sg -> length -= copy ;
782+ sk_mem_uncharge (sk , copy );
783+
784+ if (!sg -> length ) {
785+ i ++ ;
786+ if (i == MAX_SKB_FRAGS )
787+ i = 0 ;
788+ put_page (page );
789+ }
790+ if (copied == len )
791+ break ;
792+ } while (i != md -> sg_end );
793+ md -> sg_start = i ;
794+
795+ if (!sg -> length && md -> sg_start == md -> sg_end ) {
796+ list_del (& md -> list );
797+ kfree (md );
798+ }
799+ }
800+
801+ release_sock (sk );
802+ smap_release_sock (psock , sk );
803+ return copied ;
804+ out :
805+ rcu_read_unlock ();
806+ return tcp_recvmsg (sk , msg , len , nonblock , flags , addr_len );
807+ }
808+
809+
626810static int bpf_tcp_sendmsg (struct sock * sk , struct msghdr * msg , size_t size )
627811{
628812 int flags = msg -> msg_flags | MSG_NO_SHARED_FRAGS ;
@@ -1107,6 +1291,7 @@ static void sock_map_remove_complete(struct bpf_stab *stab)
11071291static void smap_gc_work (struct work_struct * w )
11081292{
11091293 struct smap_psock_map_entry * e , * tmp ;
1294+ struct sk_msg_buff * md , * mtmp ;
11101295 struct smap_psock * psock ;
11111296
11121297 psock = container_of (w , struct smap_psock , gc_work );
@@ -1131,6 +1316,12 @@ static void smap_gc_work(struct work_struct *w)
11311316 kfree (psock -> cork );
11321317 }
11331318
1319+ list_for_each_entry_safe (md , mtmp , & psock -> ingress , list ) {
1320+ list_del (& md -> list );
1321+ free_start_sg (psock -> sock , md );
1322+ kfree (md );
1323+ }
1324+
11341325 list_for_each_entry_safe (e , tmp , & psock -> maps , list ) {
11351326 list_del (& e -> list );
11361327 kfree (e );
@@ -1160,6 +1351,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock,
11601351 INIT_WORK (& psock -> tx_work , smap_tx_work );
11611352 INIT_WORK (& psock -> gc_work , smap_gc_work );
11621353 INIT_LIST_HEAD (& psock -> maps );
1354+ INIT_LIST_HEAD (& psock -> ingress );
11631355 refcount_set (& psock -> refcnt , 1 );
11641356
11651357 rcu_assign_sk_user_data (sock , psock );
0 commit comments