Skip to content

Commit

Permalink
Add reference counting to packet buffers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 408426639
  • Loading branch information
manninglucas authored and gvisor-bot committed Nov 8, 2021
1 parent 49d23be commit 84b38f4
Show file tree
Hide file tree
Showing 34 changed files with 177 additions and 44 deletions.
2 changes: 2 additions & 0 deletions pkg/tcpip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ deps_test(
"//pkg/state/wire",
"//pkg/sync",
"//pkg/waiter",
"//pkg/refsvfs2",
"//pkg/refs",
"//pkg/syserr",
"//pkg/abi/linux/errno",
"//pkg/errors",
Expand Down
12 changes: 12 additions & 0 deletions pkg/tcpip/link/channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
}

func (q *queue) Write(p PacketInfo) bool {
// q holds the PacketBuffer.

// Ideally, Write() should take a reference here, since it is adding
// the underlying PacketBuffer to the channel. However, in practice,
// calls to Read() are not necessarily symetric with calls
// to Write() (e.g writing to this endpoint and then exiting). This
// causes tests and analyzers to detect erroneous "leaks" for expected
// behavior. To prevent this, we allow the refcount to go to zero, and
// make a call to PreserveObject(), which prevents the PacketBuffer
// pooling implementation from reclaiming this instance, even when
// the refcount goes to zero.
p.Pkt.PreserveObject()
wrote := false
select {
case q.c <- p:
Expand Down
1 change: 1 addition & 0 deletions pkg/tcpip/link/fdbased/mmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) {
pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(pkt).ToVectorisedView(),
})
defer pbuf.DecRef()
if d.e.hdrSize > 0 {
if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok {
panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize))
Expand Down
2 changes: 2 additions & 0 deletions pkg/tcpip/link/fdbased/packet_dispatchers.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: d.buf.pullViews(n),
})
defer pkt.DecRef()

var (
p tcpip.NetworkProtocolNumber
Expand Down Expand Up @@ -289,6 +290,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: d.bufs[k].pullViews(n),
})
defer pkt.DecRef()

// Mark that this iovec has been processed.
d.msgHdrs[k].Msg.Iovlen = 0
Expand Down
1 change: 1 addition & 0 deletions pkg/tcpip/link/loopback/loopback.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: data,
})
defer newPkt.DecRef()
e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, pkt.NetworkProtocolNumber, newPkt)

return nil
Expand Down
6 changes: 4 additions & 2 deletions pkg/tcpip/link/pipe/pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocol
// avoid a deadlock when a packet triggers a response which leads the stack to
// try and take a lock it already holds.
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
}))
})
e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, newPkt)
newPkt.DecRef()
}
}

Expand Down
4 changes: 1 addition & 3 deletions pkg/tcpip/link/qdisc/fifo/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ func (q *queueDispatcher) dispatchLoop() {
// We pass a protocol of zero here because each packet carries its
// NetworkProtocol.
q.lower.WritePackets(stack.RouteInfo{}, batch, 0 /* protocol */)
for pkt := batch.Front(); pkt != nil; pkt = pkt.Next() {
batch.Remove(pkt)
}
batch.DecRef()
batch.Reset()
}
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ func (q *packetBufferQueue) setLimit(limit int) {
// enqueue adds the given packet to the queue.
//
// Returns true when the PacketBuffer is successfully added to the queue, in
// which case ownership of the reference is transferred to the queue. And
// returns false if the queue is full, in which case ownership is retained by
// the caller.
// which case the queue acquires a reference to the PacketBuffer, and
// returns false if the queue is full.
func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool {
q.mu.Lock()
r := q.used < q.limit
if r {
s.IncRef()
q.list.PushBack(s)
q.used++
}
Expand All @@ -70,7 +70,7 @@ func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool {
}

// dequeue removes and returns the next PacketBuffer from queue, if one exists.
// Ownership is transferred to the caller.
// Caller is responsible for calling DecRef on the PacketBuffer.
func (q *packetBufferQueue) dequeue() *stack.PacketBuffer {
q.mu.Lock()
s := q.list.Front()
Expand Down
1 change: 1 addition & 0 deletions pkg/tcpip/link/sharedmem/sharedmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(b).ToVectorisedView(),
})
defer pkt.DecRef()

var src, dst tcpip.LinkAddress
var proto tcpip.NetworkProtocolNumber
Expand Down
4 changes: 4 additions & 0 deletions pkg/tcpip/link/sharedmem/sharedmem_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) {
if e.addr != "" {
hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
if !ok {
pkt.DecRef()
continue
}
eth := header.Ethernet(hdr)
Expand All @@ -323,6 +324,7 @@ func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) {
// IP version information is at the first octet, so pulling up 1 byte.
h, ok := pkt.Data().PullUp(1)
if !ok {
pkt.DecRef()
continue
}
switch header.IPVersion(h) {
Expand All @@ -331,11 +333,13 @@ func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) {
case header.IPv6Version:
proto = header.IPv6ProtocolNumber
default:
pkt.DecRef()
continue
}
}
// Send packet up the stack.
d.DeliverNetworkPacket(src, dst, proto, pkt)
pkt.DecRef()
}

