Skip to content

Commit

Permalink
[v9] Fix listing all nodes in tsh (#19823)
Browse files Browse the repository at this point in the history
[v9] Fix listing all nodes in tsh  (#19823)

Usage of channels was flipped, we tried to write to collecting channel,
but nobody was reading from it, so we blocked forever. Now using simpler
version with mutex for synchronization, and doing it for db listings as
well for consistency.
  • Loading branch information
AntonAM committed Jan 4, 2023
1 parent 3c404fc commit 849adba
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 53 deletions.
27 changes: 9 additions & 18 deletions tool/tsh/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"os"
"sort"
"strings"
"sync"

"golang.org/x/sync/errgroup"

Expand Down Expand Up @@ -145,20 +146,9 @@ func listDatabasesAllClusters(cf *CLIConf) error {
group, groupCtx := errgroup.WithContext(cf.Context)
group.SetLimit(4)

dbListingsResultChan := make(chan databaseListings)
dbListingsCollectChan := make(chan databaseListings)
go func() {
var dbListings databaseListings
for {
select {
case items := <-dbListingsCollectChan:
dbListings = append(dbListings, items...)
case <-groupCtx.Done():
dbListingsResultChan <- dbListings
return
}
}
}()
// mu guards access to dbListings
var mu sync.Mutex
var dbListings databaseListings

err := forEachProfile(cf, func(tc *client.TeleportClient, profile *client.ProfileStatus) error {
group.Go(func() error {
Expand All @@ -173,7 +163,6 @@ func listDatabasesAllClusters(cf *CLIConf) error {
return trace.Wrap(err)
}

var dbListings databaseListings
for _, site := range sites {
databases, err := proxy.FindDatabasesByFiltersForCluster(groupCtx, *tc.DefaultResourceFilter(), site.Name)
if err != nil {
Expand All @@ -188,17 +177,20 @@ func listDatabasesAllClusters(cf *CLIConf) error {
}
}

localDBListings := make(databaseListings, 0, len(databases))
for _, database := range databases {
dbListings = append(dbListings, databaseListing{
localDBListings = append(localDBListings, databaseListing{
Proxy: profile.ProxyURL.Host,
Cluster: site.Name,
roleSet: roleSet,
Database: database,
})
}
mu.Lock()
dbListings = append(dbListings, localDBListings...)
mu.Unlock()
}

dbListingsCollectChan <- dbListings
return nil
})
return nil
Expand All @@ -211,7 +203,6 @@ func listDatabasesAllClusters(cf *CLIConf) error {
return trace.Wrap(err)
}

dbListings := <-dbListingsResultChan
sort.Sort(dbListings)

profile, err := client.StatusCurrent(cf.HomePath, cf.Proxy, cf.IdentityFileIn)
Expand Down
54 changes: 54 additions & 0 deletions tool/tsh/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"path"
"path/filepath"
"strconv"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -385,6 +386,59 @@ func TestTSHConfigConnectWithOpenSSHClient(t *testing.T) {
}
}

// TestList verifies "tsh ls" functionality
func TestList(t *testing.T) {
t.Parallel()

isInsecure := lib.IsInsecureDevMode()
lib.SetInsecureDevMode(true)
t.Cleanup(func() {
lib.SetInsecureDevMode(isInsecure)
})

s := newTestSuite(t,
withRootConfigFunc(func(cfg *service.Config) {
cfg.Version = defaults.TeleportConfigVersionV2
cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex)
}),
withLeafCluster(),
)
rootNodeAddress, err := s.root.NodeSSHAddr()
require.NoError(t, err)

testCases := []struct {
description string
command []string
resultNodes []string
}{
{
description: "List root cluster nodes",
command: []string{"ls"},
resultNodes: []string{"localnode " + rootNodeAddress.String()},
},
{
description: "List all clusters nodes",
command: []string{"ls", "-R"},
resultNodes: []string{"leaf1 localnode", "localhost localnode"},
},
}

for _, test := range testCases {
t.Run(test.description, func(t *testing.T) {
tshHome, _ := mustLogin(t, s)
stdout := new(bytes.Buffer)

err := Run(context.Background(), test.command, setHomePath(tshHome), setOverrideStdout(stdout))

require.NoError(t, err)
require.Equal(t, len(test.resultNodes), len(strings.Split(stdout.String(), "\n"))-4) // 4 - unimportant new lines
for _, node := range test.resultNodes {
require.Contains(t, stdout.String(), node)
}
})
}
}

func createAgent(t *testing.T) string {
t.Helper()

Expand Down
80 changes: 45 additions & 35 deletions tool/tsh/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"runtime"
"sort"
"strings"
"sync"
"syscall"
"time"

Expand Down Expand Up @@ -1594,7 +1595,7 @@ func onListNodes(cf *CLIConf) error {
return nodes[i].GetHostname() < nodes[j].GetHostname()
})

if err := printNodes(nodes, cf.Format, cf.Verbose); err != nil {
if err := printNodes(nodes, cf); err != nil {
return trace.Wrap(err)
}

Expand Down Expand Up @@ -1633,20 +1634,9 @@ func listNodesAllClusters(cf *CLIConf) error {
group, groupCtx := errgroup.WithContext(cf.Context)
group.SetLimit(4)

nodeListingsResultChan := make(chan nodeListings)
nodeListingsCollectChan := make(chan nodeListings)
go func() {
var listings nodeListings
for {
select {
case items := <-nodeListingsResultChan:
listings = append(listings, items...)
case <-groupCtx.Done():
nodeListingsCollectChan <- listings
return
}
}
}()
// mu guards access to listings
var mu sync.Mutex
var listings nodeListings

err := forEachProfile(cf, func(tc *client.TeleportClient, profile *client.ProfileStatus) error {
group.Go(func() error {
Expand All @@ -1661,23 +1651,25 @@ func listNodesAllClusters(cf *CLIConf) error {
return trace.Wrap(err)
}

var listings nodeListings
for _, site := range sites {
nodes, err := proxy.FindNodesByFiltersForCluster(groupCtx, *tc.DefaultResourceFilter(), site.Name)
if err != nil {
return trace.Wrap(err)
}

localListings := make(nodeListings, 0, len(nodes))
for _, node := range nodes {
listings = append(listings, nodeListing{
localListings = append(localListings, nodeListing{
Proxy: profile.ProxyURL.Host,
Cluster: site.Name,
Node: node,
})
}
mu.Lock()
listings = append(listings, localListings...)
mu.Unlock()
}

nodeListingsCollectChan <- listings
return nil
})

Expand All @@ -1691,26 +1683,30 @@ func listNodesAllClusters(cf *CLIConf) error {
return trace.Wrap(err)
}

listings := <-nodeListingsResultChan
sort.Sort(listings)

format := strings.ToLower(cf.Format)
switch format {
case teleport.Text, "":
printNodesWithClusters(listings, cf.Verbose)
if err := printNodesWithClusters(listings, cf.Verbose, cf.Stdout()); err != nil {
return trace.Wrap(err)
}
case teleport.JSON, teleport.YAML:
out, err := serializeNodesWithClusters(listings, format)
if err != nil {
return trace.Wrap(err)
}
fmt.Println(out)

if _, err := fmt.Fprintln(cf.Stdout(), out); err != nil {
return trace.Wrap(err)
}
default:

}
return nil
}

func printNodesWithClusters(nodes []nodeListing, verbose bool) {
func printNodesWithClusters(nodes []nodeListing, verbose bool, output io.Writer) error {
var rows [][]string
for _, n := range nodes {
rows = append(rows, getNodeRow(n.Proxy, n.Cluster, n.Node, verbose))
Expand All @@ -1721,7 +1717,10 @@ func printNodesWithClusters(nodes []nodeListing, verbose bool) {
} else {
t = asciitable.MakeTableWithTruncatedColumn([]string{"Proxy", "Cluster", "Node Name", "Address", "Labels"}, rows, "Labels")
}
fmt.Println(t.AsBuffer().String())
if _, err := fmt.Fprintln(output, t.AsBuffer().String()); err != nil {
return trace.Wrap(err)
}
return nil
}

func serializeNodesWithClusters(nodes []nodeListing, format string) (string, error) {
Expand Down Expand Up @@ -1840,23 +1839,30 @@ func executeAccessRequest(cf *CLIConf, tc *client.TeleportClient) error {
return trace.Wrap(<-errChan)
}

func printNodes(nodes []types.Server, format string, verbose bool) error {
format = strings.ToLower(format)
func printNodes(nodes []types.Server, conf *CLIConf) error {
format := strings.ToLower(conf.Format)
switch format {
case teleport.Text, "":
printNodesAsText(nodes, verbose)
if err := printNodesAsText(conf.Stdout(), nodes, conf.Verbose); err != nil {
return trace.Wrap(err)
}
case teleport.JSON, teleport.YAML:
out, err := serializeNodes(nodes, format)
if err != nil {
return trace.Wrap(err)
}
fmt.Println(out)
if _, err := fmt.Fprintln(conf.Stdout(), out); err != nil {
return trace.Wrap(err)
}
case teleport.Names:
for _, n := range nodes {
fmt.Println(n.GetHostname())
if _, err := fmt.Fprintln(conf.Stdout(), n.GetHostname()); err != nil {
return trace.Wrap(err)
}
}
default:
return trace.BadParameter("unsupported format %q", format)

}

return nil
Expand Down Expand Up @@ -1898,7 +1904,7 @@ func getNodeRow(proxy, cluster string, node types.Server, verbose bool) []string
return row
}

func printNodesAsText(nodes []types.Server, verbose bool) {
func printNodesAsText(output io.Writer, nodes []types.Server, verbose bool) error {
var rows [][]string
for _, n := range nodes {
rows = append(rows, getNodeRow("", "", n, verbose))
Expand All @@ -1914,7 +1920,11 @@ func printNodesAsText(nodes []types.Server, verbose bool) {
case false:
t = asciitable.MakeTableWithTruncatedColumn([]string{"Node Name", "Address", "Labels"}, rows, "Labels")
}
fmt.Println(t.AsBuffer().String())
if _, err := fmt.Fprintln(output, t.AsBuffer().String()); err != nil {
return trace.Wrap(err)
}

return nil
}

func sortedLabels(labels map[string]string) string {
Expand Down Expand Up @@ -2331,11 +2341,11 @@ func onSSH(cf *CLIConf) error {
nodes = append(nodes, node)
}
}
fmt.Fprintf(os.Stderr, "error: ambiguous host could match multiple nodes\n\n")
printNodesAsText(nodes, true)
fmt.Fprintf(os.Stderr, "Hint: try addressing the node by unique id (ex: tsh ssh user@node-id)\n")
fmt.Fprintf(os.Stderr, "Hint: use 'tsh ls -v' to list all nodes with their unique ids\n")
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(cf.Stderr(), "error: ambiguous host could match multiple nodes\n\n")
printNodesAsText(cf.Stderr(), nodes, true)
fmt.Fprintf(cf.Stderr(), "Hint: try addressing the node by unique id (ex: tsh ssh user@node-id)\n")
fmt.Fprintf(cf.Stderr(), "Hint: use 'tsh ls -v' to list all nodes with their unique ids\n")
fmt.Fprintf(cf.Stderr(), "\n")
return trace.Wrap(&exitCodeError{code: 1})
}
// exit with the same exit status as the failed command:
Expand Down
8 changes: 8 additions & 0 deletions tool/tsh/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/url"
Expand Down Expand Up @@ -1574,6 +1575,13 @@ func mockSSOLogin(t *testing.T, authServer *auth.Server, user types.User) client
}
}

func setOverrideStdout(stdout io.Writer) cliOption {
return func(cf *CLIConf) error {
cf.overrideStdout = stdout
return nil
}
}

func setHomePath(path string) cliOption {
return func(cf *CLIConf) error {
cf.HomePath = path
Expand Down

0 comments on commit 849adba

Please sign in to comment.