Skip to content

Commit

Permalink
Added support for gorm/v2 (#36)
Browse files Browse the repository at this point in the history
* Added support for gorm/v2

Signed-off-by: l1ghtman2k <aibek.zhil@gmail.com>

* Add gorm-adapter/v3 to the documentation

Signed-off-by: l1ghtman2k <aibek.zhil@gmail.com>

* Return a modified finalizer

Signed-off-by: l1ghtman2k <aibek.zhil@gmail.com>
  • Loading branch information
L1ghtman2k committed Jul 22, 2020
1 parent 9ebbe7e commit 57a0c40
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 163 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Gorm Adapter
[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/casbin/lobby)
[![Sourcegraph](https://sourcegraph.com/github.com/casbin/gorm-adapter/-/badge.svg)](https://sourcegraph.com/github.com/casbin/gorm-adapter?badge)

Gorm Adapter is the [Gorm](https://github.com/jinzhu/gorm) adapter for [Casbin](https://github.com/casbin/casbin). With this library, Casbin can load policy from Gorm supported database or save policy to it.
Gorm Adapter is the [Gorm](https://gorm.io/gorm) adapter for [Casbin](https://github.com/casbin/casbin). With this library, Casbin can load policy from Gorm supported database or save policy to it.

Based on [Officially Supported Databases](http://jinzhu.me/gorm/database.html), The current supported databases are:

Expand All @@ -31,7 +31,7 @@ package main

import (
"github.com/casbin/casbin/v2"
gormadapter "github.com/casbin/gorm-adapter/v2"
gormadapter "github.com/casbin/gorm-adapter/v3"
_ "github.com/go-sql-driver/mysql"
)

Expand Down
68 changes: 39 additions & 29 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ package gormadapter

import (
"errors"
"runtime"
"strings"

"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
"github.com/jinzhu/gorm"
"github.com/lib/pq"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"runtime"
"strings"
)

var tablePrefix string
Expand Down Expand Up @@ -64,7 +67,11 @@ type Adapter struct {

// finalizer is the destructor for Adapter.
func finalizer(a *Adapter) {
err := a.db.Close()
sqlDB, err := a.db.DB()
if err != nil {
panic(err)
}
err = sqlDB.Close()
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -131,74 +138,77 @@ func NewAdapterByDB(db *gorm.DB) (*Adapter, error) {
return a, nil
}

func (a *Adapter) createDatabase() error {
func openDBConnection(driverName, dataSourceName string) (*gorm.DB, error) {
var err error
var db *gorm.DB
if a.driverName == "postgres" {
db, err = gorm.Open(a.driverName, a.dataSourceName+" dbname=postgres")
if driverName == "postgres" {
db, err = gorm.Open(postgres.Open(dataSourceName+" dbname=postgres"), &gorm.Config{})
} else if driverName == "mysql" {
db, err = gorm.Open(mysql.Open(dataSourceName), &gorm.Config{})
} else if driverName == "sqlite3" {
db, err = gorm.Open(sqlite.Open(dataSourceName), &gorm.Config{})
} else if driverName == "sqlserver" {
db, err = gorm.Open(sqlserver.Open(dataSourceName), &gorm.Config{})
} else {
db, err = gorm.Open(a.driverName, a.dataSourceName)
return nil, errors.New("database dialect is not supported")
}
if err != nil {
return err
return nil, err
}
return db, err
}

func (a *Adapter) createDatabase() error {
var err error
db, err := openDBConnection(a.driverName, a.dataSourceName)
if err != nil {
return err
}
if a.driverName == "postgres" {
if err = db.Exec("CREATE DATABASE casbin").Error; err != nil {
// 42P04 is duplicate_database
if err.(*pq.Error).Code == "42P04" {
db.Close()
return nil
}
}
} else if a.driverName != "sqlite3" {
err = db.Exec("CREATE DATABASE IF NOT EXISTS casbin").Error
}
if err != nil {
db.Close()
return err
}

return db.Close()
return nil
}

func (a *Adapter) open() error {
var err error
var db *gorm.DB

if a.dbSpecified {
db, err = gorm.Open(a.driverName, a.dataSourceName)
db, err = openDBConnection(a.driverName, a.dataSourceName)
if err != nil {
return err
}
} else {
if err = a.createDatabase(); err != nil {
return err
}

if a.driverName == "postgres" {
db, err = gorm.Open(a.driverName, a.dataSourceName+" dbname=casbin")
db, err = openDBConnection(a.driverName, a.dataSourceName+" dbname=casbin")
} else if a.driverName == "sqlite3" {
db, err = gorm.Open(a.driverName, a.dataSourceName)
db, err = openDBConnection(a.driverName, a.dataSourceName)
} else {
db, err = gorm.Open(a.driverName, a.dataSourceName+"casbin")
db, err = openDBConnection(a.driverName, a.dataSourceName+"casbin")
}
if err != nil {
return err
}
}

a.db = db

return a.createTable()
}

func (a *Adapter) close() error {
err := a.db.Close()
if err != nil {
return err
}

a.db = nil
return nil
}
Expand All @@ -209,15 +219,15 @@ func (a *Adapter) getTableInstance() *CasbinRule {
}

func (a *Adapter) createTable() error {
if a.db.HasTable(a.getTableInstance()) {
if a.db.Migrator().HasTable(a.getTableInstance()) {
return nil
}

return a.db.CreateTable(a.getTableInstance()).Error
return a.db.Migrator().CreateTable(a.getTableInstance())
}

func (a *Adapter) dropTable() error {
return a.db.DropTable(a.getTableInstance()).Error
return a.db.Migrator().DropTable(a.getTableInstance())
}

func loadPolicyLine(line CasbinRule, model model.Model) {
Expand Down
11 changes: 5 additions & 6 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@
package gormadapter

import (
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"log"
"os"
"testing"

"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/util"
_ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mssql"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
_ "github.com/lib/pq"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)

func testGetPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
Expand Down Expand Up @@ -199,7 +198,7 @@ func TestAdapters(t *testing.T) {
testAutoSave(t, a)
testSaveLoad(t, a)

db, err := gorm.Open("mysql", "root:@tcp(127.0.0.1:3306)/casbin")
db, err := gorm.Open(mysql.Open("root:@tcp(127.0.0.1:3306)/casbin"), &gorm.Config{})
if err != nil {
panic(err)
}
Expand All @@ -210,7 +209,7 @@ func TestAdapters(t *testing.T) {
a = initAdapterWithGormInstance(t, db)
testFilteredPolicy(t, a)

db, err = gorm.Open("postgres", "user=postgres host=127.0.0.1 port=5432 sslmode=disable dbname=casbin")
db, err = gorm.Open(postgres.Open("user=postgres host=127.0.0.1 port=5432 sslmode=disable dbname=casbin"), &gorm.Config{})
if err != nil {
panic(err)
}
Expand Down
15 changes: 10 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
module github.com/casbin/gorm-adapter/v2
module github.com/casbin/gorm-adapter/v3

go 1.12

require (
github.com/casbin/casbin/v2 v2.2.2
github.com/go-sql-driver/mysql v1.4.1
github.com/jinzhu/gorm v1.9.12
github.com/lib/pq v1.1.1
github.com/stretchr/testify v1.3.0
github.com/go-sql-driver/mysql v1.5.0
github.com/lib/pq v1.3.0
github.com/mattn/go-sqlite3 v2.0.1+incompatible // indirect
github.com/stretchr/testify v1.5.1
gorm.io/driver/mysql v0.3.0
gorm.io/driver/postgres v0.2.6
gorm.io/driver/sqlite v1.0.8
gorm.io/driver/sqlserver v0.2.4
gorm.io/gorm v0.2.23
)

0 comments on commit 57a0c40

Please sign in to comment.