diff --git a/adapter.go b/adapter.go index f61f88d..b3ed4d1 100644 --- a/adapter.go +++ b/adapter.go @@ -20,7 +20,6 @@ import ( "github.com/casbin/casbin/model" _ "github.com/go-sql-driver/mysql" // This is for MySQL initialization. - ) // DBAdapter represents the database adapter for policy persistence, can load policy from database or save policy to database. @@ -39,18 +38,26 @@ func NewDBAdapter(driverName string, dataSourceName string) *DBAdapter { return &a } -func (a *DBAdapter) open() { +func (a *DBAdapter) createDatabase() error { db, err := sql.Open(a.driverName, a.dataSourceName) if err != nil { - panic(err) + return err } + defer db.Close() _, err = db.Exec("CREATE DATABASE IF NOT EXISTS casbin") if err != nil { + return err + } + return nil +} + +func (a *DBAdapter) open() { + if err := a.createDatabase(); err != nil { panic(err) } - db, err = sql.Open("mysql", a.dataSourceName+"casbin") + db, err := sql.Open(a.driverName, a.dataSourceName+"casbin") if err != nil { panic(err) } @@ -137,17 +144,17 @@ func (a *DBAdapter) LoadPolicy(model model.Model) { } } -func (a *DBAdapter) writeTableLine(ptype string, rule []string) { - line := "'" + ptype + "'" - for i := range rule { - line += ",'" + rule[i] + "'" +func (a *DBAdapter) writeTableLine(stm *sql.Stmt, ptype string, rule []string) { + params := make([]interface{}, 0, 5) + params = append(params, ptype) + for _, v := range rule { + params = append(params, v) } - for i := 0; i < 4-len(rule); i++ { - line += ",''" + need := 5 - len(params) + for i := 0; i < need; i++ { + params = append(params, "") } - - _, err := a.db.Exec("insert into policy values(" + line + ")") - if err != nil { + if _, err := stm.Exec(params...); err != nil { panic(err) } } @@ -160,15 +167,21 @@ func (a *DBAdapter) SavePolicy(model model.Model) { a.dropTable() a.createTable() + stm, err := a.db.Prepare("insert into policy values(?, ?, ?, ?, ?)") + if err != nil { + panic(err) + } + defer stm.Close() + for ptype, ast := range model["p"] { for _, rule := range ast.Policy { - a.writeTableLine(ptype, rule) + a.writeTableLine(stm, ptype, rule) } } for ptype, ast := range model["g"] { for _, rule := range ast.Policy { - a.writeTableLine(ptype, rule) + a.writeTableLine(stm, ptype, rule) } } }