Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions app/controlplane/internal/biz/casclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ type CASClient interface {
CASDownloader
}

type CASClientFactory func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error)
// Function that returns a CAS client including a connection closer method
type CASClientFactory func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func(), error)
type CASClientOpts func(u *CASClientUseCase)

func WithClientFactory(f CASClientFactory) CASClientOpts {
Expand All @@ -62,20 +63,29 @@ func WithClientFactory(f CASClientFactory) CASClientOpts {
}

func NewCASClientUseCase(credsProvider *CASCredentialsUseCase, config *conf.Bootstrap_CASServer, l log.Logger, opts ...CASClientOpts) *CASClientUseCase {
helper := servicelogger.ScopedHelper(l, "biz/cas-client")

// generate a client from the given configuration
defaultCasClientFactory := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error) {
defaultCasClientFactory := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func(), error) {
conn, err := grpcconn.New(conf.GetGrpc().GetAddr(), token, conf.GetInsecure())
if err != nil {
return nil, fmt.Errorf("failed to create grpc connection: %w", err)
return nil, nil, fmt.Errorf("failed to create grpc connection: %w", err)
}

closerFn := func() {
err := conn.Close()
if err != nil {
helper.Error(err)
}
}

return casclient.New(conn), nil
return casclient.New(conn), closerFn, err
}

uc := &CASClientUseCase{
credsProvider: credsProvider,
casServerConf: config,
logger: servicelogger.ScopedHelper(l, "biz/cas-client"),
logger: helper,
casClientFactory: defaultCasClientFactory,
}

Expand All @@ -91,10 +101,11 @@ func (uc *CASClientUseCase) Upload(ctx context.Context, secretID string, content
uc.logger.Infow("msg", "upload initialized", "filename", filename, "digest", digest)

// client with temporary set of credentials
client, err := uc.casAPIClient(secretID, casJWT.Uploader)
client, closeFn, err := uc.casAPIClient(secretID, casJWT.Uploader)
if err != nil {
return fmt.Errorf("failed to create cas client: %w", err)
}
defer closeFn()

status, err := client.Upload(ctx, content, filename, digest)
if err != nil {
Expand All @@ -109,10 +120,11 @@ func (uc *CASClientUseCase) Upload(ctx context.Context, secretID string, content
func (uc *CASClientUseCase) Download(ctx context.Context, secretID string, w io.Writer, digest string) error {
uc.logger.Infow("msg", "download initialized", "digest", digest)

client, err := uc.casAPIClient(secretID, casJWT.Downloader)
client, closeFn, err := uc.casAPIClient(secretID, casJWT.Downloader)
if err != nil {
return fmt.Errorf("failed to create cas client: %w", err)
}
defer closeFn()

if err := client.Download(ctx, w, digest); err != nil {
return fmt.Errorf("failed to download content: %w", err)
Expand All @@ -124,10 +136,10 @@ func (uc *CASClientUseCase) Download(ctx context.Context, secretID string, w io.
}

// create a client with a temporary set of credentials for a specific operation
func (uc *CASClientUseCase) casAPIClient(secretID string, role casJWT.Role) (casclient.DownloaderUploader, error) {
func (uc *CASClientUseCase) casAPIClient(secretID string, role casJWT.Role) (casclient.DownloaderUploader, func(), error) {
token, err := uc.credsProvider.GenerateTemporaryCredentials(secretID, role)
if err != nil {
return nil, fmt.Errorf("failed to generate temporary credentials: %w", err)
return nil, nil, fmt.Errorf("failed to generate temporary credentials: %w", err)
}

// Initialize connection to CAS server
Expand All @@ -145,10 +157,11 @@ func (uc *CASClientUseCase) IsReady(ctx context.Context) (bool, error) {
return false, fmt.Errorf("invalid CAS client configuration: %w", err)
}

c, err := uc.casClientFactory(uc.casServerConf, "")
c, closeFn, err := uc.casClientFactory(uc.casServerConf, "")
if err != nil {
return false, fmt.Errorf("failed to create CAS client: %w", err)
}
defer closeFn()

return c.IsReady(ctx)
}
4 changes: 2 additions & 2 deletions app/controlplane/internal/biz/casclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ func TestIsReady(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
clientProvider := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error) {
clientProvider := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func(), error) {
c := mocks.NewDownloaderUploader(t)
c.On("IsReady", mock.Anything).Return(tc.casReady, nil)
return c, nil
return c, func() {}, nil
}
uc := biz.NewCASClientUseCase(nil, tc.config, nil, biz.WithClientFactory(clientProvider))

Expand Down