Skip to content
Browse files

Update for Go tip API changes.

  • Loading branch information...
1 parent e7db323 commit 3efa8cf1e1952de8e9657fc31408a0528bfebe58 @bradfitz committed
View
14 src/github.com/bmizerany/pq.go/conn.go
@@ -122,7 +122,7 @@ func New(rwc io.ReadWriteCloser, params proto.Values, pw string) (*Conn, error)
panic("not reached")
}
-func (cn *Conn) Exec(query string, args []interface{}) (driver.Result, error) {
+func (cn *Conn) Exec(query string, args []driver.Value) (driver.Result, error) {
if len(args) == 0 {
err := cn.p.SimpleQuery(query)
if err != nil {
@@ -294,7 +294,7 @@ func (stmt *Stmt) NumInput() int {
return len(stmt.params)
}
-func (stmt *Stmt) Exec(args []interface{}) (driver.Result, error) {
+func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) {
// NOTE: should return []drive.Result, because a PS can have more
// than one statement and recv more than one tag.
rows, err := stmt.Query(args)
@@ -317,9 +317,13 @@ func (stmt *Stmt) Exec(args []interface{}) (driver.Result, error) {
return driver.RowsAffected(0), nil
}
-func (stmt *Stmt) Query(args []interface{}) (driver.Rows, error) {
+func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) {
// For now, we'll just say they're strings
- sargs := encodeParams(args)
+ iargs := make([]interface{}, len(args))
+ for i, a := range args {
+ iargs[i] = a
+ }
+ sargs := encodeParams(iargs)
err := stmt.p.Bind(stmt.Name, stmt.Name, sargs...)
if err != nil {
@@ -388,7 +392,7 @@ func (r *Rows) Columns() []string {
return r.names
}
-func (r *Rows) Next(dest []interface{}) (err error) {
+func (r *Rows) Next(dest []driver.Value) (err error) {
if r.done {
return io.EOF
}
View
8 src/github.com/mattn/go-sqlite3/sqlite3.go
@@ -146,7 +146,7 @@ func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s))
}
-func (s *SQLiteStmt) bind(args []interface{}) error {
+func (s *SQLiteStmt) bind(args []driver.Value) error {
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db)))
@@ -195,7 +195,7 @@ func (s *SQLiteStmt) bind(args []interface{}) error {
return nil
}
-func (s *SQLiteStmt) Query(args []interface{}) (driver.Rows, error) {
+func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
if err := s.bind(args); err != nil {
return nil, err
}
@@ -214,7 +214,7 @@ func (r *SQLiteResult) RowsAffected() (int64, error) {
return int64(C.sqlite3_changes(r.s.c.db)), nil
}
-func (s *SQLiteStmt) Exec(args []interface{}) (driver.Result, error) {
+func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
if err := s.bind(args); err != nil {
return nil, err
}
@@ -245,7 +245,7 @@ func (rc *SQLiteRows) Columns() []string {
return rc.cols
}
-func (rc *SQLiteRows) Next(dest []interface{}) error {
+func (rc *SQLiteRows) Next(dest []driver.Value) error {
rv := C.sqlite3_step(rc.s.s)
if rv != C.SQLITE_ROW {
return errors.New(C.GoString(C.sqlite3_errmsg(rc.s.c.db)))
View
12 src/github.com/ziutek/mymysql/godrv/driver.go
@@ -67,7 +67,11 @@ func (s stmt) NumInput() int {
return s.my.NumParam()
}
-func (s stmt) run(args []interface{}) (rowsRes, error) {
+func (s stmt) run(vargs []driver.Value) (rowsRes, error) {
+ args := make([]interface{}, len(vargs))
+ for i, a := range vargs {
+ args[i] = a
+ }
res, err := s.my.Run(args...)
if err != nil {
return rowsRes{nil}, err
@@ -75,11 +79,11 @@ func (s stmt) run(args []interface{}) (rowsRes, error) {
return rowsRes{res}, nil
}
-func (s stmt) Exec(args []interface{}) (driver.Result, error) {
+func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.run(args)
}
-func (s stmt) Query(args []interface{}) (driver.Rows, error) {
+func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.run(args)
}
@@ -114,7 +118,7 @@ func (r rowsRes) Close() error {
}
// DATE, DATETIME, TIMESTAMP are treated as they are in Local time zone
-func (r rowsRes) Next(dest []interface{}) error {
+func (r rowsRes) Next(dest []driver.Value) error {
row, err := r.my.GetRow()
if err != nil {
return err
View
26 src/github.com/ziutek/mymysql/mysql/row.go
@@ -11,6 +11,8 @@ import (
"time"
)
+var errRange = errors.New("mysql: value out of range")
+
// Result row - contains values for any column of received row.
//
// If row is a result of ordinary text query, its element can be
@@ -78,22 +80,22 @@ func (tr Row) IntErr(nn int) (val int, err error) {
if data >= int64(_MIN_INT) && data <= int64(_MAX_INT) {
val = int(data)
} else {
- err = &strconv.NumError{fn, fmt.Sprint(data), os.ERANGE}
+ err = &strconv.NumError{fn, fmt.Sprint(data), errRange}
}
case uint32:
if data <= uint32(_MAX_INT) {
val = int(data)
} else {
- err = &strconv.NumError{fn, fmt.Sprint(data), os.ERANGE}
+ err = &strconv.NumError{fn, fmt.Sprint(data), errRange}
}
case uint64:
if data <= uint64(_MAX_INT) {
val = int(data)
} else {
- err = &strconv.NumError{fn, fmt.Sprint(data), os.ERANGE}
+ err = &strconv.NumError{fn, fmt.Sprint(data), errRange}
}
default:
- err = &strconv.NumError{fn, fmt.Sprint(data), os.EINVAL}
+ err = &strconv.NumError{fn, fmt.Sprint(data), os.ErrInvalid}
}
return
}
@@ -138,17 +140,17 @@ func (tr Row) UintErr(nn int) (val uint, err error) {
if data <= uint64(_MAX_UINT) {
val = uint(data)
} else {
- err = &strconv.NumError{fn, fmt.Sprint(data), os.ERANGE}
+ err = &strconv.NumError{fn, fmt.Sprint(data), errRange}
}
case int8, int16, int32, int64:
v := reflect.ValueOf(data).Int()
if v >= 0 && v <= int64(_MAX_UINT) {
val = uint(v)
} else {
- err = &strconv.NumError{fn, fmt.Sprint(data), os.ERANGE}
+ err = &strconv.NumError{fn, fmt.Sprint(data), errRange}
}
default:
- err = &strconv.NumError{fn, fmt.Sprint(data), os.EINVAL}
+ err = &strconv.NumError{fn, fmt.Sprint(data), os.ErrInvalid}
}
return
}
@@ -329,7 +331,7 @@ func (tr Row) BoolErr(nn int) (val bool, err error) {
case uint64:
val = (data != 0)
default:
- err = &strconv.NumError{fn, fmt.Sprint(data), os.EINVAL}
+ err = &strconv.NumError{fn, fmt.Sprint(data), os.ErrInvalid}
}
return
}
@@ -361,13 +363,13 @@ func (tr Row) Int64Err(nn int) (val int64, err error) {
case uint64, uint32, uint16, uint8:
u := reflect.ValueOf(data).Uint()
if u > math.MaxInt64 {
- err = &strconv.NumError{fn, fmt.Sprint(data), os.ERANGE}
+ err = &strconv.NumError{fn, fmt.Sprint(data), errRange}
}
val = int64(u)
case []byte:
val, err = strconv.ParseInt(string(data), 10, 64)
default:
- err = &strconv.NumError{fn, fmt.Sprint(data), os.EINVAL}
+ err = &strconv.NumError{fn, fmt.Sprint(data), os.ErrInvalid}
}
return
}
@@ -401,13 +403,13 @@ func (tr Row) Uint64Err(nn int) (val uint64, err error) {
case int64, int32, int16, int8:
i := reflect.ValueOf(data).Int()
if i < 0 {
- err = &strconv.NumError{fn, fmt.Sprint(data), os.ERANGE}
+ err = &strconv.NumError{fn, fmt.Sprint(data), errRange}
}
val = uint64(i)
case []byte:
val, err = strconv.ParseUint(string(data), 10, 64)
default:
- err = &strconv.NumError{fn, fmt.Sprint(data), os.EINVAL}
+ err = &strconv.NumError{fn, fmt.Sprint(data), os.ErrInvalid}
}
return
}
View
18 src/sqltest/sql_test.go
@@ -140,8 +140,8 @@ func (mdb *mysqlDB) RunTest(t *testing.T, fn func(params)) {
if user == "" {
user = "root"
}
- pass, err := os.Getenverror("GOSQLTEST_MYSQL_PASS")
- if err != nil {
+ pass, ok := getenvOk("GOSQLTEST_MYSQL_PASS")
+ if !ok {
pass = "root"
}
dbName := "gosqltest"
@@ -265,3 +265,17 @@ func testTxQuery(t params) {
t.Fatal(err)
}
}
+
+func getenvOk(k string) (v string, ok bool) {
+ v = os.Getenv(k)
+ if v != "" {
+ return v, true
+ }
+ keq := k + "="
+ for _, kv := range os.Environ() {
+ if kv == keq {
+ return "", true
+ }
+ }
+ return "", false
+}

0 comments on commit 3efa8cf

Please sign in to comment.
Something went wrong with that request. Please try again.