From 7fdac7d2c70c5cf1eab4bd781356cb4fd9d9a759 Mon Sep 17 00:00:00 2001 From: Miguel Martinez Trivino Date: Sat, 16 Dec 2023 09:59:55 +0100 Subject: [PATCH 1/5] feat: allow upload/downloads using API token Signed-off-by: Miguel Martinez Trivino --- app/controlplane/internal/biz/casmapping.go | 58 +++++++++++-------- .../biz/casmapping_integration_test.go | 18 +++--- .../internal/service/cascredential.go | 17 ++++-- .../internal/service/casredirect.go | 17 ++++-- app/controlplane/internal/service/context.go | 10 +--- app/controlplane/internal/service/service.go | 20 +++++++ 6 files changed, 91 insertions(+), 49 deletions(-) diff --git a/app/controlplane/internal/biz/casmapping.go b/app/controlplane/internal/biz/casmapping.go index aad864d75..e5268e5d8 100644 --- a/app/controlplane/internal/biz/casmapping.go +++ b/app/controlplane/internal/biz/casmapping.go @@ -79,12 +79,12 @@ func (uc *CASMappingUseCase) FindByDigest(ctx context.Context, digest string) ([ return uc.repo.FindByDigest(ctx, digest) } -// FindCASMappingForDownload returns the CASMapping appropriate for the given digest and user +// FindCASMappingForDownloadByUser returns the CASMapping appropriate for the given digest and user // This means, in order // 1 - Any mapping that points to an organization which the user is member of // 1.1 If there are multiple mappings, it will pick the default one or the first one // 2 - Any mapping that is public -func (uc *CASMappingUseCase) FindCASMappingForDownload(ctx context.Context, digest string, userID string) (*CASMapping, error) { +func (uc *CASMappingUseCase) FindCASMappingForDownloadByUser(ctx context.Context, digest string, userID string) (*CASMapping, error) { uc.logger.Infow("msg", "finding cas mapping for download", "digest", digest, "user", userID) userUUID, err := uuid.Parse(userID) @@ -92,29 +92,48 @@ func (uc *CASMappingUseCase) FindCASMappingForDownload(ctx context.Context, dige return nil, NewErrInvalidUUID(err) } - if _, err = cr_v1.NewHash(digest); err != nil { + // Load organizations for the given user + memberships, err := uc.membershipRepo.FindByUser(ctx, userUUID) + if err != nil { + return nil, fmt.Errorf("failed to list memberships: %w", err) + } + + userOrgs := make([]string, 0, len(memberships)) + for _, m := range memberships { + userOrgs = append(userOrgs, m.OrganizationID.String()) + } + + return uc.FindCASMappingForDownloadByOrg(ctx, digest, userOrgs) +} + +func (uc *CASMappingUseCase) FindCASMappingForDownloadByOrg(ctx context.Context, digest string, orgs []string) (*CASMapping, error) { + if _, err := cr_v1.NewHash(digest); err != nil { return nil, NewErrValidation(fmt.Errorf("invalid digest format: %w", err)) } + if len(orgs) == 0 { + return nil, NewErrValidationStr("no organizations provided") + } + // 1 - All CAS mappings for the given digest mappings, err := uc.repo.FindByDigest(ctx, digest) if err != nil { return nil, fmt.Errorf("failed to list cas mappings: %w", err) } - uc.logger.Debugw("msg", fmt.Sprintf("found %d entries globally", len(mappings)), "digest", digest, "user", userID) + uc.logger.Debugw("msg", fmt.Sprintf("found %d entries globally", len(mappings)), "digest", digest, "orgs", orgs) if len(mappings) == 0 { return nil, NewErrNotFound("digest not found in any mapping") } - // 2 - CAS mappings that the user has access to. - // This means any mapping that points to an organization which the user is member of - userMappings, err := filterByUser(ctx, mappings, userUUID, uc.membershipRepo) + // 2 - CAS mappings associated with the given list of orgs + orgMappings, err := filterByOrgs(mappings, orgs) if err != nil { return nil, fmt.Errorf("failed to load mappings associated to an user: %w", err) - } else if len(userMappings) > 0 { - result := defaultOrFirst(userMappings) - uc.logger.Infow("msg", "mapping found!", "digest", digest, "user", userID, "casBackend", result.CASBackend.ID, "default", result.CASBackend.Default, "public", result.Public) + } else if len(orgMappings) > 0 { + result := defaultOrFirst(orgMappings) + + uc.logger.Infow("msg", "mapping found!", "digest", digest, "orgs", orgs, "casBackend", result.CASBackend.ID, "default", result.CASBackend.Default, "public", result.Public) return result, nil } @@ -122,30 +141,23 @@ func (uc *CASMappingUseCase) FindCASMappingForDownload(ctx context.Context, dige publicMappings := filterByPublic(mappings) // The user has not access to neither proprietary nor public mappings if len(publicMappings) == 0 { - uc.logger.Warnw("msg", "digest exist but user does not have access to it", "digest", digest, "user", userID) + uc.logger.Warnw("msg", "digest exist but user does not have access to it", "digest", digest, "orgs", orgs) return nil, NewErrUnauthorized(errors.New("unauthorized access to the artifact")) } // Pick the appropriate mapping from multiple ones result := defaultOrFirst(publicMappings) - uc.logger.Infow("msg", "mapping found!", "digest", digest, "user", userID, "casBackend", result.CASBackend.ID, "default", result.CASBackend.Default, "public", result.Public) + uc.logger.Infow("msg", "mapping found!", "digest", digest, "orgs", orgs, "casBackend", result.CASBackend.ID, "default", result.CASBackend.Default, "public", result.Public) return result, nil } -// get the casMapping based on -// 1 - the mapping is part of an organization an user has access to -// 2 - if there is more than one, pick the default if possible -func filterByUser(ctx context.Context, mappings []*CASMapping, userID uuid.UUID, mRepo MembershipRepo) ([]*CASMapping, error) { +// Extract only the mappings associated with a list of orgs +func filterByOrgs(mappings []*CASMapping, orgs []string) ([]*CASMapping, error) { result := make([]*CASMapping, 0) - memberships, err := mRepo.FindByUser(ctx, userID) - if err != nil { - return nil, fmt.Errorf("failed to list memberships: %w", err) - } - for _, mapping := range mappings { - for _, m := range memberships { - if mapping.OrgID == m.OrganizationID { + for _, o := range orgs { + if mapping.OrgID.String() == o { result = append(result, mapping) } } diff --git a/app/controlplane/internal/biz/casmapping_integration_test.go b/app/controlplane/internal/biz/casmapping_integration_test.go index 73f84b8ff..6cd0be26c 100644 --- a/app/controlplane/internal/biz/casmapping_integration_test.go +++ b/app/controlplane/internal/biz/casmapping_integration_test.go @@ -39,7 +39,7 @@ const ( invalidDigest = "sha256:deadbeef" ) -func (s *casMappingIntegrationSuite) TestCASMappingForDownload() { +func (s *casMappingIntegrationSuite) TestCASMappingForDownlod() { // Let's create 3 CASMappings: // 1. Digest: validDigest, CASBackend: casBackend1, WorkflowRunID: workflowRun // 2. Digest: validDigest, CASBackend: casBackend2, WorkflowRunID: workflowRun @@ -60,55 +60,55 @@ func (s *casMappingIntegrationSuite) TestCASMappingForDownload() { // Since the userOrg1And2 is member of org1 and org2, she should be able to download // both validDigest and validDigest2 from two different orgs s.Run("userOrg1And2 can download validDigest from org1", func() { - mapping, err := s.CASMapping.FindCASMappingForDownload(context.TODO(), validDigest, s.userOrg1And2.ID) + mapping, err := s.CASMapping.FindCASMappingForDownloadByUser(context.TODO(), validDigest, s.userOrg1And2.ID) s.NoError(err) s.NotNil(mapping) s.Equal(s.casBackend1.ID, mapping.CASBackend.ID) }) s.Run("userOrg1And2 can download validDigest2 from org2", func() { - mapping, err := s.CASMapping.FindCASMappingForDownload(context.TODO(), validDigest2, s.userOrg1And2.ID) + mapping, err := s.CASMapping.FindCASMappingForDownloadByUser(context.TODO(), validDigest2, s.userOrg1And2.ID) s.NoError(err) s.NotNil(mapping) s.Equal(s.casBackend2.ID, mapping.CASBackend.ID) }) s.Run("userOrg1And2 can not download validDigest3 from org3", func() { - mapping, err := s.CASMapping.FindCASMappingForDownload(context.TODO(), validDigest3, s.userOrg1And2.ID) + mapping, err := s.CASMapping.FindCASMappingForDownloadByUser(context.TODO(), validDigest3, s.userOrg1And2.ID) s.Error(err) s.Nil(mapping) }) s.Run("userOrg1And2 can download validDigestPublic from org3", func() { - mapping, err := s.CASMapping.FindCASMappingForDownload(context.TODO(), validDigestPublic, s.userOrg1And2.ID) + mapping, err := s.CASMapping.FindCASMappingForDownloadByUser(context.TODO(), validDigestPublic, s.userOrg1And2.ID) s.NoError(err) s.NotNil(mapping) s.Equal(s.casBackend3.ID, mapping.CASBackend.ID) }) s.Run("userOrg2 can download validDigest2 from org2", func() { - mapping, err := s.CASMapping.FindCASMappingForDownload(context.TODO(), validDigest2, s.userOrg2.ID) + mapping, err := s.CASMapping.FindCASMappingForDownloadByUser(context.TODO(), validDigest2, s.userOrg2.ID) s.NoError(err) s.NotNil(mapping) s.Equal(s.casBackend2.ID, mapping.CASBackend.ID) }) s.Run("userOrg2 can download validDigestPublic from org3", func() { - mapping, err := s.CASMapping.FindCASMappingForDownload(context.TODO(), validDigestPublic, s.userOrg2.ID) + mapping, err := s.CASMapping.FindCASMappingForDownloadByUser(context.TODO(), validDigestPublic, s.userOrg2.ID) s.NoError(err) s.NotNil(mapping) s.Equal(s.casBackend3.ID, mapping.CASBackend.ID) }) s.Run("userOrg2 can download validDigest from org2", func() { - mapping, err := s.CASMapping.FindCASMappingForDownload(context.TODO(), validDigest, s.userOrg2.ID) + mapping, err := s.CASMapping.FindCASMappingForDownloadByUser(context.TODO(), validDigest, s.userOrg2.ID) s.NoError(err) s.NotNil(mapping) s.Equal(s.casBackend2.ID, mapping.CASBackend.ID) }) s.Run("userOrg2 can not download invalidDigest", func() { - mapping, err := s.CASMapping.FindCASMappingForDownload(context.TODO(), invalidDigest, s.userOrg2.ID) + mapping, err := s.CASMapping.FindCASMappingForDownloadByUser(context.TODO(), invalidDigest, s.userOrg2.ID) s.Error(err) s.Nil(mapping) }) diff --git a/app/controlplane/internal/service/cascredential.go b/app/controlplane/internal/service/cascredential.go index 701e3bb1f..74c7f31a0 100644 --- a/app/controlplane/internal/service/cascredential.go +++ b/app/controlplane/internal/service/cascredential.go @@ -48,8 +48,7 @@ func NewCASCredentialsService(casUC *biz.CASCredentialsUseCase, casmUC *biz.CASM // Get will generate temporary credentials to be used against the CAS service for the current organization func (s *CASCredentialsService) Get(ctx context.Context, req *pb.CASCredentialsServiceGetRequest) (*pb.CASCredentialsServiceGetResponse, error) { - // TODO: Add support API-Token-based authentication - currentUser, err := requireCurrentUser(ctx) + currentUser, currentAPIToken, err := requireCurrentUserOrAPIToken(ctx) if err != nil { return nil, err } @@ -72,12 +71,20 @@ func (s *CASCredentialsService) Get(ctx context.Context, req *pb.CASCredentialsS if err != nil && !biz.IsNotFound(err) { return nil, sl.LogAndMaskErr(err, s.log) } else if backend == nil { - return nil, errors.NotFound("not found", "main repository not found") + return nil, errors.NotFound("not found", "main CAS backend not found") } + // Try to find the proper backend where the artifact is stored if role == casJWT.Downloader { - // Try to find the proper backend where the artifact is stored - mapping, err := s.casMappingUC.FindCASMappingForDownload(ctx, req.Digest, currentUser.ID) + var mapping *biz.CASMapping + // If we are logged in as a user, we'll try to find a mapping for that user + if currentUser != nil { + mapping, err = s.casMappingUC.FindCASMappingForDownloadByUser(ctx, req.Digest, currentUser.ID) + // otherwise, we'll try to find a mapping for the current API token associated orgs + } else if currentAPIToken != nil { + mapping, err = s.casMappingUC.FindCASMappingForDownloadByOrg(ctx, req.Digest, []string{currentOrg.ID}) + } + // If we can't find a mapping, we'll use the default backend if err != nil && !biz.IsNotFound(err) && !biz.IsErrUnauthorized(err) { if biz.IsErrValidation(err) { diff --git a/app/controlplane/internal/service/casredirect.go b/app/controlplane/internal/service/casredirect.go index 01ea9f296..c35f08f31 100644 --- a/app/controlplane/internal/service/casredirect.go +++ b/app/controlplane/internal/service/casredirect.go @@ -68,14 +68,23 @@ func NewCASRedirectService(casmUC *biz.CASMappingUseCase, casCredsUC *biz.CASCre // The URL includes a JWT token that is used to authenticate the request, this token has all the information required to validate the request // The result would look like "https://cas.chainloop.dev/download/sha256:[DIGEST]?t=tokenJWT func (s *CASRedirectService) GetDownloadURL(ctx context.Context, req *pb.GetDownloadURLRequest) (*pb.GetDownloadURLResponse, error) { - // TODO: Add support API-Token-based authentication - currentUser, err := requireCurrentUser(ctx) + currentUser, currentAPIToken, err := requireCurrentUserOrAPIToken(ctx) if err != nil { return nil, err } - // Find the CAS backend that should be used for the download, if any - mapping, err := s.casMappingUC.FindCASMappingForDownload(ctx, req.Digest, currentUser.ID) + currentOrg, err := requireCurrentOrg(ctx) + if err != nil { + return nil, err + } + + var mapping *biz.CASMapping + if currentUser != nil { + mapping, err = s.casMappingUC.FindCASMappingForDownloadByUser(ctx, req.Digest, currentUser.ID) + } else if currentAPIToken != nil { + mapping, err = s.casMappingUC.FindCASMappingForDownloadByOrg(ctx, req.Digest, []string{currentOrg.ID}) + } + if err != nil { // We don't want to leak the fact that the asset exists but the user does not have permissions // that's why we return a generic 404 in unauthorized scenarios too diff --git a/app/controlplane/internal/service/context.go b/app/controlplane/internal/service/context.go index 3c5b126fe..ca3dcea14 100644 --- a/app/controlplane/internal/service/context.go +++ b/app/controlplane/internal/service/context.go @@ -45,14 +45,8 @@ func (s *ContextService) Current(ctx context.Context, _ *pb.ContextServiceCurren return nil, err } - // load either user or API token - currentUser, err := requireCurrentUser(ctx) - if err != nil && !errors.IsNotFound(err) { - return nil, err - } - - currentAPIToken, err := requireAPIToken(ctx) - if err != nil && !errors.IsNotFound(err) { + currentUser, currentAPIToken, err := requireCurrentUserOrAPIToken(ctx) + if err != nil { return nil, err } diff --git a/app/controlplane/internal/service/service.go b/app/controlplane/internal/service/service.go index afdd4d775..a9ed3c55d 100644 --- a/app/controlplane/internal/service/service.go +++ b/app/controlplane/internal/service/service.go @@ -68,6 +68,26 @@ func requireAPIToken(ctx context.Context) (*usercontext.APIToken, error) { return token, nil } +func requireCurrentUserOrAPIToken(ctx context.Context) (*usercontext.User, *usercontext.APIToken, error) { + user, err := requireCurrentUser(ctx) + if err != nil && !errors.IsNotFound(err) { + return nil, nil, err + } + + apiToken, err := requireAPIToken(ctx) + if err != nil && !errors.IsNotFound(err) { + return nil, nil, err + } + + // NOTE: we shouldn't get to this point since the middleware should have already catched this + // Adding the check here for defensivity and testing purposes + if user == nil && apiToken == nil { + return nil, nil, errors.Forbidden("authz required", "logged in user nor API token found") + } + + return user, apiToken, nil +} + func requireCurrentOrg(ctx context.Context) (*usercontext.Org, error) { currentOrg := usercontext.CurrentOrg(ctx) if currentOrg == nil { From 41fc7bb3646d8b70a8af0681f63d4bdd9ebc6ece Mon Sep 17 00:00:00 2001 From: Miguel Martinez Trivino Date: Sat, 16 Dec 2023 15:59:25 +0100 Subject: [PATCH 2/5] feat: allow upload/downloads using API token Signed-off-by: Miguel Martinez Trivino --- .../biz/casmapping_integration_test.go | 63 +++++--- .../internal/service/service_test.go | 150 ++++++++++++++++++ .../usercontext/allowlist_middleware_test.go | 2 +- .../usercontext/apitoken_middleware.go | 6 +- .../usercontext/currentuser_middleware.go | 8 +- 5 files changed, 202 insertions(+), 27 deletions(-) create mode 100644 app/controlplane/internal/service/service_test.go diff --git a/app/controlplane/internal/biz/casmapping_integration_test.go b/app/controlplane/internal/biz/casmapping_integration_test.go index 6cd0be26c..a6945d0cc 100644 --- a/app/controlplane/internal/biz/casmapping_integration_test.go +++ b/app/controlplane/internal/biz/casmapping_integration_test.go @@ -39,25 +39,7 @@ const ( invalidDigest = "sha256:deadbeef" ) -func (s *casMappingIntegrationSuite) TestCASMappingForDownlod() { - // Let's create 3 CASMappings: - // 1. Digest: validDigest, CASBackend: casBackend1, WorkflowRunID: workflowRun - // 2. Digest: validDigest, CASBackend: casBackend2, WorkflowRunID: workflowRun - // 3. Digest: validDigest2, CASBackend: casBackend2, WorkflowRunID: workflowRun - // 4. Digest: validDigest3, CASBackend: casBackend3, WorkflowRunID: workflowRun - // 4. Digest: validDigestPublic, CASBackend: casBackend3, WorkflowRunID: workflowRunPublic - _, err := s.CASMapping.Create(context.TODO(), validDigest, s.casBackend1.ID.String(), s.workflowRun.ID.String()) - require.NoError(s.T(), err) - _, err = s.CASMapping.Create(context.TODO(), validDigest, s.casBackend2.ID.String(), s.workflowRun.ID.String()) - require.NoError(s.T(), err) - _, err = s.CASMapping.Create(context.TODO(), validDigest2, s.casBackend2.ID.String(), s.workflowRun.ID.String()) - require.NoError(s.T(), err) - _, err = s.CASMapping.Create(context.TODO(), validDigest3, s.casBackend3.ID.String(), s.workflowRun.ID.String()) - require.NoError(s.T(), err) - _, err = s.CASMapping.Create(context.TODO(), validDigestPublic, s.casBackend3.ID.String(), s.publicWorkflowRun.ID.String()) - require.NoError(s.T(), err) - - // Since the userOrg1And2 is member of org1 and org2, she should be able to download +func (s *casMappingIntegrationSuite) TestCASMappingForDownloadUser() { // both validDigest and validDigest2 from two different orgs s.Run("userOrg1And2 can download validDigest from org1", func() { mapping, err := s.CASMapping.FindCASMappingForDownloadByUser(context.TODO(), validDigest, s.userOrg1And2.ID) @@ -114,6 +96,30 @@ func (s *casMappingIntegrationSuite) TestCASMappingForDownlod() { }) } +func (s *casMappingIntegrationSuite) TestCASMappingForDownloadByOrg() { + ctx := context.Background() + // both validDigest and validDigest2 from two different orgs + s.Run("validDigest is in org1", func() { + mapping, err := s.CASMapping.FindCASMappingForDownloadByOrg(ctx, validDigest, []string{s.org1.ID}) + s.NoError(err) + s.NotNil(mapping) + s.Equal(s.casBackend1.ID, mapping.CASBackend.ID) + }) + + s.Run("validDigestPublic is available from any org", func() { + mapping, err := s.CASMapping.FindCASMappingForDownloadByOrg(ctx, validDigestPublic, []string{uuid.NewString()}) + s.NoError(err) + s.NotNil(mapping) + s.Equal(s.casBackend3.ID, mapping.CASBackend.ID) + }) + + s.Run("can't find an invalid digest", func() { + mapping, err := s.CASMapping.FindCASMappingForDownloadByOrg(ctx, invalidDigest, []string{s.org1.ID}) + s.Error(err) + s.Nil(mapping) + }) +} + func (s *casMappingIntegrationSuite) TestFindByDigest() { // 1. Digest: validDigest, CASBackend: casBackend1, WorkflowRunID: workflowRun // 2. Digest: validDigest2, CASBackend: casBackend1, WorkflowRunID: workflowRun @@ -362,6 +368,25 @@ func (s *casMappingIntegrationSuite) SetupTest() { assert.NoError(err) _, err = s.Membership.Create(ctx, s.org2.ID, s.userOrg2.ID, true) assert.NoError(err) + + // Let's create 3 CASMappings: + // 1. Digest: validDigest, CASBackend: casBackend1, WorkflowRunID: workflowRun + // 2. Digest: validDigest, CASBackend: casBackend2, WorkflowRunID: workflowRun + // 3. Digest: validDigest2, CASBackend: casBackend2, WorkflowRunID: workflowRun + // 4. Digest: validDigest3, CASBackend: casBackend3, WorkflowRunID: workflowRun + // 4. Digest: validDigestPublic, CASBackend: casBackend3, WorkflowRunID: workflowRunPublic + _, err = s.CASMapping.Create(context.TODO(), validDigest, s.casBackend1.ID.String(), s.workflowRun.ID.String()) + require.NoError(s.T(), err) + _, err = s.CASMapping.Create(context.TODO(), validDigest, s.casBackend2.ID.String(), s.workflowRun.ID.String()) + require.NoError(s.T(), err) + _, err = s.CASMapping.Create(context.TODO(), validDigest2, s.casBackend2.ID.String(), s.workflowRun.ID.String()) + require.NoError(s.T(), err) + _, err = s.CASMapping.Create(context.TODO(), validDigest3, s.casBackend3.ID.String(), s.workflowRun.ID.String()) + require.NoError(s.T(), err) + _, err = s.CASMapping.Create(context.TODO(), validDigestPublic, s.casBackend3.ID.String(), s.publicWorkflowRun.ID.String()) + require.NoError(s.T(), err) + + // Since the userOrg1And2 is member of org1 and org2, she should be able to download } func TestCASMappingIntegration(t *testing.T) { diff --git a/app/controlplane/internal/service/service_test.go b/app/controlplane/internal/service/service_test.go new file mode 100644 index 000000000..1b47e8850 --- /dev/null +++ b/app/controlplane/internal/service/service_test.go @@ -0,0 +1,150 @@ +// +// Copyright 2023 The Chainloop Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "testing" + + "github.com/chainloop-dev/chainloop/app/controlplane/internal/usercontext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequireCurrentUser(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("no user", func(t *testing.T) { + _, err := requireCurrentUser(ctx) + assert.Error(t, err) + }) + + t.Run("with user", func(t *testing.T) { + want := &usercontext.User{} + ctx = usercontext.WithCurrentUser(ctx, want) + u, err := requireCurrentUser(ctx) + assert.NoError(t, err) + require.Equal(t, want, u) + }) +} + +func TestRequireCurrentOrg(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("no org", func(t *testing.T) { + _, err := requireCurrentOrg(ctx) + assert.Error(t, err) + }) + + t.Run("with org", func(t *testing.T) { + want := &usercontext.Org{} + ctx = usercontext.WithCurrentOrg(ctx, want) + o, err := requireCurrentOrg(ctx) + assert.NoError(t, err) + require.Equal(t, want, o) + }) +} + +func TestRequireAPIToken(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("no token", func(t *testing.T) { + _, err := requireAPIToken(ctx) + assert.Error(t, err) + }) + + t.Run("with token", func(t *testing.T) { + want := &usercontext.APIToken{} + ctx = usercontext.WithCurrentAPIToken(ctx, want) + got, err := requireAPIToken(ctx) + assert.NoError(t, err) + require.Equal(t, want, got) + }) +} + +func TestRequireCurrentUserOrAPIToken(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + tesCases := []struct { + name string + hasUser bool + hasToken bool + wantErr bool + }{ + { + name: "no user nor token", + hasUser: false, + hasToken: false, + wantErr: true, + }, + { + name: "with user", + hasUser: true, + hasToken: false, + wantErr: false, + }, + { + name: "with token", + hasUser: false, + hasToken: true, + wantErr: false, + }, + } + + for _, tc := range tesCases { + t.Run(tc.name, func(t *testing.T) { + ctx = context.Background() + wantUser := &usercontext.User{} + wantToken := &usercontext.APIToken{} + + if tc.hasUser { + ctx = usercontext.WithCurrentUser(ctx, wantUser) + } + + if tc.hasToken { + ctx = usercontext.WithCurrentAPIToken(ctx, wantToken) + } + + gotUser, gotToken, err := requireCurrentUserOrAPIToken(ctx) + if tc.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + if tc.hasUser { + require.Equal(t, wantUser, gotUser) + } else { + assert.Nil(t, gotUser) + } + + if tc.hasToken { + require.Equal(t, wantToken, gotToken) + } else { + assert.Nil(t, gotToken) + } + }) + } + +} diff --git a/app/controlplane/internal/usercontext/allowlist_middleware_test.go b/app/controlplane/internal/usercontext/allowlist_middleware_test.go index c24b9df68..0b29905f7 100644 --- a/app/controlplane/internal/usercontext/allowlist_middleware_test.go +++ b/app/controlplane/internal/usercontext/allowlist_middleware_test.go @@ -59,7 +59,7 @@ func TestCheckUserInAllowList(t *testing.T) { m := CheckUserInAllowList(tc.allowList) ctx := context.Background() if tc.user != nil { - ctx = withCurrentUser(ctx, tc.user) + ctx = WithCurrentUser(ctx, tc.user) } _, err := m(emptyHandler)(ctx, nil) diff --git a/app/controlplane/internal/usercontext/apitoken_middleware.go b/app/controlplane/internal/usercontext/apitoken_middleware.go index fc5afc907..00ef7f6f1 100644 --- a/app/controlplane/internal/usercontext/apitoken_middleware.go +++ b/app/controlplane/internal/usercontext/apitoken_middleware.go @@ -35,7 +35,7 @@ type APIToken struct { CreatedAt *time.Time } -func withCurrentAPIToken(ctx context.Context, token *APIToken) context.Context { +func WithCurrentAPIToken(ctx context.Context, token *APIToken) context.Context { return context.WithValue(ctx, currentAPITokenCtxKey{}, token) } @@ -114,7 +114,7 @@ func setCurrentOrgAndAPIToken(ctx context.Context, apiTokenUC *biz.APITokenUseCa return nil, errors.New("organization not found") } - ctx = withCurrentOrg(ctx, &Org{Name: org.Name, ID: org.ID, CreatedAt: org.CreatedAt}) - ctx = withCurrentAPIToken(ctx, &APIToken{ID: token.ID.String(), CreatedAt: token.CreatedAt}) + ctx = WithCurrentOrg(ctx, &Org{Name: org.Name, ID: org.ID, CreatedAt: org.CreatedAt}) + ctx = WithCurrentAPIToken(ctx, &APIToken{ID: token.ID.String(), CreatedAt: token.CreatedAt}) return ctx, nil } diff --git a/app/controlplane/internal/usercontext/currentuser_middleware.go b/app/controlplane/internal/usercontext/currentuser_middleware.go index a387afaad..caa9aeddc 100644 --- a/app/controlplane/internal/usercontext/currentuser_middleware.go +++ b/app/controlplane/internal/usercontext/currentuser_middleware.go @@ -41,7 +41,7 @@ type Org struct { CreatedAt *time.Time } -func withCurrentUser(ctx context.Context, user *User) context.Context { +func WithCurrentUser(ctx context.Context, user *User) context.Context { return context.WithValue(ctx, currentUserCtxKey{}, user) } @@ -55,7 +55,7 @@ func CurrentUser(ctx context.Context) *User { return res.(*User) } -func withCurrentOrg(ctx context.Context, org *Org) context.Context { +func WithCurrentOrg(ctx context.Context, org *Org) context.Context { return context.WithValue(ctx, currentOrgCtxKey{}, org) } @@ -133,8 +133,8 @@ func setCurrentOrgAndUser(ctx context.Context, userUC biz.UserOrgFinder, userID return nil, errors.New("org not found") } - ctx = withCurrentOrg(ctx, &Org{Name: org.Name, ID: org.ID, CreatedAt: org.CreatedAt}) - ctx = withCurrentUser(ctx, &User{Email: u.Email, ID: u.ID, CreatedAt: u.CreatedAt}) + ctx = WithCurrentOrg(ctx, &Org{Name: org.Name, ID: org.ID, CreatedAt: org.CreatedAt}) + ctx = WithCurrentUser(ctx, &User{Email: u.Email, ID: u.ID, CreatedAt: u.CreatedAt}) return ctx, nil } From 1abee3be15eb22692f1de329050baf9d342c3582 Mon Sep 17 00:00:00 2001 From: Miguel Martinez Trivino Date: Sat, 16 Dec 2023 16:02:53 +0100 Subject: [PATCH 3/5] feat: allow upload/downloads using API token Signed-off-by: Miguel Martinez Trivino --- app/controlplane/internal/service/service_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/app/controlplane/internal/service/service_test.go b/app/controlplane/internal/service/service_test.go index 1b47e8850..48eaa5a17 100644 --- a/app/controlplane/internal/service/service_test.go +++ b/app/controlplane/internal/service/service_test.go @@ -146,5 +146,4 @@ func TestRequireCurrentUserOrAPIToken(t *testing.T) { } }) } - } From 61766cdffb2e2bfca1f83405ceec95434550ccef Mon Sep 17 00:00:00 2001 From: Miguel Martinez Trivino Date: Sat, 16 Dec 2023 16:12:42 +0100 Subject: [PATCH 4/5] feat: allow upload/downloads using API token Signed-off-by: Miguel Martinez Trivino --- .../biz/casmapping_integration_test.go | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/app/controlplane/internal/biz/casmapping_integration_test.go b/app/controlplane/internal/biz/casmapping_integration_test.go index a6945d0cc..99c0e22ba 100644 --- a/app/controlplane/internal/biz/casmapping_integration_test.go +++ b/app/controlplane/internal/biz/casmapping_integration_test.go @@ -40,6 +40,24 @@ const ( ) func (s *casMappingIntegrationSuite) TestCASMappingForDownloadUser() { + // Let's create 3 CASMappings: + // 1. Digest: validDigest, CASBackend: casBackend1, WorkflowRunID: workflowRun + // 2. Digest: validDigest, CASBackend: casBackend2, WorkflowRunID: workflowRun + // 3. Digest: validDigest2, CASBackend: casBackend2, WorkflowRunID: workflowRun + // 4. Digest: validDigest3, CASBackend: casBackend3, WorkflowRunID: workflowRun + // 4. Digest: validDigestPublic, CASBackend: casBackend3, WorkflowRunID: workflowRunPublic + _, err := s.CASMapping.Create(context.TODO(), validDigest, s.casBackend1.ID.String(), s.workflowRun.ID.String()) + require.NoError(s.T(), err) + _, err = s.CASMapping.Create(context.TODO(), validDigest, s.casBackend2.ID.String(), s.workflowRun.ID.String()) + require.NoError(s.T(), err) + _, err = s.CASMapping.Create(context.TODO(), validDigest2, s.casBackend2.ID.String(), s.workflowRun.ID.String()) + require.NoError(s.T(), err) + _, err = s.CASMapping.Create(context.TODO(), validDigest3, s.casBackend3.ID.String(), s.workflowRun.ID.String()) + require.NoError(s.T(), err) + _, err = s.CASMapping.Create(context.TODO(), validDigestPublic, s.casBackend3.ID.String(), s.publicWorkflowRun.ID.String()) + require.NoError(s.T(), err) + + // Since the userOrg1And2 is member of org1 and org2, she should be able to download // both validDigest and validDigest2 from two different orgs s.Run("userOrg1And2 can download validDigest from org1", func() { mapping, err := s.CASMapping.FindCASMappingForDownloadByUser(context.TODO(), validDigest, s.userOrg1And2.ID) @@ -98,6 +116,11 @@ func (s *casMappingIntegrationSuite) TestCASMappingForDownloadUser() { func (s *casMappingIntegrationSuite) TestCASMappingForDownloadByOrg() { ctx := context.Background() + _, err := s.CASMapping.Create(ctx, validDigest, s.casBackend1.ID.String(), s.workflowRun.ID.String()) + require.NoError(s.T(), err) + _, err = s.CASMapping.Create(ctx, validDigestPublic, s.casBackend3.ID.String(), s.publicWorkflowRun.ID.String()) + require.NoError(s.T(), err) + // both validDigest and validDigest2 from two different orgs s.Run("validDigest is in org1", func() { mapping, err := s.CASMapping.FindCASMappingForDownloadByOrg(ctx, validDigest, []string{s.org1.ID}) @@ -368,25 +391,6 @@ func (s *casMappingIntegrationSuite) SetupTest() { assert.NoError(err) _, err = s.Membership.Create(ctx, s.org2.ID, s.userOrg2.ID, true) assert.NoError(err) - - // Let's create 3 CASMappings: - // 1. Digest: validDigest, CASBackend: casBackend1, WorkflowRunID: workflowRun - // 2. Digest: validDigest, CASBackend: casBackend2, WorkflowRunID: workflowRun - // 3. Digest: validDigest2, CASBackend: casBackend2, WorkflowRunID: workflowRun - // 4. Digest: validDigest3, CASBackend: casBackend3, WorkflowRunID: workflowRun - // 4. Digest: validDigestPublic, CASBackend: casBackend3, WorkflowRunID: workflowRunPublic - _, err = s.CASMapping.Create(context.TODO(), validDigest, s.casBackend1.ID.String(), s.workflowRun.ID.String()) - require.NoError(s.T(), err) - _, err = s.CASMapping.Create(context.TODO(), validDigest, s.casBackend2.ID.String(), s.workflowRun.ID.String()) - require.NoError(s.T(), err) - _, err = s.CASMapping.Create(context.TODO(), validDigest2, s.casBackend2.ID.String(), s.workflowRun.ID.String()) - require.NoError(s.T(), err) - _, err = s.CASMapping.Create(context.TODO(), validDigest3, s.casBackend3.ID.String(), s.workflowRun.ID.String()) - require.NoError(s.T(), err) - _, err = s.CASMapping.Create(context.TODO(), validDigestPublic, s.casBackend3.ID.String(), s.publicWorkflowRun.ID.String()) - require.NoError(s.T(), err) - - // Since the userOrg1And2 is member of org1 and org2, she should be able to download } func TestCASMappingIntegration(t *testing.T) { From 85050680f721637ed3106b95582ecb20a519759b Mon Sep 17 00:00:00 2001 From: Miguel Martinez Trivino Date: Sun, 17 Dec 2023 12:51:47 +0100 Subject: [PATCH 5/5] feat: allow use discovery endpoint with api-token Signed-off-by: Miguel Martinez Trivino --- app/controlplane/internal/biz/referrer.go | 8 +++- .../internal/biz/referrer_integration_test.go | 38 +++++++++---------- app/controlplane/internal/service/referrer.go | 26 +++++++++++-- 3 files changed, 48 insertions(+), 24 deletions(-) diff --git a/app/controlplane/internal/biz/referrer.go b/app/controlplane/internal/biz/referrer.go index 1c63e4a2c..b3cd2c89d 100644 --- a/app/controlplane/internal/biz/referrer.go +++ b/app/controlplane/internal/biz/referrer.go @@ -144,10 +144,10 @@ func (s *ReferrerUseCase) ExtractAndPersist(ctx context.Context, att *dsse.Envel return nil } -// GetFromRoot returns the referrer identified by the provided content digest, including its first-level references +// GetFromRootUser returns the referrer identified by the provided content digest, including its first-level references // For example if sha:deadbeef represents an attestation, the result will contain the attestation + materials associated to it // It only returns referrers that belong to organizations the user is member of -func (s *ReferrerUseCase) GetFromRoot(ctx context.Context, digest, rootKind, userID string) (*StoredReferrer, error) { +func (s *ReferrerUseCase) GetFromRootUser(ctx context.Context, digest, rootKind, userID string) (*StoredReferrer, error) { userUUID, err := uuid.Parse(userID) if err != nil { return nil, NewErrInvalidUUID(err) @@ -166,6 +166,10 @@ func (s *ReferrerUseCase) GetFromRoot(ctx context.Context, digest, rootKind, use orgIDs = append(orgIDs, m.OrganizationID) } + return s.GetFromRoot(ctx, digest, rootKind, orgIDs) +} + +func (s *ReferrerUseCase) GetFromRoot(ctx context.Context, digest, rootKind string, orgIDs []uuid.UUID) (*StoredReferrer, error) { filters := make([]GetFromRootFilter, 0) if rootKind != "" { filters = append(filters, WithKind(rootKind)) diff --git a/app/controlplane/internal/biz/referrer_integration_test.go b/app/controlplane/internal/biz/referrer_integration_test.go index 9ba258067..a0d88a028 100644 --- a/app/controlplane/internal/biz/referrer_integration_test.go +++ b/app/controlplane/internal/biz/referrer_integration_test.go @@ -53,7 +53,7 @@ func (s *referrerIntegrationTestSuite) TestGetFromRootInPublicSharedIndex() { s.T().Run("storing it associated with a private workflow keeps it private and not in the index", func(t *testing.T) { err = s.sharedEnabledUC.ExtractAndPersist(ctx, envelope, s.workflow1.ID.String()) require.NoError(s.T(), err) - ref, err := s.Referrer.GetFromRoot(ctx, wantReferrerAtt.Digest, "", s.user.ID) + ref, err := s.Referrer.GetFromRootUser(ctx, wantReferrerAtt.Digest, "", s.user.ID) s.NoError(err) s.False(ref.InPublicWorkflow) res, err := s.sharedEnabledUC.GetFromRootInPublicSharedIndex(ctx, wantReferrerAtt.Digest, "") @@ -69,7 +69,7 @@ func (s *referrerIntegrationTestSuite) TestGetFromRootInPublicSharedIndex() { err = s.sharedEnabledUC.ExtractAndPersist(ctx, envelope, s.workflow2.ID.String()) require.NoError(s.T(), err) // It's marked as public in the internal index - ref, err := s.sharedEnabledUC.GetFromRoot(ctx, wantReferrerAtt.Digest, "", s.user.ID) + ref, err := s.sharedEnabledUC.GetFromRootUser(ctx, wantReferrerAtt.Digest, "", s.user.ID) s.NoError(err) s.True(ref.InPublicWorkflow) @@ -165,21 +165,21 @@ func (s *referrerIntegrationTestSuite) TestExtractAndPersists() { s.T().Run("it can store properly the first time", func(t *testing.T) { err := s.Referrer.ExtractAndPersist(ctx, envelope, s.workflow1.ID.String()) s.NoError(err) - prevStoredRef, err = s.Referrer.GetFromRoot(ctx, wantReferrerAtt.Digest, "", s.user.ID) + prevStoredRef, err = s.Referrer.GetFromRootUser(ctx, wantReferrerAtt.Digest, "", s.user.ID) s.NoError(err) }) s.T().Run("and it's idempotent", func(t *testing.T) { err := s.Referrer.ExtractAndPersist(ctx, envelope, s.workflow1.ID.String()) s.NoError(err) - ref, err := s.Referrer.GetFromRoot(ctx, wantReferrerAtt.Digest, "", s.user.ID) + ref, err := s.Referrer.GetFromRootUser(ctx, wantReferrerAtt.Digest, "", s.user.ID) s.NoError(err) // Check it's the same referrer than previously retrieved, including timestamps s.Equal(prevStoredRef, ref) }) s.T().Run("contains all the info", func(t *testing.T) { - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerAtt.Digest, "", s.user.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerAtt.Digest, "", s.user.ID) s.NoError(err) // parent i.e attestation s.Equal(wantReferrerAtt.Digest, got.Digest) @@ -198,14 +198,14 @@ func (s *referrerIntegrationTestSuite) TestExtractAndPersists() { }) s.T().Run("can get sha1 digests too", func(t *testing.T) { - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerCommit.Digest, "", s.user.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerCommit.Digest, "", s.user.ID) s.NoError(err) s.Equal(wantReferrerCommit.Digest, got.Digest) }) s.T().Run("can't be accessed by a second user in another org", func(t *testing.T) { // the user2 has not access to org1 - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerAtt.Digest, "", s.user2.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerAtt.Digest, "", s.user2.ID) s.True(biz.IsNotFound(err)) s.Nil(got) }) @@ -213,7 +213,7 @@ func (s *referrerIntegrationTestSuite) TestExtractAndPersists() { s.T().Run("but another workflow can be attached", func(t *testing.T) { err = s.Referrer.ExtractAndPersist(ctx, envelope, s.workflow2.ID.String()) s.NoError(err) - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerAtt.Digest, "", s.user.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerAtt.Digest, "", s.user.ID) s.NoError(err) require.Len(t, got.OrgIDs, 2) s.Contains(got.OrgIDs, s.org1UUID) @@ -222,7 +222,7 @@ func (s *referrerIntegrationTestSuite) TestExtractAndPersists() { // and it's idempotent (no new orgs added) err = s.Referrer.ExtractAndPersist(ctx, envelope, s.workflow2.ID.String()) s.NoError(err) - got, err = s.Referrer.GetFromRoot(ctx, wantReferrerAtt.Digest, "", s.user.ID) + got, err = s.Referrer.GetFromRootUser(ctx, wantReferrerAtt.Digest, "", s.user.ID) s.NoError(err) require.Len(t, got.OrgIDs, 2) s.Equal([]uuid.UUID{s.org1UUID, s.org2UUID}, got.OrgIDs) @@ -232,13 +232,13 @@ func (s *referrerIntegrationTestSuite) TestExtractAndPersists() { s.T().Run("and now user2 has access to it since it has access to workflow2 in org2", func(t *testing.T) { err = s.Referrer.ExtractAndPersist(ctx, envelope, s.workflow2.ID.String()) s.NoError(err) - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerAtt.Digest, "", s.user2.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerAtt.Digest, "", s.user2.ID) s.NoError(err) require.Len(t, got.OrgIDs, 2) }) s.T().Run("you can ask for info about materials that are subjects", func(t *testing.T) { - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerContainerImage.Digest, "", s.user.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerContainerImage.Digest, "", s.user.ID) s.NoError(err) // parent i.e attestation s.Equal(wantReferrerContainerImage.Digest, got.Digest) @@ -252,7 +252,7 @@ func (s *referrerIntegrationTestSuite) TestExtractAndPersists() { }) s.T().Run("it might not have references", func(t *testing.T) { - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerSarif.Digest, "", s.user.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerSarif.Digest, "", s.user.ID) s.NoError(err) // parent i.e attestation s.Equal(wantReferrerSarif.Digest, got.Digest) @@ -262,7 +262,7 @@ func (s *referrerIntegrationTestSuite) TestExtractAndPersists() { }) s.T().Run("or it does not exist", func(t *testing.T) { - got, err := s.Referrer.GetFromRoot(ctx, "sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "", s.user.ID) + got, err := s.Referrer.GetFromRootUser(ctx, "sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "", s.user.ID) s.True(biz.IsNotFound(err)) s.Nil(got) }) @@ -287,20 +287,20 @@ func (s *referrerIntegrationTestSuite) TestExtractAndPersists() { s.NoError(err) // but retrieval should fail. In the future we will ask the user to provide the artifact type in these cases of ambiguity - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerSarif.Digest, "", s.user.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerSarif.Digest, "", s.user.ID) s.Nil(got) s.ErrorContains(err, "present in 2 kinds") }) s.T().Run("it should not fail on retrieval if we filter out by one kind", func(t *testing.T) { // but retrieval should fail. In the future we will ask the user to provide the artifact type in these cases of ambiguity - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerSarif.Digest, "SARIF", s.user.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerSarif.Digest, "SARIF", s.user.ID) s.NoError(err) s.Equal(wantReferrerSarif.Digest, got.Digest) s.Equal(true, got.Downloadable) s.Equal("SARIF", got.Kind) - got, err = s.Referrer.GetFromRoot(ctx, wantReferrerSarif.Digest, "ARTIFACT", s.user.ID) + got, err = s.Referrer.GetFromRootUser(ctx, wantReferrerSarif.Digest, "ARTIFACT", s.user.ID) s.NoError(err) s.Equal(wantReferrerSarif.Digest, got.Digest) s.Equal(true, got.Downloadable) @@ -309,7 +309,7 @@ func (s *referrerIntegrationTestSuite) TestExtractAndPersists() { s.T().Run("now there should a container image pointing to two attestations", func(t *testing.T) { // but retrieval should fail. In the future we will ask the user to provide the artifact type in these cases of ambiguity - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerContainerImage.Digest, "", s.user.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerContainerImage.Digest, "", s.user.ID) s.NoError(err) // it should be referenced by two attestations since it's subject of both require.Len(t, got.References, 2) @@ -320,7 +320,7 @@ func (s *referrerIntegrationTestSuite) TestExtractAndPersists() { }) s.T().Run("if all associated workflows are private, the referrer is private", func(t *testing.T) { - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerAtt.Digest, "", s.user.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerAtt.Digest, "", s.user.ID) s.NoError(err) s.False(got.InPublicWorkflow) s.Equal([]uuid.UUID{s.workflow1.ID, s.workflow2.ID}, got.WorkflowIDs) @@ -334,7 +334,7 @@ func (s *referrerIntegrationTestSuite) TestExtractAndPersists() { _, err := s.Workflow.Update(ctx, s.org1.ID, s.workflow1.ID.String(), &biz.WorkflowUpdateOpts{Public: toPtrBool(true)}) require.NoError(t, err) - got, err := s.Referrer.GetFromRoot(ctx, wantReferrerAtt.Digest, "", s.user.ID) + got, err := s.Referrer.GetFromRootUser(ctx, wantReferrerAtt.Digest, "", s.user.ID) s.NoError(err) s.True(got.InPublicWorkflow) for _, r := range got.References { diff --git a/app/controlplane/internal/service/referrer.go b/app/controlplane/internal/service/referrer.go index 01f0e8121..3d6628f59 100644 --- a/app/controlplane/internal/service/referrer.go +++ b/app/controlplane/internal/service/referrer.go @@ -17,9 +17,11 @@ package service import ( "context" + "fmt" pb "github.com/chainloop-dev/chainloop/app/controlplane/api/controlplane/v1" "github.com/chainloop-dev/chainloop/app/controlplane/internal/biz" + "github.com/google/uuid" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -38,18 +40,36 @@ func NewReferrerService(uc *biz.ReferrerUseCase, opts ...NewOpt) *ReferrerServic } func (s *ReferrerService) DiscoverPrivate(ctx context.Context, req *pb.ReferrerServiceDiscoverPrivateRequest) (*pb.ReferrerServiceDiscoverPrivateResponse, error) { - currentUser, err := requireCurrentUser(ctx) + currentUser, currentToken, err := requireCurrentUserOrAPIToken(ctx) if err != nil { return nil, err } - res, err := s.referrerUC.GetFromRoot(ctx, req.GetDigest(), req.GetKind(), currentUser.ID) + currentOrg, err := requireCurrentOrg(ctx) + if err != nil { + return nil, err + } + + // if we are logged in as user we find the referrer from the user + // otherwise for the current organization associated with the API token + var referrer *biz.StoredReferrer + if currentUser != nil { + referrer, err = s.referrerUC.GetFromRootUser(ctx, req.GetDigest(), req.GetKind(), currentUser.ID) + } else if currentToken != nil { + var orgUUID uuid.UUID + orgUUID, err = uuid.Parse(currentOrg.ID) + if err != nil { + return nil, fmt.Errorf("invalid org UUID: %w", err) + } + + referrer, err = s.referrerUC.GetFromRoot(ctx, req.GetDigest(), req.GetKind(), []uuid.UUID{orgUUID}) + } if err != nil { return nil, handleUseCaseErr("referrer discovery", err, s.log) } return &pb.ReferrerServiceDiscoverPrivateResponse{ - Result: bizReferrerToPb(res), + Result: bizReferrerToPb(referrer), }, nil }