diff --git a/gorp_test.go b/gorp_test.go index b135fff..e3b3693 100644 --- a/gorp_test.go +++ b/gorp_test.go @@ -1720,7 +1720,7 @@ func TestColumnFilter(t *testing.T) { } } -func TestTypeConversionExample(t *testing.T) { +func TestTypeConversionDBMapExample(t *testing.T) { dbmap := initDBMap(t) defer dropAndClose(dbmap) @@ -1821,12 +1821,13 @@ func TestTypeConversionExample(t *testing.T) { t.Errorf(`Select failed: %s`, err) } - _, err = dbmap.QueryContext(context.Background(), + rows, err := dbmap.QueryContext(context.Background(), `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, personJSON, hi2) if err != nil { t.Errorf(`Select failed: %s`, err) } + _ = rows.Close() row := dbmap.QueryRowContext(context.Background(), `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, @@ -1847,6 +1848,155 @@ func TestTypeConversionExample(t *testing.T) { } } +func TestTypeConversionTransactionExample(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + p := Person{FName: "Bob", LName: "Smith"} + tc := &TypeConversionExample{-1, p, CustomStringType("hi")} + _insert(dbmap, tc) + + expected := &TypeConversionExample{1, p, CustomStringType("hi")} + tc2 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample) + if !reflect.DeepEqual(expected, tc2) { + t.Errorf("tc2 %v != %v", expected, tc2) + } + + hi2 := CustomStringType("hi2") + tc2.Name = hi2 + tc2.PersonJSON = Person{FName: "Jane", LName: "Doe"} + _update(dbmap, tc2) + + expected = &TypeConversionExample{1, tc2.PersonJSON, CustomStringType("hi2")} + tc3 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample) + if !reflect.DeepEqual(expected, tc3) { + t.Errorf("tc3 %v != %v", expected, tc3) + } + + d := dbmap.Dialect + pj := d.QuoteField("PersonJSON") + id := d.QuoteField("Id") + name := d.QuoteField("Name") + bv0 := d.BindVar(0) + bv1 := d.BindVar(1) + + // Test that the Person argument to Select goes through the + // type converter + var holder TypeConversionExample + personJSON := Person{FName: "Jane", LName: "Doe"} + ctx := context.Background() + tx, err := dbmap.BeginTx(ctx) + if err != nil { + t.Errorf("begin tx: %v", err) + return + } + defer tx.Rollback() + + _, err = tx.Select(ctx, + holder, + `select * from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + err = tx.SelectOne(ctx, + &holder, + `select * from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = tx.SelectInt(ctx, + `select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = tx.SelectInt(ctx, + `select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = tx.SelectNullInt(ctx, + `select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = tx.SelectFloat(ctx, + `select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = tx.SelectNullFloat(ctx, + `select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = tx.SelectStr(ctx, + `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = tx.SelectNullStr(ctx, + `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + rows, err := tx.QueryContext(ctx, + `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + _ = rows.Close() + + row := tx.QueryRowContext(ctx, + `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if row == nil || row.Err() != nil { + t.Errorf(`QueryRowContext failed: %s`, row.Err()) + } + // Must consume the row to release the connection. + var gotName string + err = row.Scan(&gotName) + if err != nil { + t.Errorf(`QueryRowContext failed: %s`, err) + } + + _, err = tx.ExecContext(ctx, + `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + // We must rollback to release the transaction's connection before we can + // delete the row below. + err = tx.Rollback() + if err != nil { + t.Errorf("rollback failed: %v", err) + } + + if _del(dbmap, tc) != 1 { + t.Errorf("Did not delete row with Id: %d", tc.Id) + } +} + func TestWithEmbeddedStruct(t *testing.T) { dbmap := initDBMap(t) defer dropAndClose(dbmap) diff --git a/transaction.go b/transaction.go index 27eecc4..33e2717 100644 --- a/transaction.go +++ b/transaction.go @@ -51,6 +51,11 @@ func (t *Transaction) Select(ctx context.Context, i interface{}, query string, a expandSliceArgs(&query, args...) } + args, err := t.dbmap.convertArgs(args...) + if err != nil { + return nil, err + } + return hookedselect(ctx, t.dbmap, t, i, query, args...) } @@ -60,6 +65,11 @@ func (t *Transaction) ExecContext(ctx context.Context, query string, args ...int expandSliceArgs(&query, args...) } + args, err := t.dbmap.convertArgs(args...) + if err != nil { + return nil, err + } + if t.dbmap.logger != nil { now := time.Now() defer t.dbmap.trace(now, query, args...) @@ -73,6 +83,11 @@ func (t *Transaction) SelectInt(ctx context.Context, query string, args ...inter expandSliceArgs(&query, args...) } + args, err := t.dbmap.convertArgs(args...) + if err != nil { + return 0, err + } + return SelectInt(ctx, t, query, args...) } @@ -82,6 +97,11 @@ func (t *Transaction) SelectNullInt(ctx context.Context, query string, args ...i expandSliceArgs(&query, args...) } + args, err := t.dbmap.convertArgs(args...) + if err != nil { + return sql.NullInt64{}, err + } + return SelectNullInt(ctx, t, query, args...) } @@ -91,6 +111,11 @@ func (t *Transaction) SelectFloat(ctx context.Context, query string, args ...int expandSliceArgs(&query, args...) } + args, err := t.dbmap.convertArgs(args...) + if err != nil { + return 0, err + } + return SelectFloat(ctx, t, query, args...) } @@ -100,6 +125,11 @@ func (t *Transaction) SelectNullFloat(ctx context.Context, query string, args .. expandSliceArgs(&query, args...) } + args, err := t.dbmap.convertArgs(args...) + if err != nil { + return sql.NullFloat64{}, err + } + return SelectNullFloat(ctx, t, query, args...) } @@ -109,6 +139,11 @@ func (t *Transaction) SelectStr(ctx context.Context, query string, args ...inter expandSliceArgs(&query, args...) } + args, err := t.dbmap.convertArgs(args...) + if err != nil { + return "", err + } + return SelectStr(ctx, t, query, args...) } @@ -118,6 +153,11 @@ func (t *Transaction) SelectNullStr(ctx context.Context, query string, args ...i expandSliceArgs(&query, args...) } + args, err := t.dbmap.convertArgs(args...) + if err != nil { + return sql.NullString{}, err + } + return SelectNullStr(ctx, t, query, args...) } @@ -127,6 +167,11 @@ func (t *Transaction) SelectOne(ctx context.Context, holder interface{}, query s expandSliceArgs(&query, args...) } + args, err := t.dbmap.convertArgs(args...) + if err != nil { + return err + } + return SelectOne(ctx, t.dbmap, t, holder, query, args...) } @@ -211,6 +256,11 @@ func (t *Transaction) QueryRowContext(ctx context.Context, query string, args .. expandSliceArgs(&query, args...) } + args, err := t.dbmap.convertArgs(args...) + if err != nil { + return nil + } + if t.dbmap.logger != nil { now := time.Now() defer t.dbmap.trace(now, query, args...) @@ -223,6 +273,11 @@ func (t *Transaction) QueryContext(ctx context.Context, q string, args ...interf expandSliceArgs(&q, args...) } + args, err := t.dbmap.convertArgs(args...) + if err != nil { + return nil, err + } + if t.dbmap.logger != nil { now := time.Now() defer t.dbmap.trace(now, q, args...)