Skip to content

Commit

Permalink
chore: add template_with_user view to include user contextual data (#…
Browse files Browse the repository at this point in the history
…8568)

* chore: Refactor template sql queries to use new view
* TemplateWithUser -> Template
* Add unit test to enforce good view
  • Loading branch information
Emyrk committed Jul 19, 2023
1 parent cdbae29 commit aceedef
Show file tree
Hide file tree
Showing 25 changed files with 453 additions and 358 deletions.
23 changes: 13 additions & 10 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -1808,9 +1808,12 @@ func (q *querier) InsertReplica(ctx context.Context, arg database.InsertReplicaP
return q.db.InsertReplica(ctx, arg)
}

func (q *querier) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) (database.Template, error) {
func (q *querier) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
obj := rbac.ResourceTemplate.InOrg(arg.OrganizationID)
return insert(q.log, q.auth, obj, q.db.InsertTemplate)(ctx, arg)
if err := q.authorizeContext(ctx, rbac.ActionCreate, obj); err != nil {
return err
}
return q.db.InsertTemplate(ctx, arg)
}

func (q *querier) InsertTemplateVersion(ctx context.Context, arg database.InsertTemplateVersionParams) (database.TemplateVersion, error) {
Expand Down Expand Up @@ -2134,13 +2137,13 @@ func (q *querier) UpdateReplica(ctx context.Context, arg database.UpdateReplicaP
return q.db.UpdateReplica(ctx, arg)
}

func (q *querier) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) {
// UpdateTemplateACL uses the ActionCreate action. Only users that can create the template
// may update the ACL.
func (q *querier) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error {
fetch := func(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) {
return q.db.GetTemplateByID(ctx, arg.ID)
}
return fetchAndQuery(q.log, q.auth, rbac.ActionCreate, fetch, q.db.UpdateTemplateACLByID)(ctx, arg)
// UpdateTemplateACL uses the ActionCreate action. Only users that can create the template
// may update the ACL.
return fetchAndExec(q.log, q.auth, rbac.ActionCreate, fetch, q.db.UpdateTemplateACLByID)(ctx, arg)
}

func (q *querier) UpdateTemplateActiveVersionByID(ctx context.Context, arg database.UpdateTemplateActiveVersionByIDParams) error {
Expand All @@ -2155,18 +2158,18 @@ func (q *querier) UpdateTemplateDeletedByID(ctx context.Context, arg database.Up
return q.SoftDeleteTemplateByID(ctx, arg.ID)
}

func (q *querier) UpdateTemplateMetaByID(ctx context.Context, arg database.UpdateTemplateMetaByIDParams) (database.Template, error) {
func (q *querier) UpdateTemplateMetaByID(ctx context.Context, arg database.UpdateTemplateMetaByIDParams) error {
fetch := func(ctx context.Context, arg database.UpdateTemplateMetaByIDParams) (database.Template, error) {
return q.db.GetTemplateByID(ctx, arg.ID)
}
return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateTemplateMetaByID)(ctx, arg)
return update(q.log, q.auth, fetch, q.db.UpdateTemplateMetaByID)(ctx, arg)
}

func (q *querier) UpdateTemplateScheduleByID(ctx context.Context, arg database.UpdateTemplateScheduleByIDParams) (database.Template, error) {
func (q *querier) UpdateTemplateScheduleByID(ctx context.Context, arg database.UpdateTemplateScheduleByIDParams) error {
fetch := func(ctx context.Context, arg database.UpdateTemplateScheduleByIDParams) (database.Template, error) {
return q.db.GetTemplateByID(ctx, arg.ID)
}
return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateTemplateScheduleByID)(ctx, arg)
return update(q.log, q.auth, fetch, q.db.UpdateTemplateScheduleByID)(ctx, arg)
}

func (q *querier) UpdateTemplateVersionByID(ctx context.Context, arg database.UpdateTemplateVersionByIDParams) (database.TemplateVersion, error) {
Expand Down
2 changes: 1 addition & 1 deletion coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ func (s *MethodTestSuite) TestTemplate() {
t1 := dbgen.Template(s.T(), db, database.Template{})
check.Args(database.UpdateTemplateACLByIDParams{
ID: t1.ID,
}).Asserts(t1, rbac.ActionCreate).Returns(t1)
}).Asserts(t1, rbac.ActionCreate)
}))
s.Run("UpdateTemplateActiveVersionByID", s.Subtest(func(db database.Store, check *expects) {
t1 := dbgen.Template(s.T(), db, database.Template{
Expand Down
81 changes: 52 additions & 29 deletions coderd/database/dbfake/dbfake.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func New() database.Store {
workspaceResourceMetadata: make([]database.WorkspaceResourceMetadatum, 0),
provisionerJobs: make([]database.ProvisionerJob, 0),
templateVersions: make([]database.TemplateVersion, 0),
templates: make([]database.Template, 0),
templates: make([]database.TemplateTable, 0),
workspaceAgentStats: make([]database.WorkspaceAgentStat, 0),
workspaceAgentLogs: make([]database.WorkspaceAgentStartupLog, 0),
workspaceBuilds: make([]database.WorkspaceBuild, 0),
Expand Down Expand Up @@ -130,7 +130,7 @@ type data struct {
templateVersions []database.TemplateVersion
templateVersionParameters []database.TemplateVersionParameter
templateVersionVariables []database.TemplateVersionVariable
templates []database.Template
templates []database.TemplateTable
workspaceAgents []database.WorkspaceAgent
workspaceAgentMetadata []database.WorkspaceAgentMetadatum
workspaceAgentLogs []database.WorkspaceAgentStartupLog
Expand Down Expand Up @@ -446,12 +446,37 @@ func (q *FakeQuerier) getLatestWorkspaceBuildByWorkspaceIDNoLock(_ context.Conte
func (q *FakeQuerier) getTemplateByIDNoLock(_ context.Context, id uuid.UUID) (database.Template, error) {
for _, template := range q.templates {
if template.ID == id {
return template.DeepCopy(), nil
return q.templateWithUserNoLock(template), nil
}
}
return database.Template{}, sql.ErrNoRows
}

func (q *FakeQuerier) templatesWithUserNoLock(tpl []database.TemplateTable) []database.Template {
cpy := make([]database.Template, 0, len(tpl))
for _, t := range tpl {
cpy = append(cpy, q.templateWithUserNoLock(t))
}
return cpy
}

func (q *FakeQuerier) templateWithUserNoLock(tpl database.TemplateTable) database.Template {
var user database.User
for _, _user := range q.users {
if _user.ID == tpl.CreatedBy {
user = _user
break
}
}
var withUser database.Template
// This is a cheeky way to copy the fields over without explicitly listing them all.
d, _ := json.Marshal(tpl)
_ = json.Unmarshal(d, &withUser)
withUser.CreatedByUsername = user.Username
withUser.CreatedByAvatarURL = user.AvatarURL.String
return withUser
}

func (q *FakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) {
for _, templateVersion := range q.templateVersions {
if templateVersion.ID != templateVersionID {
Expand Down Expand Up @@ -1869,7 +1894,7 @@ func (q *FakeQuerier) GetTemplateByOrganizationAndName(_ context.Context, arg da
if template.Deleted != arg.Deleted {
continue
}
return template.DeepCopy(), nil
return q.templateWithUserNoLock(template), nil
}
return database.Template{}, sql.ErrNoRows
}
Expand Down Expand Up @@ -2092,17 +2117,14 @@ func (q *FakeQuerier) GetTemplates(_ context.Context) ([]database.Template, erro
defer q.mutex.RUnlock()

templates := slices.Clone(q.templates)
for i := range templates {
templates[i] = templates[i].DeepCopy()
}
slices.SortFunc(templates, func(i, j database.Template) bool {
slices.SortFunc(templates, func(i, j database.TemplateTable) bool {
if i.Name != j.Name {
return i.Name < j.Name
}
return i.ID.String() < j.ID.String()
})

return templates, nil
return q.templatesWithUserNoLock(templates), nil
}

func (q *FakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
Expand Down Expand Up @@ -3436,16 +3458,16 @@ func (q *FakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplic
return replica, nil
}

func (q *FakeQuerier) InsertTemplate(_ context.Context, arg database.InsertTemplateParams) (database.Template, error) {
func (q *FakeQuerier) InsertTemplate(_ context.Context, arg database.InsertTemplateParams) error {
if err := validateDatabaseType(arg); err != nil {
return database.Template{}, err
return err
}

q.mutex.Lock()
defer q.mutex.Unlock()

//nolint:gosimple
template := database.Template{
template := database.TemplateTable{
ID: arg.ID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
Expand All @@ -3464,7 +3486,7 @@ func (q *FakeQuerier) InsertTemplate(_ context.Context, arg database.InsertTempl
AllowUserAutostop: true,
}
q.templates = append(q.templates, template)
return template.DeepCopy(), nil
return nil
}

func (q *FakeQuerier) InsertTemplateVersion(_ context.Context, arg database.InsertTemplateVersionParams) (database.TemplateVersion, error) {
Expand Down Expand Up @@ -4172,9 +4194,9 @@ func (q *FakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplic
return database.Replica{}, sql.ErrNoRows
}

func (q *FakeQuerier) UpdateTemplateACLByID(_ context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) {
func (q *FakeQuerier) UpdateTemplateACLByID(_ context.Context, arg database.UpdateTemplateACLByIDParams) error {
if err := validateDatabaseType(arg); err != nil {
return database.Template{}, err
return err
}

q.mutex.Lock()
Expand All @@ -4186,11 +4208,11 @@ func (q *FakeQuerier) UpdateTemplateACLByID(_ context.Context, arg database.Upda
template.UserACL = arg.UserACL

q.templates[i] = template
return template.DeepCopy(), nil
return nil
}
}

return database.Template{}, sql.ErrNoRows
return sql.ErrNoRows
}

func (q *FakeQuerier) UpdateTemplateActiveVersionByID(_ context.Context, arg database.UpdateTemplateActiveVersionByIDParams) error {
Expand Down Expand Up @@ -4233,9 +4255,9 @@ func (q *FakeQuerier) UpdateTemplateDeletedByID(_ context.Context, arg database.
return sql.ErrNoRows
}

func (q *FakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.UpdateTemplateMetaByIDParams) (database.Template, error) {
func (q *FakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.UpdateTemplateMetaByIDParams) error {
if err := validateDatabaseType(arg); err != nil {
return database.Template{}, err
return err
}

q.mutex.Lock()
Expand All @@ -4251,15 +4273,15 @@ func (q *FakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.Upd
tpl.Description = arg.Description
tpl.Icon = arg.Icon
q.templates[idx] = tpl
return tpl.DeepCopy(), nil
return nil
}

return database.Template{}, sql.ErrNoRows
return sql.ErrNoRows
}

func (q *FakeQuerier) UpdateTemplateScheduleByID(_ context.Context, arg database.UpdateTemplateScheduleByIDParams) (database.Template, error) {
func (q *FakeQuerier) UpdateTemplateScheduleByID(_ context.Context, arg database.UpdateTemplateScheduleByIDParams) error {
if err := validateDatabaseType(arg); err != nil {
return database.Template{}, err
return err
}

q.mutex.Lock()
Expand All @@ -4278,10 +4300,10 @@ func (q *FakeQuerier) UpdateTemplateScheduleByID(_ context.Context, arg database
tpl.InactivityTTL = arg.InactivityTTL
tpl.LockedTTL = arg.LockedTTL
q.templates[idx] = tpl
return tpl.DeepCopy(), nil
return nil
}

return database.Template{}, sql.ErrNoRows
return sql.ErrNoRows
}

func (q *FakeQuerier) UpdateTemplateVersionByID(_ context.Context, arg database.UpdateTemplateVersionByIDParams) (database.TemplateVersion, error) {
Expand Down Expand Up @@ -4984,7 +5006,8 @@ func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.G
}

var templates []database.Template
for _, template := range q.templates {
for _, templateTable := range q.templates {
template := q.templateWithUserNoLock(templateTable)
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
continue
}
Expand Down Expand Up @@ -5012,7 +5035,7 @@ func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.G
continue
}
}
templates = append(templates, template.DeepCopy())
templates = append(templates, template)
}
if len(templates) > 0 {
slices.SortFunc(templates, func(i, j database.Template) bool {
Expand All @@ -5031,7 +5054,7 @@ func (q *FakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]
q.mutex.RLock()
defer q.mutex.RUnlock()

var template database.Template
var template database.TemplateTable
for _, t := range q.templates {
if t.ID == id {
template = t
Expand Down Expand Up @@ -5068,7 +5091,7 @@ func (q *FakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]d
q.mutex.RLock()
defer q.mutex.RUnlock()

var template database.Template
var template database.TemplateTable
for _, t := range q.templates {
if t.ID == id {
template = t
Expand Down
8 changes: 6 additions & 2 deletions coderd/database/dbgen/dbgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ func AuditLog(t testing.TB, db database.Store, seed database.AuditLog) database.
}

func Template(t testing.TB, db database.Store, seed database.Template) database.Template {
template, err := db.InsertTemplate(genCtx, database.InsertTemplateParams{
ID: takeFirst(seed.ID, uuid.New()),
id := takeFirst(seed.ID, uuid.New())
err := db.InsertTemplate(genCtx, database.InsertTemplateParams{
ID: id,
CreatedAt: takeFirst(seed.CreatedAt, database.Now()),
UpdatedAt: takeFirst(seed.UpdatedAt, database.Now()),
OrganizationID: takeFirst(seed.OrganizationID, uuid.New()),
Expand All @@ -79,6 +80,9 @@ func Template(t testing.TB, db database.Store, seed database.Template) database.
AllowUserCancelWorkspaceJobs: seed.AllowUserCancelWorkspaceJobs,
})
require.NoError(t, err, "insert template")

template, err := db.GetTemplateByID(context.Background(), id)
require.NoError(t, err, "get template")
return template
}

Expand Down
24 changes: 12 additions & 12 deletions coderd/database/dbmetrics/dbmetrics.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit aceedef

Please sign in to comment.