From 3b5143aca871cb5244a58eee1c5b20af4863a82e Mon Sep 17 00:00:00 2001 From: Miguel Martinez Trivino Date: Sun, 16 Apr 2023 16:32:03 +0200 Subject: [PATCH 1/2] fix: close GRPC connection Signed-off-by: Miguel Martinez Trivino --- app/controlplane/internal/biz/casclient.go | 22 +++++++++++-------- .../internal/biz/casclient_test.go | 4 ++-- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/app/controlplane/internal/biz/casclient.go b/app/controlplane/internal/biz/casclient.go index c0a0072f7..513fb4b6c 100644 --- a/app/controlplane/internal/biz/casclient.go +++ b/app/controlplane/internal/biz/casclient.go @@ -52,7 +52,8 @@ type CASClient interface { CASDownloader } -type CASClientFactory func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error) +// Function that returns a CAS client including a connection closer method +type CASClientFactory func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func() error, error) type CASClientOpts func(u *CASClientUseCase) func WithClientFactory(f CASClientFactory) CASClientOpts { @@ -63,13 +64,13 @@ func WithClientFactory(f CASClientFactory) CASClientOpts { func NewCASClientUseCase(credsProvider *CASCredentialsUseCase, config *conf.Bootstrap_CASServer, l log.Logger, opts ...CASClientOpts) *CASClientUseCase { // generate a client from the given configuration - defaultCasClientFactory := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error) { + defaultCasClientFactory := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func() error, error) { conn, err := grpcconn.New(conf.GetGrpc().GetAddr(), token, conf.GetInsecure()) if err != nil { - return nil, fmt.Errorf("failed to create grpc connection: %w", err) + return nil, nil, fmt.Errorf("failed to create grpc connection: %w", err) } - return casclient.New(conn), nil + return casclient.New(conn), conn.Close, err } uc := &CASClientUseCase{ @@ -91,10 +92,11 @@ func (uc *CASClientUseCase) Upload(ctx context.Context, secretID string, content uc.logger.Infow("msg", "upload initialized", "filename", filename, "digest", digest) // client with temporary set of credentials - client, err := uc.casAPIClient(secretID, casJWT.Uploader) + client, closeFn, err := uc.casAPIClient(secretID, casJWT.Uploader) if err != nil { return fmt.Errorf("failed to create cas client: %w", err) } + defer closeFn() status, err := client.Upload(ctx, content, filename, digest) if err != nil { @@ -109,10 +111,11 @@ func (uc *CASClientUseCase) Upload(ctx context.Context, secretID string, content func (uc *CASClientUseCase) Download(ctx context.Context, secretID string, w io.Writer, digest string) error { uc.logger.Infow("msg", "download initialized", "digest", digest) - client, err := uc.casAPIClient(secretID, casJWT.Downloader) + client, closeFn, err := uc.casAPIClient(secretID, casJWT.Downloader) if err != nil { return fmt.Errorf("failed to create cas client: %w", err) } + defer closeFn() if err := client.Download(ctx, w, digest); err != nil { return fmt.Errorf("failed to download content: %w", err) @@ -124,10 +127,10 @@ func (uc *CASClientUseCase) Download(ctx context.Context, secretID string, w io. } // create a client with a temporary set of credentials for a specific operation -func (uc *CASClientUseCase) casAPIClient(secretID string, role casJWT.Role) (casclient.DownloaderUploader, error) { +func (uc *CASClientUseCase) casAPIClient(secretID string, role casJWT.Role) (casclient.DownloaderUploader, func() error, error) { token, err := uc.credsProvider.GenerateTemporaryCredentials(secretID, role) if err != nil { - return nil, fmt.Errorf("failed to generate temporary credentials: %w", err) + return nil, nil, fmt.Errorf("failed to generate temporary credentials: %w", err) } // Initialize connection to CAS server @@ -145,10 +148,11 @@ func (uc *CASClientUseCase) IsReady(ctx context.Context) (bool, error) { return false, fmt.Errorf("invalid CAS client configuration: %w", err) } - c, err := uc.casClientFactory(uc.casServerConf, "") + c, closeFn, err := uc.casClientFactory(uc.casServerConf, "") if err != nil { return false, fmt.Errorf("failed to create CAS client: %w", err) } + defer closeFn() return c.IsReady(ctx) } diff --git a/app/controlplane/internal/biz/casclient_test.go b/app/controlplane/internal/biz/casclient_test.go index c3a2a8959..cc7fd7dd6 100644 --- a/app/controlplane/internal/biz/casclient_test.go +++ b/app/controlplane/internal/biz/casclient_test.go @@ -65,10 +65,10 @@ func TestIsReady(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - clientProvider := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error) { + clientProvider := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func() error, error) { c := mocks.NewDownloaderUploader(t) c.On("IsReady", mock.Anything).Return(tc.casReady, nil) - return c, nil + return c, func() error { return nil }, nil } uc := biz.NewCASClientUseCase(nil, tc.config, nil, biz.WithClientFactory(clientProvider)) From ace30896babc710b1f00365fe63ce879a3528e12 Mon Sep 17 00:00:00 2001 From: Miguel Martinez Trivino Date: Sun, 16 Apr 2023 16:48:22 +0200 Subject: [PATCH 2/2] Fix linter Signed-off-by: Miguel Martinez Trivino --- app/controlplane/internal/biz/casclient.go | 19 ++++++++++++++----- .../internal/biz/casclient_test.go | 4 ++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/app/controlplane/internal/biz/casclient.go b/app/controlplane/internal/biz/casclient.go index 513fb4b6c..f9ac9e399 100644 --- a/app/controlplane/internal/biz/casclient.go +++ b/app/controlplane/internal/biz/casclient.go @@ -53,7 +53,7 @@ type CASClient interface { } // Function that returns a CAS client including a connection closer method -type CASClientFactory func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func() error, error) +type CASClientFactory func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func(), error) type CASClientOpts func(u *CASClientUseCase) func WithClientFactory(f CASClientFactory) CASClientOpts { @@ -63,20 +63,29 @@ func WithClientFactory(f CASClientFactory) CASClientOpts { } func NewCASClientUseCase(credsProvider *CASCredentialsUseCase, config *conf.Bootstrap_CASServer, l log.Logger, opts ...CASClientOpts) *CASClientUseCase { + helper := servicelogger.ScopedHelper(l, "biz/cas-client") + // generate a client from the given configuration - defaultCasClientFactory := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func() error, error) { + defaultCasClientFactory := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func(), error) { conn, err := grpcconn.New(conf.GetGrpc().GetAddr(), token, conf.GetInsecure()) if err != nil { return nil, nil, fmt.Errorf("failed to create grpc connection: %w", err) } - return casclient.New(conn), conn.Close, err + closerFn := func() { + err := conn.Close() + if err != nil { + helper.Error(err) + } + } + + return casclient.New(conn), closerFn, err } uc := &CASClientUseCase{ credsProvider: credsProvider, casServerConf: config, - logger: servicelogger.ScopedHelper(l, "biz/cas-client"), + logger: helper, casClientFactory: defaultCasClientFactory, } @@ -127,7 +136,7 @@ func (uc *CASClientUseCase) Download(ctx context.Context, secretID string, w io. } // create a client with a temporary set of credentials for a specific operation -func (uc *CASClientUseCase) casAPIClient(secretID string, role casJWT.Role) (casclient.DownloaderUploader, func() error, error) { +func (uc *CASClientUseCase) casAPIClient(secretID string, role casJWT.Role) (casclient.DownloaderUploader, func(), error) { token, err := uc.credsProvider.GenerateTemporaryCredentials(secretID, role) if err != nil { return nil, nil, fmt.Errorf("failed to generate temporary credentials: %w", err) diff --git a/app/controlplane/internal/biz/casclient_test.go b/app/controlplane/internal/biz/casclient_test.go index cc7fd7dd6..e48cba9e4 100644 --- a/app/controlplane/internal/biz/casclient_test.go +++ b/app/controlplane/internal/biz/casclient_test.go @@ -65,10 +65,10 @@ func TestIsReady(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - clientProvider := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func() error, error) { + clientProvider := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func(), error) { c := mocks.NewDownloaderUploader(t) c.On("IsReady", mock.Anything).Return(tc.casReady, nil) - return c, func() error { return nil }, nil + return c, func() {}, nil } uc := biz.NewCASClientUseCase(nil, tc.config, nil, biz.WithClientFactory(clientProvider))