Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 16 additions & 24 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"gopkg.in/mgo.v2"
)

// CasbinRule represents a rule in Casbin.
type CasbinRule struct {
PType string
V0 string
Expand All @@ -34,34 +35,22 @@ type CasbinRule struct {

// Adapter represents the MongoDB adapter for policy storage.
type Adapter struct {
url string
dbSpecified bool
session *mgo.Session
collection *mgo.Collection
url string
session *mgo.Session
collection *mgo.Collection
}

// finalizer is the destructor for Adapter.
func finalizer(a *Adapter) {
a.close()
}

// NewAdapter is the constructor for Adapter.
// dbSpecified is an optional bool parameter. The default value is false.
// It's up to whether you have specified an existing DB in dataSourceName.
// If dbSpecified == true, you need to make sure the DB in dataSourceName exists.
// If dbSpecified == false, the adapter will automatically create a DB named "casbin".
func NewAdapter(url string, dbSpecified ...bool) *Adapter {
// NewAdapter is the constructor for Adapter. If database name is not provided
// in the Mongo URL, 'casbin' will be used as database name.
func NewAdapter(url string) *Adapter {
a := &Adapter{}
a.url = url

if len(dbSpecified) == 0 {
a.dbSpecified = false
} else if len(dbSpecified) == 1 {
a.dbSpecified = dbSpecified[0]
} else {
panic(errors.New("invalid parameter: dbSpecified"))
}

// Open the DB, create it if not existed.
a.open()

Expand Down Expand Up @@ -132,18 +121,21 @@ func (a *Adapter) createIndice() {
}

func (a *Adapter) open() {
session, err := mgo.Dial(a.url)
dI, err := mgo.ParseURL(a.url)
if err != nil {
panic(err)
}

var db *mgo.Database
if a.dbSpecified {
db = session.DB("")
} else {
db = session.DB("casbin")
if dI.Database == "" {
dI.Database = "casbin"
}

session, err := mgo.DialWithInfo(dI)
if err != nil {
panic(err)
}

db := session.DB(dI.Database)
collection := db.C("casbin_rule")

a.session = session
Expand Down