Skip to content

Commit

Permalink
fix: do not call read auth model on batchcheck and write (#78)
Browse files Browse the repository at this point in the history
* fix: do not call read auth model on batchcheck and write

* refactor: remove inlining of model id check, fix returning for non transaction and improve tests

* test: add tests for 400 when transaction mode is disabled

* refactor: apply suggestions from code review

Co-authored-by: Raghd Hamzeh <raghd.hamzeh@auth0.com>

---------

Co-authored-by: Raghd Hamzeh <raghd.hamzeh@auth0.com>
  • Loading branch information
ewanharris and rhamzeh committed Apr 30, 2024
1 parent b6a8881 commit c50b18f
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 123 deletions.
54 changes: 28 additions & 26 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,19 +461,6 @@ func (client *OpenFgaClient) getAuthorizationModelId(authorizationModelId *strin
return &modelId, nil
}

// helper function to validate the connection (i.e., get token)
func (client *OpenFgaClient) checkValidApiConnection(ctx _context.Context, authorizationModelId *string) error {
if authorizationModelId != nil && *authorizationModelId != "" {
_, _, err := client.OpenFgaApi.ReadAuthorizationModel(ctx, *authorizationModelId).Execute()
return err
} else {
_, err := client.ReadAuthorizationModels(ctx).Options(ClientReadAuthorizationModelsOptions{
PageSize: fgaSdk.PtrInt32(1),
}).Execute()
return err
}
}

/* Stores */

// / ListStores
Expand Down Expand Up @@ -1400,17 +1387,13 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
}

writeGroup, ctx := errgroup.WithContext(request.GetContext())
err = client.checkValidApiConnection(ctx, authorizationModelId)
if err != nil {
return nil, err
}

writeGroup.SetLimit(int(maxParallelReqs))
writeResponses := make([]ClientWriteResponse, len(writeChunks))
for index, writeBody := range writeChunks {
index, writeBody := index, writeBody
writeGroup.Go(func() error {
singleResponse, _ := client.WriteExecute(&SdkClientWriteRequest{
singleResponse, err := client.WriteExecute(&SdkClientWriteRequest{
ctx: ctx,
Client: client,
body: &ClientWriteRequest{
Expand All @@ -1421,13 +1404,21 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
},
})

if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok {
return err
}

writeResponses[index] = *singleResponse

return nil
})
}

_ = writeGroup.Wait()
err = writeGroup.Wait()
// If an error was returned then it will be an authentication error so we want to return
if err != nil {
return &response, err
}

var deleteChunkSize = int(maxPerChunk)
var deleteChunks [][]ClientTupleKeyWithoutCondition
Expand All @@ -1445,7 +1436,7 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
for index, deleteBody := range deleteChunks {
index, deleteBody := index, deleteBody
deleteGroup.Go(func() error {
singleResponse, _ := client.WriteExecute(&SdkClientWriteRequest{
singleResponse, err := client.WriteExecute(&SdkClientWriteRequest{
ctx: ctx,
Client: client,
body: &ClientWriteRequest{
Expand All @@ -1456,13 +1447,21 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
},
})

if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok {
return err
}

deleteResponses[index] = *singleResponse

return nil
})
}

_ = deleteGroup.Wait()
err = deleteGroup.Wait()
if err != nil {
// If an error was returned then it will be an authentication error so we want to return
return &response, err
}

for _, writeResponse := range writeResponses {
for _, writeSingleResponse := range writeResponse.Writes {
Expand Down Expand Up @@ -1695,9 +1694,11 @@ func (client *OpenFgaClient) CheckExecute(request SdkClientCheckRequestInterface
}
}
authorizationModelId, err := client.getAuthorizationModelId(request.GetAuthorizationModelIdOverride())

if err != nil {
return nil, err
}

requestBody := fgaSdk.CheckRequest{
TupleKey: fgaSdk.CheckRequestTupleKey{
User: request.GetBody().User,
Expand Down Expand Up @@ -1801,15 +1802,11 @@ func (client *OpenFgaClient) BatchCheckExecute(request SdkClientBatchCheckReques
var numOfChecks = len(*request.GetBody())
response := make(ClientBatchCheckResponse, numOfChecks)
authorizationModelId, err := client.getAuthorizationModelId(request.GetAuthorizationModelIdOverride())

if err != nil {
return nil, err
}

group.Go(func() error {
// if the connection is probelmatic, we will return error to the overall
// response rather than individual response
return client.checkValidApiConnection(ctx, authorizationModelId)
})
for index, checkBody := range *request.GetBody() {
index, checkBody := index, checkBody
group.Go(func() error {
Expand All @@ -1822,6 +1819,11 @@ func (client *OpenFgaClient) BatchCheckExecute(request SdkClientBatchCheckReques
},
})

if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok {
return err

}

response[index] = ClientBatchCheckSingleResponse{
Request: checkBody,
ClientCheckResponse: *singleResponse,
Expand Down
Loading

0 comments on commit c50b18f

Please sign in to comment.