diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index 39cef3bd442e..f609c6c66595 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -317,28 +317,32 @@ func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter { return w } -func (w *bufWriter) Write(b []byte) (n int, err error) { +func (w *bufWriter) Write(b []byte) (int, error) { if w.err != nil { return 0, w.err } if w.batchSize == 0 { // Buffer has been disabled. - n, err = w.conn.Write(b) + n, err := w.conn.Write(b) return n, toIOError(err) } if w.buf == nil { b := w.pool.Get().(*[]byte) w.buf = *b } + written := 0 for len(b) > 0 { - nn := copy(w.buf[w.offset:], b) - b = b[nn:] - w.offset += nn - n += nn - if w.offset >= w.batchSize { - err = w.flushKeepBuffer() + copied := copy(w.buf[w.offset:], b) + b = b[copied:] + written += copied + w.offset += copied + if w.offset < w.batchSize { + continue + } + if err := w.flushKeepBuffer(); err != nil { + return written, err } } - return n, err + return written, nil } func (w *bufWriter) Flush() error { diff --git a/internal/transport/http_util_test.go b/internal/transport/http_util_test.go index cc7807670b62..5a259d43cdc2 100644 --- a/internal/transport/http_util_test.go +++ b/internal/transport/http_util_test.go @@ -19,7 +19,10 @@ package transport import ( + "errors" "fmt" + "io" + "net" "reflect" "testing" "time" @@ -215,6 +218,39 @@ func (s) TestParseDialTarget(t *testing.T) { } } +type badNetworkConn struct { + net.Conn +} + +func (c *badNetworkConn) Write([]byte) (int, error) { + return 0, io.EOF +} + +// This test ensures Write() on a broken network connection does not lead to +// an infinite loop. See https://github.com/grpc/grpc-go/issues/7389 for more details. +func (s) TestWriteBadConnection(t *testing.T) { + data := []byte("test_data") + // Configure the bufWriter with a batchsize that results in data being flushed + // to the underlying conn, midway through Write(). + writeBufferSize := (len(data) - 1) / 2 + writer := newBufWriter(&badNetworkConn{}, writeBufferSize, getWriteBufferPool(writeBufferSize)) + + errCh := make(chan error, 1) + go func() { + _, err := writer.Write(data) + errCh <- err + }() + + select { + case <-time.After(time.Second): + t.Fatalf("Write() did not return in time") + case err := <-errCh: + if !errors.Is(err, io.EOF) { + t.Fatalf("Write() = %v, want error presence = %v", err, io.EOF) + } + } +} + func BenchmarkDecodeGrpcMessage(b *testing.B) { input := "Hello, %E4%B8%96%E7%95%8C" want := "Hello, 世界"