Skip to content

Commit

Permalink
Add custom tablename support (#30)
Browse files Browse the repository at this point in the history
* Fix custom tablename support
  • Loading branch information
CyJaySong committed Jul 14, 2020
1 parent 0bb458f commit 0ca20dc
Showing 1 changed file with 72 additions and 21 deletions.
93 changes: 72 additions & 21 deletions adapter.go
Expand Up @@ -25,14 +25,22 @@ import (
"xorm.io/xorm"
)

func (the *CasbinRule) TableName() string {
if len(the.tableName) == 0 {
return "casbin_rule"
}
return the.tableName
}

type CasbinRule struct {
PType string `xorm:"varchar(100) index not null default ''"`
V0 string `xorm:"varchar(100) index not null default ''"`
V1 string `xorm:"varchar(100) index not null default ''"`
V2 string `xorm:"varchar(100) index not null default ''"`
V3 string `xorm:"varchar(100) index not null default ''"`
V4 string `xorm:"varchar(100) index not null default ''"`
V5 string `xorm:"varchar(100) index not null default ''"`
PType string `xorm:"varchar(100) index not null default ''"`
V0 string `xorm:"varchar(100) index not null default ''"`
V1 string `xorm:"varchar(100) index not null default ''"`
V2 string `xorm:"varchar(100) index not null default ''"`
V3 string `xorm:"varchar(100) index not null default ''"`
V4 string `xorm:"varchar(100) index not null default ''"`
V5 string `xorm:"varchar(100) index not null default ''"`
tableName string `xorm:"-" json:"-"`
}

// Adapter represents the Xorm adapter for policy storage.
Expand All @@ -42,6 +50,7 @@ type Adapter struct {
dbSpecified bool
isFiltered bool
engine *xorm.Engine
tableName string
}

type Filter struct {
Expand All @@ -68,9 +77,37 @@ func finalizer(a *Adapter) {
// If dbSpecified == true, you need to make sure the DB in dataSourceName exists.
// If dbSpecified == false, the adapter will automatically create a DB named "casbin".
func NewAdapter(driverName string, dataSourceName string, dbSpecified ...bool) (*Adapter, error) {
a := &Adapter{}
a.driverName = driverName
a.dataSourceName = dataSourceName
a := &Adapter{
driverName: driverName,
dataSourceName: dataSourceName,
}

if len(dbSpecified) == 0 {
a.dbSpecified = false
} else if len(dbSpecified) == 1 {
a.dbSpecified = dbSpecified[0]
} else {
return nil, errors.New("invalid parameter: dbSpecified")
}

// Open the DB, create it if not existed.
err := a.open()
if err != nil {
return nil, err
}

// Call the destructor when the object is released.
runtime.SetFinalizer(a, finalizer)

return a, nil
}

func NewAdapterWithTableName(driverName string, dataSourceName string, tableName string, dbSpecified ...bool) (*Adapter, error) {
a := &Adapter{
driverName: driverName,
dataSourceName: dataSourceName,
tableName: tableName,
}

if len(dbSpecified) == 0 {
a.dbSpecified = false
Expand Down Expand Up @@ -105,6 +142,20 @@ func NewAdapterByEngine(engine *xorm.Engine) (*Adapter, error) {
return a, nil
}

func NewAdapterByEngineWithTableName(engine *xorm.Engine, tableName string) (*Adapter, error) {
a := &Adapter{
engine: engine,
tableName: tableName,
}

err := a.createTable()
if err != nil {
return nil, err
}

return a, nil
}

func (a *Adapter) createDatabase() error {
var err error
var engine *xorm.Engine
Expand All @@ -121,15 +172,15 @@ func (a *Adapter) createDatabase() error {
if _, err = engine.Exec("CREATE DATABASE casbin"); err != nil {
// 42P04 is duplicate_database
if pqerr, ok := err.(*pq.Error); ok && pqerr.Code == "42P04" {
engine.Close()
_ = engine.Close()
return nil
}
}
} else if a.driverName != "sqlite3" {
_, err = engine.Exec("CREATE DATABASE IF NOT EXISTS casbin")
}
if err != nil {
engine.Close()
_ = engine.Close()
return err
}

Expand Down Expand Up @@ -178,11 +229,11 @@ func (a *Adapter) close() error {
}

func (a *Adapter) createTable() error {
return a.engine.Sync2(new(CasbinRule))
return a.engine.Sync2(&CasbinRule{tableName: a.tableName})
}

func (a *Adapter) dropTable() error {
return a.engine.DropTables(new(CasbinRule))
return a.engine.DropTables(&CasbinRule{tableName: a.tableName})
}

func loadPolicyLine(line *CasbinRule, model model.Model) {
Expand Down Expand Up @@ -220,8 +271,8 @@ func (a *Adapter) LoadPolicy(model model.Model) error {
return nil
}

func savePolicyLine(ptype string, rule []string) *CasbinRule {
line := &CasbinRule{PType: ptype}
func (a *Adapter) savePolicyLine(ptype string, rule []string) *CasbinRule {
line := &CasbinRule{PType: ptype, tableName: a.tableName}

l := len(rule)
if l > 0 {
Expand Down Expand Up @@ -261,14 +312,14 @@ func (a *Adapter) SavePolicy(model model.Model) error {

for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
line := savePolicyLine(ptype, rule)
line := a.savePolicyLine(ptype, rule)
lines = append(lines, line)
}
}

for ptype, ast := range model["g"] {
for _, rule := range ast.Policy {
line := savePolicyLine(ptype, rule)
line := a.savePolicyLine(ptype, rule)
lines = append(lines, line)
}
}
Expand All @@ -279,21 +330,21 @@ func (a *Adapter) SavePolicy(model model.Model) error {

// AddPolicy adds a policy rule to the storage.
func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
line := savePolicyLine(ptype, rule)
line := a.savePolicyLine(ptype, rule)
_, err := a.engine.Insert(line)
return err
}

// RemovePolicy removes a policy rule from the storage.
func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
line := savePolicyLine(ptype, rule)
line := a.savePolicyLine(ptype, rule)
_, err := a.engine.Delete(line)
return err
}

// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
line := &CasbinRule{PType: ptype}
line := &CasbinRule{PType: ptype, tableName: a.tableName}

idx := fieldIndex + len(fieldValues)
if fieldIndex <= 0 && idx > 0 {
Expand Down

0 comments on commit 0ca20dc

Please sign in to comment.