Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions api/restHandler/UserRestHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/devtron-labs/devtron/pkg/user"
"github.com/devtron-labs/devtron/util/rbac"
"github.com/devtron-labs/devtron/util/response"
"github.com/go-pg/pg"
"github.com/gorilla/mux"
"go.uber.org/zap"
"gopkg.in/go-playground/validator.v9"
Expand Down Expand Up @@ -108,6 +109,32 @@ func (handler UserRestHandlerImpl) CreateUser(w http.ResponseWriter, r *http.Req
return
}
}

// auth check inside groups
if len(userInfo.Groups) > 0 {
groupRoles, err := handler.roleGroupService.FetchRolesForGroups(userInfo.Groups)
if err != nil && err != pg.ErrNoRows {
handler.logger.Errorw("service err, UpdateUser", "err", err, "payload", userInfo)
writeJsonResp(w, err, "", http.StatusInternalServerError)
return
}

if groupRoles != nil && len(groupRoles) > 0 {
for _, groupRole := range groupRoles {
if len(groupRole.Team) > 0 {
if ok := handler.enforcer.Enforce(token, rbac.ResourceUser, rbac.ActionCreate, groupRole.Team); !ok {
response.WriteResponse(http.StatusForbidden, "FORBIDDEN", w, errors.New("unauthorized"))
return
}
}
}
} else {
if ok := handler.enforcer.Enforce(token, rbac.ResourceUser, rbac.ActionCreate, "*"); !ok {
response.WriteResponse(http.StatusForbidden, "FORBIDDEN", w, errors.New("unauthorized"))
return
}
}
}
//RBAC enforcer Ends

handler.logger.Infow("request payload, CreateUser ", "payload", userInfo)
Expand Down Expand Up @@ -167,6 +194,32 @@ func (handler UserRestHandlerImpl) UpdateUser(w http.ResponseWriter, r *http.Req
return
}
}

// auth check inside groups
if len(userInfo.Groups) > 0 {
groupRoles, err := handler.roleGroupService.FetchRolesForGroups(userInfo.Groups)
if err != nil && err != pg.ErrNoRows {
handler.logger.Errorw("service err, UpdateUser", "err", err, "payload", userInfo)
writeJsonResp(w, err, "", http.StatusInternalServerError)
return
}

if groupRoles != nil && len(groupRoles) > 0 {
for _, groupRole := range groupRoles {
if len(groupRole.Team) > 0 {
if ok := handler.enforcer.Enforce(token, rbac.ResourceUser, rbac.ActionUpdate, groupRole.Team); !ok {
response.WriteResponse(http.StatusForbidden, "FORBIDDEN", w, errors.New("unauthorized"))
return
}
}
}
} else {
if ok := handler.enforcer.Enforce(token, rbac.ResourceUser, rbac.ActionUpdate, "*"); !ok {
response.WriteResponse(http.StatusForbidden, "FORBIDDEN", w, errors.New("unauthorized"))
return
}
}
}
//RBAC enforcer Ends

if userInfo.EmailId == "admin" {
Expand Down
22 changes: 21 additions & 1 deletion internal/sql/repository/RoleGroupRepository.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type RoleGroupRepository interface {
GetRoleGroupRoleMappingByRoleGroupId(roleGroupId int32) ([]*RoleGroupRoleMapping, error)
DeleteRoleGroupRoleMapping(model *RoleGroupRoleMapping, tx *pg.Tx) (bool, error)
GetConnection() (dbConnection *pg.DB)
GetRoleGroupListByNames(groupNames []string) ([]*RoleGroup, error)
GetRoleGroupRoleMappingByRoleGroupIds(roleGroupIds []int32) ([]*RoleModel, error)
}

type RoleGroupRepositoryImpl struct {
Expand All @@ -63,7 +65,7 @@ type RoleGroupRoleMapping struct {
Id int `sql:"id,pk"`
RoleGroupId int32 `sql:"role_group_id,notnull"`
RoleId int `sql:"role_id,notnull"`
//User UserModel
RoleModel *RoleModel
models.AuditLog
}

Expand Down Expand Up @@ -154,3 +156,21 @@ func (impl RoleGroupRepositoryImpl) DeleteRoleGroupRoleMapping(model *RoleGroupR
}
return true, nil
}

func (impl RoleGroupRepositoryImpl) GetRoleGroupListByNames(groupNames []string) ([]*RoleGroup, error) {
var model []*RoleGroup
err := impl.dbConnection.Model(&model).Where("name in (?)", pg.In(groupNames)).Where("active = ?", true).Order("updated_on desc").Select()
return model, err
}

func (impl RoleGroupRepositoryImpl) GetRoleGroupRoleMappingByRoleGroupIds(roleGroupIds []int32) ([]*RoleModel, error) {
var roleModels []*RoleModel
query := "SELECT r.* from roles r" +
" INNER JOIN role_group_role_mapping rgm on rgm.role_id=r.id" +
" WHERE rgm.role_group_id in (?);"
_, err := impl.dbConnection.Query(&roleModels, query, pg.In(roleGroupIds))
if err != nil {
return roleModels, err
}
return roleModels, nil
}
43 changes: 43 additions & 0 deletions pkg/user/RoleGroupService.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/devtron-labs/devtron/internal/constants"
"github.com/devtron-labs/devtron/internal/sql/repository"
"github.com/devtron-labs/devtron/internal/util"
"github.com/go-pg/pg"
"github.com/gorilla/sessions"
"go.uber.org/zap"
"strings"
Expand All @@ -40,6 +41,7 @@ type RoleGroupService interface {
FetchRoleGroups() ([]*bean.RoleGroup, error)
FetchRoleGroupsByName(name string) ([]*bean.RoleGroup, error)
DeleteRoleGroup(model *bean.RoleGroup) (bool, error)
FetchRolesForGroups(groupNames []string) ([]*bean.RoleFilter, error)
}

type RoleGroupServiceImpl struct {
Expand Down Expand Up @@ -525,3 +527,44 @@ func (impl RoleGroupServiceImpl) DeleteRoleGroup(bean *bean.RoleGroup) (bool, er

return true, nil
}

func (impl RoleGroupServiceImpl) FetchRolesForGroups(groupNames []string) ([]*bean.RoleFilter, error) {
if len(groupNames) == 0 {
return nil, nil
}
roleGroups, err := impl.roleGroupRepository.GetRoleGroupListByNames(groupNames)
if err != nil && err != pg.ErrNoRows {
impl.logger.Errorw("error while fetching user from db", "error", err)
return nil, err
}
if err == pg.ErrNoRows {
impl.logger.Warnw("no result found for role groups", "groups", groupNames)
return nil, nil
}

var roleGroupIds []int32
for _, roleGroup := range roleGroups {
roleGroupIds = append(roleGroupIds, roleGroup.Id)
}

roles, err := impl.roleGroupRepository.GetRoleGroupRoleMappingByRoleGroupIds(roleGroupIds)
if err != nil && err != pg.ErrNoRows {
impl.logger.Errorw("error while fetching user from db", "error", err)
return nil, err
}
list := make([]*bean.RoleFilter, 0)
if roles == nil {
return list, nil
}
for _, role := range roles {
bean := &bean.RoleFilter{
EntityName: role.EntityName,
Entity: role.Entity,
Action: role.Action,
Environment: role.Environment,
Team: role.Team,
}
list = append(list, bean)
}
return list, nil
}