Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add organization ID, and remove FKs and sequences #2896

Merged
merged 6 commits into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 1 addition & 6 deletions internal/access/access_test.go
Expand Up @@ -41,13 +41,8 @@ func setupAccessTestContext(t *testing.T) (*gin.Context, *gorm.DB, *models.Provi
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Set("db", db)

err := data.SaveSettings(db, &models.Settings{
pdevine marked this conversation as resolved.
Show resolved Hide resolved
LengthMin: 8,
})
assert.NilError(t, err)

admin := &models.Identity{Name: "admin@example.com"}
err = data.CreateIdentity(db, admin)
err := data.CreateIdentity(db, admin)
assert.NilError(t, err)

c.Set("identity", admin)
Expand Down
5 changes: 5 additions & 0 deletions internal/access/credential_test.go
Expand Up @@ -86,6 +86,11 @@ func TestUpdateCredentials(t *testing.T) {
_, err = CreateCredential(c, *user)
assert.NilError(t, err)

err = data.SaveSettings(db, &models.Settings{
LengthMin: 8,
})
assert.NilError(t, err)

t.Run("Update user credentials IS single use password", func(t *testing.T) {
err := UpdateCredential(c, user, "newPassword")
assert.NilError(t, err)
Expand Down
7 changes: 6 additions & 1 deletion internal/access/passwordreset_test.go
Expand Up @@ -13,12 +13,17 @@ import (
func TestPasswordResetFlow(t *testing.T) {
c, db, _ := setupAccessTestContext(t)

err := data.SaveSettings(db, &models.Settings{
LengthMin: 8,
})
assert.NilError(t, err)

user := &models.Identity{
Name: "joe@example.com",
}

// setup user
err := CreateIdentity(c, user)
err = CreateIdentity(c, user)
assert.NilError(t, err)

err = data.CreateCredential(db, &models.Credential{
Expand Down
8 changes: 4 additions & 4 deletions internal/access/signup.go
Expand Up @@ -20,22 +20,22 @@ func SignupEnabled(c *gin.Context) (bool, error) {
db := getDB(c)

// use Unscoped because deleting identities, providers or grants should not re-enable signup
identities, err := data.Count[models.Identity](db.Unscoped(), data.NotName(models.InternalInfraConnectorIdentityName))
identities, err := data.GlobalCount[models.Identity](db.Unscoped(), data.NotName(models.InternalInfraConnectorIdentityName))
if err != nil {
return false, err
}

providers, err := data.Count[models.Provider](db.Unscoped(), data.NotProviderKind(models.ProviderKindInfra))
providers, err := data.GlobalCount[models.Provider](db.Unscoped(), data.NotProviderKind(models.ProviderKindInfra))
if err != nil {
return false, err
}

grants, err := data.Count[models.Grant](db.Unscoped(), data.NotPrivilege(models.InfraConnectorRole))
grants, err := data.GlobalCount[models.Grant](db.Unscoped(), data.NotPrivilege(models.InfraConnectorRole))
if err != nil {
return false, err
}

accessKeys, err := data.Count[models.AccessKey](db.Unscoped())
accessKeys, err := data.GlobalCount[models.AccessKey](db.Unscoped())
if err != nil {
return false, err
}
Expand Down
90 changes: 59 additions & 31 deletions internal/server/data/data.go
Expand Up @@ -105,6 +105,10 @@ func get[T models.Modelable](db *gorm.DB, selectors ...SelectorFunc) (*T, error)
}

result := new(T)
if isOrgMember(result) {
db = ByOrgID(OrgFromContext(db.Statement.Context).ID)(db)
BruceMacD marked this conversation as resolved.
Show resolved Hide resolved
}

if err := db.Model((*T)(nil)).First(result).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, internal.ErrNotFound
Expand All @@ -116,11 +120,34 @@ func get[T models.Modelable](db *gorm.DB, selectors ...SelectorFunc) (*T, error)
return result, nil
}

func setOrg(db *gorm.DB, model any) {
member, ok := model.(orgMember)
if !ok {
return
}
Comment on lines +124 to +127
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we are ready to, we can make SetOrganizationID a required method, and implement no-op implementations on the other tables. That gives us the same safeguards as putting it in the base model, without adding unused organization_id fields.


org := OrgFromContext(db.Statement.Context)
member.SetOrganizationID(org.ID)
}

type orgMember interface {
IsOrganizationMember()
SetOrganizationID(id uid.ID)
}
ssoroka marked this conversation as resolved.
Show resolved Hide resolved

func isOrgMember(model any) bool {
_, ok := model.(orgMember)
return ok
}

func list[T models.Modelable](db *gorm.DB, p *models.Pagination, selectors ...SelectorFunc) ([]T, error) {
db = db.Order(getDefaultSortFromType((*T)(nil)))
for _, selector := range selectors {
db = selector(db)
}
if isOrgMember(new(T)) {
db = ByOrgID(OrgFromContext(db.Statement.Context).ID)(db)
}

if p != nil {
var count int64
Expand All @@ -141,11 +168,14 @@ func list[T models.Modelable](db *gorm.DB, p *models.Pagination, selectors ...Se
}

func save[T models.Modelable](db *gorm.DB, model *T) error {
setOrg(db, model)
err := db.Save(model).Error
return handleError(err)
}

func add[T models.Modelable](db *gorm.DB, model *T) error {
setOrg(db, model)

var err error
if db.Name() == "postgres" {
// failures on postgres need to be rolled back in order to
Expand Down Expand Up @@ -247,18 +277,25 @@ func handleError(err error) error {
}

func delete[T models.Modelable](db *gorm.DB, id uid.ID) error {
if isOrgMember(new(T)) {
db = ByOrgID(OrgFromContext(db.Statement.Context).ID)(db)
}
return db.Delete(new(T), id).Error
}

func deleteAll[T models.Modelable](db *gorm.DB, selectors ...SelectorFunc) error {
for _, selector := range selectors {
db = selector(db)
}
if isOrgMember(new(T)) {
db = ByOrgID(OrgFromContext(db.Statement.Context).ID)(db)
}

return db.Delete(new(T)).Error
}

func Count[T models.Modelable](db *gorm.DB, selectors ...SelectorFunc) (int64, error) {
// GlobalCount gives the count of all records, not scoped by org.
func GlobalCount[T models.Modelable](db *gorm.DB, selectors ...SelectorFunc) (int64, error) {
for _, selector := range selectors {
db = selector(db)
}
Expand All @@ -271,53 +308,44 @@ func Count[T models.Modelable](db *gorm.DB, selectors ...SelectorFunc) (int64, e
return count, nil
}

var infraProviderCache *models.Provider

// InfraProvider is a lazy-loaded cached reference to the infra provider. The
// cache lasts for the entire lifetime of the process, so any test or test
// helper that calls InfraProvider must call InvalidateCache to clean up.
func InfraProvider(db *gorm.DB) *models.Provider {
if infraProviderCache == nil {
infra, err := get[models.Provider](db, ByProviderKind(models.ProviderKindInfra))
if err != nil {
if errors.Is(err, internal.ErrNotFound) {
p := &models.Provider{Name: models.InternalInfraProviderName, Kind: models.ProviderKindInfra}
if err := add(db, p); err != nil {
logging.L.Panic().Err(err).Msg("failed to create infra provider")
}
return p
org := OrgFromContext(db.Statement.Context)
infra, err := get[models.Provider](db, ByProviderKind(models.ProviderKindInfra), ByOrgID(org.ID))
BruceMacD marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
if errors.Is(err, internal.ErrNotFound) {
p := &models.Provider{
BruceMacD marked this conversation as resolved.
Show resolved Hide resolved
Name: models.InternalInfraProviderName,
Kind: models.ProviderKindInfra,
OrganizationMember: models.OrganizationMember{OrganizationID: org.ID},
}
if err := add(db, p); err != nil {
logging.L.Panic().Err(err).Msg("failed to create infra provider")
}
logging.L.Panic().Err(err).Msg("failed to retrieve infra provider")
return nil // unreachable, the line above panics
return p
}

infraProviderCache = infra
logging.L.Panic().Err(err).Msg("failed to retrieve infra provider")
return nil // unreachable, the line above panics
}

return infraProviderCache
return infra
}

var infraConnectorCache *models.Identity

// InfraConnectorIdentity is a lazy-loaded reference to the connector identity.
// The cache lasts for the entire lifetime of the process, so any test or test
// helper that calls InfraConnectorIdentity must call InvalidateCache to clean up.
func InfraConnectorIdentity(db *gorm.DB) *models.Identity {
if infraConnectorCache == nil {
connector, err := GetIdentity(db, ByName(models.InternalInfraConnectorIdentityName))
if err != nil {
logging.L.Panic().Err(err).Msg("failed to retrieve connector identity")
return nil // unreachable, the line above panics
}

infraConnectorCache = connector
org := OrgFromContext(db.Statement.Context)
connector, err := GetIdentity(db, ByName(models.InternalInfraConnectorIdentityName), ByOrgID(org.ID))
if err != nil {
logging.L.Panic().Err(err).Msg("failed to retrieve connector identity")
return nil // unreachable, the line above panics
}

return infraConnectorCache
return connector
}

// InvalidateCache is used to clear references to frequently used resources
func InvalidateCache() {
infraProviderCache = nil
infraConnectorCache = nil
}
13 changes: 13 additions & 0 deletions internal/server/data/data_test.go
Expand Up @@ -214,3 +214,16 @@ func TestCreateTransactionError(t *testing.T) {
assert.NilError(t, err)
})
}

func TestSetOrg(t *testing.T) {
model := &models.AccessKey{}
org := &models.Organization{}
org.ID = 123456

db := &gorm.DB{}
db.Statement = &gorm.Statement{
Context: WithOrg(context.Background(), org),
}
setOrg(db, model)
assert.Equal(t, model.OrganizationID, uid.ID(123456))
}