Skip to content

Commit

Permalink
[CLI] configure cmd ask for target server (#3585)
Browse files Browse the repository at this point in the history
  • Loading branch information
schoren committed Jan 31, 2024
1 parent 28999b9 commit 9f7f26f
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 59 deletions.
3 changes: 2 additions & 1 deletion Makefile
@@ -1,4 +1,5 @@
VERSION?=dev
export VERSION?=dev
export TRACETEST_DEFAULT_CLOUD_ENDPOINT=https://app.tracetest.io
TAG?=$(VERSION)
GORELEASER_VERSION=1.23.0-pro

Expand Down
2 changes: 1 addition & 1 deletion agent/config/flags.go
Expand Up @@ -8,7 +8,7 @@ const (
)

type Flags struct {
Endpoint string
ServerURL string
OrganizationID string
EnvironmentID string
CI bool
Expand Down
2 changes: 1 addition & 1 deletion agent/runner/runner.go
Expand Up @@ -57,7 +57,7 @@ Once started, Tracetest Agent exposes OTLP ports 4317 and 4318 to ingest traces

s.logger = logger

return s.configurator.Start(ctx, cfg, flags)
return s.configurator.Start(ctx, &cfg, flags)
}

func (s *Runner) onStartAgent(ctx context.Context, cfg config.Config) {
Expand Down
9 changes: 2 additions & 7 deletions cli/cmd/configure_cmd.go
Expand Up @@ -26,13 +26,9 @@ var configureCmd = &cobra.Command{
flags := agentConfig.Flags{
CI: configParams.CI,
}
config, err := config.LoadConfig("")
if err != nil {
return "", err
}

if flagProvided(cmd, "server-url") || flagProvided(cmd, "endpoint") {
flags.Endpoint = configParams.ServerURL
flags.ServerURL = configParams.ServerURL
}

if flagProvided(cmd, "token") {
Expand All @@ -47,8 +43,7 @@ var configureCmd = &cobra.Command{
flags.OrganizationID = configParams.OrganizationID
}

err = configurator.Start(ctx, config, flags)
return "", err
return "", configurator.Start(ctx, nil, flags)
})),
PostRun: teardownCommand,
}
Expand Down
2 changes: 1 addition & 1 deletion cli/cmd/start_cmd.go
Expand Up @@ -30,7 +30,7 @@ var startCmd = &cobra.Command{
flags := agentConfig.Flags{
OrganizationID: saveParams.organizationID,
EnvironmentID: saveParams.environmentID,
Endpoint: saveParams.endpoint,
ServerURL: saveParams.endpoint,
AgentApiKey: saveParams.agentApiKey,
Token: saveParams.token,
Mode: agentConfig.Mode(saveParams.mode),
Expand Down
5 changes: 5 additions & 0 deletions cli/config/config.go
Expand Up @@ -38,6 +38,11 @@ type Config struct {
UIEndpoint string `yaml:"uIEndpoint,omitempty"`
}

func (c Config) OAuthEndpoint() string {
return fmt.Sprintf("%s%s", c.URL(), c.Path())

}

func (c Config) URL() string {
if c.Scheme == "" || c.Endpoint == "" {
return ""
Expand Down
149 changes: 101 additions & 48 deletions cli/config/configurator.go
Expand Up @@ -7,13 +7,10 @@ import (
"strings"

"github.com/golang-jwt/jwt"

agentConfig "github.com/kubeshop/tracetest/agent/config"

"github.com/kubeshop/tracetest/cli/analytics"
"github.com/kubeshop/tracetest/cli/pkg/oauth"
"github.com/kubeshop/tracetest/cli/pkg/resourcemanager"

cliUI "github.com/kubeshop/tracetest/cli/ui"
)

Expand Down Expand Up @@ -42,111 +39,167 @@ func (c Configurator) WithOnFinish(onFinish onFinishFn) Configurator {
return c
}

func (c Configurator) Start(ctx context.Context, prev Config, flags agentConfig.Flags) error {
func (c Configurator) Start(ctx context.Context, prev *Config, flags agentConfig.Flags) error {
c.flags = flags
serverURL := getFirstValidString(flags.Endpoint, prev.UIEndpoint, DefaultCloudEndpoint)
if serverURL == "" {
path := ""
if prev.ServerPath != nil {
path = *prev.ServerPath
serverURL, err := c.getServerURL(prev, flags)
if err != nil {
return err
}

cfg, err := c.createConfig(serverURL)
if err != nil {
return err
}

cfg, err, isOSS := c.populateConfigWithVersionInfo(ctx, cfg)
if err != nil {
return err
}

if isOSS {
// we don't need anything else for OSS
return nil
}

if flags.CI {
err = Save(cfg)
if err != nil {
return err
}
serverURL = c.ui.TextInput("Enter your Tracetest server URL", fmt.Sprintf("%s%s", prev.URL(), path))
return nil
}

if err := ValidateServerURL(serverURL); err != nil {
_, err = c.handleOAuth(ctx, cfg, prev, flags)
if err != nil {
return err
}

return nil
}

func (c Configurator) getServerURL(prev *Config, flags agentConfig.Flags) (string, error) {
var prevUIEndpoint string
if prev != nil {
prevUIEndpoint = prev.UIEndpoint
}
serverURL := getFirstValidString(flags.ServerURL, prevUIEndpoint)
if serverURL == "" {
serverURL = c.ui.TextInput("What tracetest server do you want to use?", DefaultCloudEndpoint)
}

if err := ValidateServerURL(serverURL); err != nil {
return "", err
}

return serverURL, nil
}

func (c Configurator) createConfig(serverURL string) (Config, error) {
scheme, endpoint, path, err := ParseServerURL(serverURL)
if err != nil {
return err
return Config{}, err
}

if strings.Contains(serverURL, DefaultCloudDomain) {
path = &DefaultCloudPath
}

cfg := Config{
return Config{
Scheme: scheme,
Endpoint: endpoint,
ServerPath: path,
}
}, nil
}

func (c Configurator) populateConfigWithVersionInfo(ctx context.Context, cfg Config) (_ Config, _ error, isOSS bool) {
client := GetAPIClient(cfg)
version, err := getVersionMetadata(ctx, client)
if err != nil {
return fmt.Errorf("cannot get version metadata: %w", err)
return Config{}, fmt.Errorf("cannot get version metadata: %w", err), false
}

serverType := version.GetType()
if serverType == "oss" {
err := Save(cfg)
if err != nil {
return fmt.Errorf("could not save configuration: %w", err)
return Config{}, fmt.Errorf("could not save configuration: %w", err), false
}

c.ui.Success("Successfully configured Tracetest CLI")
return nil
return cfg, nil, true
}

cfg.AgentEndpoint = version.GetAgentEndpoint()
cfg.UIEndpoint = version.GetUiEndpoint()
cfg.Scheme, cfg.Endpoint, cfg.ServerPath, err = ParseServerURL(version.GetApiEndpoint())
if err != nil {
return fmt.Errorf("cannot parse server url: %w", err)
return Config{}, fmt.Errorf("cannot parse server url: %w", err), false
}

if flags.CI {
return Save(cfg)
}

oauthEndpoint := fmt.Sprintf("%s%s", cfg.URL(), cfg.Path())
return cfg, nil, false
}

if prev.Jwt != "" {
func (c Configurator) handleOAuth(ctx context.Context, cfg Config, prev *Config, flags agentConfig.Flags) (Config, error) {
if prev != nil && prev.Jwt != "" {
cfg.Jwt = prev.Jwt
cfg.Token = prev.Token
}

if flags.Token != "" {
jwt, err := oauth.ExchangeToken(oauthEndpoint, flags.Token)
if err != nil {
return err
}

cfg.Jwt = jwt
cfg.Token = flags.Token

claims, err := GetTokenClaims(jwt)
var err error
cfg, err = c.exchangeToken(cfg, flags.Token)
if err != nil {
return err
}

organizationId := claims["organization_id"].(string)
environmentId := claims["environment_id"].(string)

if organizationId != "" {
flags.OrganizationID = organizationId
}
if environmentId != "" {
flags.EnvironmentID = environmentId
return Config{}, err
}
}

if flags.AgentApiKey != "" {
cfg.AgentApiKey = flags.AgentApiKey
c.ShowOrganizationSelector(ctx, cfg, flags)
return nil
return cfg, nil
}

if cfg.Jwt != "" {
c.ShowOrganizationSelector(ctx, cfg, flags)
return nil
return cfg, nil
}

oauthServer := oauth.NewOAuthServer(oauthEndpoint, cfg.UIEndpoint)
return oauthServer.WithOnSuccess(c.onOAuthSuccess(ctx, cfg)).
oauthServer := oauth.NewOAuthServer(cfg.OAuthEndpoint(), cfg.UIEndpoint)
err := oauthServer.WithOnSuccess(c.onOAuthSuccess(ctx, cfg)).
WithOnFailure(c.onOAuthFailure).
GetAuthJWT()
if err != nil {
return Config{}, err
}

return cfg, nil
}

func (c Configurator) exchangeToken(cfg Config, token string) (Config, error) {
jwt, err := oauth.ExchangeToken(cfg.OAuthEndpoint(), token)
if err != nil {
return Config{}, err
}

cfg.Jwt = jwt
cfg.Token = token

claims, err := GetTokenClaims(jwt)
if err != nil {
return Config{}, err
}

organizationId := claims["organization_id"].(string)
environmentId := claims["environment_id"].(string)

if organizationId != "" {
c.flags.OrganizationID = organizationId
}
if environmentId != "" {
c.flags.EnvironmentID = environmentId
}

return cfg, nil
}

func getFirstValidString(values ...string) string {
Expand Down

0 comments on commit 9f7f26f

Please sign in to comment.