Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 152 additions & 2 deletions gorp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1720,7 +1720,7 @@ func TestColumnFilter(t *testing.T) {
}
}

func TestTypeConversionExample(t *testing.T) {
func TestTypeConversionDBMapExample(t *testing.T) {
dbmap := initDBMap(t)
defer dropAndClose(dbmap)

Expand Down Expand Up @@ -1821,12 +1821,13 @@ func TestTypeConversionExample(t *testing.T) {
t.Errorf(`Select failed: %s`, err)
}

_, err = dbmap.QueryContext(context.Background(),
rows, err := dbmap.QueryContext(context.Background(),
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}
_ = rows.Close()

row := dbmap.QueryRowContext(context.Background(),
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
Expand All @@ -1847,6 +1848,155 @@ func TestTypeConversionExample(t *testing.T) {
}
}

func TestTypeConversionTransactionExample(t *testing.T) {
dbmap := initDBMap(t)
defer dropAndClose(dbmap)

p := Person{FName: "Bob", LName: "Smith"}
tc := &TypeConversionExample{-1, p, CustomStringType("hi")}
_insert(dbmap, tc)

expected := &TypeConversionExample{1, p, CustomStringType("hi")}
tc2 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample)
if !reflect.DeepEqual(expected, tc2) {
t.Errorf("tc2 %v != %v", expected, tc2)
}

hi2 := CustomStringType("hi2")
tc2.Name = hi2
tc2.PersonJSON = Person{FName: "Jane", LName: "Doe"}
_update(dbmap, tc2)

expected = &TypeConversionExample{1, tc2.PersonJSON, CustomStringType("hi2")}
tc3 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample)
if !reflect.DeepEqual(expected, tc3) {
t.Errorf("tc3 %v != %v", expected, tc3)
}

d := dbmap.Dialect
pj := d.QuoteField("PersonJSON")
id := d.QuoteField("Id")
name := d.QuoteField("Name")
bv0 := d.BindVar(0)
bv1 := d.BindVar(1)

// Test that the Person argument to Select goes through the
// type converter
var holder TypeConversionExample
personJSON := Person{FName: "Jane", LName: "Doe"}
ctx := context.Background()
tx, err := dbmap.BeginTx(ctx)
if err != nil {
t.Errorf("begin tx: %v", err)
return
}
defer tx.Rollback()

_, err = tx.Select(ctx,
holder,
`select * from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

err = tx.SelectOne(ctx,
&holder,
`select * from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = tx.SelectInt(ctx,
`select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = tx.SelectInt(ctx,
`select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = tx.SelectNullInt(ctx,
`select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = tx.SelectFloat(ctx,
`select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = tx.SelectNullFloat(ctx,
`select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = tx.SelectStr(ctx,
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = tx.SelectNullStr(ctx,
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

rows, err := tx.QueryContext(ctx,
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}
_ = rows.Close()

row := tx.QueryRowContext(ctx,
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if row == nil || row.Err() != nil {
t.Errorf(`QueryRowContext failed: %s`, row.Err())
}
// Must consume the row to release the connection.
var gotName string
err = row.Scan(&gotName)
if err != nil {
t.Errorf(`QueryRowContext failed: %s`, err)
}

_, err = tx.ExecContext(ctx,
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

// We must rollback to release the transaction's connection before we can
// delete the row below.
err = tx.Rollback()
if err != nil {
t.Errorf("rollback failed: %v", err)
}

if _del(dbmap, tc) != 1 {
t.Errorf("Did not delete row with Id: %d", tc.Id)
}
}

func TestWithEmbeddedStruct(t *testing.T) {
dbmap := initDBMap(t)
defer dropAndClose(dbmap)
Expand Down
55 changes: 55 additions & 0 deletions transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ func (t *Transaction) Select(ctx context.Context, i interface{}, query string, a
expandSliceArgs(&query, args...)
}

args, err := t.dbmap.convertArgs(args...)
if err != nil {
return nil, err
}

return hookedselect(ctx, t.dbmap, t, i, query, args...)
}

Expand All @@ -60,6 +65,11 @@ func (t *Transaction) ExecContext(ctx context.Context, query string, args ...int
expandSliceArgs(&query, args...)
}

args, err := t.dbmap.convertArgs(args...)
if err != nil {
return nil, err
}

if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, query, args...)
Expand All @@ -73,6 +83,11 @@ func (t *Transaction) SelectInt(ctx context.Context, query string, args ...inter
expandSliceArgs(&query, args...)
}

args, err := t.dbmap.convertArgs(args...)
if err != nil {
return 0, err
}

return SelectInt(ctx, t, query, args...)
}

Expand All @@ -82,6 +97,11 @@ func (t *Transaction) SelectNullInt(ctx context.Context, query string, args ...i
expandSliceArgs(&query, args...)
}

args, err := t.dbmap.convertArgs(args...)
if err != nil {
return sql.NullInt64{}, err
}

return SelectNullInt(ctx, t, query, args...)
}

Expand All @@ -91,6 +111,11 @@ func (t *Transaction) SelectFloat(ctx context.Context, query string, args ...int
expandSliceArgs(&query, args...)
}

args, err := t.dbmap.convertArgs(args...)
if err != nil {
return 0, err
}

return SelectFloat(ctx, t, query, args...)
}

Expand All @@ -100,6 +125,11 @@ func (t *Transaction) SelectNullFloat(ctx context.Context, query string, args ..
expandSliceArgs(&query, args...)
}

args, err := t.dbmap.convertArgs(args...)
if err != nil {
return sql.NullFloat64{}, err
}

return SelectNullFloat(ctx, t, query, args...)
}

Expand All @@ -109,6 +139,11 @@ func (t *Transaction) SelectStr(ctx context.Context, query string, args ...inter
expandSliceArgs(&query, args...)
}

args, err := t.dbmap.convertArgs(args...)
if err != nil {
return "", err
}

return SelectStr(ctx, t, query, args...)
}

Expand All @@ -118,6 +153,11 @@ func (t *Transaction) SelectNullStr(ctx context.Context, query string, args ...i
expandSliceArgs(&query, args...)
}

args, err := t.dbmap.convertArgs(args...)
if err != nil {
return sql.NullString{}, err
}

return SelectNullStr(ctx, t, query, args...)
}

Expand All @@ -127,6 +167,11 @@ func (t *Transaction) SelectOne(ctx context.Context, holder interface{}, query s
expandSliceArgs(&query, args...)
}

args, err := t.dbmap.convertArgs(args...)
if err != nil {
return err
}

return SelectOne(ctx, t.dbmap, t, holder, query, args...)
}

Expand Down Expand Up @@ -211,6 +256,11 @@ func (t *Transaction) QueryRowContext(ctx context.Context, query string, args ..
expandSliceArgs(&query, args...)
}

args, err := t.dbmap.convertArgs(args...)
if err != nil {
return nil
}

if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, query, args...)
Expand All @@ -223,6 +273,11 @@ func (t *Transaction) QueryContext(ctx context.Context, q string, args ...interf
expandSliceArgs(&q, args...)
}

args, err := t.dbmap.convertArgs(args...)
if err != nil {
return nil, err
}

if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, q, args...)
Expand Down