Skip to content

Commit

Permalink
Merge aee641d into 02cccac
Browse files Browse the repository at this point in the history
  • Loading branch information
vano144 committed Dec 3, 2019
2 parents 02cccac + aee641d commit 93b41a0
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 19 deletions.
9 changes: 9 additions & 0 deletions clickhouse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type chSuite struct {
suite.Suite
conn *sql.DB
connWithCompression *sql.DB
connWithKillQuery *sql.DB
}

func (s *chSuite) SetupSuite() {
Expand All @@ -67,6 +68,10 @@ func (s *chSuite) SetupSuite() {
connWithCompression, err := sql.Open("clickhouse", dsn+"?enable_http_compression=1")
s.Require().NoError(err)
s.connWithCompression = connWithCompression

connWithKillQuery, err := sql.Open("clickhouse", dsn+"?kill_query=1&read_timeout=1s")
s.Require().NoError(err)
s.connWithKillQuery = connWithKillQuery
}

func (s *chSuite) TearDownSuite() {
Expand All @@ -77,6 +82,10 @@ func (s *chSuite) TearDownSuite() {
s.connWithCompression.Close()
_, err = s.connWithCompression.Query("SELECT 1")
s.EqualError(err, "sql: database is closed")

s.connWithKillQuery.Close()
_, err = s.connWithKillQuery.Query("SELECT 1")
s.EqualError(err, "sql: database is closed")
}

func (d *dbInit) Do(conn *sql.DB) error {
Expand Down
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Config struct {
GzipCompression bool
Params map[string]string
TLSConfig string
KillQueryOnErr bool // kill query on the server side if we have error from transport
}

// NewConfig creates a new config with default values
Expand Down Expand Up @@ -64,6 +65,9 @@ func (cfg *Config) FormatDSN() string {
if cfg.Debug {
query.Set("debug", "1")
}
if cfg.KillQueryOnErr {
query.Set("kill_query", "1")
}

u.RawQuery = query.Encode()
return u.String()
Expand Down Expand Up @@ -157,6 +161,8 @@ func parseDSNParams(cfg *Config, params map[string][]string) (err error) {
cfg.Params[k] = v[0]
case "tls_config":
cfg.TLSConfig = v[0]
case "kill_query":
cfg.KillQueryOnErr, err = strconv.ParseBool(v[0])
default:
cfg.Params[k] = v[0]
}
Expand Down
7 changes: 5 additions & 2 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

func TestParseDSN(t *testing.T) {
dsn := "http://username:password@localhost:8123/test?timeout=1s&idle_timeout=2s&read_timeout=3s" +
"&write_timeout=4s&location=Local&max_execution_time=10&debug=1"
"&write_timeout=4s&location=Local&max_execution_time=10&debug=1&kill_query=1"
cfg, err := ParseDSN(dsn)
if assert.NoError(t, err) {
assert.Equal(t, "username", cfg.User)
Expand All @@ -23,6 +23,7 @@ func TestParseDSN(t *testing.T) {
assert.Equal(t, 4*time.Second, cfg.WriteTimeout)
assert.Equal(t, time.Local, cfg.Location)
assert.True(t, cfg.Debug)
assert.True(t, cfg.KillQueryOnErr)
assert.Equal(t, map[string]string{"max_execution_time": "10"}, cfg.Params)
}
}
Expand All @@ -35,6 +36,7 @@ func TestDefaultConfig(t *testing.T) {
assert.Empty(t, cfg.User)
assert.Empty(t, cfg.Password)
assert.False(t, cfg.Debug)
assert.False(t, cfg.KillQueryOnErr)
assert.Equal(t, time.UTC, cfg.Location)
assert.EqualValues(t, 0, cfg.ReadTimeout)
assert.EqualValues(t, 0, cfg.WriteTimeout)
Expand All @@ -60,7 +62,7 @@ func TestParseWrongDSN(t *testing.T) {

func TestFormatDSN(t *testing.T) {
dsn := "http://username:password@localhost:8123/test?timeout=1s&idle_timeout=2s&read_timeout=3s" +
"&write_timeout=4s&location=Europe%2FMoscow&max_execution_time=10&debug=1"
"&write_timeout=4s&location=Europe%2FMoscow&max_execution_time=10&debug=1&kill_query=1"
cfg, err := ParseDSN(dsn)
if assert.NoError(t, err) {
dsn2 := cfg.FormatDSN()
Expand All @@ -73,6 +75,7 @@ func TestFormatDSN(t *testing.T) {
assert.Contains(t, dsn2, "location=Europe%2FMoscow")
assert.Contains(t, dsn2, "max_execution_time=10")
assert.Contains(t, dsn2, "debug=1")
assert.Contains(t, dsn2, "kill_query=1")
}
}

Expand Down
78 changes: 64 additions & 14 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"log"
"net"
Expand All @@ -13,6 +15,8 @@ import (
"strings"
"sync/atomic"
"time"

uuid "github.com/satori/go.uuid"
)

type key int
Expand All @@ -27,6 +31,11 @@ const (
queryIDParamName = "query_id"
)

// errors
var (
errEmptyQueryID = errors.New("query id is empty")
)

// conn implements an interface sql.Conn
type conn struct {
url *url.URL
Expand All @@ -40,6 +49,7 @@ type conn struct {
stmts []*stmt
logger *log.Logger
closed int32
killQueryOnErr bool
}

func newConn(cfg *Config) *conn {
Expand All @@ -52,6 +62,7 @@ func newConn(cfg *Config) *conn {
location: cfg.Location,
useDBLocation: cfg.UseDBLocation,
useGzipCompression: cfg.GzipCompression,
killQueryOnErr: cfg.KillQueryOnErr,
transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: cfg.Timeout,
Expand Down Expand Up @@ -172,6 +183,31 @@ func (c *conn) beginTx(ctx context.Context) (driver.Tx, error) {
return c, nil
}

func (c *conn) killQuery(req *http.Request, args []driver.Value) error {
if !c.killQueryOnErr {
return nil
}
queryID := req.URL.Query().Get(queryIDParamName)
if queryID == "" {
return errEmptyQueryID
}
query := fmt.Sprintf("KILL QUERY WHERE query_id='%s'", queryID)
ctx, cancelFunc := context.WithTimeout(context.Background(), c.transport.ResponseHeaderTimeout)
defer cancelFunc()
req, err := c.buildRequest(ctx, query, args, false)
if err != nil {
return err
}
body, err := c.doRequest(ctx, req)
if err != nil {
return err
}
if body != nil {
body.Close()
}
return nil
}

func (c *conn) query(ctx context.Context, query string, args []driver.Value) (driver.Rows, error) {
if atomic.LoadInt32(&c.closed) != 0 {
return nil, driver.ErrBadConn
Expand All @@ -182,6 +218,12 @@ func (c *conn) query(ctx context.Context, query string, args []driver.Value) (dr
}
body, err := c.doRequest(ctx, req)
if err != nil {
if _, ok := err.(*Error); !ok && err != driver.ErrBadConn {
killErr := c.killQuery(req, args)
if killErr != nil {
c.log("error from killQuery", killErr)
}
}
return nil, err
}

Expand Down Expand Up @@ -248,26 +290,34 @@ func (c *conn) buildRequest(ctx context.Context, query string, params []driver.V
}
c.log("query: ", query)
req, err := http.NewRequest(method, c.url.String(), strings.NewReader(query))
if err != nil {
return nil, err
}
// http.Transport ignores url.User argument, handle it here
if err == nil && c.user != nil {
if c.user != nil {
p, _ := c.user.Password()
req.SetBasicAuth(c.user.Username(), p)
}
var queryID, quotaKey string
if ctx != nil {
quotaKey, quotaOk := ctx.Value(QuotaKey).(string)
queryID, queryOk := ctx.Value(QueryID).(string)
if quotaOk || queryOk {
reqQuery := req.URL.Query()
if quotaOk {
reqQuery.Add(quotaKeyParamName, quotaKey)
}
if queryOk && len(queryID) > 0 {
reqQuery.Add(queryIDParamName, queryID)
}
req.URL.RawQuery = reqQuery.Encode()
}
quotaKey, _ = ctx.Value(QuotaKey).(string)
queryID, _ = ctx.Value(QueryID).(string)
}

if c.killQueryOnErr && queryID == "" {
queryID = uuid.NewV4().String()
}
return req, err

reqQuery := req.URL.Query()
if quotaKey != "" {
reqQuery.Add(quotaKeyParamName, quotaKey)
}
if queryID != "" {
reqQuery.Add(queryIDParamName, queryID)
}
req.URL.RawQuery = reqQuery.Encode()

return req, nil
}

func (c *conn) prepare(query string) (*stmt, error) {
Expand Down
28 changes: 26 additions & 2 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"net/http"
"testing"
"time"

uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/suite"
)

Expand Down Expand Up @@ -73,6 +75,9 @@ func (s *connSuite) TestQuery() {

// Tests on connections with enabled compression
doTests(s.connWithCompression)

// Tests on connection with enabled kill connection
doTests(s.connWithKillQuery)
}

func (s *connSuite) TestExec() {
Expand Down Expand Up @@ -171,6 +176,25 @@ func (s *connSuite) TestServerError() {
s.Contains(srvErr.Error(), "Code: 62, Message: Syntax error:")
}

func (s *connSuite) TestServerKillQuery() {
queryID := uuid.NewV4().String()
ctx := context.WithValue(context.Background(), queryIDParamName, queryID)
_, err := s.connWithKillQuery.QueryContext(ctx, "SELECT sleep(2)")
s.Error(err)
s.Contains(err.Error(), "net/http: timeout awaiting response headers")
rows := s.connWithKillQuery.QueryRow(fmt.Sprintf("SELECT count(query_id) FROM system.processes where query_id='%s'", queryID))
var amount int
err = rows.Scan(&amount)
s.NoError(err)
s.Equal(0, amount)

_, err = s.connWithKillQuery.QueryContext(ctx, "SELECT sleep(0.5)")
s.NoError(err)

_, err = s.conn.QueryContext(ctx, "SELECT sleep(2)")
s.NoError(err)
}

func (s *connSuite) TestBuildRequestReadonlyWithAuth() {
cfg := NewConfig()
cfg.User = "user"
Expand Down Expand Up @@ -250,7 +274,7 @@ func (s *connSuite) TestBuildRequestWithQuotaKey() {
}{
{
"",
cn.url.String() + "&quota_key=",
cn.url.String(),
},
{
"quota-key",
Expand Down Expand Up @@ -295,7 +319,7 @@ func (s *connSuite) TestBuildRequestWithQueryIdAndQuotaKey() {
{
"",
"",
cn.url.String() + "&quota_key=",
cn.url.String(),
},
{
"quota-key",
Expand Down
7 changes: 6 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module github.com/mailru/go-clickhouse

require github.com/stretchr/testify v1.3.0
require (
github.com/kr/pretty v0.1.0 // indirect
github.com/satori/go.uuid v1.2.0
github.com/stretchr/testify v1.3.0
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
)
9 changes: 9 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

0 comments on commit 93b41a0

Please sign in to comment.