diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index 9d61e881e8..960d7d862a 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -274,7 +274,34 @@ func (c *WriteContext) TryNewPacketBuffer(reserveHdrBytes int, data buffer.Buffe if !e.hasSendSpaceRLocked() { return nil } + return c.newPacketBufferLocked(reserveHdrBytes, data) +} + +// TryNewPacketBufferFromPayloader returns a new packet buffer iff the endpoint's send buffer +// is not full. Otherwise, data from `payloader` isn't read. +// +// If this method returns nil, the caller should wait for the endpoint to become +// writable. +func (c *WriteContext) TryNewPacketBufferFromPayloader(reserveHdrBytes int, payloader tcpip.Payloader) *stack.PacketBuffer { + e := c.e + + e.sendBufferSizeInUseMu.Lock() + defer e.sendBufferSizeInUseMu.Unlock() + if !e.hasSendSpaceRLocked() { + return nil + } + var data buffer.Buffer + if _, err := data.WriteFromReader(payloader, int64(payloader.Len())); err != nil { + data.Release() + return nil + } + return c.newPacketBufferLocked(reserveHdrBytes, data) +} + +// +checklocks:c.e.sendBufferSizeInUseMu +func (c *WriteContext) newPacketBufferLocked(reserveHdrBytes int, data buffer.Buffer) *stack.PacketBuffer { + e := c.e // Note that we allow oversubscription - if there is any space at all in the // send buffer, we accept the full packet which may be larger than the space // available. This is because if the endpoint reports that it is writable, diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 0c21be861d..6c4de0a103 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -21,7 +21,6 @@ import ( "math" "time" - "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checksum" @@ -436,16 +435,8 @@ func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) ( return udpPacketInfo{}, &tcpip.ErrMessageTooLong{} } - var buf buffer.Buffer - if _, err := buf.WriteFromReader(p, int64(p.Len())); err != nil { - buf.Release() - ctx.Release() - return udpPacketInfo{}, &tcpip.ErrBadBuffer{} - } - return udpPacketInfo{ ctx: ctx, - data: buf, localPort: e.localPort, remotePort: dst.Port, }, nil @@ -473,9 +464,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } defer udpInfo.ctx.Release() - dataSz := udpInfo.data.Size() + dataSz := p.Len() pktInfo := udpInfo.ctx.PacketInfo() - pkt := udpInfo.ctx.TryNewPacketBuffer(header.UDPMinimumSize+int(pktInfo.MaxHeaderLength), udpInfo.data) + pkt := udpInfo.ctx.TryNewPacketBufferFromPayloader(header.UDPMinimumSize+int(pktInfo.MaxHeaderLength), p) if pkt == nil { return 0, &tcpip.ErrWouldBlock{} } @@ -593,7 +584,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { // udpPacketInfo holds information needed to send a UDP packet. type udpPacketInfo struct { ctx network.WriteContext - data buffer.Buffer localPort uint16 remotePort uint16 }