Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ func main() {
// If it doesn't exist, the adapter will create it automatically.
// a := mongodbadapter.NewAdapter("127.0.0.1:27017/abc")

// Or you can appoint the table name "rule" like this.
// a := mongodbadapter.NewAdapter("127.0.0.1:27017/abc", "rule")

e := casbin.NewEnforcer("examples/rbac_model.conf", a)

// Load the policy from DB.
Expand Down
21 changes: 16 additions & 5 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,18 @@ func finalizer(a *adapter) {

// NewAdapter is the constructor for Adapter. If database name is not provided
// in the Mongo URL, 'casbin' will be used as database name.
func NewAdapter(url string) persist.Adapter {
// Can use provided table name.
func NewAdapter(url string, tableName ...string) persist.Adapter {
a := &adapter{url: url}

// Open the DB, create it if not existed.
a.open()
if len(tableName) == 0 {
a.open()
} else if len(tableName) == 1 {
a.open(tableName[0])
} else {
panic(errors.New("invalid parameter: tableName"))
}

// Call the destructor when the object is released.
runtime.SetFinalizer(a, finalizer)
Expand All @@ -68,7 +75,8 @@ func NewFilteredAdapter(url string) persist.FilteredAdapter {
return NewAdapter(url).(*adapter)
}

func (a *adapter) open() {
func (a *adapter) open(tableName ...string) {
var collection *mgo.Collection
dI, err := mgo.ParseURL(a.url)
if err != nil {
panic(err)
Expand All @@ -91,11 +99,14 @@ func (a *adapter) open() {
}

db := session.DB(dI.Database)
collection := db.C("casbin_rule")
if len(tableName) == 0 {
collection = db.C("casbin_rule")
} else {
collection = db.C(tableName[0])
}

a.session = session
a.collection = collection

indexes := []string{"ptype", "v0", "v1", "v2", "v3", "v4", "v5"}
for _, k := range indexes {
if err := a.collection.EnsureIndexKey(k); err != nil {
Expand Down