Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pkg/sentry/inet/nlmcast.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ func (m *McastTable) RemoveSocket(s NetlinkSocket) {
delete(m.socks[p], s)
}

func (m *McastTable) forEachMcastSock(protocol int, mcastGroup int, fn func(s NetlinkSocket)) {
// ForEachMcastSock calls fn on all Netlink sockets that are members of the given multicast group.
func (m *McastTable) ForEachMcastSock(protocol int, mcastGroup int, fn func(s NetlinkSocket)) {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.socks[protocol]; !ok {
Expand All @@ -122,15 +123,15 @@ func (m *McastTable) forEachMcastSock(protocol int, mcastGroup int, fn func(s Ne
// OnInterfaceChangeEvent implements InterfaceEventSubscriber.OnInterfaceChangeEvent.
func (m *McastTable) OnInterfaceChangeEvent(ctx context.Context, idx int32, i Interface) {
// Relay the event to RTNLGRP_LINK subscribers.
m.forEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) {
m.ForEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) {
s.HandleInterfaceChangeEvent(ctx, idx, i)
})
}

// OnInterfaceDeleteEvent implements InterfaceEventSubscriber.OnInterfaceDeleteEvent.
func (m *McastTable) OnInterfaceDeleteEvent(ctx context.Context, idx int32, i Interface) {
// Relay the event to RTNLGRP_LINK subscribers.
m.forEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) {
m.ForEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) {
s.HandleInterfaceDeleteEvent(ctx, idx, i)
})
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sentry/socket/netlink/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ go_library(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/ktime",
"//pkg/sentry/socket",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/netlink/nlmsg",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/unix",
Expand Down
8 changes: 4 additions & 4 deletions pkg/sentry/socket/netlink/netfilter/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (p *Protocol) Receive(ctx context.Context, s *netlink.Socket, buf []byte) *
// TODO: b/434785410 - Support batch messages.
if hdr.Type == linux.NFNL_MSG_BATCH_BEGIN {
ms := nlmsg.NewMessageSet(s.GetPortID(), hdr.Seq)
if err := p.receiveBatchMessage(ctx, s, ms, buf); err != nil {
if err := p.receiveBatchMessage(ctx, ms, buf); err != nil {
log.Debugf("Nftables: Failed to process batch message: %v", err)
netlink.DumpErrorMessage(hdr, ms, err.GetError())
}
Expand Down Expand Up @@ -1215,7 +1215,7 @@ func (p *Protocol) ProcessMessage(ctx context.Context, s *netlink.Socket, msg *n
}

// receiveBatchMessage processes a NETFILTER batch message.
func (p *Protocol) receiveBatchMessage(ctx context.Context, s *netlink.Socket, ms *nlmsg.MessageSet, buf []byte) *syserr.AnnotatedError {
func (p *Protocol) receiveBatchMessage(ctx context.Context, ms *nlmsg.MessageSet, buf []byte) *syserr.AnnotatedError {
// Linux ignores messages that are too small.
// From net/netfilter/nfnetlink.c:nfnetlink_rcv_skb_batch
if len(buf) < linux.NetlinkMessageHeaderSize+linux.SizeOfNetfilterGenMsg {
Expand Down Expand Up @@ -1254,7 +1254,7 @@ func (p *Protocol) receiveBatchMessage(ctx context.Context, s *netlink.Socket, m
// The resource ID is a 16-bit value that is stored in network byte order.
// We ensure that it is in host byte order before passing it for processing.
resID := nlmsg.NetToHostU16(nfGenMsg.ResourceID)
if err := p.processBatchMessage(ctx, s, buf, ms, hdr, resID); err != nil {
if err := p.processBatchMessage(ctx, buf, ms, hdr, resID); err != nil {
log.Debugf("Failed to process batch message: %v", err)
netlink.DumpErrorMessage(hdr, ms, err.GetError())
}
Expand All @@ -1263,7 +1263,7 @@ func (p *Protocol) receiveBatchMessage(ctx context.Context, s *netlink.Socket, m
}

// processBatchMessage processes a batch message.
func (p *Protocol) processBatchMessage(ctx context.Context, s *netlink.Socket, buf []byte, ms *nlmsg.MessageSet, batchHdr linux.NetlinkMessageHeader, subsysID uint16) *syserr.AnnotatedError {
func (p *Protocol) processBatchMessage(ctx context.Context, buf []byte, ms *nlmsg.MessageSet, batchHdr linux.NetlinkMessageHeader, subsysID uint16) *syserr.AnnotatedError {
if subsysID >= linux.NFNL_SUBSYS_COUNT {
return syserr.NewAnnotatedError(syserr.ErrInvalidArgument, fmt.Sprintf("Nftables: Unknown subsystem id %d", subsysID))
}
Expand Down
Loading
Loading