From 1b4865ae8f246ca4a8af5fa6bb457bbf3475b9e0 Mon Sep 17 00:00:00 2001 From: Sebastian Choren Date: Fri, 16 Feb 2024 11:51:42 -0300 Subject: [PATCH] feat(cli): reauthenticate user in case of invalid token (#3643) --- cli/cmd/configure_cmd.go | 3 +- cli/cmd/dashboard_cmd.go | 3 +- cli/cmd/middleware.go | 62 ++++++++++++++++++++++++++++--- cli/cmd/resource_apply_cmd.go | 3 +- cli/cmd/resource_delete_cmd.go | 3 +- cli/cmd/resource_export_cmd.go | 3 +- cli/cmd/resource_get_cmd.go | 3 +- cli/cmd/resource_list_cmd.go | 3 +- cli/cmd/resource_run_cmd.go | 3 +- cli/cmd/root.go | 12 ++++-- cli/cmd/start_cmd.go | 19 ++++++++-- cli/cmd/version_cmd.go | 4 +- cli/config/config.go | 44 +++++++++++++++++++--- cli/config/configurator.go | 31 ++++++++++++---- cli/pkg/resourcemanager/client.go | 15 +++++--- 15 files changed, 165 insertions(+), 46 deletions(-) diff --git a/cli/cmd/configure_cmd.go b/cli/cmd/configure_cmd.go index ef9d8582f0..934df57c1e 100644 --- a/cli/cmd/configure_cmd.go +++ b/cli/cmd/configure_cmd.go @@ -21,8 +21,7 @@ var configureCmd = &cobra.Command{ Short: "Configure your tracetest CLI", Long: "Configure your tracetest CLI", PreRun: setupLogger, - Run: WithResultHandler(WithParamsHandler(configParams)(func(cmd *cobra.Command, _ []string) (string, error) { - ctx := context.Background() + Run: WithResultHandler(WithParamsHandler(configParams)(func(ctx context.Context, cmd *cobra.Command, _ []string) (string, error) { flags := agentConfig.Flags{ CI: configParams.CI, } diff --git a/cli/cmd/dashboard_cmd.go b/cli/cmd/dashboard_cmd.go index a1598a2188..c2f50e94b6 100644 --- a/cli/cmd/dashboard_cmd.go +++ b/cli/cmd/dashboard_cmd.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "github.com/kubeshop/tracetest/cli/ui" @@ -13,7 +14,7 @@ var dashboardCmd = &cobra.Command{ Short: "Opens the Tracetest Dashboard URL", Long: "Opens the Tracetest Dashboard URL", PreRun: setupCommand(), - Run: WithResultHandler(func(_ *cobra.Command, _ []string) (string, error) { + Run: WithResultHandler(func(_ context.Context, _ *cobra.Command, _ []string) (string, error) { if cliConfig.IsEmpty() { return "", fmt.Errorf("missing Tracetest endpoint configuration") } diff --git a/cli/cmd/middleware.go b/cli/cmd/middleware.go index 473ebffb0a..d738ecd357 100644 --- a/cli/cmd/middleware.go +++ b/cli/cmd/middleware.go @@ -1,25 +1,55 @@ package cmd import ( + "context" "errors" "fmt" "os" + "github.com/kubeshop/tracetest/cli/config" "github.com/kubeshop/tracetest/cli/pkg/resourcemanager" + "github.com/kubeshop/tracetest/cli/ui" "github.com/spf13/cobra" ) -type RunFn func(cmd *cobra.Command, args []string) (string, error) +type RunFn func(ctx context.Context, cmd *cobra.Command, args []string) (string, error) type CobraRunFn func(cmd *cobra.Command, args []string) type MiddlewareWrapper func(RunFn) RunFn +func rootCtx(cmd *cobra.Command) context.Context { + // cobra does not correctly progpagate rootcmd context to sub commands, + // so we need to manually traverse the command tree to find the root context + if cmd == nil { + return nil + } + + var ( + ctx = cmd.Context() + p = cmd.Parent() + ) + if cmd.Parent() == nil { + return ctx + } + for { + ctx = p.Context() + p = p.Parent() + if p == nil { + break + } + } + return ctx +} + func WithResultHandler(runFn RunFn) CobraRunFn { return func(cmd *cobra.Command, args []string) { - res, err := runFn(cmd, args) + // we need the root cmd context in case of an error caused rerun + ctx := rootCtx(cmd) + + res, err := runFn(ctx, cmd, args) if err != nil { - OnError(err) + handleError(ctx, err) return } @@ -29,6 +59,28 @@ func WithResultHandler(runFn RunFn) CobraRunFn { } } +func handleError(ctx context.Context, err error) { + reqErr := resourcemanager.RequestError{} + if errors.As(err, &reqErr) && reqErr.IsAuthError { + handleAuthError(ctx) + } else { + OnError(err) + } +} + +func handleAuthError(ctx context.Context) { + ui.DefaultUI.Warning("Your authentication token has expired, please log in again.") + configurator. + WithOnFinish(func(ctx context.Context, _ config.Config) { + retryCommand(ctx) + }). + ExecuteUserLogin(ctx, cliConfig) +} + +func retryCommand(ctx context.Context) { + handleRootExecErr(rootCmd.ExecuteContext(ctx)) +} + type errorMessageRenderer interface { Render() } @@ -66,7 +118,7 @@ func handleErrorMessage(err error) string { func WithParamsHandler(validators ...Validator) MiddlewareWrapper { return func(runFn RunFn) RunFn { - return func(cmd *cobra.Command, args []string) (string, error) { + return func(ctx context.Context, cmd *cobra.Command, args []string) (string, error) { errors := make([]error, 0) for _, validator := range validators { @@ -82,7 +134,7 @@ func WithParamsHandler(validators ...Validator) MiddlewareWrapper { return "", fmt.Errorf(errorText) } - return runFn(cmd, args) + return runFn(ctx, cmd, args) } } } diff --git a/cli/cmd/resource_apply_cmd.go b/cli/cmd/resource_apply_cmd.go index fa91fb884f..2f9e902499 100644 --- a/cli/cmd/resource_apply_cmd.go +++ b/cli/cmd/resource_apply_cmd.go @@ -21,9 +21,8 @@ func init() { Short: "Apply resources", Long: "Apply (create/update) resources to your Tracetest server", PreRun: setupCommand(), - Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) { + Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) { resourceType := resourceParams.ResourceName - ctx := context.Background() resourceClient, err := resources.Get(resourceType) if err != nil { diff --git a/cli/cmd/resource_delete_cmd.go b/cli/cmd/resource_delete_cmd.go index 13fd1e1ee5..ab291527a6 100644 --- a/cli/cmd/resource_delete_cmd.go +++ b/cli/cmd/resource_delete_cmd.go @@ -21,9 +21,8 @@ func init() { Short: "Delete resources", Long: "Delete resources from your Tracetest server", PreRun: setupCommand(), - Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) { + Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) { resourceType := resourceParams.ResourceName - ctx := context.Background() resourceClient, err := resources.Get(resourceType) if err != nil { diff --git a/cli/cmd/resource_export_cmd.go b/cli/cmd/resource_export_cmd.go index cba48c3d7e..a9d33cefeb 100644 --- a/cli/cmd/resource_export_cmd.go +++ b/cli/cmd/resource_export_cmd.go @@ -21,9 +21,8 @@ func init() { Long: "Export a resource from your Tracetest server", Short: "Export resource", PreRun: setupCommand(), - Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) { + Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) { resourceType := resourceParams.ResourceName - ctx := context.Background() resourceClient, err := resources.Get(resourceType) if err != nil { diff --git a/cli/cmd/resource_get_cmd.go b/cli/cmd/resource_get_cmd.go index b10dece6d5..09760770f4 100644 --- a/cli/cmd/resource_get_cmd.go +++ b/cli/cmd/resource_get_cmd.go @@ -20,9 +20,8 @@ func init() { Short: "Get resource", Long: "Get a resource from your Tracetest server", PreRun: setupCommand(), - Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) { + Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) { resourceType := resourceParams.ResourceName - ctx := context.Background() resourceClient, err := resources.Get(resourceType) if err != nil { diff --git a/cli/cmd/resource_list_cmd.go b/cli/cmd/resource_list_cmd.go index 59f6bcb7d1..a2536b36b2 100644 --- a/cli/cmd/resource_list_cmd.go +++ b/cli/cmd/resource_list_cmd.go @@ -19,9 +19,8 @@ func init() { Short: "List resources", Long: "List resources from your Tracetest server", PreRun: setupCommand(), - Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) { + Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) { resourceType := resourceParams.ResourceName - ctx := context.Background() resourceClient, err := resources.Get(resourceType) if err != nil { diff --git a/cli/cmd/resource_run_cmd.go b/cli/cmd/resource_run_cmd.go index 39e3c1ffbe..1eb2022b4a 100644 --- a/cli/cmd/resource_run_cmd.go +++ b/cli/cmd/resource_run_cmd.go @@ -24,8 +24,7 @@ func init() { Short: "run resources", Long: "run resources", PreRun: setupCommand(WithOptionalResourceName()), - Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) { - ctx := context.Background() + Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) { resourceType, err := getResourceType(runParams, resourceParams) if err != nil { return "", err diff --git a/cli/cmd/root.go b/cli/cmd/root.go index 72a1c3fc40..b73ad1d65f 100644 --- a/cli/cmd/root.go +++ b/cli/cmd/root.go @@ -31,10 +31,16 @@ var rootCmd = &cobra.Command{ } func Execute() { - if err := rootCmd.Execute(); err != nil { - fmt.Fprintln(os.Stderr, err) - ExitCLI(1) + handleRootExecErr(rootCmd.Execute()) +} + +func handleRootExecErr(err error) { + if err == nil { + ExitCLI(0) } + + fmt.Fprintln(os.Stderr, err) + ExitCLI(1) } func ExitCLI(errorCode int) { diff --git a/cli/cmd/start_cmd.go b/cli/cmd/start_cmd.go index d32095ee02..955ff48f07 100644 --- a/cli/cmd/start_cmd.go +++ b/cli/cmd/start_cmd.go @@ -4,14 +4,16 @@ import ( "context" "os" + "github.com/davecgh/go-spew/spew" agentConfig "github.com/kubeshop/tracetest/agent/config" "github.com/kubeshop/tracetest/agent/runner" "github.com/kubeshop/tracetest/agent/ui" + "github.com/kubeshop/tracetest/cli/config" "github.com/spf13/cobra" ) var ( - agentRunner = runner.NewRunner(configurator, resources, ui.DefaultUI) + agentRunner = runner.NewRunner(configurator.WithErrorHandler(handleError), resources, ui.DefaultUI) defaultToken = os.Getenv("TRACETEST_TOKEN") defaultEndpoint = os.Getenv("TRACETEST_SERVER_URL") defaultAPIKey = os.Getenv("TRACETEST_API_KEY") @@ -24,9 +26,7 @@ var startCmd = &cobra.Command{ Short: "Start Tracetest", Long: "Start using Tracetest", PreRun: setupCommand(SkipConfigValidation(), SkipVersionMismatchCheck()), - Run: WithResultHandler((func(_ *cobra.Command, _ []string) (string, error) { - ctx := context.Background() - + Run: WithResultHandler((func(ctx context.Context, _ *cobra.Command, _ []string) (string, error) { flags := agentConfig.Flags{ OrganizationID: saveParams.organizationID, EnvironmentID: saveParams.environmentID, @@ -37,6 +37,17 @@ var startCmd = &cobra.Command{ LogLevel: saveParams.logLevel, } + // override organization and environment id from context. + // this happens when auto rerunning the cmd after relogin + if orgID := config.ContextGetOrganizationID(ctx); orgID != "" { + flags.OrganizationID = orgID + } + if envID := config.ContextGetEnvironmentID(ctx); envID != "" { + flags.EnvironmentID = envID + } + + spew.Dump(flags) + cfg, err := agentConfig.LoadConfig() if err != nil { return "", err diff --git a/cli/cmd/version_cmd.go b/cli/cmd/version_cmd.go index 2df624019d..04026d6903 100644 --- a/cli/cmd/version_cmd.go +++ b/cli/cmd/version_cmd.go @@ -1,6 +1,8 @@ package cmd import ( + "context" + "github.com/spf13/cobra" ) @@ -10,7 +12,7 @@ var versionCmd = &cobra.Command{ Short: "Display this CLI tool version", Long: "Display this CLI tool version", PreRun: setupCommand(), - Run: WithResultHandler(func(_ *cobra.Command, _ []string) (string, error) { + Run: WithResultHandler(func(_ context.Context, _ *cobra.Command, _ []string) (string, error) { return versionText, nil }), PostRun: teardownCommand, diff --git a/cli/config/config.go b/cli/config/config.go index 86f1a77343..d95d5fe146 100644 --- a/cli/config/config.go +++ b/cli/config/config.go @@ -1,6 +1,7 @@ package config import ( + "context" "encoding/json" "fmt" "os" @@ -143,15 +144,45 @@ func ParseServerURL(serverURL string) (scheme, endpoint, serverPath string, err return url.Scheme, url.Host, url.Path, nil } -func Save(config Config) error { +type orgIDKeyType struct{} +type envIDKeyType struct{} + +var orgIDKey = orgIDKeyType{} +var envIDKey = envIDKeyType{} + +func ContextWithOrganizationID(ctx context.Context, orgID string) context.Context { + return context.WithValue(ctx, orgIDKey, orgID) +} + +func ContextWithEnvironmentID(ctx context.Context, envID string) context.Context { + return context.WithValue(ctx, envIDKey, envID) +} + +func ContextGetOrganizationID(ctx context.Context) string { + v := ctx.Value(orgIDKey) + if v == nil { + return "" + } + return v.(string) +} + +func ContextGetEnvironmentID(ctx context.Context) string { + v := ctx.Value(envIDKey) + if v == nil { + return "" + } + return v.(string) +} + +func Save(ctx context.Context, config Config) (context.Context, error) { configPath, err := GetConfigurationPath() if err != nil { - return fmt.Errorf("could not get configuration path: %w", err) + return ctx, fmt.Errorf("could not get configuration path: %w", err) } configYml, err := yaml.Marshal(config) if err != nil { - return fmt.Errorf("could not marshal configuration into yml: %w", err) + return ctx, fmt.Errorf("could not marshal configuration into yml: %w", err) } if _, err := os.Stat(configPath); os.IsNotExist(err) { @@ -159,10 +190,13 @@ func Save(config Config) error { } err = os.WriteFile(configPath, configYml, 0755) if err != nil { - return fmt.Errorf("could not write file: %w", err) + return ctx, fmt.Errorf("could not write file: %w", err) } - return nil + ctx = ContextWithOrganizationID(ctx, config.OrganizationID) + ctx = ContextWithEnvironmentID(ctx, config.EnvironmentID) + + return ctx, nil } func GetConfigurationPath() (string, error) { diff --git a/cli/config/configurator.go b/cli/config/configurator.go index aa695bf224..a74f73719b 100644 --- a/cli/config/configurator.go +++ b/cli/config/configurator.go @@ -20,6 +20,7 @@ type Configurator struct { resources *resourcemanager.Registry ui cliUI.UI onFinish onFinishFn + errorHandlerFn errorHandlerFn flags agentConfig.Flags finalServerURL string } @@ -34,6 +35,9 @@ func NewConfigurator(resources *resourcemanager.Registry) Configurator { ui.Success("Successfully configured Tracetest CLI") ui.Finish() }, + errorHandlerFn: func(ctx context.Context, err error) { + ui.Exit(err.Error()) + }, flags: agentConfig.Flags{}, } } @@ -43,6 +47,13 @@ func (c Configurator) WithOnFinish(onFinish onFinishFn) Configurator { return c } +type errorHandlerFn func(ctx context.Context, err error) + +func (c Configurator) WithErrorHandler(fn errorHandlerFn) Configurator { + c.errorHandlerFn = fn + return c +} + func (c Configurator) Start(ctx context.Context, prev *Config, flags agentConfig.Flags) error { c.flags = flags serverURL, err := c.getServerURL(prev, flags) @@ -67,7 +78,7 @@ func (c Configurator) Start(ctx context.Context, prev *Config, flags agentConfig } if flags.CI { - err = Save(cfg) + _, err = Save(ctx, cfg) if err != nil { return err } @@ -143,7 +154,7 @@ func (c Configurator) populateConfigWithVersionInfo(ctx context.Context, cfg Con serverType := version.GetType() if serverType == "oss" { - err := Save(cfg) + _, err = Save(ctx, cfg) if err != nil { return Config{}, fmt.Errorf("could not save configuration: %w", err), false } @@ -187,6 +198,10 @@ func (c Configurator) handleOAuth(ctx context.Context, cfg Config, prev *Config, return cfg, nil } + return c.ExecuteUserLogin(ctx, cfg) +} + +func (c Configurator) ExecuteUserLogin(ctx context.Context, cfg Config) (Config, error) { oauthServer := oauth.NewOAuthServer(cfg.OAuthEndpoint(), cfg.UIEndpoint) err := oauthServer.WithOnSuccess(c.onOAuthSuccess(ctx, cfg)). WithOnFailure(c.onOAuthFailure). @@ -195,7 +210,7 @@ func (c Configurator) handleOAuth(ctx context.Context, cfg Config, prev *Config, return Config{}, err } - return cfg, nil + return cfg, err } func (c Configurator) exchangeToken(cfg Config, token string) (Config, error) { @@ -245,7 +260,7 @@ func (c Configurator) onOAuthSuccess(ctx context.Context, cfg Config) func(token } func (c Configurator) onOAuthFailure(err error) { - c.ui.Exit(err.Error()) + c.errorHandlerFn(context.Background(), err) } func (c Configurator) ShowOrganizationSelector(ctx context.Context, cfg Config, flags agentConfig.Flags) { @@ -253,7 +268,7 @@ func (c Configurator) ShowOrganizationSelector(ctx context.Context, cfg Config, if cfg.OrganizationID == "" && flags.AgentApiKey == "" { orgID, err := c.organizationSelector(ctx, cfg) if err != nil { - c.ui.Exit(err.Error()) + c.errorHandlerFn(ctx, err) return } @@ -264,16 +279,16 @@ func (c Configurator) ShowOrganizationSelector(ctx context.Context, cfg Config, if cfg.EnvironmentID == "" && flags.AgentApiKey == "" { envID, err := c.environmentSelector(ctx, cfg) if err != nil { - c.ui.Exit(err.Error()) + c.errorHandlerFn(ctx, err) return } cfg.EnvironmentID = envID } - err := Save(cfg) + ctx, err := Save(ctx, cfg) if err != nil { - c.ui.Exit(err.Error()) + c.errorHandlerFn(ctx, err) return } diff --git a/cli/pkg/resourcemanager/client.go b/cli/pkg/resourcemanager/client.go index 0f3ba2c900..5e66d42d81 100644 --- a/cli/pkg/resourcemanager/client.go +++ b/cli/pkg/resourcemanager/client.go @@ -134,8 +134,9 @@ var ErrNotFound = RequestError{ } type RequestError struct { - Code int `json:"code"` - Message string `json:"error"` + Code int `json:"code"` + Message string `json:"error"` + IsAuthError bool `json:"isAuthError"` } type alternateRequestError struct { @@ -152,6 +153,10 @@ func (e RequestError) Is(target error) bool { return ok && t.Code == e.Code } +func isAuthError(resp *http.Response) bool { + return resp.StatusCode == http.StatusUnauthorized +} + func isSuccessResponse(resp *http.Response) bool { // successfull http status codes are 2xx return resp.StatusCode >= 200 && resp.StatusCode < 300 @@ -188,9 +193,9 @@ func parseRequestError(resp *http.Response, format Format) error { if err != nil { return fmt.Errorf("cannot parse response body: %w", err) } - return RequestError{ - Code: alternateReqError.Status, - Message: alternateReqError.Detail, + Code: alternateReqError.Status, + Message: alternateReqError.Detail, + IsAuthError: isAuthError(resp), } }