Skip to content

Commit

Permalink
fix: bug in group, role, and policy (#147)
Browse files Browse the repository at this point in the history
* fix: bug in group, role, and policy

* fix: error message role update
  • Loading branch information
mabdh committed Aug 4, 2022
1 parent 4e280ae commit 88226c0
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 45 deletions.
4 changes: 2 additions & 2 deletions core/role/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func NewService(repository Repository) *Service {
func (s Service) Create(ctx context.Context, toCreate Role) (Role, error) {
roleID, err := s.repository.Create(ctx, toCreate)
if err != nil {
return Role{}, nil
return Role{}, err
}
return s.repository.Get(ctx, roleID)
}
Expand All @@ -31,7 +31,7 @@ func (s Service) List(ctx context.Context) ([]Role, error) {
func (s Service) Update(ctx context.Context, toUpdate Role) (Role, error) {
roleID, err := s.repository.Update(ctx, toUpdate)
if err != nil {
return Role{}, nil
return Role{}, err
}
return s.repository.Get(ctx, roleID)
}
21 changes: 11 additions & 10 deletions internal/api/v1beta1/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ func (h Handler) CreateGroup(ctx context.Context, request *shieldv1beta1.CreateG
}

newGroup, err := h.groupService.Create(ctx, grp)

if err != nil {
logger.Error(err.Error())
return nil, grpcInternalServerError
Expand Down Expand Up @@ -228,18 +227,20 @@ func (h Handler) UpdateGroup(ctx context.Context, request *shieldv1beta1.UpdateG
var updatedGroup group.Group
if uuid.IsValid(request.GetId()) {
updatedGroup, err = h.groupService.Update(ctx, group.Group{
ID: request.GetId(),
Name: request.GetBody().GetName(),
Slug: request.GetBody().GetSlug(),
Organization: organization.Organization{ID: request.GetBody().GetOrgId()},
Metadata: metaDataMap,
ID: request.GetId(),
Name: request.GetBody().GetName(),
Slug: request.GetBody().GetSlug(),
Organization: organization.Organization{ID: request.GetBody().GetOrgId()},
OrganizationID: request.GetBody().GetOrgId(),
Metadata: metaDataMap,
})
} else {
updatedGroup, err = h.groupService.Update(ctx, group.Group{
Name: request.GetBody().GetName(),
Slug: request.GetId(),
Organization: organization.Organization{ID: request.GetBody().GetOrgId()},
Metadata: metaDataMap,
Name: request.GetBody().GetName(),
Slug: request.GetId(),
Organization: organization.Organization{ID: request.GetBody().GetOrgId()},
OrganizationID: request.GetBody().GetOrgId(),
Metadata: metaDataMap,
})
}
if err != nil {
Expand Down
5 changes: 2 additions & 3 deletions internal/api/v1beta1/namespace.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,9 @@ func (h Handler) UpdateNamespace(ctx context.Context, request *shieldv1beta1.Upd
logger := grpczap.Extract(ctx)

updatedNS, err := h.namespaceService.Update(ctx, namespace.Namespace{
ID: request.GetBody().GetId(),
Name: request.GetBody().GetName(),
ID: request.GetId(),
Name: request.GetBody().Name,
})

if err != nil {
logger.Error(err.Error())
switch {
Expand Down
20 changes: 9 additions & 11 deletions internal/api/v1beta1/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,26 +118,24 @@ func (h Handler) UpdatePolicy(ctx context.Context, request *shieldv1beta1.Update
NamespaceID: request.GetBody().GetNamespaceId(),
ActionID: request.GetBody().GetActionId(),
})

if err != nil {
logger.Error(err.Error())
return nil, grpcInternalServerError
switch {
case errors.Is(err, policy.ErrNotExist):
return nil, grpcPolicyNotFoundErr
case errors.Is(err, policy.ErrConflict):
return nil, grpcConflictError
default:
return nil, grpcInternalServerError
}
}

for _, p := range updatedPolices {
policyPB, err := transformPolicyToPB(p)
if err != nil {
logger.Error(err.Error())
switch {
case errors.Is(err, policy.ErrNotExist):
return nil, grpcPolicyNotFoundErr
case errors.Is(err, policy.ErrConflict):
return nil, grpcConflictError
default:
return nil, grpcInternalServerError
}
return nil, grpcInternalServerError
}

policies = append(policies, &policyPB)
}

Expand Down
10 changes: 5 additions & 5 deletions internal/api/v1beta1/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,17 @@ func (h Handler) UpdateRole(ctx context.Context, request *shieldv1beta1.UpdateRo
}

updatedRole, err := h.roleService.Update(ctx, role.Role{
ID: request.GetBody().GetId(),
Name: request.GetBody().GetName(),
Types: request.GetBody().GetTypes(),
NamespaceID: request.GetBody().GetNamespaceId(),
ID: request.GetId(),
Name: request.GetBody().Name,
Types: request.GetBody().Types,
NamespaceID: request.GetBody().NamespaceId,
Metadata: metaDataMap,
})
if err != nil {
logger.Error(err.Error())
switch {
case errors.Is(err, role.ErrNotExist):
return nil, grpcProjectNotFoundErr
return nil, grpcRoleNotFoundErr
case errors.Is(err, role.ErrConflict):
return nil, grpcConflictError
default:
Expand Down
27 changes: 18 additions & 9 deletions internal/store/postgres/group_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,28 +458,37 @@ func (r GroupRepository) ListUserGroupSlugRelations(ctx context.Context, userID
}

func (r GroupRepository) ListUserGroups(ctx context.Context, userID string, roleID string) ([]group.Group, error) {
if roleID == "" || userID == "" {
if userID == "" {
return nil, group.ErrInvalidID
}

query, params, err := dialect.Select(
sqlStatement := dialect.Select(
goqu.I("g.id").As("id"),
goqu.I("g.metadata").As("metadata"),
goqu.I("g.name").As("name"),
goqu.I("g.slug").As("slug"),
goqu.I("g.updated_at").As("updated_at"),
goqu.I("g.created_at").As("created_at"),
goqu.I("g.org_id").As("org_id"),
).From(goqu.L("relations r")).
).
From(goqu.L("relations r")).
Join(goqu.L("groups g"), goqu.On(
goqu.I("g.id").Cast("VARCHAR").
Eq(goqu.I("r.object_id")),
)).Where(goqu.Ex{
"r.object_namespace_id": namespace.DefinitionTeam.ID,
"subject_namespace_id": namespace.DefinitionUser.ID,
"subject_id": userID,
"role_id": roleID,
}).ToSQL()
)).
Where(goqu.Ex{
"r.object_namespace_id": namespace.DefinitionTeam.ID,
"subject_namespace_id": namespace.DefinitionUser.ID,
"subject_id": userID,
})

if roleID != "" {
sqlStatement = sqlStatement.Where(goqu.Ex{
"role_id": roleID,
})
}

query, params, err := sqlStatement.ToSQL()
if err != nil {
return []group.Group{}, fmt.Errorf("%w: %s", queryErr, err)
}
Expand Down
16 changes: 11 additions & 5 deletions internal/store/postgres/group_repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,17 @@ func (s *GroupRepositoryTestSuite) TestListUserGroups() {
},
},
},
{
Description: "should not return error if role id is empty",
UserID: s.users[0].ID,
ExpectedGroups: []group.Group{
{
Name: "group1",
Slug: "group-1",
OrganizationID: s.orgs[0].ID,
},
},
},
{
Description: "should get empty groups if there is none",
UserID: s.users[1].ID,
Expand All @@ -798,11 +809,6 @@ func (s *GroupRepositoryTestSuite) TestListUserGroups() {
RoleID: role.DefinitionTeamMember.ID,
ErrString: group.ErrInvalidID.Error(),
},
{
Description: "should get error if group id is empty",
UserID: s.users[0].ID,
ErrString: group.ErrInvalidID.Error(),
},
}

for _, tc := range testCases {
Expand Down

0 comments on commit 88226c0

Please sign in to comment.