/
lightpatch.go
182 lines (152 loc) · 3.65 KB
/
lightpatch.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// Package lightpatch generates and applies patch files. A description of the patch file
// format is included in the README.
package lightpatch
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io"
"io/ioutil"
"time"
)
const (
OpCopy byte = 'C'
OpInsert byte = 'I'
OpDelete byte = 'D'
OpCRC byte = 'K'
DefaultTimeout = 5 * time.Second
)
var (
ErrCRC = errors.New("CRC mismatch")
ErrExtraData = errors.New("unexpected data following CRC")
)
// MatchPatch generates a diff to change before into after, writing the output to patch.
func MakePatch(before, after io.Reader, output io.Writer) error {
return MakePatchTimeout(before, after, output, DefaultTimeout)
}
// MatchPatchTimeout generates a diff to change before into after, writing the output to
// patch. timeout is the max time to try to make an efficient patch. The operation will
// still succeed even if timeout is reached, with perhaps a less compact patch. If timeout
// is 0 the function will take as long as it needs to complete.
func MakePatchTimeout(before, after io.Reader, patch io.Writer, timeout time.Duration) error {
beforeBytes, err := ioutil.ReadAll(before)
if err != nil {
return err
}
afterBytes, err := ioutil.ReadAll(after)
if err != nil {
return err
}
diffs := diffMain(beforeBytes, afterBytes, timeout)
// If inputs are very different, the total size of the encoded diffs can be greater than just
// outputting after bytes. We'll check whether this "naive" diff is actually shorter.
naiveDiff := []diff{
{
Type: OpInsert,
Text: afterBytes,
},
}
if encodedLen(naiveDiff) < encodedLen(diffs) {
diffs = naiveDiff
}
varintBuf := make([]byte, binary.MaxVarintLen64)
for _, diff := range diffs {
if _, err := patch.Write([]byte{diff.Type}); err != nil {
return err
}
n := binary.PutUvarint(varintBuf, uint64(len(diff.Text)))
if _, err := patch.Write(varintBuf[:n]); err != nil {
return err
}
if diff.Type == OpInsert {
if _, err := patch.Write(diff.Text); err != nil {
return err
}
}
}
n := crc32.NewIEEE()
n.Write(afterBytes)
if _, err := patch.Write(n.Sum([]byte{OpCRC})); err != nil {
return err
}
return nil
}
// ApplyPatch reads before, applies the edits from patch, and writes
// the output to after.
func ApplyPatch(before, patch io.Reader, after io.Writer) error {
var crcRead bool
var n = crc32.NewIEEE()
after = io.MultiWriter(after, n)
beforeBR := bufio.NewReader(before)
patchBR := bufio.NewReader(patch)
for {
op, err := patchBR.ReadByte()
if err == io.EOF {
break
} else if err != nil {
return err
}
if crcRead {
return ErrExtraData
}
var tl uint64
if op != OpCRC {
tl, err = binary.ReadUvarint(patchBR)
if err != nil {
return err
}
}
switch op {
case OpCopy:
_, err := io.CopyN(after, beforeBR, int64(tl))
if err != nil {
return err
}
case OpInsert:
_, err := io.CopyN(after, patchBR, int64(tl))
if err != nil {
return err
}
case OpDelete:
_, err := beforeBR.Discard(int(tl))
if err != nil {
return err
}
case OpCRC:
patchCRC := make([]byte, 4)
_, err := io.ReadFull(patchBR, patchCRC)
if err != nil {
return err
}
if !bytes.Equal(patchCRC, n.Sum(nil)) {
return ErrCRC
}
crcRead = true
default:
return fmt.Errorf("unexpected operation byte: %x", op)
}
}
return nil
}
func encodedLen(diffs []diff) int {
var total int
for _, d := range diffs {
// Op bytes
total++
// Size bytes. Copied from varint code
x := len(d.Text)
for x >= 0x80 {
x >>= 7
total++
}
total++
// Data
if d.Type == OpInsert {
total += len(d.Text)
}
}
return total
}