Skip to content

Commit

Permalink
Update pkg/osv to allow overriding the http client / transport. (#357)
Browse files Browse the repository at this point in the history
I don't know if you all intended `pkg/osv` to be a general purpose
https://osv.dev/ client library, but we are using it as that here:
https://github.com/guacsec/guac/blob/main/pkg/certifier/osv/osv.go

It would be nice to override the http client used to insert custom
transports, etc.

Signed-off-by: Jeff Mendoza <jlm@jlm.name>
  • Loading branch information
jeffmendoza committed May 1, 2023
1 parent 8af3d62 commit 190aea2
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions pkg/osv/osv.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ func checkResponseError(resp *http.Response) error {

// MakeRequest sends a batched query to osv.dev
func MakeRequest(request BatchedQuery) (*BatchedResponse, error) {
return MakeRequestWithClient(request, http.DefaultClient)
}

// MakeRequestWithClient sends a batched query to osv.dev with the provided
// http client.
func MakeRequestWithClient(request BatchedQuery, client *http.Client) (*BatchedResponse, error) {
// API has a limit of 1000 bulk query per request
queryChunks := chunkBy(request.Queries, maxQueriesPerRequest)
var totalOsvResp BatchedResponse
Expand All @@ -140,7 +146,7 @@ func MakeRequest(request BatchedQuery) (*BatchedResponse, error) {
resp, err := makeRetryRequest(func() (*http.Response, error) {
// We do not need a specific context
//nolint:noctx
return http.Post(QueryEndpoint, "application/json", requestBuf)
return client.Post(QueryEndpoint, "application/json", requestBuf)
})
if err != nil {
return nil, err
Expand All @@ -166,9 +172,15 @@ func MakeRequest(request BatchedQuery) (*BatchedResponse, error) {

// Get a Vulnerability for the given ID.
func Get(id string) (*models.Vulnerability, error) {
return GetWithClient(id, http.DefaultClient)
}

// GetWithClient gets a Vulnerability for the given ID with the provided http
// client.
func GetWithClient(id string, client *http.Client) (*models.Vulnerability, error) {
resp, err := makeRetryRequest(func() (*http.Response, error) {
//nolint:noctx
return http.Get(GetEndpoint + "/" + id)
return client.Get(GetEndpoint + "/" + id)
})
if err != nil {
return nil, err
Expand All @@ -192,6 +204,12 @@ func Get(id string) (*models.Vulnerability, error) {
// Hydrate fills the results of the batched response with the full
// Vulnerability details.
func Hydrate(resp *BatchedResponse) (*HydratedBatchedResponse, error) {
return HydrateWithClient(resp, http.DefaultClient)
}

// HydrateWithClient fills the results of the batched response with the full
// Vulnerability details using the provided http client.
func HydrateWithClient(resp *BatchedResponse, client *http.Client) (*HydratedBatchedResponse, error) {
hydrated := HydratedBatchedResponse{}
ctx := context.TODO()
// Preallocate the array to avoid slice reallocations when inserting later
Expand All @@ -211,7 +229,7 @@ func Hydrate(resp *BatchedResponse) (*HydratedBatchedResponse, error) {
}

go func(id string, batchIdx int, resultIdx int) {
vuln, err := Get(id)
vuln, err := GetWithClient(id, client)
if err != nil {
errChan <- err
} else {
Expand Down

0 comments on commit 190aea2

Please sign in to comment.