diff --git a/pkg/sentry/inet/nlmcast.go b/pkg/sentry/inet/nlmcast.go index 9f3a521899..1efd7a50cb 100644 --- a/pkg/sentry/inet/nlmcast.go +++ b/pkg/sentry/inet/nlmcast.go @@ -104,7 +104,8 @@ func (m *McastTable) RemoveSocket(s NetlinkSocket) { delete(m.socks[p], s) } -func (m *McastTable) forEachMcastSock(protocol int, mcastGroup int, fn func(s NetlinkSocket)) { +// ForEachMcastSock calls fn on all Netlink sockets that are members of the given multicast group. +func (m *McastTable) ForEachMcastSock(protocol int, mcastGroup int, fn func(s NetlinkSocket)) { m.mu.Lock() defer m.mu.Unlock() if _, ok := m.socks[protocol]; !ok { @@ -122,7 +123,7 @@ func (m *McastTable) forEachMcastSock(protocol int, mcastGroup int, fn func(s Ne // OnInterfaceChangeEvent implements InterfaceEventSubscriber.OnInterfaceChangeEvent. func (m *McastTable) OnInterfaceChangeEvent(ctx context.Context, idx int32, i Interface) { // Relay the event to RTNLGRP_LINK subscribers. - m.forEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) { + m.ForEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) { s.HandleInterfaceChangeEvent(ctx, idx, i) }) } @@ -130,7 +131,7 @@ func (m *McastTable) OnInterfaceChangeEvent(ctx context.Context, idx int32, i In // OnInterfaceDeleteEvent implements InterfaceEventSubscriber.OnInterfaceDeleteEvent. func (m *McastTable) OnInterfaceDeleteEvent(ctx context.Context, idx int32, i Interface) { // Relay the event to RTNLGRP_LINK subscribers. - m.forEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) { + m.ForEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) { s.HandleInterfaceDeleteEvent(ctx, idx, i) }) } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 15399c9148..d29eb9a783 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -28,6 +28,7 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/ktime", "//pkg/sentry/socket", + "//pkg/sentry/socket/control", "//pkg/sentry/socket/netlink/nlmsg", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix", diff --git a/pkg/sentry/socket/netlink/netfilter/protocol.go b/pkg/sentry/socket/netlink/netfilter/protocol.go index adacd29b9e..5808ef60f3 100644 --- a/pkg/sentry/socket/netlink/netfilter/protocol.go +++ b/pkg/sentry/socket/netlink/netfilter/protocol.go @@ -82,7 +82,7 @@ func (p *Protocol) Receive(ctx context.Context, s *netlink.Socket, buf []byte) * // TODO: b/434785410 - Support batch messages. if hdr.Type == linux.NFNL_MSG_BATCH_BEGIN { ms := nlmsg.NewMessageSet(s.GetPortID(), hdr.Seq) - if err := p.receiveBatchMessage(ctx, s, ms, buf); err != nil { + if err := p.receiveBatchMessage(ctx, ms, buf); err != nil { log.Debugf("Nftables: Failed to process batch message: %v", err) netlink.DumpErrorMessage(hdr, ms, err.GetError()) } @@ -1215,7 +1215,7 @@ func (p *Protocol) ProcessMessage(ctx context.Context, s *netlink.Socket, msg *n } // receiveBatchMessage processes a NETFILTER batch message. -func (p *Protocol) receiveBatchMessage(ctx context.Context, s *netlink.Socket, ms *nlmsg.MessageSet, buf []byte) *syserr.AnnotatedError { +func (p *Protocol) receiveBatchMessage(ctx context.Context, ms *nlmsg.MessageSet, buf []byte) *syserr.AnnotatedError { // Linux ignores messages that are too small. // From net/netfilter/nfnetlink.c:nfnetlink_rcv_skb_batch if len(buf) < linux.NetlinkMessageHeaderSize+linux.SizeOfNetfilterGenMsg { @@ -1254,7 +1254,7 @@ func (p *Protocol) receiveBatchMessage(ctx context.Context, s *netlink.Socket, m // The resource ID is a 16-bit value that is stored in network byte order. // We ensure that it is in host byte order before passing it for processing. resID := nlmsg.NetToHostU16(nfGenMsg.ResourceID) - if err := p.processBatchMessage(ctx, s, buf, ms, hdr, resID); err != nil { + if err := p.processBatchMessage(ctx, buf, ms, hdr, resID); err != nil { log.Debugf("Failed to process batch message: %v", err) netlink.DumpErrorMessage(hdr, ms, err.GetError()) } @@ -1263,7 +1263,7 @@ func (p *Protocol) receiveBatchMessage(ctx context.Context, s *netlink.Socket, m } // processBatchMessage processes a batch message. -func (p *Protocol) processBatchMessage(ctx context.Context, s *netlink.Socket, buf []byte, ms *nlmsg.MessageSet, batchHdr linux.NetlinkMessageHeader, subsysID uint16) *syserr.AnnotatedError { +func (p *Protocol) processBatchMessage(ctx context.Context, buf []byte, ms *nlmsg.MessageSet, batchHdr linux.NetlinkMessageHeader, subsysID uint16) *syserr.AnnotatedError { if subsysID >= linux.NFNL_SUBSYS_COUNT { return syserr.NewAnnotatedError(syserr.ErrInvalidArgument, fmt.Sprintf("Nftables: Unknown subsystem id %d", subsysID)) } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index e86ca0d500..7fa06038c3 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -19,6 +19,8 @@ import ( "fmt" "io" "math" + "math/bits" + "strconv" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -35,6 +37,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/ktime" "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/nlmsg" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port" "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -105,7 +108,8 @@ type Socket struct { // Writing to it requires the per-netns table lock to be held, reading it does not. groups atomicbitops.Uint64 - // mu protects the fields below. + // mu protects the fields below. In lock order it follows the per-netns multicast + // table lock. mu sync.Mutex `state:"nosave"` // bound indicates that portid is valid. @@ -321,7 +325,7 @@ func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error { return nil } -func (s *Socket) checkMcastSupport(t *kernel.Task) *syserr.Error { +func (s *Socket) checkMcastSupport() *syserr.Error { // Currently only ROUTE family sockets support multicast. if s.Protocol() != linux.NETLINK_ROUTE { return syserr.ErrNotSupported @@ -330,20 +334,15 @@ func (s *Socket) checkMcastSupport(t *kernel.Task) *syserr.Error { if _, ok := s.Stack().(inet.InterfaceEventPublisher); !ok { return syserr.ErrNotSupported } - // man 7 netlink: "Only processes with an effective UID of 0 or the CAP_NET_ADMIN - // capability may send or listen to a netlink multicast group." - if !t.HasCapability(linux.CAP_NET_ADMIN) { - return syserr.ErrPermissionDenied - } return nil } // preconditions: the netlink multicast table is locked. -func (s *Socket) joinGroups(t *kernel.Task, groups uint64) *syserr.Error { +func (s *Socket) joinGroups(groups uint64) *syserr.Error { if groups&supportedGroups != groups { return syserr.ErrNotSupported } - if err := s.checkMcastSupport(t); err != nil { + if err := s.checkMcastSupport(); err != nil { return err } @@ -358,7 +357,7 @@ func (s *Socket) joinGroups(t *kernel.Task, groups uint64) *syserr.Error { } // preconditions: the netlink multicast table is locked. -func (s *Socket) joinGroup(t *kernel.Task, group uint32) *syserr.Error { +func (s *Socket) joinGroup(group uint32) *syserr.Error { if group == 0 || group > 64 { return syserr.ErrInvalidArgument } @@ -366,7 +365,7 @@ func (s *Socket) joinGroup(t *kernel.Task, group uint32) *syserr.Error { if groups&supportedGroups != groups { return syserr.ErrNotSupported } - if err := s.checkMcastSupport(t); err != nil { + if err := s.checkMcastSupport(); err != nil { return err } @@ -379,7 +378,7 @@ func (s *Socket) joinGroup(t *kernel.Task, group uint32) *syserr.Error { } // preconditions: the netlink multicast table is locked. -func (s *Socket) leaveGroup(t *kernel.Task, group uint32) *syserr.Error { +func (s *Socket) leaveGroup(group uint32) *syserr.Error { if group == 0 || group > 64 { return syserr.ErrInvalidArgument } @@ -387,7 +386,7 @@ func (s *Socket) leaveGroup(t *kernel.Task, group uint32) *syserr.Error { if groups&supportedGroups != groups { return syserr.ErrNotSupported } - if err := s.checkMcastSupport(t); err != nil { + if err := s.checkMcastSupport(); err != nil { return err } @@ -415,14 +414,8 @@ func (s *Socket) HandleInterfaceChangeEvent(ctx context.Context, idx int32, i in panic(fmt.Sprintf("Non-ROUTE netlink socket (protocol %d) cannot handle interface events", s.Protocol())) } - // s.portID is protected by s.mu. But we cannot take s.mu because it already may - // be held across sendMsg() -> Protocol.Receive() -> ProcessMessages() -> ... - // -> protocol.ProcessMessage() -> Socket.SendResponse(). The racy access to s.portID - // happens to be okay because once bound, a netlink socket's port ID is immutable. - // TODO(b/435491173): Stop holding s.mu across the chain above. - ms := nlmsg.NewMessageSet(s.portID, 0) + ms := nlmsg.NewMessageSet(s.GetPortID(), 0) routeProtocol.AddNewLinkMessage(ms, idx, i) - // TODO(b/456238795): Implement netlink ENOBUFS. s.SendResponse(ctx, ms) } @@ -433,14 +426,8 @@ func (s *Socket) HandleInterfaceDeleteEvent(ctx context.Context, idx int32, i in panic(fmt.Sprintf("Non-ROUTE netlink socket (protocol %d) cannot handle interface events", s.Protocol())) } - // s.portID is protected by s.mu. But we cannot take s.mu because it already may - // be held across sendMsg() -> Protocol.Receive() -> ProcessMessages() -> ... - // -> protocol.ProcessMessage() -> Socket.SendResponse(). The racy access to s.portID - // happens to be okay because once bound, a netlink socket's port ID is immutable. - // TODO(b/435491173): Stop holding s.mu across the chain above. - ms := nlmsg.NewMessageSet(s.portID, 0) + ms := nlmsg.NewMessageSet(s.GetPortID(), 0) routeProtocol.AddDelLinkMessage(ms, idx, i) - // TODO(b/456238795): Implement netlink ENOBUFS. s.SendResponse(ctx, ms) } @@ -454,7 +441,7 @@ func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if a.Groups != 0 { var err *syserr.Error s.netns.NetlinkMcastTable().WithTableLocked(func() { - err = s.joinGroups(t, uint64(a.Groups)) + err = s.joinGroups(uint64(a.Groups)) }) if err != nil { return err @@ -473,7 +460,7 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr return err } - // No support for sending to destination multicast groups yet. + // No support for connecting to destination multicast groups yet. if a.Groups != 0 { return syserr.ErrPermissionDenied } @@ -685,7 +672,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy group := hostarch.ByteOrder.Uint32(opt) var err *syserr.Error s.netns.NetlinkMcastTable().WithTableLocked(func() { - err = s.joinGroup(t, group) + err = s.joinGroup(group) }) return err @@ -696,7 +683,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy group := hostarch.ByteOrder.Uint32(opt) var err *syserr.Error s.netns.NetlinkMcastTable().WithTableLocked(func() { - err = s.leaveGroup(t, group) + err = s.leaveGroup(group) }) return err @@ -757,6 +744,7 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have r := unix.EndpointReader{ Ctx: t, Endpoint: s.ep, + Creds: s.Passcred(), Peek: flags&linux.MSG_PEEK != 0, } @@ -785,7 +773,10 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have if trunc { n = int64(r.MsgSize) } - return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + if srcPort, err := strconv.ParseUint(r.From.Addr, 10, 32); err == nil { + from.PortID = uint32(srcPort) + } + return int(n), mflags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) } // We'll have to block. Register for notification and keep trying to @@ -805,7 +796,10 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have if trunc { n = int64(r.MsgSize) } - return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + if srcPort, err := strconv.ParseUint(r.From.Addr, 10, 32); err == nil { + from.PortID = uint32(srcPort) + } + return int(n), mflags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) } if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { @@ -837,31 +831,35 @@ func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) // kernelCreds is the concrete version of kernelSCM used in all creds. var kernelCreds = &kernelSCM{} -// SendResponse sends the response messages in ms back to userspace. +func (s *Socket) sendBufs(ctx context.Context, bufs [][]byte, cms transport.ControlMessages, srcPort uint32) *syserr.Error { + from := transport.Address{Addr: strconv.FormatUint(uint64(srcPort), 10)} + _, notify, err := s.connection.Send(ctx, bufs, cms, from) + // If the buffer is full, we simply drop messages, just like Linux. + // TODO(b/456238795): Implement netlink ENOBUFS. + if err != nil && err != syserr.ErrWouldBlock { + return err + } + if notify { + s.connection.SendNotify() + } + return nil +} + +// SendResponse sends the kernel's response messages in ms back to userspace. func (s *Socket) SendResponse(ctx context.Context, ms *nlmsg.MessageSet) *syserr.Error { // Linux combines multiple netlink messages into a single datagram. bufs := make([][]byte, 0, len(ms.Messages)) for _, m := range ms.Messages { bufs = append(bufs, m.Finalize()) } - - // All messages are from the kernel. cms := transport.ControlMessages{ Credentials: kernelCreds, } if len(bufs) > 0 { - // RecvMsg never receives the address, so we don't need to send - // one. - _, notify, err := s.connection.Send(ctx, bufs, cms, transport.Address{}) - // If the buffer is full, we simply drop messages, just like - // Linux. - if err != nil && err != syserr.ErrWouldBlock { + if err := s.sendBufs(ctx, bufs, cms, 0 /* srcPort */); err != nil { return err } - if notify { - s.connection.SendNotify() - } } // N.B. multi-part messages should still send NLMSG_DONE even if @@ -880,13 +878,9 @@ func (s *Socket) SendResponse(ctx context.Context, ms *nlmsg.MessageSet) *syserr // Add the dump_done_errno payload. m.Put(primitive.AllocateInt64(0)) - _, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, transport.Address{}) - if err != nil && err != syserr.ErrWouldBlock { + if err := s.sendBufs(ctx, [][]byte{m.Finalize()}, cms, 0 /* srcPort */); err != nil { return err } - if notify { - s.connection.SendNotify() - } } return nil @@ -933,7 +927,7 @@ func (s *Socket) ProcessMessages(ctx context.Context, buf []byte) *syserr.Error continue } - ms := nlmsg.NewMessageSet(s.portID, hdr.Seq) + ms := nlmsg.NewMessageSet(s.GetPortID(), hdr.Seq) if err := s.protocol.ProcessMessage(ctx, s, msg, ms); err != nil { DumpErrorMessage(hdr, ms, err) } else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK { @@ -951,6 +945,7 @@ func (s *Socket) ProcessMessages(ctx context.Context, buf []byte) *syserr.Error // sendMsg is the core of message send, used for SendMsg and Write. func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { dstPort := int32(0) + dstGroup := 0 if len(to) != 0 { a, err := ExtractSockAddr(to) @@ -958,9 +953,16 @@ func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, return 0, err } - // No support for multicast groups yet. if a.Groups != 0 { - return 0, syserr.ErrPermissionDenied + // man 7 netlink: "Since Linux 2.6.13, messages can't be broadcast to multiple groups" + // So we pick one like Linux. + dstGroup = bits.TrailingZeros32(a.Groups) + 1 + if err := s.checkMcastSupport(); err != nil { + return 0, err + } + if !kernel.TaskFromContext(ctx).HasCapability(linux.CAP_NET_ADMIN) { + return 0, syserr.ErrNotPermitted + } } dstPort = int32(a.PortID) @@ -973,14 +975,13 @@ func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, } s.mu.Lock() - defer s.mu.Unlock() - // An unbound socket has a port ID defaulted to 0. If it has not yet been bound, // bind it to assign it a unique port ID. // From net/netlink/af_netlink.c:netlink_sendmsg if !s.bound { s.bindPort(kernel.TaskFromContext(ctx), 0) } + s.mu.Unlock() // For simplicity, and consistency with Linux, we copy in the entire // message up front. @@ -1001,7 +1002,18 @@ func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, return 0, syserr.FromError(err) } - // Because we ensure the kernel is the only valid destination, + if dstGroup != 0 { + s.netns.NetlinkMcastTable().ForEachMcastSock(s.Protocol(), dstGroup, func(ns inet.NetlinkSocket) { + if ns.(*Socket) == s { + return + } + cms := transport.ControlMessages{ + Credentials: control.MakeCreds(kernel.TaskFromContext(ctx)), + } + ns.(*Socket).sendBufs(ctx, [][]byte{buf}, cms, uint32(s.GetPortID())) + }) + } + // Because we ensure the kernel is the only valid port destination, // we can start processing here. if err := s.protocol.Receive(ctx, s, buf); err != nil { return 0, err @@ -1027,5 +1039,7 @@ func (s *Socket) Type() (family int, skType linux.SockType, protocol int) { // GetPortID returns the port ID of the NETLINK socket. func (s *Socket) GetPortID() int32 { + s.mu.Lock() + defer s.mu.Unlock() return s.portID } diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index 86a438e6c4..f46cb11305 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -2167,6 +2167,263 @@ TEST(NetlinkRouteTest, LinkMulticastGroupEnobufs) { SyscallFailsWithErrno(ENOBUFS)); } +// Tests the ability of a userspace process to simultaneously send an +// RTM_SETLINK message to both the kernel and the userland RTMGRP_LINK +// subscribers. +TEST(NetlinkRouteTest, LinkMulticastGroupUserToUserSend) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; + + const struct sockaddr_nl recv_addr = { + .nl_family = AF_NETLINK, + .nl_groups = RTMGRP_LINK, + }; + FileDescriptor nlsk_recv = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &recv_addr)); + int passcred_val = 1; + ASSERT_THAT(setsockopt(nlsk_recv.get(), SOL_SOCKET, SO_PASSCRED, + &passcred_val, sizeof(passcred_val)), + SyscallSucceeds()); + + FileDescriptor nlsk_send = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + struct sockaddr_nl nlsk_send_addr = {}; + socklen_t addr_len = sizeof(nlsk_send_addr); + ASSERT_THAT(getsockname(nlsk_send.get(), (struct sockaddr*)&nlsk_send_addr, + &addr_len), + SyscallSucceeds()); + + // Send an RTM_SETLINK message to change the MTU of the loopback + // interface. + struct sockaddr_nl send_addr = { + .nl_family = AF_NETLINK, + .nl_pid = 0, // Send to the kernel. + .nl_groups = RTMGRP_LINK, // and also send to userland subscribers. + }; + const Link link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + MtuRequest mtu_request = GetMtuRequest(link, RTM_SETLINK, link.mtu + 10); + mtu_request.hdr.nlmsg_pid = getpid(); + struct iovec iov = { + .iov_base = &mtu_request, + .iov_len = sizeof(mtu_request), + }; + struct msghdr msg = { + .msg_name = &send_addr, + .msg_namelen = sizeof(send_addr), + .msg_iov = &iov, + .msg_iovlen = 1, + }; + ASSERT_THAT(RetryEINTR(sendmsg)(nlsk_send.get(), &msg, 0), SyscallSucceeds()); + + // And verify we received a response from the kernel. + constexpr int kPollTimeoutMs = 1000; + struct pollfd pfd = {.fd = nlsk_send.get(), .events = POLLIN}; + ASSERT_EQ(RetryEINTR(poll)(&pfd, 1, kPollTimeoutMs), 1) + << "nlsk_send: Did not get the expected unicast from the kernel."; + bool got_msg = false; + ASSERT_NO_ERRNO(NetlinkResponse( + nlsk_send, + [&](const struct nlmsghdr* hdr) { + EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type); + EXPECT_EQ(hdr->nlmsg_seq, kSeq); + EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); + const struct nlmsgerr* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + ASSERT_EQ(msg->error, 0); + got_msg = true; + }, + /*expect_nlmsgerr=*/true)); + ASSERT_TRUE(got_msg) << "Did not get a response from the kernel."; + + // Now that we know the kernel has processed the message, setup a cleanup. + auto restore_mtu = Cleanup([&]() { + MtuRequest mtu_request = GetMtuRequest(link, RTM_SETLINK, link.mtu); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(nlsk_send, kSeq, &mtu_request, + sizeof(mtu_request))); + }); + + // Verify that the RTM_SETLINK we sent was also received by the RTMGRP_LINK + // subscriber. We also expect an RTM_NEWLINK message from the kernel. Hence + // the two attempts. + bool got_user_msg = false; + for (int i = 0; i < 2; ++i) { + struct pollfd pfd2 = {.fd = nlsk_recv.get(), .events = POLLIN}; + if (RetryEINTR(poll)(&pfd2, 1, kPollTimeoutMs) < 1) break; + + char control[CMSG_SPACE(sizeof(struct ucred))] = {}; + constexpr size_t kBufferSize = 4096; + std::vector buf(kBufferSize); + iov = { + .iov_base = buf.data(), + .iov_len = buf.size(), + }; + struct sockaddr_nl from_addr = {}; + msg = { + .msg_name = &from_addr, + .msg_namelen = sizeof(from_addr), + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = control, + .msg_controllen = sizeof(control), + }; + + int len; + ASSERT_THAT(len = RetryEINTR(recvmsg)(nlsk_recv.get(), &msg, 0), + SyscallSucceeds()); + ASSERT_NE(msg.msg_flags & MSG_TRUNC, + MSG_TRUNC); // The buf we gave was big. + + struct nlmsghdr* hdr = reinterpret_cast(buf.data()); + for (; NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) { + ASSERT_TRUE(NLMSG_OK(hdr, len)); + // Ignore the kernel's message. + if (hdr->nlmsg_type == RTM_NEWLINK) { + EXPECT_EQ(from_addr.nl_pid, 0); + continue; + } + + ASSERT_EQ(hdr->nlmsg_type, RTM_SETLINK); + ASSERT_EQ(hdr->nlmsg_pid, getpid()); + EXPECT_EQ(from_addr.nl_family, AF_NETLINK); + EXPECT_EQ(from_addr.nl_pid, nlsk_send_addr.nl_pid); + + struct ucred creds; + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(creds))); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); + memcpy(&creds, CMSG_DATA(cmsg), sizeof(creds)); + EXPECT_EQ(creds.pid, getpid()); + got_user_msg = true; + } + } + EXPECT_TRUE(got_user_msg) << "RTMGRP_LINK subscriber did not receive the " + "expected RTM_SETLINK message."; +} + +TEST(NetlinkRouteTest, LinkMulticastGroupCapNetAdmin) { + SKIP_IF(IsRunningWithHostinet()); + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; + AutoCapability cap(CAP_NET_ADMIN, false); + + FileDescriptor nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + struct sockaddr_nl send_addr = { + .nl_family = AF_NETLINK, + .nl_pid = 0, + .nl_groups = RTMGRP_LINK, + }; + + const Link link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + MtuRequest mtu_request = GetMtuRequest(link, RTM_SETLINK, link.mtu + 10); + struct iovec iov = { + .iov_base = &mtu_request, + .iov_len = sizeof(mtu_request), + }; + struct msghdr msg = { + .msg_name = &send_addr, + .msg_namelen = sizeof(send_addr), + .msg_iov = &iov, + .msg_iovlen = 1, + }; + + EXPECT_THAT(RetryEINTR(sendmsg)(nlsk.get(), &msg, 0), + SyscallFailsWithErrno(EPERM)); + + // But joining a multicast group should not require CAP_NET_ADMIN. + struct sockaddr_nl mcast_addr = { + .nl_family = AF_NETLINK, + .nl_groups = RTMGRP_LINK, + }; + ASSERT_NO_ERRNO(NetlinkBoundSocket(NETLINK_ROUTE, &mcast_addr)); +} + +TEST(NetlinkRouteTest, LinkMulticastGroupNoSelfBroadcast) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; + + const struct sockaddr_nl linkgrp_addr = { + .nl_family = AF_NETLINK, + .nl_groups = RTMGRP_LINK, + }; + FileDescriptor nlsk_send = ASSERT_NO_ERRNO_AND_VALUE( + NetlinkBoundSocket(NETLINK_ROUTE, &linkgrp_addr)); + FileDescriptor nlsk_recv = ASSERT_NO_ERRNO_AND_VALUE( + NetlinkBoundSocket(NETLINK_ROUTE, &linkgrp_addr)); + + // Create an RTM_GETLINK request. + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + request req = {}; + req.hdr = { + .nlmsg_len = sizeof(req), + .nlmsg_type = RTM_GETLINK, + .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP, + .nlmsg_seq = kSeq, + }; + req.ifm.ifi_family = AF_UNSPEC; + + // And send it to both the kernel and the RTMGRP_LINK group. + struct iovec iov = { + .iov_base = &req, + .iov_len = sizeof(req), + }; + struct sockaddr_nl send_addr = linkgrp_addr; + ASSERT_EQ(send_addr.nl_pid, 0); + struct msghdr msg = { + .msg_name = &send_addr, + .msg_namelen = sizeof(send_addr), + .msg_iov = &iov, + .msg_iovlen = 1, + }; + ASSERT_THAT(RetryEINTR(sendmsg)(nlsk_send.get(), &msg, 0), SyscallSucceeds()); + + // The sender should only receive the kernel's response, not its own + // broadcast. + constexpr int kPollTimeoutMs = 500; + bool got_response = false; + for (int i = 0; i < 2; ++i) { + struct pollfd pfd = {.fd = nlsk_send.get(), .events = POLLIN}; + if (RetryEINTR(poll)(&pfd, 1, kPollTimeoutMs) != 1) { + break; + } + ASSERT_NO_ERRNO(NetlinkResponse( + nlsk_send, + [&](const struct nlmsghdr* hdr) { + ASSERT_NE(hdr->nlmsg_type, RTM_GETLINK); // No self-broadcast. + if (hdr->nlmsg_type == RTM_NEWLINK) { + got_response = true; + } + }, + /*expect_nlmsgerr=*/false)); + } + EXPECT_TRUE(got_response) << "Did not get a response from the kernel."; + + // Whereas the bystander RTMGRP_LINK subscriber should get the broadcast. + bool got_broadcast = false; + { + struct pollfd pfd = {.fd = nlsk_recv.get(), .events = POLLIN}; + ASSERT_EQ(RetryEINTR(poll)(&pfd, 1, kPollTimeoutMs), 1) + << "RTMGRP_LINK subscriber did not get the broadcast."; + ASSERT_NO_ERRNO(NetlinkResponse( + nlsk_recv, + [&](const struct nlmsghdr* hdr) { + ASSERT_EQ(hdr->nlmsg_type, RTM_GETLINK); + got_broadcast = true; + }, + /*expect_nlmsgerr=*/false)); + } + EXPECT_TRUE(got_broadcast) << "Did not get the broadcast."; +} + } // namespace } // namespace testing