From ca92de4400d4351ba3b505b3ced2e42d9fd6323f Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 3 Apr 2024 10:42:19 -0700 Subject: [PATCH] netstack: fix failure to notify writers when cwnd decreases PiperOrigin-RevId: 621571685 --- pkg/tcpip/transport/tcp/endpoint.go | 10 +- pkg/tcpip/transport/tcp/test/e2e/tcp_test.go | 130 ++++++++++++++++++ .../transport/tcp/testing/context/context.go | 16 ++- 3 files changed, 149 insertions(+), 7 deletions(-) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cf0b0093b9..6f3cc457a8 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2974,11 +2974,11 @@ func (e *Endpoint) updateSndBufferUsage(v int) { notify = notify && e.sndQueueInfo.SndBufUsed < int(newSndBufSz)>>1 e.sndQueueInfo.sndQueueMu.Unlock() - if notify { - // Set the new send buffer size calculated from auto tuning. - e.ops.SetSendBufferSize(newSndBufSz, false /* notify */) - e.waiterQueue.Notify(waiter.WritableEvents) - } + // if notify { + // Set the new send buffer size calculated from auto tuning. + e.ops.SetSendBufferSize(newSndBufSz, false /* notify */) + e.waiterQueue.Notify(waiter.WritableEvents) + // } } // readyToRead is called by the protocol goroutine when a new segment is ready diff --git a/pkg/tcpip/transport/tcp/test/e2e/tcp_test.go b/pkg/tcpip/transport/tcp/test/e2e/tcp_test.go index 4b863b5816..e4849f1a96 100644 --- a/pkg/tcpip/transport/tcp/test/e2e/tcp_test.go +++ b/pkg/tcpip/transport/tcp/test/e2e/tcp_test.go @@ -9052,6 +9052,136 @@ func TestSendBufferTuning(t *testing.T) { } } +func TestSendBufferTuningRTO(t *testing.T) { + const maxPayload = 536 + const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + e2e.MaxTCPOptionSize + maxPayload + const packetOverheadFactor = 2 + + testCases := []struct { + name string + autoTuningDisabled bool + }{ + {"autoTuningDisabled", true}, + {"autoTuningEnabled", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + clock := faketime.NewManualClock() + // tsNow := func() uint32 { + // return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Milliseconds()) + // } + // Advance the clock so that NowMonotonic is non-zero. + clock.Advance(time.Second) + c := context.NewWithOpts(t, context.Options{ + EnableV4: true, + EnableV6: true, + MTU: mtu, + Clock: clock, + }) + defer c.Cleanup() + + // Set the stack option for send buffer size. + const defaultSndBufSz = maxPayload * tcp.InitialCwnd + const maxSndBufSz = defaultSndBufSz * 10 + { + opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz} + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } + } + + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + + oldSz := c.EP.SocketOptions().GetSendBufferSize() + if oldSz != defaultSndBufSz { + t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz) + } + + if tc.autoTuningDisabled { + c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */) + } + + data := make([]byte, maxPayload) + for i := range data { + data[i] = byte(i) + } + + w, ch := waiter.NewChannelEntry(waiter.WritableEvents) + c.WQ.EventRegister(&w) + defer c.WQ.EventUnregister(&w) + + bytesRead := 0 + for { + // Packets will be sent till the send buffer + // size is reached. + var r bytes.Reader + r.Reset(data[bytesRead : bytesRead+maxPayload]) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + break + } + + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0) + bytesRead += maxPayload + data = append(data, data...) + } + + // Send an ACK and wait for connection to become writable again. + c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) + select { + case <-ch: + if err := c.EP.LastError(); err != nil { + t.Fatalf("Write failed: %s", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } + + // The window has grown. Write until it's full. + oldBytesRead := bytesRead + var maxSeq seqnum.Value + for { + // Packets will be sent till the send buffer + // size is reached. + var r bytes.Reader + r.Reset(data[:maxPayload]) + n, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + break + } + + // c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0) + bytesRead += int(n) + if pkt := c.MaybeGetPacket(500 * time.Millisecond); pkt != nil { + defer pkt.Release() + tcp := header.TCP(header.IPv4(pkt.AsSlice()).Payload()) + maxSeq = seqnum.Value(tcp.SequenceNumber()) // + len? + } + // if seq > maxSeq { + // maxSeq = seq + // } + // data = append(data, data...) + } + t.Logf("wrote %d additional bytes to fill the buffer", bytesRead-oldBytesRead) + + // Cause an RTO. + clock.Advance(time.Second + tcp.MaxRTO) + + // Ensure that sending an ACK causes the send buffer to drain. + c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), int(maxSeq)-int(c.IRS)) + select { + case <-ch: + if err := c.EP.LastError(); err != nil { + t.Fatalf("Write failed: %s", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } + }) + } +} + func TestTimestampSynCookies(t *testing.T) { clock := faketime.NewManualClock() tsNow := func() uint32 { diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index f0c968968f..7b32523822 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -369,10 +369,22 @@ func (c *Context) GetPacketWithTimeout(timeout time.Duration) *buffer.View { // addresses. func (c *Context) GetPacket() *buffer.View { c.t.Helper() + return c.getPacket(5*time.Second, func() { + c.t.Fatalf("Packet wasn't written out") + }) +} - p := c.GetPacketWithTimeout(5 * time.Second) +func (c *Context) MaybeGetPacket(to time.Duration) *buffer.View { + c.t.Helper() + return c.getPacket(to, func() {}) +} + +func (c *Context) getPacket(to time.Duration, onNil func()) *buffer.View { + c.t.Helper() + + p := c.GetPacketWithTimeout(to) if p == nil { - c.t.Fatalf("Packet wasn't written out") + onNil() return nil }