From 5ca51622d9bb2833195bd617ad75bcf185cebb7d Mon Sep 17 00:00:00 2001 From: Constantine Peresypkin Date: Mon, 1 Dec 2025 10:17:16 -0800 Subject: [PATCH] tcpip: fix IP_MTU_DISCOVER flag for tcp and udp --- pkg/tcpip/network/ipv4/ipv4.go | 4 +--- .../transport/internal/network/endpoint.go | 23 +++++++++++++++---- pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 21 +++++++---------- pkg/tcpip/transport/udp/endpoint.go | 17 ++++++++++++++ 5 files changed, 45 insertions(+), 22 deletions(-) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 8d530b8033..c06452f017 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -609,9 +609,7 @@ func (e *endpoint) writePacketPostRouting(r *stack.Route, pkt *stack.PacketBuffe if packetMustBeFragmented(pkt, networkMTU) { h := header.IPv4(pkt.NetworkHeader().Slice()) - if h.Flags()&header.IPv4FlagDontFragment != 0 && pkt.NetworkPacketInfo.IsForwardedPacket { - // TODO(gvisor.dev/issue/5919): Handle error condition in which DontFragment - // is set but the packet must be fragmented for the non-forwarding case. + if h.Flags()&header.IPv4FlagDontFragment != 0 { return &tcpip.ErrMessageTooLong{} } sent, remain, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index c7522aac8a..cf46186797 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -74,6 +74,8 @@ type Endpoint struct { ipv4TOS uint8 // +checklocks:mu ipv6TClass uint8 + // +checklocks:mu + pmtudStrategy tcpip.PMTUDStrategy // Lock ordering: mu > infoMu. infoMu sync.RWMutex `state:"nosave"` @@ -358,10 +360,15 @@ func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool) expOptVal = c.e.ops.GetExperimentOptionValue() } + c.e.mu.RLock() + pmtud := c.e.pmtudStrategy + c.e.mu.RUnlock() + err := c.route.WritePacket(stack.NetworkHeaderParams{ Protocol: c.e.transProto, TTL: c.ttl, TOS: c.tos, + DF: pmtud == tcpip.PMTUDiscoveryWant || pmtud == tcpip.PMTUDiscoveryDo || pmtud == tcpip.PMTUDiscoveryProbe, ExperimentOptionValue: expOptVal, }, pkt) @@ -840,9 +847,13 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) { func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.MTUDiscoverOption: - // Return not supported if the value is not disabling path - // MTU discovery. - if tcpip.PMTUDStrategy(v) != tcpip.PMTUDiscoveryDont { + strategy := tcpip.PMTUDStrategy(v) + switch strategy { + case tcpip.PMTUDiscoveryWant, tcpip.PMTUDiscoveryDont, tcpip.PMTUDiscoveryDo, tcpip.PMTUDiscoveryProbe: + e.mu.Lock() + e.pmtudStrategy = strategy + e.mu.Unlock() + default: return &tcpip.ErrNotSupported{} } @@ -891,8 +902,10 @@ func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.MTUDiscoverOption: - // The only supported setting is path MTU discovery disabled. - return int(tcpip.PMTUDiscoveryDont), nil + e.mu.Lock() + v := int(e.pmtudStrategy) + e.mu.Unlock() + return v, nil case tcpip.MulticastTTLOption: e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index d31ce112f6..9e8e0962ef 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1042,7 +1042,7 @@ func (e *Endpoint) sendRaw(pkt *stack.PacketBuffer, flags header.TCPFlags, seq, ack: ack, rcvWnd: rcvWnd, opts: options, - df: e.pmtud == tcpip.PMTUDiscoveryWant || e.pmtud == tcpip.PMTUDiscoveryDo, + df: e.pmtud == tcpip.PMTUDiscoveryWant || e.pmtud == tcpip.PMTUDiscoveryDo || e.pmtud == tcpip.PMTUDiscoveryProbe, expOptVal: expOptVal, }, pkt, e.gso) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 7e12775a4d..611c78c2c9 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1895,18 +1895,9 @@ func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { e.UnlockUser() case tcpip.MTUDiscoverOption: - switch v := tcpip.PMTUDStrategy(v); v { - case tcpip.PMTUDiscoveryWant, tcpip.PMTUDiscoveryDont, tcpip.PMTUDiscoveryDo: - e.LockUser() - e.pmtud = v - e.UnlockUser() - case tcpip.PMTUDiscoveryProbe: - // We don't support a way to ignore MTU updates; it's - // either on or it's off. - return &tcpip.ErrNotSupported{} - default: - return &tcpip.ErrNotSupported{} - } + e.LockUser() + e.pmtud = tcpip.PMTUDStrategy(v) + e.UnlockUser() case tcpip.IPv4TTLOption: e.LockUser() @@ -2965,7 +2956,11 @@ func (e *Endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB // TODO(gvisor.dev/issues/5270): Handle all transport errors. switch transErr.Kind() { case stack.PacketTooBigTransportError: - handlePacketTooBig(transErr.Info()) + if e.pmtud == tcpip.PMTUDiscoveryProbe { + e.onICMPError(&tcpip.ErrMessageTooLong{}, transErr, pkt) + } else { + handlePacketTooBig(transErr.Info()) + } case stack.DestinationHostUnreachableTransportError: e.onICMPError(&tcpip.ErrHostUnreachable{}, transErr, pkt) case stack.DestinationNetworkUnreachableTransportError: diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index bd9dbbf0de..7d45bc485d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -520,6 +520,21 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { e.stack.Stats().UDP.PacketSendErrors.Increment() + if _, ok := err.(*tcpip.ErrMessageTooLong); ok { + so := e.SocketOptions() + if so.GetIPv4RecvError() || so.GetIPv6RecvError() { + so.QueueLocalErr( + err, + udpInfo.ctx.PacketInfo().NetProto, + udpInfo.ctx.MTU(), + tcpip.FullAddress{ + Addr: udpInfo.ctx.PacketInfo().RemoteAddress, + Port: udpInfo.remotePort, + }, + nil, + ) + } + } return 0, err } @@ -1050,6 +1065,8 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB if e.net.State() == transport.DatagramEndpointStateConnected { e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt) } + case stack.PacketTooBigTransportError: + e.onICMPError(&tcpip.ErrMessageTooLong{}, transErr, pkt) } }