From 8b0e1620109a51f06306a0f2643d61fa8f1d5631 Mon Sep 17 00:00:00 2001 From: Amaindex Date: Wed, 13 May 2026 23:45:02 -0700 Subject: [PATCH] netstack: let ICMP echo handlers consume requests IPv4 echo requests are delivered to the transport dispatcher before netstack builds the in-stack echo reply, but a per-stack default handler cannot tell the ICMP endpoint that it has taken ownership of the request. This leaves embedders that install a custom ICMP handler with no direct way to suppress the built-in reply. Add a transport dispatcher extension that reports whether the per-stack default transport handler consumed the packet, and use that signal in the IPv4 and IPv6 ICMP echo paths. For IPv6, deliver echo requests through the same endpoint/default-handler path before deciding whether to synthesize the built-in reply. Keep ordinary endpoint delivery distinct from default-handler ownership so registered ICMP endpoints and raw sockets can observe echo requests without suppressing the normal reply path. This also preserves the IPv4 temporary-address behavior introduced by the earlier echo-handling fix. Signed-off-by: Zi Li Signed-off-by: Amaindex --- pkg/tcpip/network/ipv4/icmp.go | 19 +- pkg/tcpip/network/ipv4/ipv4_test.go | 313 ++++++++++++++++++++++++++++ pkg/tcpip/network/ipv6/icmp.go | 27 ++- pkg/tcpip/network/ipv6/icmp_test.go | 227 ++++++++++++++++++++ pkg/tcpip/stack/nic.go | 27 ++- pkg/tcpip/stack/registration.go | 15 ++ 6 files changed, 612 insertions(+), 16 deletions(-) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 980fade11c..d8baf462d5 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -360,14 +360,21 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { localAddressTemporary := pkt.NetworkPacketInfo.LocalAddressTemporary localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast - // It's possible that a raw socket or custom defaultHandler expects to - // receive this packet. - e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) + // It's possible that a raw socket or per-stack default handler expects + // to receive this packet. + defaultHandlerHandled := false + if dispatcher, ok := e.dispatcher.(stack.TransportDispatcherWithDefaultHandlerResult); ok { + _, defaultHandlerHandled = dispatcher.DeliverTransportPacketWithDefaultHandlerResult(header.ICMPv4ProtocolNumber, pkt) + } else { + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) + } pkt = nil - // Skip direct ICMP echo reply if the packet was received with a temporary - // address, allowing custom handlers to take over. - if localAddressTemporary { + // Skip the built-in ICMP echo reply if the request was consumed by a + // per-stack default handler. Also preserve the IPv4 behavior for + // temporary local addresses: the packet is delivered above, but the + // stack does not synthesize an echo reply for it. + if defaultHandlerHandled || localAddressTemporary { return } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index ec1a0e3d63..df69487c29 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -3886,6 +3886,319 @@ func TestCloseLocking(t *testing.T) { }() } +func TestICMPEchoDefaultHandlerControlsReply(t *testing.T) { + var ( + localAddr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + remoteAddr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 24, + }, + } + ) + + tests := []struct { + name string + installHandler bool + handled bool + wantHandlerCalled bool + wantReply bool + }{ + { + name: "no default handler", + wantReply: true, + }, + { + name: "default handler handled", + installHandler: true, + handled: true, + wantHandlerCalled: true, + wantReply: false, + }, + { + name: "default handler not handled", + installHandler: true, + handled: false, + wantHandlerCalled: true, + wantReply: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + Clock: clock, + }) + defer func() { + s.Close() + s.Wait() + refs.DoRepeatedLeakCheck() + }() + + const ident = 1234 + handlerCalled := false + if test.installHandler { + s.SetTransportProtocolHandler(icmp.ProtocolNumber4, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + handlerCalled = true + if got := id.LocalPort; got != ident { + t.Errorf("got id.LocalPort = %d, want = %d", got, ident) + } + return test.handled + }) + } + + e := channel.New(1, defaultMTU, "") + defer e.Close() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, localAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, localAddr, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: localAddr.AddressWithPrefix.Subnet(), + NIC: nicID, + }}) + + totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := prependable.New(totalLength) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH.SetIdent(ident) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^checksum.Checksum(icmpH, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: remoteAddr.AddressWithPrefix.Address, + DstAddr: localAddr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + echoPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(hdr.View()), + }) + e.InjectInbound(header.IPv4ProtocolNumber, echoPkt) + echoPkt.DecRef() + clock.RunImmediatelyScheduledJobs() + + if got, want := handlerCalled, test.wantHandlerCalled; got != want { + t.Fatalf("got handlerCalled = %t, want = %t", got, want) + } + + p := e.Read() + if !test.wantReply { + if p != nil { + p.DecRef() + t.Fatalf("got unexpected ICMP echo reply") + } + return + } + if p == nil { + t.Fatalf("expected ICMP echo reply") + } + defer p.DecRef() + payload := stack.PayloadSince(p.NetworkHeader()) + defer payload.Release() + checker.IPv4(t, payload, + checker.SrcAddr(localAddr.AddressWithPrefix.Address), + checker.DstAddr(remoteAddr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + checker.ICMPv4Code(header.ICMPv4UnusedCode))) + }) + } +} + +func TestICMPEchoRegisteredEndpointDoesNotSuppressReply(t *testing.T) { + localAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + remoteAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 24, + }, + } + + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + Clock: clock, + }) + defer func() { + s.Close() + s.Wait() + refs.DoRepeatedLeakCheck() + }() + + e := channel.New(1, defaultMTU, "") + defer e.Close() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, localAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, localAddr, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: localAddr.AddressWithPrefix.Subnet(), + NIC: nicID, + }}) + + const ident = 1234 + var wq waiter.Queue + ep, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err) + } + defer ep.Close() + if err := ep.Bind(tcpip.FullAddress{Addr: localAddr.AddressWithPrefix.Address, Port: ident}); err != nil { + t.Fatalf("ep.Bind(...) = %s", err) + } + + handlerCalled := false + s.SetTransportProtocolHandler(icmp.ProtocolNumber4, func(stack.TransportEndpointID, *stack.PacketBuffer) bool { + handlerCalled = true + return true + }) + + totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := prependable.New(totalLength) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH.SetIdent(ident) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^checksum.Checksum(icmpH, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: remoteAddr.AddressWithPrefix.Address, + DstAddr: localAddr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + echoPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(hdr.View()), + }) + e.InjectInbound(header.IPv4ProtocolNumber, echoPkt) + echoPkt.DecRef() + clock.RunImmediatelyScheduledJobs() + + if handlerCalled { + t.Fatalf("default handler was unexpectedly called") + } + + p := e.Read() + if p == nil { + t.Fatalf("expected ICMP echo reply") + } + defer p.DecRef() + payload := stack.PayloadSince(p.NetworkHeader()) + defer payload.Release() + checker.IPv4(t, payload, + checker.SrcAddr(localAddr.AddressWithPrefix.Address), + checker.DstAddr(remoteAddr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + checker.ICMPv4Code(header.ICMPv4UnusedCode))) +} + +func TestICMPEchoTemporaryAddressSuppressesReply(t *testing.T) { + assignedAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + temporaryAddr := tcpip.AddrFromSlice(net.ParseIP("192.168.0.99").To4()) + remoteAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 24, + }, + } + + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + Clock: clock, + }) + defer func() { + s.Close() + s.Wait() + refs.DoRepeatedLeakCheck() + }() + + e := channel.New(1, defaultMTU, "") + defer e.Close() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, assignedAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, assignedAddr, err) + } + if err := s.SetPromiscuousMode(nicID, true); err != nil { + t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: assignedAddr.AddressWithPrefix.Subnet(), + NIC: nicID, + }}) + + const ident = 1234 + totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := prependable.New(totalLength) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH.SetIdent(ident) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^checksum.Checksum(icmpH, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: remoteAddr.AddressWithPrefix.Address, + DstAddr: temporaryAddr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + echoPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(hdr.View()), + }) + e.InjectInbound(header.IPv4ProtocolNumber, echoPkt) + echoPkt.DecRef() + clock.RunImmediatelyScheduledJobs() + + if p := e.Read(); p != nil { + p.DecRef() + t.Fatalf("got unexpected ICMP echo reply") + } +} + func TestIcmpRateLimit(t *testing.T) { var ( host1IPv4Addr = tcpip.ProtocolAddress{ diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index eb85f4505c..b921d9d8eb 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -654,6 +654,27 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6EchoRequest: received.echoRequest.Increment() + replyPayload := pkt.Data().ToBuffer() + replyHeader := make([]byte, header.ICMPv6EchoMinimumSize) + copy(replyHeader, h[:header.ICMPv6EchoMinimumSize]) + + // It's possible that a raw socket or per-stack default handler expects + // to receive this packet. + defaultHandlerHandled := false + if dispatcher, ok := e.dispatcher.(stack.TransportDispatcherWithDefaultHandlerResult); ok { + _, defaultHandlerHandled = dispatcher.DeliverTransportPacketWithDefaultHandlerResult(header.ICMPv6ProtocolNumber, pkt) + } else { + e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt) + } + pkt = nil + + // Skip the built-in ICMP echo reply if the request was consumed by a + // per-stack default handler. + if defaultHandlerHandled { + replyPayload.Release() + return + } + // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. localAddr := dstAddr @@ -664,23 +685,25 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, srcAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. + replyPayload.Release() return } defer r.Release() if !e.protocol.allowICMPReply(header.ICMPv6EchoReply) { sent.rateLimited.Increment() + replyPayload.Release() return } replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, - Payload: pkt.Data().ToBuffer(), + Payload: replyPayload, }) defer replyPkt.DecRef() icmp := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) replyPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber - copy(icmp, h) + copy(icmp, replyHeader) icmp.SetType(header.ICMPv6EchoReply) replyData := replyPkt.Data() icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index dc4fc3b1f7..47de440d3e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -229,6 +229,233 @@ func (c *testContext) cleanup() { refs.DoRepeatedLeakCheck() } +func TestICMPEchoDefaultHandlerControlsReply(t *testing.T) { + var ( + localAddr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(net.ParseIP("a::1").To16()), + PrefixLen: 64, + }, + } + remoteAddr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(net.ParseIP("a::2").To16()), + PrefixLen: 64, + }, + } + ) + + tests := []struct { + name string + installHandler bool + handled bool + wantHandlerCalled bool + wantReply bool + }{ + { + name: "no default handler", + wantReply: true, + }, + { + name: "default handler handled", + installHandler: true, + handled: true, + wantHandlerCalled: true, + wantReply: false, + }, + { + name: "default handler not handled", + installHandler: true, + handled: false, + wantHandlerCalled: true, + wantReply: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newTestContext() + defer c.cleanup() + s := c.s + + const ident = 1234 + handlerCalled := false + if test.installHandler { + s.SetTransportProtocolHandler(icmp.ProtocolNumber6, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + handlerCalled = true + if got := id.LocalPort; got != ident { + t.Errorf("got id.LocalPort = %d, want = %d", got, ident) + } + return test.handled + }) + } + + e := channel.New(1, defaultMTU, "") + defer e.Close() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, localAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, localAddr, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: localAddr.AddressWithPrefix.Subnet(), + NIC: nicID, + }}) + + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := prependable.New(totalLen) + icmpH := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpH.SetIdent(ident) + icmpH.SetType(header.ICMPv6EchoRequest) + icmpH.SetCode(header.ICMPv6UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH, + Src: remoteAddr.AddressWithPrefix.Address, + Dst: localAddr.AddressWithPrefix.Address, + })) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: DefaultTTL, + SrcAddr: remoteAddr.AddressWithPrefix.Address, + DstAddr: localAddr.AddressWithPrefix.Address, + }) + echoPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(hdr.View()), + }) + e.InjectInbound(ProtocolNumber, echoPkt) + echoPkt.DecRef() + c.clock.RunImmediatelyScheduledJobs() + + if got, want := handlerCalled, test.wantHandlerCalled; got != want { + t.Fatalf("got handlerCalled = %t, want = %t", got, want) + } + + p := e.Read() + if !test.wantReply { + if p != nil { + p.DecRef() + t.Fatalf("got unexpected ICMP echo reply") + } + return + } + if p == nil { + t.Fatalf("expected ICMP echo reply") + } + defer p.DecRef() + payload := stack.PayloadSince(p.NetworkHeader()) + defer payload.Release() + checker.IPv6(t, payload, + checker.SrcAddr(localAddr.AddressWithPrefix.Address), + checker.DstAddr(remoteAddr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply), + checker.ICMPv6Code(header.ICMPv6UnusedCode))) + }) + } +} + +func TestICMPEchoRegisteredEndpointDoesNotSuppressReply(t *testing.T) { + localAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(net.ParseIP("a::1").To16()), + PrefixLen: 64, + }, + } + remoteAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(net.ParseIP("a::2").To16()), + PrefixLen: 64, + }, + } + + c := newTestContext() + defer c.cleanup() + s := c.s + + e := channel.New(1, defaultMTU, "") + defer e.Close() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, localAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, localAddr, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: localAddr.AddressWithPrefix.Subnet(), + NIC: nicID, + }}) + + const ident = 1234 + var wq waiter.Queue + ep, err := s.NewEndpoint(icmp.ProtocolNumber6, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber6, ProtocolNumber, err) + } + defer ep.Close() + if err := ep.Bind(tcpip.FullAddress{Addr: localAddr.AddressWithPrefix.Address, Port: ident}); err != nil { + t.Fatalf("ep.Bind(...) = %s", err) + } + + handlerCalled := false + s.SetTransportProtocolHandler(icmp.ProtocolNumber6, func(stack.TransportEndpointID, *stack.PacketBuffer) bool { + handlerCalled = true + return true + }) + + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := prependable.New(totalLen) + icmpH := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpH.SetIdent(ident) + icmpH.SetType(header.ICMPv6EchoRequest) + icmpH.SetCode(header.ICMPv6UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH, + Src: remoteAddr.AddressWithPrefix.Address, + Dst: localAddr.AddressWithPrefix.Address, + })) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: DefaultTTL, + SrcAddr: remoteAddr.AddressWithPrefix.Address, + DstAddr: localAddr.AddressWithPrefix.Address, + }) + echoPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(hdr.View()), + }) + e.InjectInbound(ProtocolNumber, echoPkt) + echoPkt.DecRef() + c.clock.RunImmediatelyScheduledJobs() + + if handlerCalled { + t.Fatalf("default handler was unexpectedly called") + } + + p := e.Read() + if p == nil { + t.Fatalf("expected ICMP echo reply") + } + defer p.DecRef() + payload := stack.PayloadSince(p.NetworkHeader()) + defer payload.Release() + checker.IPv6(t, payload, + checker.SrcAddr(localAddr.AddressWithPrefix.Address), + checker.DstAddr(remoteAddr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply), + checker.ICMPv6Code(header.ICMPv6UnusedCode))) +} + func TestICMPCounts(t *testing.T) { c := newTestContext() defer c.cleanup() diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index c4275623f9..112b240088 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -841,23 +841,34 @@ func (n *nic) DeliverLinkPacket(protocol tcpip.NetworkProtocolNumber, pkt *Packe // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { + res, _ := n.deliverTransportPacket(protocol, pkt) + return res +} + +// DeliverTransportPacketWithDefaultHandlerResult implements +// TransportDispatcherWithDefaultHandlerResult.DeliverTransportPacketWithDefaultHandlerResult. +func (n *nic) DeliverTransportPacketWithDefaultHandlerResult(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) (TransportPacketDisposition, bool) { + return n.deliverTransportPacket(protocol, pkt) +} + +func (n *nic) deliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) (TransportPacketDisposition, bool) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stats.unknownL4ProtocolRcvdPacketCounts.Increment(uint64(protocol)) - return TransportPacketProtocolUnreachable + return TransportPacketProtocolUnreachable, false } transProto := state.proto if len(pkt.TransportHeader().Slice()) == 0 { n.stats.malformedL4RcvdPackets.Increment() - return TransportPacketHandled + return TransportPacketHandled, false } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().Slice()) if err != nil { n.stats.malformedL4RcvdPackets.Increment() - return TransportPacketHandled + return TransportPacketHandled, false } netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber] @@ -873,13 +884,13 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt RemoteAddress: src, } if n.stack.demux.deliverPacket(protocol, pkt, id) { - return TransportPacketHandled + return TransportPacketHandled, false } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { if state.defaultHandler(id, pkt) { - return TransportPacketHandled + return TransportPacketHandled, true } } @@ -889,11 +900,11 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: n.stats.malformedL4RcvdPackets.Increment() - return TransportPacketHandled + return TransportPacketHandled, false case UnknownDestinationPacketUnhandled: - return TransportPacketDestinationPortUnreachable + return TransportPacketDestinationPortUnreachable, false case UnknownDestinationPacketHandled: - return TransportPacketHandled + return TransportPacketHandled, false default: panic(fmt.Sprintf("unrecognized result from HandleUnknownDestinationPacket = %d", res)) } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 2c23eee23c..05052ede81 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -365,6 +365,21 @@ type TransportDispatcher interface { DeliverRawPacket(tcpip.TransportProtocolNumber, *PacketBuffer) } +// TransportDispatcherWithDefaultHandlerResult extends TransportDispatcher with +// default-handler-specific delivery metadata. +type TransportDispatcherWithDefaultHandlerResult interface { + TransportDispatcher + + // DeliverTransportPacketWithDefaultHandlerResult delivers packets to the + // appropriate transport protocol endpoint and reports whether the packet was + // specifically handled by the per-stack default transport protocol handler. + // + // pkt.NetworkHeader must be set before calling this method. + // + // DeliverTransportPacketWithDefaultHandlerResult may modify the packet. + DeliverTransportPacketWithDefaultHandlerResult(tcpip.TransportProtocolNumber, *PacketBuffer) (TransportPacketDisposition, bool) +} + // PacketLooping specifies where an outbound packet should be sent. type PacketLooping byte