Skip to content

Commit

Permalink
make PMTUD on by default and settable via sockopt
Browse files Browse the repository at this point in the history
We've supported PMTUD for a long time and just never turned it on.

Addresses #10344.

PiperOrigin-RevId: 632186215
  • Loading branch information
kevinGC authored and gvisor-bot committed May 9, 2024
1 parent d19a74b commit f628cb5
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 18 deletions.
47 changes: 46 additions & 1 deletion pkg/sentry/socket/netstack/netstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,30 @@ func getSockOptIP(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int,
return nil, err
}
return &ret, nil

case linux.IP_MTU_DISCOVER:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}

v, err := ep.GetSockOptInt(tcpip.MTUDiscoverOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
switch tcpip.PMTUDStrategy(v) {
case tcpip.PMTUDiscoveryWant:
v = linux.IP_PMTUDISC_WANT
case tcpip.PMTUDiscoveryDont:
v = linux.IP_PMTUDISC_DONT
case tcpip.PMTUDiscoveryDo:
v = linux.IP_PMTUDISC_DO
case tcpip.PMTUDiscoveryProbe:
v = linux.IP_PMTUDISC_PROBE
default:
panic(fmt.Errorf("unknown PMTUD option: %d", v))
}
vP := primitive.Int32(v)
return &vP, nil
}
return nil, syserr.ErrProtocolNotAvailable
}
Expand Down Expand Up @@ -2578,6 +2602,28 @@ func setSockOptIP(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int,
log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported")
return nil

case linux.IP_MTU_DISCOVER:
if len(optVal) == 0 {
return nil
}
v, err := parseIntOrChar(optVal)
if err != nil {
return err
}
switch v {
case linux.IP_PMTUDISC_DONT:
v = int32(tcpip.PMTUDiscoveryDont)
case linux.IP_PMTUDISC_WANT:
v = int32(tcpip.PMTUDiscoveryWant)
case linux.IP_PMTUDISC_DO:
v = int32(tcpip.PMTUDiscoveryDo)
case linux.IP_PMTUDISC_PROBE:
v = int32(tcpip.PMTUDiscoveryProbe)
default:
return syserr.ErrNotSupported
}
return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MTUDiscoverOption, int(v)))

