diff --git a/api/restHandler/UserRestHandler.go b/api/restHandler/UserRestHandler.go index 39e10eabc3..ee29355911 100644 --- a/api/restHandler/UserRestHandler.go +++ b/api/restHandler/UserRestHandler.go @@ -26,6 +26,7 @@ import ( "github.com/devtron-labs/devtron/pkg/user" "github.com/devtron-labs/devtron/util/rbac" "github.com/devtron-labs/devtron/util/response" + "github.com/go-pg/pg" "github.com/gorilla/mux" "go.uber.org/zap" "gopkg.in/go-playground/validator.v9" @@ -108,6 +109,32 @@ func (handler UserRestHandlerImpl) CreateUser(w http.ResponseWriter, r *http.Req return } } + + // auth check inside groups + if len(userInfo.Groups) > 0 { + groupRoles, err := handler.roleGroupService.FetchRolesForGroups(userInfo.Groups) + if err != nil && err != pg.ErrNoRows { + handler.logger.Errorw("service err, UpdateUser", "err", err, "payload", userInfo) + writeJsonResp(w, err, "", http.StatusInternalServerError) + return + } + + if groupRoles != nil && len(groupRoles) > 0 { + for _, groupRole := range groupRoles { + if len(groupRole.Team) > 0 { + if ok := handler.enforcer.Enforce(token, rbac.ResourceUser, rbac.ActionCreate, groupRole.Team); !ok { + response.WriteResponse(http.StatusForbidden, "FORBIDDEN", w, errors.New("unauthorized")) + return + } + } + } + } else { + if ok := handler.enforcer.Enforce(token, rbac.ResourceUser, rbac.ActionCreate, "*"); !ok { + response.WriteResponse(http.StatusForbidden, "FORBIDDEN", w, errors.New("unauthorized")) + return + } + } + } //RBAC enforcer Ends handler.logger.Infow("request payload, CreateUser ", "payload", userInfo) @@ -167,6 +194,32 @@ func (handler UserRestHandlerImpl) UpdateUser(w http.ResponseWriter, r *http.Req return } } + + // auth check inside groups + if len(userInfo.Groups) > 0 { + groupRoles, err := handler.roleGroupService.FetchRolesForGroups(userInfo.Groups) + if err != nil && err != pg.ErrNoRows { + handler.logger.Errorw("service err, UpdateUser", "err", err, "payload", userInfo) + writeJsonResp(w, err, "", http.StatusInternalServerError) + return + } + + if groupRoles != nil && len(groupRoles) > 0 { + for _, groupRole := range groupRoles { + if len(groupRole.Team) > 0 { + if ok := handler.enforcer.Enforce(token, rbac.ResourceUser, rbac.ActionUpdate, groupRole.Team); !ok { + response.WriteResponse(http.StatusForbidden, "FORBIDDEN", w, errors.New("unauthorized")) + return + } + } + } + } else { + if ok := handler.enforcer.Enforce(token, rbac.ResourceUser, rbac.ActionUpdate, "*"); !ok { + response.WriteResponse(http.StatusForbidden, "FORBIDDEN", w, errors.New("unauthorized")) + return + } + } + } //RBAC enforcer Ends if userInfo.EmailId == "admin" { diff --git a/internal/sql/repository/RoleGroupRepository.go b/internal/sql/repository/RoleGroupRepository.go index d9777a4d8d..64c7ca4020 100644 --- a/internal/sql/repository/RoleGroupRepository.go +++ b/internal/sql/repository/RoleGroupRepository.go @@ -37,6 +37,8 @@ type RoleGroupRepository interface { GetRoleGroupRoleMappingByRoleGroupId(roleGroupId int32) ([]*RoleGroupRoleMapping, error) DeleteRoleGroupRoleMapping(model *RoleGroupRoleMapping, tx *pg.Tx) (bool, error) GetConnection() (dbConnection *pg.DB) + GetRoleGroupListByNames(groupNames []string) ([]*RoleGroup, error) + GetRoleGroupRoleMappingByRoleGroupIds(roleGroupIds []int32) ([]*RoleModel, error) } type RoleGroupRepositoryImpl struct { @@ -63,7 +65,7 @@ type RoleGroupRoleMapping struct { Id int `sql:"id,pk"` RoleGroupId int32 `sql:"role_group_id,notnull"` RoleId int `sql:"role_id,notnull"` - //User UserModel + RoleModel *RoleModel models.AuditLog } @@ -154,3 +156,21 @@ func (impl RoleGroupRepositoryImpl) DeleteRoleGroupRoleMapping(model *RoleGroupR } return true, nil } + +func (impl RoleGroupRepositoryImpl) GetRoleGroupListByNames(groupNames []string) ([]*RoleGroup, error) { + var model []*RoleGroup + err := impl.dbConnection.Model(&model).Where("name in (?)", pg.In(groupNames)).Where("active = ?", true).Order("updated_on desc").Select() + return model, err +} + +func (impl RoleGroupRepositoryImpl) GetRoleGroupRoleMappingByRoleGroupIds(roleGroupIds []int32) ([]*RoleModel, error) { + var roleModels []*RoleModel + query := "SELECT r.* from roles r" + + " INNER JOIN role_group_role_mapping rgm on rgm.role_id=r.id" + + " WHERE rgm.role_group_id in (?);" + _, err := impl.dbConnection.Query(&roleModels, query, pg.In(roleGroupIds)) + if err != nil { + return roleModels, err + } + return roleModels, nil +} diff --git a/pkg/user/RoleGroupService.go b/pkg/user/RoleGroupService.go index ec3a883307..580c8f1629 100644 --- a/pkg/user/RoleGroupService.go +++ b/pkg/user/RoleGroupService.go @@ -26,6 +26,7 @@ import ( "github.com/devtron-labs/devtron/internal/constants" "github.com/devtron-labs/devtron/internal/sql/repository" "github.com/devtron-labs/devtron/internal/util" + "github.com/go-pg/pg" "github.com/gorilla/sessions" "go.uber.org/zap" "strings" @@ -40,6 +41,7 @@ type RoleGroupService interface { FetchRoleGroups() ([]*bean.RoleGroup, error) FetchRoleGroupsByName(name string) ([]*bean.RoleGroup, error) DeleteRoleGroup(model *bean.RoleGroup) (bool, error) + FetchRolesForGroups(groupNames []string) ([]*bean.RoleFilter, error) } type RoleGroupServiceImpl struct { @@ -525,3 +527,44 @@ func (impl RoleGroupServiceImpl) DeleteRoleGroup(bean *bean.RoleGroup) (bool, er return true, nil } + +func (impl RoleGroupServiceImpl) FetchRolesForGroups(groupNames []string) ([]*bean.RoleFilter, error) { + if len(groupNames) == 0 { + return nil, nil + } + roleGroups, err := impl.roleGroupRepository.GetRoleGroupListByNames(groupNames) + if err != nil && err != pg.ErrNoRows { + impl.logger.Errorw("error while fetching user from db", "error", err) + return nil, err + } + if err == pg.ErrNoRows { + impl.logger.Warnw("no result found for role groups", "groups", groupNames) + return nil, nil + } + + var roleGroupIds []int32 + for _, roleGroup := range roleGroups { + roleGroupIds = append(roleGroupIds, roleGroup.Id) + } + + roles, err := impl.roleGroupRepository.GetRoleGroupRoleMappingByRoleGroupIds(roleGroupIds) + if err != nil && err != pg.ErrNoRows { + impl.logger.Errorw("error while fetching user from db", "error", err) + return nil, err + } + list := make([]*bean.RoleFilter, 0) + if roles == nil { + return list, nil + } + for _, role := range roles { + bean := &bean.RoleFilter{ + EntityName: role.EntityName, + Entity: role.Entity, + Action: role.Action, + Environment: role.Environment, + Team: role.Team, + } + list = append(list, bean) + } + return list, nil +}