Skip to content

Commit 254e509

Browse files
authored
fix: close GRPC connection (#70)
Signed-off-by: Miguel Martinez Trivino <miguel@chainloop.dev>
1 parent 291b931 commit 254e509

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

app/controlplane/internal/biz/casclient.go

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ type CASClient interface {
5252
CASDownloader
5353
}
5454

55-
type CASClientFactory func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error)
55+
// Function that returns a CAS client including a connection closer method
56+
type CASClientFactory func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func(), error)
5657
type CASClientOpts func(u *CASClientUseCase)
5758

5859
func WithClientFactory(f CASClientFactory) CASClientOpts {
@@ -62,20 +63,29 @@ func WithClientFactory(f CASClientFactory) CASClientOpts {
6263
}
6364

6465
func NewCASClientUseCase(credsProvider *CASCredentialsUseCase, config *conf.Bootstrap_CASServer, l log.Logger, opts ...CASClientOpts) *CASClientUseCase {
66+
helper := servicelogger.ScopedHelper(l, "biz/cas-client")
67+
6568
// generate a client from the given configuration
66-
defaultCasClientFactory := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error) {
69+
defaultCasClientFactory := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func(), error) {
6770
conn, err := grpcconn.New(conf.GetGrpc().GetAddr(), token, conf.GetInsecure())
6871
if err != nil {
69-
return nil, fmt.Errorf("failed to create grpc connection: %w", err)
72+
return nil, nil, fmt.Errorf("failed to create grpc connection: %w", err)
73+
}
74+
75+
closerFn := func() {
76+
err := conn.Close()
77+
if err != nil {
78+
helper.Error(err)
79+
}
7080
}
7181

72-
return casclient.New(conn), nil
82+
return casclient.New(conn), closerFn, err
7383
}
7484

7585
uc := &CASClientUseCase{
7686
credsProvider: credsProvider,
7787
casServerConf: config,
78-
logger: servicelogger.ScopedHelper(l, "biz/cas-client"),
88+
logger: helper,
7989
casClientFactory: defaultCasClientFactory,
8090
}
8191

@@ -91,10 +101,11 @@ func (uc *CASClientUseCase) Upload(ctx context.Context, secretID string, content
91101
uc.logger.Infow("msg", "upload initialized", "filename", filename, "digest", digest)
92102

93103
// client with temporary set of credentials
94-
client, err := uc.casAPIClient(secretID, casJWT.Uploader)
104+
client, closeFn, err := uc.casAPIClient(secretID, casJWT.Uploader)
95105
if err != nil {
96106
return fmt.Errorf("failed to create cas client: %w", err)
97107
}
108+
defer closeFn()
98109

99110
status, err := client.Upload(ctx, content, filename, digest)
100111
if err != nil {
@@ -109,10 +120,11 @@ func (uc *CASClientUseCase) Upload(ctx context.Context, secretID string, content
109120
func (uc *CASClientUseCase) Download(ctx context.Context, secretID string, w io.Writer, digest string) error {
110121
uc.logger.Infow("msg", "download initialized", "digest", digest)
111122

112-
client, err := uc.casAPIClient(secretID, casJWT.Downloader)
123+
client, closeFn, err := uc.casAPIClient(secretID, casJWT.Downloader)
113124
if err != nil {
114125
return fmt.Errorf("failed to create cas client: %w", err)
115126
}
127+
defer closeFn()
116128

117129
if err := client.Download(ctx, w, digest); err != nil {
118130
return fmt.Errorf("failed to download content: %w", err)
@@ -124,10 +136,10 @@ func (uc *CASClientUseCase) Download(ctx context.Context, secretID string, w io.
124136
}
125137

126138
// create a client with a temporary set of credentials for a specific operation
127-
func (uc *CASClientUseCase) casAPIClient(secretID string, role casJWT.Role) (casclient.DownloaderUploader, error) {
139+
func (uc *CASClientUseCase) casAPIClient(secretID string, role casJWT.Role) (casclient.DownloaderUploader, func(), error) {
128140
token, err := uc.credsProvider.GenerateTemporaryCredentials(secretID, role)
129141
if err != nil {
130-
return nil, fmt.Errorf("failed to generate temporary credentials: %w", err)
142+
return nil, nil, fmt.Errorf("failed to generate temporary credentials: %w", err)
131143
}
132144

133145
// Initialize connection to CAS server
@@ -145,10 +157,11 @@ func (uc *CASClientUseCase) IsReady(ctx context.Context) (bool, error) {
145157
return false, fmt.Errorf("invalid CAS client configuration: %w", err)
146158
}
147159

148-
c, err := uc.casClientFactory(uc.casServerConf, "")
160+
c, closeFn, err := uc.casClientFactory(uc.casServerConf, "")
149161
if err != nil {
150162
return false, fmt.Errorf("failed to create CAS client: %w", err)
151163
}
164+
defer closeFn()
152165

153166
return c.IsReady(ctx)
154167
}

app/controlplane/internal/biz/casclient_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ func TestIsReady(t *testing.T) {
6565

6666
for _, tc := range testCases {
6767
t.Run(tc.name, func(t *testing.T) {
68-
clientProvider := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error) {
68+
clientProvider := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, func(), error) {
6969
c := mocks.NewDownloaderUploader(t)
7070
c.On("IsReady", mock.Anything).Return(tc.casReady, nil)
71-
return c, nil
71+
return c, func() {}, nil
7272
}
7373
uc := biz.NewCASClientUseCase(nil, tc.config, nil, biz.WithClientFactory(clientProvider))
7474

0 commit comments

Comments
 (0)