diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 43a8c651e5..57cea3e083 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1764,6 +1764,30 @@ func getSockOptIP(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, return nil, err } return &ret, nil + + case linux.IP_MTU_DISCOVER: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.MTUDiscoverOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + switch tcpip.PMTUDStrategy(v) { + case tcpip.PMTUDiscoveryWant: + v = linux.IP_PMTUDISC_WANT + case tcpip.PMTUDiscoveryDont: + v = linux.IP_PMTUDISC_DONT + case tcpip.PMTUDiscoveryDo: + v = linux.IP_PMTUDISC_DO + case tcpip.PMTUDiscoveryProbe: + v = linux.IP_PMTUDISC_PROBE + default: + panic(fmt.Errorf("unknown PMTUD option: %d", v)) + } + vP := primitive.Int32(v) + return &vP, nil } return nil, syserr.ErrProtocolNotAvailable } @@ -2578,6 +2602,28 @@ func setSockOptIP(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported") return nil + case linux.IP_MTU_DISCOVER: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + switch v { + case linux.IP_PMTUDISC_DONT: + v = int32(tcpip.PMTUDiscoveryDont) + case linux.IP_PMTUDISC_WANT: + v = int32(tcpip.PMTUDiscoveryWant) + case linux.IP_PMTUDISC_DO: + v = int32(tcpip.PMTUDiscoveryDo) + case linux.IP_PMTUDISC_PROBE: + v = int32(tcpip.PMTUDiscoveryProbe) + default: + return syserr.ErrNotSupported + } + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MTUDiscoverOption, int(v))) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -2587,7 +2633,6 @@ func setSockOptIP(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, linux.IP_IPSEC_POLICY, linux.IP_MINTTL, linux.IP_MSFILTER, - linux.IP_MTU_DISCOVER, linux.IP_MULTICAST_ALL, linux.IP_NODEFRAG, linux.IP_OPTIONS, diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index d7eab68c60..37796905e2 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -462,19 +462,26 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet if length > math.MaxUint16 { return &tcpip.ErrMessageTooLong{} } - // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic - // datagrams. Since the DF bit is never being set here, all datagrams - // are non-atomic and need an ID. - ipH.Encode(&header.IPv4Fields{ + + fields := header.IPv4Fields{ TotalLength: uint16(length), - ID: e.getID(), TTL: params.TTL, TOS: params.TOS, Protocol: uint8(params.Protocol), SrcAddr: srcAddr, DstAddr: dstAddr, Options: options, - }) + } + if params.DF { + // Treat want and do the same. + fields.Flags = header.IPv4FlagDontFragment + } else { + // RFC 6864 section 4.3 mandates uniqueness of ID values for + // non-atomic datagrams. + fields.ID = e.getID() + } + ipH.Encode(&fields) + ipH.SetChecksum(^ipH.CalculateChecksum()) pkt.NetworkProtocolNumber = ProtocolNumber return nil diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 5197757693..8ac23f164e 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -319,6 +319,9 @@ type NetworkHeaderParams struct { // TOS refers to TypeOfService or TrafficClass field of the IP-header. TOS uint8 + + // DF indicates whether the DF bit should be set. + DF bool } // GroupAddressableEndpoint is an endpoint that supports group addressing. diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 7c21c43c1c..59fb3534d4 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -995,10 +995,13 @@ const ( UseDefaultIPv6HopLimit = -1 ) +// PMTUDStrategy is the kind of PMTUD to perform. +type PMTUDStrategy int + const ( // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use // per-route settings. - PMTUDiscoveryWant int = iota + PMTUDiscoveryWant PMTUDStrategy = iota // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable // path MTU discovery. diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index 9d61e881e8..bb7cda6196 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -797,7 +797,7 @@ func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { case tcpip.MTUDiscoverOption: // Return not supported if the value is not disabling path // MTU discovery. - if v != tcpip.PMTUDiscoveryDont { + if tcpip.PMTUDStrategy(v) != tcpip.PMTUDiscoveryDont { return &tcpip.ErrNotSupported{} } @@ -835,7 +835,7 @@ func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.MTUDiscoverOption: // The only supported setting is path MTU discovery disabled. - return tcpip.PMTUDiscoveryDont, nil + return int(tcpip.PMTUDiscoveryDont), nil case tcpip.MulticastTTLOption: e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 062d75533b..7de3fe9e45 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -794,6 +794,7 @@ type tcpFields struct { rcvWnd seqnum.Size opts []byte txHash uint32 + df bool } func (e *Endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error { @@ -881,7 +882,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso sta buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) pkt.GSOOptions = gso - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos, DF: tf.df}, pkt); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() if shouldSplitPacket { pkt.DecRef() @@ -913,7 +914,7 @@ func sendTCP(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso stack.GS pkt.Owner = owner buildTCPHdr(r, tf, pkt, gso) - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos, DF: tf.df}, pkt); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } @@ -964,6 +965,9 @@ func (e *Endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { } // sendEmptyRaw sends a TCP segment with no payload to the endpoint's peer. +// +// +checklocks:e.mu +// +checklocksalias:e.snd.ep.mu=e.mu func (e *Endpoint) sendEmptyRaw(flags header.TCPFlags, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{}) defer pkt.DecRef() @@ -972,6 +976,9 @@ func (e *Endpoint) sendEmptyRaw(flags header.TCPFlags, seq, ack seqnum.Value, rc // sendRaw sends a TCP segment to the endpoint's peer. This method takes // ownership of pkt. pkt must not have any headers set. +// +// +checklocks:e.mu +// +checklocksalias:e.snd.ep.mu=e.mu func (e *Endpoint) sendRaw(pkt *stack.PacketBuffer, flags header.TCPFlags, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { var sackBlocks []header.SACKBlock if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { @@ -989,6 +996,7 @@ func (e *Endpoint) sendRaw(pkt *stack.PacketBuffer, flags header.TCPFlags, seq, ack: ack, rcvWnd: rcvWnd, opts: options, + df: e.pmtud == tcpip.PMTUDiscoveryWant || e.pmtud == tcpip.PMTUDiscoveryDo, }, pkt, e.gso) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b9ae74eb00..9747cee95f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -600,6 +600,11 @@ type Endpoint struct { // // +checklocks:mu limRdr *io.LimitedReader `state:"nosave"` + + // pmtud is the PMTUD strategy to use. + // + // +checklocks:mu + pmtud tcpip.PMTUDStrategy } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -1890,9 +1895,16 @@ func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { e.UnlockUser() case tcpip.MTUDiscoverOption: - // Return not supported if attempting to set this option to - // anything other than path MTU discovery disabled. - if v != tcpip.PMTUDiscoveryDont { + switch v := tcpip.PMTUDStrategy(v); v { + case tcpip.PMTUDiscoveryWant, tcpip.PMTUDiscoveryDont, tcpip.PMTUDiscoveryDo: + e.LockUser() + e.pmtud = v + e.UnlockUser() + case tcpip.PMTUDiscoveryProbe: + // We don't support a way to ignore MTU updates; it's + // either on or it's off. + return &tcpip.ErrNotSupported{} + default: return &tcpip.ErrNotSupported{} } @@ -2089,9 +2101,10 @@ func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { return v, nil case tcpip.MTUDiscoverOption: - // Always return the path MTU discovery disabled setting since - // it's the only one supported. - return tcpip.PMTUDiscoveryDont, nil + e.LockUser() + v := e.pmtud + e.UnlockUser() + return int(v), nil case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index ab332609d6..5aea49e5d1 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -18,6 +18,7 @@ #include #include #endif // __linux__ +#include #include #include #include @@ -2327,6 +2328,49 @@ TEST_P(TcpSocketTest, GetSocketAcceptConnNonListener) { EXPECT_EQ(got, 0); } +TEST_P(TcpSocketTest, SetPMTUD) { + // IP_PMTUDISC_WANT should be default. + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT( + getsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &got, &length), + SyscallSucceeds()); + EXPECT_EQ(got, IP_PMTUDISC_WANT); + + int set = IP_PMTUDISC_DO; + ASSERT_THAT( + setsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &set, length), + SyscallSucceeds()); + ASSERT_THAT( + getsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &got, &length), + SyscallSucceeds()); + EXPECT_EQ(got, IP_PMTUDISC_DO); + set = IP_PMTUDISC_DONT; + ASSERT_THAT( + setsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &set, length), + SyscallSucceeds()); + ASSERT_THAT( + getsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &got, &length), + SyscallSucceeds()); + EXPECT_EQ(got, IP_PMTUDISC_DONT); + + // IP_PMTUDISC_PROBE isn's upported by gVisor. + set = IP_PMTUDISC_PROBE; + if (IsRunningOnGvisor()) { + ASSERT_THAT( + setsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &set, length), + SyscallFailsWithErrno(ENOTSUP)); + } else { + ASSERT_THAT( + setsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &set, length), + SyscallSucceeds()); + ASSERT_THAT( + getsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &got, &length), + SyscallSucceeds()); + EXPECT_EQ(got, IP_PMTUDISC_PROBE); + } +} + TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) { // TODO(b/171345701): Fix the TCP state for listening socket on shutdown. SKIP_IF(IsRunningOnGvisor());