From c4f1976e433a0e5f4a0d612ee45c85f689076ae7 Mon Sep 17 00:00:00 2001 From: Alex Snast Date: Sun, 9 Feb 2020 16:41:49 +0200 Subject: [PATCH] connection: interpolate json.RawMessage as string (#1058) json encoded data is represented as bytes however it should be interpolated as a string Fixes #819 --- AUTHORS | 1 + connection.go | 9 +++++++++ connection_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/AUTHORS b/AUTHORS index ad5989800..0896ba1bc 100644 --- a/AUTHORS +++ b/AUTHORS @@ -13,6 +13,7 @@ Aaron Hopkins Achille Roussel +Alex Snast Alexey Palazhchenko Andrew Reid Arne Hormann diff --git a/connection.go b/connection.go index e4bb59e67..b07cd7651 100644 --- a/connection.go +++ b/connection.go @@ -12,6 +12,7 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/json" "io" "net" "strconv" @@ -271,6 +272,14 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } buf = append(buf, '\'') } + case json.RawMessage: + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') case []byte: if v == nil { buf = append(buf, "NULL"...) diff --git a/connection_test.go b/connection_test.go index 19c17ff8b..a6d677308 100644 --- a/connection_test.go +++ b/connection_test.go @@ -11,6 +11,7 @@ package mysql import ( "context" "database/sql/driver" + "encoding/json" "errors" "net" "testing" @@ -36,6 +37,33 @@ func TestInterpolateParams(t *testing.T) { } } +func TestInterpolateParamsJSONRawMessage(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + buf, err := json.Marshal(struct { + Value int `json:"value"` + }{Value: 42}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + q, err := mc.interpolateParams("SELECT ?", []driver.Value{json.RawMessage(buf)}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + expected := `SELECT '{\"value\":42}'` + if q != expected { + t.Errorf("Expected: %q\nGot: %q", expected, q) + } +} + func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(nil),