Skip to content

Commit

Permalink
ssh: eliminate some goroutine leaks in tests and examples
Browse files Browse the repository at this point in the history
This should fix the "Log in goroutine" panic seen in
https://build.golang.org/log/e42bf69fc002113dbccfe602a6c67fd52e8f31df,
as well as a few other related leaks. It also helps to verify that
none of the functions under test deadlock unexpectedly.

See https://go.dev/wiki/CodeReviewComments#goroutine-lifetimes.

Updates golang/go#58901.

Change-Id: Ica943444db381ae1accb80b101ea646e28ebf4f9
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/541095
Auto-Submit: Bryan Mills <bcmills@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
  • Loading branch information
Bryan C. Mills authored and gopherbot committed Nov 9, 2023
1 parent eb61739 commit ff15cd5
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 56 deletions.
18 changes: 16 additions & 2 deletions ssh/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
Expand Down Expand Up @@ -98,8 +99,15 @@ func ExampleNewServerConn() {
}
log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"])

var wg sync.WaitGroup
defer wg.Wait()

// The incoming Request channel must be serviced.
go ssh.DiscardRequests(reqs)
wg.Add(1)
go func() {
ssh.DiscardRequests(reqs)
wg.Done()
}()

// Service the incoming Channel channel.
for newChannel := range chans {
Expand All @@ -119,16 +127,22 @@ func ExampleNewServerConn() {
// Sessions have out-of-band requests such as "shell",
// "pty-req" and "env". Here we handle only the
// "shell" request.
wg.Add(1)
go func(in <-chan *ssh.Request) {
for req := range in {
req.Reply(req.Type == "shell", nil)
}
wg.Done()
}(requests)

term := terminal.NewTerminal(channel, "> ")

wg.Add(1)
go func() {
defer channel.Close()
defer func() {
channel.Close()
wg.Done()
}()
for {
line, err := term.ReadLine()
if err != nil {
Expand Down
91 changes: 49 additions & 42 deletions ssh/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"io"
"sync"
"testing"
"time"
)

func muxPair() (*mux, *mux) {
Expand Down Expand Up @@ -112,7 +111,11 @@ func TestMuxReadWrite(t *testing.T) {

magic := "hello world"
magicExt := "hello stderr"
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer wg.Done()
_, err := s.Write([]byte(magic))
if err != nil {
t.Errorf("Write: %v", err)
Expand Down Expand Up @@ -152,13 +155,15 @@ func TestMuxChannelOverflow(t *testing.T) {
defer writer.Close()
defer mux.Close()

wDone := make(chan int, 1)
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
writer.Write(make([]byte, 1))
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()

Expand All @@ -175,7 +180,6 @@ func TestMuxChannelOverflow(t *testing.T) {
if _, err := reader.SendRequest("hello", true, nil); err == nil {
t.Errorf("SendRequest succeeded.")
}
<-wDone
}

func TestMuxChannelCloseWriteUnblock(t *testing.T) {
Expand All @@ -184,20 +188,21 @@ func TestMuxChannelCloseWriteUnblock(t *testing.T) {
defer writer.Close()
defer mux.Close()

wDone := make(chan int, 1)
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
wDone <- 1
}()

writer.remoteWin.waitWriterBlocked()
reader.Close()
<-wDone
}

func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
Expand All @@ -206,28 +211,34 @@ func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
defer writer.Close()
defer mux.Close()

wDone := make(chan int, 1)
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
wDone <- 1
}()

writer.remoteWin.waitWriterBlocked()
mux.Close()
<-wDone
}

func TestMuxReject(t *testing.T) {
client, server := muxPair()
defer server.Close()
defer client.Close()

var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer wg.Done()

ch, ok := <-server.incomingChannels
if !ok {
t.Error("cannot accept channel")
Expand Down Expand Up @@ -267,6 +278,7 @@ func TestMuxChannelRequest(t *testing.T) {

var received int
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
for r := range server.incomingRequests {
Expand Down Expand Up @@ -295,7 +307,6 @@ func TestMuxChannelRequest(t *testing.T) {
}
if ok {
t.Errorf("SendRequest(no): %v", ok)

}

client.Close()
Expand Down Expand Up @@ -389,27 +400,18 @@ func TestMuxUnknownChannelRequests(t *testing.T) {

// Wait for the server to send the keepalive message and receive back a
// response.
select {
case err := <-kDone:
if err != nil {
t.Fatal(err)
}
case <-time.After(10 * time.Second):
t.Fatalf("server never received ack")
if err := <-kDone; err != nil {
t.Fatal(err)
}

// Confirm client hasn't closed.
if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil {
t.Fatalf("failed to send keepalive: %v", err)
}

select {
case err := <-kDone:
if err != nil {
t.Fatal(err)
}
case <-time.After(10 * time.Second):
t.Fatalf("server never shut down")
// Wait for the server to shut down.
if err := <-kDone; err != nil {
t.Fatal(err)
}
}

Expand Down Expand Up @@ -525,11 +527,7 @@ func TestMuxClosedChannel(t *testing.T) {
defer ch.Close()

// Wait for the server to close the channel and send the keepalive.
select {
case <-kDone:
case <-time.After(10 * time.Second):
t.Fatalf("server never received ack")
}
<-kDone

// Make sure the channel closed.
if _, ok := <-ch.incomingRequests; ok {
Expand All @@ -541,22 +539,29 @@ func TestMuxClosedChannel(t *testing.T) {
t.Fatalf("failed to send keepalive: %v", err)
}

select {
case <-kDone:
case <-time.After(10 * time.Second):
t.Fatalf("server never shut down")
}
// Wait for the server to shut down.
<-kDone
}

func TestMuxGlobalRequest(t *testing.T) {
var sawPeek bool
var wg sync.WaitGroup
defer func() {
wg.Wait()
if !sawPeek {
t.Errorf("never saw 'peek' request")
}
}()

clientMux, serverMux := muxPair()
defer serverMux.Close()
defer clientMux.Close()

var seen bool
wg.Add(1)
go func() {
defer wg.Done()
for r := range serverMux.incomingRequests {
seen = seen || r.Type == "peek"
sawPeek = sawPeek || r.Type == "peek"
if r.WantReply {
err := r.Reply(r.Type == "yes",
append([]byte(r.Type), r.Payload...))
Expand Down Expand Up @@ -586,10 +591,6 @@ func TestMuxGlobalRequest(t *testing.T) {
t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
ok, data, err)
}

if !seen {
t.Errorf("never saw 'peek' request")
}
}

func TestMuxGlobalRequestUnblock(t *testing.T) {
Expand Down Expand Up @@ -739,7 +740,13 @@ func TestMuxMaxPacketSize(t *testing.T) {
t.Errorf("could not send packet")
}

go a.SendRequest("hello", false, nil)
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
a.SendRequest("hello", false, nil)
wg.Done()
}()

_, ok := <-b.incomingRequests
if ok {
Expand Down
Loading

0 comments on commit ff15cd5

Please sign in to comment.