/
compression.go
157 lines (128 loc) · 3.62 KB
/
compression.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package compressor
import (
"bytes"
"compress/zlib"
"io"
"github.com/golang/snappy"
"github.com/mongodb/mongo-go-driver/core/wiremessage"
)
// Compressor is the interface implemented by types that can compress and decompress wire messages. This is used
// when sending and receiving messages to and from the server.
type Compressor interface {
CompressBytes(src, dest []byte) ([]byte, error)
UncompressBytes(src, dest []byte) ([]byte, error)
CompressorID() wiremessage.CompressorID
Name() string
}
type writer struct {
buf []byte
}
// Write appends bytes to the writer
func (w *writer) Write(p []byte) (n int, err error) {
index := len(w.buf)
if len(p) > cap(w.buf)-index {
buf := make([]byte, 2*cap(w.buf)+len(p))
copy(buf, w.buf)
w.buf = buf
}
w.buf = w.buf[:index+len(p)]
copy(w.buf[index:], p)
return len(p), nil
}
// SnappyCompressor uses the snappy method to compress data
type SnappyCompressor struct {
}
// ZlibCompressor uses the zlib method to compress data
type ZlibCompressor struct {
level int
zlibWriter *zlib.Writer
}
// CompressBytes uses snappy to compress a slice of bytes.
func (s *SnappyCompressor) CompressBytes(src, dest []byte) ([]byte, error) {
dest = dest[:0]
dest = snappy.Encode(dest, src)
return dest, nil
}
// UncompressBytes uses snappy to uncompress a slice of bytes.
func (s *SnappyCompressor) UncompressBytes(src, dest []byte) ([]byte, error) {
var err error
dest, err = snappy.Decode(dest, src)
if err != nil {
return dest, err
}
return dest, nil
}
// CompressorID returns the ID for the snappy compressor.
func (s *SnappyCompressor) CompressorID() wiremessage.CompressorID {
return wiremessage.CompressorSnappy
}
// Name returns the string name for the snappy compressor.
func (s *SnappyCompressor) Name() string {
return "snappy"
}
// CompressBytes uses zlib to compress a slice of bytes.
func (z *ZlibCompressor) CompressBytes(src, dest []byte) ([]byte, error) {
dest = dest[:0]
z.zlibWriter.Reset(&writer{
buf: dest,
})
_, err := z.zlibWriter.Write(src)
if err != nil {
_ = z.zlibWriter.Close()
return dest, err
}
err = z.zlibWriter.Close()
if err != nil {
return dest, err
}
return dest, nil
}
// UncompressBytes uses zlib to uncompress a slice of bytes. It assumes dest is empty and is the exact size that it
// needs to be.
func (z *ZlibCompressor) UncompressBytes(src, dest []byte) ([]byte, error) {
reader := bytes.NewReader(src)
zlibReader, err := zlib.NewReader(reader)
if err != nil {
return dest, err
}
defer func() {
_ = zlibReader.Close()
}()
_, err = io.ReadFull(zlibReader, dest)
if err != nil {
return dest, err
}
return dest, nil
}
// CompressorID returns the ID for the zlib compressor.
func (z *ZlibCompressor) CompressorID() wiremessage.CompressorID {
return wiremessage.CompressorZLib
}
// Name returns the name for the zlib compressor.
func (z *ZlibCompressor) Name() string {
return "zlib"
}
// CreateSnappy creates a snappy compressor
func CreateSnappy() Compressor {
return &SnappyCompressor{}
}
// CreateZlib creates a zlib compressor
func CreateZlib(level int) (Compressor, error) {
if level < 0 {
level = wiremessage.DefaultZlibLevel
}
var compressBuf bytes.Buffer
zlibWriter, err := zlib.NewWriterLevel(&compressBuf, level)
if err != nil {
return &ZlibCompressor{}, err
}
return &ZlibCompressor{
level: level,
zlibWriter: zlibWriter,
}, nil
}