diff --git a/app/controlplane/internal/service/project.go b/app/controlplane/internal/service/project.go index d365e1946..4a3e56b4d 100644 --- a/app/controlplane/internal/service/project.go +++ b/app/controlplane/internal/service/project.go @@ -56,7 +56,7 @@ func (s *ProjectService) APITokenCreate(ctx context.Context, req *pb.ProjectServ *expiresIn = req.ExpiresIn.AsDuration() } - token, err := s.APITokenUseCase.Create(ctx, req.Name, req.Description, expiresIn, currentOrg.ID, biz.APITokenWithProjectID(project.ID)) + token, err := s.APITokenUseCase.Create(ctx, req.Name, req.Description, expiresIn, currentOrg.ID, biz.APITokenWithProject(project)) if err != nil { return nil, handleUseCaseErr(err, s.log) } @@ -81,7 +81,7 @@ func (s *ProjectService) APITokenList(ctx context.Context, req *pb.ProjectServic return nil, err } - tokens, err := s.APITokenUseCase.List(ctx, currentOrg.ID, req.IncludeRevoked, biz.APITokenWithProjectID(project.ID)) + tokens, err := s.APITokenUseCase.List(ctx, currentOrg.ID, req.IncludeRevoked, biz.APITokenWithProject(project)) if err != nil { return nil, handleUseCaseErr(err, s.log) } @@ -106,7 +106,7 @@ func (s *ProjectService) APITokenRevoke(ctx context.Context, req *pb.ProjectServ return nil, err } - t, err := s.APITokenUseCase.FindByNameInOrg(ctx, currentOrg.ID, req.Name, biz.APITokenWithProjectID(project.ID)) + t, err := s.APITokenUseCase.FindByNameInOrg(ctx, currentOrg.ID, req.Name, biz.APITokenWithProject(project)) if err != nil { return nil, handleUseCaseErr(err, s.log) } diff --git a/app/controlplane/internal/service/service.go b/app/controlplane/internal/service/service.go index abe0e4d20..e60a7e9f3 100644 --- a/app/controlplane/internal/service/service.go +++ b/app/controlplane/internal/service/service.go @@ -17,6 +17,7 @@ package service import ( "context" + "fmt" "io" "github.com/chainloop-dev/chainloop/app/controlplane/internal/usercontext" @@ -163,18 +164,35 @@ func (s *service) authorizeResource(ctx context.Context, op *authz.Policy, resou return nil } - // Apply RBAC + // 1 - Authorize using API token + // For now we only support API tokens to authorize project resourceTypes + // NOTE we do not run s.enforcer here because API tokens do not have roles associated with resourceTypes + // the authorization has happened at the API level and we do not have attribute-based policies in casbin yet + if token := entities.CurrentAPIToken(ctx); token != nil { + if resourceType == authz.ResourceTypeProject && token.ProjectID != nil && token.ProjectID.String() == resourceID.String() { + s.log.Debugw("msg", "authorized using API token", "resource_id", resourceID.String(), "resource_type", resourceType, "token_name", token.Name, "token_id", token.ID) + return nil + } + + return errors.Forbidden("forbidden", fmt.Errorf("operation not allowed: This auth token is valid only with the project %q", *token.ProjectName).Error()) + } + + // 2 - We are a user + // find the resource membership that matches the resource type and ID + // for example admin in project1, then apply RBAC enforcement m := entities.CurrentMembership(ctx) - // check for specific resource role for _, rm := range m.Resources { if rm.ResourceType == resourceType && rm.ResourceID == resourceID { pass, err := s.enforcer.Enforce(string(rm.Role), op) if err != nil { return handleUseCaseErr(err, s.log) } + if !pass { return errors.Forbidden("forbidden", "operation not allowed") } + + s.log.Debugw("msg", "authorized using user membership", "resource_id", resourceID.String(), "resource_type", resourceType, "role", rm.Role, "membership_id", rm.MembershipID, "user_id", m.UserID) return nil } } @@ -212,6 +230,15 @@ func (s *service) visibleProjects(ctx context.Context) []uuid.UUID { projects := make([]uuid.UUID, 0) + // 1 - Check if we are using an API token + if token := entities.CurrentAPIToken(ctx); token != nil { + if token.ProjectID != nil { + projects = append(projects, *token.ProjectID) + } + return projects + } + + // 2 - We are a user m := entities.CurrentMembership(ctx) for _, rm := range m.Resources { if rm.ResourceType == authz.ResourceTypeProject { @@ -222,8 +249,16 @@ func (s *service) visibleProjects(ctx context.Context) []uuid.UUID { return projects } -// RBAC feature is enabled if the user has the `Org Member` role. +// RBAC feature is enabled if we are using a project scoped token or +// it is a user with org role member func rbacEnabled(ctx context.Context) bool { + // it's an API token + token := entities.CurrentAPIToken(ctx) + if token != nil { + return token.ProjectID != nil + } + + // we have an user currentSubject := usercontext.CurrentAuthzSubject(ctx) return currentSubject == string(authz.RoleOrgMember) } diff --git a/app/controlplane/internal/service/workflow.go b/app/controlplane/internal/service/workflow.go index 34fedbebd..298304cc7 100644 --- a/app/controlplane/internal/service/workflow.go +++ b/app/controlplane/internal/service/workflow.go @@ -1,5 +1,5 @@ // -// Copyright 2024 The Chainloop Authors. +// Copyright 2024-2025 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/app/controlplane/internal/usercontext/apitoken_middleware.go b/app/controlplane/internal/usercontext/apitoken_middleware.go index 226aa4b01..eb5f1b46a 100644 --- a/app/controlplane/internal/usercontext/apitoken_middleware.go +++ b/app/controlplane/internal/usercontext/apitoken_middleware.go @@ -71,12 +71,15 @@ func WithCurrentAPITokenAndOrgMiddleware(apiTokenUC *biz.APITokenUseCase, orgUC return nil, errors.New("error mapping the API-token claims") } - ctx, err = setCurrentOrgAndAPIToken(ctx, apiTokenUC, orgUC, tokenID) + // Project ID is optional + projectID, _ := genericClaims["project_id"].(string) + + ctx, err = setCurrentOrgAndAPIToken(ctx, apiTokenUC, orgUC, tokenID, projectID) if err != nil { return nil, fmt.Errorf("error setting current org and user: %w", err) } - logger.Infow("msg", "[authN] processed credentials", "id", tokenID, "type", "API-token") + logger.Infow("msg", "[authN] processed credentials", "id", tokenID, "type", "API-token", "projectID", projectID) } return handler(ctx, req) @@ -120,7 +123,7 @@ func WithAttestationContextFromAPIToken(apiTokenUC *biz.APITokenUseCase, orgUC * return nil, fmt.Errorf("error extracting organization from APIToken: %w", err) } - ctx, err = setCurrentOrgAndAPIToken(ctx, apiTokenUC, orgUC, tokenID) + ctx, err = setCurrentOrgAndAPIToken(ctx, apiTokenUC, orgUC, tokenID, claims.ProjectID) if err != nil { return nil, fmt.Errorf("error setting current org and user: %w", err) } @@ -157,7 +160,7 @@ func setRobotAccountFromAPIToken(ctx context.Context, apiTokenUC *biz.APITokenUs } // Set the current organization and API-Token in the context -func setCurrentOrgAndAPIToken(ctx context.Context, apiTokenUC *biz.APITokenUseCase, orgUC *biz.OrganizationUseCase, tokenID string) (context.Context, error) { +func setCurrentOrgAndAPIToken(ctx context.Context, apiTokenUC *biz.APITokenUseCase, orgUC *biz.OrganizationUseCase, tokenID, projectIDInClaim string) (context.Context, error) { if tokenID == "" { return nil, errors.New("error retrieving the key ID from the API token") } @@ -170,6 +173,11 @@ func setCurrentOrgAndAPIToken(ctx context.Context, apiTokenUC *biz.APITokenUseCa return nil, errors.New("API token not found") } + // Make sure that the projectID that comes in the token claim matches the one in the DB + if projectIDInClaim != "" && token.ProjectID.String() != projectIDInClaim { + return nil, errors.New("API token project mismatch") + } + // Note: Expiration time does not need to be checked because that's done at the JWT // verification layer, which happens before this middleware is called if token.RevokedAt != nil { @@ -186,7 +194,15 @@ func setCurrentOrgAndAPIToken(ctx context.Context, apiTokenUC *biz.APITokenUseCa // Set the current organization and API-Token in the context ctx = entities.WithCurrentOrg(ctx, &entities.Org{Name: org.Name, ID: org.ID, CreatedAt: org.CreatedAt}) - ctx = entities.WithCurrentAPIToken(ctx, &entities.APIToken{ID: token.ID.String(), CreatedAt: token.CreatedAt, Token: token.JWT}) + + ctx = entities.WithCurrentAPIToken(ctx, &entities.APIToken{ + ID: token.ID.String(), + Name: token.Name, + CreatedAt: token.CreatedAt, + Token: token.JWT, + ProjectID: token.ProjectID, + ProjectName: token.ProjectName, + }) // Set the authorization subject that will be used to check the policies subjectAPIToken := authz.SubjectAPIToken{ID: token.ID.String()} diff --git a/app/controlplane/internal/usercontext/currentorganization_middleware.go b/app/controlplane/internal/usercontext/currentorganization_middleware.go index 394790395..94f54c18a 100644 --- a/app/controlplane/internal/usercontext/currentorganization_middleware.go +++ b/app/controlplane/internal/usercontext/currentorganization_middleware.go @@ -105,10 +105,11 @@ func setCurrentMembershipsForUser(ctx context.Context, u *entities.User, members Role: m.Role, ResourceType: m.ResourceType, ResourceID: m.ResourceID, + MembershipID: m.ID, }) } - membership = &entities.Membership{Resources: resourceMemberships} + membership = &entities.Membership{UserID: uuid.MustParse(u.ID), Resources: resourceMemberships} membershipsCache.Add(u.ID, membership) } diff --git a/app/controlplane/internal/usercontext/entities/apitoken.go b/app/controlplane/internal/usercontext/entities/apitoken.go index 5ff78174c..efdab169e 100644 --- a/app/controlplane/internal/usercontext/entities/apitoken.go +++ b/app/controlplane/internal/usercontext/entities/apitoken.go @@ -1,5 +1,5 @@ // -// Copyright 2024 The Chainloop Authors. +// Copyright 2024-2025 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,12 +18,18 @@ package entities import ( "context" "time" + + "github.com/google/uuid" ) type APIToken struct { - ID string - CreatedAt *time.Time - Token string + ID string + // Token Name + Name string + CreatedAt *time.Time + Token string + ProjectID *uuid.UUID + ProjectName *string } func WithCurrentAPIToken(ctx context.Context, token *APIToken) context.Context { diff --git a/app/controlplane/internal/usercontext/entities/memberships.go b/app/controlplane/internal/usercontext/entities/memberships.go index b8cd9cc9c..042f78559 100644 --- a/app/controlplane/internal/usercontext/entities/memberships.go +++ b/app/controlplane/internal/usercontext/entities/memberships.go @@ -23,10 +23,12 @@ import ( ) type Membership struct { + UserID uuid.UUID Resources []*ResourceMembership } type ResourceMembership struct { + MembershipID uuid.UUID Role authz.Role ResourceType authz.ResourceType ResourceID uuid.UUID diff --git a/app/controlplane/pkg/biz/apitoken.go b/app/controlplane/pkg/biz/apitoken.go index 8b2ea6948..030d3737a 100644 --- a/app/controlplane/pkg/biz/apitoken.go +++ b/app/controlplane/pkg/biz/apitoken.go @@ -128,15 +128,15 @@ func NewAPITokenUseCase(apiTokenRepo APITokenRepo, jwtConfig *APITokenJWTConfig, } type apiTokenOptions struct { - projectID *uuid.UUID + project *Project showOnlySystemTokens bool } type APITokenUseCaseOpt func(*apiTokenOptions) -func APITokenWithProjectID(projectID uuid.UUID) APITokenUseCaseOpt { +func APITokenWithProject(project *Project) APITokenUseCaseOpt { return func(o *apiTokenOptions) { - o.projectID = &projectID + o.project = project } } @@ -181,9 +181,15 @@ func (uc *APITokenUseCase) Create(ctx context.Context, name string, description return nil, fmt.Errorf("finding organization: %w", err) } + // If a project is provided, we store it in the token + var projectID *uuid.UUID + if options.project != nil { + projectID = ToPtr(options.project.ID) + } + // NOTE: the expiration time is stored just for reference, it's also encoded in the JWT // We store it since Chainloop will not have access to the JWT to check the expiration once created - token, err := uc.apiTokenRepo.Create(ctx, name, description, expiresAt, orgUUID, options.projectID) + token, err := uc.apiTokenRepo.Create(ctx, name, description, expiresAt, orgUUID, projectID) if err != nil { if IsErrAlreadyExists(err) { return nil, NewErrAlreadyExistsStr("name already taken") @@ -191,8 +197,22 @@ func (uc *APITokenUseCase) Create(ctx context.Context, name string, description return nil, fmt.Errorf("storing token: %w", err) } + generationOpts := &apitoken.GenerateJWTOptions{ + OrgID: token.OrganizationID, + OrgName: org.Name, + KeyID: token.ID, + KeyName: name, + ExpiresAt: expiresAt, + } + + if projectID != nil { + generationOpts.ProjectID = ToPtr(options.project.ID) + generationOpts.ProjectName = ToPtr(options.project.Name) + } + // generate the JWT - token.JWT, err = uc.jwtBuilder.GenerateJWT(token.OrganizationID.String(), org.Name, token.ID.String(), expiresAt) + token.JWT, err = uc.jwtBuilder.GenerateJWT(generationOpts) + if err != nil { return nil, fmt.Errorf("generating jwt: %w", err) } @@ -233,8 +253,16 @@ func (uc *APITokenUseCase) RegenerateJWT(ctx context.Context, tokenID uuid.UUID, return nil, fmt.Errorf("finding organization: %w", err) } + generationOpts := &apitoken.GenerateJWTOptions{ + OrgID: token.OrganizationID, + OrgName: org.Name, + KeyID: token.ID, + KeyName: token.Name, + ExpiresAt: &expiresAt, + } + // generate the JWT - token.JWT, err = uc.jwtBuilder.GenerateJWT(token.OrganizationID.String(), org.Name, token.ID.String(), &expiresAt) + token.JWT, err = uc.jwtBuilder.GenerateJWT(generationOpts) if err != nil { return nil, fmt.Errorf("generating jwt: %w", err) } @@ -258,7 +286,12 @@ func (uc *APITokenUseCase) List(ctx context.Context, orgID string, includeRevoke return nil, NewErrInvalidUUID(err) } - return uc.apiTokenRepo.List(ctx, &orgUUID, options.projectID, includeRevoked, options.showOnlySystemTokens) + var projectID *uuid.UUID + if options.project != nil { + projectID = ToPtr(options.project.ID) + } + + return uc.apiTokenRepo.List(ctx, &orgUUID, projectID, includeRevoked, options.showOnlySystemTokens) } func (uc *APITokenUseCase) Revoke(ctx context.Context, orgID, id string) error { @@ -308,7 +341,12 @@ func (uc *APITokenUseCase) FindByNameInOrg(ctx context.Context, orgID, name stri return nil, NewErrInvalidUUID(err) } - t, err := uc.apiTokenRepo.FindByNameInOrg(ctx, orgUUID, name, options.projectID) + var projectID *uuid.UUID + if options.project != nil { + projectID = ToPtr(options.project.ID) + } + + t, err := uc.apiTokenRepo.FindByNameInOrg(ctx, orgUUID, name, projectID) if err != nil { return nil, fmt.Errorf("finding token: %w", err) } diff --git a/app/controlplane/pkg/biz/apitoken_integration_test.go b/app/controlplane/pkg/biz/apitoken_integration_test.go index 0989095c3..ba29a5bf4 100644 --- a/app/controlplane/pkg/biz/apitoken_integration_test.go +++ b/app/controlplane/pkg/biz/apitoken_integration_test.go @@ -66,7 +66,7 @@ func (s *apiTokenTestSuite) TestCreate() { }) s.Run("happy path with project", func() { - token, err := s.APIToken.Create(ctx, randomName(), nil, nil, s.org.ID, biz.APITokenWithProjectID(s.p1.ID)) + token, err := s.APIToken.Create(ctx, randomName(), nil, nil, s.org.ID, biz.APITokenWithProject(s.p1)) s.NoError(err) s.Equal(s.org.ID, token.OrganizationID.String()) s.Equal(s.p1.ID, *token.ProjectID) @@ -78,7 +78,7 @@ func (s *apiTokenTestSuite) TestCreate() { name string tokenName string wantErrMsg string - projectID *uuid.UUID + project *biz.Project }{ { name: "name missing", @@ -107,17 +107,17 @@ func (s *apiTokenTestSuite) TestCreate() { { name: "tokens in projects can have the same name", tokenName: "my-name", - projectID: &s.p1.ID, + project: s.p1, }, { name: "tokens in different projects too", tokenName: "my-name", - projectID: &s.p2.ID, + project: s.p2, }, { name: "can't be duplicated in the same project", tokenName: "my-name", - projectID: &s.p1.ID, + project: s.p1, wantErrMsg: "name already taken", }, } @@ -125,8 +125,8 @@ func (s *apiTokenTestSuite) TestCreate() { for _, tc := range testCases { s.Run(tc.name, func() { var opts []biz.APITokenUseCaseOpt - if tc.projectID != nil { - opts = append(opts, biz.APITokenWithProjectID(*tc.projectID)) + if tc.project != nil { + opts = append(opts, biz.APITokenWithProject(tc.project)) } token, err := s.APIToken.Create(ctx, tc.tokenName, nil, nil, s.org.ID, opts...) @@ -273,7 +273,7 @@ func (s *apiTokenTestSuite) TestList() { s.Run("can return only for a specific project", func() { var err error - tokens, err := s.APIToken.List(ctx, s.org.ID, false, biz.APITokenWithProjectID(s.p1.ID)) + tokens, err := s.APIToken.List(ctx, s.org.ID, false, biz.APITokenWithProject(s.p1)) s.NoError(err) require.Len(s.T(), tokens, 2) s.Equal(s.t4.ID, tokens[0].ID) @@ -369,8 +369,8 @@ func (s *apiTokenTestSuite) SetupTest() { require.NoError(s.T(), err) // Create 2 tokens for project 1 - s.t4, err = s.APIToken.Create(ctx, randomName(), nil, nil, s.org.ID, biz.APITokenWithProjectID(s.p1.ID)) + s.t4, err = s.APIToken.Create(ctx, randomName(), nil, nil, s.org.ID, biz.APITokenWithProject(s.p1)) require.NoError(s.T(), err) - s.t5, err = s.APIToken.Create(ctx, randomName(), nil, nil, s.org.ID, biz.APITokenWithProjectID(s.p1.ID)) + s.t5, err = s.APIToken.Create(ctx, randomName(), nil, nil, s.org.ID, biz.APITokenWithProject(s.p1)) require.NoError(s.T(), err) } diff --git a/app/controlplane/pkg/jwt/apitoken/apitoken.go b/app/controlplane/pkg/jwt/apitoken/apitoken.go index 5fa50fff6..a9be059bd 100644 --- a/app/controlplane/pkg/jwt/apitoken/apitoken.go +++ b/app/controlplane/pkg/jwt/apitoken/apitoken.go @@ -1,5 +1,5 @@ // -// Copyright 2024 The Chainloop Authors. +// Copyright 2024-2025 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import ( "time" "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" ) var SigningMethod = jwt.SigningMethodHS256 @@ -66,22 +67,58 @@ func NewBuilder(opts ...NewOpt) (*Builder, error) { return b, nil } +type GenerateJWTOptions struct { + OrgID uuid.UUID + OrgName string + KeyID uuid.UUID + KeyName string + ProjectID *uuid.UUID + ProjectName *string + ExpiresAt *time.Time +} + // GenerateJWT creates a new JWT token for the given organization and keyID -func (ra *Builder) GenerateJWT(orgID, orgName, keyID string, expiresAt *time.Time) (string, error) { +func (ra *Builder) GenerateJWT(opts *GenerateJWTOptions) (string, error) { + if opts == nil { + return "", errors.New("options are required") + } + + if opts.OrgID == uuid.Nil { + return "", errors.New("orgID is required") + } + + if opts.OrgName == "" { + return "", errors.New("orgName is required") + } + + if opts.KeyID == uuid.Nil { + return "", errors.New("keyID is required") + } + + if opts.KeyName == "" { + return "", errors.New("keyName is required") + } + claims := CustomClaims{ - orgID, - orgName, - jwt.RegisteredClaims{ + OrgID: opts.OrgID.String(), + OrgName: opts.OrgName, + KeyName: opts.KeyName, + RegisteredClaims: jwt.RegisteredClaims{ // Key identifier so we can check its revocation status - ID: keyID, + ID: opts.KeyID.String(), Issuer: ra.issuer, Audience: jwt.ClaimStrings{Audience}, }, } + if opts.ProjectID != nil { + claims.ProjectID = opts.ProjectID.String() + claims.ProjectName = *opts.ProjectName + } + // optional expiration value, i.e 30 days - if expiresAt != nil { - claims.ExpiresAt = jwt.NewNumericDate(*expiresAt) + if opts.ExpiresAt != nil { + claims.ExpiresAt = jwt.NewNumericDate(*opts.ExpiresAt) } resultToken := jwt.NewWithClaims(SigningMethod, claims) @@ -89,7 +126,10 @@ func (ra *Builder) GenerateJWT(orgID, orgName, keyID string, expiresAt *time.Tim } type CustomClaims struct { - OrgID string `json:"org_id"` - OrgName string `json:"org_name"` + OrgID string `json:"org_id"` + OrgName string `json:"org_name"` + KeyName string `json:"token_name"` + ProjectID string `json:"project_id,omitempty"` + ProjectName string `json:"project_name,omitempty"` jwt.RegisteredClaims } diff --git a/app/controlplane/pkg/jwt/apitoken/apitoken_test.go b/app/controlplane/pkg/jwt/apitoken/apitoken_test.go index 27a28bbbb..e04aaec39 100644 --- a/app/controlplane/pkg/jwt/apitoken/apitoken_test.go +++ b/app/controlplane/pkg/jwt/apitoken/apitoken_test.go @@ -1,5 +1,5 @@ // -// Copyright 2024 The Chainloop Authors. +// Copyright 2024-2025 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import ( "time" "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -69,30 +70,127 @@ func TestNewBuilder(t *testing.T) { func TestGenerateJWT(t *testing.T) { const hmacSecret = "my-secret" + testCases := []struct { + name string + opts *GenerateJWTOptions + wantErr bool + }{ + { + name: "no project", + opts: &GenerateJWTOptions{ + OrgID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + OrgName: "org-name", + KeyName: "key-name", + KeyID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + ExpiresAt: toPtr(time.Now().Add(1 * time.Hour)), + }, + }, + { + name: "no expiration", + opts: &GenerateJWTOptions{ + OrgID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + OrgName: "org-name", + KeyName: "key-name", + KeyID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + }, + }, + { + name: "with project", + opts: &GenerateJWTOptions{ + OrgID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + OrgName: "org-name", + KeyName: "key-name", + KeyID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + ProjectID: toPtr(uuid.MustParse("123e4567-e89b-12d3-a456-426614174000")), + ProjectName: toPtr("project-name"), + ExpiresAt: toPtr(time.Now().Add(1 * time.Hour)), + }, + }, + { + name: "missing orgID", + opts: &GenerateJWTOptions{ + OrgName: "org-name", + KeyName: "key-name", + KeyID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + ExpiresAt: toPtr(time.Now().Add(1 * time.Hour)), + }, + wantErr: true, + }, + { + name: "missing orgName", + opts: &GenerateJWTOptions{ + OrgID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + KeyName: "key-name", + KeyID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + ExpiresAt: toPtr(time.Now().Add(1 * time.Hour)), + }, + wantErr: true, + }, + { + name: "missing keyID", + opts: &GenerateJWTOptions{ + OrgID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + OrgName: "org-name", + KeyName: "key-name", + ExpiresAt: toPtr(time.Now().Add(1 * time.Hour)), + }, + wantErr: true, + }, + { + name: "missing keyName", + opts: &GenerateJWTOptions{ + OrgID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + OrgName: "org-name", + KeyID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + ExpiresAt: toPtr(time.Now().Add(1 * time.Hour)), + }, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + b, err := NewBuilder(WithIssuer("my-issuer"), WithKeySecret(hmacSecret)) + require.NoError(t, err) + + token, err := b.GenerateJWT(tc.opts) + if tc.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.NotEmpty(t, token) + + claims := &CustomClaims{} + tokenInfo, err := jwt.ParseWithClaims(token, claims, func(_ *jwt.Token) (interface{}, error) { + return []byte(hmacSecret), nil + }) + + require.NoError(t, err) + assert.True(t, tokenInfo.Valid) + assert.Equal(t, tc.opts.OrgID.String(), claims.OrgID) + assert.Equal(t, tc.opts.OrgName, claims.OrgName) + assert.Equal(t, tc.opts.KeyID.String(), claims.ID) + assert.Equal(t, tc.opts.KeyName, claims.KeyName) + + if tc.opts.ProjectID != nil { + assert.Equal(t, tc.opts.ProjectID.String(), claims.ProjectID) + assert.Equal(t, *tc.opts.ProjectName, claims.ProjectName) + } else { + assert.Empty(t, claims.ProjectID) + assert.Empty(t, claims.ProjectName) + } - b, err := NewBuilder(WithIssuer("my-issuer"), WithKeySecret(hmacSecret)) - require.NoError(t, err) - - token, err := b.GenerateJWT("org-id", "org-name", "key-id", toPtrTime(time.Now().Add(1*time.Hour))) - assert.NoError(t, err) - assert.NotEmpty(t, token) - - // Verify signature and check claims - claims := &CustomClaims{} - tokenInfo, err := jwt.ParseWithClaims(token, claims, func(_ *jwt.Token) (interface{}, error) { - return []byte(hmacSecret), nil - }) - - require.NoError(t, err) - assert.True(t, tokenInfo.Valid) - assert.Equal(t, "org-id", claims.OrgID) - assert.Equal(t, "org-name", claims.OrgName) - assert.Equal(t, "key-id", claims.ID) - assert.Equal(t, "my-issuer", claims.Issuer) - assert.Contains(t, claims.Audience, Audience) - assert.NotNil(t, claims.ExpiresAt) + if tc.opts.ExpiresAt != nil { + assert.True(t, claims.ExpiresAt.After(time.Now())) + } else { + assert.Nil(t, claims.ExpiresAt) + } + }) + } } -func toPtrTime(t time.Time) *time.Time { +func toPtr[T any](t T) *T { return &t }