Skip to content

Commit

Permalink
fix(acl): fix wrong type name (#560)
Browse files Browse the repository at this point in the history
Because

- Wrong type name for model

This commit

- fix model acl type name
- add missing users pinning in `repository`
  • Loading branch information
heiruwu committed May 7, 2024
1 parent 7b17393 commit 89d09a5
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 40 deletions.
2 changes: 1 addition & 1 deletion cmd/main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func main() {
}
aclClient := acl.NewACLClient(fgaClient, fgaReplicaClient, redisClient)

repo := repository.NewRepository(db)
repo := repository.NewRepository(db, redisClient)

serv := service.NewService(repo, mgmtPublicServiceClient, mgmtPrivateServiceClient, artifactPrivateServiceClient, redisClient, temporalClient, rayService, &aclClient)

Expand Down
2 changes: 1 addition & 1 deletion cmd/worker/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func main() {
initTemporalNamespace(ctx, tempClient)
}

cw := modelWorker.NewWorker(repository.NewRepository(db), redisClient, rayService)
cw := modelWorker.NewWorker(repository.NewRepository(db, redisClient), redisClient, rayService)

w := worker.New(tempClient, modelWorker.TaskQueue, worker.Options{})

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ require (
gorm.io/datatypes v1.1.0
gorm.io/driver/postgres v1.4.5
gorm.io/gorm v1.25.2
gorm.io/plugin/dbresolver v1.5.1
)

require (
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2825,6 +2825,7 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/datatypes v1.1.0 h1:EVp1Z28N4ACpYFK1nHboEIJGIFfjY7vLeieDk8jSHJA=
gorm.io/datatypes v1.1.0/go.mod h1:SH2K9R+2RMjuX1CkCONrPwoe9JzVv2hkQvEu4bXGojE=
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
gorm.io/driver/mysql v1.4.4 h1:MX0K9Qvy0Na4o7qSC/YI7XxqUw5KDw01umqgID+svdQ=
gorm.io/driver/mysql v1.4.4/go.mod h1:BCg8cKI+R0j/rZRQxeKis/forqRwRSYOR8OM3Wo6hOM=
gorm.io/driver/postgres v1.0.8/go.mod h1:4eOzrI1MUfm6ObJU/UcmbXyiHSs8jSwH95G5P5dxcAg=
Expand All @@ -2840,6 +2841,8 @@ gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.1-0.20221019064659-5dd2bb482755/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/plugin/dbresolver v1.5.1 h1:s9Dj9f7r+1rE3nx/Ywzc85nXptUEaeOO0pt27xdopM8=
gorm.io/plugin/dbresolver v1.5.1/go.mod h1:l4Cn87EHLEYuqUncpEeTC2tTJQkjngPSD+lo8hIvcT0=
gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk=
Expand Down
4 changes: 2 additions & 2 deletions pkg/acl/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (c *ACLClient) SetModelPermission(ctx context.Context, modelUID uuid.UUID,
{
User: user,
Relation: role,
Object: fmt.Sprintf("model:%s", modelUID.String()),
Object: fmt.Sprintf("model_:%s", modelUID.String()),
},
},
},
Expand All @@ -180,7 +180,7 @@ func (c *ACLClient) DeleteModelPermission(ctx context.Context, modelUID uuid.UUI
{
User: user,
Relation: role,
Object: fmt.Sprintf("model:%s", modelUID.String()),
Object: fmt.Sprintf("model_:%s", modelUID.String()),
},
},
},
Expand Down
2 changes: 0 additions & 2 deletions pkg/handler/public.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ func (h *PublicHandler) ListModels(ctx context.Context, req *modelPB.ListModelsR
span,
logUUID.String(),
eventName,
custom_otel.SetEventResource(pbModels),
custom_otel.SetEventMessage(fmt.Sprintf("%s done", eventName)),
)))

Expand Down Expand Up @@ -296,7 +295,6 @@ func (h *PublicHandler) listNamespaceModels(ctx context.Context, req ListNamespa
span,
logUUID.String(),
eventName,
custom_otel.SetEventResource(pbModels),
custom_otel.SetEventMessage(fmt.Sprintf("%s done", eventName)),
)))

Expand Down
73 changes: 63 additions & 10 deletions pkg/repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@ package repository

