Skip to content

Commit

Permalink
quic: remove streams from the conn when done
Browse files Browse the repository at this point in the history
When a stream has been fully shut down--the peer has closed
its end and acked every frame we will send for it--remove
it from the Conn's set of active streams.

We do the actual removal on the conn's loop, so stream cleanup
can access conn state without worrying about locking.

For golang/go#58547

Change-Id: Id9715693649929b07d303f0c4b3a782d135f0326
Reviewed-on: https://go-review.googlesource.com/c/net/+/524296
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
  • Loading branch information
neild committed Sep 1, 2023
1 parent 03d5e62 commit 97384c1
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 44 deletions.
33 changes: 33 additions & 0 deletions internal/quic/atomic_bits.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build go1.21

package quic

import "sync/atomic"

// atomicBits is an atomic uint32 that supports setting individual bits.
type atomicBits[T ~uint32] struct {
bits atomic.Uint32
}

// set sets the bits in mask to the corresponding bits in v.
// It returns the new value.
func (a *atomicBits[T]) set(v, mask T) T {
if v&^mask != 0 {
panic("BUG: bits in v are not in mask")
}
for {
o := a.bits.Load()
n := (o &^ uint32(mask)) | uint32(v)
if a.bits.CompareAndSwap(o, n) {
return T(n)
}
}
}

func (a *atomicBits[T]) load() T {
return T(a.bits.Load())
}
62 changes: 45 additions & 17 deletions internal/quic/conn_streams.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,24 +185,46 @@ func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool)
for {
s := c.streams.sendHead
const pto = false
if !s.appendInFrames(w, pnum, pto) {
return false

state := s.state.load()
if state&streamInSend != 0 {
s.ingate.lock()
ok := s.appendInFramesLocked(w, pnum, pto)
state = s.inUnlockNoQueue()
if !ok {
return false
}
}
avail := w.avail()
if !s.appendOutFrames(w, pnum, pto) {
// We've sent some data for this stream, but it still has more to send.
// If the stream got a reasonable chance to put data in a packet,
// advance sendHead to the next stream in line, to avoid starvation.
// We'll come back to this stream after going through the others.
//
// If the packet was already mostly out of space, leave sendHead alone
// and come back to this stream again on the next packet.
if avail > 512 {
c.streams.sendHead = s.next
c.streams.sendTail = s

if state&streamOutSend != 0 {
avail := w.avail()
s.outgate.lock()
ok := s.appendOutFramesLocked(w, pnum, pto)
state = s.outUnlockNoQueue()
if !ok {
// We've sent some data for this stream, but it still has more to send.
// If the stream got a reasonable chance to put data in a packet,
// advance sendHead to the next stream in line, to avoid starvation.
// We'll come back to this stream after going through the others.
//
// If the packet was already mostly out of space, leave sendHead alone
// and come back to this stream again on the next packet.
if avail > 512 {
c.streams.sendHead = s.next
c.streams.sendTail = s
}
return false
}
return false
}

if state == streamInDone|streamOutDone {
// Stream is finished, remove it from the conn.
s.state.set(streamConnRemoved, streamConnRemoved)
delete(c.streams.streams, s.id)

// TODO: Provide the peer with additional stream quota (MAX_STREAMS).
}

next := s.next
s.next = nil
if (next == s) != (s == c.streams.sendTail) {
Expand Down Expand Up @@ -231,10 +253,16 @@ func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool {
defer c.streams.sendMu.Unlock()
for _, s := range c.streams.streams {
const pto = true
if !s.appendInFrames(w, pnum, pto) {
s.ingate.lock()
inOK := s.appendInFramesLocked(w, pnum, pto)
s.inUnlockNoQueue()
if !inOK {
return false
}
if !s.appendOutFrames(w, pnum, pto) {
s.outgate.lock()
outOK := s.appendOutFramesLocked(w, pnum, pto)
s.outUnlockNoQueue()
if !outOK {
return false
}
}
Expand Down
89 changes: 89 additions & 0 deletions internal/quic/conn_streams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package quic

import (
"context"
"fmt"
"io"
"testing"
)

Expand Down Expand Up @@ -253,3 +255,90 @@ func TestStreamsWriteQueueFairness(t *testing.T) {
}
}
}

func TestStreamsShutdown(t *testing.T) {
// These tests verify that a stream is removed from the Conn's map of live streams
// after it is fully shut down.
//
// Each case consists of a setup step, after which one stream should exist,
// and a shutdown step, after which no streams should remain in the Conn.
for _, test := range []struct {
name string
side streamSide
styp streamType
setup func(*testing.T, *testConn, *Stream)
shutdown func(*testing.T, *testConn, *Stream)
}{{
name: "closed",
side: localStream,
styp: uniStream,
setup: func(t *testing.T, tc *testConn, s *Stream) {
s.CloseContext(canceledContext())
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
tc.writeAckForAll()
},
}, {
name: "local close",
side: localStream,
styp: bidiStream,
setup: func(t *testing.T, tc *testConn, s *Stream) {
tc.writeFrames(packetType1RTT, debugFrameResetStream{
id: s.id,
})
s.CloseContext(canceledContext())
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
tc.writeAckForAll()
},
}, {
name: "remote reset",
side: localStream,
styp: bidiStream,
setup: func(t *testing.T, tc *testConn, s *Stream) {
s.CloseContext(canceledContext())
tc.wantIdle("all frames after CloseContext are ignored")
tc.writeAckForAll()
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
tc.writeFrames(packetType1RTT, debugFrameResetStream{
id: s.id,
})
},
}, {
name: "local close",
side: remoteStream,
styp: uniStream,
setup: func(t *testing.T, tc *testConn, s *Stream) {
ctx := canceledContext()
tc.writeFrames(packetType1RTT, debugFrameStream{
id: s.id,
fin: true,
})
if n, err := s.ReadContext(ctx, make([]byte, 16)); n != 0 || err != io.EOF {
t.Errorf("ReadContext() = %v, %v; want 0, io.EOF", n, err)
}
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
s.CloseRead()
},
}} {
name := fmt.Sprintf("%v/%v/%v", test.side, test.styp, test.name)
t.Run(name, func(t *testing.T) {
tc, s := newTestConnAndStream(t, serverSide, test.side, test.styp,
permissiveTransportParameters)
tc.ignoreFrame(frameTypeStreamBase)
tc.ignoreFrame(frameTypeStopSending)
test.setup(t, tc, s)
tc.wantIdle("conn should be idle after setup")
if got, want := len(tc.conn.streams.streams), 1; got != want {
t.Fatalf("after setup: %v streams in Conn's map; want %v", got, want)
}
test.shutdown(t, tc, s)
tc.wantIdle("conn should be idle after shutdown")
if got, want := len(tc.conn.streams.streams), 0; got != want {
t.Fatalf("after shutdown: %v streams in Conn's map; want %v", got, want)
}
})
}
}
2 changes: 2 additions & 0 deletions internal/quic/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
// writeAckForAll sends the Conn a datagram containing an ack for all packets up to the
// last one received.
func (tc *testConn) writeAckForAll() {
tc.t.Helper()
if tc.lastPacket == nil {
return
}
Expand All @@ -405,6 +406,7 @@ func (tc *testConn) writeAckForAll() {
// writeAckForLatest sends the Conn a datagram containing an ack for the
// most recent packet received.
func (tc *testConn) writeAckForLatest() {
tc.t.Helper()
if tc.lastPacket == nil {
return
}
Expand Down
Loading

0 comments on commit 97384c1

Please sign in to comment.