From aee641de0c7375dd5a9a8e3f1ff3432b97484953 Mon Sep 17 00:00:00 2001 From: Ivan Romanov Date: Tue, 3 Dec 2019 13:01:40 +0300 Subject: [PATCH] add kill query option --- clickhouse_test.go | 9 ++++++ config.go | 6 ++++ config_test.go | 7 +++-- conn.go | 78 +++++++++++++++++++++++++++++++++++++--------- conn_test.go | 28 +++++++++++++++-- go.mod | 7 ++++- go.sum | 9 ++++++ 7 files changed, 125 insertions(+), 19 deletions(-) diff --git a/clickhouse_test.go b/clickhouse_test.go index 4340943..7aa45b7 100644 --- a/clickhouse_test.go +++ b/clickhouse_test.go @@ -51,6 +51,7 @@ type chSuite struct { suite.Suite conn *sql.DB connWithCompression *sql.DB + connWithKillQuery *sql.DB } func (s *chSuite) SetupSuite() { @@ -67,6 +68,10 @@ func (s *chSuite) SetupSuite() { connWithCompression, err := sql.Open("clickhouse", dsn+"?enable_http_compression=1") s.Require().NoError(err) s.connWithCompression = connWithCompression + + connWithKillQuery, err := sql.Open("clickhouse", dsn+"?kill_query=1&read_timeout=1s") + s.Require().NoError(err) + s.connWithKillQuery = connWithKillQuery } func (s *chSuite) TearDownSuite() { @@ -77,6 +82,10 @@ func (s *chSuite) TearDownSuite() { s.connWithCompression.Close() _, err = s.connWithCompression.Query("SELECT 1") s.EqualError(err, "sql: database is closed") + + s.connWithKillQuery.Close() + _, err = s.connWithKillQuery.Query("SELECT 1") + s.EqualError(err, "sql: database is closed") } func (d *dbInit) Do(conn *sql.DB) error { diff --git a/config.go b/config.go index 0c02e70..0bd93c2 100644 --- a/config.go +++ b/config.go @@ -25,6 +25,7 @@ type Config struct { GzipCompression bool Params map[string]string TLSConfig string + KillQueryOnErr bool // kill query on the server side if we have error from transport } // NewConfig creates a new config with default values @@ -64,6 +65,9 @@ func (cfg *Config) FormatDSN() string { if cfg.Debug { query.Set("debug", "1") } + if cfg.KillQueryOnErr { + query.Set("kill_query", "1") + } u.RawQuery = query.Encode() return u.String() @@ -157,6 +161,8 @@ func parseDSNParams(cfg *Config, params map[string][]string) (err error) { cfg.Params[k] = v[0] case "tls_config": cfg.TLSConfig = v[0] + case "kill_query": + cfg.KillQueryOnErr, err = strconv.ParseBool(v[0]) default: cfg.Params[k] = v[0] } diff --git a/config_test.go b/config_test.go index 7b79ad6..81c12fc 100644 --- a/config_test.go +++ b/config_test.go @@ -9,7 +9,7 @@ import ( func TestParseDSN(t *testing.T) { dsn := "http://username:password@localhost:8123/test?timeout=1s&idle_timeout=2s&read_timeout=3s" + - "&write_timeout=4s&location=Local&max_execution_time=10&debug=1" + "&write_timeout=4s&location=Local&max_execution_time=10&debug=1&kill_query=1" cfg, err := ParseDSN(dsn) if assert.NoError(t, err) { assert.Equal(t, "username", cfg.User) @@ -23,6 +23,7 @@ func TestParseDSN(t *testing.T) { assert.Equal(t, 4*time.Second, cfg.WriteTimeout) assert.Equal(t, time.Local, cfg.Location) assert.True(t, cfg.Debug) + assert.True(t, cfg.KillQueryOnErr) assert.Equal(t, map[string]string{"max_execution_time": "10"}, cfg.Params) } } @@ -35,6 +36,7 @@ func TestDefaultConfig(t *testing.T) { assert.Empty(t, cfg.User) assert.Empty(t, cfg.Password) assert.False(t, cfg.Debug) + assert.False(t, cfg.KillQueryOnErr) assert.Equal(t, time.UTC, cfg.Location) assert.EqualValues(t, 0, cfg.ReadTimeout) assert.EqualValues(t, 0, cfg.WriteTimeout) @@ -60,7 +62,7 @@ func TestParseWrongDSN(t *testing.T) { func TestFormatDSN(t *testing.T) { dsn := "http://username:password@localhost:8123/test?timeout=1s&idle_timeout=2s&read_timeout=3s" + - "&write_timeout=4s&location=Europe%2FMoscow&max_execution_time=10&debug=1" + "&write_timeout=4s&location=Europe%2FMoscow&max_execution_time=10&debug=1&kill_query=1" cfg, err := ParseDSN(dsn) if assert.NoError(t, err) { dsn2 := cfg.FormatDSN() @@ -73,6 +75,7 @@ func TestFormatDSN(t *testing.T) { assert.Contains(t, dsn2, "location=Europe%2FMoscow") assert.Contains(t, dsn2, "max_execution_time=10") assert.Contains(t, dsn2, "debug=1") + assert.Contains(t, dsn2, "kill_query=1") } } diff --git a/conn.go b/conn.go index 9537621..a505fd3 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" + "fmt" "io" "log" "net" @@ -13,6 +15,8 @@ import ( "strings" "sync/atomic" "time" + + uuid "github.com/satori/go.uuid" ) type key int @@ -27,6 +31,11 @@ const ( queryIDParamName = "query_id" ) +// errors +var ( + errEmptyQueryID = errors.New("query id is empty") +) + // conn implements an interface sql.Conn type conn struct { url *url.URL @@ -40,6 +49,7 @@ type conn struct { stmts []*stmt logger *log.Logger closed int32 + killQueryOnErr bool } func newConn(cfg *Config) *conn { @@ -52,6 +62,7 @@ func newConn(cfg *Config) *conn { location: cfg.Location, useDBLocation: cfg.UseDBLocation, useGzipCompression: cfg.GzipCompression, + killQueryOnErr: cfg.KillQueryOnErr, transport: &http.Transport{ DialContext: (&net.Dialer{ Timeout: cfg.Timeout, @@ -172,6 +183,31 @@ func (c *conn) beginTx(ctx context.Context) (driver.Tx, error) { return c, nil } +func (c *conn) killQuery(req *http.Request, args []driver.Value) error { + if !c.killQueryOnErr { + return nil + } + queryID := req.URL.Query().Get(queryIDParamName) + if queryID == "" { + return errEmptyQueryID + } + query := fmt.Sprintf("KILL QUERY WHERE query_id='%s'", queryID) + ctx, cancelFunc := context.WithTimeout(context.Background(), c.transport.ResponseHeaderTimeout) + defer cancelFunc() + req, err := c.buildRequest(ctx, query, args, false) + if err != nil { + return err + } + body, err := c.doRequest(ctx, req) + if err != nil { + return err + } + if body != nil { + body.Close() + } + return nil +} + func (c *conn) query(ctx context.Context, query string, args []driver.Value) (driver.Rows, error) { if atomic.LoadInt32(&c.closed) != 0 { return nil, driver.ErrBadConn @@ -182,6 +218,12 @@ func (c *conn) query(ctx context.Context, query string, args []driver.Value) (dr } body, err := c.doRequest(ctx, req) if err != nil { + if _, ok := err.(*Error); !ok && err != driver.ErrBadConn { + killErr := c.killQuery(req, args) + if killErr != nil { + c.log("error from killQuery", killErr) + } + } return nil, err } @@ -248,26 +290,34 @@ func (c *conn) buildRequest(ctx context.Context, query string, params []driver.V } c.log("query: ", query) req, err := http.NewRequest(method, c.url.String(), strings.NewReader(query)) + if err != nil { + return nil, err + } // http.Transport ignores url.User argument, handle it here - if err == nil && c.user != nil { + if c.user != nil { p, _ := c.user.Password() req.SetBasicAuth(c.user.Username(), p) } + var queryID, quotaKey string if ctx != nil { - quotaKey, quotaOk := ctx.Value(QuotaKey).(string) - queryID, queryOk := ctx.Value(QueryID).(string) - if quotaOk || queryOk { - reqQuery := req.URL.Query() - if quotaOk { - reqQuery.Add(quotaKeyParamName, quotaKey) - } - if queryOk && len(queryID) > 0 { - reqQuery.Add(queryIDParamName, queryID) - } - req.URL.RawQuery = reqQuery.Encode() - } + quotaKey, _ = ctx.Value(QuotaKey).(string) + queryID, _ = ctx.Value(QueryID).(string) + } + + if c.killQueryOnErr && queryID == "" { + queryID = uuid.NewV4().String() } - return req, err + + reqQuery := req.URL.Query() + if quotaKey != "" { + reqQuery.Add(quotaKeyParamName, quotaKey) + } + if queryID != "" { + reqQuery.Add(queryIDParamName, queryID) + } + req.URL.RawQuery = reqQuery.Encode() + + return req, nil } func (c *conn) prepare(query string) (*stmt, error) { diff --git a/conn_test.go b/conn_test.go index fbd8948..c8b65e7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,10 +4,12 @@ import ( "context" "database/sql" "database/sql/driver" + "fmt" "net/http" "testing" "time" + uuid "github.com/satori/go.uuid" "github.com/stretchr/testify/suite" ) @@ -73,6 +75,9 @@ func (s *connSuite) TestQuery() { // Tests on connections with enabled compression doTests(s.connWithCompression) + + // Tests on connection with enabled kill connection + doTests(s.connWithKillQuery) } func (s *connSuite) TestExec() { @@ -171,6 +176,25 @@ func (s *connSuite) TestServerError() { s.Contains(srvErr.Error(), "Code: 62, Message: Syntax error:") } +func (s *connSuite) TestServerKillQuery() { + queryID := uuid.NewV4().String() + ctx := context.WithValue(context.Background(), queryIDParamName, queryID) + _, err := s.connWithKillQuery.QueryContext(ctx, "SELECT sleep(2)") + s.Error(err) + s.Contains(err.Error(), "net/http: timeout awaiting response headers") + rows := s.connWithKillQuery.QueryRow(fmt.Sprintf("SELECT count(query_id) FROM system.processes where query_id='%s'", queryID)) + var amount int + err = rows.Scan(&amount) + s.NoError(err) + s.Equal(0, amount) + + _, err = s.connWithKillQuery.QueryContext(ctx, "SELECT sleep(0.5)") + s.NoError(err) + + _, err = s.conn.QueryContext(ctx, "SELECT sleep(2)") + s.NoError(err) +} + func (s *connSuite) TestBuildRequestReadonlyWithAuth() { cfg := NewConfig() cfg.User = "user" @@ -250,7 +274,7 @@ func (s *connSuite) TestBuildRequestWithQuotaKey() { }{ { "", - cn.url.String() + ""a_key=", + cn.url.String(), }, { "quota-key", @@ -295,7 +319,7 @@ func (s *connSuite) TestBuildRequestWithQueryIdAndQuotaKey() { { "", "", - cn.url.String() + ""a_key=", + cn.url.String(), }, { "quota-key", diff --git a/go.mod b/go.mod index 046726f..a924df6 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,8 @@ module github.com/mailru/go-clickhouse -require github.com/stretchr/testify v1.3.0 +require ( + github.com/kr/pretty v0.1.0 // indirect + github.com/satori/go.uuid v1.2.0 + github.com/stretchr/testify v1.3.0 + gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect +) diff --git a/go.sum b/go.sum index 380091e..a832c19 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,17 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=