Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 182 additions & 35 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,32 @@ type DbMap struct {

TypeConverter TypeConverter

tables []*TableMap
logger GorpLogger
logPrefix string
tables []*TableMap
tablesDynamic map[string]*TableMap // tables that use same go-struct and different db table names
logger GorpLogger
logPrefix string
}

func (m *DbMap) dynamicTableAdd(tableName string, tbl *TableMap) {
if nil == m.tablesDynamic {
m.tablesDynamic = make(map[string]*TableMap)
}
m.tablesDynamic[tableName] = tbl
}

func (m *DbMap) dynamicTableFind(tableName string) (*TableMap, bool) {
if nil == m.tablesDynamic {
return nil, false
}
tbl, found := m.tablesDynamic[tableName]
return tbl, found
}

func (m *DbMap) dynamicTableMap() map[string]*TableMap {
if nil == m.tablesDynamic {
m.tablesDynamic = make(map[string]*TableMap)
}
return m.tablesDynamic
}

func (m *DbMap) CreateIndex() error {
Expand All @@ -52,36 +75,52 @@ func (m *DbMap) CreateIndex() error {
dialect := reflect.TypeOf(m.Dialect)
for _, table := range m.tables {
for _, index := range table.indexes {

s := bytes.Buffer{}
s.WriteString("create")
if index.Unique {
s.WriteString(" unique")
}
s.WriteString(" index")
s.WriteString(fmt.Sprintf(" %s on %s", index.IndexName, table.TableName))
if dname := dialect.Name(); dname == "PostgresDialect" && index.IndexType != "" {
s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType))
}
s.WriteString(" (")
for x, col := range index.columns {
if x > 0 {
s.WriteString(", ")
}
s.WriteString(m.Dialect.QuoteField(col))
err = m.createIndexImpl(dialect, table, index)
if err != nil {
break
}
s.WriteString(")")
}
}

if dname := dialect.Name(); dname == "MySQLDialect" && index.IndexType != "" {
s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType))
}
s.WriteString(";")
_, err = m.Exec(s.String())
for _, table := range m.dynamicTableMap() {
for _, index := range table.indexes {
err = m.createIndexImpl(dialect, table, index)
if err != nil {
break
}
}
}

return err
}

func (m *DbMap) createIndexImpl(dialect reflect.Type,
table *TableMap,
index *IndexMap) error {
s := bytes.Buffer{}
s.WriteString("create")
if index.Unique {
s.WriteString(" unique")
}
s.WriteString(" index")
s.WriteString(fmt.Sprintf(" %s on %s", index.IndexName, table.TableName))
if dname := dialect.Name(); dname == "PostgresDialect" && index.IndexType != "" {
s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType))
}
s.WriteString(" (")
for x, col := range index.columns {
if x > 0 {
s.WriteString(", ")
}
s.WriteString(m.Dialect.QuoteField(col))
}
s.WriteString(")")

if dname := dialect.Name(); dname == "MySQLDialect" && index.IndexType != "" {
s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType))
}
s.WriteString(";")
_, err := m.Exec(s.String())
return err
}

Expand Down Expand Up @@ -155,6 +194,36 @@ func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name str
return tmap
}

// AddTableDynamic registers the given interface type with gorp.
// The table name will be dynamically determined at runtime by
// using the GetTableName method on DynamicTable interface
func (m *DbMap) AddTableDynamic(inp DynamicTable, schema string) *TableMap {

val := reflect.ValueOf(inp)
elm := val.Elem()
t := elm.Type()
name := inp.TableName()
if "" == name {
panic("Missing table name in DynamicTable instance")
}

// Check if there is another dynamic table with the same name
if _, found := m.dynamicTableFind(name); found {
panic(fmt.Sprintf("A table with the same name %v already exists", name))
}

tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m}
var primaryKey []*ColumnMap
tmap.Columns, primaryKey = m.readStructColumns(t)
if len(primaryKey) > 0 {
tmap.keys = append(tmap.keys, primaryKey...)
}

m.dynamicTableAdd(name, tmap)

return tmap
}

func (m *DbMap) readStructColumns(t reflect.Type) (cols []*ColumnMap, primaryKey []*ColumnMap) {
primaryKey = make([]*ColumnMap, 0)
n := t.NumField()
Expand Down Expand Up @@ -306,23 +375,44 @@ func (m *DbMap) createTables(ifNotExists bool) error {
sql := table.SqlForCreate(ifNotExists)
_, err = m.Exec(sql)
if err != nil {
break
return err
}
}

for _, tbl := range m.dynamicTableMap() {
sql := tbl.SqlForCreate(ifNotExists)
_, err = m.Exec(sql)
if err != nil {
return err
}
}

return err
}

// DropTable drops an individual table.
// Returns an error when the table does not exist.
func (m *DbMap) DropTable(table interface{}) error {
t := reflect.TypeOf(table)
return m.dropTable(t, false)

tableName := ""
if dyn, ok := table.(DynamicTable); ok {
tableName = dyn.TableName()
}

return m.dropTable(t, tableName, false)
}

// DropTableIfExists drops an individual table when the table exists.
func (m *DbMap) DropTableIfExists(table interface{}) error {
t := reflect.TypeOf(table)
return m.dropTable(t, true)

tableName := ""
if dyn, ok := table.(DynamicTable); ok {
tableName = dyn.TableName()
}

return m.dropTable(t, tableName, true)
}

