diff --git a/bpf/fd_to_address_tracepoints.c b/bpf/fd_to_address_tracepoints.c index 5bb9baf..57c91bc 100644 --- a/bpf/fd_to_address_tracepoints.c +++ b/bpf/fd_to_address_tracepoints.c @@ -29,7 +29,7 @@ struct sys_enter_accept4_ctx { SEC("tracepoint/syscalls/sys_enter_accept4") void sys_enter_accept4(struct sys_enter_accept4_ctx* ctx) { - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); if (!should_watch(id >> 32)) { return; @@ -55,7 +55,7 @@ struct sys_exit_accept4_ctx { SEC("tracepoint/syscalls/sys_exit_accept4") void sys_exit_accept4(struct sys_exit_accept4_ctx* ctx) { - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); if (!should_watch(id >> 32)) { return; @@ -122,7 +122,7 @@ struct sys_enter_connect_ctx { SEC("tracepoint/syscalls/sys_enter_connect") void sys_enter_connect(struct sys_enter_connect_ctx* ctx) { - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); if (!should_watch(id >> 32)) { return; @@ -149,7 +149,7 @@ struct sys_exit_connect_ctx { SEC("tracepoint/syscalls/sys_exit_connect") void sys_exit_connect(struct sys_exit_connect_ctx* ctx) { - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); if (!should_watch(id >> 32)) { return; diff --git a/bpf/fd_tracepoints.c b/bpf/fd_tracepoints.c index 6be6e87..9196afe 100644 --- a/bpf/fd_tracepoints.c +++ b/bpf/fd_tracepoints.c @@ -59,7 +59,7 @@ static __always_inline void fd_tracepoints_handle_go(struct sys_enter_read_write SEC("tracepoint/syscalls/sys_enter_read") void sys_enter_read(struct sys_enter_read_write_ctx* ctx) { - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); if (!should_watch(id >> 32)) { return; @@ -76,7 +76,7 @@ void sys_enter_read(struct sys_enter_read_write_ctx* ctx) { SEC("tracepoint/syscalls/sys_enter_write") void sys_enter_write(struct sys_enter_read_write_ctx* ctx) { - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); if (!should_watch(id >> 32)) { return; @@ -93,7 +93,7 @@ void sys_enter_write(struct sys_enter_read_write_ctx* ctx) { SEC("tracepoint/syscalls/sys_exit_read") void sys_exit_read(struct sys_exit_read_write_ctx* ctx) { - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); // Delete from go map. The value is not used after exiting this syscall. // Keep value in openssl map. bpf_map_delete_elem(&go_kernel_read_context, &id); @@ -101,7 +101,7 @@ void sys_exit_read(struct sys_exit_read_write_ctx* ctx) { SEC("tracepoint/syscalls/sys_exit_write") void sys_exit_write(struct sys_exit_read_write_ctx* ctx) { - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); // Delete from go map. The value is not used after exiting this syscall. // Keep value in openssl map. bpf_map_delete_elem(&go_kernel_write_context, &id); diff --git a/bpf/go_uprobes.c b/bpf/go_uprobes.c index 3987deb..e45d044 100644 --- a/bpf/go_uprobes.c +++ b/bpf/go_uprobes.c @@ -175,7 +175,7 @@ static __always_inline int go_crypto_tls_get_fd_from_tcp_conn(struct pt_regs* ct } static __always_inline void go_crypto_tls_uprobe(struct pt_regs* ctx, struct bpf_map_def* go_context, enum ABI abi) { - __u64 pid_tgid = bpf_get_current_pid_tgid(); + __u64 pid_tgid = tracer_get_current_pid_tgid(); __u64 pid = pid_tgid >> 32; if (!should_target(pid)) { return; @@ -251,7 +251,7 @@ static __always_inline void go_crypto_tls_uprobe(struct pt_regs* ctx, struct bpf } static __always_inline void go_crypto_tls_ex_uprobe(struct pt_regs* ctx, struct bpf_map_def* go_context, struct bpf_map_def* go_user_kernel_context, __u32 flags, enum ABI abi) { - __u64 pid_tgid = bpf_get_current_pid_tgid(); + __u64 pid_tgid = tracer_get_current_pid_tgid(); __u64 pid = pid_tgid >> 32; if (!should_target(pid)) { return; diff --git a/bpf/include/pids.h b/bpf/include/pids.h index 67ed82e..9adf4ca 100644 --- a/bpf/include/pids.h +++ b/bpf/include/pids.h @@ -19,6 +19,34 @@ int _pid_in_map(struct bpf_map_def* pmap, __u32 pid) { return shouldTargetGlobally != NULL && *shouldTargetGlobally == 1; } +const volatile __u64 TRACER_NS_INO = 0; +#define TRACER_NAMESPACES_MAX 4 +static __always_inline __u64 tracer_get_current_pid_tgid() { + unsigned int inum; + + __u64 base_pid_tgid = bpf_get_current_pid_tgid(); + + if (TRACER_NS_INO == 0) { + return base_pid_tgid; + } + + struct task_struct* task = (struct task_struct*)bpf_get_current_task(); + + int level = BPF_CORE_READ(task, group_leader, nsproxy, pid_ns_for_children, level); + + for (int i = 0; i < TRACER_NAMESPACES_MAX; i++) { + if ((level - i) < 0) { + break; + } + inum = BPF_CORE_READ(task, group_leader, thread_pid, numbers[level - i].ns, ns.inum); + if (inum == TRACER_NS_INO) { + __u64 ret = BPF_CORE_READ(task, group_leader, thread_pid, numbers[level - i].nr); + ret = (ret << 32) | (base_pid_tgid & 0xFFFFFFFF); + return ret; + } + } + return base_pid_tgid; +} int should_target(__u32 pid) { return _pid_in_map(&target_pids_map, pid); diff --git a/bpf/openssl_uprobes.c b/bpf/openssl_uprobes.c index bc30f97..1b6437d 100644 --- a/bpf/openssl_uprobes.c +++ b/bpf/openssl_uprobes.c @@ -42,7 +42,7 @@ static __always_inline int get_count_bytes(struct pt_regs* ctx, struct ssl_info* static __always_inline void ssl_uprobe(struct pt_regs* ctx, void* ssl, void* buffer, int num, struct bpf_map_def* map_fd, size_t* count_ptr) { long err; - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); if (!should_target(id >> 32)) { return; @@ -61,7 +61,7 @@ static __always_inline void ssl_uprobe(struct pt_regs* ctx, void* ssl, void* buf } static __always_inline void ssl_uretprobe(struct pt_regs* ctx, struct bpf_map_def* map_fd, __u32 flags) { - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); if (!should_target(id >> 32)) { return; diff --git a/bpf/tcp_kprobes.c b/bpf/tcp_kprobes.c index 01890db..df700b7 100644 --- a/bpf/tcp_kprobes.c +++ b/bpf/tcp_kprobes.c @@ -79,7 +79,7 @@ static void __always_inline tcp_kprobes_forward_openssl(struct ssl_info* info_pt static __always_inline void tcp_kprobe(struct pt_regs* ctx, struct bpf_map_def* map_fd_openssl, struct bpf_map_def* map_fd_go_kernel, struct bpf_map_def* map_fd_go_user_kernel) { long err; - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); if (!should_watch(id >> 32)) { return; @@ -108,12 +108,12 @@ static __always_inline void tcp_kprobe(struct pt_regs* ctx, struct bpf_map_def* SEC("kprobe/tcp_sendmsg") void BPF_KPROBE(tcp_sendmsg) { - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); tcp_kprobe(ctx, &openssl_write_context, &go_kernel_write_context, &go_user_kernel_write_context); } SEC("kprobe/tcp_recvmsg") void BPF_KPROBE(tcp_recvmsg) { - __u64 id = bpf_get_current_pid_tgid(); + __u64 id = tracer_get_current_pid_tgid(); tcp_kprobe(ctx, &openssl_read_context, &go_kernel_read_context, &go_user_kernel_read_context); } diff --git a/packet_sniffer.go b/packet_sniffer.go index 8c73241..01bc48f 100644 --- a/packet_sniffer.go +++ b/packet_sniffer.go @@ -13,38 +13,49 @@ import ( type packetFilter struct { ingressFilterProgram *ebpf.Program egressFilterProgram *ebpf.Program + ingressPullProgram *ebpf.Program + egressPullProgram *ebpf.Program attachedPods map[string][2]link.Link tcClient TcClient } func newPacketFilter(ingressFilterProgram, egressFilterProgram, pullIngress, pullEgress *ebpf.Program, pktsRingBuffer *ebpf.Map) (*packetFilter, error) { - var ifaces []int - links, err := netlink.LinkList() - if err != nil { - return nil, err - } - for _, link := range links { - ifaces = append(ifaces, link.Attrs().Index) - } - tcClient := &TcClientImpl{ TcPackage: &TcPackageImpl{}, } - for _, l := range ifaces { - if err := tcClient.SetupTC(l, pullIngress.FD(), pullEgress.FD()); err != nil { - return nil, err - } - } pf := &packetFilter{ ingressFilterProgram: ingressFilterProgram, egressFilterProgram: egressFilterProgram, + ingressPullProgram: pullIngress, + egressPullProgram: pullEgress, attachedPods: make(map[string][2]link.Link), tcClient: tcClient, } + pf.update() return pf, nil } +func (p *packetFilter) update() { + var ifaces []int + links, err := netlink.LinkList() + if err != nil { + log.Error().Err(err).Msg("Get link list failed:") + return + } + for _, link := range links { + ifaces = append(ifaces, link.Attrs().Index) + } + + for _, l := range ifaces { + if err := p.tcClient.SetupTC(l, p.ingressPullProgram.FD(), p.egressPullProgram.FD()); err != nil { + log.Error().Int("link", l).Err(err).Msg("Setup TC failed:") + continue + } + log.Info().Int("link", l).Msg("Attached TC programs:") + } +} + func (p *packetFilter) close() { _ = p.tcClient.CleanTC() for uuid := range p.attachedPods { @@ -64,13 +75,13 @@ func (t *packetFilter) AttachPod(uuid, cgroupV2Path string) error { return err } t.attachedPods[uuid] = [2]link.Link{lIngress, lEgress} - log.Info().Str("pod", uuid).Msg("Attaching pod:") //XXX + log.Info().Str("pod", uuid).Msg("Attaching pod:") return nil } func (t *packetFilter) DetachPod(uuid string) error { - log.Info().Str("pod", uuid).Msg("Detaching pod:") //XXX + log.Info().Str("pod", uuid).Msg("Detaching pod:") p, ok := t.attachedPods[uuid] if !ok { return fmt.Errorf("pod not attached") diff --git a/tls_process_discoverer.go b/tls_process_discoverer.go index 889e70e..5f53c0c 100644 --- a/tls_process_discoverer.go +++ b/tls_process_discoverer.go @@ -25,6 +25,9 @@ type podInfo struct { var numberRegex = regexp.MustCompile("[0-9]+") func (t *Tracer) updateTargets(addedWatchedPods []v1.Pod, removedWatchedPods []v1.Pod, addedTargetedPods []v1.Pod, removedTargetedPods []v1.Pod) error { + if t.packetFilter != nil { + t.packetFilter.update() + } for _, pod := range removedTargetedPods { if t.packetFilter != nil { if err := t.packetFilter.DetachPod(string(pod.UID)); err == nil { diff --git a/tracer.go b/tracer.go index 6e2f632..cc53612 100644 --- a/tracer.go +++ b/tracer.go @@ -3,6 +3,8 @@ package main import ( "fmt" + "bytes" + "os" "strconv" "syscall" @@ -50,6 +52,44 @@ type pidOffset struct { offset uint64 } +type BpfObjectsImpl struct { + bpfObjs tracerObjects + specs *ebpf.CollectionSpec +} + +func (objs *BpfObjectsImpl) loadBpfObjects(bpfConstants map[string]uint64) error { + var err error + opts := ebpf.CollectionOptions{ + Programs: ebpf.ProgramOptions{ + LogSize: ebpf.DefaultVerifierLogSize * 32, + }, + } + + reader := bytes.NewReader(_TracerBytes) + objs.specs, err = ebpf.LoadCollectionSpecFromReader(reader) + if err != nil { + return err + } + + consts := make(map[string]interface{}) + for k, v := range bpfConstants { + consts[k] = v + } + err = objs.specs.RewriteConstants(consts) + if err != nil { + return err + } + + err = objs.specs.LoadAndAssign(&objs.bpfObjs, &opts) + if err != nil { + var ve *ebpf.VerifierError + if errors.As(err, &ve) { + log.Error().Msg(fmt.Sprintf("Got verifier error: %+v", ve)) + } + } + return err +} + func (t *Tracer) Init( chunksBufferSize int, logBufferSize int, @@ -81,25 +121,34 @@ func (t *Tracer) Init( log.Info().Msg(fmt.Sprintf("Detected Linux kernel version: %s cgroups version: %v", kernelVersion, cgroupsVersion)) - t.bpfObjects = tracerObjects{} // TODO: cilium/ebpf does not support .kconfig Therefore; for now, we load object files according to kernel version. if kernel.CompareKernelVersion(*kernelVersion, kernel.VersionInfo{Kernel: 4, Major: 6, Minor: 0}) < 1 { + t.bpfObjects = tracerObjects{} if err := loadTracer46Objects(&t.bpfObjects, nil); err != nil { return errors.Wrap(err, 0) } } else { - opts := ebpf.CollectionOptions{ - Programs: ebpf.ProgramOptions{ - LogSize: ebpf.DefaultVerifierLogSize * 32, - }, + var hostProcIno uint64 + fileInfo, err := os.Stat("/hostproc/1/ns/pid") + if err != nil { + // services like "apparmor" on EKS can reject access to system pid information + log.Warn().Err(err).Msg("Get host netns failed") + } else { + hostProcIno = fileInfo.Sys().(*syscall.Stat_t).Ino + log.Info().Uint64("ns", hostProcIno).Msg("Setting host ns") } - if err := loadTracerObjects(&t.bpfObjects, &opts); err != nil { - var ve *ebpf.VerifierError - if errors.As(err, &ve) { - log.Error().Msg(fmt.Sprintf("Got verifier error: %+v", ve)) - } - return errors.Wrap(err, 0) + + objs := &BpfObjectsImpl{} + + bpfConsts := map[string]uint64{ + "TRACER_NS_INO": hostProcIno, + } + err = objs.loadBpfObjects(bpfConsts) + if err != nil { + log.Error().Msg(fmt.Sprintf("load bpf objects failed: %v", err)) + return err } + t.bpfObjects = objs.bpfObjs } t.syscallHooks = syscallHooks{}