Skip to content

Commit

Permalink
feat(cli): Token Support (#3255)
Browse files Browse the repository at this point in the history
* feat(cli): Token Support

* feat(cli): Token Support

* fix(cli): updating flag name

* fix(cli): reverting flag name to token
  • Loading branch information
xoscar committed Oct 13, 2023
1 parent 6644a3c commit 0aeb269
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 34 deletions.
9 changes: 7 additions & 2 deletions cli/cmd/start_cmd.go
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"context"
"os"

agentConfig "github.com/kubeshop/tracetest/agent/config"
"github.com/kubeshop/tracetest/cli/config"
Expand All @@ -10,8 +11,9 @@ import (
)

var (
start = starter.NewStarter(configurator, resources)
saveParams = &saveParameters{}
start = starter.NewStarter(configurator, resources)
defaultToken = os.Getenv("TRACETEST_TOKEN")
saveParams = &saveParameters{}
)

var startCmd = &cobra.Command{
Expand All @@ -28,6 +30,7 @@ var startCmd = &cobra.Command{
EnvironmentID: saveParams.environmentID,
Endpoint: saveParams.endpoint,
AgentApiKey: saveParams.agentApiKey,
Token: saveParams.token,
}

cfg, err := agentConfig.LoadConfig()
Expand All @@ -49,6 +52,7 @@ func init() {
startCmd.Flags().StringVarP(&saveParams.organizationID, "organization", "", "", "organization id")
startCmd.Flags().StringVarP(&saveParams.environmentID, "environment", "", "", "environment id")
startCmd.Flags().StringVarP(&saveParams.agentApiKey, "api-key", "", "", "agent api key")
startCmd.Flags().StringVarP(&saveParams.token, "token", "", defaultToken, "token api key")
startCmd.Flags().StringVarP(&saveParams.endpoint, "endpoint", "e", config.DefaultCloudEndpoint, "set the value for the endpoint, so the CLI won't ask for this value")
rootCmd.AddCommand(startCmd)
}
Expand All @@ -58,4 +62,5 @@ type saveParameters struct {
environmentID string
endpoint string
agentApiKey string
token string
}
1 change: 1 addition & 0 deletions cli/config/config.go
Expand Up @@ -27,6 +27,7 @@ type ConfigFlags struct {
EnvironmentID string
CI bool
AgentApiKey string
Token string
}

type Config struct {
Expand Down
53 changes: 41 additions & 12 deletions cli/config/configurator.go
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"strings"

"github.com/golang-jwt/jwt"
"github.com/kubeshop/tracetest/cli/analytics"
"github.com/kubeshop/tracetest/cli/pkg/oauth"
"github.com/kubeshop/tracetest/cli/pkg/resourcemanager"
Expand Down Expand Up @@ -99,32 +100,46 @@ func (c Configurator) Start(ctx context.Context, prev Config, flags ConfigFlags)
return Save(cfg)
}

if flags.AgentApiKey != "" {
cfg.AgentApiKey = flags.AgentApiKey
c.ShowOrganizationSelector(ctx, cfg, flags)
return nil
}
oauthEndpoint := fmt.Sprintf("%s%s", cfg.URL(), cfg.Path())

if prev.Jwt != "" {
cfg.Jwt = prev.Jwt
cfg.Token = prev.Token
}

if flags.Token != "" {
jwt, err := oauth.ExchangeToken(oauthEndpoint, flags.Token)
if err != nil {
return err
}

cfg.Jwt = jwt
cfg.Token = flags.Token

claims, err := GetTokenClaims(jwt)
if err != nil {
return err
}

flags.OrganizationID = claims["organization_id"].(string)
flags.EnvironmentID = claims["environment_id"].(string)
}

if flags.AgentApiKey != "" {
cfg.AgentApiKey = flags.AgentApiKey
c.ShowOrganizationSelector(ctx, cfg, flags)
return nil
}

confirmed := c.ui.Enter("Lets get to it! Press enter to launch a browser and authenticate:")
if !confirmed {
c.ui.Finish()
if cfg.Jwt != "" {
c.ShowOrganizationSelector(ctx, cfg, flags)
return nil
}

oauthServer := oauth.NewOAuthServer(fmt.Sprintf("%s%s", cfg.URL(), cfg.Path()), cfg.UIEndpoint)
err = oauthServer.WithOnSuccess(c.onOAuthSuccess(ctx, cfg)).
oauthServer := oauth.NewOAuthServer(oauthEndpoint, cfg.UIEndpoint)
return oauthServer.WithOnSuccess(c.onOAuthSuccess(ctx, cfg)).
WithOnFailure(c.onOAuthFailure).
GetAuthJWT()

return err
}

func (c Configurator) onOAuthSuccess(ctx context.Context, cfg Config) func(token, jwt string) {
Expand Down Expand Up @@ -182,3 +197,17 @@ func SetupHttpClient(cfg Config) *resourcemanager.HTTPClient {

return resourcemanager.NewHTTPClient(fmt.Sprintf("%s%s", cfg.URL(), cfg.Path()), extraHeaders)
}

func GetTokenClaims(tokenString string) (jwt.MapClaims, error) {
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
if err != nil {
return jwt.MapClaims{}, err
}

claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return jwt.MapClaims{}, fmt.Errorf("invalid token claims")
}

return claims, nil
}
14 changes: 11 additions & 3 deletions cli/pkg/oauth/oauth.go
Expand Up @@ -22,6 +22,7 @@ type OAuthServer struct {
port int
server *http.Server
mutex sync.Mutex
ui ui.UI
}

type Option func(*OAuthServer)
Expand All @@ -30,6 +31,7 @@ func NewOAuthServer(endpoint, frontendEndpoint string) *OAuthServer {
return &OAuthServer{
endpoint: endpoint,
frontendEndpoint: frontendEndpoint,
ui: ui.DefaultUI,
}
}

Expand All @@ -44,6 +46,12 @@ func (s *OAuthServer) WithOnFailure(onFailure OnAuthFailure) *OAuthServer {
}

func (s *OAuthServer) GetAuthJWT() error {
confirmed := s.ui.Enter("Lets get to it! Press enter to launch a browser and authenticate:")
if !confirmed {
s.ui.Finish()
return nil
}

url, err := s.getUrl()
if err != nil {
return fmt.Errorf("failed to start oauth server: %w", err)
Expand All @@ -64,8 +72,8 @@ type JWTResponse struct {
Jwt string `json:"jwt"`
}

func (s *OAuthServer) ExchangeToken(token string) (string, error) {
req, err := http.NewRequest("GET", fmt.Sprintf("%s/tokens/%s/exchange", s.endpoint, token), nil)
func ExchangeToken(endpoint string, token string) (string, error) {
req, err := http.NewRequest("GET", fmt.Sprintf("%s/tokens/%s/exchange", endpoint, token), nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
Expand Down Expand Up @@ -133,7 +141,7 @@ func (s *OAuthServer) handleResult(r *http.Request) (string, string, error) {
return "", "", fmt.Errorf("tokenId not found")
}

jwt, err := s.ExchangeToken(tokenId)
jwt, err := ExchangeToken(s.endpoint, tokenId)
if err != nil {
return "", "", err
}
Expand Down
23 changes: 6 additions & 17 deletions cli/pkg/starter/starter.go
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"

"github.com/golang-jwt/jwt/v4"
agentConfig "github.com/kubeshop/tracetest/agent/config"
"github.com/kubeshop/tracetest/agent/initialization"

Expand All @@ -31,7 +30,11 @@ func (s *Starter) Run(ctx context.Context, cfg config.Config, flags config.Confi
s.ui.Println(`Tracetest start launches a lightweight agent. It enables you to run tests and collect traces with Tracetest.
Once started, Tracetest Agent exposes OTLP ports 4317 and 4318 to ingest traces via gRCP and HTTP.`)

return s.configurator.WithOnFinish(s.onStartAgent).Start(ctx, cfg, flags)
if flags.Token == "" || flags.AgentApiKey != "" {
s.configurator = s.configurator.WithOnFinish(s.onStartAgent)
}

return s.configurator.Start(ctx, cfg, flags)
}

func (s *Starter) onStartAgent(ctx context.Context, cfg config.Config) {
Expand Down Expand Up @@ -132,7 +135,7 @@ func (s *Starter) StartAgent(ctx context.Context, endpoint, agentApiKey, uiEndpo
isStarted = true
}

claims, err := s.getTokenClaims(session.Token)
claims, err := config.GetTokenClaims(session.Token)
if err != nil {
return err
}
Expand Down Expand Up @@ -160,17 +163,3 @@ You can`
}
return nil
}

func (s *Starter) getTokenClaims(tokenString string) (jwt.MapClaims, error) {
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
if err != nil {
return jwt.MapClaims{}, err
}

claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return jwt.MapClaims{}, fmt.Errorf("invalid token claims")
}

return claims, nil
}

0 comments on commit 0aeb269

Please sign in to comment.