From 9bb44fbf2ec5b24fcf6ea93cff2efccea6f972e0 Mon Sep 17 00:00:00 2001 From: Sergey Matyukevich Date: Thu, 20 Jul 2023 16:28:06 -0600 Subject: [PATCH] transport: use a sync.Pool to share per-connection write buffer (#6309) --- benchmark/benchmain/main.go | 11 ++++++ benchmark/stats/stats.go | 9 ++++- dialoptions.go | 14 +++++++ internal/transport/http2_client.go | 2 +- internal/transport/http2_server.go | 2 +- internal/transport/http_util.go | 59 +++++++++++++++++++++++++--- internal/transport/keepalive_test.go | 2 +- internal/transport/transport.go | 3 ++ server.go | 16 ++++++++ 9 files changed, 107 insertions(+), 11 deletions(-) diff --git a/benchmark/benchmain/main.go b/benchmark/benchmain/main.go index f4b96a5d460..76c1a265e50 100644 --- a/benchmark/benchmain/main.go +++ b/benchmark/benchmain/main.go @@ -115,6 +115,8 @@ var ( sleepBetweenRPCs = flags.DurationSlice("sleepBetweenRPCs", []time.Duration{0}, "Configures the maximum amount of time the client should sleep between consecutive RPCs - may be a a comma-separated list") connections = flag.Int("connections", 1, "The number of connections. Each connection will handle maxConcurrentCalls RPC streams") recvBufferPool = flags.StringWithAllowedValues("recvBufferPool", recvBufferPoolNil, "Configures the shared receive buffer pool. One of: nil, simple, all", allRecvBufferPools) + sharedWriteBuffer = flags.StringWithAllowedValues("sharedWriteBuffer", toggleModeOff, + fmt.Sprintf("Configures both client and server to share write buffer - One of: %v", strings.Join(allToggleModes, ", ")), allToggleModes) logger = grpclog.Component("benchmark") ) @@ -335,6 +337,10 @@ func makeClients(bf stats.Features) ([]testgrpc.BenchmarkServiceClient, func()) if bf.ServerReadBufferSize >= 0 { sopts = append(sopts, grpc.ReadBufferSize(bf.ServerReadBufferSize)) } + if bf.SharedWriteBuffer { + opts = append(opts, grpc.WithSharedWriteBuffer(true)) + sopts = append(sopts, grpc.SharedWriteBuffer(true)) + } if bf.ServerWriteBufferSize >= 0 { sopts = append(sopts, grpc.WriteBufferSize(bf.ServerWriteBufferSize)) } @@ -603,6 +609,7 @@ type featureOpts struct { serverWriteBufferSize []int sleepBetweenRPCs []time.Duration recvBufferPools []string + sharedWriteBuffer []bool } // makeFeaturesNum returns a slice of ints of size 'maxFeatureIndex' where each @@ -651,6 +658,8 @@ func makeFeaturesNum(b *benchOpts) []int { featuresNum[i] = len(b.features.sleepBetweenRPCs) case stats.RecvBufferPool: featuresNum[i] = len(b.features.recvBufferPools) + case stats.SharedWriteBuffer: + featuresNum[i] = len(b.features.sharedWriteBuffer) default: log.Fatalf("Unknown feature index %v in generateFeatures. maxFeatureIndex is %v", i, stats.MaxFeatureIndex) } @@ -720,6 +729,7 @@ func (b *benchOpts) generateFeatures(featuresNum []int) []stats.Features { ServerWriteBufferSize: b.features.serverWriteBufferSize[curPos[stats.ServerWriteBufferSize]], SleepBetweenRPCs: b.features.sleepBetweenRPCs[curPos[stats.SleepBetweenRPCs]], RecvBufferPool: b.features.recvBufferPools[curPos[stats.RecvBufferPool]], + SharedWriteBuffer: b.features.sharedWriteBuffer[curPos[stats.SharedWriteBuffer]], } if len(b.features.reqPayloadCurves) == 0 { f.ReqSizeBytes = b.features.reqSizeBytes[curPos[stats.ReqSizeBytesIndex]] @@ -793,6 +803,7 @@ func processFlags() *benchOpts { serverWriteBufferSize: append([]int(nil), *serverWriteBufferSize...), sleepBetweenRPCs: append([]time.Duration(nil), *sleepBetweenRPCs...), recvBufferPools: setRecvBufferPool(*recvBufferPool), + sharedWriteBuffer: setToggleMode(*sharedWriteBuffer), }, } diff --git a/benchmark/stats/stats.go b/benchmark/stats/stats.go index 3989e25dbf4..e42c5b6c0f2 100644 --- a/benchmark/stats/stats.go +++ b/benchmark/stats/stats.go @@ -58,6 +58,7 @@ const ( ServerWriteBufferSize SleepBetweenRPCs RecvBufferPool + SharedWriteBuffer // MaxFeatureIndex is a place holder to indicate the total number of feature // indices we have. Any new feature indices should be added above this. @@ -129,6 +130,8 @@ type Features struct { SleepBetweenRPCs time.Duration // RecvBufferPool represents the shared recv buffer pool used. RecvBufferPool string + // SharedWriteBuffer configures whether both client and server share per-connection write buffer + SharedWriteBuffer bool } // String returns all the feature values as a string. @@ -148,13 +151,13 @@ func (f Features) String() string { "trace_%v-latency_%v-kbps_%v-MTU_%v-maxConcurrentCalls_%v-%s-%s-"+ "compressor_%v-channelz_%v-preloader_%v-clientReadBufferSize_%v-"+ "clientWriteBufferSize_%v-serverReadBufferSize_%v-serverWriteBufferSize_%v-"+ - "sleepBetweenRPCs_%v-connections_%v-recvBufferPool_%v-", + "sleepBetweenRPCs_%v-connections_%v-recvBufferPool_%v-sharedWriteBuffer_%v", f.NetworkMode, f.UseBufConn, f.EnableKeepalive, f.BenchTime, f.EnableTrace, f.Latency, f.Kbps, f.MTU, f.MaxConcurrentCalls, reqPayloadString, respPayloadString, f.ModeCompressor, f.EnableChannelz, f.EnablePreloader, f.ClientReadBufferSize, f.ClientWriteBufferSize, f.ServerReadBufferSize, f.ServerWriteBufferSize, f.SleepBetweenRPCs, f.Connections, - f.RecvBufferPool) + f.RecvBufferPool, f.SharedWriteBuffer) } // SharedFeatures returns the shared features as a pretty printable string. @@ -230,6 +233,8 @@ func (f Features) partialString(b *bytes.Buffer, wantFeatures []bool, sep, delim b.WriteString(fmt.Sprintf("SleepBetweenRPCs%v%v%v", sep, f.SleepBetweenRPCs, delim)) case RecvBufferPool: b.WriteString(fmt.Sprintf("RecvBufferPool%v%v%v", sep, f.RecvBufferPool, delim)) + case SharedWriteBuffer: + b.WriteString(fmt.Sprintf("SharedWriteBuffer%v%v%v", sep, f.SharedWriteBuffer, delim)) default: log.Fatalf("Unknown feature index %v. maxFeatureIndex is %v", i, MaxFeatureIndex) } diff --git a/dialoptions.go b/dialoptions.go index 23ea95237ea..1fd0d5c127f 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -139,6 +139,20 @@ func newJoinDialOption(opts ...DialOption) DialOption { return &joinDialOption{opts: opts} } +// WithSharedWriteBuffer allows reusing per-connection transport write buffer. +// If this option is set to true every connection will release the buffer after +// flushing the data on the wire. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func WithSharedWriteBuffer(val bool) DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.copts.SharedWriteBuffer = val + }) +} + // WithWriteBufferSize determines how much data can be batched before doing a // write on the wire. The corresponding memory allocation for this buffer will // be twice the size to keep syscalls low. The default value for this buffer is diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 326bf084800..52b88c32b15 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -330,7 +330,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts readerDone: make(chan struct{}), writerDone: make(chan struct{}), goAway: make(chan struct{}), - framer: newFramer(conn, writeBufSize, readBufSize, maxHeaderListSize), + framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize), fc: &trInFlow{limit: uint32(icwz)}, scheme: scheme, activeStreams: make(map[uint32]*Stream), diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index f9606401289..c48091f6c0a 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -165,7 +165,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, if config.MaxHeaderListSize != nil { maxHeaderListSize = *config.MaxHeaderListSize } - framer := newFramer(conn, writeBufSize, readBufSize, maxHeaderListSize) + framer := newFramer(conn, writeBufSize, readBufSize, config.SharedWriteBuffer, maxHeaderListSize) // Send initial settings as connection preface to client. isettings := []http2.Setting{{ ID: http2.SettingMaxFrameSize, diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index 19cbb18f5ab..add1e9b2cc0 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -30,6 +30,7 @@ import ( "net/url" "strconv" "strings" + "sync" "time" "unicode/utf8" @@ -309,6 +310,7 @@ func decodeGrpcMessageUnchecked(msg string) string { } type bufWriter struct { + pool *sync.Pool buf []byte offset int batchSize int @@ -316,12 +318,17 @@ type bufWriter struct { err error } -func newBufWriter(conn net.Conn, batchSize int) *bufWriter { - return &bufWriter{ - buf: make([]byte, batchSize*2), +func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter { + w := &bufWriter{ batchSize: batchSize, conn: conn, + pool: pool, } + // this indicates that we should use non shared buf + if pool == nil { + w.buf = make([]byte, batchSize) + } + return w } func (w *bufWriter) Write(b []byte) (n int, err error) { @@ -332,19 +339,34 @@ func (w *bufWriter) Write(b []byte) (n int, err error) { n, err = w.conn.Write(b) return n, toIOError(err) } + if w.buf == nil { + b := w.pool.Get().(*[]byte) + w.buf = *b + } for len(b) > 0 { nn := copy(w.buf[w.offset:], b) b = b[nn:] w.offset += nn n += nn if w.offset >= w.batchSize { - err = w.Flush() + err = w.flushKeepBuffer() } } return n, err } func (w *bufWriter) Flush() error { + err := w.flushKeepBuffer() + // Only release the buffer if we are in a "shared" mode + if w.buf != nil && w.pool != nil { + b := w.buf + w.pool.Put(&b) + w.buf = nil + } + return err +} + +func (w *bufWriter) flushKeepBuffer() error { if w.err != nil { return w.err } @@ -381,7 +403,10 @@ type framer struct { fr *http2.Framer } -func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderListSize uint32) *framer { +var writeBufferPoolMap map[int]*sync.Pool = make(map[int]*sync.Pool) +var writeBufferMutex sync.Mutex + +func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32) *framer { if writeBufferSize < 0 { writeBufferSize = 0 } @@ -389,7 +414,11 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderList if readBufferSize > 0 { r = bufio.NewReaderSize(r, readBufferSize) } - w := newBufWriter(conn, writeBufferSize) + var pool *sync.Pool + if sharedWriteBuffer { + pool = getWriteBufferPool(writeBufferSize) + } + w := newBufWriter(conn, writeBufferSize, pool) f := &framer{ writer: w, fr: http2.NewFramer(w, r), @@ -403,6 +432,24 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderList return f } +func getWriteBufferPool(writeBufferSize int) *sync.Pool { + writeBufferMutex.Lock() + defer writeBufferMutex.Unlock() + size := writeBufferSize * 2 + pool, ok := writeBufferPoolMap[size] + if ok { + return pool + } + pool = &sync.Pool{ + New: func() interface{} { + b := make([]byte, size) + return &b + }, + } + writeBufferPoolMap[size] = pool + return pool +} + // parseDialTarget returns the network and address to pass to dialer. func parseDialTarget(target string) (string, string) { net := "tcp" diff --git a/internal/transport/keepalive_test.go b/internal/transport/keepalive_test.go index a46bcf020df..8144277fb6c 100644 --- a/internal/transport/keepalive_test.go +++ b/internal/transport/keepalive_test.go @@ -191,7 +191,7 @@ func (s) TestKeepaliveServerClosesUnresponsiveClient(t *testing.T) { if n, err := conn.Write(clientPreface); err != nil || n != len(clientPreface) { t.Fatalf("conn.Write(clientPreface) failed: n=%v, err=%v", n, err) } - framer := newFramer(conn, defaultWriteBufSize, defaultReadBufSize, 0) + framer := newFramer(conn, defaultWriteBufSize, defaultReadBufSize, false, 0) if err := framer.fr.WriteSettings(http2.Setting{}); err != nil { t.Fatal("framer.WriteSettings(http2.Setting{}) failed:", err) } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index aa1c896595d..3828b3e4a8d 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -559,6 +559,7 @@ type ServerConfig struct { InitialConnWindowSize int32 WriteBufferSize int ReadBufferSize int + SharedWriteBuffer bool ChannelzParentID *channelz.Identifier MaxHeaderListSize *uint32 HeaderTableSize *uint32 @@ -592,6 +593,8 @@ type ConnectOptions struct { WriteBufferSize int // ReadBufferSize sets the size of read buffer, which in turn determines how much data can be read at most for one read syscall. ReadBufferSize int + // SharedWriteBuffer indicates whether connections should reuse write buffer + SharedWriteBuffer bool // ChannelzParentID sets the addrConn id which initiate the creation of this client transport. ChannelzParentID *channelz.Identifier // MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received. diff --git a/server.go b/server.go index e076ec7143b..01b3265223c 100644 --- a/server.go +++ b/server.go @@ -170,6 +170,7 @@ type serverOptions struct { initialConnWindowSize int32 writeBufferSize int readBufferSize int + sharedWriteBuffer bool connectionTimeout time.Duration maxHeaderListSize *uint32 headerTableSize *uint32 @@ -235,6 +236,20 @@ func newJoinServerOption(opts ...ServerOption) ServerOption { return &joinServerOption{opts: opts} } +// SharedWriteBuffer allows reusing per-connection transport write buffer. +// If this option is set to true every connection will release the buffer after +// flushing the data on the wire. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func SharedWriteBuffer(val bool) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.sharedWriteBuffer = val + }) +} + // WriteBufferSize determines how much data can be batched before doing a write // on the wire. The corresponding memory allocation for this buffer will be // twice the size to keep syscalls low. The default value for this buffer is @@ -938,6 +953,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { InitialConnWindowSize: s.opts.initialConnWindowSize, WriteBufferSize: s.opts.writeBufferSize, ReadBufferSize: s.opts.readBufferSize, + SharedWriteBuffer: s.opts.sharedWriteBuffer, ChannelzParentID: s.channelzID, MaxHeaderListSize: s.opts.maxHeaderListSize, HeaderTableSize: s.opts.headerTableSize,