Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (

"github.com/casbin/casbin/model"
_ "github.com/go-sql-driver/mysql" // This is for MySQL initialization.

)

// DBAdapter represents the database adapter for policy persistence, can load policy from database or save policy to database.
Expand All @@ -39,18 +38,26 @@ func NewDBAdapter(driverName string, dataSourceName string) *DBAdapter {
return &a
}

func (a *DBAdapter) open() {
func (a *DBAdapter) createDatabase() error {
db, err := sql.Open(a.driverName, a.dataSourceName)
if err != nil {
panic(err)
return err
}
defer db.Close()

_, err = db.Exec("CREATE DATABASE IF NOT EXISTS casbin")
if err != nil {
return err
}
return nil
}

func (a *DBAdapter) open() {
if err := a.createDatabase(); err != nil {
panic(err)
}

db, err = sql.Open("mysql", a.dataSourceName+"casbin")
db, err := sql.Open(a.driverName, a.dataSourceName+"casbin")
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -137,17 +144,17 @@ func (a *DBAdapter) LoadPolicy(model model.Model) {
}
}

func (a *DBAdapter) writeTableLine(ptype string, rule []string) {
line := "'" + ptype + "'"
for i := range rule {
line += ",'" + rule[i] + "'"
func (a *DBAdapter) writeTableLine(stm *sql.Stmt, ptype string, rule []string) {
params := make([]interface{}, 0, 5)
params = append(params, ptype)
for _, v := range rule {
params = append(params, v)
}
for i := 0; i < 4-len(rule); i++ {
line += ",''"
need := 5 - len(params)
for i := 0; i < need; i++ {
params = append(params, "")
}

_, err := a.db.Exec("insert into policy values(" + line + ")")
if err != nil {
if _, err := stm.Exec(params...); err != nil {
panic(err)
}
}
Expand All @@ -160,15 +167,21 @@ func (a *DBAdapter) SavePolicy(model model.Model) {
a.dropTable()
a.createTable()

stm, err := a.db.Prepare("insert into policy values(?, ?, ?, ?, ?)")
if err != nil {
panic(err)
}
defer stm.Close()

for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
a.writeTableLine(ptype, rule)
a.writeTableLine(stm, ptype, rule)
}
}

for ptype, ast := range model["g"] {
for _, rule := range ast.Policy {
a.writeTableLine(ptype, rule)
a.writeTableLine(stm, ptype, rule)
}
}
}