Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Kevin Malachowski <kevin at chowski.com>
Kieron Woodhouse <kieron.woodhouse at infosum.com>
Lance Tian <lance6716 at gmail.com>
Lennart Rudolph <lrudolph at hmc.edu>
Lefteris Zafiris <zaf at fastmail.com>
Leonardo YongUk Kim <dalinaum at gmail.com>
Linh Tran Tuan <linhduonggnu at gmail.com>
Lion Yang <lion at aosc.xyz>
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,14 @@ Default: false

Toggles zlib compression. false by default.

##### `compressLevel`

```
Type: decimal number
Valid Values: 1-9
Default: 2
```

##### `interpolateParams`

```
Expand Down
63 changes: 23 additions & 40 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,73 +10,56 @@ package mysql

import (
"bytes"
"compress/zlib"
"fmt"
"io"
"sync"
)

var (
zrPool *sync.Pool // Do not use directly. Use zDecompress() instead.
zwPool *sync.Pool // Do not use directly. Use zCompress() instead.
"github.com/klauspost/compress/zlib"
)

func init() {
zrPool = &sync.Pool{
New: func() any { return nil },
}
zwPool = &sync.Pool{
New: func() any {
zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2)
if err != nil {
panic(err) // compress/zlib return non-nil error only if level is invalid
}
return zw
},
}
}

func zDecompress(src []byte, dst *bytes.Buffer) (int, error) {
func (c *compIO) zDecompress(src []byte) (int, error) {
br := bytes.NewReader(src)
var zr io.ReadCloser
var err error

if a := zrPool.Get(); a == nil {
if zr, err = zlib.NewReader(br); err != nil {
if c.zr == nil {
c.zr, err = zlib.NewReader(br)
if err != nil {
return 0, err
}
} else {
zr = a.(io.ReadCloser)
if err := zr.(zlib.Resetter).Reset(br, nil); err != nil {
err = c.zr.(zlib.Resetter).Reset(br, nil)
if err != nil {
return 0, err
}
}

n, _ := dst.ReadFrom(zr) // ignore err because zr.Close() will return it again.
err = zr.Close() // zr.Close() may return chuecksum error.
zrPool.Put(zr)
n, _ := c.buff.ReadFrom(c.zr) // ignore err because zr.Close() will return it again.
err = c.zr.Close() // zr.Close() may return chuecksum error.
return int(n), err
}

func zCompress(src []byte, dst io.Writer) error {
zw := zwPool.Get().(*zlib.Writer)
zw.Reset(dst)
if _, err := zw.Write(src); err != nil {
func (c *compIO) zCompress(src []byte) error {
c.zw.Reset(&c.buff)
if _, err := c.zw.Write(src); err != nil {
return err
}
err := zw.Close()
zwPool.Put(zw)
err := c.zw.Close()
return err
}

type compIO struct {
mc *mysqlConn
buff bytes.Buffer
zw *zlib.Writer
zr io.ReadCloser
}

func newCompIO(mc *mysqlConn) *compIO {
w, err := zlib.NewWriterLevel(new(bytes.Buffer), mc.cfg.compressLevel)
if err != nil {
panic(err) // compress/zlib return non-nil error only if level is invalid
}
return &compIO{
mc: mc,
zw: w,
zr: nil,
}
}

Expand Down Expand Up @@ -133,7 +116,7 @@ func (c *compIO) readCompressedPacket() error {

// use existing capacity in bytesBuf if possible
c.buff.Grow(uncompressedLength)
nread, err := zDecompress(comprData, &c.buff)
nread, err := c.zDecompress(comprData)
if err != nil {
return err
}
Expand Down Expand Up @@ -167,7 +150,7 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
buf.Write(payload)
uncompressedLen = 0
} else {
err := zCompress(payload, buf)
err := c.zCompress(payload)
if debug && err != nil {
fmt.Printf("zCompress error: %v", err)
}
Expand Down
23 changes: 20 additions & 3 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"time"
)

const defaultCompressionLevel = 2

var (
errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?")
errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
Expand Down Expand Up @@ -75,7 +77,8 @@ type Config struct {
// unexported fields. new options should be come here.
// boolean first. alphabetical order.

compress bool // Enable zlib compression
compress bool // Enable zlib compression
compressLevel int // Compression level

beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
pubKey *rsa.PublicKey // Server public key
Expand All @@ -95,6 +98,7 @@ func NewConfig() *Config {
Logger: defaultLogger,
AllowNativePasswords: true,
CheckConnLiveness: true,
compressLevel: defaultCompressionLevel,
}
return cfg
}
Expand Down Expand Up @@ -127,10 +131,14 @@ func BeforeConnect(fn func(context.Context, *Config) error) Option {
}
}

// EnableCompress sets the compression mode.
func EnableCompression(yes bool) Option {
// EnableCompress sets the compression mode and level.
func EnableCompression(yes bool, level int) Option {
return func(cfg *Config) error {
cfg.compress = yes
cfg.compressLevel = defaultCompressionLevel
if level > 0 {
cfg.compressLevel = level
}
return nil
}
}
Expand Down Expand Up @@ -563,6 +571,15 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Compression level
case "compressLevel":
cfg.compressLevel, err = strconv.Atoi(value)
if err != nil {
return
}
if cfg.compressLevel < 0 || cfg.compressLevel > 9 {
return errors.New("invalid compress level: " + value)
}

// Enable client side placeholder substitution
case "interpolateParams":
Expand Down
6 changes: 6 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ var testDSNs = []struct {
}, {
"tcp(127.0.0.1)/dbname",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
}, {
"tcp(127.0.0.1)/dbname?compress=true,compressLevel=6",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", compress: true, compressLevel: 6, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
}, {
"tcp(127.0.0.1)/dbname?compress=true",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", compress: true, compressLevel: defaultCompressionLevel, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
}, {
"tcp(de:ad:be:ef::ca:fe)/dbname",
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ module github.com/go-sql-driver/mysql
go 1.22.0

require filippo.io/edwards25519 v1.1.0

require github.com/klauspost/compress v1.18.0
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=