import (
"context"
"errors"
"fmt"
"time"

"github.com/go-redis/redis/v9"
"github.com/gofrs/uuid"
"go.einride.tech/aip/filtering"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/plugin/dbresolver"

"github.com/instill-ai/model-backend/config"
"github.com/instill-ai/model-backend/internal/resource"
"github.com/instill-ai/model-backend/pkg/constant"
"github.com/instill-ai/model-backend/pkg/datamodel"
"github.com/instill-ai/x/paginate"
"github.com/instill-ai/x/sterr"
Expand Down Expand Up @@ -57,14 +63,33 @@ const DefaultPageSize = 10
const MaxPageSize = 100

type repository struct {
db *gorm.DB
db *gorm.DB
redisClient *redis.Client
}

func NewRepository(db *gorm.DB) Repository {
// NewRepository initiates a repository instance
func NewRepository(db *gorm.DB, redisClient *redis.Client) Repository {
return &repository{
db: db,
db: db,
redisClient: redisClient,
}
}
func (r *repository) checkPinnedUser(ctx context.Context, db *gorm.DB, _ string) *gorm.DB {
userUID := resource.GetRequestSingleHeader(ctx, constant.HeaderUserUIDKey)
// If the user is pinned, we will use the primary database for querying.
if !errors.Is(r.redisClient.Get(ctx, fmt.Sprintf("db_pin_user:%s:%s", userUID, "model")).Err(), redis.Nil) {
db = db.Clauses(dbresolver.Write)
}
return db
}

func (r *repository) pinUser(ctx context.Context, _ string) {
userUID := resource.GetRequestSingleHeader(ctx, constant.HeaderUserUIDKey)
// To solve the read-after-write inconsistency problem,
// we will direct the user to read from the primary database for a certain time frame
// to ensure that the data is synchronized from the primary DB to the replica DB.
_ = r.redisClient.Set(ctx, fmt.Sprintf("db_pin_user:%s:%s", userUID, "model"), time.Now(), time.Duration(config.Config.Database.Replica.ReplicationTimeFrame)*time.Second)
}

func (r *repository) listModels(ctx context.Context, where string, whereArgs []any, pageSize int64, pageToken string, isBasicView bool, filter filtering.Filter, uidAllowList []uuid.UUID, showDeleted bool) (models []*datamodel.Model, totalSize int64, nextPageToken string, err error) {

Expand Down Expand Up @@ -253,14 +278,22 @@ func (r *repository) GetModelByUIDAdmin(ctx context.Context, uid uuid.UUID, isBa
}

func (r *repository) CreateNamespaceModel(ctx context.Context, ownerPermalink string, model *datamodel.Model) error {
if result := r.db.Model(&datamodel.Model{}).Create(model); result.Error != nil {

r.pinUser(ctx, "model")
db := r.checkPinnedUser(ctx, r.db, "model")

if result := db.Model(&datamodel.Model{}).Create(model); result.Error != nil {
return result.Error
}
return nil
}

func (r *repository) UpdateNamespaceModelByID(ctx context.Context, ownerPermalink string, id string, model *datamodel.Model) error {
if result := r.db.Model(&datamodel.Model{}).

r.pinUser(ctx, "model")
db := r.checkPinnedUser(ctx, r.db, "model")

if result := db.Model(&datamodel.Model{}).
Where("(id = ? AND owner = ?)", id, ownerPermalink).
Updates(model); result.Error != nil {
return result.Error
Expand All @@ -271,7 +304,11 @@ func (r *repository) UpdateNamespaceModelByID(ctx context.Context, ownerPermalin
}

func (r *repository) UpdateNamespaceModelIDByID(ctx context.Context, ownerPermalink string, id string, newID string) error {
if result := r.db.Model(&datamodel.Model{}).

r.pinUser(ctx, "model")
db := r.checkPinnedUser(ctx, r.db, "model")

if result := db.Model(&datamodel.Model{}).
Where("(id = ? AND owner = ?)", id, ownerPermalink).
Update("id", newID); result.Error != nil {
return result.Error
Expand All @@ -282,7 +319,11 @@ func (r *repository) UpdateNamespaceModelIDByID(ctx context.Context, ownerPermal
}

func (r *repository) DeleteNamespaceModelByID(ctx context.Context, ownerPermalink string, id string) error {
result := r.db.Model(&datamodel.Model{}).

r.pinUser(ctx, "model")
db := r.checkPinnedUser(ctx, r.db, "model")

result := db.Model(&datamodel.Model{}).
Where("(id = ? AND owner = ?)", id, ownerPermalink).
Delete(&datamodel.Model{})

Expand All @@ -298,23 +339,35 @@ func (r *repository) DeleteNamespaceModelByID(ctx context.Context, ownerPermalin
}

func (r *repository) CreateModelPrediction(ctx context.Context, prediction *datamodel.ModelPrediction) error {
if result := r.db.Model(&datamodel.ModelPrediction{}).Create(&prediction); result.Error != nil {

r.pinUser(ctx, "model")
db := r.checkPinnedUser(ctx, r.db, "model")

if result := db.Model(&datamodel.ModelPrediction{}).Create(&prediction); result.Error != nil {
return result.Error
}

return nil
}

func (r *repository) CreateModelVersion(ctx context.Context, ownerPermalink string, version *datamodel.ModelVersion) error {
if result := r.db.Model(&datamodel.ModelVersion{}).Create(&version); result.Error != nil {

r.pinUser(ctx, "model")
db := r.checkPinnedUser(ctx, r.db, "model")

if result := db.Model(&datamodel.ModelVersion{}).Create(&version); result.Error != nil {
return result.Error
}

return nil
}

func (r *repository) DeleteModelVersion(ctx context.Context, ownerPermalink string, version *datamodel.ModelVersion) error {
result := r.db.Model(&datamodel.ModelVersion{}).

r.pinUser(ctx, "model")
db := r.checkPinnedUser(ctx, r.db, "model")

result := db.Model(&datamodel.ModelVersion{}).
Where("(name = ? AND version = ?)", version.Name, version.Version).
Delete(&datamodel.ModelVersion{})

Expand Down
Loading

0 comments on commit 89d09a5

Please sign in to comment.