5454#include <linux/slab.h>
5555#include <linux/uaccess.h>
5656#include <linux/memcontrol.h>
57+ #include <linux/res_counter.h>
5758
5859#include <linux/filter.h>
5960#include <linux/rculist_nulls.h>
@@ -168,6 +169,7 @@ struct sock_common {
168169 /* public: */
169170};
170171
172+ struct cg_proto ;
171173/**
172174 * struct sock - network layer representation of sockets
173175 * @__sk_common: shared layout with inet_timewait_sock
@@ -228,6 +230,7 @@ struct sock_common {
228230 * @sk_security: used by security modules
229231 * @sk_mark: generic packet mark
230232 * @sk_classid: this socket's cgroup classid
233+ * @sk_cgrp: this socket's cgroup-specific proto data
231234 * @sk_write_pending: a write to stream socket waits to start
232235 * @sk_state_change: callback to indicate change in the state of the sock
233236 * @sk_data_ready: callback to indicate there is data to be processed
@@ -342,6 +345,7 @@ struct sock {
342345#endif
343346 __u32 sk_mark ;
344347 u32 sk_classid ;
348+ struct cg_proto * sk_cgrp ;
345349 void (* sk_state_change )(struct sock * sk );
346350 void (* sk_data_ready )(struct sock * sk , int bytes );
347351 void (* sk_write_space )(struct sock * sk );
@@ -838,6 +842,37 @@ struct proto {
838842#ifdef SOCK_REFCNT_DEBUG
839843 atomic_t socks ;
840844#endif
845+ #ifdef CONFIG_CGROUP_MEM_RES_CTLR_KMEM
846+ /*
847+ * cgroup specific init/deinit functions. Called once for all
848+ * protocols that implement it, from cgroups populate function.
849+ * This function has to setup any files the protocol want to
850+ * appear in the kmem cgroup filesystem.
851+ */
852+ int (* init_cgroup )(struct cgroup * cgrp ,
853+ struct cgroup_subsys * ss );
854+ void (* destroy_cgroup )(struct cgroup * cgrp ,
855+ struct cgroup_subsys * ss );
856+ struct cg_proto * (* proto_cgroup )(struct mem_cgroup * memcg );
857+ #endif
858+ };
859+
860+ struct cg_proto {
861+ void (* enter_memory_pressure )(struct sock * sk );
862+ struct res_counter * memory_allocated ; /* Current allocated memory. */
863+ struct percpu_counter * sockets_allocated ; /* Current number of sockets. */
864+ int * memory_pressure ;
865+ long * sysctl_mem ;
866+ /*
867+ * memcg field is used to find which memcg we belong directly
868+ * Each memcg struct can hold more than one cg_proto, so container_of
869+ * won't really cut.
870+ *
871+ * The elegant solution would be having an inverse function to
872+ * proto_cgroup in struct proto, but that means polluting the structure
873+ * for everybody, instead of just for memcg users.
874+ */
875+ struct mem_cgroup * memcg ;
841876};
842877
843878extern int proto_register (struct proto * prot , int alloc_slab );
@@ -856,7 +891,7 @@ static inline void sk_refcnt_debug_dec(struct sock *sk)
856891 sk -> sk_prot -> name , sk , atomic_read (& sk -> sk_prot -> socks ));
857892}
858893
859- static inline void sk_refcnt_debug_release (const struct sock * sk )
894+ inline void sk_refcnt_debug_release (const struct sock * sk )
860895{
861896 if (atomic_read (& sk -> sk_refcnt ) != 1 )
862897 printk (KERN_DEBUG "Destruction of the %s socket %p delayed, refcnt=%d\n" ,
@@ -868,6 +903,24 @@ static inline void sk_refcnt_debug_release(const struct sock *sk)
868903#define sk_refcnt_debug_release (sk ) do { } while (0)
869904#endif /* SOCK_REFCNT_DEBUG */
870905
906+ #ifdef CONFIG_CGROUP_MEM_RES_CTLR_KMEM
907+ extern struct jump_label_key memcg_socket_limit_enabled ;
908+ static inline struct cg_proto * parent_cg_proto (struct proto * proto ,
909+ struct cg_proto * cg_proto )
910+ {
911+ return proto -> proto_cgroup (parent_mem_cgroup (cg_proto -> memcg ));
912+ }
913+ #define mem_cgroup_sockets_enabled static_branch(&memcg_socket_limit_enabled)
914+ #else
915+ #define mem_cgroup_sockets_enabled 0
916+ static inline struct cg_proto * parent_cg_proto (struct proto * proto ,
917+ struct cg_proto * cg_proto )
918+ {
919+ return NULL ;
920+ }
921+ #endif
922+
923+
871924static inline bool sk_has_memory_pressure (const struct sock * sk )
872925{
873926 return sk -> sk_prot -> memory_pressure != NULL ;
@@ -877,59 +930,147 @@ static inline bool sk_under_memory_pressure(const struct sock *sk)
877930{
878931 if (!sk -> sk_prot -> memory_pressure )
879932 return false;
933+
934+ if (mem_cgroup_sockets_enabled && sk -> sk_cgrp )
935+ return !!* sk -> sk_cgrp -> memory_pressure ;
936+
880937 return !!* sk -> sk_prot -> memory_pressure ;
881938}
882939
883940static inline void sk_leave_memory_pressure (struct sock * sk )
884941{
885942 int * memory_pressure = sk -> sk_prot -> memory_pressure ;
886943
887- if (memory_pressure && * memory_pressure )
944+ if (!memory_pressure )
945+ return ;
946+
947+ if (* memory_pressure )
888948 * memory_pressure = 0 ;
949+
950+ if (mem_cgroup_sockets_enabled && sk -> sk_cgrp ) {
951+ struct cg_proto * cg_proto = sk -> sk_cgrp ;
952+ struct proto * prot = sk -> sk_prot ;
953+
954+ for (; cg_proto ; cg_proto = parent_cg_proto (prot , cg_proto ))
955+ if (* cg_proto -> memory_pressure )
956+ * cg_proto -> memory_pressure = 0 ;
957+ }
958+
889959}
890960
891961static inline void sk_enter_memory_pressure (struct sock * sk )
892962{
893- if (sk -> sk_prot -> enter_memory_pressure )
894- sk -> sk_prot -> enter_memory_pressure (sk );
963+ if (!sk -> sk_prot -> enter_memory_pressure )
964+ return ;
965+
966+ if (mem_cgroup_sockets_enabled && sk -> sk_cgrp ) {
967+ struct cg_proto * cg_proto = sk -> sk_cgrp ;
968+ struct proto * prot = sk -> sk_prot ;
969+
970+ for (; cg_proto ; cg_proto = parent_cg_proto (prot , cg_proto ))
971+ cg_proto -> enter_memory_pressure (sk );
972+ }
973+
974+ sk -> sk_prot -> enter_memory_pressure (sk );
895975}
896976
897977static inline long sk_prot_mem_limits (const struct sock * sk , int index )
898978{
899979 long * prot = sk -> sk_prot -> sysctl_mem ;
980+ if (mem_cgroup_sockets_enabled && sk -> sk_cgrp )
981+ prot = sk -> sk_cgrp -> sysctl_mem ;
900982 return prot [index ];
901983}
902984
985+ static inline void memcg_memory_allocated_add (struct cg_proto * prot ,
986+ unsigned long amt ,
987+ int * parent_status )
988+ {
989+ struct res_counter * fail ;
990+ int ret ;
991+
992+ ret = res_counter_charge (prot -> memory_allocated ,
993+ amt << PAGE_SHIFT , & fail );
994+
995+ if (ret < 0 )
996+ * parent_status = OVER_LIMIT ;
997+ }
998+
999+ static inline void memcg_memory_allocated_sub (struct cg_proto * prot ,
1000+ unsigned long amt )
1001+ {
1002+ res_counter_uncharge (prot -> memory_allocated , amt << PAGE_SHIFT );
1003+ }
1004+
1005+ static inline u64 memcg_memory_allocated_read (struct cg_proto * prot )
1006+ {
1007+ u64 ret ;
1008+ ret = res_counter_read_u64 (prot -> memory_allocated , RES_USAGE );
1009+ return ret >> PAGE_SHIFT ;
1010+ }
1011+
9031012static inline long
9041013sk_memory_allocated (const struct sock * sk )
9051014{
9061015 struct proto * prot = sk -> sk_prot ;
1016+ if (mem_cgroup_sockets_enabled && sk -> sk_cgrp )
1017+ return memcg_memory_allocated_read (sk -> sk_cgrp );
1018+
9071019 return atomic_long_read (prot -> memory_allocated );
9081020}
9091021
9101022static inline long
911- sk_memory_allocated_add (struct sock * sk , int amt )
1023+ sk_memory_allocated_add (struct sock * sk , int amt , int * parent_status )
9121024{
9131025 struct proto * prot = sk -> sk_prot ;
1026+
1027+ if (mem_cgroup_sockets_enabled && sk -> sk_cgrp ) {
1028+ memcg_memory_allocated_add (sk -> sk_cgrp , amt , parent_status );
1029+ /* update the root cgroup regardless */
1030+ atomic_long_add_return (amt , prot -> memory_allocated );
1031+ return memcg_memory_allocated_read (sk -> sk_cgrp );
1032+ }
1033+
9141034 return atomic_long_add_return (amt , prot -> memory_allocated );
9151035}
9161036
9171037static inline void
918- sk_memory_allocated_sub (struct sock * sk , int amt )
1038+ sk_memory_allocated_sub (struct sock * sk , int amt , int parent_status )
9191039{
9201040 struct proto * prot = sk -> sk_prot ;
1041+
1042+ if (mem_cgroup_sockets_enabled && sk -> sk_cgrp &&
1043+ parent_status != OVER_LIMIT ) /* Otherwise was uncharged already */
1044+ memcg_memory_allocated_sub (sk -> sk_cgrp , amt );
1045+
9211046 atomic_long_sub (amt , prot -> memory_allocated );
9221047}
9231048
9241049static inline void sk_sockets_allocated_dec (struct sock * sk )
9251050{
9261051 struct proto * prot = sk -> sk_prot ;
1052+
1053+ if (mem_cgroup_sockets_enabled && sk -> sk_cgrp ) {
1054+ struct cg_proto * cg_proto = sk -> sk_cgrp ;
1055+
1056+ for (; cg_proto ; cg_proto = parent_cg_proto (prot , cg_proto ))
1057+ percpu_counter_dec (cg_proto -> sockets_allocated );
1058+ }
1059+
9271060 percpu_counter_dec (prot -> sockets_allocated );
9281061}
9291062
9301063static inline void sk_sockets_allocated_inc (struct sock * sk )
9311064{
9321065 struct proto * prot = sk -> sk_prot ;
1066+
1067+ if (mem_cgroup_sockets_enabled && sk -> sk_cgrp ) {
1068+ struct cg_proto * cg_proto = sk -> sk_cgrp ;
1069+
1070+ for (; cg_proto ; cg_proto = parent_cg_proto (prot , cg_proto ))
1071+ percpu_counter_inc (cg_proto -> sockets_allocated );
1072+ }
1073+
9331074 percpu_counter_inc (prot -> sockets_allocated );
9341075}
9351076
@@ -938,6 +1079,9 @@ sk_sockets_allocated_read_positive(struct sock *sk)
9381079{
9391080 struct proto * prot = sk -> sk_prot ;
9401081
1082+ if (mem_cgroup_sockets_enabled && sk -> sk_cgrp )
1083+ return percpu_counter_sum_positive (sk -> sk_cgrp -> sockets_allocated );
1084+
9411085 return percpu_counter_sum_positive (prot -> sockets_allocated );
9421086}
9431087
0 commit comments