e.mu.Lock()
Expand Down
1 change: 1 addition & 0 deletions pkg/tcpip/link/sniffer/sniffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe
vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
vv.TrimFront(len(pkt.LinkHeader().View()))
pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
defer pkt.DecRef()
switch protocol {
case header.IPv4ProtocolNumber:
if ok := parse.IPv4(pkt); !ok {
Expand Down
2 changes: 2 additions & 0 deletions pkg/tcpip/link/tun/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ func (d *Device) Release(ctx context.Context) {

// Decrease refcount if there is an endpoint associated with this file.
if d.endpoint != nil {
d.endpoint.Drain()
d.endpoint.RemoveNotify(d.notifyHandle)
d.endpoint.DecRef(ctx)
d.endpoint = nil
Expand Down Expand Up @@ -231,6 +232,7 @@ func (d *Device) Write(data []byte) (int64, error) {
ReserveHeaderBytes: len(ethHdr),
Data: buffer.View(data).ToVectorisedView(),
})
defer pkt.DecRef()
copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr)
endpoint.InjectLinkAddr(protocol, remote, pkt)
return dataLen, nil
Expand Down
2 changes: 2 additions & 0 deletions pkg/tcpip/network/arp/arp.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
defer respPkt.DecRef()
packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize))
respPkt.NetworkProtocolNumber = ProtocolNumber
packet.SetIPv4OverEthernet()
Expand Down Expand Up @@ -339,6 +340,7 @@ func (e *endpoint) sendARPRequest(localAddr, targetAddr tcpip.Address, remoteLin
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(e.MaxHeaderLength()),
})
defer pkt.DecRef()
h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
pkt.NetworkProtocolNumber = ProtocolNumber
h.SetIPv4OverEthernet()
Expand Down
2 changes: 2 additions & 0 deletions pkg/tcpip/network/internal/fragmentation/reassembler.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
r.proto = proto
}

pkt.IncRef()
break
}
if !holeFound {
Expand All @@ -166,6 +167,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
})

resPkt := r.holes[0].pkt
resPkt.DecRef()
for i := 1; i < len(r.holes); i++ {
stack.MergeFragment(resPkt, r.holes[i].pkt)
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/tcpip/network/ipv4/icmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.echoRequest.Increment()

// DeliverTransportPacket will take ownership of pkt so don't use it beyond
// DeliverTransportPacket may modify pkt so don't use it beyond
// this point. Make a deep copy of the data before pkt gets sent as we will
// be modifying fields. Both the ICMP header (with its type modified to
// EchoReply) and payload are reused in the reply packet.
Expand Down Expand Up @@ -320,6 +320,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: replyVV,
})
defer replyPkt.DecRef()
replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber

if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
Expand Down Expand Up @@ -667,6 +668,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize,
Data: payload,
})
defer icmpPkt.DecRef()

icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber

Expand Down
1 change: 1 addition & 0 deletions pkg/tcpip/network/ipv4/igmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()),
Data: buffer.View(igmpData).ToVectorisedView(),
})
defer pkt.DecRef()

addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */)
if addressEndpoint == nil {
Expand Down
6 changes: 5 additions & 1 deletion pkg/tcpip/network/ipv4/ipv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) {
pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
})
defer pkt.DecRef()
pkt.NICID = e.nic.ID()
pkt.NetworkProtocolNumber = ProtocolNumber
// Use the same control type as an ICMPv4 destination host unreachable error
Expand Down Expand Up @@ -534,6 +535,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// removed once the fragmentation is done.
originalPkt := pkt
if _, _, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
fragPkt.IncRef()
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pkt, fragPkt)
pkt = fragPkt
Expand Down Expand Up @@ -751,10 +753,11 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
}

// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// WriteHeaderIncludedPacket may modify the packet buffer, but we do
// not own it.
newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
newHdr := header.IPv4(newPkt.NetworkHeader().View())
defer newPkt.DecRef()

// As per RFC 791 page 30, Time to Live,
//
Expand Down Expand Up @@ -859,6 +862,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
stats.PacketsReceived.Increment()

pkt = pkt.CloneToInbound()
defer pkt.DecRef()
pkt.RXTransportChecksumValidated = canSkipRXChecksum

h, ok := e.protocol.parseAndValidate(pkt)
Expand Down
3 changes: 3 additions & 0 deletions pkg/tcpip/network/ipv6/icmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborAdvertSize,
})
defer pkt.DecRef()
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
Expand Down Expand Up @@ -675,6 +676,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
Data: pkt.Data().ExtractVV(),
})
defer replyPkt.DecRef()
icmp := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
copy(icmp, h)
Expand Down Expand Up @@ -1213,6 +1215,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize,
Data: payload,
})
defer newPkt.DecRef()
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber

icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
Expand Down
4 changes: 4 additions & 0 deletions pkg/tcpip/network/ipv6/ipv6.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) {
pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
})
defer pkt.DecRef()
pkt.NICID = e.nic.ID()
pkt.NetworkProtocolNumber = ProtocolNumber
e.handleControl(&icmpv6DestinationAddressUnreachableSockError{}, pkt)
Expand Down Expand Up @@ -855,6 +856,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// removed once the fragmentation is done.
originalPkt := pb
if _, _, err := e.handleFragments(r, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
fragPkt.IncRef()
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pb, fragPkt)
pb = fragPkt
Expand Down Expand Up @@ -1025,6 +1027,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
defer newPkt.DecRef()
newHdr := header.IPv6(newPkt.NetworkHeader().View())

// As per RFC 8200 section 3,
Expand Down Expand Up @@ -1118,6 +1121,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
stats.PacketsReceived.Increment()

pkt = pkt.CloneToInbound()
defer pkt.DecRef()
pkt.RXTransportChecksumValidated = canSkipRXChecksum

h, ok := e.protocol.parseAndValidate(pkt)
Expand Down
1 change: 1 addition & 0 deletions pkg/tcpip/network/ipv6/mld.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp
ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(),
Data: buffer.View(icmp).ToVectorisedView(),
})
defer pkt.DecRef()

if err := addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
Expand Down
2 changes: 2 additions & 0 deletions pkg/tcpip/network/ipv6/ndp.go
Original file line number Diff line number Diff line change
Expand Up @@ -1807,6 +1807,7 @@ func (ndp *ndpState) startSolicitingRouters() {
ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
Data: buffer.View(icmpData).ToVectorisedView(),
})
defer pkt.DecRef()

sent := ndp.ep.stats.icmp.packetsSent
if err := addIPHeader(localAddr, header.IPv6AllRoutersLinkLocalMulticastAddress, pkt, stack.NetworkHeaderParams{
Expand Down Expand Up @@ -1924,6 +1925,7 @@ func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteL
ReserveHeaderBytes: int(e.MaxHeaderLength()),
Data: buffer.View(icmp).ToVectorisedView(),
})
defer pkt.DecRef()

if err := addIPHeader(srcAddr, dstAddr, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
Expand Down
13 changes: 13 additions & 0 deletions pkg/tcpip/stack/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ go_template_instance(
},
)

go_template_instance(
name = "packet_buffer_refs",
out = "packet_buffer_refs.go",
package = "stack",
prefix = "packetBuffer",
template = "//pkg/refsvfs2:refs_template",
types = {
"T": "PacketBuffer",
},
)

go_library(
name = "stack",
srcs = [
Expand All @@ -59,6 +70,7 @@ go_library(
"nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
"packet_buffer_refs.go",
"packet_buffer_unsafe.go",
"pending_packets.go",
"rand.go",
Expand All @@ -78,6 +90,7 @@ go_library(
"//pkg/ilist",
"//pkg/log",
"//pkg/rand",
"//pkg/refsvfs2",
"//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
Expand Down

0 comments on commit 84b38f4

Please sign in to comment.