Skip to content

Commit

Permalink
Add unit tests for testing hubble args/flags handling
Browse files Browse the repository at this point in the history
Signed-off-by: Chance Zibolski <chance.zibolski@gmail.com>
  • Loading branch information
chancez committed Jan 25, 2023
1 parent 904285a commit 950e9c5
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 39 deletions.
108 changes: 108 additions & 0 deletions cmd/cli_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Hubble

package cmd

import (
"bytes"
"context"
"io/ioutil"
"testing"

"github.com/cilium/cilium/api/v1/observer"
"github.com/cilium/hubble/cmd/observe"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var expectedObserveHelp string

func init() {
// Override the client so that it always returns an IOReaderObserver with no flows.
observe.GetHubbleClientFunc = func(_ context.Context, _ *viper.Viper) (client observer.ObserverClient, cleanup func() error, err error) {
cleanup = func() error { return nil }
return observe.NewIOReaderObserver(bytes.NewBuffer([]byte(``))), cleanup, nil
}

// Separate file because it has more sensitive whitespace/is bigger than we
// would want to include inline with the tests.
b, err := ioutil.ReadFile("observe_help.txt")
if err != nil {
panic(err)
}
expectedObserveHelp = string(b)
}

func TestTestHubbleObserve(t *testing.T) {
tests := []struct {
name string
args []string
expectErr error
expectedOutput string
}{
{
name: "observe no flags",
args: []string{"observe"},
},
{
name: "observe formatting flags",
args: []string{"observe", "-o", "json"},
},
{
name: "observe server flags",
args: []string{"observe", "--server", "foo.example.org", "--tls", "--tls-allow-insecure"},
},
{
name: "observe filter flags",
args: []string{"observe", "--from-pod", "foo/test-pod-1234", "--type", "l7"},
},
{
name: "help",
args: []string{"--help"},
expectedOutput: `Hubble is a utility to observe and inspect recent Cilium routed traffic in a cluster.
Usage:
hubble [command]
Available Commands:
completion Generate the autocompletion script for the specified shell
config Modify or view hubble config
help Help about any command
list List Hubble objects
observe Observe flows of a Hubble server
status Display status of Hubble server
version Display detailed version information
Global Flags:
--config string Optional config file (default "/Users/chancezibolski/Library/Application Support/hubble/config.yaml")
-D, --debug Enable debug messages
Get help:
-h, --help Help for any command or subcommand
Use "hubble [command] --help" for more information about a command.
`,
},
{
name: "observe help",
args: []string{"observe", "--help"},
expectedOutput: expectedObserveHelp,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var b bytes.Buffer
cli := New()
cli.SetOut(&b)
cli.SetArgs(tt.args)
err := cli.Execute()
require.Equal(t, tt.expectErr, err)
output := b.String()
if tt.expectedOutput != "" {
assert.Equal(t, tt.expectedOutput, output, "expected output does not match")
}
})
}
}
2 changes: 1 addition & 1 deletion cmd/observe/agent_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func newAgentEventsCommand(vp *viper.Viper, flagSets ...*pflag.FlagSet) *cobra.C
Short: "Observe Cilium agent events",
RunE: func(cmd *cobra.Command, _ []string) error {
debug := vp.GetBool(config.KeyDebug)
if err := handleEventsArgs(debug); err != nil {
if err := handleEventsArgs(cmd.OutOrStdout(), debug); err != nil {
return err
}
req, err := getAgentEventsRequest()
Expand Down
2 changes: 1 addition & 1 deletion cmd/observe/debug_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func newDebugEventsCommand(vp *viper.Viper, flagSets ...*pflag.FlagSet) *cobra.C
Short: "Observe Cilium debug events",
RunE: func(cmd *cobra.Command, _ []string) error {
debug := vp.GetBool(config.KeyDebug)
if err := handleEventsArgs(debug); err != nil {
if err := handleEventsArgs(cmd.OutOrStdout(), debug); err != nil {
return err
}
req, err := getDebugEventsRequest()
Expand Down
4 changes: 3 additions & 1 deletion cmd/observe/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ package observe

import (
"fmt"
"io"

"github.com/cilium/hubble/cmd/common/config"
hubprinter "github.com/cilium/hubble/pkg/printer"
hubtime "github.com/cilium/hubble/pkg/time"
)

func handleEventsArgs(debug bool) error {
func handleEventsArgs(writer io.Writer, debug bool) error {
// initialize the printer with any options that were passed in
var opts = []hubprinter.Option{
hubprinter.Writer(writer),
hubprinter.WithTimeFormat(hubtime.FormatNameToLayout(formattingOpts.timeFormat)),
}

Expand Down
44 changes: 27 additions & 17 deletions cmd/observe/flows.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,28 @@ func getFlowFiltersYAML(req *observer.GetFlowsRequest) (string, error) {
return string(out), nil
}

// GetHubbleClientFunc is primarily used to mock out the hubble client in some unit tests.
var GetHubbleClientFunc = func(ctx context.Context, vp *viper.Viper) (client observer.ObserverClient, cleanup func() error, err error) {
fi, err := os.Stdin.Stat()
if err != nil {
return nil, nil, err
}
if fi.Mode()&os.ModeCharDevice == 0 {
// read flows from stdin
client = NewIOReaderObserver(os.Stdin)
logger.Logger.Debug("Reading flows from stdin")
} else {
// read flows from a hubble server
hubbleConn, err := conn.New(ctx, vp.GetString(config.KeyServer), vp.GetDuration(config.KeyTimeout))
if err != nil {
return nil, nil, err
}
cleanup = hubbleConn.Close
client = observer.NewObserverClient(hubbleConn)
}
return client, cleanup, nil
}

func newFlowsCmd(vp *viper.Viper, ofilter *flowFilter) *cobra.Command {
observeCmd := &cobra.Command{
Example: `* Piping flows to hubble observe
Expand All @@ -160,7 +182,7 @@ individual pods, services, TCP connections, DNS queries, HTTP requests and
more.`,
RunE: func(cmd *cobra.Command, args []string) error {
debug := vp.GetBool(config.KeyDebug)
if err := handleFlowArgs(ofilter, debug); err != nil {
if err := handleFlowArgs(cmd.OutOrStdout(), ofilter, debug); err != nil {
return err
}
req, err := getFlowsRequest(ofilter, vp.GetStringSlice(allowlistFlag), vp.GetStringSlice(denylistFlag))
Expand All @@ -179,24 +201,11 @@ more.`,
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
defer cancel()

var client observer.ObserverClient
fi, err := os.Stdin.Stat()
client, cleanup, err := GetHubbleClientFunc(ctx, vp)
if err != nil {
return err
}
if fi.Mode()&os.ModeCharDevice == 0 {
// read flows from stdin
client = newIOReaderObserver(os.Stdin)
logger.Logger.Debug("Reading flows from stdin")
} else {
// read flows from a hubble server
hubbleConn, err := conn.New(ctx, vp.GetString(config.KeyServer), vp.GetDuration(config.KeyTimeout))
if err != nil {
return err
}
defer hubbleConn.Close()
client = observer.NewObserverClient(hubbleConn)
}
defer cleanup()

logger.Logger.WithField("request", req).Debug("Sending GetFlows request")
if err := getFlows(ctx, client, req); err != nil {
Expand Down Expand Up @@ -591,13 +600,14 @@ more.`,
return observeCmd
}

func handleFlowArgs(ofilter *flowFilter, debug bool) (err error) {
func handleFlowArgs(writer io.Writer, ofilter *flowFilter, debug bool) (err error) {
if ofilter.blacklisting {
return errors.New("trailing --not found in the arguments")
}

// initialize the printer with any options that were passed in
var opts = []hubprinter.Option{
hubprinter.Writer(writer),
hubprinter.WithTimeFormat(hubtime.FormatNameToLayout(formattingOpts.timeFormat)),
hubprinter.WithColor(formattingOpts.color),
}
Expand Down
9 changes: 5 additions & 4 deletions cmd/observe/flows_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package observe

import (
"os"
"strconv"
"testing"

Expand Down Expand Up @@ -64,7 +65,7 @@ func TestTrailingNot(t *testing.T) {
})
require.NoError(t, err)

err = handleFlowArgs(f, false)
err = handleFlowArgs(os.Stdout, f, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "trailing --not")
}
Expand All @@ -82,7 +83,7 @@ func TestFilterDispatch(t *testing.T) {
"-t", "l7", // int:129 in cilium-land
}))

require.NoError(t, handleFlowArgs(f, false))
require.NoError(t, handleFlowArgs(os.Stdout, f, false))
if diff := cmp.Diff(
[]*flowpb.FlowFilter{
{
Expand Down Expand Up @@ -127,7 +128,7 @@ func TestFilterLeftRight(t *testing.T) {
"--node-name", "k8s*",
}))

require.NoError(t, handleFlowArgs(f, false))
require.NoError(t, handleFlowArgs(os.Stdout, f, false))

if diff := cmp.Diff(
[]*flowpb.FlowFilter{
Expand Down Expand Up @@ -198,7 +199,7 @@ func TestFilterType(t *testing.T) {
"-t", "agent:service-deleted",
}))

require.NoError(t, handleFlowArgs(f, false))
require.NoError(t, handleFlowArgs(os.Stdout, f, false))
if diff := cmp.Diff(
[]*flowpb.FlowFilter{
{
Expand Down
25 changes: 16 additions & 9 deletions cmd/observe/io_reader_observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,42 @@ import (
"google.golang.org/protobuf/encoding/protojson"
)

// ioReaderObserver implements ObserverClient interface. It reads flows
// IOReaderObserver implements ObserverClient interface. It reads flows
// in jsonpb format from an io.Reader.
type ioReaderObserver struct {
type IOReaderObserver struct {
scanner *bufio.Scanner
}

func newIOReaderObserver(reader io.Reader) *ioReaderObserver {
return &ioReaderObserver{
// NewIOReaderObserver reads flows in jsonpb format from an io.Reader and
// returns a IOReaderObserver that implements the ObserverClient interface.
func NewIOReaderObserver(reader io.Reader) *IOReaderObserver {
return &IOReaderObserver{
scanner: bufio.NewScanner(reader),
}
}

func (o *ioReaderObserver) GetFlows(ctx context.Context, in *observer.GetFlowsRequest, _ ...grpc.CallOption) (observer.Observer_GetFlowsClient, error) {
// GetFlows returns flows
func (o *IOReaderObserver) GetFlows(ctx context.Context, in *observer.GetFlowsRequest, _ ...grpc.CallOption) (observer.Observer_GetFlowsClient, error) {
return newIOReaderClient(ctx, o.scanner, in)
}

func (o *ioReaderObserver) GetAgentEvents(_ context.Context, _ *observer.GetAgentEventsRequest, _ ...grpc.CallOption) (observer.Observer_GetAgentEventsClient, error) {
// GetAgentEvents is not implemented, and will throw an error if used.
func (o *IOReaderObserver) GetAgentEvents(_ context.Context, _ *observer.GetAgentEventsRequest, _ ...grpc.CallOption) (observer.Observer_GetAgentEventsClient, error) {
return nil, status.Errorf(codes.Unimplemented, "GetAgentEvents not implemented")
}

func (o *ioReaderObserver) GetDebugEvents(_ context.Context, _ *observer.GetDebugEventsRequest, _ ...grpc.CallOption) (observer.Observer_GetDebugEventsClient, error) {
// GetDebugEvents is not implemented, and will throw an error if used.
func (o *IOReaderObserver) GetDebugEvents(_ context.Context, _ *observer.GetDebugEventsRequest, _ ...grpc.CallOption) (observer.Observer_GetDebugEventsClient, error) {
return nil, status.Errorf(codes.Unimplemented, "GetDebugEvents not implemented")
}

func (o *ioReaderObserver) GetNodes(_ context.Context, _ *observer.GetNodesRequest, _ ...grpc.CallOption) (*observer.GetNodesResponse, error) {
// GetNodes is not implemented, and will throw an error if used.
func (o *IOReaderObserver) GetNodes(_ context.Context, _ *observer.GetNodesRequest, _ ...grpc.CallOption) (*observer.GetNodesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "GetNodes not implemented")
}

func (o *ioReaderObserver) ServerStatus(_ context.Context, _ *observer.ServerStatusRequest, _ ...grpc.CallOption) (*observer.ServerStatusResponse, error) {
// ServerStatus is not implemented, and will throw an error if used.
func (o *IOReaderObserver) ServerStatus(_ context.Context, _ *observer.ServerStatusRequest, _ ...grpc.CallOption) (*observer.ServerStatusResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "ServerStatus not implemented")
}

Expand Down
6 changes: 3 additions & 3 deletions cmd/observe/io_reader_observer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func Test_getFlowsBasic(t *testing.T) {
assert.NoError(t, err)
flowStrings = append(flowStrings, string(b))
}
server := newIOReaderObserver(strings.NewReader(strings.Join(flowStrings, "\n") + "\n"))
server := NewIOReaderObserver(strings.NewReader(strings.Join(flowStrings, "\n") + "\n"))
req := observer.GetFlowsRequest{}
client, err := server.GetFlows(context.Background(), &req)
assert.NoError(t, err)
Expand Down Expand Up @@ -56,7 +56,7 @@ func Test_getFlowsTimeRange(t *testing.T) {
assert.NoError(t, err)
flowStrings = append(flowStrings, string(b))
}
server := newIOReaderObserver(strings.NewReader(strings.Join(flowStrings, "\n") + "\n"))
server := NewIOReaderObserver(strings.NewReader(strings.Join(flowStrings, "\n") + "\n"))
req := observer.GetFlowsRequest{
Since: &timestamppb.Timestamp{Seconds: 50},
Until: &timestamppb.Timestamp{Seconds: 150},
Expand Down Expand Up @@ -91,7 +91,7 @@ func Test_getFlowsFilter(t *testing.T) {
assert.NoError(t, err)
flowStrings = append(flowStrings, string(b))
}
server := newIOReaderObserver(strings.NewReader(strings.Join(flowStrings, "\n") + "\n"))
server := NewIOReaderObserver(strings.NewReader(strings.Join(flowStrings, "\n") + "\n"))
req := observer.GetFlowsRequest{
Whitelist: []*flow.FlowFilter{
{
Expand Down
Loading

0 comments on commit 950e9c5

Please sign in to comment.