From 6dafbca4959ca02cc967e07cd0a9c386fde57a9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Bi=C5=82as?= Date: Wed, 29 Aug 2018 15:54:12 +0200 Subject: [PATCH] Compression (gzip) support --- clickhouse_test.go | 12 +++++++++++- config.go | 38 +++++++++++++++++++++----------------- conn.go | 43 +++++++++++++++++++++++++++++-------------- conn_test.go | 34 +++++++++++++++++++++------------- 4 files changed, 82 insertions(+), 45 deletions(-) diff --git a/clickhouse_test.go b/clickhouse_test.go index 65796ca..19da900 100644 --- a/clickhouse_test.go +++ b/clickhouse_test.go @@ -45,7 +45,8 @@ type dbInit struct { type chSuite struct { suite.Suite - conn *sql.DB + conn *sql.DB + connWithCompression *sql.DB } func (s *chSuite) SetupSuite() { @@ -53,16 +54,25 @@ func (s *chSuite) SetupSuite() { if len(dsn) == 0 { dsn = "http://localhost:8123/default" } + conn, err := sql.Open("clickhouse", dsn) s.Require().NoError(err) s.Require().NoError(initialzer.Do(conn)) s.conn = conn + + connWithCompression, err := sql.Open("clickhouse", dsn+"?enable_http_compression=1") + s.Require().NoError(err) + s.connWithCompression = connWithCompression } func (s *chSuite) TearDownSuite() { s.conn.Close() _, err := s.conn.Query("SELECT 1") s.EqualError(err, "sql: database is closed") + + s.connWithCompression.Close() + _, err = s.connWithCompression.Query("SELECT 1") + s.EqualError(err, "sql: database is closed") } func (d *dbInit) Do(conn *sql.DB) error { diff --git a/config.go b/config.go index b3ca589..ffc888e 100644 --- a/config.go +++ b/config.go @@ -10,19 +10,20 @@ import ( // Config is a configuration parsed from a DSN string type Config struct { - User string - Password string - Scheme string - Host string - Database string - Timeout time.Duration - IdleTimeout time.Duration - ReadTimeout time.Duration - WriteTimeout time.Duration - Location *time.Location - Debug bool - UseDBLocation bool - Params map[string]string + User string + Password string + Scheme string + Host string + Database string + Timeout time.Duration + IdleTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + Location *time.Location + Debug bool + UseDBLocation bool + GzipCompression bool + Params map[string]string } // NewConfig creates a new config with default values @@ -32,6 +33,7 @@ func NewConfig() *Config { Host: "localhost:8123", IdleTimeout: time.Hour, Location: time.UTC, + Params: make(map[string]string), } } @@ -55,6 +57,9 @@ func (cfg *Config) FormatDSN() string { if cfg.Location != time.UTC && cfg.Location != nil { query.Set("location", cfg.Location.String()) } + if cfg.GzipCompression { + query.Set("enable_http_compression", "1") + } if cfg.Debug { query.Set("debug", "1") } @@ -147,11 +152,10 @@ func parseDSNParams(cfg *Config, params map[string][]string) (err error) { cfg.Debug, err = strconv.ParseBool(v[0]) case "default_format", "query", "database": err = fmt.Errorf("unknown option '%s'", k) + case "enable_http_compression": + cfg.GzipCompression, err = strconv.ParseBool(v[0]) + cfg.Params[k] = v[0] default: - // lazy init - if cfg.Params == nil { - cfg.Params = make(map[string]string) - } cfg.Params[k] = v[0] } if err != nil { diff --git a/conn.go b/conn.go index 0717539..d74eb67 100644 --- a/conn.go +++ b/conn.go @@ -1,6 +1,7 @@ package clickhouse import ( + "compress/gzip" "context" "database/sql" "database/sql/driver" @@ -17,16 +18,17 @@ import ( // conn implements an interface sql.Conn type conn struct { - url *url.URL - user *url.Userinfo - location *time.Location - useDBLocation bool - transport *http.Transport - cancel context.CancelFunc - txCtx context.Context - stmts []*stmt - logger *log.Logger - closed int32 + url *url.URL + user *url.Userinfo + location *time.Location + useDBLocation bool + useGzipCompression bool + transport *http.Transport + cancel context.CancelFunc + txCtx context.Context + stmts []*stmt + logger *log.Logger + closed int32 } func newConn(cfg *Config) *conn { @@ -35,9 +37,10 @@ func newConn(cfg *Config) *conn { logger = log.New(os.Stderr, "clickhouse: ", log.LstdFlags) } c := &conn{ - url: cfg.url(map[string]string{"default_format": "TabSeparatedWithNamesAndTypes"}, false), - location: cfg.Location, - useDBLocation: cfg.UseDBLocation, + url: cfg.url(map[string]string{"default_format": "TabSeparatedWithNamesAndTypes"}, false), + location: cfg.Location, + useDBLocation: cfg.UseDBLocation, + useGzipCompression: cfg.GzipCompression, transport: &http.Transport{ DialContext: (&net.Dialer{ Timeout: cfg.Timeout, @@ -210,7 +213,15 @@ func (c *conn) doRequest(ctx context.Context, req *http.Request) (io.ReadCloser, return nil, err } - return resp.Body, nil + respBody := resp.Body + if resp.Header.Get("Content-Encoding") == "gzip" { + respBody, err = gzip.NewReader(respBody) + if err != nil { + return nil, err + } + } + + return respBody, nil } func (c *conn) buildRequest(query string, params []driver.Value, readonly bool) (*http.Request, error) { @@ -235,6 +246,10 @@ func (c *conn) buildRequest(query string, params []driver.Value, readonly bool) p, _ := c.user.Password() req.SetBasicAuth(c.user.Username(), p) } + if c.useGzipCompression { + req.Header.Set("Accept-Encoding", "gzip") + } + return req, err } diff --git a/conn_test.go b/conn_test.go index 6a747bc..9146588 100644 --- a/conn_test.go +++ b/conn_test.go @@ -44,22 +44,30 @@ func (s *connSuite) TestQuery() { }, } - for _, tc := range testCases { - rows, err := s.conn.Query(tc.query, tc.args...) - if !s.NoError(err) { - continue - } - if len(tc.expected) == 0 { - s.False(rows.Next()) - s.NoError(rows.Err()) - } else { - v, err := scanValues(rows, tc.expected[0]) - if s.NoError(err) { - s.Equal(tc.expected, v) + doTests := func(conn *sql.DB) { + for _, tc := range testCases { + rows, err := conn.Query(tc.query, tc.args...) + if !s.NoError(err) { + continue } + if len(tc.expected) == 0 { + s.False(rows.Next()) + s.NoError(rows.Err()) + } else { + v, err := scanValues(rows, tc.expected[0]) + if s.NoError(err) { + s.Equal(tc.expected, v) + } + } + s.NoError(rows.Close()) } - s.NoError(rows.Close()) } + + // Tests on regular connection + doTests(s.conn) + + // Tests on connections with enabled compression + doTests(s.connWithCompression) } func (s *connSuite) TestExec() {