case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
Expand All @@ -2587,7 +2633,6 @@ func setSockOptIP(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int,
linux.IP_IPSEC_POLICY,
linux.IP_MINTTL,
linux.IP_MSFILTER,
linux.IP_MTU_DISCOVER,
linux.IP_MULTICAST_ALL,
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
Expand Down
19 changes: 13 additions & 6 deletions pkg/tcpip/network/ipv4/ipv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,19 +462,26 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
if length > math.MaxUint16 {
return &tcpip.ErrMessageTooLong{}
}
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
// are non-atomic and need an ID.
ipH.Encode(&header.IPv4Fields{

fields := header.IPv4Fields{
TotalLength: uint16(length),
ID: e.getID(),
TTL: params.TTL,
TOS: params.TOS,
Protocol: uint8(params.Protocol),
SrcAddr: srcAddr,
DstAddr: dstAddr,
Options: options,
})
}
if params.DF {
// Treat want and do the same.
fields.Flags = header.IPv4FlagDontFragment
} else {
// RFC 6864 section 4.3 mandates uniqueness of ID values for
// non-atomic datagrams.
fields.ID = e.getID()
}
ipH.Encode(&fields)

ipH.SetChecksum(^ipH.CalculateChecksum())
pkt.NetworkProtocolNumber = ProtocolNumber
return nil
Expand Down
3 changes: 3 additions & 0 deletions pkg/tcpip/stack/registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ type NetworkHeaderParams struct {

// TOS refers to TypeOfService or TrafficClass field of the IP-header.
TOS uint8

// DF indicates whether the DF bit should be set.
DF bool
}

// GroupAddressableEndpoint is an endpoint that supports group addressing.
Expand Down
5 changes: 4 additions & 1 deletion pkg/tcpip/tcpip.go
Original file line number Diff line number Diff line change
Expand Up @@ -995,10 +995,13 @@ const (
UseDefaultIPv6HopLimit = -1
)

// PMTUDStrategy is the kind of PMTUD to perform.
type PMTUDStrategy int

const (
// PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use
// per-route settings.
PMTUDiscoveryWant int = iota
PMTUDiscoveryWant PMTUDStrategy = iota

// PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable
// path MTU discovery.
Expand Down
4 changes: 2 additions & 2 deletions pkg/tcpip/transport/internal/network/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
case tcpip.MTUDiscoverOption:
// Return not supported if the value is not disabling path
// MTU discovery.
if v != tcpip.PMTUDiscoveryDont {
if tcpip.PMTUDStrategy(v) != tcpip.PMTUDiscoveryDont {
return &tcpip.ErrNotSupported{}
}

Expand Down Expand Up @@ -835,7 +835,7 @@ func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.MTUDiscoverOption:
// The only supported setting is path MTU discovery disabled.
return tcpip.PMTUDiscoveryDont, nil
return int(tcpip.PMTUDiscoveryDont), nil

case tcpip.MulticastTTLOption:
e.mu.Lock()
Expand Down
12 changes: 10 additions & 2 deletions pkg/tcpip/transport/tcp/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ type tcpFields struct {
rcvWnd seqnum.Size
opts []byte
txHash uint32
df bool
}

func (e *Endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error {
Expand Down Expand Up @@ -881,7 +882,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso sta
buildTCPHdr(r, tf, pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
pkt.GSOOptions = gso
if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil {
if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos, DF: tf.df}, pkt); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
if shouldSplitPacket {
pkt.DecRef()
Expand Down Expand Up @@ -913,7 +914,7 @@ func sendTCP(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso stack.GS
pkt.Owner = owner
buildTCPHdr(r, tf, pkt, gso)

if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil {
if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos, DF: tf.df}, pkt); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
return err
}
Expand Down Expand Up @@ -964,6 +965,9 @@ func (e *Endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
}

// sendEmptyRaw sends a TCP segment with no payload to the endpoint's peer.
//
// +checklocks:e.mu
// +checklocksalias:e.snd.ep.mu=e.mu
func (e *Endpoint) sendEmptyRaw(flags header.TCPFlags, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{})
defer pkt.DecRef()
Expand All @@ -972,6 +976,9 @@ func (e *Endpoint) sendEmptyRaw(flags header.TCPFlags, seq, ack seqnum.Value, rc

// sendRaw sends a TCP segment to the endpoint's peer. This method takes
// ownership of pkt. pkt must not have any headers set.
//
// +checklocks:e.mu
// +checklocksalias:e.snd.ep.mu=e.mu
func (e *Endpoint) sendRaw(pkt *stack.PacketBuffer, flags header.TCPFlags, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error {
var sackBlocks []header.SACKBlock
if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) {
Expand All @@ -989,6 +996,7 @@ func (e *Endpoint) sendRaw(pkt *stack.PacketBuffer, flags header.TCPFlags, seq,
ack: ack,
rcvWnd: rcvWnd,
opts: options,
df: e.pmtud == tcpip.PMTUDiscoveryWant || e.pmtud == tcpip.PMTUDiscoveryDo,
}, pkt, e.gso)
}

Expand Down
25 changes: 19 additions & 6 deletions pkg/tcpip/transport/tcp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,11 @@ type Endpoint struct {
//
// +checklocks:mu
limRdr *io.LimitedReader `state:"nosave"`

// pmtud is the PMTUD strategy to use.
//
// +checklocks:mu
pmtud tcpip.PMTUDStrategy
}

// UniqueID implements stack.TransportEndpoint.UniqueID.
Expand Down Expand Up @@ -1890,9 +1895,16 @@ func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
e.UnlockUser()

case tcpip.MTUDiscoverOption:
// Return not supported if attempting to set this option to
// anything other than path MTU discovery disabled.
if v != tcpip.PMTUDiscoveryDont {
switch v := tcpip.PMTUDStrategy(v); v {
case tcpip.PMTUDiscoveryWant, tcpip.PMTUDiscoveryDont, tcpip.PMTUDiscoveryDo:
e.LockUser()
e.pmtud = v
e.UnlockUser()
case tcpip.PMTUDiscoveryProbe:
// We don't support a way to ignore MTU updates; it's
// either on or it's off.
return &tcpip.ErrNotSupported{}
default:
return &tcpip.ErrNotSupported{}
}

Expand Down Expand Up @@ -2089,9 +2101,10 @@ func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
return v, nil

case tcpip.MTUDiscoverOption:
// Always return the path MTU discovery disabled setting since
// it's the only one supported.
return tcpip.PMTUDiscoveryDont, nil
e.LockUser()
v := e.pmtud
e.UnlockUser()
return int(v), nil

case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
Expand Down
44 changes: 44 additions & 0 deletions test/syscalls/linux/tcp_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <linux/filter.h>
#include <sys/epoll.h>
#endif // __linux__
#include <errno.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <poll.h>
Expand Down Expand Up @@ -2327,6 +2328,49 @@ TEST_P(TcpSocketTest, GetSocketAcceptConnNonListener) {
EXPECT_EQ(got, 0);
}

TEST_P(TcpSocketTest, SetPMTUD) {
// IP_PMTUDISC_WANT should be default.
int got = -1;
socklen_t length = sizeof(got);
ASSERT_THAT(
getsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &got, &length),
SyscallSucceeds());
EXPECT_EQ(got, IP_PMTUDISC_WANT);

int set = IP_PMTUDISC_DO;
ASSERT_THAT(
setsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &set, length),
SyscallSucceeds());
ASSERT_THAT(
getsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &got, &length),
SyscallSucceeds());
EXPECT_EQ(got, IP_PMTUDISC_DO);
set = IP_PMTUDISC_DONT;
ASSERT_THAT(
setsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &set, length),
SyscallSucceeds());
ASSERT_THAT(
getsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &got, &length),
SyscallSucceeds());
EXPECT_EQ(got, IP_PMTUDISC_DONT);

// IP_PMTUDISC_PROBE isn's upported by gVisor.
set = IP_PMTUDISC_PROBE;
if (IsRunningOnGvisor()) {
ASSERT_THAT(
setsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &set, length),
SyscallFailsWithErrno(ENOTSUP));
} else {
ASSERT_THAT(
setsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &set, length),
SyscallSucceeds());
ASSERT_THAT(
getsockopt(accepted_.get(), SOL_IP, IP_MTU_DISCOVER, &got, &length),
SyscallSucceeds());
EXPECT_EQ(got, IP_PMTUDISC_PROBE);
}
}

TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) {
// TODO(b/171345701): Fix the TCP state for listening socket on shutdown.
SKIP_IF(IsRunningOnGvisor());
Expand Down

0 comments on commit f628cb5

Please sign in to comment.