// DropTables iterates through TableMaps registered to this DbMap and
Expand All @@ -347,12 +437,20 @@ func (m *DbMap) dropTables(addIfExists bool) (err error) {
return err
}
}

for _, table := range m.dynamicTableMap() {
err = m.dropTableImpl(table, addIfExists)
if err != nil {
return err
}
}

return err
}

// Implementation of dropping a single table.
func (m *DbMap) dropTable(t reflect.Type, addIfExists bool) error {
table := tableOrNil(m, t)
func (m *DbMap) dropTable(t reflect.Type, name string, addIfExists bool) error {
table := tableOrNil(m, t, name)
if table == nil {
return fmt.Errorf("table %s was not registered", table.TableName)
}
Expand Down Expand Up @@ -382,6 +480,14 @@ func (m *DbMap) TruncateTables() error {
err = e
}
}

for _, table := range m.dynamicTableMap() {
_, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)))
if e != nil {
err = e
}
}

return err
}

Expand Down Expand Up @@ -547,7 +653,7 @@ func (m *DbMap) Begin() (*Transaction, error) {
// If no table is mapped to that type an error is returned.
// If checkPK is true and the mapped table has no registered PKs, an error is returned.
func (m *DbMap) TableFor(t reflect.Type, checkPK bool) (*TableMap, error) {
table := tableOrNil(m, t)
table := tableOrNil(m, t, "")
if table == nil {
return nil, fmt.Errorf("no table found for type: %v", t.Name())
}
Expand All @@ -561,6 +667,27 @@ func (m *DbMap) TableFor(t reflect.Type, checkPK bool) (*TableMap, error) {
return table, nil
}

// TableForDynamic returns the *TableMap for the dynamic table corresponding
// to the input tablename
// If no table is mapped to that tablename an error is returned.
// If checkPK is true and the mapped table has no registered PKs, an error is returned.
func (m *DbMap) TableForDynamic(tableName string, checkPK bool) (*TableMap, error) {

table, found := m.dynamicTableFind(tableName)

if false == found {
return nil, fmt.Errorf("gorp: no table found for name: %v", tableName)
}

if checkPK && len(table.keys) < 1 {
e := fmt.Sprintf("gorp: no keys defined for table: %s",
table.TableName)
return nil, errors.New(e)
}

return table, nil
}

// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the returned statement.
// This is equivalent to running: Prepare() using database/sql
Expand All @@ -572,7 +699,17 @@ func (m *DbMap) Prepare(query string) (*sql.Stmt, error) {
return m.Db.Prepare(query)
}

func tableOrNil(m *DbMap, t reflect.Type) *TableMap {
func tableOrNil(m *DbMap, t reflect.Type, name string) *TableMap {

if "" != name {
// Search by table name (dynamic tables)
if table, found := m.dynamicTableFind(name); found {
return table
} else {
return nil
}
}

for i := range m.tables {
table := m.tables[i]
if table.gotype == t {
Expand All @@ -590,8 +727,18 @@ func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, refle
return nil, reflect.Value{}, errors.New(e)
}
elem := ptrv.Elem()
etype := reflect.TypeOf(elem.Interface())
t, err := m.TableFor(etype, checkPK)
ifc := elem.Interface()
var t *TableMap
var err error
tableName := ""
if dyn, isDyn := ptr.(DynamicTable); isDyn {
tableName = dyn.TableName()
t, err = m.TableForDynamic(tableName, checkPK)
} else {
etype := reflect.TypeOf(ifc)
t, err = m.TableFor(etype, checkPK)
}

if err != nil {
return nil, reflect.Value{}, err
}
Expand Down
29 changes: 26 additions & 3 deletions gorp.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ type SqlExecutor interface {
queryRow(query string, args ...interface{}) *sql.Row
}

// DynamicTable allows the users of gorp to dynamically
// use different database table names during runtime
// while sharing the same golang struct for in-memory data
type DynamicTable interface {
TableName() string
SetTableName(string)
}

// Compile-time check that DbMap and Transaction implement the SqlExecutor
// interface.
var _, _ SqlExecutor = &DbMap{}, &Transaction{}
Expand Down Expand Up @@ -220,13 +228,13 @@ func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect
}), args
}

func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error) {
func columnToFieldIndex(m *DbMap, t reflect.Type, name string, cols []string) ([][]int, error) {
colToFieldIndex := make([][]int, len(cols))

// check if type t is a mapped table - if so we'll
// check the table for column aliasing below
tableMapped := false
table := tableOrNil(m, t)
table := tableOrNil(m, t, name)
if table != nil {
tableMapped = true
}
Expand Down Expand Up @@ -335,14 +343,29 @@ func get(m *DbMap, exec SqlExecutor, i interface{},
return nil, err
}

table, err := m.TableFor(t, true)
var table *TableMap
tableName := ""
var dyn DynamicTable
isDynamic := false
if dyn, isDynamic = i.(DynamicTable); isDynamic {
tableName = dyn.TableName()
table, err = m.TableForDynamic(tableName, true)
} else {
table, err = m.TableFor(t, true)
}

if err != nil {
return nil, err
}

plan := table.bindGet()

v := reflect.New(t)
if true == isDynamic {
retDyn := v.Interface().(DynamicTable)
retDyn.SetTableName(tableName)
}

dest := make([]interface{}, len(plan.argFields))

conv := m.TypeConverter
Expand Down
Loading