diff --git a/Makefile b/Makefile index 3a3905a646..d367f780f7 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ -VERSION?=dev +export VERSION?=dev +export TRACETEST_DEFAULT_CLOUD_ENDPOINT=https://app.tracetest.io TAG?=$(VERSION) GORELEASER_VERSION=1.23.0-pro diff --git a/agent/config/flags.go b/agent/config/flags.go index 44e5e3d62f..068de3a81f 100644 --- a/agent/config/flags.go +++ b/agent/config/flags.go @@ -8,7 +8,7 @@ const ( ) type Flags struct { - Endpoint string + ServerURL string OrganizationID string EnvironmentID string CI bool diff --git a/agent/runner/runner.go b/agent/runner/runner.go index 49f7c8941f..6b753efd10 100644 --- a/agent/runner/runner.go +++ b/agent/runner/runner.go @@ -57,7 +57,7 @@ Once started, Tracetest Agent exposes OTLP ports 4317 and 4318 to ingest traces s.logger = logger - return s.configurator.Start(ctx, cfg, flags) + return s.configurator.Start(ctx, &cfg, flags) } func (s *Runner) onStartAgent(ctx context.Context, cfg config.Config) { diff --git a/cli/cmd/configure_cmd.go b/cli/cmd/configure_cmd.go index 6417c46ae1..ef9d8582f0 100644 --- a/cli/cmd/configure_cmd.go +++ b/cli/cmd/configure_cmd.go @@ -26,13 +26,9 @@ var configureCmd = &cobra.Command{ flags := agentConfig.Flags{ CI: configParams.CI, } - config, err := config.LoadConfig("") - if err != nil { - return "", err - } if flagProvided(cmd, "server-url") || flagProvided(cmd, "endpoint") { - flags.Endpoint = configParams.ServerURL + flags.ServerURL = configParams.ServerURL } if flagProvided(cmd, "token") { @@ -47,8 +43,7 @@ var configureCmd = &cobra.Command{ flags.OrganizationID = configParams.OrganizationID } - err = configurator.Start(ctx, config, flags) - return "", err + return "", configurator.Start(ctx, nil, flags) })), PostRun: teardownCommand, } diff --git a/cli/cmd/start_cmd.go b/cli/cmd/start_cmd.go index 94a064c2f2..2a3a11af16 100644 --- a/cli/cmd/start_cmd.go +++ b/cli/cmd/start_cmd.go @@ -30,7 +30,7 @@ var startCmd = &cobra.Command{ flags := agentConfig.Flags{ OrganizationID: saveParams.organizationID, EnvironmentID: saveParams.environmentID, - Endpoint: saveParams.endpoint, + ServerURL: saveParams.endpoint, AgentApiKey: saveParams.agentApiKey, Token: saveParams.token, Mode: agentConfig.Mode(saveParams.mode), diff --git a/cli/config/config.go b/cli/config/config.go index 89ba800d98..2f9ab95645 100644 --- a/cli/config/config.go +++ b/cli/config/config.go @@ -38,6 +38,11 @@ type Config struct { UIEndpoint string `yaml:"uIEndpoint,omitempty"` } +func (c Config) OAuthEndpoint() string { + return fmt.Sprintf("%s%s", c.URL(), c.Path()) + +} + func (c Config) URL() string { if c.Scheme == "" || c.Endpoint == "" { return "" diff --git a/cli/config/configurator.go b/cli/config/configurator.go index 71797a13d5..1f7caa9220 100644 --- a/cli/config/configurator.go +++ b/cli/config/configurator.go @@ -7,13 +7,10 @@ import ( "strings" "github.com/golang-jwt/jwt" - agentConfig "github.com/kubeshop/tracetest/agent/config" - "github.com/kubeshop/tracetest/cli/analytics" "github.com/kubeshop/tracetest/cli/pkg/oauth" "github.com/kubeshop/tracetest/cli/pkg/resourcemanager" - cliUI "github.com/kubeshop/tracetest/cli/ui" ) @@ -42,111 +39,167 @@ func (c Configurator) WithOnFinish(onFinish onFinishFn) Configurator { return c } -func (c Configurator) Start(ctx context.Context, prev Config, flags agentConfig.Flags) error { +func (c Configurator) Start(ctx context.Context, prev *Config, flags agentConfig.Flags) error { c.flags = flags - serverURL := getFirstValidString(flags.Endpoint, prev.UIEndpoint, DefaultCloudEndpoint) - if serverURL == "" { - path := "" - if prev.ServerPath != nil { - path = *prev.ServerPath + serverURL, err := c.getServerURL(prev, flags) + if err != nil { + return err + } + + cfg, err := c.createConfig(serverURL) + if err != nil { + return err + } + + cfg, err, isOSS := c.populateConfigWithVersionInfo(ctx, cfg) + if err != nil { + return err + } + + if isOSS { + // we don't need anything else for OSS + return nil + } + + if flags.CI { + err = Save(cfg) + if err != nil { + return err } - serverURL = c.ui.TextInput("Enter your Tracetest server URL", fmt.Sprintf("%s%s", prev.URL(), path)) + return nil } - if err := ValidateServerURL(serverURL); err != nil { + _, err = c.handleOAuth(ctx, cfg, prev, flags) + if err != nil { return err } + return nil +} + +func (c Configurator) getServerURL(prev *Config, flags agentConfig.Flags) (string, error) { + var prevUIEndpoint string + if prev != nil { + prevUIEndpoint = prev.UIEndpoint + } + serverURL := getFirstValidString(flags.ServerURL, prevUIEndpoint) + if serverURL == "" { + serverURL = c.ui.TextInput("What tracetest server do you want to use?", DefaultCloudEndpoint) + } + + if err := ValidateServerURL(serverURL); err != nil { + return "", err + } + + return serverURL, nil +} + +func (c Configurator) createConfig(serverURL string) (Config, error) { scheme, endpoint, path, err := ParseServerURL(serverURL) if err != nil { - return err + return Config{}, err } if strings.Contains(serverURL, DefaultCloudDomain) { path = &DefaultCloudPath } - cfg := Config{ + return Config{ Scheme: scheme, Endpoint: endpoint, ServerPath: path, - } + }, nil +} +func (c Configurator) populateConfigWithVersionInfo(ctx context.Context, cfg Config) (_ Config, _ error, isOSS bool) { client := GetAPIClient(cfg) version, err := getVersionMetadata(ctx, client) if err != nil { - return fmt.Errorf("cannot get version metadata: %w", err) + return Config{}, fmt.Errorf("cannot get version metadata: %w", err), false } serverType := version.GetType() if serverType == "oss" { err := Save(cfg) if err != nil { - return fmt.Errorf("could not save configuration: %w", err) + return Config{}, fmt.Errorf("could not save configuration: %w", err), false } c.ui.Success("Successfully configured Tracetest CLI") - return nil + return cfg, nil, true } cfg.AgentEndpoint = version.GetAgentEndpoint() cfg.UIEndpoint = version.GetUiEndpoint() cfg.Scheme, cfg.Endpoint, cfg.ServerPath, err = ParseServerURL(version.GetApiEndpoint()) if err != nil { - return fmt.Errorf("cannot parse server url: %w", err) + return Config{}, fmt.Errorf("cannot parse server url: %w", err), false } - if flags.CI { - return Save(cfg) - } - - oauthEndpoint := fmt.Sprintf("%s%s", cfg.URL(), cfg.Path()) + return cfg, nil, false +} - if prev.Jwt != "" { +func (c Configurator) handleOAuth(ctx context.Context, cfg Config, prev *Config, flags agentConfig.Flags) (Config, error) { + if prev != nil && prev.Jwt != "" { cfg.Jwt = prev.Jwt cfg.Token = prev.Token } if flags.Token != "" { - jwt, err := oauth.ExchangeToken(oauthEndpoint, flags.Token) - if err != nil { - return err - } - - cfg.Jwt = jwt - cfg.Token = flags.Token - - claims, err := GetTokenClaims(jwt) + var err error + cfg, err = c.exchangeToken(cfg, flags.Token) if err != nil { - return err - } - - organizationId := claims["organization_id"].(string) - environmentId := claims["environment_id"].(string) - - if organizationId != "" { - flags.OrganizationID = organizationId - } - if environmentId != "" { - flags.EnvironmentID = environmentId + return Config{}, err } } if flags.AgentApiKey != "" { cfg.AgentApiKey = flags.AgentApiKey c.ShowOrganizationSelector(ctx, cfg, flags) - return nil + return cfg, nil } if cfg.Jwt != "" { c.ShowOrganizationSelector(ctx, cfg, flags) - return nil + return cfg, nil } - oauthServer := oauth.NewOAuthServer(oauthEndpoint, cfg.UIEndpoint) - return oauthServer.WithOnSuccess(c.onOAuthSuccess(ctx, cfg)). + oauthServer := oauth.NewOAuthServer(cfg.OAuthEndpoint(), cfg.UIEndpoint) + err := oauthServer.WithOnSuccess(c.onOAuthSuccess(ctx, cfg)). WithOnFailure(c.onOAuthFailure). GetAuthJWT() + if err != nil { + return Config{}, err + } + + return cfg, nil +} + +func (c Configurator) exchangeToken(cfg Config, token string) (Config, error) { + jwt, err := oauth.ExchangeToken(cfg.OAuthEndpoint(), token) + if err != nil { + return Config{}, err + } + + cfg.Jwt = jwt + cfg.Token = token + + claims, err := GetTokenClaims(jwt) + if err != nil { + return Config{}, err + } + + organizationId := claims["organization_id"].(string) + environmentId := claims["environment_id"].(string) + + if organizationId != "" { + c.flags.OrganizationID = organizationId + } + if environmentId != "" { + c.flags.EnvironmentID = environmentId + } + + return cfg, nil } func getFirstValidString(values ...string) string {